## Generate WLALL figure from manuscript

Simon Olsson 2018

Featurizes and analyses 24 trajectories of the WLALL peptide (25 trajectories in the original data-set). One trajectory (15) is left out as a rare event is happening in the first tens of nano-second in this trajectory which is not reversibly sampled. 

Generates manuscript figure for WLALL peptide.

Please note, since this notebook makes use of random sampling for error-estimation, exact reproduction cannot be expected. The notebook requires internet access, as primary data is downloaded. Complete execution of the notebook may vary between tens of minutes to hours depending on available hardware, speed internet connectivity and server load.

In [None]:
%matplotlib inline
import mdtraj as md
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
import mdshare
import matplotlib as mpl
import numpy as np
import pyemma as pe
import msmtools
from sklearn.preprocessing import LabelBinarizer

from graphtime import markov_random_fields
from graphtime import utils as _ut

double_column_width = 6.968
single_column_width= 3.307
font = {'sans-serif': "Arial",
        'family': "sans-serif",
  'size'   : 8}


mpl.rc('font', **font)
mpl.rcParams['mathtext.fontset'] = 'custom'

Load and prepare data for model estimation

In [None]:
def to_red_tmat(Tfull, dtrajs_):
    """ 
        Slices out a sub-matrix of Tfull consistent with states observed in dtrajs_ 
        and renormalizes the sub-matrix to yield a transition matrix on the subset.
    """
    tft=Tfull[np.array(list(set(np.ravel(dtrajs_)))), :][:, np.array(list(set(np.ravel(dtrajs_))))].copy()
    return tft/tft.sum(axis=1)[:, None]

def discr_feats(ftrajs, feat_describe):
    discr_trajs = []
    for ft in ftrajs:
        dftraj = np.zeros(ft.shape, dtype = int)
        for i, fstr in enumerate(feat_describe):
            ls = fstr.split()
            if fstr[:3] == "PHI": # split into two states
                dftraj[:, i] = (ft[:, i]<0).astype(int)
            elif fstr[:3] == "PSI": # split in to two states if not n-terminal
                if int(ls[-1]) == 1:
                    dftraj[:, i] = -1
                else:
                    dftraj[:, i] = (ft[:, i]<80).astype(int)
            elif fstr[:3] == "CHI": #split into 3 rotamers
                tv = (ft[:, i]+180+60)%360
                dftraj[:, i] = (tv>125).astype(int) + (tv>250).astype(int)  
        non_n_psi = np.where(dftraj[0, :]>-1)[0]    
        discr_trajs.append(dftraj.copy()[:,non_n_psi])
    return discr_trajs, [f for i,f in enumerate(feat_describe) if i in non_n_psi]


def featurize(X):
    dts=[]
    for x in X:
        dts.append ([int(''.join(map(str, f)), 2) for f in x])
    return dts


In [None]:
pdb = mdshare.fetch('pentapeptide-impl-solv.pdb', working_directory='pentapeptide_data')
files = mdshare.fetch('pentapeptide-*-500ns-impl-solv.xtc', working_directory='pentapeptide_data')

In [None]:
feat = pe.coordinates.featurizer('pentapeptide_data/pentapeptide-impl-solv.pdb')

feat.add_backbone_torsions(deg=True)

source = pe.coordinates.source([f'pentapeptide_data/pentapeptide-{i:02}-500ns-impl-solv.xtc' for i in range(25)], features=feat)

dihe = source.get_output()

In [None]:
bindihe = [(d<0).astype(int) for d in dihe]

In [None]:
dfeats, nlbls = discr_feats(dihe, feat.describe())

Estimate Markov state models and Dynamic graphical models for multiple lag-times.

In [None]:
from functools import reduce
MRF_MSM = []
MSM_MSM = []
X = dfeats.copy()

idx=np.array([i for i in np.arange(25) if i not in [15]])
np.random.shuffle(idx)
_x=np.array_split(idx, 5)
_I=[np.concatenate([_x[j] for j in range(5) if j!=i]) for i in range(5)]
dtrajs_ = featurize(X)

lr_kwargs = {'fit_intercept': True, 'penalty': 'l1', 'C': 1.0, 'tol': 0.0001, 'solver': 'saga'} 
for lag in [2,5,10,15,20,30,50]:
    _dmrfs = [markov_random_fields.estimate_dMRF([2*X[i]-1 for i in I], 
                                                     lag = lag, 
                                                     Encoder = LabelBinarizer(neg_label = -1, pos_label = 1),
                                                     logistic_regression_kwargs = lr_kwargs
                                                    ) 
                  for I in _I]
    Ts = [_D.generate_transition_matrix() for _D in _dmrfs ]
    bmsm = [pe.msm.estimate_markov_model(dtrajs=[dtrajs_[i] for i in I], lag = lag) for I in _I]
    ms_interesection=reduce(np.intersect1d, [m.active_set for m in bmsm ])
    bm_as=[np.where(m.active_set==ms_interesection.reshape(-1,1))[1] for m in bmsm ]
    
    Ts_red = [to_red_tmat(T, [ms_interesection]) for T in Ts]
    MRF_MSM.append(Ts_red)
    MSM_MSM.append([bmsm, bm_as])

In [None]:
lag=20
_dmrfs = [markov_random_fields.estimate_dMRF([2*X[i]-1 for i in I], 
                                                 lag = lag, 
                                                 Encoder = LabelBinarizer(neg_label = -1, pos_label = 1),
                                                 logistic_regression_kwargs = lr_kwargs
                                                ) 
              for I in _I]
Ts = [_D.generate_transition_matrix() for _D in _dmrfs ]
bmsm = [pe.msm.estimate_markov_model(dtrajs=[dtrajs_[i] for i in I], lag = lag) for I in _I]
ms_interesection_selected_lag=reduce(np.intersect1d, [m.active_set for m in bmsm ])


Compute statistics and prepare input for plotting

In [None]:
import msmtools
its_mrf = np.array([msmtools.util.statistics.confidence_interval([msmtools.analysis.timescales(t)[1:7]*l for t in M])  for l,M in zip([2,5,10,15,20,30,50],MRF_MSM)])
its_msm = np.array([msmtools.util.statistics.confidence_interval([t.timescales(k=6) for t in M[0]])  for l,M in zip([2,5,10,15,20,30,50],MSM_MSM)])

In [None]:
bayes_msm = pe.msm.bayesian_markov_model([dt for i, dt in enumerate(dtrajs_) if i!=15], lag = 20)

In [None]:
bsmts=bayes_msm.sample_mean('timescales')
bscts=bayes_msm.sample_conf('timescales')

In [None]:

statdist_mrf_conf = np.array([msmtools.util.statistics.confidence_interval([msmtools.analysis.statdist(t) for t in M])  for l,M in zip([2,5,10,15,20,30,50],MRF_MSM)])
statdist_msm_conf = np.array([msmtools.util.statistics.confidence_interval([t.stationary_distribution[a_s] for t,a_s in zip(*M) ])  for l,M in zip([2,5,10,15,20,30,50],MSM_MSM)])

statdist_mrf_avg = np.array([np.mean([msmtools.analysis.statdist(t) for t in M], axis=0)  for l,M in zip([2,5,10,15,20,30,50],MRF_MSM)])
statdist_msm_avg = np.array([np.mean([t.stationary_distribution[a_s] for t,a_s in zip(*M) ], axis=0)  for l,M in zip([2,5,10,15,20,30,50],MSM_MSM)])


xerr = np.vstack([statdist_mrf_avg[4]-statdist_mrf_conf[4][0], statdist_mrf_conf[4][1]-statdist_mrf_avg[4]])
yerr = np.vstack([statdist_msm_avg[4]-statdist_msm_conf[4][0], statdist_msm_conf[4][1]-statdist_msm_avg[4]])
    

idx = [i for i in np.arange(25) if i not in [15]]
np.random.shuffle(idx)

bmsm = [pe.msm.estimate_markov_model(dtrajs=[dtrajs_[i][:] for i in idx[:I]], lag = 20) for I in range(1,25)]
ms_interesection=reduce(np.intersect1d, [m.active_set for m in bmsm ])
bm_as=[np.where(m.active_set==ms_interesection.reshape(-1,1))[1] for m in bmsm ]




In [None]:
fig = plt.figure(figsize=(single_column_width, 2*single_column_width))
dt = 0.1
gs = gridspec.GridSpec(200, 100,left=0.1,bottom=-0.05,top=1.0,right=1.0, wspace=0.05, hspace=.05)
axts = np.array([[plt.subplot(gs[100+i*18+i*12:100+(i+1)*18+i*12, 8+j*35+j*15:8+(j+1)*35+j*15]) for i in range(3)] for j in range(2)]).T
axsd = np.array([[plt.subplot(gs[50:80, 8+j*33+j*18:8+(j+1)*33+j*18]) for i in range(1)] for j in range(2)]).T[::-1,:]

for i,_ax in enumerate(axts.flatten()[:]):
    # confidence intervals of MSM/DGM ITS
    _ax.fill_between(np.array([2,5,10,15,20,30,50])*dt,its_msm[:,0,i]*dt, its_msm[:,1,i]*dt,alpha=0.35, label="MSM")
    _ax.fill_between(np.array([2,5,10,15,20,30,50])*dt,its_mrf[:,0,i]*dt, its_mrf[:,1,i]*dt,alpha=0.35, label="DGM")
    
    # Bayesian MSM error-bars
    _ax.hlines(bsmts[i]*dt, 19.5*dt,20.5*dt, lw=2, linestyle=":")
    _ax.fill_between(np.array([19.5,20.5])*dt,bscts[0][i]*np.ones(2)*dt, bscts[1][i]*np.ones(2)*dt,lw=0,alpha=0.35,color='k',label="Bayesian MSM")

    #  MSM/DGM ITS
    _ax.plot(np.array([2,5,10,15,20,30,50])*dt,np.array([2,5,10,15,20,30,50])*dt, color="k")
    _ax.fill_between(np.array([2,5,10,15,20,30,50])*dt,np.zeros(7), np.array([2,5,10,15,20,30,50])*dt, color='k',alpha=0.25,)
    _ax.set_ylim([0.5,8])
    if i==2:
        _ax.set_ylabel(r'implied timescale / $\mathrm{ns}$')
    if i>2:
        _ax.set_xlabel(r'lag time $\tau$ / $\mathrm{ns}$')
    if i==5:
        _ax.legend()
        _ax.set_ylim(10,20)
        _ax.axis('off')
    else:
        _ax.set_title('implied timescale {:d}'.format(i+1))

        
colors_ = [plt.cm.viridis(c) for c in np.linspace(0,1,bm_as[0].shape[0])]
# Correlate stationary distributions on common sub-sets
axsd[0,0].errorbar(statdist_mrf_avg[4],statdist_msm_avg[4], xerr=xerr, yerr=yerr,fmt='.',ms=0,zorder=1,ecolor='k')
axsd[0,0].scatter(statdist_mrf_avg[4],statdist_msm_avg[4], s=10, c='m',zorder=10)

axsd[0,0].set_xlim(1e-6,2);
axsd[0,0].set_ylim(1e-6,2)
axsd[0,0].plot([1e-6,2],[1e-6,2], ls=':', color='k')
axsd[0,0].set_xlabel(r'dMRF, $\pi_i$')
axsd[0,0].set_ylabel(r'MSM, $\pi_i$')
axsd[0,0].loglog()

# illustration of sub-system encoding
axsd[0,1].hist2d(np.vstack(dihe)[:,0]*np.pi/180, np.vstack(dihe)[:,3]*np.pi/180,bins=128, norm=mpl.colors.LogNorm(), alpha=0.4)
axsd[0,1].vlines(0,-1.5*np.pi,1.5*np.pi,linestyles=':', color='k' )
axsd[0,1].hlines(80.*np.pi/180,-1.5*np.pi,1.5*np.pi,linestyles=':', color='k' )
axsd[0,1].set_xticks([-np.pi,-np.pi/2,0,np.pi/2,np.pi])
axsd[0,1].set_yticks([-np.pi,-np.pi/2,0,np.pi/2,np.pi])
axsd[0,1].set_xticklabels([r'$-\pi$',r'$-\frac{\pi}{2}$',"0",r'$\frac{\pi}{2}$',r'$\pi$'])
axsd[0,1].set_yticklabels([r'$-\pi$',r'$-\frac{\pi}{2}$',"0",r'$\frac{\pi}{2}$',r'$\pi$'])
axsd[0,1].set_ylabel(r'$\psi$')
axsd[0,1].set_xlabel(r'$\phi$')
axsd[0,1].text(-np.pi/2., -np.pi/2., r"$1 / 1$", va='center', ha='center')
axsd[0,1].text(np.pi/2., -np.pi/2., r"$-1 / 1$", va='center', ha='center')
axsd[0,1].text(np.pi/2., np.pi/1.5, r"$-1 / -1$", va='center', ha='center')
axsd[0,1].text(-np.pi/2., np.pi/1.5, r"$1 / -1$", va='center', ha='center')
axsd[0,1].set_xlim((-np.pi, np.pi))
axsd[0,1].set_ylim((-np.pi, np.pi))



#embedding of structure render
structure_ax = plt.subplot(gs[1:40, 0:-10])
structure_ax.imshow(plt.imread('penta_render_trimmed.png'),interpolation='nearest')
structure_ax.axis('off')
for a,lbl in zip([axts[0,0], axsd[0,0], axsd[0,1]], ('D', 'B', 'C')):
    a.text(-0.5, 1.10, lbl, transform=a.transAxes,
      fontsize=12, va='top')

for a,lbl in zip([structure_ax], ('A')):
    a.text(-0.20, 0.95, lbl, transform=a.transAxes,
      fontsize=12, va='top')
    


#fig.savefig('Fig3_re.pdf', dpi=600)