In [None]:
import torch
import numpy as np
from tqdm.auto import tqdm
import mdtraj as md
import random
from collections import defaultdict
from dataclasses import dataclass
import mdtraj

import functools
import openmm as mm
import simtk.unit as u  # type: ignore [import]
import matplotlib.pyplot as plt
from bgflow import OpenMMBridge, OpenMMEnergy
from openmm import unit

from timewarp.utils.training_utils import load_model
from timewarp.datasets import RawMolDynDataset
from timewarp.utils.openmm import OpenmmPotentialEnergyTorch
from timewarp.dataloader import (
    DenseMolDynBatch,
    moldyn_dense_collate_fn,
)
from itertools import islice
import os
from utilities.training_utils import set_seed
from typing import Optional, List, Union, Tuple,  DefaultDict, Dict
from timewarp.utils.energy_utils import get_energy_mean_std, plot_all_energy
from simulation.md import (
    get_simulation_environment,
    compute_energy_and_forces,
    compute_energy_and_forces_decomposition,
    get_parameters_from_preset, 
    get_simulation_environment_integrator,
    get_simulation_environment_for_force
)
from timewarp.equivariance.equivariance_transforms import transform_batch
from timewarp.utils.evaluation_utils import compute_kinetic_energy
from timewarp.utils.tica_utils import tica_features, run_tica, plot_tic01, plot_free_energy
from scipy.spatial import distance
import arviz as az
plt.rc('font', size=35) 
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")    


In [None]:
dataset = "4AA-huge"
#dataset = "2AA-1-huge"

parameters = "T1B-peptides"
step_width = 1000
# base_dir = ""


In [None]:
data_type = 'test'
data_dir = base_dir + f".data/simulated-data/trajectory-data/{dataset}/{data_type}"



In [None]:
def ESS(autocorrelations, spacing=1, cut_off_at_zero=True):
    if cut_off_at_zero:
        steps_until_zero = np.where(autocorrelations<=0)[0][0]
    else:
        steps_until_zero = -1
    Neff = 1/(-1 + 2 * spacing * np.abs(autocorrelations[:steps_until_zero]).sum())
    return Neff

In [None]:
# TODO
evaluation_outputs = {}

proteins = ["SAEL", "RYDT", "CSFQ", "FALS", "CSGS",
"LPEM", "LYVI", "AYTG", "VCVS", "AAEW",
"FKVP", "NQFM", "DTDL", "CTSA", "ANYT",
"VTST", "AWKC", "RGSP", "AVEK", "FIYG",
"VLSM", "QADY", "DQAL", "TFFL", "FIGE",
"KKQF", "SLTC", "ITQD", "DFKS", "QDED"]
proteins.sort()

for protein in tqdm(proteins):

    #     if protein in not_found_peptides:
    #         evaluation_outputs[protein] = 1e10
    #         print(f"MD fails to find all states for {protein}")
    #         continue

    # model_npz = np.load()
    trajectory_exploration = model_npz["positions"]
    time  = model_npz["time"]
    # MD
    npz_traj = np.load(base_dir+f'.data/simulated-data/trajectory-data/{dataset}/test/{protein}-traj-arrays.npz')
    positions = npz_traj['positions'][::5]

    state0pdbpath = os.path.join(data_dir, f"{protein}-traj-state0.pdb")
    topology = md.load(state0pdbpath).topology
    trajectory = md.Trajectory(
        xyz=positions,
        topology=topology
    )

    tica_model = run_tica(trajectory, lagtime=100)
    feats = tica_features(trajectory)
    tics = tica_model.transform(feats)

    traj_model = md.Trajectory(
        xyz=trajectory_exploration,
        topology=topology
    )
    feat_model = tica_features(traj_model)
    tics_model = tica_model.transform(feat_model)


    # check exploration MD
    ranges = np.abs(tics[::10, :2]).max(0)

    min_dist = 1
    best_i = 0
    for i in range(0, 100):
        dist = distance.cdist(tics[::10, :2]/ranges, tics_model[i::100][:, :2]/ranges).min(axis=1).max()
        if dist < min_dist:
            min_dist = dist
            best_i = i
        if dist < 0.1:
            break

    # times
    time_md = np.load(base_dir + f'.data/simulated-data/trajectory-data/{dataset}/{data_type}/{protein}-traj-time.npy')
    model_time_per_step = time / len(tics_model[best_i::100])
    md_time_per_step = time_md / (2*10**9)


    autocorrelation_model = az.autocorr(tics_model[best_i::100, 0])
    autocorrelation_openMM = az.autocorr(tics[:, 0])
    ess_model_s = ESS(autocorrelation_model, spacing=1)/(model_time_per_step)
    ess_md_s = ESS(autocorrelation_openMM, spacing=1)/(md_time_per_step * 10000)

    ratio_tic0 = ess_model_s / ess_md_s


    evaluation_outputs[protein] = ratio_tic0

    simulation = get_simulation_environment(state0pdbpath, parameters)
    integrator = get_simulation_environment_integrator(parameters)
    system = simulation.system

    openmm_potential_energy_torch = OpenmmPotentialEnergyTorch(system, integrator, platform_name='CUDA')
    energies_model = openmm_potential_energy_torch(torch.from_numpy(trajectory_exploration[best_i::100]))

    fig, ax = plt.subplots(figsize=(10,10))
    plot_tic01(ax, tics, "MD", tics_lims=tics)
    plot_tic01(ax, tics_model, f"{protein}", tics_lims=tics, cmap='autumn')

In [None]:
#np.save(base_dir + f'outputs/new-training/samples/4AA-model-exploration_evaluation_outputs.npy', evaluation_outputs)

In [None]:
evaluation_outputs

In [None]:
np.median(list(evaluation_outputs.values()))

## Look at the ramachandran plots, to see which are missing

In [None]:
# Peptides for which we find more states than MD
not_found_peptides  = ["ANYT", "AVEK", "DTDL", "LYVI", "NQFM", "QADY", "VLSM", "CTSA"]

for peptide in not_found_peptides:
    evaluation_outputs[peptide] = 1e10

In [None]:
save = False

In [None]:
fig1, ax1 = plt.subplots(figsize=(12,6))
ax1.set_title('4AA exploration - Speed-up TIC 0')
ax1.set_xticks([])
ax1.axhline(1, linestyle='--', color='firebrick', linewidth=2)
ax1.set_ylabel("Speed-up factor")
ax1.set_xlabel("Tetrapeptides")
values = np.sort(np.array(list(evaluation_outputs.values())))[::-1]
values[:len(not_found_peptides)] = 0
ax1.bar(np.arange(len(values)),values)
#ax1.bar(np.arange(len(values[values[:,1]>0])), np.sort(values[:, 1][values[:,1]>0])[::-1])
ax1.semilogy()
ax1.grid(linewidth=2.5)
ax1.set_xmargin(0.01)
ax1.axvspan(-0.5, len(not_found_peptides)-0.75, alpha=0.2, color='green')

#plt.axhline(1, linestyle='--', color='firebrick')
if save:
    plt.savefig(base_dir+"outputs/figures/4AA-exploration-speedup-bar-log.svg", bbox_inches="tight")

## As we saved only good matching chains before, this exludes the ones where we find more meta-stable states than MD. Hence, we have to run these peptides again and save *all* chains

In [None]:
# TODO

proteins = not_found_peptides
proteins.sort()

for protein in tqdm(proteins):
    
    # model_npz = np.load()
    trajectory_exploration = model_npz["positions"]
    time  = model_npz["time"]
    # MD
    npz_traj = np.load(base_dir+f'.data/simulated-data/trajectory-data/{dataset}/test/{protein}-traj-arrays.npz')
    positions = npz_traj['positions'][::5]

    state0pdbpath = os.path.join(data_dir, f"{protein}-traj-state0.pdb")
    topology = md.load(state0pdbpath).topology
    trajectory = md.Trajectory(
        xyz=positions,
        topology=topology
    )

    tica_model = run_tica(trajectory, lagtime=100)
    feats = tica_features(trajectory)
    tics = tica_model.transform(feats)

    traj_model = md.Trajectory(
        xyz=trajectory_exploration,
        topology=topology
    )
    feat_model = tica_features(traj_model)
    tics_model = tica_model.transform(feat_model)

    simulation = get_simulation_environment(state0pdbpath, parameters)
    integrator = get_simulation_environment_integrator(parameters)
    system = simulation.system

    openmm_potential_energy_torch = OpenmmPotentialEnergyTorch(system, integrator, platform_name='CUDA')
    energies_model = openmm_potential_energy_torch(torch.from_numpy(trajectory_exploration))
