# Evaluating alanine dipeptide (ala2)
This notebook contains evaluations for alanine dipeptide:
- Dihedral distributions (Ramachandran plots)
- Jensen-Shannon divergence w.r.t. the reference dihedral distribution

Dihedrals are torsional angles that twist the protein structure around the backbone. A dihedral is defined by 4 consecutive points; it is the angle between the plain defined by the first 3 points and the plane defined by the last 3 points. Since alanine dipeptide has 5 backbone atoms, it has only two dihedrals ($\phi$ and $\psi$, see illustration below), which are the main strcutural degrees of freedom. 

<img src="../repo_images/dihedrals.png" width="200"/>

> <span style="color: blue;">**! Note 1:**</span> <br> This notebook assumes that sampled molecules are saved. Use the [sampling script](../sample.py) to sample molecules and save them to a specified location.

> <span style="color: blue;">**! Note 2:**</span> <br> In this notebook, we assume that the raw data is not available, and therefore use the pre-saved reference distribution from [saved_references](../saved_references/).

> <span style="color: blue;">**! Note 3:**</span> <br> The [saved_models](../saved_models/) folder contains models trained on the *full* alanine dipeptide training set (split into 4 different folds). <br>To study the effect of training set size, models should be trained from scratch. 

### Imports

In [None]:
import torch
import numpy as np
import mdtraj as md
from pathlib import Path
from os.path import join

import sys
sys.path.insert(0, "../")
from evaluate.evaluators import DihedralEnergiesEvaluator, js_divergence
from evaluate.evaluators_CGflowmatching import (
    get_prob,
    get_torsions,
)

np.random.seed(0)

### Arguments

In [None]:
# Specify
fold = 1 # Which cross validation fold?
append_exp_name = "" # Name of the experiment, leave as an empty string if not applicable
gen_mode = "langevin" # Mode of generation, "iid" or "langevin"
subsample = None # Give integer if only a random subset of the samples should be analyzed (if None, all samples are used)

# Inferred
append_exp_name_str = '_' + append_exp_name if append_exp_name else ''
eval_folder = f"../saved_models/alanine/fold{fold}/main_eval_output_{gen_mode}{append_exp_name_str}"
sample_path = Path(eval_folder, f"sample-{gen_mode}.pt")
pdb_file = "../datasets/folded_pdbs/ala2_cg.pdb"

### Loading samples

In [None]:
# Load sampled molecules
sampled_mol = torch.load(sample_path)
if subsample is not None:
    sampled_mol = sampled_mol[np.random.permutation(subsample)]
print(f"Size of samples set (num_samples x num_backbone_atoms x 3): {sampled_mol.shape}")

# Load topology from pdb file
topology = md.load(pdb_file).topology
print(f"Atoms: {[a for a in topology.atoms]}")

### Obtain dihedral distribution for samples

In [None]:
# Initialize evaluator
dihedral_evaluator_val = DihedralEnergiesEvaluator(val_data=None, 
                                               topology=topology, 
                                               plots_folder=eval_folder, 
                                               n_bins=61,
                                               saved_ref=f"./saved_references/saved_dih_probs_ala2_fold_{fold}_valset.pickle")
dihedral_evaluator_test = DihedralEnergiesEvaluator(val_data=None, 
                                               topology=topology, 
                                               plots_folder=eval_folder, 
                                               n_bins=61,
                                               saved_ref=f"./saved_references/saved_dih_probs_ala2_fold_{fold}_testset.pickle")


# Get samples dihedral distribution
sampled_dihedrals = get_torsions(sampled_mol.numpy(), topology)
sampled_probs = get_prob(sampled_dihedrals, n_bins=dihedral_evaluator_test.n_bins)

### Ramachandran plot

In [None]:
save_plot = True # Whether or not to save the displayed plots

# Valset Reference dihedral distribution
dihedral_evaluator_val._plot_freeE_2d(dihedral_evaluator_val.gt_probs,
                                  file_name=join(dihedral_evaluator_val.plots_folder, "ramachandran_reference.png"),
                                  plot_title="Reference val",
                                  save_plot=save_plot)

# Testset Reference dihedral distribution
dihedral_evaluator_test._plot_freeE_2d(dihedral_evaluator_test.gt_probs,
                                  file_name=join(dihedral_evaluator_test.plots_folder, "ramachandran_reference.png"),
                                  plot_title="Reference test",
                                  save_plot=save_plot)

# # Samples dihedrals distribution
dihedral_evaluator_test._plot_freeE_2d(sampled_probs,
                                  file_name=join(dihedral_evaluator_test.plots_folder, f"ramachandran_sampled{exp_append}.png"),
                                  plot_title="Samples",
                                  save_plot=save_plot)

### Jensen-Shannon divergence to the reference test distribution

In [None]:
dihedral_js = js_divergence(sampled_probs, dihedral_evaluator_test.gt_probs)
print(f"JS divergence to reference test distribution: {dihedral_js:.4f}")