In [None]:
from scoremd.data.dataset.protein import SingleProteinDataset
import jax
import jax.numpy as jnp

import matplotlib as mpl
mpl.rcParams.update({"font.size": 14, "axes.titlesize": 22, "axes.labelsize": 18})

system = 'bba' 

if system == 'chignolin':
    dataset = SingleProteinDataset(
        paths="./storage/deshaw/chignolin-0_ca.h5", tica_path="./storage/deshaw/chignolin_tica.pic", topology_path='./storage/deshaw/chignolin.pdb'
    )
elif system == 'bba':
    dataset = SingleProteinDataset(
        paths=["./storage/deshaw/bba-0_ca.h5", "./storage/deshaw/bba-1_ca.h5"],
        tica_path="./storage/deshaw/bba_tica.pic",
        topology_path="./storage/deshaw/bba.pdb",
    )
else:
    raise NotImplementedError

reference_tic0_langevin, reference_tic1_langevin = dataset.get_2d_features(jnp.array(dataset._dataset.xyz))  # Otherwise the dataset is shuffled

In [None]:
dataset.plot_2d(jnp.array(dataset._dataset.xyz), title=f"Reference {system}", free_energy_bar=True)

In [None]:
import matplotlib.pyplot as plt

def plot_tic_over_time(tic0, tic1):
    plt.plot(tic0, label='TIC 0')
    plt.plot(tic1, label='TIC 1')
    plt.ylim(min(dataset.range[0][0], dataset.range[1][0]), max(dataset.range[0][1], dataset.range[1][1]))
    plt.xlim(0, len(tic0))
    plt.xlabel('(Coarse-Grained) Timesteps')

plt.figure(figsize=(16, 4.8))
plt.title('Reference TIC over time')
plot_tic_over_time(reference_tic0_langevin[:100000][::10], reference_tic1_langevin[:100000][::10])
plt.legend()

In [None]:
import numpy as onp
import os

if system == 'chignolin':
    path = os.path.expanduser("../multirun/2025-07-21/16-49-36/0/out/chignolin_langevin_trajectories.npy")
elif system == 'bba':
    path = os.path.expanduser("../multirun/2025-09-18/22-00-04/0/out/bba_langevin_trajectories.npy")
else:
    raise NotImplementedError

model_langevin = jnp.array(onp.load(path))
model_langevin.shape

In [None]:
t0, t1 = dataset.get_2d_features(jnp.array(dataset._dataset.xyz).reshape(-1, model_langevin.shape[-1]))

In [None]:
import matplotlib.pyplot as plt
dataset.plot_2d(model_langevin[0], title='Model', free_energy_bar=True)
plt.show()

In [None]:
# chignolin
if system == 'chignolin':
    model_one_trajectory = model_langevin[2]
    model_one_trajectory_tic0, model_one_trajectory_tic1 = dataset.get_2d_features(model_one_trajectory)
    
    start_point = 110_000
    model_one_trajectory_limited = model_one_trajectory[start_point:][::25]
    model_one_trajectory_tic0_limited, model_one_trajectory_tic1_limited = model_one_trajectory_tic0[start_point:][::25], model_one_trajectory_tic1[start_point:][::25]
    annotations = {'1': 3900, '2': 4740, '3': 5000, '4': 7100}
    legend_loc, legend_anchor = 'upper left', (0.165, 1)
elif system == 'bba':
    model_one_trajectory = model_langevin[0]
    model_one_trajectory_tic0, model_one_trajectory_tic1 = dataset.get_2d_features(model_one_trajectory)
    
    start_point = 0
    every_n = 5
    max_step = 6500
    model_one_trajectory_limited = model_one_trajectory[start_point:][::every_n][:max_step]
    model_one_trajectory_tic0_limited, model_one_trajectory_tic1_limited = model_one_trajectory_tic0[start_point:][::every_n][:max_step], model_one_trajectory_tic1[start_point:][::every_n][:max_step]
    annotations = {'1': 330, '2': 2900, '3': 4000, '4': 6000}
    legend_loc, legend_anchor = 'lower right', None
else:
    raise NotImplementedError

In [None]:
import matplotlib as mpl
plt.style.use("tableau-colorblind10")
mpl.rcParams.update({"font.size": 14, "axes.titlesize": 22, "axes.labelsize": 18})


plt.figure(figsize=(16, 4.8))
plot_tic_over_time(model_one_trajectory_tic0_limited, model_one_trajectory_tic1_limited)
plt.legend(loc=legend_loc, bbox_to_anchor=legend_anchor)

for name, pos in annotations.items():
    plt.annotate(name, [pos, 0.5], xytext=(0, 0), 
                 textcoords="offset points", zorder=1000, color="white", 
                 ha="center", va="center_baseline", 
                 bbox=dict(boxstyle="circle,pad=0.35", fc="black", ec="none"))

plt.ylabel('TIC')
plt.savefig(f'{system}_both_tics_over_time.pdf', bbox_inches='tight')

In [None]:
shuffled_model_langevin = jax.random.permutation(jax.random.PRNGKey(0), model_langevin.reshape(-1, model_langevin.shape[-1]))[: reference_tic0_langevin.shape[0]]
dataset.plot_2d(shuffled_model_langevin)

In [None]:
from matplotlib import cm
cmap = cm.GnBu

# we only plot every 10th step
if system == 'chignolin':
    line_every = 5
elif system == 'bba':
    line_every = 2
else:
    raise NotImplementedError()
line_x, line_y = model_one_trajectory_tic0_limited[::line_every], model_one_trajectory_tic1_limited[::line_every]

dataset.plot_2d(shuffled_model_langevin, range=[[t0.min(), t0.max()], [t1.min(), t1.max()]])

for i in range(len(line_x) - 1):
    plt.plot(line_x[i: i+2], line_y[i:i+2], color=cmap(i / (len(line_x) - 1)), rasterized=True)

for name, pos in annotations.items():
    plt.annotate(name, [line_x[pos//line_every], line_y[pos//line_every]], xytext=(0, 0), 
                 textcoords="offset points", zorder=1000, color="white", 
                 ha="center", va="center_baseline", 
                 bbox=dict(boxstyle="circle,pad=0.35", fc="black", ec="none"))

plt.savefig(f'{system}_both_dynamics.svg', bbox_inches='tight', dpi=200)

In [None]:
from scoremd.rmsd import kabsch_align_many

model_one_trajectory_limited_aligned, _ = kabsch_align_many(model_one_trajectory_limited, model_one_trajectory_limited[0])

dataset.write_animation(model_one_trajectory_limited_aligned[list(annotations.values()),...], f'{system}_animation.pdb')

In [None]:
model_one_trajectory_limited.shape

In [None]:
# center all the molecules
import mdtraj as md

X = model_one_trajectory_limited.reshape(model_one_trajectory_limited.shape[0], -1, 3)
X_centered = X - X.mean(axis=1, keepdims=True)

dataset.write_animation(X_centered, f'{system}_all.pdb')
top = md.load(f"{system}_all.pdb").topology  # must match atom count

traj = md.Trajectory(X_centered, top)
traj.save_xtc(f"{system}_all.xtc")