In [None]:
import torch
import numpy as np
from tqdm.auto import tqdm
import mdtraj as md
import random
from collections import defaultdict
from dataclasses import dataclass
import mdtraj

import functools
import openmm as mm
import simtk.unit as u  # type: ignore [import]
import matplotlib.pyplot as plt
from bgflow import OpenMMBridge, OpenMMEnergy
from openmm import unit

from timewarp.utils.training_utils import load_model
from timewarp.datasets import RawMolDynDataset
from timewarp.utils.openmm import OpenmmPotentialEnergyTorch
from timewarp.dataloader import (
    DenseMolDynBatch,
    moldyn_dense_collate_fn,
)
from itertools import islice
import os
from utilities.training_utils import set_seed
from typing import Optional, List, Union, Tuple,  DefaultDict, Dict
from timewarp.utils.energy_utils import get_energy_mean_std, plot_all_energy
from simulation.md import (
    get_simulation_environment,
    compute_energy_and_forces,
    compute_energy_and_forces_decomposition,
    get_parameters_from_preset, 
    get_simulation_environment_integrator,
    get_simulation_environment_for_force
)
from timewarp.equivariance.equivariance_transforms import transform_batch
from timewarp.utils.evaluation_utils import compute_kinetic_energy
from timewarp.utils.training_utils import (
    end_of_epoch_report,
    load_or_construct_loss,
    load_or_construct_loss_scheduler,
    run_on_dataloader,
)
from timewarp.losses import (
    wrap_or_replace_loss,
)
from timewarp.model_constructor import model_constructor


plt.rc('font', size=30) 
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")    


In [None]:
dataset = "4AA-huge"
step_width = 1000000

In [None]:
# Load the dataset
data_type = 'val'
data_dir = base_dir + f".data/simulated-data/trajectory-data/{dataset}/{data_type}"
#protein = "AACG"

protein = "LAKS"



raw_dataset = RawMolDynDataset(data_dir=data_dir, step_width=step_width, equal_data_spacing=False)
pdb_names = [protein]
raw_iterator = raw_dataset.make_iterator(pdb_names)
batches = (moldyn_dense_collate_fn([datapoint]) for datapoint in raw_iterator)
batches = list(islice(batches, 5000))  
state0pdbpath = os.path.join(data_dir, f"{protein}-traj-state0.pdb")
parameters =  "T1B-peptides" # "alanine-dipeptide"

In [None]:
simulation = get_simulation_environment(state0pdbpath, parameters)
integrator = get_simulation_environment_integrator(parameters)
system = simulation.system

openmm_potential_energy_torch = OpenmmPotentialEnergyTorch(system, integrator, platform_name='CUDA')
kbT = (integrator.getTemperature() * unit.MOLAR_GAS_CONSTANT_R).value_in_unit(unit.kilojoule_per_mole)

num_atoms = system.getNumParticles()
masses = [system.getParticleMass(i).value_in_unit(u.dalton) for i in range(num_atoms)]
masses = torch.tensor(masses).to(device)
#parameters = get_parameters_from_preset(parameters)


In [None]:
model_type = 'latest'

# construct model

state_dict = torch.load(savefile, map_location=lambda storage, loc: storage)

config = state_dict["training_config"]
model = model_constructor(config.model_config)
loss_computer = load_or_construct_loss(config)
model = wrap_or_replace_loss(model, loss_computer)
model.load_state_dict(state_dict['module'])
model = model.module.to(device)

## Sample with model on conditioning samples from the Boltzmann distribution

In [None]:
from timewarp.utils.evaluation_utils import sample_on_batches

n_samples = 200
(
    y_coords_model,
    y_velocs_model,
    traj_coords,
    traj_velocs,
    traj_coords_conditioning,
    _,
    ll_reverse,
    ll_forward,
    ll_reverse_training,
    ll_forward_training,
    acceptance,
) = sample_on_batches(
    batches[:n_samples],
    model,
    device,
    openmm_potential_energy_torch,
    True,
    masses,
    random_velocs=True,
)

In [None]:
batch = batches[50]

In [None]:
x_coords = batch.atom_coords
y_coords_target = batch.atom_coord_targets
atom_types = batch.atom_types
adj_list = batch.adj_list
edge_batch_idx = batch.edge_batch_idx
masked_elements = batch.masked_elements
x_velocs = torch.randn_like(batch.atom_velocs, device=device).contiguous()
y_velocs_target = torch.randn_like(batch.atom_veloc_targets, device=device).contiguous()
model(
    atom_types=atom_types.to(device),
    x_coords=x_coords.to(device),
    x_velocs=x_velocs.to(device),
    y_coords=y_coords_target.to(device),
    y_velocs=y_velocs_target.to(device),
    adj_list=adj_list.to(device),
    edge_batch_idx=edge_batch_idx.to(device),
    masked_elements=masked_elements.to(device),
)

In [None]:
protein

In [None]:
# 0.2
acceptance.mean()

In [None]:
# Energies
print("Plotting energies...")
energies_model_bg = openmm_potential_energy_torch(torch.from_numpy(y_coords_model)).cpu().numpy().flatten()
energies_traj = openmm_potential_energy_torch(torch.from_numpy(traj_coords)).cpu().numpy()

# potential energy distribution
plt.figure(figsize=(16, 9))
plt.hist(
    energies_model_bg,
    bins=100,
    color="orange",
    alpha=0.5,
    density=True,
    label="model",
    range=(energies_traj.min(), 500)

)
plt.hist(
    energies_traj,
    bins=100,
    color="green",
    alpha=0.5,
    density=True,
    label="openMM",
)
plt.xlabel("Energy in kJ/mol")
plt.legend()
plt.title("Sample on data: Potential energy distribution")

In [None]:
(energies_model_bg>1000).sum()/len(energies_model_bg)

In [None]:
# Likelihood difference samples
plt.figure(figsize=(16, 9))
plt.hist((ll_forward - ll_reverse), bins=100, label="Model samples", range=(-100,100))
plt.xlabel("Log-likelihood")
plt.title("Log-likelihood difference distribution for model samples")


# Likelihoods
plt.figure(figsize=(16, 9))
plt.hist((ll_forward), alpha=0.5, density=True, bins=100, label="Model samples")
plt.hist((ll_reverse), alpha=0.5, density=True, bins=100, label="Model samples reverse")
plt.hist((ll_forward_training), alpha=0.5, density=True, bins=100, label="Training samples")
plt.hist(
    (ll_reverse_training), alpha=0.5, density=True, bins=100, label="Reverse training samples"
)
plt.xlabel("Log-likelihood")
plt.title("Sample on data: Log-likelihood distributions")
plt.legend()

In [None]:
plt.figure(figsize=(16,9))
plt.hist(np.random.randn(10000), alpha=0.5, density=True, bins=100, label="Gaussian");
plt.hist(y_velocs_model.flatten(), alpha=0.5, density=True, bins=100, label="Model samples", range=(-10,10));
plt.title("Velocity distribution")
plt.legend();

prob = 0.5 * y_velocs_model**2
prob_gauss = 0.5 * np.random.randn(*y_velocs_model.shape)**2
plt.figure(figsize=(16,9))
plt.hist(prob_gauss.sum(axis=(1,2)), alpha=0.5, density=True, bins=100, label="Gaussian");
plt.hist(prob.sum(axis=(1,2)), alpha=0.5, density=True, bins=100, label="Model samples");
plt.title("Kinetic energy distribution")
plt.legend();

In [None]:
delta_x = traj_coords - traj_coords_conditioning
delta_x_model = y_coords_model - traj_coords_conditioning


fig, axs = plt.subplots(2, 2, figsize=(16, 16), sharey=True, sharex=True)
for i in range(2):
    for j in range(2):

        axs[i, j].hist(
            delta_x[:, i * 5 + j, 0],
            bins=100,
            alpha=0.5,
            density=True,
            label="openMM",
        )
        axs[i, j].hist(
            delta_x_model[:, i * 5 + j, 0],
            bins=100,
            alpha=0.5,
            density=True,
            label="model",
        )
axs[i, j].legend()


# Sample conditional distribution


In [None]:
from timewarp.utils.evaluation_utils import compute_kinetic_energy

initial_idx = 108
batch = batches[initial_idx]

num_samples = 10000
x_coords=batch.atom_coords.to(device)
random_velocs = True
with torch.no_grad():
    y_coords, y_velocs, p_xy = model.conditional_sample_with_logp(
        atom_types=batch.atom_types.to(device),
        x_coords=x_coords,
        x_velocs=batch.atom_velocs.to(device),
        adj_list=batch.adj_list,
        edge_batch_idx=batch.edge_batch_idx.to(device),
        masked_elements=batch.masked_elements.to(device),
        num_samples=num_samples,
    )

    y_coords = y_coords.squeeze(1)
    y_velocs = y_velocs.squeeze(1)
    x_coords = x_coords.repeat(num_samples, 1, 1)
    x_velocs = torch.randn_like(y_velocs)
        
    p_yx = model.log_likelihood(
        atom_types=batch.atom_types.repeat(x_coords.shape[0], 1).to(device),
        y_coords=x_coords,
        y_velocs=x_velocs,
        x_coords=y_coords,
        x_velocs=y_velocs,
        adj_list=batch.adj_list,
        edge_batch_idx=batch.edge_batch_idx.to(device),
        masked_elements=batch.masked_elements.repeat(x_coords.shape[0], 1).to(device),
    )

    p_xy = p_xy.reshape(p_yx.shape)
    
    e_pot_x = (openmm_potential_energy_torch(x_coords) / kbT).squeeze(-1)
    e_kin_x = compute_kinetic_energy(x_velocs, masses, random_velocs=random_velocs)

    assert y_coords.shape == x_coords.shape
    e_kin_y = compute_kinetic_energy(y_velocs, masses, random_velocs=random_velocs)
    e_kin = e_kin_y - e_kin_x

    e_pot_y = (openmm_potential_energy_torch(y_coords) / kbT).squeeze(-1)
    e_pot = e_pot_y - e_pot_x
    assert e_kin.shape == e_pot.shape
    energy = e_pot + e_kin

    assert energy.shape == p_xy.shape
    exp = energy + p_xy - p_yx

    p_acc = torch.min(torch.tensor(1), torch.exp(-exp))

In [None]:
p_acc.mean()

In [None]:
plt.figure(figsize=(16, 9))
diff_log_likeli = p_xy.cpu().numpy() - p_yx.cpu().numpy()
plt.hist(diff_log_likeli, bins=100, label="Model samples", range=(-100,200))
plt.xlabel("Log-likelihood")
plt.title("Log-likelihood difference distribution for model samples")


In [None]:
from timewarp.utils.evaluation_utils import compute_internal_coordinates

bonds, torsions = compute_internal_coordinates(
    state0pdbpath, batch.adj_list.cpu().numpy(), y_coords.cpu().numpy()
)

bonds_traj_conditioning, torsions_traj_conditioning = compute_internal_coordinates(
    state0pdbpath, batch.adj_list.cpu().numpy(), batch.atom_coords.cpu().numpy()
)

In [None]:
energies_model = openmm_potential_energy_torch(y_coords.squeeze(1)).cpu().numpy()
energy_conditioning = openmm_potential_energy_torch(x_coords[0]).cpu().numpy()

npz_traj = np.load(base_dir+f'/.data/simulated-data/trajectory-data/{dataset}/{data_type}/{protein}-traj-arrays.npz')
traj_energy_calc = npz_traj['energies'][:, 0]
plt.figure(figsize=(16, 9))

# potential energy distribution
plt.hist(
    traj_energy_calc,
    bins=100,
    color="green",
    alpha=0.5,
    density=True,
    label="openMM",
)

# potential energy distribution
plt.hist(
    energies_model,
    bins=100,
    color="orange",
    alpha=0.5,
    density=True,
    label="model",
    range=(traj_energy_calc.min(), 500)
);


plt.xlabel("Energy in kJ/mol")
plt.legend()
plt.title("Conditional potential energy distribution")
energy_conditioning

In [None]:
good_energies = (energies_model<0) 

In [None]:
good_energies.sum()/len(energies_model)

In [None]:
plt.figure(figsize=(16, 9))

# potential energy distribution
plt.hist(
    traj_energy_calc,
    bins=100,
    color="green",
    alpha=0.5,
    density=True,
    label=f"openMM - mean={traj_energy_calc.mean():.1f}",
)

# potential energy distribution
plt.hist(
    energies_model,
    bins=100,
    color="orange",
    alpha=0.5,
    density=True,
    label=f"model - mean={energies_model[good_energies].mean():.1f}",
    range=(traj_energy_calc.min(), 500)
);


plt.xlabel("Energy in kJ/mol")
plt.legend()
plt.title("Conditional potential energy distribution")
energy_conditioning

In [None]:
import matplotlib.cm as cm
fig, axes = plt.subplots(1, 3, figsize=(40, 10), sharey=True )
for i in range(3):
    #plt.title(f"Ramachandran plot - model")
    im = axes[i].scatter(torsions[0][:, [i]][good_energies], torsions[1][:, [i]][good_energies], c=energies_model[good_energies], cmap='rainbow', s=20)
    axes[i].scatter(torsions_traj_conditioning[0][0,i], torsions_traj_conditioning[1][0,i], marker="x", color="red", s=500, linewidths=5 )
    axes[i].set_xlim(-np.pi,np.pi)
    axes[i].set_ylim(-np.pi,np.pi)
    axes[i].set_xlabel("Phi")
    axes[0].set_ylabel("Psi")
fig.suptitle(f"Ramachandran plots - model - energy")
#im=cm.ScalarMappable()
fig.colorbar(im, ax=axes.ravel().tolist())


In [None]:
import matplotlib.cm as cm
acceptance = p_acc.reshape(-1, 1).cpu().numpy()
good_acceptance = acceptance > 0.0001
fig, axes = plt.subplots(1, 3, figsize=(40, 10), sharey=True )
for i in range(3):
    #plt.title(f"Ramachandran plot - model")
    im = axes[i].scatter(torsions[0][:, [i]][good_acceptance], torsions[1][:, [i]][good_acceptance], c=acceptance[good_acceptance], cmap='rainbow', s=20)
    axes[i].scatter(torsions_traj_conditioning[0][0,i], torsions_traj_conditioning[1][0,i], marker="x", color="red", s=500, linewidths=5 )
    axes[i].set_xlim(-np.pi,np.pi)
    axes[i].set_ylim(-np.pi,np.pi)
    axes[i].set_xlabel("Phi")
    axes[0].set_ylabel("Psi")
fig.suptitle(f"Ramachandran plot - model - acceptance")
#im=cm.ScalarMappable()
fig.colorbar(im, ax=axes.ravel().tolist())




In [None]:
import matplotlib as mpl

def plot_ramachandrans(torsions, name):
    fig, axes = plt.subplots(1, 3, figsize=(33, 10), sharey=True )
    for i in range(3):
        #plt.title(f"Ramachandran plot - model")
        axes[i].hist2d(torsions[0][:, i], torsions[1][:, i], bins=100, norm=mpl.colors.LogNorm())
        #axes[i].scatter(torsions[0][0, i], torsions[1][0, i], marker="x", color="red", s=500, linewidths=5)
        axes[i].set_xlim(-np.pi,np.pi)
        axes[i].set_ylim(-np.pi,np.pi)
        axes[i].set_xlabel("Phi")
        axes[0].set_ylabel("Psi")
    fig.suptitle(f"Ramachandran plots - {name}")


In [None]:
plot_ramachandrans(torsions, f'{protein} - model')

In [None]:
batch = batches[0]

In [None]:
bonds_md, torsions_md = compute_internal_coordinates(
    state0pdbpath, batch.adj_list.cpu().numpy(), npz_traj['positions']
)
plot_ramachandrans(torsions_md, f'{protein} - MD')

In [None]:
fig, axs = plt.subplots(2, 2, figsize=(16, 16), sharey=True, sharex=True)
atom_types = batches[0].atom_types.cpu().detach().numpy()[0]
atom_names = ["C", "H", "N", "O", "S"]
adj_list = batches[0].adj_list

for i in range(2):
    for j in range(2):
        idx =  i * 7 + j
        axs[i, j].hist(
            bonds_md[:, idx],
            bins=100,
            alpha=0.5,
            density=True,
            label="Boltzmann",
        )
        axs[i, j].hist(
            bonds[:, idx],
            bins=100,
            alpha=0.5,
            density=True,
            label="model",
        )
        axs[i, j].set_title(
            f"{atom_names[atom_types[adj_list[idx, 0]]]}-{atom_names[atom_types[adj_list[idx, 1]]]}",
            y=0.7,
        )
axs[i, j].legend()
fig.supxlabel("Bondlength in nm")
fig.suptitle(f"Conditional bondlength distribution")


In [None]:
fig, axs = plt.subplots(5, 5, figsize=(16, 16), sharey=True, sharex=True)
atom_types = batches[0].atom_types.cpu().detach().numpy()[0]
atom_names = ["C", "H", "N", "O", "S"]
adj_list = batches[0].adj_list

for i in range(5):
    for j in range(5):
        idx =  i * 5 + j
        axs[i, j].hist(
            bonds_md[:, idx],
            bins=100,
            alpha=0.5,
            density=True,
            label="Boltzmann",
        )
        axs[i, j].hist(
            bonds[:, idx],
            bins=100,
            alpha=0.5,
            density=True,
            label="model",
        )
        axs[i, j].set_title(
            f"{atom_names[atom_types[adj_list[idx, 0]]]}-{atom_names[atom_types[adj_list[idx, 1]]]}",
            y=0.7,
        )
axs[i, j].legend()
fig.supxlabel("Bondlength in nm")
fig.suptitle(f"Conditional bondlength distribution")


In [None]:
traj = mdtraj.load(data_dir+f"/{protein}-traj-state0.pdb")


In [None]:
plt.figure(figsize=(12, 6))

bond_types, coutns = np.unique(np.sort(atom_types[adj_list], axis=-1), return_counts=True, axis=0)


for bond_type in bond_types:

    bond_idxs = np.where(np.all(np.sort(atom_types[adj_list], axis=-1) == bond_type, axis=1))[0]
    bond_md = bonds_md[:, bond_idxs]
    bond_model = bonds[:, bond_idxs]
    plt.hist(bond_md.flatten(), bins=100, density=True, histtype='step',linewidth=5, label="MD", color="C0");
    plt.hist(bond_model.flatten(), bins=100, density=True, histtype='step',linewidth=5, label="model",  color="C1", linestyle="--");
    plt.title(f"{bond_type}")
    if np.all(bond_type == bond_types[0]):
        plt.legend();
plt.title("Bondlength distribution")
plt.xlabel("Bondlength in nm")