In [None]:
import torch
import numpy as np
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")    

import random
from collections import defaultdict
from dataclasses import dataclass
import mdtraj
import functools
import openmm as mm
import simtk.unit as u  # type: ignore [import]
import matplotlib.pyplot as plt
from bgflow import OpenMMBridge, OpenMMEnergy#, LinLog

#from timewarp.sample import sample, main
from timewarp.utils.training_utils import load_model
from timewarp.datasets import RawMolDynDataset
from timewarp.dataloader import (
    DenseMolDynBatch,
    moldyn_dense_collate_fn,
)
from itertools import islice
import os
from utilities.training_utils import set_seed
from typing import Optional, List, Union, Tuple,  DefaultDict, Dict
from timewarp.utils.energy_utils import get_energy_mean_std, plot_all_energy
from simulation.md import (
    get_simulation_environment,
    compute_energy_and_forces,
    compute_energy_and_forces_decomposition,
    get_parameters_from_preset, 
    get_simulation_environment_integrator,
    get_simulation_environment_for_force
)
from timewarp.equivariance.equivariance_transforms import transform_batch
from tqdm.auto import tqdm
import mdtraj as md
from timewarp.utils.tica_utils import tica_features, run_tica, plot_tic01, plot_free_energy

plt.rc('font', size=35) 


### The convservative metastable state heuristic might flag peptides that were actually discovered all states. These need to be rerun

In [None]:
# Load the model

# config = torch.load(savefile)["training_config"]
step_width = 10000 # config.step_width
dataset = '4AA' # config.dataset
dataset = '2AA-1-huge'
dataset = '4AA-huge'

step_width, dataset


In [None]:
# Load the dataset
data_type = 'test'
data_dir = base_dir + f".data/simulated-data/trajectory-data/{dataset}/{data_type}" # train

In [None]:
def load_trajectory_from_dir(directory_name, protein, stride=1):
    directory = os.fsencode(directory_name)
    trajectories = []
    times = []
    for i in range(len(os.listdir(directory))-1):
        npz = np.load(directory_name+f"/{protein}_trajectory_model_{i}.npz")
        trajectories.append(npz['positions'][::stride])
        try:
            times.append(npz['time'])
        except:
            pass
    return np.concatenate(trajectories, axis=0), np.array(times)

In [None]:
initial_state = 7
model_stride = 1
num_md_steps = 10000

In [None]:
def get_names(directory):
    directory = os.fsencode(directory)
    proteins = []
    for file in os.listdir(directory):
        filename = os.fsdecode(file)
        if filename.endswith(".npz"): 
            proteins.append(filename[:4])
            continue
        else:
            continue
    return proteins

In [None]:
import arviz as az

def ESS(autocorrelations, spacing=1, cut_off_at_zero=True):
    if cut_off_at_zero:
        steps_until_zero = np.where(autocorrelations<=0)[0][0]
    else:
        steps_until_zero = -1
    Neff = 1/(-1 + 2 * spacing * np.abs(autocorrelations[:steps_until_zero]).sum())
    return Neff

In [None]:
# protein: [accetpance, ratio_tic0, ratio_tic1]
from scipy.spatial import distance

evaluation_outputs = {}
initial_state = 1
#proteins = get_names(data_dir)
proteins = ["SAEL", "RYDT", "CSFQ", "FALS", "CSGS",
    "LPEM", "LYVI", "AYTG", "VCVS", "AAEW",
    "FKVP", "NQFM", "DTDL", "CTSA", "ANYT",
    "VTST", "AWKC", "RGSP", "AVEK", "FIYG",
    "VLSM", "QADY", "DQAL", "TFFL", "FIGE",
    "KKQF", "SLTC", "ITQD", "DFKS", "QDED"]
#proteins = ["CTSA", "DFKS", "DTDL", "KKQF"]
proteins.sort()
for protein in tqdm(proteins):
    sampled_coords, times=load_trajectory_from_dir(dir_data, protein, model_stride)
    if len(sampled_coords) < 1900000:
        print(f"{protein} trajectory not finished")
        assert False
    npz_traj = np.load(base_dir + f'.data/simulated-data/trajectory-data/{dataset}/{data_type}/{protein}-traj-arrays.npz')
    xyz = npz_traj['positions'][::5]
    state0pdbpath = base_dir + f'.data/simulated-data/trajectory-data/{dataset}/{data_type}/{protein}-traj-state0.pdb'
    topology = md.load(state0pdbpath).topology   
    trajectory = md.Trajectory(
        xyz=xyz,
        topology=topology
    )
    tica_model = run_tica(trajectory, lagtime=100)
    features = tica_features(trajectory)
    tics = tica_model.transform(features)
    traj_model = md.Trajectory(
        xyz=sampled_coords,
        topology=topology
    )
    feat_model = tica_features(traj_model)
    tics_model = tica_model.transform(feat_model)

    acceptance = (len(np.unique(sampled_coords[:, 0, 0]))-1)/len(sampled_coords) / 10

    # times
    time_md = np.load(base_dir + f'.data/simulated-data/trajectory-data/{dataset}/{data_type}/{protein}-traj-time.npy')
    model_time_per_step = times.sum() / (len(sampled_coords) * 10)
    md_time_per_step = time_md / (2*10**9)



    autocorrelation_model = az.autocorr(tics_model[:, 1])
    autocorrelation_openMM = az.autocorr(tics[:, 1])

    ess_model_s = ESS(autocorrelation_model, spacing=1)/(model_time_per_step * 10)
    ess_md_s = ESS(autocorrelation_openMM, spacing=1)/(md_time_per_step * num_md_steps)
    ratio_tic1 = ess_model_s / ess_md_s

    autocorrelation_model = az.autocorr(tics_model[:, 0])
    autocorrelation_openMM = az.autocorr(tics[:, 0])
    ess_model_s = ESS(autocorrelation_model, spacing=1)/(model_time_per_step * 10)
    ess_md_s = ESS(autocorrelation_openMM, spacing=1)/(md_time_per_step * num_md_steps)
    ratio_tic0 = ess_model_s / ess_md_s
    # check if we are missing states! And plot the correspoding TICA plots
    ranges = np.abs(tics[::10, :2]).max(0)
    max_dist = distance.cdist(tics[::10, :2]/ranges,tics_model[::10, :2]/ranges).min(axis=1).max()
    if max_dist < 0.3:
        evaluation_outputs[protein] = [acceptance,ratio_tic0, ratio_tic1]

    else:            
        #evaluation_outputs2[protein] = [acceptance,ratio_tic0, ratio_tic1]
        evaluation_outputs[protein] = [acceptance,0, 0]
        fig, ax = plt.subplots(figsize=(10,10))
        plot_tic01(ax, tics, f"MD", tics_lims=tics)
        plot_tic01(ax, tics_model, f"{protein} MD - model", cmap='autumn', tics_lims=tics)


In [None]:
save=False

In [None]:
#new_evaluations["CTSA"][1]=1e10
values = np.sort(np.array(list(evaluation_outputs.values()))[:, 1])[::-1]
not_found_peptides = []
#values[0] = 0

fig1, ax1 = plt.subplots(figsize=(12,6))

ax1.set_title('4AA MCMC - Speed-up TIC 0')
ax1.set_xticks([])
ax1.axhline(1, linestyle='--', color='firebrick')
ax1.set_ylabel("Speed-up factor")
ax1.bar(np.arange(len(values)), values)
ax1.set_xmargin(0.001)
ax1.set_xlabel("Tetrapeptides")
ax1.semilogy()
#plt.axhline(1, linestyle='--', color='firebrick')
ax1.grid(axis='y', linewidth=2.5)
#ax1.axvspan(-0.5, len(not_found_peptides)-0.55, alpha=0.2, color='green')
ax1.axvspan(len(values[values>0])-0.45+len(not_found_peptides), len(values)-0.5, alpha=0.2, color='grey')
if save:

    plt.savefig(base_dir+"outputs/figures/4AA-A100-speedup-bar-log-grey.svg", bbox_inches="tight")

In [None]:
np.median(values)