# Posprocessing a run

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
%env XLA_PYTHON_CLIENT_MEM_FRACTION=.6

env: XLA_PYTHON_CLIENT_MEM_FRACTION=.6


In [3]:
import lenses
import jax
import numpy as np
from jax import numpy as jnp
from typing import cast
import matplotlib.pyplot as plt
from tqdm import trange

from flox.util import key_chain
from flox.flow import Pipe, Inverted, bind#, Transform, Transformed
import equinox as eqx
from functools import partial

from rigid_flows.flow import build_flow, RigidWithAuxiliary
from rigid_flows.data import Data, DataWithAuxiliary
from rigid_flows.density import OpenMMDensity
from rigid_flows.specs import FlowSpecification, CouplingSpecification, ExperimentSpecification
# from rigid_flows.density import PositionPrior, RotationPrior
from rigid_flows.utils import jit_and_cleanup_cache, scanned_vmap

chain = key_chain(42)

2023-01-25 16:38:41.427047: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: :/home/mi/minvernizzi/.local/lib
2023-01-25 16:38:41.427137: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer_plugin.so.7'; dlerror: libnvinfer_plugin.so.7: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: :/home/mi/minvernizzi/.local/lib




In [4]:
def count_params(model):
    return jax.tree_util.tree_reduce(
        lambda s, n: s + n.size if eqx.is_array(n) else s, model, jnp.zeros((), dtype=jnp.int32)).item()

def ess(logw):
    return jnp.exp(2*jax.scipy.special.logsumexp(logw)-jax.scipy.special.logsumexp(2*logw))

In [5]:
logdir_path = 'jonas_logdir/'
logdir_path += 'dragonfly_N128_T100_noaux_2023-01-25_05:35:15'
# logdir_path += 'antelope_N16_T50_noaux_2023-01-25_08:08:54'
# logdir_path += 'tuna_N16_T100_noaux_2023-01-24_16:02:39'
stage = 0
epoch = 9
print(f'+++ epoch {epoch} +++')
specs_path = f"{logdir_path}/config.yaml"
pretrained_model_path = f'{logdir_path}/training_stage_{stage}/epoch_{epoch}/model.eqx'
print(pretrained_model_path)

specs = ExperimentSpecification.load_from_file(specs_path)
specs = lenses.bind(specs).model.base.path.set(specs.model.base.path+'/eval_100')

base = OpenMMDensity.from_specs(specs.model.auxiliary_shape, specs.model.base)
target = OpenMMDensity.from_specs(specs.model.auxiliary_shape, specs.model.target)
model = base.omm_model.model

data = Data.from_specs(specs.model.target, target.box)
flow = build_flow(next(chain), specs.model.auxiliary_shape, specs.model.flow)
flow = cast(Pipe[DataWithAuxiliary, RigidWithAuxiliary], eqx.tree_deserialise_leaves(pretrained_model_path, flow))

training_data_size = 100_000 if specs.model.base.num_samples is None else specs.model.base.num_samples
print(f'tot flow parameters: {count_params(flow):_}')
print(f'MD training datapoints = {training_data_size:_}')
print(f'MD eval datapoints = {base.data.pos.shape[0]:_}')
print(f'batchs per epoch = {specs.train[0].num_iters_per_epoch}')
print(f'batch size = {specs.train[0].num_samples}')
print(f'data fraction: {specs.train[0].num_epochs*specs.train[0].num_samples*specs.train[0].num_iters_per_epoch/training_data_size}')

+++ epoch 9 +++
jonas_logdir/dragonfly_N128_T100_noaux_2023-01-25_05:35:15/training_stage_0/epoch_9/model.eqx
tot flow parameters: 7_487_568
MD training datapoints = 100_000
MD eval datapoints = 100_000
batchs per epoch = 2000
batch size = 32
data fraction: 9.6


In [6]:
reference_deltaF, reference_deltaF_std = None, None

if specs.model.base.temperature == specs.model.target.temperature:
    reference_deltaF, reference_deltaF_std = 0, 0
elif specs.model.base.num_molecules == 16:
    if specs.model.base.temperature == 250 and specs.model.target.temperature == 100:
        reference_deltaF, reference_deltaF_std = -666.09897990553, 0.0558899 # MBAR, 10_000 samples, 5 replicas
    elif specs.model.base.temperature == 250 and specs.model.target.temperature == 50:
        reference_deltaF, reference_deltaF_std = -1818.2199636389134, 0.0632776 # MBAR, 10_000 samples, 10 replicas
elif specs.model.base.num_molecules == 128:
    if specs.model.base.temperature == 250 and specs.model.target.temperature == 100:
        reference_deltaF, reference_deltaF_std = -5314.324317927606, 0.0944014 # MBAR, 10_000 samples, 10 replicas
    elif specs.model.base.temperature == 250 and specs.model.target.temperature == 50:
        reference_deltaF, reference_deltaF_std = -14522.10556489954, 0.135 # MBAR, 10_000 samples, 20 replicas

reference_deltaF, reference_deltaF_std

(-5314.324317927606, 0.0944014)

In [7]:
sc = 3 * model.n_waters
test_subplots = False

In [8]:
import json

with open(logdir_path+'/loss.json', "r") as f:
    loss_info = np.array(json.load(f))[:,1:]

In [9]:
from mpl_toolkits.axes_grid1.inset_locator import inset_axes
import matplotlib.image as mpimg


def plot_loss():
    plt.plot(loss_info[:,0], loss_info[:,1]/sc)
    plt.axhline(reference_deltaF/sc, c='k', ls=':')
    plt.xlim(0, loss_info[:,0].max()+loss_info[:,0].min())
    plt.ylabel('loss')
    plt.xlabel('steps')
    
    img = mpimg.imread(f'data/water/iceXI-N{model.n_waters}.png')
    subax = inset_axes(plt.gca(), width="50%", height="50%", loc=1)
    subax.imshow(img)
    subax.axis('off')
    
if test_subplots:
    plot_loss()
    plt.show()

In [10]:
num_samples = 100_000
batch_size = 128

print('sample base')
keys = jax.random.split(next(chain), num_samples)
base_tr = scanned_vmap(base.sample, batch_size)(keys)
base_ene = base_tr.ldj * base.omm_model.kbT

print('map base')
mapped_tr = scanned_vmap(flow.forward, batch_size)(base_tr.obj)

print('compute weights')
## NB: base_tr.ldj = jax.vmap(base.potential)(base_tr.obj)
mapped_ene_tg = scanned_vmap(target.potential, batch_size)(mapped_tr.obj)
mapped_ene = mapped_ene_tg * target.omm_model.kbT

logw = base_tr.ldj + mapped_tr.ldj - mapped_ene_tg
weights = jnp.exp(logw-jax.scipy.special.logsumexp(logw))

print('sample target')
keys = jax.random.split(next(chain), num_samples)
target_tr = scanned_vmap(target.sample, batch_size)(keys)
target_ene = target_tr.ldj * target.omm_model.kbT

sample base


In [None]:
def plot_2Dview(data_tr: DataWithAuxiliary, title: str, toPBC: bool = True, skip: int = 100):
    model.plot_2Dview(data_tr.obj.pos.reshape(-1, model.n_atoms, 3)[::skip], toPBC=toPBC, title=title)

plot_2Dview(base_tr, title='base')
plot_2Dview(mapped_tr, title='mapped')
plot_2Dview(target_tr, title='target')

plt.hist(jnp.exp(logw - logw.max()), bins=100)
plt.yscale('log')
plt.show()
print(f'ESS = {ess(logw):g}  ->  {ess(logw)/len(logw):.2%}')

In [None]:
hist_kwargs = {"bins": "auto", "density": True, "alpha": 0.5}

def plot_energy():
    plt.hist(base_ene/sc, **hist_kwargs, label='base')
    plt.hist(mapped_ene/sc, **hist_kwargs, label='mapped')
    plt.hist(target_ene/sc, **hist_kwargs, label='target')
    plt.hist(mapped_ene/sc, weights=weights, bins=100, histtype='step', density=True, label='reweighted')
    plt.xlabel('potential energy [kJ/mol]')
    plt.legend()

if test_subplots:
    plot_energy()
    plt.show()

In [None]:
#TODO add reweighted rdf

def plot_rdf():
    r_range = [0.2, np.diag(model.box).min()]
    n_bins = 200
    # plt.title('Oxygen radial distribution function')
    model.plot_rdf(base_tr.obj.pos.reshape(-1, model.n_atoms, 3), r_range=r_range, n_bins=n_bins, label='base')
    model.plot_rdf(mapped_tr.obj.pos.reshape(-1, model.n_atoms, 3), r_range=r_range, n_bins=n_bins, label='mapped')
    model.plot_rdf(target_tr.obj.pos.reshape(-1, model.n_atoms, 3), r_range=r_range, n_bins=n_bins, label='target')
    model.plot_rdf(target_tr.obj.pos.reshape(-1, model.n_atoms, 3), r_range=r_range, n_bins=n_bins, label='reweighted', ls=':') #TODO actually calculate it!
    # plt.axvline(model.box.max()/2, c='k', ls=':', alpha=.5)
    plt.ylabel("oxygen g(r)")
    plt.legend()

if test_subplots:
    plot_rdf()
    plt.show()

In [None]:
## TFEP
deltaF = (jnp.log(len(logw)) - jax.scipy.special.logsumexp(logw)).item()
print(f'Estimated deltaF from LFEP = {deltaF/sc:g}')
if reference_deltaF is not None:
    print(f'          Reference deltaF = {reference_deltaF/sc:g}')

In [None]:
plt.figure(figsize=(15, 4))

plt.subplot(1, 3, 1)
plot_loss()

plt.subplot(1, 3, 2)
plot_energy()

plt.subplot(1, 3, 3)
plot_rdf()

filename = f'fig3-N{model.n_waters}_T{specs.model.target.temperature}.pdf'
plt.savefig(filename, bbox_inches='tight')
plt.show()

!cp {filename} {logdir_path}/{filename}

In [None]:
raise SystemError

In [None]:
@partial(jax.jit, static_argnames=["num_samples", "base", "batch_size"])
def estimate_deltaF(key, num_samples, base=base, batch_size=128):
    keys = jax.random.split(key, num_samples)
    base_tr = jax.vmap(base.sample)(keys)
    mapped_tr = scanned_vmap(flow.forward, batch_size)(base_tr.obj)
    logw = base_tr.ldj + mapped_tr.ldj - jax.vmap(target.potential)(mapped_tr.obj)
    return (jnp.log(len(logw)) - jax.scipy.special.logsumexp(logw))

In [None]:
iterations = 10
num_samples = base.data.pos.shape[0] // iterations

deltaFs = np.zeros(iterations)
for i in trange(iterations):
    deltaFs[i] = estimate_deltaF(next(chain), num_samples)
    print(deltaFs[i])
print(f'deltaF = {deltaFs.mean():g} +/- {deltaFs.std():g}')

In [None]:
# @partial(jax.jit, static_argnames=["start", "num_samples", "base", "batch_size"])
# def estimate_deltaF_idx(key, start, num_samples, base=base, batch_size=128):
#     keys = jax.random.split(key, num_samples)
#     base_tr = jax.vmap(base.sample_idx)(keys, jnp.arange(start, start+num_samples))
#     mapped_tr = scanned_vmap(flow.forward, batch_size)(base_tr.obj)
#     logw = base_tr.ldj + mapped_tr.ldj - jax.vmap(target.potential)(mapped_tr.obj)
#     return (jnp.log(len(logw)) - jax.scipy.special.logsumexp(logw))

# deltaFs = np.zeros(iterations)
# for i in trange(iterations):
#     deltaFs[i] = estimate_deltaF_idx(next(chain), i*num_samples, num_samples)
#     print(deltaFs[i])
# print(f'deltaF = {deltaFs.mean():g} +/- {deltaFs.std():g}')

In [None]:
xlim = [0, len(deltaFs)]
plt.plot(deltaFs, '.', c='orange', label='LFEP')
x = 2 * [deltaFs.mean()]
plt.fill_between(xlim, x-deltaFs.std(), x+deltaFs.std(), color='orange', alpha=0.3)
plt.axhline(deltaFs.mean(), c='orange')
if reference_deltaF is not None:
    plt.axhline(reference_deltaF, c='k', ls=":", label='MBAR reference')
    x = np.array(2 * [reference_deltaF])
    plt.fill_between(xlim, x-reference_deltaF_std, x+reference_deltaF_std, color='k', alpha=0.1)
plt.xlim(xlim)
plt.legend()
plt.show()

In [None]:
deltaFs.mean()-reference_deltaF