In [None]:
import sys
sys.path.append("../gcae_models/")

In [None]:
import dataset
from gcae_tcn import GVPEncoder, GVPDecoder, TCNModel
from gcae_transformer import TransformerEncoder
import matplotlib.pyplot as plt 
import mdtraj as md 
import os
import torch
import torch.nn as nn
import torch_geometric
from torch_geometric.data import Batch
from tqdm import tqdm 
from torch.utils.data import DataLoader
import torch_cluster

import matplotlib.animation as animation
from IPython.display import HTML

from sklearn.decomposition import PCA
from sklearn.manifold import TSNE 

import umap
import math

from functools import partial 

In [None]:
import matplotlib.gridspec as gridspec

plt.rcParams.update({
    "font.family": "serif",   # or "sans-serif"
    "font.size": 14,
    "axes.labelsize": 16,
    "axes.titlesize": 16,
    "xtick.labelsize": 12,
    "ytick.labelsize": 12
})

colors = ["#0072B2","#E69F00", 'darkred']

In [None]:
from forecasting_main import Config, Trainer
from forecasting_utils import get_forecasting_dataloader


In [None]:
def determine_device():
    if torch.backends.mps.is_available():
        return torch.device('mps')
    elif torch.cuda.is_available():
        return torch.device('cuda')
    else:
        return torch.device('cpu')

Load in Forecaster

In [None]:
dataset = 'latent_traj'
data_path = 'more_recent_models/tcn/pentapeptide/Falsenorm_10smoothing_2000ep_0.0005lr_0.1dr_2batch_3layers_10edge_16latentdim_0.0reg_0.0temp_TCN3layers_TCN5kernel/'
latent_filename = 'MORErecomputed_test_latents.pt'
time_lag = 1
stride = 1
latent_dim = 16
hidden_dim = 512
num_layers = 5
random_time_sampling = False
shuffle_batch = False
load_path = os.path.join(data_path, f'forecast_ckpts/latest_model_{time_lag}lag_{stride}stride.pt')
sigma_coef = 1.0
beta_fn = 't^2'
debug = False
sample_only = True
overfit = False

In [None]:
#if dataset == 'latent_traj':
conf = Config(
    dataset=dataset,
    debug=debug,
    overfit=overfit,
    sigma_coef=sigma_coef,
    beta_fn=beta_fn,
    latent_dim=latent_dim,
    hidden_dim=hidden_dim,
    num_layers=num_layers,
    data_path=data_path,
    latent_filename=latent_filename,
    time_lag=time_lag,
    stride=stride,
    random_time_sampling=random_time_sampling,
    shuffle_batch=shuffle_batch
)

trainer = Trainer(
    conf, 
    load_path = load_path,
    sample_only = sample_only
)

Load in autoencoder

In [None]:
model_path = data_path
temporal_model = 'tcn'
model_num_load = 1999
latent_dim = 16

top_k = 10
dr = 0.1
n_layers = 3
lr = 0.0005

T = 512
n_heads = 4
n_encoder_layers = 2
channel_size= [64,64,64]
kernel_size = 5
lambda_reg = 0
lambda_temp = 0

device = determine_device()


test_file = 'pentapeptide/split_25files_5001len_1chunks/file20_part0.xtc'
traj0_file = 'pentapeptide/split_25files_5001len_1chunks/file0_part0.xtc'
pdb = 'pentapeptide/pentapeptide-impl-solv.pdb'
if test_file[-4:] == '.xtc':
    traj0 = md.load_xtc(traj0_file, top=pdb)
else:
    traj0 = md.load_dcd(traj0_file, top=pdb)

test_structures = dataset.generate_structures([test_file], pdb, traj0, split=True, smooth=10)
test_frame_dataset = dataset.LigandDataset(test_structures, top_k=top_k)
test_seq_dataset = dataset.SequenceDataset([len(traj0)] , sequence_length=T, stride=T, include_partial=True)


node_h_dim = (100, 16)
edge_h_dim = (32, 1)
node_num = md.load(pdb).topology.n_residues

In [None]:
frame_num = 4340
if test_file[-4:] == '.xtc':
    traj = md.load_xtc(test_file, top=pdb)
else:
    traj = md.load_dcd(test_file, top=pdb)

traj.superpose(traj0, frame=0)
traj.center_coordinates()

rmsd = md.rmsd(traj, traj, frame=frame_num, precentered=True)

In [None]:
def collate_sparse(batch, frame_dataset):
    # 'batch' is a list of tuples: [(start0, end0), (start1, end1), ...]
    frames, seq_ptr = [], [0]

    for s, e in batch:
        frames.extend(frame_dataset[i] for i in range(s, e))
        seq_ptr.append(len(frames))

    out = Batch.from_data_list(frames)
    out.seq_ptr = torch.tensor(seq_ptr, dtype=torch.long)
    out.seq_len = e - s
    return out

In [None]:
test_loader = DataLoader(test_seq_dataset,
                         batch_size=1,
                         shuffle=False,
                         collate_fn = partial(collate_sparse, frame_dataset=test_frame_dataset),
                         num_workers=0,
                         drop_last=False)

In [None]:
encoder = GVPEncoder((6,3), node_h_dim, (32,1), edge_h_dim,
                        latent_dim=latent_dim,
                        n_layers= n_layers,
                        drop_rate= dr,
                        node_num=node_num).to(device)

decoder = GVPDecoder((6,3), node_h_dim, (32,1), edge_h_dim,
                        latent_dim=latent_dim,
                        n_layers= n_layers,
                        drop_rate= dr, dense_mode=True, node_num=node_num).to(device)

tcn = TCNModel(input_size=latent_dim, channel_size=channel_size, input_length=512, kernel_size=kernel_size).to(device)
transformer = TransformerEncoder(latent_dim, max_seq_len=T, nhead=n_heads,num_layers=n_encoder_layers, dropout=dr).to(device)

if temporal_model == 'tcn':
    optimizer = torch.optim.Adam(list(encoder.parameters()) + list(decoder.parameters()) + list(tcn.parameters()), lr=lr)
else:
    optimizer = torch.optim.Adam(list(encoder.parameters()) + list(decoder.parameters()) + list(transformer.parameters()), lr=lr)
loss_function = nn.MSELoss()

encoder_state_dict = torch.load(os.path.join(model_path, f'encoder_epoch-{model_num_load}.pt'), map_location=device, weights_only=True)
encoder.load_state_dict(encoder_state_dict)
encoder.to(device)

decoder_state_dict = torch.load(os.path.join(model_path, f'decoder_epoch-{model_num_load}.pt'), map_location=device, weights_only=True)
decoder.load_state_dict(decoder_state_dict)
decoder.to(device)
if temporal_model == 'tcn':
    tcn_state_dict = torch.load(os.path.join(model_path, f'tcn_epoch-{model_num_load}.pt'), map_location=device, weights_only=True)
    tcn.load_state_dict(tcn_state_dict)
    tcn.to(device)
else: 
    transformer_state_dict = torch.load(os.path.join(model_path, f'transformer_epoch-{model_num_load}.pt'), map_location=device, weights_only=True)
    transformer.load_state_dict(transformer_state_dict)
    transformer.to(device)


In [None]:
diffusion_fns = {
    'g_sigma' : None,
    'g_other' : lambda t: sigma_coef * trainer.wide(1-t).pow(4),
}

k = 'g_sigma'

In [None]:
def plot_3d(base_coords, forecasted_coords, true_coords, title=None, ax=None):
    base_coords = base_coords.detach().cpu().numpy()
    forecasted_coords = forecasted_coords.detach().cpu().numpy()
    true_coords = true_coords.detach().cpu().numpy()
    bx,by,bz = base_coords[:,0], base_coords[:,1], base_coords[:,2]
    fx,fy,fz = forecasted_coords[:,0], forecasted_coords[:,1], forecasted_coords[:,2]
    tx,ty,tz = true_coords[:,0], true_coords[:,1], true_coords[:,2]
    if ax is None:
        fig = plt.figure()
        ax = fig.add_subplot(111, projection='3d')

    ax.plot(bx, by, bz, marker='o', linestyle=':', color=colors[0], label='z0')
    ax.scatter(bx, by, bz, c=colors[0], s=30)

    ax.plot(fx, fy, fz, marker='o', linestyle='-', color=colors[1], label="Forecasted")
    ax.scatter(fx, fy, fz, c=colors[1], s=30)

    ax.plot(tx, ty, tz, marker='o', linestyle='--', color=colors[2], label="z1")
    ax.scatter(tx, ty, tz, c=colors[2], s=30)


    if title:
        ax.set_title(title)
    ax.set_xlabel('x')
    ax.set_ylabel('y')
    ax.set_zlabel('z')
    ax.legend()


In [None]:
z0s, samples, z1s = trainer.test(diffusion_fns[k])

In [None]:
load_latents = torch.load(os.path.join(data_path, latent_filename))


In [None]:
end_load_latents=load_latents[int(0.8*load_latents.shape[0]):,:]

In [None]:
print(end_load_latents.shape)

In [None]:
z0s = z0s.to(device)
z1s = z1s.to(device)

In [None]:
cond = end_load_latents[0,:]
cond = torch.unsqueeze(cond, dim=0)
print(cond.shape)
cond = cond.to(device)

In [None]:
sample_len = 1000

In [None]:
chained_samples = []
for i in range(sample_len):
    cond = trainer.EM(cond, cond.to(device), diffusion_fn=diffusion_fns[k], return_avg=False)
    chained_samples.append(cond)


In [None]:
og_latents = []
for i in range(sample_len):
    start = end_load_latents[i,:]
    start = torch.unsqueeze(start, dim=0)
    og_latents.append(start)
    



In [None]:
start_index = int(0.8*load_latents.shape[0])
first_forecasted_index = start_index + 1
end_forecasted_index = first_forecasted_index + sample_len
forecasted_indices = list(range(first_forecasted_index, end_forecasted_index))
print(len(forecasted_indices))
assert( len(forecasted_indices) == sample_len)
print("first index: ", forecasted_indices[0], "last index: ", forecasted_indices[-1])

forecasted_latents = {idx: z.to('cpu') for idx, z in zip(forecasted_indices, chained_samples)}
og_latents = {idx-1: z.to('cpu') for idx, z in zip(forecasted_indices, og_latents)}

In [None]:
encoder = encoder.to(device)
decoder = decoder.to(device)
if temporal_model == 'transformer':
    transformer = transformer.to(device)
else:
    tcn = tcn.to(device)

In [None]:
import numpy as np 

def _normalize(tensor, dim=-1):
    '''
    Normalizes a `torch.Tensor` along dimension `dim` without `nan`s.
    '''
    return torch.nan_to_num(
        torch.div(tensor, torch.norm(tensor, dim=dim, keepdim=True)))


def _rbf(D, D_min=0., D_max=20., D_count=16, device='cpu'):
    '''
    From https://github.com/jingraham/neurips19-graph-protein-design
    
    Returns an RBF embedding of `torch.Tensor` `D` along a new axis=-1.
    That is, if `D` has shape [...dims], then the returned tensor will have
    shape [...dims, D_count].
    '''
    D_mu = torch.linspace(D_min, D_max, D_count, device=device)
    D_mu = D_mu.view([1, -1])
    D_sigma = (D_max - D_min) / D_count
    D_expand = torch.unsqueeze(D, -1)

    RBF = torch.exp(-((D_expand - D_mu) / D_sigma) ** 2)
    return RBF


def _positional_embeddings(edge_index, 
                            num_embeddings=None,
                            period_range=[2, 1000], device='cpu'):
    # From https://github.com/jingraham/neurips19-graph-protein-design
    num_embeddings = 16#num_embeddings or self.num_positional_embeddings
    d = edge_index[0] - edge_index[1]
    
    frequency = torch.exp(
        torch.arange(0, num_embeddings, 2, dtype=torch.float32, device=device)
        * -(np.log(10000.0) / num_embeddings)
    )
    angles = d.unsqueeze(-1) * frequency
    E = torch.cat((torch.cos(angles), torch.sin(angles)), -1)
    return E



def make_edge_features(X_ca, top_k, num_rbf, device):
    edge_index = torch_cluster.knn_graph(X_ca, k=top_k)
    E_vectors = X_ca[edge_index[0]] - X_ca[edge_index[1]]
    pos_embeddings = _positional_embeddings(edge_index, device=device)
    
    rbf = _rbf(E_vectors.norm(dim=-1), D_count=num_rbf, device=device)
            
    edge_s = torch.cat([rbf, pos_embeddings], dim=-1)
    edge_v = _normalize(E_vectors).unsqueeze(-2)
            
    edge_s, edge_v = map(torch.nan_to_num,
                    (edge_s, edge_v))
    
    return edge_index, (edge_s, edge_v)

In [None]:
device=determine_device()

In [None]:
encoder.eval()
decoder.eval()
if temporal_model == 'tcn':
    tcn.eval()
else:
    transformer.eval()

test_loss = 0.0
all_latents = []
all_og_pred = []
all_forecasted_pred = []
all_og_starts = []
first = True 
previous_coord = None 

with torch.no_grad():
    for batch_idx, batch in enumerate(test_loader):
        print("batch: ", batch_idx)
        B = batch.num_graphs
        N = batch.x.size(0) //B
        print("B:", B)
        
        global_start = test_seq_dataset.get_global_start(batch_idx)
        global_indices = list(range(global_start, global_start + B))

        print("  Index start: ", global_indices[0], " index end: ", global_indices[-1])

        matched_local_indices = [
            local_i for local_i, g_idx in enumerate(global_indices) if g_idx in forecasted_latents.keys()
        ]

        if not matched_local_indices:
            continue

        fore_preds = torch.empty(len(matched_local_indices), N, 3)
        og_preds = torch.empty(len(matched_local_indices), N, 3)
        og_starts = torch.empty(len(matched_local_indices), N, 3)

        print(f"Processing batch {batch_idx}, global_start={global_start}")
        
        batch = batch.to(device)
        nodes = (batch.node_s, batch.node_v)
        edgesOG = (batch.edge_s, batch.edge_v)

        z = encoder(nodes, batch.edge_index, edgesOG)
        z_seq = torch.stack([z[start:end] for start,end in zip(batch.seq_ptr[:-1], batch.seq_ptr[1:])])
        z_seq_out = tcn(z_seq) if temporal_model == 'tcn' else transformer(z_seq)

        z_out = torch.cat([seq for seq in z_seq_out], dim=0)
        z_out_cpy = z_out.clone()

        for ii,local_i in enumerate(matched_local_indices):
            
            global_i = global_indices[local_i]
            z0_global = global_i - 1

            forecast_replacement = forecasted_latents[global_i]
            og_latent = og_latents[z0_global]
            z_out[local_i] = forecast_replacement
            if local_i > 0:
                z_out_cpy[local_i-1] = og_latent
            GT = batch.x.view(B, N, 3)
            if first:
                print("FIRST INDEX: ", global_i)
                coords_curr = GT[local_i].clone()
                fore_preds[ii,:,:] = coords_curr 
                first = False
            elif ii==0:
                print("start of new batch!")
                coords_curr = previous_coord 
            else:
                coords_curr = fore_preds[ii-1,:,:] 
            

            edge_index, edges = make_edge_features(coords_curr, top_k=top_k, num_rbf=16,device=device)
            og_edge_index, og_edges = make_edge_features(GT[local_i].clone(), top_k=top_k, num_rbf=16,device=device)
   
            forecast_pred = decoder(torch.unsqueeze(z_out[local_i], dim=0), edge_index, edges)

            og_pred = decoder(torch.unsqueeze(z_out_cpy[local_i], dim=0), og_edge_index, og_edges)

            og_start = decoder(torch.unsqueeze(z_out_cpy[local_i-1], dim=0), og_edge_index, og_edges)


            if ii == len(matched_local_indices)-1:
                print("END OF CURRENT BATCH. moving on")
                previous_coord = forecast_pred

            fore_preds[ii,:,:] = forecast_pred
            og_preds[ii,:,:] = og_pred
            og_starts[ii,:,:] = og_start
        all_forecasted_pred.append(fore_preds)
        all_og_pred.append(og_preds)
        all_og_starts.append(og_starts)







In [None]:
forecasted_pred_coords = torch.cat(all_forecasted_pred, dim=0)

og_pred_coords = torch.cat(all_og_pred, dim=0)

og_start_coords = torch.cat(all_og_starts, dim=0)

print(forecasted_pred_coords.shape)
print(og_pred_coords.shape)
print(og_start_coords.shape)

In [None]:
all_og_starts = og_start_coords
all_forecasted_pred = forecasted_pred_coords
all_og_pred = og_pred_coords 

In [None]:
bad = ~torch.isfinite(all_og_pred) 
coords = torch.where(bad)          # tuple of index tensors, one per dimension
# e.g., for 2D:
rows, cols, z = coords
print("ROWS", rows, "COLS", cols, "z", z)

In [None]:
testing_frames = [1,100,330,462]


In [None]:
fig = plt.figure(figsize=(12, 8))
for i, frame in enumerate(testing_frames):
    ax = fig.add_subplot(2, 2, i+1, projection='3d')
    plot_3d(
        base_coords=og_start_coords[frame],
        forecasted_coords=forecasted_pred_coords[frame],
        true_coords=og_pred_coords[frame],
        title=f"Sample {frame}",
        ax=ax
    )

plt.tight_layout()
plt.savefig("figs/better_chain_forecast_a.pdf")
#plt.show()

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation, FFMpegWriter, PillowWriter

# Pick the frames you want to animate
frames = range(0, 1000)

def to_np(a):
    # works for torch tensors or numpy arrays
    try:
        return a.detach().cpu().numpy()
    except AttributeError:
        return np.asarray(a)

# Compute global axis limits so the view doesn't jump
def stack_all(coords_src, idx_list):
    return np.vstack([to_np(coords_src[i]) for i in idx_list])

all_xyz = np.vstack([
    stack_all(all_og_starts, frames),
    stack_all(all_forecasted_pred, frames),
    stack_all(all_og_pred, frames),
])
rng = all_xyz.max(axis=0) - all_xyz.min(axis=0)
pad = 0.05 * (rng + 1e-12)
xyz_min = all_xyz.min(axis=0) - pad
xyz_max = all_xyz.max(axis=0) + pad

# Figure/axes
fig = plt.figure(figsize=(12, 10), dpi=300)
ax = fig.add_subplot(111, projection='3d')
ax.set_xlim(xyz_min[0], xyz_max[0])
ax.set_ylim(xyz_min[1], xyz_max[1])
ax.set_zlim(xyz_min[2], xyz_max[2])
ax.set_box_aspect((xyz_max - xyz_min))
ax.set_xlabel('x'); ax.set_ylabel('y'); ax.set_zlabel('z')

# --- Create artists ONCE (3 lines + 3 scatters) ---
# Use your same colors & styles
line_b,   = ax.plot([], [], [], marker='o', linestyle=':',  color=colors[0], label='z0')
scat_b     = ax.scatter([], [], [], c=colors[0], s=30)

line_f,   = ax.plot([], [], [], marker='o', linestyle='-',  color=colors[1], label='Forecasted')
scat_f     = ax.scatter([], [], [], c=colors[1], s=30)

line_t,   = ax.plot([], [], [], marker='o', linestyle='--', color=colors[2], label='z1')
scat_t     = ax.scatter([], [], [], c=colors[2], s=30)

ax.legend()
artists = [line_b, line_f, line_t, scat_b, scat_f, scat_t]

def get_coords(frame):
    b = to_np(all_og_starts[frame])
    f = to_np(all_forecasted_pred[frame])
    t = to_np(all_og_pred[frame])
    return b[:,0], b[:,1], b[:,2], f[:,0], f[:,1], f[:,2], t[:,0], t[:,1], t[:,2]

def init():
    # nothing to clear; just return artists so blitting knows what to draw
    ax.set_title("Initializing…")
    return artists

def update(k):
    frame = frames[k]
    bx,by,bz, fx,fy,fz, tx,ty,tz = get_coords(frame)

    # Update lines
    line_b.set_data_3d(bx, by, bz)
    line_f.set_data_3d(fx, fy, fz)
    line_t.set_data_3d(tx, ty, tz)

    # Update scatters (3D scatter uses _offsets3d)
    scat_b._offsets3d = (bx, by, bz)
    scat_f._offsets3d = (fx, fy, fz)
    scat_t._offsets3d = (tx, ty, tz)

    ax.set_title(f"Sample {frame}")
    return artists

anim = FuncAnimation(fig, update, init_func=init, frames=len(frames),
                     interval=500, blit=True)

plt.show()

# Save if you want:
anim.save("figs/ACTUAL_CHAINEDforecast_vs_truth.mp4", writer=FFMpegWriter(fps=8, bitrate=5000))


In [None]:
import torch
import numpy as np

@torch.no_grad()
def tf_vs_free_run_gaps_EM(trainer, samples, chained_samples, z0, t0=0, H=100, diffusion_fn=None):
    """
    trainer.EM(base=z0, cond=z0, diffusion_fn=...) -> (1,D) 
    Compare step-ahead vs chained samples 
    Returns: (mse_tf, mse_free, gap) each shape (H,)
    """
    
    tf_preds = samples 

    free_preds = torch.cat(chained_samples, dim=0)  # (H,D)
    
    tgt =z0
    print("tgt:", tgt.shape)
    print("free preds: ", free_preds.shape)
    print("tf_preds: ", tf_preds.shape)
    mse_tf   = ((tf_preds   - tgt)**2).mean(dim=1).detach().cpu().numpy()
    mse_free = ((free_preds - tgt)**2).mean(dim=1).detach().cpu().numpy()
    gap = mse_free - mse_tf
    return mse_tf, mse_free, gap



In [None]:
mse_tf, mse_free, gap = tf_vs_free_run_gaps_EM(trainer, samples, chained_samples,z1s, H=1000)

In [None]:
print("MSE step ahead forecast", mse_tf.mean())
print("MSE chain forecast", mse_free.mean())
print("Gap btwn", gap.mean())

In [None]:
print("step ahead MSE (first 5):   ", np.round(mse_tf[:5], 6))
print("chain MSE (first 5): ", np.round(mse_free[:5], 6))
print("GAP (first 5):      ", np.round(gap[:5], 6))
print("GAP median / max:   ", float(np.median(gap)), float(np.max(gap)))

In [None]:
H = 1000
horizons = np.arange(1, H+1)
plt.figure(figsize=(10,8))
plt.plot(horizons[:500], mse_tf[:500], label='Step Ahead Prediction',c=colors[2],linewidth=2)
plt.plot(horizons[:500], mse_free[:500], label='Chained Forecast',c=colors[1],linewidth=2)
plt.xlabel("Forecast horizon (steps)")
plt.ylabel('MSE')
plt.title("MSE vs Horizon")
plt.legend()
plt.savefig("figs/forecast_mse.pdf")