In [None]:
%matplotlib inline

import matplotlib.pyplot as plt
import matplotlib as mpl
import numpy as np
import pyemma
import glob
import pandas as pd
import seaborn as sns
import os

import nglview
import mdtraj
from mdtraj import shrake_rupley, compute_rg
from threading import Timer
from nglview.player import TrajectoryPlayer

import MDAnalysis as mda
from MDAnalysis.analysis import diffusionmap,align, rms

In [None]:
fig_path = r'./figures/estimation/'
if not os.path.exists(fig_path): os.mkdir(fig_path)

In [None]:
# Get topology and trajectory files. 
# Split into training set and test set

top_path = './data/peptide.gro'
trajs_path = glob.glob('./data/md_1us_*_noPBC.xtc')
train_files = trajs_path[:-1]
test_file = trajs_path[-1]
print('Training files:', *train_files, '\nTest files:', test_file)
assert set(train_files) & set(test_file) == set()

In [None]:
widget = nglview.show_mdtraj(mdtraj.load(top_path))
p = TrajectoryPlayer(widget)
widget.add_ball_and_stick()
p.spin = True
def stop_spin():
    p.spin = False
    widget.close()
Timer(30, stop_spin).start()
widget

In [None]:
# Plot Ramachandran and free energy 

feat_torsion = pyemma.coordinates.featurizer(top_path)
feat_torsion.add_backbone_torsions(periodic=False)
reader_torsion = pyemma.coordinates.source(trajs_path, features = feat_torsion)
data_plot = reader_torsion.get_output(stride = 3)

phi = np.concatenate(data_plot[0][:,::2])
psi = np.concatenate(data_plot[0][:,1::2])

fig, axes = plt.subplots(1, 2, figsize=(12, 5), sharex=True, sharey=True)
pyemma.plots.plot_density(phi, psi, ax=axes[0])
pyemma.plots.plot_free_energy(phi, psi, ax=axes[1], legacy=False)
for ax in axes.flat:
    ax.set_xlabel('$\Phi$')
    ax.set_aspect('equal')
axes[0].set_ylabel('$\Psi$')
fig.tight_layout()
plt.savefig(fname = os.path.join(fig_path,'Ramachandran.png'))

***
### Feature selection

In [None]:
stride = 10

In [None]:
# Select features by comparing VAMP scores at different lag times 
feat = pyemma.coordinates.featurizer(top_path)

# backbone torsions
feat.active_features = []
feat.add_backbone_torsions(periodic=False)
data_backbone = pyemma.coordinates.load(train_files, features=feat, stride = stride)
data_backbone_test = pyemma.coordinates.load(test_file, features=feat, stride = stride)
print('Backbone torsions dimension: ',feat.dimension())

# ca distances
feat.active_features = []
feat.add_distances_ca(periodic=False)
data_ca = pyemma.coordinates.load(train_files, features=feat, stride = stride)
data_ca_test = pyemma.coordinates.load(test_file, features=feat, stride = stride)
print('Ca distances dimension: ',feat.dimension())

# sidechain torsions
feat.active_features = []
feat.add_sidechain_torsions(periodic=False)
data_sidechain = pyemma.coordinates.load(train_files, features=feat, stride = stride)
data_sidechain_test = pyemma.coordinates.load(test_file, features=feat, stride = stride)
print('Sidechain torsions dimension:',feat.dimension())

# backbone torsions + ca distances
feat.active_features = []
feat.add_backbone_torsions(periodic=False)
feat.add_distances_ca(periodic=False)
data_backbone_ca = pyemma.coordinates.load(train_files, features=feat, stride = stride)
data_backbone_ca_test = pyemma.coordinates.load(test_file, features=feat, stride = stride)
print('Backbone torsions + Ca distances dimension: ',feat.dimension())

# Ca distances + sidechain torsions
feat.active_features = []
feat.add_distances_ca(periodic=False)
feat.add_sidechain_torsions(periodic=False)
data_ca_sidechain = pyemma.coordinates.load(train_files, features=feat, stride = stride)
data_ca_sidechain_test = pyemma.coordinates.load(test_file, features=feat, stride = stride)
print('Ca distances + sidechain torsions dimension:',feat.dimension())

# backbone torsions + sidechain torsions
feat.active_features = []
feat.add_backbone_torsions(periodic=False)
feat.add_sidechain_torsions(periodic=False)
data_backbone_sidechain = pyemma.coordinates.load(train_files, features=feat, stride = stride)
data_backbone_sidechain_test = pyemma.coordinates.load(test_file, features=feat, stride = stride)
print('Backbone torsions + sidechain torsions dimension:',feat.dimension())

# backbone torsions + ca distances + sidechain torsions
feat.active_features = []
feat.add_backbone_torsions(periodic=False)
feat.add_distances_ca(periodic=False)
feat.add_sidechain_torsions(periodic=False)
data_backbone_ca_sidechain = pyemma.coordinates.load(train_files, features=feat, stride = stride)
data_backbone_ca_sidechain_test = pyemma.coordinates.load(test_file, features=feat, stride = stride)
print('Backbone torsions + Ca distances + sidechain torsions dimension:',feat.dimension())

data_list = [data_backbone, data_ca, data_sidechain, data_backbone_ca, data_ca_sidechain, data_backbone_sidechain, data_backbone_ca_sidechain]
test_data_list = [data_backbone_test, data_ca_test, data_sidechain_test, data_backbone_ca_test, data_ca_sidechain_test, data_backbone_sidechain_test, data_backbone_ca_sidechain_test]
label_list = ['BackBone', 'Ca_Dists', 'SideChain','BB+Ca', 'Ca+SC', 'BB+SC', 'BB+CA+SC']

In [None]:
def plot_for_lag(ax, lag, data_list, test_data_list, label_list, dim=5):
    vamps = []
    test_data = []
    labels = []
    for i, data in enumerate(data_list):
        if data[0].shape[1] >= dim: 
            vamps.append(pyemma.coordinates.vamp(data, lag = lag, dim = dim))
            test_data.append(test_data_list[i])
            labels.append(label_list[i])
    for i, (v, test_data) in enumerate(zip(vamps, test_data)):
        if dim > v.dimension(): continue
        s = v.score(test_data = test_data)
        ax.bar(i, s)
    ax.set_xticks(range(len(vamps)))
    ax.set_xticklabels(labels, rotation = 60)

In [None]:
# Compute VAMP scores at different lag times and dimensions
# This is to ensure that our selected feature is robust as a function of lag time
# and to select features by comparing VAMP scores 

dimensions = [5, 10, 40, 70]
lagtimes = [5, 10, 25, 50]

fig, axes = plt.subplots(len(dimensions), 4, figsize=(5*len(lagtimes), 5*len(dimensions)), sharey = True)
for i, dim in enumerate(dimensions):
    for j, lag in enumerate(lagtimes):
        plot_for_lag(axes[i,j], lag, data_list, test_data_list, label_list, dim = dim)
for ax, lag in zip(axes[0], lagtimes):
    ax.set_title('VAMP2 at lag = {}ps'.format(lag*40), size = 'large')
for ax, dim in zip(axes[:,0], dimensions):
    ax.set_ylabel('VAMP dimension = {}'.format(dim), rotation = 90, size = 'large')
fig.tight_layout()
plt.savefig(fname = os.path.join(fig_path,'VAMPs.png'))

***
### Dimensionality reduction and discretisation

In [None]:
lag_tica = 25
var_cutoff = 0.95
stride_tica = 3

In [None]:
# Number of orginal features = sidechain torsions (18) + ca distances (78) = 96
# Concatenate trajectories
# Perform TICA and reduce the feature space dimensions to (60*length)

feat = pyemma.coordinates.featurizer(top_path)
feat.active_features = []
feat.add_sidechain_torsions(periodic = False)
feat.add_distances_ca(periodic = False)
reader = pyemma.coordinates.source(trajs_path, features = feat)

tica = pyemma.coordinates.tica(reader, lag=lag_tica, var_cutoff=var_cutoff, stride=stride_tica)
tica_output = tica.get_output()
tica_concatenated = np.concatenate(tica_output)

print('TICA subspace shape', tica_concatenated.shape)

***
### Validation

In [None]:
n_cluster_test = [150, 200, 250, 300, 400, 500]
max_iter = 300
stride_cluster = 3

In [None]:
# ITS convergence test for different numbers of cluster centers

fig, axes = plt.subplots(len(n_cluster_test), 2, figsize=(12, len(n_cluster_test)*6))
for i, k in enumerate(n_cluster_test):
    cluster_test = pyemma.coordinates.cluster_kmeans(tica, k=k, max_iter=max_iter, stride=stride_cluster)
    pyemma.plots.plot_density(*tica_concatenated[:,:2].T, ax=axes[i, 0], cbar=False, alpha=0.1)
    axes[i, 0].scatter(*cluster_test.clustercenters[:, :2].T, s=5, c='C1')
    axes[i, 0].set_xlabel('IC1')
    axes[i, 0].set_ylabel('IC2')
    axes[i, 0].set_title('k = {} centers'.format(k))
    
    pyemma.plots.plot_implied_timescales(pyemma.msm.its(cluster_test.dtrajs, 
                                                        nits=5, 
                                                        lags=[1, 2, 5, 10, 20, 50, 100], 
                                                        errors='bayes'), 
                                         units='ns',
                                         dt=0.04,
                                         ax=axes[i, 1])
    axes[i, 1].set_title('k = {} centers'.format(k))

fig.tight_layout()
plt.savefig(fname = os.path.join(fig_path,'ITS_compare.png'))

In [None]:
n_cluster = 300
max_iter = 300
stride_cluster = 3
lag_markov = 25
dt_traj = '40 ps'
n_metastate = 2

In [None]:
# CK test

cluster = pyemma.coordinates.cluster_kmeans(tica, k=k, max_iter=max_iter, stride=stride_cluster)
bayesian_msm = pyemma.msm.bayesian_markov_model(cluster.dtrajs, lag=lag_markov, dt_traj=dt_traj, conf=0.95)
pyemma.plots.plot_cktest(bayesian_msm.cktest(n_metastate), figsize = (8,8))
plt.savefig(fname = os.path.join(fig_path,f'CK_{n_metastate}_state.png'))