## Dynamic graphical model of fast folder Villin
Simon Olsson 2018

This notebook makes use previously published simulation data ([Lindroff-Larsen et al. 2011](http://science.sciencemag.org/content/334/6055/517)). We do not have the rights to distribute this data but they can be requested directly from DE Shaw Research.

Note that the random seed is fixed. The notebook makes extensive use of nested point-estimates (i.e. statistical estimates subject to fluctuations which condition other such values), and can therefore be numerically unstable for some random seeds.


In [None]:
import mdtraj as md
%matplotlib inline
import matplotlib.pyplot as plt

import matplotlib as mpl
import numpy as np
import pyemma as pe
from onedeeIsing import Ising_tmatrix

from sklearn.preprocessing import LabelBinarizer
from graphtime.markov_random_fields import estimate_dMRF
from graphtime.utils import simulate_MSM as generate

double_column_width = 6.968
single_column_width= 3.307
font = {'sans-serif': "Arial",
        'family': "sans-serif",
  'size'   : 8}

mpl.rc('font', **font)
mpl.rcParams['mathtext.fontset'] = 'custom'
np.random.seed(101023)

In [None]:
feat = pe.coordinates.featurizer('villin/2F4K-0-protein/2F4K-0-protein.pdb')

feat.add_backbone_torsions(deg=True)
#feat.add_sidechain_torsions(which=['chi1'], deg=True)

source = pe.coordinates.source([['villin/2F4K-0-protein/2F4K-0-protein-{:03d}.dcd'.format(i) for i in range(63)]], features=feat)

dihe = source.get_output()

In [None]:
def discr_feats(ftrajs, feat_describe):
    discr_trajs = []
    for ft in ftrajs:
        dftraj = np.zeros(ft.shape, dtype = int)
        for i, fstr in enumerate(feat_describe):
            ls = fstr.split()
            if fstr[:3] == "PHI": # split into two states
                dftraj[:, i] = (ft[:, i]<0).astype(int)
            elif fstr[:3] == "PSI": # split in to two states if not n-terminal
                if int(ls[-1]) == 1:
                    dftraj[:, i] = -1
                else:
                    dftraj[:, i] = (ft[:, i]<80).astype(int)
            elif fstr[:3] == "CHI": #split into 3 rotamers
                tv = (ft[:, i]+180+60)%360
                dftraj[:, i] = (tv>125).astype(int) + (tv>250).astype(int)  
        non_n_psi = np.where(dftraj[0, :]>-1)[0]    
        discr_trajs.append(dftraj.copy()[:,non_n_psi])
    return discr_trajs, [f for i,f in enumerate(feat_describe) if i in non_n_psi]

In [None]:
dfeats, nlbls = discr_feats(dihe, feat.describe())

In [None]:
dihe_fixed = [dihe[0][:,[i for i in range(68) if i not in [38,66 ]]].copy()]


In [None]:
dfeats_fixed = [dfeats[0][:,[i for i in range(68) if i not in [38,66 ]]].copy()]
#dfeats_fixed[0][:,38] = 1 

In [None]:
C = 1000./len(dfeats_fixed[0])**0.5
logistic_regression_kwargs={'fit_intercept': True, 'penalty': 'l1', 'C': C, 
                            'tol': 0.0001, 'solver': 'saga'}
dmrf_all_data = estimate_dMRF([2*dfeats_fixed[0]-1], 
            lag=300, stride=10, 
            logistic_regression_kwargs=logistic_regression_kwargs,
            Encoder = LabelBinarizer(neg_label = -1, pos_label = 1))

#penta_params_r, penta_biases_r, penta_Encoder_r = estimate_potts([2*dfeats_fixed[0]-1], lag=10, C=10, fit_bias=False,tol=1e-6)

In [None]:
fig,ax=plt.subplots(ncols=1,figsize=(16,8))
ax.imshow(np.hstack([dmrf_all_data.get_subsystem_couplings(), dmrf_all_data.get_subsystem_biases().reshape(-1,1)] ))
#ax[1].imshow(np.hstack([np.vstack(penta_params_r), penta_biases_r] ))

### Build MSM

In [None]:
import pyemma as pe


In [None]:
tica_objs = [pe.coordinates.tica(dfeats_fixed, lag = lag) for lag in [5,10,50,100,200,300,500,900]]

In [None]:
fig, ax = plt.subplots(ncols = 2, figsize=(8, 3))
ax[0].semilogy([to.lag for to in tica_objs], [to.timescales[:10] for to in tica_objs])
ax[0].set_xlabel('lag time / steps')
ax[0].set_ylabel('implied timescale / steps')
ax[1].plot([to.lag for to in tica_objs], [to.ndim for to in tica_objs])
ax[1].set_xlabel('lag time / steps')
ax[1].set_ylabel('number of dimensions for 95% kinetic variance')
fig.tight_layout()


In [None]:
Y = tica_objs[5].get_output()

In [None]:
Ys = np.vstack(Y)

In [None]:
a=plt.hist2d(Ys[:, 0], Ys[:, 1], norm=mpl.colors.LogNorm(), bins=256)#, interpolation='gaussian')

In [None]:
cluster_obj = pe.coordinates.cluster_kmeans(data = Y, k=1024, stride=10)

In [None]:
its = pe.msm.its(cluster_obj.dtrajs, lags = [10, 50, 100, 200, 300, 500, 700, 900, 1000, 1200], nits=4)#, errors='bayes')

In [None]:
pe.plots.plot_implied_timescales(its, ylog=False)

In [None]:
msm = pe.msm.bayesian_markov_model(cluster_obj.dtrajs, lag = 300)

In [None]:
ckt = msm.cktest(5)

In [None]:
pe.plots.plot_cktest(ckt, diag=True)

In [None]:
HMM = msm.coarse_grain(5)

In [None]:
pe.plots.scatter_contour(cluster_obj.cluster_centers_[:, 0], cluster_obj.cluster_centers_[:, 1], -np.log(msm.stationary_distribution))

In [None]:
pe.plots.scatter_contour(cluster_obj.cluster_centers_[:, 2], cluster_obj.cluster_centers_[:, 3], HMM.observation_probabilities[0])

In [None]:


inmeta = [np.where(HMM.metastable_assignments[cluster_obj.dtrajs[0].reshape(-1)]==i)[0] for i in range(HMM.nstates)]

Meta_filtered = [np.where(np.isin(HMM.metastable_assignments[cluster_obj.dtrajs[0].reshape(-1)],[i], invert=True))[0] for i in range(HMM.nstates)]

not_in_meta_data = [[dfeats_fixed[0][t] for t in np.split(Meta_filtered[i], np.where(np.diff(inmeta[i])>1)[0]) if len(t)>300] for i in range(HMM.nstates)]

only_in_meta_data = [[dfeats_fixed[0][t] for t in np.split(inmeta[i], np.where(np.diff(inmeta[i])>1)[0]) if len(t)>2] for i in range(HMM.nstates)]

In [None]:
regl_=[1000./np.vstack([2*t-1 for t in not_in_meta_data[i]]).shape[0]**0.5 for i in range(5)]
regl_

In [None]:
dMRFs = [] 
for M in range(5):
    logistic_regression_kwargs={'fit_intercept': True, 'penalty': 'l1', 'C': regl_[M], 
                                'tol': 0.0001, 'solver': 'saga'}
    dMRFs.append(estimate_dMRF([2*t-1 for t in not_in_meta_data[M]], 
                lag=300, stride=10, 
                logistic_regression_kwargs=logistic_regression_kwargs,
                Encoder = LabelBinarizer(neg_label = -1, pos_label = 1)))


In [None]:
[len(dm.get_active_subsystems()) for dm in dMRFs]

In [None]:
fig,ax=plt.subplots(ncols=5, figsize=(12,10))
for M,d in enumerate(dMRFs):
    ax[M].imshow(np.hstack([d.get_subsystem_couplings(), d.get_subsystem_biases().reshape(-1,1)]))
fig.tight_layout()
#ax[1].imshow(np.hstack([np.vstack(villin_nf_coupl), villin_nf_bias] ))

In [None]:
plt.close('all')

In [None]:
synthts = [d.simulate(nsteps=100000, start= ((2*np.array(not_in_meta_data[M][0][0]))-1)[d.get_active_subsystems()] ) for M,d in enumerate(dMRFs)] 

In [None]:
syntht_alldata = dmrf_all_data.simulate(nsteps=100000, start = (2*dfeats_fixed[0][0])-1)

In [None]:

Y_synths = []

for s in synthts:
    Y_synths.append(np.vstack(tica_objs[5].transform([(sf+1)/2 for sf in s] )))


In [None]:
fig, axs = plt.subplots(nrows=2, ncols = 3,figsize=(single_column_width,.75*single_column_width), sharex=True, sharey=True)
ax = axs.ravel()
ax[0].hist2d(Y[0][:,0],Y[0][:,1], bins=128,norm=mpl.colors.LogNorm(),  label="all data")
ax[0].set_title('All MD data')
ax[0].set_ylabel('TIC2')

for I,ys in enumerate(Y_synths):
    ax[I+1].hist2d(ys[:,0],ys[:,1], bins=128, norm=mpl.colors.LogNorm(), label="missing meta {}".format(I))
    #ax[I+1].scatter(cluster_obj.cluster_centers_[HMM.metastable_assignments==I, 0],  cluster_obj.cluster_centers_[HMM.metastable_assignments==I, 1],marker='^', alpha=0.1, color='r')
    ax[I+1].set_title('Without {}'.format(I+1))
    if I==2:
        ax[I+1].set_ylabel('TIC2')
    if I>1:
        ax[I+1].set_xlabel('TIC1')
fig.tight_layout()
plt.savefig('tica_leave_one_out.pdf', dpi=300)

In [None]:
Y_synth_all = tica_objs[5].transform((syntht_alldata+1)/2)

In [None]:
dtraj_synths = cluster_obj.transform(Y_synths)
dtraj_synths_all = cluster_obj.transform(Y_synth_all)

In [None]:
import matplotlib.gridspec as gridspec

In [None]:
yy_=[(syntht[:,np.argsort(syntht.var(axis=0))[:]][np.where(HMM.metastable_assignments[ye[::1]]==3)[0],:].mean(axis=0)+1)/2 for syntht,ye in zip(synthts, dtraj_synths)]
xx_=dfeats_fixed[0][np.where(HMM.metastable_assignments[cluster_obj.dtrajs[0][:]]==3)[0],:].mean(axis=0)[:]

In [None]:
from scipy.stats import spearmanr


fig = plt.figure(figsize=(double_column_width/1.3,double_column_width*1.1))

gs = gridspec.GridSpec(143, 100, left=0.097,bottom=0.02,top=0.98,right=0.999)#, wspace=None, hspace=None)
ax = plt.subplot(gs[5:23, 2:22])
ax2 = plt.subplot(gs[5:23, 34:55])
ax3 = plt.subplot(gs[5:23, 70:78])
ls = []
for i in range(5): 
    ax2.scatter(np.vstack(synthts[i]).var(axis=0), (2*dfeats_fixed[0][::300, :]-1).var(axis=0), s=5, label="Without state {:d}".format(i+1))
    l = ax.scatter(np.vstack(synthts[i]).mean(axis=0), (2*dfeats_fixed[0][::300, :]-1).mean(axis=0), s=5, label="Without state {:d}".format(i+1))
    ls.append(l)
ax.set_xlabel('DGM Mean feature')
ax2.set_xlabel('DGM Variance of feature ')

ax.set_ylabel('MD Mean feature')
ax2.set_ylabel('MD Variance of feature')

ax3.axis('off')
ax3.legend(ls, ["Without state {:d}".format(i+1) for i in range(5)], fontsize=9, loc=(-0.20,0.1))
#plt.tight_layout()

#fig, ax = plt.subplots(nrows=5, ncols=5, sharey=True, sharex=True, figsize=(10,10))
ax_ = np.array([[plt.subplot(gs[40+18*(j)+2*j:40+18*(j+1)+j*2, 18*(i)+2*i:18*(i+1)+2*i]) for i in range(5)] for j in range(5)])

axf = ax_.flatten()

ax.text(-.35, 1.30, "A", transform=ax.transAxes,
          fontsize=12,  va='top')

ax2.text(-.35, 1.30, "B", transform=ax2.transAxes,
          fontsize=12,  va='top')
i=0
axf[i].text(-.35, 1.30, "C", transform=axf[i].transAxes,
          fontsize=12, va='top')
for j in range(5):
    for syntht,ye in zip(synthts, dtraj_synths):
        yy_=syntht[np.where(HMM.metastable_assignments[ye[::1]]==j)[0],:].mean(axis=0)
        xx_=2*dfeats_fixed[0][np.where(HMM.metastable_assignments[cluster_obj.dtrajs[0][:]]==j)[0],:].mean(axis=0)-1
        axf[i].scatter(xx_, yy_,s=5, color="C{:d}".format(j))
    
        axf[i].text(0.05, 0.95, r"$\rho={:0.2f}$".format(np.corrcoef(xx_,yy_)[0,1]), transform=axf[i].transAxes,
          fontsize=9, va='top')

        axf[i].set_xlim((-1,1))
        axf[i].set_ylim((-1,1))
        
        #axf[i].text(xx_.min(), yy_.max()-0.1,)
        if i<5:
            axf[i].set_title('State {}'.format(i+1), fontsize = 10)
        
        if i%5==0: 
            #axf[i].set_ylabel('Empirical mean feature')
            if i==10:
                axf[i].set_ylabel('MD mean feature')
        else:
            axf[i].set_yticklabels([])
            axf[i].set_yticks([])
        
        if i>19:
            #axf[i].set_xlabel('MRF mean feature')
            if i==22:
                axf[i].set_xlabel('DGM mean feature')
        else:
            axf[i].set_xticklabels([])
            axf[i].set_xticks([])

        
        if i%5==j:
            axf[i].set_facecolor((0.85,0.85,0.85)) #2*dfeats_fixed[0][np.where(HMM.metastable_assignments[cluster_obj.dtrajs[0][:]]==j)[0],:].mean(axis=0)-1
        i=i+1
#axf[-1].axis('off')
#gs.tight_layout(fig, pad=-0.5)
plt.savefig('feature_scatter.pdf')

In [None]:
from scipy.spatial import distance_matrix

In [None]:
not_folded = np.hstack([np.where(np.isin(HMM.metastable_assignments[cluster_obj.dtrajs],[0,1,2,4]))[0].reshape(-1,1), dfeats_fixed[0][np.where(np.isin(HMM.metastable_assignments[cluster_obj.dtrajs],[0,1,2,4]))[0], :]])

not_inmeta = [np.hstack([np.where(np.isin(HMM.metastable_assignments[cluster_obj.dtrajs],[xx for xx in range(5) if xx != i]))[0].reshape(-1,1), dfeats_fixed[0][np.where(np.isin(HMM.metastable_assignments[cluster_obj.dtrajs],[xx for xx in range(5) if xx != i]))[0], :]]) for i in range(5)]

only_inmeta = [np.hstack([np.where(np.isin(HMM.metastable_assignments[cluster_obj.dtrajs],[i]))[0].reshape(-1,1), 
                          dfeats_fixed[0][np.where(np.isin(HMM.metastable_assignments[cluster_obj.dtrajs],[i]))[0], :]]) for i in range(5)]

not_folded_split = [t[:,1:] for t in  np.split(not_folded, np.where(np.diff(not_folded[:,0])>1)[0], axis=0) if len(t)>300]

Meta_filtered = [[t[:,1:] for t in  np.split(not_inmeta[i], np.where(np.diff(not_inmeta[i][:,0])>1)[0], axis=0) if len(t)>300] for i in range(5) ]

Just_one_Meta = [[t[:,1:] for t in  np.split(only_inmeta[i], np.where(np.diff(only_inmeta[i][:,0])>1)[0], axis=0) if len(t)>300] for i in range(5) ]

In [None]:
from scipy import spatial

In [None]:
from scipy import spatial
def mrftraj_to_dtraj(mrftraj, ftraj, transformer = lambda x:0.5*(x+1)):
    dtraj_out = []
    errors = []
    for m in mrftraj:
        pair_contacts = spatial.distance.cdist(transformer(m).reshape(1,-1), ftraj, metric='hamming').reshape(-1)
        dtraj_out.append(np.argmin(pair_contacts))
        errors.append(pair_contacts[dtraj_out[-1]])
        
    return dtraj_out, errors



In [None]:
resampledtraj,errors = mrftraj_to_dtraj(syntht_alldata[:1000], dfeats_fixed[0])

In [None]:
rtrajs = []
errs = []
for i in range(5):
    _resampledtraj,_errors = mrftraj_to_dtraj(synthts[i][:1000], dfeats_fixed[0])
    rtrajs.append(_resampledtraj)
    errs.append(_errors)
    pe.coordinates.save_traj(source, np.vstack((np.zeros((1,1000)) ,_resampledtraj)).T.astype(int), "resampled_wo_{}.pdb".format(i))

In [None]:
pe.coordinates.save_traj(source, np.vstack((np.zeros((1,1000)) ,resampledtraj)).T.astype(int), "subsampled.pdb")

In [None]:
for i, sampl in enumerate(HMM.sample_by_observation_probabilities(5)):
    pe.coordinates.save_traj(source, sampl, "HMM_state_{}.pdb".format(i))

Meta-data for synthetic trajectory used for generating animated GIF

In [None]:
np.savez("resmtraj3.npz", **{'data': dfeats_fixed[0][rtrajs[3], :]})
np.savez("state_assign_resmtraj_3.npz", **{'data': HMM.metastable_assignments[cluster_obj.transform(tica_objs[5].transform(dfeats_fixed[0][rtrajs[3], :])).reshape(-1)]} )
np.savez("recerr_resmtraj_3.npz", **{'data': errs[3]} )

In [None]:
from itertools import product

In [None]:
meanfree_synthtrajs = [0.5*(synthts[i][:10000,56]+1)-0.5*(synthts[i][:10000,56]+1).mean(axis=0) for i in range(5)]
meanfree_MD = dfeats_fixed[0][::300,56]-dfeats_fixed[0][::300,56].mean(axis=0)

In [None]:
def func(x, a, b, c):
    return a * np.exp(-b * x) + c

In [None]:
from scipy.optimize import curve_fit

In [None]:
for i in range(5):
    popt, pcov = curve_fit(func, np.arange(2093)*60/1000., np.correlate(meanfree_synthtrajs[i][:2094],meanfree_synthtrajs[i][:2094], mode='full')[2094:])
    print("relaxation rate {:0.3f} for without state {:d}".format(popt[1],i+1))

popt, pcov = curve_fit(func, np.arange(2093)*60/1000., np.correlate(meanfree_MD,meanfree_MD,mode='full')[2094:])
print("relaxation rate MD", popt[1])    


In [None]:
from itertools import product

In [None]:
fig = plt.figure(figsize=(single_column_width,1.8*1.75*single_column_width/2))

gs = gridspec.GridSpec(180, 90,left=0.16,bottom=0.05,top=0.95,right=0.95, wspace=0.0, hspace=0.)#, width_ratios=1, height_ratios=1)
axs = [plt.subplot(gs[30*(i)+15:30*(i+1)+15, 3+29*(j):3+29*(j+1)]) for i,j in product(range(2), range(3))]
axs.append(plt.subplot(gs[:15, :]))
#gs.update(hspace=0, wspace=0)

ax_ylbl = plt.subplot(gs[30:,:3])
ax_ylbl.set_ylabel(r'$p(x)$ (log-scale)')
ylbl=ax_ylbl.axes.yaxis.get_label()
fig.text(ylbl.get_position()[0]+0.05,1.5*ylbl.get_position()[1]+0.10*2/3 , ylbl.get_text(), rotation=90)

ax_ylbl.axis('off')


ls = []
for i, ax in enumerate(axs):
    ax.set_ylim((2e-4,2.0))
    if i==5:  
        ising_msm = pe.msm.estimate_markov_model(HMM.metastable_assignments[dtraj_synths_all.reshape(-1)], lag=1)
        l = ax.bar(range(1,6), ising_msm.stationary_distribution,hatch="//", fill=False, label="Ising")
        ls.append(l)
        l2 = ax.bar(range(1,6), HMM.stationary_distribution,fill=True, alpha=0.2, log=True, label="HMM (full MD)")
        #ax.legend([l, l2, l3], ["MRF", "HMM (full MD)", "Missing state"], fontsize=12, loc=3)
        #ax.axis('off')
        ax.set_yticks([])
        ax.set_xticks([1,2,3,4,5])
        #ax.set_xlabel('Meta-stable state')
    elif i==6:
        ax.legend([l, l2, l3], ["dMRF", "HMM (full MD)", "State left out during estimation"], fontsize=8, loc=(0.27,0.2))
        ax.set_yticks([])
        ax.set_xticks([])
        ax.axis('off')
        
    else:
        ising_msm = pe.msm.estimate_markov_model(HMM.metastable_assignments[dtraj_synths[i].reshape(-1)], lag=1)
        l = ax.bar(range(1,6), ising_msm.stationary_distribution,hatch="//", fill=False, label="Ising")
        ls.append(l)
        l2 = ax.bar(range(1,6), HMM.stationary_distribution,fill=True, alpha=0.2, log=True, label="HMM (full MD)")
        ls.append(l2)
        l3 = ax.scatter([i+1], 0.5, marker="*", s=50, color='purple')

        if i>2:
            ax.set_xticks([1,2,3,4,5])
            if i>3:
                ax.set_xlabel(r'Meta-stable state / $x$')

        else:
            ax.set_xticks([])

        if i in [0, 3]:
            continue
            #ax.set_ylabel('State prob (log-scale)')
        else:
            ax.set_yticks([])

axs2 = [plt.subplot(gs[105:145, 35*i+20*i:35*(i+1)+20*i]) for i in range(2) ]

bins = np.unique(np.concatenate(errs))
for i, err in enumerate(errs):
    axs2[0].hist(err, bins=bins, label = "Without state {:d}".format(i+1), histtype='step', lw=1, normed=True, log=False)
axs2[0].set_xlabel(r'$\epsilon $')
axs2[0].set_ylabel(r'$p(\epsilon)$')


for i in range(5):
    _a = np.correlate(meanfree_synthtrajs[i][:10000],meanfree_synthtrajs[i][:10000], mode='full')[10000:]
    axs2[1].plot(np.arange(1,10000)*60./1000.,_a/_a[0], label="Without state {:d}".format(i+1) )
_a = np.correlate(meanfree_MD,meanfree_MD,mode='full')[2094:]
axs2[1].plot(np.arange(1,2094)*60./1000.,_a/_a[0], label="MD", color='k')

axs2[1].set_xlabel(r'$\tau$ / $\mu$s')

axs2[1].set_ylabel(r'$C(\tau)$ / Lys 71 $\phi$ rotamer')
axs2[1].semilogx()
axs2[1].set_xlim((axs2[1].get_xlim()[0],100))
ax3 = plt.subplot(gs[165:, :])
ax3.legend([child for child in axs2[1].get_children() if isinstance(child, mpl.lines.Line2D)],
                 [child.get_label() for child in axs2[1].get_children() if isinstance(child, mpl.lines.Line2D)]
                 ,loc=(0.0,-0.4), ncol=2)#axs2[0].legend()
ax3.axis('off')
axs[0].text(-.35, 1.30, "A", transform=axs[0].transAxes,
          fontsize=12,  va='top')
axs2[0].text(-.35, 1.30, "B", transform=axs2[0].transAxes,
          fontsize=12,  va='top')
axs2[1].text(-.35, 1.30, "C", transform=axs2[1].transAxes,
          fontsize=12,  va='top')
#gs.tight_layout(fig)#, pad=-1.5)
plt.savefig('Villin_statdist_corrfunc.pdf')

In [None]:

from matplotlib.colors import hex2color

In [None]:
new_colors = ['#1f77b4', '#ff7f0e', '#2ca02c', '#d62728',
              '#9467bd', '#8c564b', '#e377c2', '#7f7f7f',
              '#bcbd22', '#17becf']
for i,nc in enumerate(new_colors):
    print("set_color C{:}, [{:}, {:}, {:}]".format(i+1, *hex2color(nc)))

In [None]:
bba_errs = np.loadtxt('bba_errs.txt')
bba_synthts = np.loadtxt('bba_synthtrajs.txt')
bba_dmrf_mus = [np.loadtxt('bba_hist_dmrf{}.txt'.format(i)) for i in range(4)]
bba_hist_hmm = np.loadtxt('bba_hist_hmm.txt')
bba_syntht_all = np.loadtxt('bba_syntht_all.txt')
bba_hmm = pe.load('bba_hmm.pyemma')
bba_md_acf = np.loadtxt('bba_md_acf.txt')

In [None]:
import msmtools as mt

In [None]:
fig = plt.figure(figsize=(double_column_width,0.90*1.8*1.75*single_column_width/2+2./3*single_column_width ))

gs = gridspec.GridSpec(180+60, 180,left=0.08,bottom=0.025,top=0.95,right=0.95, wspace=0.0, hspace=0.)#, width_ratios=1, height_ratios=1)
axs = [plt.subplot(gs[30*(i)+62+15:30*(i+1)+62+15, 3+27*(j):3+27*(j+1)]) for i,j in product(range(2), range(3))]
axs.append(plt.subplot(gs[62:62+15, :]))
axs3 = [plt.subplot(gs[30*(i)+62+15:30*(i+1)+62+15, 10+27*(j)+90:10+27*(j+1)+90]) for i,j in product(range(2), range(3))]

#gs.update(hspace=0, wspace=0)

ax_ylbl = plt.subplot(gs[30+62:,:3])
ax_ylbl.set_ylabel(r'$p(x)$ (log-scale)')
ylbl=ax_ylbl.axes.yaxis.get_label()
fig.text(ylbl.get_position()[0]+0.03,0.5*ylbl.get_position()[1]+1.0*1/3 , ylbl.get_text(), rotation=90)

ax_ylbl.axis('off')


ls = []
for i, ax in enumerate(axs):
    ax.set_ylim((2e-4,5.0))
    if i==5:  
        ising_msm = pe.msm.estimate_markov_model(HMM.metastable_assignments[dtraj_synths_all.reshape(-1)], lag=1)
        l = ax.bar(range(1,6), ising_msm.stationary_distribution,hatch="//", fill=False, label="Ising")
        ls.append(l)
        l2 = ax.bar(range(1,6), HMM.stationary_distribution,fill=True, alpha=0.2, log=True, label="HMM (full MD)")
        #ax.legend([l, l2, l3], ["MRF", "HMM (full MD)", "Missing state"], fontsize=12, loc=3)
        #ax.axis('off')
        ax.text(3.,3.0, "all data", va='top',ha='center')
        ax.set_yticks([])
        ax.set_xticks([1,2,3,4,5])
        #ax.set_xlabel('Meta-stable state')
    elif i==6:
        #ax.legend([l, l2, l3], ["DGM", "HMM (full MD)", "State left out during estimation"])#, fontsize=6, loc=(1.5,-1.0))
        ax.set_yticks([])
        ax.set_xticks([])
        ax.axis('off')
        
    else:
        ising_msm = pe.msm.estimate_markov_model(HMM.metastable_assignments[dtraj_synths[i].reshape(-1)], lag=1)
        l = ax.bar(range(1,6), ising_msm.stationary_distribution,hatch="//", fill=False, label="Ising")
        ls.append(l)
        l2 = ax.bar(range(1,6), HMM.stationary_distribution,fill=True, alpha=0.2, log=True, label="HMM (full MD)")
        ls.append(l2)
        l3 = ax.scatter([i+1], 0.5, marker="*", s=50, color='purple')

        if i>2:
            ax.set_xticks([1,2,3,4,5])
            if i>3:
                ax.set_xlabel(r'Meta-stable state / $x$')

        else:
            ax.set_xticks([])

        if i in [0, 3]:
            continue
            #ax.set_ylabel('State prob (log-scale)')
        else:
            ax.set_yticks([])

for i, ax in enumerate(axs3):
    ax.set_ylim((2e-4,5.0))
    if i==4:  
        ising_msm = pe.msm.estimate_markov_model(bba_hmm.metastable_assignments[bba_syntht_all.reshape(-1).astype(int)], lag=1)
        l = ax.bar(range(1,5), ising_msm.stationary_distribution,hatch="//", fill=False, label="DGM")
        #ls.append(l)
        l2 = ax.bar(range(1,5), bba_hist_hmm,fill=True, alpha=0.2, log=True, label="HMM (full MD)")
        #ax.legend([l, l2, l3], ["MRF", "HMM (full MD)", "Missing state"], fontsize=12, loc=3)
        #ax.axis('off')
        ax.set_yticks([])
        ax.set_xticks([1,2,3,4])
        ax.set_xlabel(r'Meta-stable state / $x$')
        ax.text(2.5,3.0, "all data", va='top',ha='center')
        #ax.set_xlabel('Meta-stable state')
    elif i==5:
        #ax.legend([l, l2, l3], ["dMRF", "HMM (full MD)", "State left out during estimation"], fontsize=8, loc=(0.27/2,0.2))
        ax.set_yticks([])
        ax.set_xticks([])
        ax.axis('off')
        
    else:
        ising_msm = pe.msm.estimate_markov_model(HMM.metastable_assignments[dtraj_synths[i].reshape(-1)], lag=1)
        l = ax.bar(range(1,5), bba_dmrf_mus[i],hatch="//", fill=False, label="Ising")
        #ls.append(l)
        l2 = ax.bar(range(1,5), bba_hist_hmm,fill=True, alpha=0.2, log=True, label="HMM (full MD)")
        ls.append(l2)
        l3 = ax.scatter([i+1], 0.5, marker="*", s=50, color='purple')

        if i>1:
            ax.set_xticks([1,2,3,4])
            #if i>3:
           #     ax.set_xlabel(r'Meta-stable state / $x$')

        else:
            ax.set_xticks([])

        if i in [0, 3]:
            continue
            #ax.set_ylabel('State prob (log-scale)')
        else:
            ax.set_yticks([])

            
axs2 = [plt.subplot(gs[62+105:62+145, 30*i+20*i:30*(i+1)+20*i]) for i in range(2) ]

axs4 = [plt.subplot(gs[62+105:62+145, 30*i+20*i+100:30*(i+1)+20*i+100]) for i in range(2) ]


ym = np.array(errs).mean(axis=1)
yconf = mt.util.statistics.confidence_interval(np.array(errs).T, conf=.68)

#for i, err in enumerate(bba_errs):
bl=axs2[0].bar(np.arange(1, ym.shape[0]+1), ym, yerr=(ym-yconf[0],yconf[1]+ym))
for i,b in enumerate(bl):
    b.set_color(f'C{i}')

axs2[0].set_ylabel(r'$\epsilon $')
axs2[0].set_xlabel(r'Without state')
axs2[0].set_xticks([1,2,3,4,5])


#for i, err in enumerate(errs):
#    axs2[0].hist(err, bins=bins, label = "Without state {:d}".format(i+1), histtype='step', lw=1, normed=True, log=False)
#axs2[0].set_xlabel(r'$\epsilon $')
#axs2[0].set_ylabel(r'$p(\epsilon)$')

ym = bba_errs.mean(axis=1)
yconf = mt.util.statistics.confidence_interval(bba_errs.T, conf=.68)

#for i, err in enumerate(bba_errs):
bl=axs4[0].bar(np.arange(1, ym.shape[0]+1), ym, yerr=(ym-yconf[0],yconf[1]+ym))
for i,b in enumerate(bl):
    b.set_color(f'C{i}')
axs4[0].set_ylabel(r'$\epsilon $')
axs4[0].set_xlabel(r'Without state')
axs4[0].set_xticks([1,2,3,4])

for i in range(5):
    _a = np.correlate(meanfree_synthtrajs[i][:10000],meanfree_synthtrajs[i][:10000], mode='full')[10000:]
    axs2[1].plot(np.arange(1,10000)*60./1000.,_a/_a[0], label="Without state {:d}".format(i+1) )
_a = np.correlate(meanfree_MD,meanfree_MD,mode='full')[2094:]
axs2[1].plot(np.arange(1,2094)*60./1000.,_a/_a[0], label="MD", color='k')

axs2[1].set_xlabel(r'$\tau$ / $\mu$s')

axs2[1].set_ylabel(r'$C(\tau)$ / Lys 71 $\phi$ rotamer')
axs2[1].semilogx()
axs2[1].set_xlim((axs2[1].get_xlim()[0],50))


for i in range(4):
    _a = np.correlate(bba_synthts[i][:],bba_synthts[i][:], mode='full')[2094:]
    axs4[1].plot(np.arange(1,2094)*60./1000.,_a/_a[0], label="Without state {:d}".format(i+1) )
    
_a = np.loadtxt('bba_md_acf.txt')
axs4[1].plot(np.arange(1,1277)*60./1000.,_a/_a[0], label="MD", color='k')

axs4[1].set_xlabel(r'$\tau$ / $\mu$s')

axs4[1].set_ylabel(r'$C(\tau)$ / Glu 17 $\phi$ rotamer')
axs4[1].semilogx()
axs4[1].set_xlim((axs2[1].get_xlim()[0],50))


# ["DGM", "HMM (full MD)", "State left out during estimation"]
ax3 = plt.subplot(gs[60+165:, :])
ax3.legend([child for child in axs2[1].get_children() if isinstance(child, mpl.lines.Line2D)]+[l, l2, l3],
                 [child.get_label() for child in axs2[1].get_children() if isinstance(child, mpl.lines.Line2D)]+["DGM", "HMM (full MD)", "State left out"]
                 ,loc=(-0.025,-0.28), ncol=5)#axs2[0].legend()
ax3.axis('off')
axs[0].text(-.35, 1.30, "B", transform=axs[0].transAxes,
          fontsize=12,  va='top')
axs2[0].text(-.35, 1.30, "C", transform=axs2[0].transAxes,
          fontsize=12,  va='top')
axs2[1].text(-.35, 1.30, "D", transform=axs2[1].transAxes,
          fontsize=12,  va='top')

ax3.axis('off')
axs3[0].text(-.35, 1.30, "F", transform=axs3[0].transAxes,
          fontsize=12,  va='top')
axs4[0].text(-.35, 1.30, "G", transform=axs4[0].transAxes,
          fontsize=12,  va='top')
axs4[1].text(-.35, 1.30, "H", transform=axs4[1].transAxes,
          fontsize=12,  va='top')

gs_renders= [plt.subplot(gs[33*(i)+2:33*(i+1)+2, 29*(j):29*(j+1)]) for i,j in product(range(2), range(3))]#[plt.subplot(gs[2+j*34:2+(j+1)*34, 34*(i%5):34*((i+1)%5)]) for j in range(2) for i in range(3)]
gs_renders_= [plt.subplot(gs[33*(i)+2:33*(i+1)+2, 10+29*(j)+90:10+29*(j+1)+90]) for i,j in product(range(2), range(3))]#[plt.subplot(gs[2+j*34:2+(j+1)*34, 34*(i%5):34*((i+1)%5)]) for j in range(2) for i in range(3)]


for i in range(5):
    gs_renders[i].imshow(plt.imread('villin_hmm{:}.png'.format(i)))
    gs_renders[i].axis('off')
    if i == 0:
        gs_renders[i].text(0.1, 0.95, 'A', transform=gs_renders[i].transAxes,
              fontsize=12,  va='top')

gs_renders[-1].axis('off')

for i in range(4):
    gs_renders_[i].imshow(plt.imread('bba_hmm{:}.png'.format(i+1)))
    gs_renders_[i].axis('off')
    if i == 0:
        gs_renders_[i].text(0.1, 0.95, "E", transform=gs_renders_[i].transAxes,
             fontsize=12,  va='top')
gs_renders_[-1].axis('off')
gs_renders_[-2].axis('off')


axs3[1].text(.5, 1.30, "BBA", transform=gs_renders_[1].transAxes,
          fontsize=15,  va='top', horizontalalignment='center', verticalalignment='center')
axs[1].text(0.5, 1.30, "Villin", transform=gs_renders[1].transAxes,
          fontsize=15,  va='top', horizontalalignment='center', verticalalignment='center')
#gs.tight_layout(fig)#, pad=-1.5)
plt.savefig('Villin_statdist_corrfunc.pdf', dpi=300)

In [None]:
from scipy.optimize import linear_sum_assignment

In [None]:

tica_dmrfs = [pe.coordinates.tica(data=[(syntht_alldata+1)/2], lag=lag) for  lag in [1,2,3,4,5]]

fig, ax = plt.subplots(ncols = 2, figsize=(8, 3))
ax[0].semilogy([to.lag for to in tica_dmrfs], [to.timescales[:10] for to in tica_dmrfs])
ax[0].set_xlabel('lag time / steps')
ax[0].set_ylabel('implied timescale / steps')
ax[1].plot([to.lag for to in tica_dmrfs], [to.ndim for to in tica_dmrfs])
ax[1].set_xlabel('lag time / steps')
ax[1].set_ylabel('number of dimensions for 95% kinetic variance')
fig.tight_layout()


hva_saa=tica_dmrfs[1].get_output()

a=plt.hist2d(hva_saa[0][:,0],hva_saa[0][:,1], norm=mpl.colors.LogNorm(), bins=256)#, interpolation='gaussian')


cluster_dmrf = pe.coordinates.cluster_kmeans(hva_saa, 1024, stride=10)




In [None]:
dmrf_msm = pe.msm.estimate_markov_model(cluster_dmrf.dtrajs, lag=1)

In [None]:
plt.semilogy(dmrf_msm.timescales()[:20]*300,'o')

In [None]:
dmrf_hmm_ts = pe.msm.timescales_hmsm(cluster_dmrf.dtrajs, nstates=6, lags=[1,2,3,4],errors='bayes')


In [None]:
pe.plots.plot_implied_timescales(dmrf_hmm_ts, ylog=False, dt=.3,units='us')

In [None]:
dmrf_hmm_ = dmrf_hmm_ts.models[0]

In [None]:
HMM_MD_configurations = np.array([2*dfeats_fixed[0][np.where(HMM.metastable_assignments[cluster_obj.dtrajs[0][:]]==j)[0],:].mean(axis=0)-1 for j in range(HMM.nstates)]).T

HMM_dMRF_configurations= np.array([syntht_alldata[dmrf_hmm_.hidden_state_trajectories[0]==i].mean(axis=0) for i in range(dmrf_hmm_.nstates)]).T

In [None]:
correlates_ = np.zeros((HMM.nstates,dmrf_hmm_.nstates))
for i,j in product(range(HMM.nstates), range(dmrf_hmm_.nstates)):
    correlates_[i, j] = np.std(HMM_MD_configurations[:,i]-HMM_dMRF_configurations[:,j])/np.std(HMM_MD_configurations[:,i])

In [None]:
HMM_blinded_dmrfs=[]
msm_blinded_msms=[]
__nstates=[6,5,5,4,5]
for k in range(5):
    tica_dmrfs = [pe.coordinates.tica(data=[(synthts[k]+1)/2], lag=lag) for  lag in [1,2,3,4,5]]

    fig, ax = plt.subplots(ncols = 2, figsize=(8, 3))
    ax[0].semilogy([to.lag for to in tica_dmrfs], [to.timescales[:10] for to in tica_dmrfs])
    ax[0].set_xlabel('lag time / steps')
    ax[0].set_ylabel('implied timescale / steps')
    ax[1].plot([to.lag for to in tica_dmrfs], [to.ndim for to in tica_dmrfs])
    ax[1].set_xlabel('TIC1')
    ax[1].set_ylabel('TIC2')
    fig.tight_layout()


    hva_saa=tica_dmrfs[1].get_output()

    a=plt.hist2d(hva_saa[0][:,0],hva_saa[0][:,1], norm=mpl.colors.LogNorm(), bins=256)#, interpolation='gaussian')


    cluster_dmrf = pe.coordinates.cluster_kmeans(hva_saa, 1024, stride=10)
    #msm_blinded_msms.append(pe.msm.estimate_markov_model([dt.reshape(-1) for dt in cluster_dmrf.get_output()], lag=1))
    HMM_blinded_dmrfs.append(pe.msm.bayesian_hidden_markov_model([dt.reshape(-1) for dt in cluster_dmrf.get_output()], __nstates[k], lag=1))
    

In [None]:
[_.nstates for _ in HMM_blinded_dmrfs]

In [None]:
dmrf_lifetimes=[]
dmrf_lifetimes.append([0.3*0.2*dmrf_hmm_.mfpt([j],[i for i in range(dmrf_hmm_.nstates) if i!=j]) for j in range(dmrf_hmm_.nstates)])
for k in range(5):
    dmrf_lifetimes.append([0.3*0.2*HMM_blinded_dmrfs[k].mfpt([j],[i for i in range(HMM_blinded_dmrfs[k].nstates) if i!=j]) for j in range(HMM_blinded_dmrfs[k].nstates)])
    #for i,j in product(range(__nstates[k]), repeat=2):
    #    mfpt_mats[k+1][i,j] = 0.3*HMM_blinded_dmrfs[k].mfpt(i,j)*0.2
    #for i,j in product(range(5), repeat=2):
    #mfpt_mats[0][i,j] = 0.3*dmrf_hmm_.mfpt(i,j)*0.2

In [None]:
hmm_mfpts = np.zeros((HMM.nstates,HMM.nstates))
for i,j in product(range(HMM.nstates), repeat=2):
    hmm_mfpts[i,j] = 0.2*1e-3*HMM.mfpt(i,j)

In [None]:
lifetimes_hmm = [0.2*1e-3*HMM.mfpt([j],[i for i in range(HMM.nstates) if i!=j]) for j in range(HMM.nstates)]

In [None]:
fig = plt.figure(figsize=(single_column_width*1.3, 1.3*double_column_width*2./3))
from matplotlib import patheffects
gs = gridspec.GridSpec(160, 120, left=0.15,bottom=0.08,top=0.95,right=0.90)#, wspace=None, hspace=None)
ax_feats = [plt.subplot(gs[35*(i):35*(i+1), 40*(j):40*(j+1)]) for i,j in product(range(2), range(3))]
ax_lifetimes = [plt.subplot(gs[35*(i)+90:35*(i+1)+90, 40*(j):40*(j+1)]) for i,j in product(range(2), range(3))]


for k,_dmrf_hmm in enumerate(HMM_blinded_dmrfs):
    HMM_MD_configurations = np.array([2*dfeats_fixed[0][np.where(HMM.metastable_assignments[cluster_obj.dtrajs[0][:]]==j)[0],:].mean(axis=0)-1 for j in range(HMM.nstates)]).T

    HMM_dMRF_configurations= np.array([synthts[k][_dmrf_hmm.hidden_state_trajectories[0]==i].mean(axis=0) for i in range(_dmrf_hmm.nstates)]).T
    correlates_ = np.zeros((HMM.nstates,_dmrf_hmm.nstates))
    for i,j in product(range(HMM.nstates), range(_dmrf_hmm.nstates)):
        correlates_[i, j] = np.std(HMM_MD_configurations[:,i]-HMM_dMRF_configurations[:,j])/np.std(HMM_MD_configurations[:,i])
    
    cax = np.array(ax_feats).ravel()[k]
    cax2 = np.array(ax_lifetimes).ravel()[k]

    if k<2:
        cax.set_xticks([])
        cax2.set_xticks([])
    else:
        cax2.set_xlabel('Metastable state')
        cax2.set_xticks(range(1,6))
        cax.set_xlabel('Avg. MD feature')
    if k in [1,2,4]:
        cax.set_yticks([])
        cax2.set_yticks([])
    else:
        cax.set_ylabel('Avg. DGM feature')
        cax2.set_ylabel(r'lifetime / $\mu s$')

        
    for i,j in zip(*linear_sum_assignment(correlates_)):
            cax.scatter(HMM_MD_configurations[:,i], HMM_dMRF_configurations[:,j], s=1, color=f'C{i}')
            t=cax.text(-1,0.8-0.3*i, r'$\rho=%.2f$'%(np.corrcoef(HMM_MD_configurations[:,i], HMM_dMRF_configurations[:,j])[0,1]), color=f'C{i}')
            t.set_path_effects([patheffects.Stroke(linewidth=0.5, foreground='black'),
                       patheffects.Normal()])
            cax2.bar(i+1, dmrf_lifetimes[k+1][j], log=False, color=f'C{i}')
            cax2.set_ylim([0,0.9])
    cax.text(0,-1, f"Without {k+1}", va='center',ha='center')
    cax2.text(3,0.85, f"Without {k+1}", va='top',ha='center')
    cax2.scatter(range(1,6),lifetimes_hmm,s=15,color='C7', zorder=10, marker="*",lw=0.1,edgecolors='k')
    if k in [1,2,4]:
        cax2.set_yticks([])
        cax2.set_yticklabels([])        
    
    
np.array(ax_feats)[-1].axis('off')
np.array(ax_lifetimes)[-1].axis('off')

ax_feats[0].text(-.35, 1.15, "A", transform=ax_feats[0].transAxes,
          fontsize=12,  va='top')

ax_lifetimes[0].text(-.35, 1.15, "B", transform=ax_lifetimes[0].transAxes,
          fontsize=12,  va='top')

#gs.tight_layout(fig)
fig.savefig('VILLIN_DGM_METASTABLE.pdf')