In [None]:
import argparse
import dataset
from gcae import GraphConvEncode
import matplotlib.pyplot as plt 
from math import sqrt
import mdtraj as md
import numpy as np
import os
import pandas as pd
import torch
import torch.nn as nn
import torch_geometric
from tqdm import tqdm

import umap
from sklearn.decomposition import PCA
from sklearn.manifold import TSNE

import matplotlib.animation as animation
from IPython.display import HTML

In [None]:
def determine_device():
    if torch.backends.mps.is_available():
        return torch.device('mps')
    elif torch.cuda.is_available():
        return torch.device('cuda')
    else:
        return torch.device('cpu')

In [None]:
plt.rcParams.update({
    "font.family": "serif",   # or "sans-serif"
    "font.size": 14,
    "axes.labelsize": 16,
    "axes.titlesize": 16,
    "xtick.labelsize": 12,
    "ytick.labelsize": 12
})

# 1. GCAE Evaluation

In [None]:
split= False
protein =  'chignolin' #'pentapeptide', 'chignolin' or, 'fs_peptide'
if protein == 'chignolin':
    pdb,files = dataset.load_DESRES(protein, data_dir='./', analyze=True, plot=False)
    test_traj_num = 26
    print(files[test_traj_num])
    

elif protein == 'pentapeptide':
    pdb,files = dataset.load_pentapeptide(data_dir='pentapeptide/')
    files = dataset.reshape_time_window(files, pdb, num_files=25, traj_len=5001, num_split=1)
    test_traj_num = 12
    print(files[test_traj_num])
    split=True

elif protein == 'fs_peptide':
    pdb, files = dataset.load_fs_peptide(data_dir='fs_peptide/')
    files = dataset.reshape_time_window(files, pdb, num_files=28, traj_len=10000, num_split=1)
    test_traj_num = 12
    print(files[test_traj_num])
    split=True


file = files[test_traj_num]
file0 = files[0]

In [None]:
path = 'gcae_experiments/more_recent_models/no_time/chignolin/nonormal_10smoothing_5000ep_0.0005lr_0.1dr_3layers_10edge_16latentdim_1e-06reg_1e-06temp/'
normalize=False

##### MODIFY THESE BASED ON CHOSEN MODEL TO EVALUATE ###########################
top_k = 10
latent_dim = 16
dr = 0.1
num_layers = 3
lr = 5e-4
file = "epoch-4999.pt"
model_path = os.path.join(path, file)
smooth = 10
lambda_reg = 1e-06
lambda_temp = 1e-06
################################################################################

if files[0][-4:] == '.xtc':
    traj0 = md.load_xtc(files[0], top=pdb)
else:
    traj0 = md.load_dcd(files[0], top=pdb)
topology = md.load(pdb).topology

test_structures = dataset.generate_structures([files[test_traj_num]], pdb, traj0, split=split, smooth=smooth, normalize=normalize)
test_dataset = dataset.LigandDataset(test_structures, top_k=top_k)

test_dataloader = torch_geometric.loader.DenseDataLoader(test_dataset, 
                                                      batch_size=256, 
                                                      shuffle=False, 
                                                      num_workers=4)


device = determine_device()
print(f"On device: {device}")

node_h_dim = (100, 16)
edge_h_dim = (32, 1)
node_num = topology.n_residues

model = GraphConvEncode((6,3), node_h_dim, (32,1), edge_h_dim,
                        latent_dim=latent_dim,
                        n_layers=num_layers,
                        drop_rate=dr,
                        node_num=node_num).to(device)

optimizer = torch.optim.Adam(model.parameters(), lr=lr)
loss_function = nn.MSELoss()

state_dict = torch.load(model_path, map_location=device)
model.load_state_dict(state_dict)
model.to(device)



In [None]:
losses = torch.load(os.path.join(path, 'losses.pt'), map_location='cpu')
losses_values = list(losses.values())
loss = [loss[0] for loss in losses_values]

plt.plot(range(len(loss)), loss)
plt.yscale('log')
plt.show()

In [None]:
model.eval()
latents = []
test_loss = 0.0
with torch.no_grad():
    for i, batch in enumerate(test_dataloader):
        batch = batch.to(device)
        nodes = (batch.node_s, batch.node_v)
        edges = (batch.edge_s, batch.edge_v)
        GT = batch.x

        edge_index = batch.edge_index.permute([1, 0, 2])
        edge_index = edge_index.reshape(2, -1)
        
        pred, z = model(nodes, edge_index, edges)
        latents.append(z.cpu())
        reg = torch.mean(torch.norm(z, dim=1)**2)
        temp = torch.mean((z[:,1:] - z[:, :-1])**2)

        loss_A = loss_function(GT, pred)

        loss = loss_A + lambda_reg * reg + lambda_temp * temp
        
        test_loss += loss.item()
test_latents = torch.cat(latents, dim=0)
latents_var = torch.var(test_latents, dim=0)

for i, var in enumerate(latents_var):
    print(f"Dimension {i}: variance {var}")


test_loss /= len(test_dataloader)

rmsd_angstroms =  sqrt(test_loss) * 10

print(f"RMSD reconstruction error on test data {rmsd_angstroms:.4f} Å")


In [None]:
latents = test_latents
print(latents.shape)

In [None]:
latents = np.asarray(latents)

f, (ax1) = plt.subplots(1,1, figsize=(6,6))

hb = ax1.hexbin(latents[:,0], latents[:,1], bins='log')
ax1.set_xlabel('1st-Dimension',fontsize=12.5)
cb = f.colorbar(hb, ax=ax1, label='log10(N)')
ax1.set_ylabel('2nd-Dimension',fontsize=12.5)
ax1.set_title('Latent space')

plt.tight_layout()
plt.show()

In [None]:
pca = PCA(n_components=2)
latent_pca = pca.fit_transform(latents)

tsne = TSNE(n_components=2, perplexity=50, learning_rate='auto')
latent_tsne = tsne.fit_transform(latents)

reducer = umap.UMAP()
embedding = reducer.fit_transform(latents)

In [None]:
f, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(16,6), constrained_layout=True)


ax1.hexbin(latent_pca[:,0], latent_pca[:,1], bins='log')
ax1.set_title("PCA of Latent Codes")
ax1.set_xlabel("PC1")
ax1.set_ylabel("PC2")

ax2.scatter(latent_tsne[:,0], latent_tsne[:,1], alpha=0.5)
ax2.set_title("t-SNE of Latent Codes")
ax2.set_xlabel("t-SNE dim 1")
ax2.set_ylabel("t-SNE dim 2")

ax3.scatter(embedding[:,0], embedding[:,1], alpha=0.5)
ax3.set_title("UMAP of Latent Codes")
ax3.set_xlabel("umap dim 1")
ax3.set_ylabel("umap dim 2")

plt.show()

In [None]:
dim_model = latent_tsne #latent_pca #latent_tsne #embedding
center = (90,0)
x_radius = 1
y_radius = 1

In [None]:
valid_points = []
for i in range(dim_model.shape[0]):
    if (center[0] - x_radius <= dim_model[i, 0]) and (dim_model[i,0] <= center[0] + x_radius):
        if (center[1] - y_radius <= dim_model[i, 1]) and (dim_model[i,1] <= center[1] + y_radius):
            valid_points.append(i)

print(len(valid_points))

print(valid_points)

In [None]:
frame_num=4858

if file[-4:] == '.xtc':
    traj = md.load_xtc(file, top=pdb)
    traj0 = md.load_xtc(file0, top=pdb)
else:
    traj = md.load_dcd(file, top=pdb)
    traj0 = md.load_dcd(file0, top=pdb)

traj.superpose(traj0, frame=0)
traj.center_coordinates()

rmsd = md.rmsd(traj, traj, frame=frame_num, precentered=True)

In [None]:
f, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(16,6), constrained_layout=True)

sc1 = ax1.scatter(latent_pca[:, 0], latent_pca[:, 1], c=rmsd, cmap='viridis', s=10)
ax1.scatter(latent_pca[frame_num, 0], latent_pca[frame_num, 1], c='r', s=120,  marker='*')
ax1.set_title("PCA of Latent Codes")
ax1.set_xlabel("PC1")
ax1.set_ylabel("PC2")

ax2.scatter(latent_tsne[:, 0], latent_tsne[:, 1], c=rmsd, cmap='viridis', s=10)
ax2.scatter(latent_tsne[frame_num, 0], latent_tsne[frame_num, 1], c='r', s=120,  marker='*')
ax2.set_title("t-SNE of Latent Codes")
ax2.set_xlabel("t-SNE dim 1")
ax2.set_ylabel("t-SNE dim 2")

ax3.scatter(embedding[:, 0], embedding[:, 1], c=rmsd, cmap='viridis', s=10)
ax3.scatter(embedding[frame_num, 0], embedding[frame_num, 1], c='r', s=120,  marker='*')
ax3.set_title("UMAP of Latent Codes")
ax3.set_xlabel("umap dim 1")
ax3.set_ylabel("umap dim 2")

cbar = f.colorbar(sc1, ax=[ax1, ax2, ax3], orientation="vertical")
cbar.set_label("RMSD")

#plt.show()
plt.savefig(f'{protein}/rmsd{frame_num}.pdf', bbox_inches='tight')


# Plot trajectories

In [None]:
x = latents[:,0]
y = latents[:,1]
z = latents[:,2]
w = latents[:,3]
starting_frame = 1000
animation_len =  100


fig, ax = plt.subplots(figsize=(6,6))
line, = ax.plot([], [], lw=2, color='red')
point, = ax.plot([], [], 'ro')
ax.hexbin(latents[:,0], latents[:,1], bins='log')
#ax.hexbin(latent_pca[:,0], latent_pca[:,1], bins='log')
ax.set_xlim(x.min() - 0.01, x.max() + 0.01)
ax.set_ylim(y.min() - 0.01, y.max() + 0.01)
ax.set_title("2D Trajectory over time - latent dimension 1 and 2")
ax.set_xlabel("Latent dim 1")
ax.set_ylabel("Latent dim 2")
ax.grid(True)
ax.set_aspect('equal')
timestamp = ax.text(0.02, 0.95, '', transform=ax.transAxes, fontsize=12, ha='left', va='top')
def init():
    line.set_data([], [])
    point.set_data([], [])
    return line, point

def old_update(frame):
    if frame < animation_len:
        line.set_data(x[starting_frame:frame+1], y[starting_frame:frame+1])
        point.set_data(x[starting_frame + frame], y[starting_frame + frame])
        timestamp.set_text(f'Time step: {starting_frame + frame}')
    else:
        line.set_data(x, y)
        point.set_data([x[starting_frame + animation_len-1]], [y[starting_frame + animation_len-1]])
        timestamp.set_text(f'Time step: {starting_frame + animation_len-1}')
    return line, point, timestamp

def update(frame):
    line.set_data(x[starting_frame:starting_frame +frame+1], y[starting_frame:starting_frame +frame+1])
    point.set_data([x[starting_frame + frame]], [y[starting_frame + frame]])
    return line, point


ani = animation.FuncAnimation(fig, update, frames=animation_len, init_func=init, interval=200, blit=True)
HTML(ani.to_jshtml())
# or for a GIF:
#ani.save('trajectory.gif', writer='pillow', fps=20)

In [None]:
fig, axs = plt.subplots(2,2, figsize=(10,5))

axs[0,0].plot(range(animation_len), x[starting_frame:starting_frame + animation_len])
axs[0,0].set_title("Latent Dimension 1")
axs[0,0].set_xlabel("t")

axs[0,1].plot(range(animation_len), y[starting_frame:starting_frame + animation_len])
axs[0,1].set_title("Latent Dimension 2")
axs[0,1].set_xlabel("t")

axs[1,0].plot(range(animation_len), z[starting_frame:starting_frame + animation_len])
axs[1,0].set_title("Latent Dimension 3")
axs[1,0].set_xlabel("t")

axs[1,1].plot(range(animation_len), w[starting_frame:starting_frame + animation_len])
axs[1,1].set_title("Latent Dimension 4")
axs[1,1].set_xlabel("t")

plt.tight_layout()
plt.show()

# If latents were already computed on test data (or train data), load them in + visualize 

In [None]:
path = "/path/to/trained/model/"
latent_path = os.path.join(path, 'test_latents.pt')
latents = torch.load(latent_path, map_location='cpu')

In [None]:
x = latents[:,0]
y = latents[:,1]
x = latents[:,2]
y = latents[:,3]
starting_frame = 500
animation_len =  100


fig, ax = plt.subplots(figsize=(6,6))
line, = ax.plot([], [], lw=2, color='red')
point, = ax.plot([], [], 'ro')
ax.hexbin(latents[:,2], latents[:,3], bins='log')
#ax.hexbin(latent_pca[:,0], latent_pca[:,1], bins='log')
ax.set_xlim(x.min() - 0.01, x.max() + 0.01)
ax.set_ylim(y.min() - 0.01, y.max() + 0.01)
ax.set_title("2D Trajectory over time - latent dimension 1 and 2")
ax.set_xlabel("Latent dim 1")
ax.set_ylabel("Latent dim 2")
ax.grid(True)
ax.set_aspect('equal')
timestamp = ax.text(0.02, 0.95, '', transform=ax.transAxes, fontsize=12, ha='left', va='top')
def init():
    line.set_data([], [])
    point.set_data([], [])
    return line, point

def old_update(frame):
    if frame < animation_len:
        line.set_data(x[starting_frame:frame+1], y[starting_frame:frame+1])
        point.set_data(x[starting_frame + frame], y[starting_frame + frame])
        timestamp.set_text(f'Time step: {starting_frame + frame}')
    else:
        line.set_data(x, y)
        point.set_data([x[starting_frame + animation_len-1]], [y[starting_frame + animation_len-1]])
        timestamp.set_text(f'Time step: {starting_frame + animation_len-1}')
    return line, point, timestamp

def update(frame):
    line.set_data(x[starting_frame:starting_frame +frame+1], y[starting_frame:starting_frame +frame+1])
    point.set_data([x[starting_frame + frame]], [y[starting_frame + frame]])
    return line, point


ani = animation.FuncAnimation(fig, update, frames=animation_len, init_func=init, interval=200, blit=True)
HTML(ani.to_jshtml())
# or for a GIF:
#ani.save('trajectory.gif', writer='pillow', fps=20)

In [None]:
fig, axs = plt.subplots(2,2, figsize=(10,5))

axs[0,0].plot(range(animation_len), x[:animation_len])
axs[0,0].set_title("Latent Dimension 1")
axs[0,0].set_xlabel("t")

axs[0,1].plot(range(animation_len), y[:animation_len])
axs[0,1].set_title("Latent Dimension 2")
axs[0,1].set_xlabel("t")

axs[1,0].plot(range(animation_len), z[:animation_len])
axs[1,0].set_title("Latent Dimension 3")
axs[1,0].set_xlabel("t")

axs[1,1].plot(range(animation_len), w[:animation_len])
axs[1,1].set_title("Latent Dimension 4")
axs[1,1].set_xlabel("t")

plt.tight_layout()
plt.show()

# 2D latent visualization

In [None]:
latents = np.asarray(latents)

f, (ax1) = plt.subplots(1,1, figsize=(6,6))

hb = ax1.hexbin(latents[:,0], latents[:,1], bins='log')
ax1.set_xlabel('1st-Dimension',fontsize=12.5)
cb = f.colorbar(hb, ax=ax1, label='log10(N)')
ax1.set_ylabel('2nd-Dimension',fontsize=12.5)
ax1.set_title('Latent space')

plt.tight_layout()
plt.show()

In [None]:
dim_model = latents
center = (4.0,-5.0)
x_radius = 0.5
y_radius = 0.5

In [None]:
valid_points = []
for i in range(dim_model.shape[0]):
    if (center[0] - x_radius <= dim_model[i, 0]) and (dim_model[i,0] <= center[0] + x_radius):
        if (center[1] - y_radius <= dim_model[i, 1]) and (dim_model[i,1] <= center[1] + y_radius):
            valid_points.append(i)

print(len(valid_points))

print(valid_points)

In [None]:
frame_num=4838

if file[-4:] == '.xtc':
    traj = md.load_xtc(file, top=pdb)
    traj0 = md.load_xtc(file0, top=pdb)
else:
    traj = md.load_dcd(file, top=pdb)
    traj0 = md.load_dcd(file0, top=pdb)

traj.superpose(traj0, frame=0)
traj.center_coordinates()

rmsd = md.rmsd(traj, traj, frame=frame_num, precentered=True)

In [None]:
f, (ax1) = plt.subplots(1, 1, figsize=(6,6), constrained_layout=True)

sc1 = ax1.scatter(latents[:, 0], latents[:, 1], c=rmsd, cmap='viridis', s=10)
ax1.scatter(latents[frame_num, 0], latents[frame_num, 1], c='r', s=120,  marker='*')
#ax1.scatter(center[0], center[1], c='r', s=120,  marker='*')
ax1.set_title("Latent Codes")
ax1.set_xlabel("Dim 1")
ax1.set_ylabel("Dim 2")



cbar = f.colorbar(sc1, ax=[ax1], orientation="vertical")
cbar.set_label("RMSD")

#plt.show()
plt.savefig(f'pentapeptide/2D_pentapeptide_rmsd{frame_num}.pdf', bbox_inches='tight')
