In [None]:
import dataset
from gcae_tcn import GVPEncoder, GVPDecoder, TCNModel
from gcae_transformer import TransformerEncoder
import matplotlib.pyplot as plt 
import mdtraj as md 
import os
import torch
import torch.nn as nn
import torch_geometric
from torch_geometric.data import Batch
from tqdm import tqdm 
from torch.utils.data import DataLoader

import matplotlib.animation as animation
from IPython.display import HTML

from sklearn.decomposition import PCA
from sklearn.manifold import TSNE 

import umap
import math

from functools import partial 

In [None]:
def determine_device():
    if torch.backends.mps.is_available():
        return torch.device('mps')
    elif torch.cuda.is_available():
        return torch.device('cuda')
    else:
        return torch.device('cpu')

In [None]:
import matplotlib.gridspec as gridspec

plt.rcParams.update({
    "font.family": "serif",   # or "sans-serif"
    "font.size": 14,
    "axes.labelsize": 16,
    "axes.titlesize": 16,
    "xtick.labelsize": 12,
    "ytick.labelsize": 12
})

# Load in latent from path 

In [None]:
split= False
protein = 'chignolin' #'pentapeptide' #'chignolin' # or, 'fs_peptide'
if protein == 'chignolin':
    pdb,files = dataset.load_DESRES(protein, data_dir='./', analyze=True, plot=False)
    test_traj_num = 26
    print(files[test_traj_num])
    

elif protein == 'pentapeptide':
    pdb,files = dataset.load_pentapeptide(data_dir='pentapeptide/')
    files = dataset.reshape_time_window(files, pdb, num_files=25, traj_len=5001, num_split=1)
    test_traj_num = 12
    print(files[test_traj_num])
    split=True

elif protein == 'fs_peptide':
    pdb, files = dataset.load_fs_peptide(data_dir='fs_peptide/')
    files = dataset.reshape_time_window(files, pdb, num_files=28, traj_len=10000, num_split=1)
    test_traj_num = 12
    print(files[test_traj_num])
    split=True


file = files[test_traj_num]
file0 = files[0]


In [None]:
path = 'gcae_experiments/more_recent_models/transformer/chignolin/Falsenorm_10smoothing_5000ep_0.0005lr_0.1dr_3layers_5edge_32latentdim_1e-06reg_1e-06temp_4nheads_2encoderlayers_512maxseqlen_CAUSAL/'
model_num_load=2649
latent_dim = 32
temporal = 'transformer'
normalize=False

top_k = 10
n_layers=3
dr = 0.1
num_layers = 3
lr = 5e-4
smooth = 10

T = 512
n_heads = 4
n_encoder_layers = 2
lambda_reg = 0
lambda_temp = 0

kernel = 5
layers = [64,64,64]


if files[0][-4:] == '.xtc':
    traj0 = md.load_xtc(files[0], top=pdb)
else:
    traj0 = md.load_dcd(files[0], top=pdb)
topology = md.load(pdb).topology


test_structures = dataset.generate_structures([files[test_traj_num]], pdb, traj0, split=split, smooth=smooth, normalize=normalize)
test_frame_dataset = dataset.LigandDataset(test_structures, top_k=top_k)
test_seq_dataset = dataset.SequenceDataset([len(traj0)], sequence_length = T, stride=T, include_partial=True)



device = determine_device()
print(f"On device: {device}")

node_h_dim = (100, 16)
edge_h_dim = (32, 1)
node_num = md.load(pdb).topology.n_residues

In [None]:
def collate_sparse(batch, frame_dataset):
    # 'batch' is a list of tuples: [(start0, end0), (start1, end1), ...]
    frames, seq_ptr = [], [0]

    for s, e in batch:
        frames.extend(frame_dataset[i] for i in range(s, e))
        seq_ptr.append(len(frames))

    # This is the key line — returns a PyG Batch
    out = Batch.from_data_list(frames)
    out.seq_ptr = torch.tensor(seq_ptr, dtype=torch.long)
    out.seq_len = e - s
    return out



In [None]:
test_loader = DataLoader(test_seq_dataset,
                         batch_size=1,
                         shuffle=False,
                         collate_fn = partial(collate_sparse, frame_dataset=test_frame_dataset),
                         num_workers=0,
                         drop_last=False)


In [None]:
encoder = GVPEncoder((6,3), node_h_dim, (32,1), edge_h_dim,
                        latent_dim=latent_dim,
                        n_layers= n_layers,
                        drop_rate= dr,
                        node_num=node_num).to(device)

decoder = GVPDecoder((6,3), node_h_dim, (32,1), edge_h_dim,
                        latent_dim=latent_dim,
                        n_layers= n_layers,
                        drop_rate= dr, dense_mode=True, node_num=node_num).to(device)

tcn = TCNModel(input_size=latent_dim, channel_size=layers, input_length=512, kernel_size=kernel).to(device)
transformer = TransformerEncoder(latent_dim, max_seq_len=T, nhead=n_heads,num_layers=n_encoder_layers, dropout=dr)

if temporal == 'tcn':
    optimizer = torch.optim.Adam(list(encoder.parameters()) + list(decoder.parameters()) + list(tcn.parameters()), lr=lr)
else:
    optimizer = torch.optim.Adam(list(encoder.parameters()) + list(decoder.parameters()) + list(transformer.parameters()), lr=lr)
loss_function = nn.MSELoss()

encoder_state_dict = torch.load(os.path.join(path, f'encoder_epoch-{model_num_load}.pt'), map_location=device, weights_only=True)
encoder.load_state_dict(encoder_state_dict)
encoder.to(device)

decoder_state_dict = torch.load(os.path.join(path, f'decoder_epoch-{model_num_load}.pt'), map_location=device, weights_only=True)
decoder.load_state_dict(decoder_state_dict)
decoder.to(device)
if temporal == 'tcn':
    tcn_state_dict = torch.load(os.path.join(path, f'tcn_epoch-{model_num_load}.pt'), map_location=device, weights_only=True)
    tcn.load_state_dict(tcn_state_dict)
    tcn.to(device)
else: 
    transformer_state_dict = torch.load(os.path.join(path, f'transformer_epoch-{model_num_load}.pt'), map_location=device, weights_only=True)
    transformer.load_state_dict(transformer_state_dict)
    transformer.to(device)


In [None]:
losses = torch.load(os.path.join(path, 'losses.pt'), map_location='cpu')

losses_values = list(losses.values())
loss = losses_values#[loss[0] for loss in losses_values]
# loss = losses_values
plt.plot(range(len(loss)), loss)
plt.yscale('log')
plt.show()

In [None]:
encoder.eval()
decoder.eval()
if temporal == 'tcn':
    tcn.eval()
else:
    transformer.eval()

test_loss = 0.0
all_mu = []
with torch.no_grad():
    for i, batch in enumerate(test_loader):
        
        batch = batch.to(device)
        nodes = (batch.node_s, batch.node_v)
        edges = (batch.edge_s, batch.edge_v)

        GT = batch.x 

        z = encoder(nodes, batch.edge_index, edges)
        z_seq = torch.stack([z[start:end] for start,end in zip(batch.seq_ptr[:-1], batch.seq_ptr[1:])])
        if temporal == 'tcn':
            z_seq_out = tcn(z_seq)
        else:
            z_seq_out = transformer(z_seq)
        z_out = torch.cat([seq for seq in z_seq_out], dim=0)
        pred = decoder(z_out, batch.edge_index, edges)
        reg = torch.mean(torch.norm(z, dim=1)**2)

        temp = torch.mean((z[:, 1:] - z[:, :-1]) **2)
        
        loss_A = loss_function(GT, pred)
        loss = loss_A + lambda_reg * reg + lambda_temp * temp #1e-6 * reg

        all_mu.append(z_out.cpu())

        test_loss += loss.item()
        
        
latents = torch.cat(all_mu, dim=0)
latent_variances = torch.var(latents, dim=0)
test_loss /= len(test_loader)

rmsd_angstroms = math.sqrt(test_loss) * 10

print(f"RMSD on test data {rmsd_angstroms:.8f} Å")

print("Per-latent-dimension variance across dataset:")
for i, var in enumerate(latent_variances):
    print(f"Dimension {i}: variance = {var.item():.4f}")

torch.save(latents, os.path.join(path, "MORErecomputed_test_latents.pt"))






In [None]:

pca = PCA(n_components=2)
latent_pca = pca.fit_transform(latents)


tsne = TSNE(n_components=2, perplexity=50, learning_rate='auto')
latent_tsne = tsne.fit_transform(latents)

reducer = umap.UMAP()
embedding = reducer.fit_transform(latents)

In [None]:
f, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(16,6), constrained_layout=True)


ax1.hexbin(latent_pca[:,0], latent_pca[:,1], bins='log')
ax1.set_title("PCA of Latent Codes")
ax1.set_xlabel("PC1")
ax1.set_ylabel("PC2")

ax2.scatter(latent_tsne[:,0], latent_tsne[:,1], alpha=0.5)
ax2.set_title("t-SNE of Latent Codes")
ax2.set_xlabel("t-SNE dim 1")
ax2.set_ylabel("t-SNE dim 2")

ax3.scatter(embedding[:,0], embedding[:,1], alpha=0.5)
ax3.set_title("UMAP of Latent Codes")
ax3.set_xlabel("UMAP dim 1")
ax3.set_ylabel("UMAP dim 2")

plt.show()

In [None]:
dim_model =  embedding #latent_pca #latent_tsne #embedding
center = (0,7)
x_radius = 1
y_radius = 1

In [None]:
valid_points = []
for i in range(dim_model.shape[0]):
    if (center[0] - x_radius <= dim_model[i, 0]) and (dim_model[i,0] <= center[0] + x_radius):
        if (center[1] - y_radius <= dim_model[i, 1]) and (dim_model[i,1] <= center[1] + y_radius):
            valid_points.append(i)

print(len(valid_points))

print(valid_points)

In [None]:
frame_num=507
if file[-4:] == '.xtc':
    traj = md.load_xtc(file, top=pdb)
    traj0 = md.load_xtc(file0, top=pdb)
else:
    traj = md.load_dcd(file, top=pdb)
    traj0 = md.load_dcd(file0, top=pdb)

traj.superpose(traj0, frame=0)
traj.center_coordinates()

rmsd = md.rmsd(traj, traj, frame=frame_num, precentered=True)

In [None]:
f, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(16,6), constrained_layout=True)

sc1 = ax1.scatter(latent_pca[:, 0], latent_pca[:, 1], c=rmsd, cmap='viridis', s=10)
ax1.scatter(latent_pca[frame_num, 0], latent_pca[frame_num, 1], c='r', s=120,  marker='*')
ax1.set_title("PCA of Latent Codes")
ax1.set_xlabel("PC1")
ax1.set_ylabel("PC2")

ax2.scatter(latent_tsne[:, 0], latent_tsne[:, 1], c=rmsd, cmap='viridis', s=10)
ax2.scatter(latent_tsne[frame_num, 0], latent_tsne[frame_num, 1], c='r', s=120,  marker='*')
ax2.set_title("t-SNE of Latent Codes")
ax2.set_xlabel("t-SNE dim 1")
ax2.set_ylabel("t-SNE dim 2")

ax3.scatter(embedding[:, 0], embedding[:, 1], c=rmsd, cmap='viridis', s=10)
ax3.scatter(embedding[frame_num, 0], embedding[frame_num, 1], c='r', s=120,  marker='*')
ax3.set_title("UMAP of Latent Codes")
ax3.set_xlabel("UMAP dim 1")
ax3.set_ylabel("UMAP dim 2")

cbar = f.colorbar(sc1, ax=[ax1, ax2, ax3], orientation="vertical")
cbar.set_label("RMSD")

#plt.show()
plt.savefig(f'/Users/lreeder/Documents/Stanford/5fifth_year/thesis/figs/{protein}/{temporal}_rmsd{frame_num}.pdf', bbox_inches='tight')


In [None]:
x = latent_pca[:,0]
y = latent_pca[:,1]
starting_frame = 0
animation_len =  100


fig, ax = plt.subplots(figsize=(6,6))
line, = ax.plot([], [], lw=2, color='red')
point, = ax.plot([], [], 'ro')
ax.hexbin(latent_pca[:,0], latent_pca[:,1], bins='log')
#ax.hexbin(latent_pca[:,0], latent_pca[:,1], bins='log')
ax.set_xlim(x.min() - 0.01, x.max() + 0.01)
ax.set_ylim(y.min() - 0.01, y.max() + 0.01)
ax.set_title("2D Trajectory over time - latent dimension 1 and 2")
ax.set_xlabel("Latent dim 1")
ax.set_ylabel("Latent dim 2")
ax.grid(True)
ax.set_aspect('equal')
timestamp = ax.text(0.02, 0.95, '', transform=ax.transAxes, fontsize=12, ha='left', va='top')
def init():
    line.set_data([], [])
    point.set_data([], [])
    return line, point

def old_update(frame):
    if frame < animation_len:
        line.set_data(x[starting_frame:frame+1], y[starting_frame:frame+1])
        point.set_data(x[starting_frame + frame], y[starting_frame + frame])
        timestamp.set_text(f'Time step: {starting_frame + frame}')
    else:
        line.set_data(x, y)
        point.set_data([x[starting_frame + animation_len-1]], [y[starting_frame + animation_len-1]])
        timestamp.set_text(f'Time step: {starting_frame + animation_len-1}')
    return line, point, timestamp

def update(frame):
    line.set_data(x[starting_frame:starting_frame +frame+1], y[starting_frame:starting_frame +frame+1])
    point.set_data([x[starting_frame + frame]], [y[starting_frame + frame]])
    return line, point


ani = animation.FuncAnimation(fig, update, frames=animation_len, init_func=init, interval=200, blit=True)
HTML(ani.to_jshtml())
# or for a GIF:
#ani.save('trajectory.gif', writer='pillow', fps=20)

In [None]:
fig, axs = plt.subplots(2,1, figsize=(10,5))
animation_len = 200
axs[0].plot(range(animation_len), x[starting_frame:starting_frame + animation_len], color="darkred", linewidth=2)
axs[0].set_title("PC 1")
axs[0].set_xlabel("t")
axs[0].spines['top'].set_visible(False)
axs[0].spines['right'].set_visible(False)
axs[0].grid(alpha=0.3)

axs[1].plot(range(animation_len), y[starting_frame:starting_frame + animation_len], color='darkred', linewidth=2)
axs[1].set_title("PC 2")
axs[1].set_xlabel("t")
axs[1].spines['top'].set_visible(False)
axs[1].spines['right'].set_visible(False)
axs[1].grid(alpha=0.3)


plt.tight_layout()
plt.show()

# Autocorrelation Plots

In [None]:
no_time_path = '/Users/lreeder/Documents/Stanford/research/gnn_md/redo_TICA/new_gvp_training/'

no_time_latents = torch.load(os.path.join(no_time_path, 'inference_latents.pt'))


some_time_path = 'more_recent_models/no_time/pentapeptide/nonormal_10smoothing_1000ep_0.0005lr_0.1dr_3layers_10edge_12latentdim_1e-06reg_0.0temp/'
some_time_latents = torch.load(os.path.join(some_time_path, 'test_latents.pt'))

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from statsmodels.tsa.stattools import acf

def latent_acf(latents, max_lag=200, demean=True):
    """
    latents: array (T, d) or list of arrays per-trajectory [(T1,d), (T2,d), ...]
    returns: mean_acf (max_lag+1,), acfs_per_dim (d, max_lag+1)
    """
    # If you have multiple independent trajectories, compute per-trajectory ACF and average.
    if isinstance(latents, list):
        acf_stack = []
        for Z in latents:
            if demean:
                Z = Z - Z.mean(axis=0, keepdims=True)
            per_dim = [acf(Z[:, i], nlags=max_lag, fft=True) for i in range(Z.shape[1])]
            acf_stack.append(np.vstack(per_dim))           # (d, max_lag+1)
        acf_stack = np.stack(acf_stack, axis=0)            # (n_traj, d, max_lag+1)
        acfs = acf_stack.mean(axis=0)                      # average over trajectories -> (d, max_lag+1)
    else:
        Z = latents - latents.mean(axis=0, keepdims=True) if demean else latents
        acfs = np.vstack([acf(Z[:, i], nlags=max_lag, fft=True) for i in range(Z.shape[1])])

    mean_acf = acfs.mean(axis=0)
    return mean_acf, acfs

def plot_acf(mean_acf, acfs=None, label=None, show_dims=False):
    lags = np.arange(len(mean_acf))
    plt.figure(figsize=(6,4))
    if show_dims and acfs is not None:
        for i in range(acfs.shape[0]):
            plt.plot(lags, acfs[i], alpha=0.15, linewidth=1)
    plt.plot(lags, mean_acf, linewidth=2, label=label or "mean ACF", color='darkred')
    plt.axhline(0, linestyle="--", linewidth=1)
    plt.xlabel("Lag (frames)")
    plt.ylabel("Autocorrelation")
    if label:
        plt.legend()
    plt.tight_layout()
    plt.show()


In [None]:
mean_acf_base, acfs_base = latent_acf(no_time_latents, max_lag=300)
mean_acf_seq,  acfs_seq = latent_acf(latents,  max_lag=300)
mean_acf_some_time, acfs_some = latent_acf(some_time_latents, max_lag=300)
colors = ["#0072B2","#E69F00", 'darkred']
lags = np.arange(len(mean_acf_base))
plt.figure(figsize=(6,4))
plt.plot(lags, mean_acf_base, label="Individual frames", linewidth=2, c=colors[0])
plt.plot(lags, mean_acf_some_time, label="Implicit time", linewidth=2, c=colors[1])
plt.plot(lags, mean_acf_seq,  label="Explicit sequence model", linewidth=2, c='darkred')
plt.axhline(0, linestyle="--", linewidth=1)
plt.xlabel("Lag (frames)")
plt.ylabel("Autocorrelation")
plt.legend()
plt.savefig(f'/Users/lreeder/Documents/Stanford/5fifth_year/thesis/figs/{protein}/{temporal}_ACFvsLag.pdf', bbox_inches='tight')

In [None]:
def correlation_time(mean_acf, max_fit_lag=100):
    y = mean_acf[1:max_fit_lag+1]
    lags = np.arange(1, max_fit_lag+1)
    # Use only positive ACF to avoid log of nonpositive
    mask = y > 0
    lags, y = lags[mask], y[mask]
    if len(y) < 5:
        return np.nan
    # Fit log y ≈ -lags/τ + c  → τ = -1 / slope
    slope, intercept = np.polyfit(lags, np.log(y), 1)
    tau = -1.0 / slope if slope < 0 else np.nan
    return tau

tau_base = correlation_time(mean_acf_base, max_fit_lag=300)
tau_some = correlation_time(mean_acf_some_time, max_fit_lag=300)
tau_seq  = correlation_time(mean_acf_seq,  max_fit_lag=300)
print(f"τ_c (no sequence): {tau_base:.1f} frames")
print(f"τ_c (some_time):    {tau_some:.1f} frames")
print(f"τ_c (sequence):    {tau_seq:.1f} frames")

In [None]:
import numpy as np
from statsmodels.tsa.stattools import acf

def acf_per_dim(latents, max_lag=300, demean=True):
    """
    latents: (T, d) array OR list of arrays [(T1,d), (T2,d), ...]
    returns:
      acf_dims: (d, max_lag+1)  # mean ACF per dimension, averaged over trajectories
      acf_dims_all: list of arrays length n_traj, each (d, max_lag+1)
    """
    # Gather trajectories as a list
    Z_list = latents if isinstance(latents, list) else [latents]
    # Sanity: all have same d
    d = Z_list[0].shape[1]
    acf_dims_all = []
    for Z in Z_list:
        if demean:
            Z = Z - Z.mean(axis=0, keepdims=True)
        acfs = np.vstack([acf(Z[:, i], nlags=max_lag, fft=True) for i in range(d)])  # (d, L)
        acf_dims_all.append(acfs)
    acf_dims = np.mean(np.stack(acf_dims_all, axis=0), axis=0)  # (d, L)
    return acf_dims, acf_dims_all


In [None]:
dim_acf_base, acfs_base = acf_per_dim(no_time_latents, max_lag=300)
dim_acf_seq,  acfs_seq =acf_per_dim(latents,  max_lag=300)
dim_acf_some_time, acfs_some = acf_per_dim(some_time_latents, max_lag=300)

In [None]:
import matplotlib.pyplot as plt
import math

def plot_acf_per_dim(acf_dims, max_cols=4, cutoff=None, title=None):
    """
    acf_dims: (d, L)   from acf_per_dim
    cutoff: truncate lags to show (e.g., 100); None -> full length
    """
    d, L = acf_dims.shape
    lags = np.arange(L)
    if cutoff is None: cutoff = L
    rows = math.ceil(d / max_cols)
    cols = min(d, max_cols)
    fig, axes = plt.subplots(rows, cols, figsize=(3.3*cols, 2.8*rows), squeeze=False)
    idx = 0
    for r in range(rows):
        for c in range(cols):
            ax = axes[r, c]
            if idx < d:
                ax.plot(lags[:cutoff], acf_dims[idx, :cutoff], linewidth=1.8)
                ax.set_title(f"dim {idx}")
                ax.set_xlim(0, cutoff-1)
                ax.set_ylim(max(1e-5, acf_dims[:,1:cutoff].min()), 1.0)
                if r == rows-1: ax.set_xlabel("Lag (frames)")
                if c == 0:      ax.set_ylabel("Autocorr (log)")
            else:
                ax.axis("off")
            idx += 1
    if title: fig.suptitle(title, y=1.02)
    fig.tight_layout()
    plt.show()


In [None]:
acf_dims_no, _  = acf_per_dim(no_time_latents,  max_lag=300)
acf_dims_imp, _ = acf_per_dim(some_time_latents, max_lag=300)
acf_dims_seq, _ = acf_per_dim(latents,      max_lag=300)

plot_acf_per_dim(acf_dims_no,  cutoff=100, title="No sequence: per-dim ACF")
plot_acf_per_dim(acf_dims_imp, cutoff=100, title="Implicit time: per-dim ACF")
plot_acf_per_dim(acf_dims_seq, cutoff=100, title="Sequence: per-dim ACF")


In [None]:
def heatmap_acf(acf_dims, cutoff=None, title=None):
    d, L = acf_dims.shape
    if cutoff is None: cutoff = L
    # Use log scale values but keep zeros positive
    A = np.clip(acf_dims[:, :cutoff], 1e-6, 1.0)
    fig, ax = plt.subplots(figsize=(8, max(2.5, d*0.25)))
    im = ax.imshow(A, aspect='auto', origin='lower')
    ax.set_xlabel("Lag (frames)")
    ax.set_ylabel("Latent dimension")
    ax.set_title(title or "ACF heatmap (clip to [1e-6, 1])")
    cbar = plt.colorbar(im, ax=ax)
    cbar.set_label("Autocorrelation")
    plt.tight_layout()
    plt.savefig(f'/Users/lreeder/Documents/Stanford/5fifth_year/thesis/figs/{protein}/{temporal}_ACFdim{d}.pdf', bbox_inches='tight')
    #plt.show()

heatmap_acf(acf_dims_imp, cutoff=150, title="Implicit time: ACF heatmap")
heatmap_acf(acf_dims_seq, cutoff=150, title="Sequence: ACF heatmap")

