# Evaluating fast-folding proteins
This notebook contains evaluations for fast-folding proteins (click to expand):

<br>

<details>
<summary><b>Pairwise distances</b></summary>
Calculating the pairwise distances between all beads for all samples, and comparing to the reference. Since pairwise distance matrices are symmetric, we only look at the upper triangle, with and a default (but adjustable) offset of three from the diagonal to emphasize global structure over local structure.
</details>

<br>

<details>
<summary><b>Time-lagged independent component analysis (TICA)</b></summary>
Similar to PCA, TIC analysis takes a high-dimensional input and fits a linear model to obtain a lower-dimensional output, where TICA also includes a time component. The first two TIC components correspond to the reaction coordinates that explain the slowest conformational changes, which is highly relevant for protein (un)folding. We can use these two TIC components to obtain a 2D free energy landscape, where we can compare samples to the reference distribution (MD simulation).
</details>

<br>

<details>
<summary><b>RMSD to folded structure</b></summary>
Align samples to the <i>native</i> structure, and calculate the bead-wise root mean squared distance.
</details>

<br>

<details>
<summary><b>Contacts</b></summary>
Contacts occur when two beads are far away in the protien sequence, but close in the 3D structure. We can calculate the pairwise distances between all beads, and binarize the resulting all-vs-all map by setting a cutoff of 10Å. This gives a square, symmetric matrix that shows which entries are contacts, called a contact map. By summing over all contact maps and normalizing w.r.t. the number of samples, we obtain the contact probability for each pair of beads.
</details>

<br>

<details>
<summary><b>State transition probabilities</b></summary>
For time-ordered simulations, we can count the state transition probabilities between clusters in the TIC free energy landscape (see explanation of TICA above).
</details>

<br>
<br>


> <span style="color: blue;">**! Note 1:**</span> <br> This notebook assumes that sampled molecules are saved. Use the [sampling script](../sample.py) to sample molecules and save them to a specified location.

> <span style="color: blue;">**! Note 2:**</span> <br> In this notebook, we assume that the raw data is not available, and therefore use the pre-saved reference distribution from [saved_references](../saved_references/).

### Imports

In [None]:
import torch
import numpy as np
import mdtraj as md
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
from os.path import join
from deeptime.clustering import MiniBatchKMeans
from deeptime.markov import TransitionCountEstimator
from sklearn.preprocessing import normalize
import seaborn as sns
from pathlib import Path

import sys
sys.path.insert(0, "../")
from evaluate.evaluators import (
    PwdEvaluator,
    TicEvaluator,
    RmsdEvaluator,
    ContactEvaluator,
    get_pwd_triu_batch,
    js_divergence,
)
from datasets.dataset_utils_empty import Molecules

np.random.seed(0)

### Arguments

> <span style="color: blue;">**! Note 1:**</span> <br>
The results in the [paper](https://pubs.acs.org/doi/10.1021/acs.jctc.3c00702) are based on a number of samples equal to the size of the training dataset (see below). To reproduce results obtained in the paper, this is the number of required samples. If needed, it is possible to randomly subsample from a set of generated structures. This is especially important for the dynamics setting, where it is necessary to subsample from the long simulated trajectories.

> <span style="color: blue;">**! Note 2:**</span> <br>
Don't use random subsampling when analyzing dynamics, as samples need to be in time order.

<details>
<summary>Dataset sizes (click to expand)</summary>
<br>

| Protein name | Dataset size (total) | Training set size (70%) | Validation set size (10%)| Test set size (20%)|
|--------------|----------------------|-------------------|---------------------|---------------|
| chignolin    |  534.743             | 374.320           | 53.474              | 106.949       |
| trp_cage     |  1.044.000           | 730.800           | 104.400             | 208.800       |
| bba          |  1.114.545           | 780.181           | 111.455             | 222.909       |
| villin       |  627.907             | 439.535           | 62.791              | 125.581       |
| protein_g    |  1.849.251           | 1.294.476         | 184.925             | 369.850       |
</details>

In [None]:
# Specify
protein_name = "chignolin" # Options: chignolin, trp_cage, bba, villin, protein_g
gen_mode = "iid" # Mode of generation, "iid" or "langevin"
subsample = None # Give integer if only a random subset of the samples should be analyzed (if None, all samples are used)
append_exp_name = "" # Name of the experiment, leave as an empty string if not applicable

# Inferred
append_exp_name_str = '_' + append_exp_name if append_exp_name else ''
eval_folder = f"../saved_models/{protein_name}/main_eval_output_{gen_mode}{append_exp_name_str}"
sample_path = Path(eval_folder, f"sample-{gen_mode}.pt")
pdb_file = f"../datasets/folded_pdbs/{Molecules[protein_name.upper()].value}-0-c-alpha.pdb"

### Loading samples

In [None]:
# Load sampled molecules
sampled_mol = torch.load(sample_path)
if subsample is not None:
    sampled_mol = sampled_mol[np.random.permutation(subsample)]
print(f"Size of samples set (num_samples x num_backbone_atoms x 3): {sampled_mol.shape}")

# Load topology from pdb file
topology = md.load(pdb_file).topology
print(f"Atoms: {[a for a in topology.atoms]}")


### Pairwise distances

In [None]:
# Initialize evaluator
pwd_evaluator = PwdEvaluator(val_data=None, 
                             plots_folder=eval_folder, 
                             offset=3, 
                             mol_name=protein_name,
                             evalset='testset') # The evalset is the set we'll compare to in the next evaluation steps

# Get samples pairwise distances
pwd_sampled = get_pwd_triu_batch(sampled_mol, pwd_evaluator.offset)

In [None]:
# Jensen-Shannon divergence to reference
pwd_js = pwd_evaluator.js_divergence_pwd(pwd_evaluator.gt_hist, 
                                         pwd_sampled, 
                                         pwd_evaluator.gt_max, 
                                         pwd_evaluator.resolution)

print(f"JS divergence to reference distribution: {pwd_js:.4f}")

### TIC analysis

In [None]:
# Initialize evaluator
tic_evaluator = TicEvaluator(val_data=None,
                             mol_name=protein_name,
                             eval_folder=eval_folder,
                             data_folder=None,
                             folded_pdb_folder="../datasets/folded_pdbs",
                             evalset='testset') # The evalset is the set we'll compare to in the next evaluation steps

# Get samples TIC free energy landscape
sample_tic_features = tic_evaluator.get_tic_features(sampled_mol, tic_evaluator.folded)
transformed_samples = tic_evaluator.tica(sample_tic_features)
prob_samp, _, _ = np.histogram2d(
            transformed_samples[:, 0],
            transformed_samples[:, 1],
            bins=[tic_evaluator.bin_edges_x, tic_evaluator.bin_edges_y],
            density=True,
        )

In [None]:
# TIC plot
save_plot = True # Whether or not to save the displayed plots

# Reference
fig = tic_evaluator._plot_tic(tic_evaluator.gt_prob,
                              file_name=join(tic_evaluator.plots_folder, "TICA_reference.png"),
                              title="Reference testset",
                              save_plot=save_plot)

# Samples
fig = tic_evaluator._plot_tic(prob_samp,
                              file_name=join(tic_evaluator.plots_folder, f"TICA_sampled_{gen_mode}.png"),
                              title="Samples",
                              save_plot=save_plot)


In [None]:
# JS divergence
tica_js = js_divergence(prob_samp.flatten(), tic_evaluator.gt_prob.flatten())
print(f"JS divergence to reference distribution: {tica_js:.4f}")

### RMSD

In [None]:
# Initialize evaluator
rmsd_evaluator = RmsdEvaluator(mol_name=protein_name,
                               folded_pdb=f"../datasets/folded_pdbs/{Molecules[protein_name.upper()].value}.pdb",
                               eval_folder=eval_folder)

In [None]:
# Plot
save_plot = True # Whether or not to save the displayed plots

# Recommended settings for plots
cutoff_dict = {'chignolin': 10, 'trp_cage': 12, 'bba': 14, 'villin': 14, 'protein_g': 20}
cutoff = cutoff_dict[protein_name.lower()]
nbins = 100 

# Load evaluation from reference data
rmsd_evaluator.eval("Reference", None, nbins, cutoff)

# Evaluate samples
rmsd_evaluator.eval(gen_mode, sampled_mol, nbins, cutoff)
rmsd_evaluator._plot_rmsd(save=save_plot)

### Contacts

In [None]:
# Contact cutoff (in angstrom), recommended: 10
cutoff = 10

# Initialize evaluator
contacts = ContactEvaluator(mol_name=protein_name, 
                            folded_pdb=f"../datasets/folded_pdbs/{Molecules[protein_name.upper()].value}.pdb", 
                            eval_folder=eval_folder,
                            contact_cutoff=cutoff
)

In [None]:
# Plot
save_plot = True # Whether or not to save the displayed plots

# Plot normalized count over contacts (contact probability) for samples
_ = contacts._plot_contact_normcount(
    sampled_mol, gen_mode, save=save_plot, take_log=True
)

### Dynamics

> <span style="color: blue;">**! Note 1:**</span> <br>
Only for Langevin dynamics samples! I.i.d. samples are not in order of time, and therefore results will not be meaningful. Make sure the Langevin dynamics samples are in time order, and not randomly subsampled.

> <span style="color: blue;">**! Note 2:**</span> <br>
Since protein G could not be compared between our DFF method and Flow-CGNet, we do not provide presets for this protein.

In [None]:
# Preset cluster centers, determined via standard K-means clustering 
# (elbow method to find the number fo clusters)

assert gen_mode == "langevin", "This analysis is only for Langevin dynamics samples (see Note 2 above)"

num_clusters_per_protein = {
    "chignolin": 3,
    "villin": 3,
    "trp_cage": 3,
    'bba': 4,
}

cluster_centers = {
      'chignolin': np.array([[ 0.69400153, -0.34598462],
       [-0.48732213,  0.00642035],
       [ 1.87483537,  0.06285344]]),
    
    'trp_cage': np.array([[-2.15921372,  0.0062795 ],
       [ 0.47752285, -0.38050238],
       [ 0.40182245,  2.0690773 ]]),

    'bba': np.array([[-0.5756589 , -0.60663654],
       [ 1.7861676 , -0.87717611],
       [ 0.91295128,  1.07518898],
       [-0.49210152,  0.40313689]]),
       
    'villin': np.array([[ 1.08971813, -0.98522752],
       [-2.49001353, -2.31375028],
       [-0.12929561,  0.53703407]])
}

In [None]:
# Since we need to do K-means clustering in TIC-space, we will transform the samples first
# (this code is the same as in the "TIC analysis" section)

# Initialize TIC evaluator
tic_evaluator = TicEvaluator(val_data=None,
                             mol_name=protein_name,
                             eval_folder=eval_folder,
                             data_folder=None,
                             folded_pdb_folder="../datasets/folded_pdbs",
                             evalset='testset') # The evalset is the set we'll compare to in the next evaluation steps

# Get samples TIC free energy landscape
sample_tic_features = tic_evaluator.get_tic_features(sampled_mol, tic_evaluator.folded)
transformed_samples = tic_evaluator.tica(sample_tic_features)

# K-means clustering
kmeans = MiniBatchKMeans(n_clusters=num_clusters_per_protein[protein_name], 
                         max_iter=0, 
                         batch_size=64, 
                         init_strategy='kmeans++', 
                         n_jobs=16, 
                         tolerance=1e-7, 
                         initial_centers=cluster_centers[protein_name])

assignments = kmeans.fit_transform(transformed_samples)

# Count transition probabilities
count_matrix = TransitionCountEstimator.count(count_mode="sliding", 
                                              dtrajs=[assignments.astype('int')], 
                                              lagtime=1)
                                              
count_matrix = normalize(count_matrix, axis=1, norm='l1')

In [None]:
# Plot 1: show cluster assignments in TIC-space
save_plot = True # Whether or not to save the displayed plots

plt.figure(figsize=(6,5))
cmap = plt.get_cmap("tab10")
handles=[]
for i in range(num_clusters_per_protein[protein_name]):
    plt.scatter(*transformed_samples[assignments==i].T, color=cmap(i), s=5, alpha=.1)
    handles.append(mpatches.Patch(color=cmap(i), label=f'Class {i+1}'))
plt.legend(handles=handles)
plt.title(f'Transformed state assignments')
if save_plot:
    plt.savefig(join(eval_folder, f'{protein_name}_{gen_mode}_kmeans.png'))
    plt.show()

In [None]:
# Plot 2: transition probability matrix
save_plot = True # Whether or not to save the displayed plots

n_clusters=num_clusters_per_protein[protein_name]
plt.figure(figsize=(4,4))
sns.heatmap(count_matrix, annot=True, square=True, cmap = 'Greens', fmt='.3f', 
            xticklabels=np.arange(1,n_clusters+1), yticklabels=np.arange(1,n_clusters+1))
plt.xlabel("End class", fontsize=14)
plt.ylabel("Start class", fontsize=14)
plt.title(f"Transition probability matrix", fontsize=14, pad=10)
if save_plot:
    plt.savefig(join(eval_folder, f'{protein_name}_{gen_mode}_transitions.png'))
plt.show()