In [None]:
%matplotlib inline

import matplotlib.pyplot as plt
from matplotlib import colors
import matplotlib as mpl
import numpy as np
import pyemma
import glob
import pandas as pd
import seaborn as sns
import os

import nglview
import mdtraj 
from mdtraj import shrake_rupley, compute_rg
from threading import Timer
from nglview.player import TrajectoryPlayer

import MDAnalysis as mda
from MDAnalysis.analysis import diffusionmap,align, rms

In [None]:
mpl.rcParams['savefig.bbox'] = 'tight'
sns.set_style("white")
sns.set_style({'font.family':'sans-serif', 'font.serif':'Arial'})

In [None]:
fig_path = r'./figures/analysis/'
if not os.path.exists(fig_path): os.mkdir(fig_path)

In [None]:
top_path = './data/peptide.gro'
trajs_path = glob.glob('./data/md_1us_*_noPBC.xtc')

# Test structures
af_path = r'./data/alphafold2.xtc'
npt_path = r'./data/npt.xtc' 

model_path = os.path.join('models',f'{len(trajs_path)}us.pyemma')

***
### Prepare MSM using hyperparamters determined by msm_estimation.ipynb

In [None]:
lag_tica = 25
var_cutoff = 0.95
stride_tica = 3
n_cluster = 300
max_iter = 300
stride_cluster = 5
lag_markov = 25
dt_traj = '40 ps'
n_metastate = 2

In [None]:
trajs_path = trajs_path[:4]

In [None]:
feat = pyemma.coordinates.featurizer(top_path)
feat.active_features = []
feat.add_sidechain_torsions(periodic = False)
feat.add_distances_ca(periodic = False)
reader = pyemma.coordinates.source(trajs_path, features = feat)

tica = pyemma.coordinates.tica(reader, lag=lag_tica, var_cutoff=var_cutoff, stride=stride_tica)
tica.save(model_path, model_name = 'tica', overwrite = True)
tica_concatenated = np.concatenate(tica.get_output())

cluster = pyemma.coordinates.cluster_kmeans(tica, k=n_cluster, max_iter=max_iter, stride=stride_cluster)
cluster.save(model_path, model_name = 'kmeans', overwrite = True)
cluster_dtrajs = cluster.dtrajs
dtrajs_concatenated = np.concatenate(cluster_dtrajs)

msm = pyemma.msm.estimate_markov_model(cluster.dtrajs, lag=lag_markov, dt_traj=dt_traj)
msm.save(model_path, model_name = 'msm', overwrite = True)

bayesian_msm = pyemma.msm.bayesian_markov_model(cluster.dtrajs, lag=lag_markov, dt_traj=dt_traj, conf=0.95)
bayesian_msm.save(model_path, model_name = 'bayesian_msm', overwrite = True)

In [None]:
# Check state disconnectivity --- only the largest set of reversibly connected states were used in MSM estimation

print('Fraction of states used = {:f}'.format(msm.active_state_fraction))
print('Fraction of counts used = {:f}'.format(msm.active_count_fraction))
print('Inactive states excluded from estimation: ', set(range(cluster.clustercenters.shape[0]))-set(msm.active_set))

print(f'TICA space dimension: {tica_concatenated.shape}')
print(f'Microstate assignment: {dtrajs_concatenated.shape}')
print(f'Cluster center dimension: {cluster.clustercenters.shape}')
print(f'Right eigenvectors dimension: {msm.eigenvectors_right().shape}')

In [None]:
# Compute the stationary distribution of metastable states

nstates = 2
msm.pcca(nstates)

print('state\tπ\t\tG/kT\t\tNo.')
for i, s in enumerate(msm.metastable_sets):
    p = msm.pi[s].sum()
    print('{}\t{:f}\t{:f}\t{}'.format(i + 1, p, -np.log(p), s.shape[0]))

In [None]:
# Compute mean first passage time -- an average timescale for a transition event to first occur

mfpt = np.zeros((nstates, nstates))
for i in range(nstates):
    for j in range(nstates):
        mfpt[i, j] = msm.mfpt(
            msm.metastable_sets[i],
            msm.metastable_sets[j])

print('MFPT / steps:')
pd.DataFrame(np.round(mfpt, decimals=1), index=range(1, nstates + 1), columns=range(1, nstates + 1))

***
### Time evolution of RMSD with repsect to npt equilibrated structure

In [None]:
npt = mdtraj.load(npt_path, top=top_path)
trajs = mdtraj.load(trajs_path, top=top_path)
trajs.superpose(reference=npt)
rmsds = mdtraj.rmsd(target=trajs, reference=npt)*10

In [None]:
with sns.plotting_context('paper',font_scale=1.5):
    xlim = (0, len(rmsds))
    ylim = (0, 14)
    
    fig = plt.figure(figsize=(10, 4))
    gs = fig.add_gridspec(1, 2,  width_ratios=(8, 1),
                      left=0.1, right=0.9, bottom=0.1, top=0.9,
                      wspace=0.05, hspace=0.05)
    
    ax = fig.add_subplot(gs[0, 0])
    ax_histx = fig.add_subplot(gs[0, 1], sharey=ax)
    
    ax.plot(rmsds, lw=0.4, label='RMSD')
    
    binwidth = 0.5
    bins = np.arange(*ylim, binwidth)
    
    ax_histx.hist(rmsds,orientation="horizontal",bins=bins)
    ax_histx.tick_params(axis="y", labelleft=False)
    
    for i in range(0, len(rmsds), 25001):
        ax.axvline(i, ls='--', lw=1.5, c='gray', alpha=0.8)

    ax.set_xlim(*xlim)
    ax.set_ylim(*ylim)
    ax_histx.set_xlim(*xlim)
    ax_histx.set_ylim(*ylim)
    
    ax_histx.set_xticks([0,len(rmsds)/2,len(rmsds)])
    ax_histx.set_xticklabels([0,0.5,1])
    
    ax.legend()
    ax.set_xlabel('Frame')
    ax.set_ylabel('RMSD (Å)')
    plt.savefig(os.path.join(fig_path, 'time_RMSD.png'))

***
### Stationary distribution, free energy, and metastable assignment

In [None]:
feat_t = pyemma.coordinates.featurizer(top_path)
feat_t.active_features = []
feat_t.add_sidechain_torsions(periodic = False)
feat_t.add_distances_ca(periodic = False)

af = pyemma.coordinates.load(af_path, features=feat_t)
test_dic = {'alphafold': [af]}

#0 original features #1 tica features #2 microstate #3 metastable states
for name, obj_list in test_dic.items():
    ff = tica.transform(obj_list[0])
    microstate = cluster.assign(ff)
    metastate = msm.metastable_assignments[microstate]
    test_dic[name].extend([ff, microstate, metastate])

In [None]:
# !Use uniform scale!

metastable_traj = msm.metastable_assignments[dtrajs_concatenated]

fig, axes = plt.subplots(1, 3, figsize=(16, 5), sharex=True, sharey=True)
    
pyemma.plots.plot_contour(
    *tica_concatenated[:, :2].T,
    msm.pi[dtrajs_concatenated],
    ax=axes[0],
    mask=True,
    cbar_label='stationary distribution')
axes[0].scatter(*cluster.clustercenters[:,:2].T, s=2, c='pink')

pyemma.plots.plot_free_energy(
    *tica_concatenated[:, :2].T,
    weights=np.concatenate(msm.trajectory_weights()),
    ax=axes[1],
    legacy=False)
axes[1].scatter(*cluster.clustercenters[:,:2].T, s=2, c='k')

_, _, misc = pyemma.plots.plot_state_map(*tica_concatenated[:, :2].T, metastable_traj, ax=axes[2], cmap = 'cool')
axes[2].scatter(*cluster.clustercenters[:,:2].T, s=2, c='k')
misc['cbar'].set_ticklabels([r'$\mathcal{S}_%d$' % (i + 1) for i in range(nstates)])
for i, (name, conf_list) in enumerate(test_dic.items()):
    x, y = conf_list[1][:, 0],conf_list[1][:, 1]
    print('Position of {}: {}, {}'.format(name, *x,*y))
    axes[2].scatter(x,y,s = 150, c = 'purple',  marker = 'v') 

fig.supxlabel('IC 1')
fig.supylabel('IC 2')
axes[0].set_title('Stationary distribution')
axes[1].set_title('Weighted free energy surface')
axes[2].set_title('State assignment')
fig.tight_layout()

plt.savefig(fname=os.path.join(fig_path, 'state_assignment.png'))

In [None]:
core_membership_cutoff = 0.9
dists = msm.metastable_distributions.copy()
dists[msm.metastable_memberships.T<core_membership_cutoff] = 0
dists = dists/np.sum(dists, axis=1)[:, np.newaxis]
eigvec = msm.eigenvectors_right()
max_eig = eigvec[:,1].argmax()
min_eig = eigvec[:,1].argmin()

In [None]:
with sns.plotting_context('paper',font_scale=1.5): 
    fig, axes =  plt.subplots(1,2, figsize = (12,5), sharex=True, sharey=True)
    for i_set, ax in enumerate(axes.flat):
        _, _, misc = pyemma.plots.plot_free_energy(*tica_concatenated[:, :2].T, weights=np.concatenate(msm.trajectory_weights()),
            ax=ax, legacy=False, cmap='viridis', alpha=0.5)
        ax.scatter(*cluster.clustercenters[:,:2].T, c=dists[i_set], cmap='Greys', s=15, 
                   vmin=dists[i_set].min(), vmax=dists[i_set].max(), alpha=0.8)
        ax.scatter(*cluster.clustercenters[[min_eig,max_eig][i_set],:2].T, color='r', marker='*', s=200 )
        ax.set_title(f'State {i_set+1}')
        misc['cbar'].ax.tick_params(labelsize=14)
        misc['cbar'].set_label(label='free energy / KT')
        ax.tick_params(bottom=True, top=False, left=True, right=False)
        ax.set_aspect('equal')
    fig.supxlabel('IC 1')
    fig.supylabel('IC 2')
    plt.tight_layout()

    plt.savefig(fname=os.path.join(fig_path, 'SI.png'))

In [None]:
with sns.plotting_context('paper', font_scale=2):
    fig, axes = plt.subplots(2, sharex=True, sharey=True, figsize=(8, 12))

    for i_ev, ax in enumerate(axes): 
        vmin = np.min(msm.eigenvectors_right()[:, i_ev+1]) 
        vmax = np.max(msm.eigenvectors_right()[:, i_ev+1]) 
        divnorm=colors.TwoSlopeNorm(vmin=vmin, vcenter=0., vmax=vmax)

        if tica_concatenated.shape[1]==1:
            pyemma.plots.plot_feature_histograms(tica_concatenated, ax=ax)
            x = kmeans_mod.clustercenters[:, 0]
            y = np.repeat(0,kmeans_mod.clustercenters.shape[0])
        else:
            pyemma.plots.plot_free_energy(tica_concatenated[:, 0], tica_concatenated[:, 1], ax=ax, cmap='viridis', alpha=0.5, weights=np.concatenate(msm.trajectory_weights()))
            x = cluster.clustercenters[:, 0]
            y = cluster.clustercenters[:, 1]    
        
        ax.tick_params(bottom=True, top=False, left=True, right=False)
        ax.set_title(f'eigenvector {i_ev+2}')
        ax.scatter(x=x, y=y, c=msm.eigenvectors_right()[:, i_ev+1], cmap='bwr',norm=divnorm, s=50)
    fig.supxlabel('IC1')
    fig.supylabel('IC2')
    plt.tight_layout()
    
    plt.savefig(fname=os.path.join(fig_path, 'eigenvector.png'))

***
### Sample metastates centres

In [None]:
n_sample = 20
n_process = 2
sample_path = r'./sample/'
if not os.path.exists(sample_path): os.mkdir(sample_path)

In [None]:
smpl_dic = {}
eigvec = msm.eigenvectors_right()

for i in range(n_process):
    max_eig = eigvec[:,i+1].argmax()
    min_eig = eigvec[:,i+1].argmin()
    med_eig = np.argsort(eigvec[:,i])[len(eigvec[:,i])//2]
    
    smpl_dic['eig{}_max'.format(i+2)] = max_eig
    smpl_dic['eig{}_min'.format(i+2)] = min_eig
    smpl_dic['eig{}_med'.format(i+2)] = med_eig

In [None]:
for name, state in smpl_dic.items():
    path = os.path.join(sample_path, '{}_{}structures.pdb'.format(name,n_sample))
    path_aligned = os.path.join(sample_path, '{}_{}structures_aligned.pdb'.format(name,n_sample))
    
    # Generate samples for each state 
    samples = msm.sample_by_state(n_sample, subset=[state])
    samples_extreme = pyemma.coordinates.save_trajs(reader, samples, outfiles=[path])
    
    # Align samples
    # Compare RMSD within states
    uni = mda.Universe(top_path, path, dt = 40)
    matrix = diffusionmap.DistanceMatrix(uni, select='name CA').run()
    aligner = align.AlignTraj(uni, uni, select='name CA', filename=path_aligned).run()
    
    uni_aligned = mda.Universe(top_path, path_aligned, dt = 40)
    matrix_aligned = diffusionmap.DistanceMatrix(uni_aligned, select='name CA').run()
    print('Averaged pairwise RMSD of {} structures before [{}({})] and after [{}({})] alignment'.format(name, matrix.dist_matrix.mean(), matrix.dist_matrix.std(), matrix_aligned.dist_matrix.mean(),matrix_aligned.dist_matrix.std()))

# Compute RMSD between extreme states
for i in range(n_process):
    max_traj = mda.Universe(top_path,f'./sample/eig{i+2}_max_{n_sample}structures_aligned.pdb', dt = 40)
    min_traj = mda.Universe(top_path,f'./sample/eig{i+2}_min_{n_sample}structures_aligned.pdb', dt = 40)
    prmsd = np.zeros((n_sample, n_sample))
    for j, frame_max in enumerate(max_traj.trajectory):
        r = rms.RMSD(min_traj, max_traj, select='name CA', ref_frame=j).run()
        prmsd[i] = r.rmsd[:, -1]
    print('Averaged RMSD between eig{} max and min structures: {}({})'.format(i+2, prmsd.mean(), prmsd.std()))

***
### Compute Rg and SASA

In [None]:
bayesian_msm.pcca(2)

# Generate 100 sample for each metastable membership distribution
sample_by_metastable_distributions = bayesian_msm.sample_by_distributions(bayesian_msm.metastable_distributions, 100)
print('Sample for {} metastable states'.format(len(sample_by_metastable_distributions)))

# Generate 10 samples for each microstate
sample_by_microstate = [smpl for smpl in bayesian_msm.sample_by_state(20)]
print('Sample for {} states'.format(len(sample_by_microstate)))

# Save generated molecular structures in trajectories
metastable_samples = [pyemma.coordinates.save_traj(reader, smpl, outfile=None, top=top_path) for smpl in sample_by_metastable_distributions]
microstate_samples = [pyemma.coordinates.save_traj(reader, smpl, outfile=None, top=top_path) for smpl in sample_by_microstate]

# Compute both observables
metastable_rg_all = [compute_rg(sample) for sample in metastable_samples]
metastable_sasa_all = [np.sum(shrake_rupley(sample, mode='residue'), axis = 1) for sample in metastable_samples]
markov_average_rg = [compute_rg(sample).mean() for sample in microstate_samples]
markov_average_sasa = [shrake_rupley(sample, mode='residue').sum(axis = 1).mean() for sample in microstate_samples]

# Weigh both observables over msm
equilibrium_rg = bayesian_msm.expectation(markov_average_rg)
equilibrium_sasa = bayesian_msm.expectation(markov_average_sasa)
equilibrium_std_rg = bayesian_msm.sample_std('expectation', markov_average_rg)
equilibrium_std_sasa = bayesian_msm.sample_std('expectation', markov_average_sasa)

# Compute mean and std gyration and sasa for each set of representative structures
metastable_average_rg, metastable_std_rg = np.array(metastable_rg_all).mean(axis=1), np.array(metastable_rg_all).std(axis=1)
metastable_average_sasa, metastable_std_sasa = np.array(metastable_sasa_all).mean(axis=1), np.array(metastable_sasa_all).std(axis=1)

dataframe_dict = {'metastable_state': [1,2, 'equilibrium'], 
                  'mean_rg' : np.append(metastable_average_rg,equilibrium_rg), 
                  'std_rg' : np.append(metastable_std_rg, equilibrium_std_rg),
                  'mean_sasa' : np.append(metastable_average_sasa, equilibrium_sasa), 
                  'std_sasa': np.append(metastable_std_sasa, equilibrium_std_sasa)}

results = pd.DataFrame(dataframe_dict).set_index('metastable_state')
results.to_csv(os.path.join(fig_path, 'observables.csv'), index = False)
results