Imports
-------

In [None]:
%matplotlib inline


import numpy as np

from msmbuilder.example_datasets import FsPeptide
from msmbuilder.featurizer import DihedralFeaturizer
from msmbuilder.preprocessing import RobustScaler
from msmbuilder.decomposition import tICA
from msmbuilder.cluster import KMeans
from msmbuilder.msm import MarkovStateModel
from msmbuilder.tpt import net_fluxes, paths

import mdtraj as md
import msmexplorer as msme

from mdentropy.metrics import DihedralMutualInformation

rs = np.random.RandomState(42)

Load Trajectories
-----------------

In [None]:
trajectories = FsPeptide().get().trajectories

Extract Dihedrals
-----------------

In [None]:
dihedrals = DihedralFeaturizer()
data = dihedrals.transform(trajectories)

Scale Dihedral Data
-------------------

In [None]:
robust_scaler = RobustScaler()
scaled_data = robust_scaler.fit_transform(data)

Dimensionality Reduction
------------------------

In [None]:
tica = tICA(n_components=2, lag_time=10)
tica_data = tica.fit_transform(scaled_data)

Clustering
----------

In [None]:
kmeans = KMeans(n_clusters=12, random_state=rs)
assignments = kmeans.fit_transform(tica_data)

Build Markov Model
------------------

In [None]:
msm = MarkovStateModel(lag_time=1)
msm_assignments = msm.fit_transform(assignments)

Infer Top Folding Pathway
-------------------------

In [None]:
sources, sinks = [msm.populations_.argmin()], [msm.populations_.argmax()]
net_flux = net_fluxes(sources, sinks, msm)
paths, _ = paths(sources, sinks, net_flux, num_paths=0)

samples = msm.draw_samples(msm_assignments, n_samples=1000, random_state=rs)

xyz = []
for state in paths[0]:
    for traj_id, frame in samples[state]:
        xyz.append(trajectories[traj_id][frame].xyz)
pathway = md.Trajectory(np.concatenate(xyz, axis=0), trajectories[0].topology)

Calculate Mutual information
----------------------------

In [None]:
dmutinf = DihedralMutualInformation(n_bins=3, method='knn', normed=True)
M = dmutinf.partial_transform(pathway)
M -= M.diagonal() * np.eye(*M.shape) 

labels = [str(res.index) for res in trajectories[0].topology.residues if res.name not in ['ACE', 'NME']]
ax = msme.plot_chord(M, threshold=np.percentile(M, 75), labels=labels, labelsize=12)