# Notebook for evaluate tempering performance

In [None]:
from eval_util import *
import torch
from energy.a4 import A4, A6, AldpBoltzmann
from network.egnn import remove_mean
import numpy as np
import matplotlib.pyplot as plt
import mdtraj

In [None]:
# ============================================================
#                ✨ Hyperparameters ✨
# ============================================================

target_name = 'a4'
sample_path = ''
gt_sample_path = '6A_trajectory_1.0_600.0.h5'
device = 'cuda'

### Define targets

In [None]:
if target_name == 'a4':
    n_particles = 43
    target = A4(500, 'cuda', scaling=1.0)
elif target_name == 'a6':
    n_particles = 63
    target = A6(600, 'cuda', scaling=1.0)
elif target_name == 'aldp':
    n_particles = 22
    target = AldpBoltzmann(300, 'cuda')
else:
    raise ValueError(f"Unknown target: {target_name}")

### Load Data

In [None]:
# load ground truth data and calculate histograms

data_low = torch.from_numpy(remove_mean(mdtraj.load(gt_sample_path).xyz, n_particles, 3)).to(device).reshape(-1, n_particles*3) 
data_low = data_low.reshape(-1, n_particles, 3) - data_low.reshape(-1, n_particles, 3).mean(dim=1, keepdim=True)
data_low = data_low.reshape(-1, n_particles*3)
reference = data_low[0].clone()
# superimpose the data
data_low = superimpose_B_onto_A(reference.reshape(n_particles, 3), data_low.reshape(-1, n_particles, 3), np.arange(43))
data_low = data_low.reshape(-1, n_particles*3)

log_p = target.log_prob(data_low[np.random.choice(data_low.shape[0], 25000, replace=False)].reshape(-1, 43*3))
hist, bins = np.histogram(log_p.cpu().numpy(), bins=100, density=True)
hist = hist / np.sum(hist)
# get distance
def get_dist(x):
    x = (((x.reshape(-1, n_particles, 1, 3) - x.reshape(-1, 1, n_particles, 3))**2).sum(-1).sqrt()).cpu()
    diagx = torch.triu_indices(x.shape[1], x.shape[1], 1)
    return x[:, diagx[0], diagx[1]].flatten()

dist = get_dist(data_low[np.random.choice(data_low.shape[0], 50*500, replace=False)])
dist_hist, dist_bins = np.histogram(dist.cpu().numpy(), bins=100, density=True)
dist_hist = dist_hist / np.sum(dist_hist)

### Load Tempering Samples

In [None]:
all_samples = np.load(sample_path)

### Set TICA Parameters

In [None]:
data_tica = torch.from_numpy(remove_mean(mdtraj.load(gt_sample_path).xyz, n_particles, 3)).to(device).reshape(-1, n_particles, 3)

import torch
import numpy as np
import mdtraj as md
import pyemma
import matplotlib.pyplot as plt


# Build an mdtraj.Trajectory from your tensor
xyz = data_tica.detach().cpu().numpy()            # (T, N, 3)
top = md.Topology.from_openmm(target.system.topology)
traj = md.Trajectory(xyz, topology=top)

# Backbone torsions (ϕ, ψ) with cosine/sine to handle periodicity
phi_idx,  phi = md.compute_phi(traj)    # radians, shape (T, n_phi)
psi_idx,  psi = md.compute_psi(traj)    # radians, shape (T, n_psi)
feat_tors = np.hstack([phi % (np.pi*2), psi % (np.pi*2)]) 

# TICA on your feature matrix
tica = pyemma.coordinates.tica([feat_tors], lag=8, dim=2)
Y = tica.transform(feat_tors)  # (T, 2)


In [None]:
from mmd import MMD_loss

# use GT data to set the kernel width of MMD
sigma_2 = ((Y[None, ::100] - Y[::100, None])**2).sum(-1)**0.5
sigma_2 = np.median(sigma_2.flatten())**2
MMD_fnc = MMD_loss(2, 10, sigma_2 )

### Calculate Metrics

In [None]:
Energy_TVD = []
Distance_TVD = []
Sample_W2 = []
TICA_MMD = []

for _ in range(3): # repeat 3 times, each time eval with 5000 samples
    budget = 50000 # this is to aligh with the inference time scaling budget of RNC
    samples = torch.from_numpy(all_samples)[1000:budget].to(device).reshape(-1, n_particles*3) 
    original_samples = samples.clone()
    # thin it to 5000 samples
    samples = samples[np.random.choice(samples.shape[0], 5000, replace=False)]


    # evaluate the log density of the thin samples
    thin_log_p = target.log_prob(data_low[np.random.choice(data_low.shape[0], 5000, replace=False)].reshape(-1, n_particles*3))
    thin_hist, _ = np.histogram(thin_log_p.cpu().numpy(), bins=bins, density=True)
    thin_hist = thin_hist / np.sum(thin_hist)

    # calculate Energy TVD
    log_p_s = target.log_prob(samples)
    s_hist, s_bins = np.histogram(log_p_s.cpu().numpy(), bins=bins, density=True)
    s_hist = s_hist / np.sum(s_hist)

    TVD = thin_hist - s_hist
    TVD = np.abs(TVD)
    TVD = np.sum(TVD) / 2
    print(f"Energy TVD: {TVD}")

    Energy_TVD.append(TVD)



    # remove mean
    samples = samples.reshape(-1, n_particles, 3) - samples.reshape(-1, n_particles, 3).mean(dim=1, keepdim=True)
    samples = samples.reshape(-1, n_particles*3)
    # align
    samples = superimpose_B_onto_A(reference.reshape(n_particles, 3), samples.reshape(-1, n_particles, 3), np.arange(n_particles))
    samples = samples.reshape(-1, n_particles*3)


    s1 = samples[np.random.choice(samples.shape[0], 5000, replace=False), None].reshape(-1, n_particles, 3).cpu() 
    s2 = data_low[np.random.choice(data_low.shape[0], 5000, replace=False), None].reshape(-1, n_particles, 3).cpu()

    # calculate Sample W2
    w2 = compute_distribution_distances(s1.reshape(-1, n_particles*3)[:, None], s2.reshape(-1, n_particles*3)[:, None])
    w2  = w2[1][1]
    print(f"Sample W2: {w2}")
    Sample_W2.append(w2)


    # calculate Distance TVD
    dist_samples = get_dist(samples.reshape(-1, n_particles, 3))
    s_hist, s_bins = np.histogram(dist_samples.cpu().numpy(), bins=dist_bins, density=True)
    s_hist = s_hist / np.sum(s_hist)


    thin_dist = get_dist(data_low[np.random.choice(data_low.shape[0], 5000, replace=False)])
    thin_dist_hist, _ = np.histogram(thin_dist.cpu().numpy(), bins=dist_bins, density=True)
    thin_dist_hist = thin_dist_hist / np.sum(thin_dist_hist)

    plt.hist(dist_samples.cpu().numpy(), bins=dist_bins, density=True, alpha=0.5, label='samples')
    plt.hist(thin_dist.cpu().numpy(), bins=dist_bins, density=True, alpha=0.5, label='data')
    plt.legend()
    plt.show()

    TVD = thin_dist_hist - s_hist
    TVD = np.abs(TVD)
    TVD = np.sum(TVD) / 2
    print(f"Distance TVD: {TVD}")
    Distance_TVD.append(TVD)



    # calculate TICA MMD   
    xyz = original_samples.detach().cpu().numpy().reshape(-1, 43, 3)  
    xyz = xyz[np.random.choice(xyz.shape[0], 5000, replace=False)]          # (T, N, 3)
    if target_name == 'a4':
        top = md.load('energy/AAAA.pdb').topology
    elif target_name == 'a6':
        top = md.load('energy/AAAAA.pdb').topology
    elif target_name == 'aldp':
        top = md.Topology.from_openmm(target.system.topology)
    traj = md.Trajectory(xyz, topology=top)
    phi_idx,  phi = md.compute_phi(traj)    # radians, shape (T, n_phi)
    psi_idx,  psi = md.compute_psi(traj)    # radians, shape (T, n_psi)
    feat_tors = np.hstack([phi % (np.pi*2), psi % (np.pi*2)])  # (T, 2*n_phi + 2*n_psi)
    ca = [a.index for a in traj.topology.atoms if a.name == 'CA']
    pairs = np.array([(i, j) for i in ca for j in ca if j > i], dtype=int)
    feat_dists = md.compute_distances(traj, pairs)    # (T, n_pairs), units nm
    Y_samples = tica.transform(feat_tors)  # (T, 2)
    Y_thin_sample = Y[np.random.choice(Y.shape[0], 5000, replace=False)]

    # calculate W2
    mm = MMD_fnc(torch.from_numpy(Y_thin_sample).to(device), 
                                    torch.from_numpy(Y_samples).to(device))
    mm  = mm.item()
    print(f"TICA MMD: {mm}")
    TICA_MMD.append(mm)




# calculate mean an std
print(f'mean and std of Energy_TVD: {np.mean(Energy_TVD):.4f} +- {np.std(Energy_TVD):.4f}')
print(f'mean and std of Distance_TVD: {np.mean(Distance_TVD):.4f} +- {np.std(Distance_TVD):.4f}')
print(f'mean and std of Sample_W2: {np.mean(Sample_W2):.4f} +- {np.std(Sample_W2):.4f}')
print(f'mean and std of TICA_MMD: {np.mean(TICA_MMD):.4f} +- {np.std(TICA_MMD):.4f}')