## Dynamic graphical model of fast folder BBA
Simon Olsson 2018

This notebook makes use previously published simulation data ([Lindroff-Larsen et al. 2011](http://science.sciencemag.org/content/334/6055/517)). We do not have the rights to distribute this data but they can be requested directly from DE Shaw Research.

Note that the random seed is fixed. The notebook makes extensive use of nested point-estimates (i.e. statistical estimates subject to fluctuations which condition other such values), and can therefore be numerically unstable for some random seeds.


In [None]:
import mdtraj as md
%matplotlib inline
import matplotlib.pyplot as plt

import matplotlib as mpl
import numpy as np
import pyemma as pe
from onedeeIsing import Ising_tmatrix
from scipy.spatial import distance_matrix
from sklearn.preprocessing import LabelBinarizer
from graphtime.markov_random_fields import estimate_dMRF
from graphtime.utils import simulate_MSM as generate

double_column_width = 6.968
single_column_width= 3.307
font = {'sans-serif': "Arial",
        'family': "sans-serif",
  'size'   : 8}

mpl.rc('font', **font)
mpl.rcParams['mathtext.fontset'] = 'custom'
np.random.seed(313808)

In [None]:
# Note: 
# we do not have the rights to distribute these data. 
# Please inquiry about these directly from DE Shaw Research.
# Adjust paths as necessary
feat = pe.coordinates.featurizer('../DESRES-Trajectory_1FME-0-protein/1FME-0-protein.pdb')
feat.add_backbone_torsions(deg=True)
source = pe.coordinates.source([['../DESRES-Trajectory_1FME-0-protein/1FME-0-protein-{:03d}.dcd'.format(i) for i in range(112)],
                                ['../DESRES-Trajectory_1FME-1-protein/1FME-1-protein-{:03d}.dcd'.format(i) for i in range(52)]], features=feat)
dihe = source.get_output()

In [None]:
def discr_feats(ftrajs, feat_describe):
    discr_trajs = []
    for ft in ftrajs:
        dftraj = np.zeros(ft.shape, dtype = int)
        for i, fstr in enumerate(feat_describe):
            ls = fstr.split()
            if fstr[:3] == "PHI": # split into two states
                dftraj[:, i] = (ft[:, i]<0).astype(int)
            elif fstr[:3] == "PSI": # split in to two states if not n-terminal
                if int(ls[-1]) == 1:
                    dftraj[:, i] = -1
                else:
                    dftraj[:, i] = (ft[:, i]<80).astype(int)
            elif fstr[:3] == "CHI": #split into 3 rotamers
                tv = (ft[:, i]+180+60)%360
                dftraj[:, i] = (tv>125).astype(int) + (tv>250).astype(int)  
        non_n_psi = np.where(dftraj[0, :]>-1)[0]    
        discr_trajs.append(dftraj.copy()[:,non_n_psi])
    return discr_trajs, [f for i,f in enumerate(feat_describe) if i in non_n_psi]

In [None]:
dfeats, nlbls = discr_feats(dihe, feat.describe())

In [None]:
# remap features {0, 1} -> {-1, 1}
dfeats_fixed = []
for df in dfeats:
    _t = df.copy()
    _t[np.where(_t==0)] = -1  
    dfeats_fixed.append(_t)

In [None]:
C = 1000./len(dfeats_fixed[0])**0.5
logistic_regression_kwargs={'fit_intercept': True, 'penalty': 'l1', 'C': C, 
                            'tol': 0.0001, 'solver': 'saga'}
dmrf_all_data = estimate_dMRF(dfeats_fixed, 
            lag=300, stride=10, 
            logistic_regression_kwargs=logistic_regression_kwargs,
            Encoder = LabelBinarizer(neg_label = -1, pos_label = 1))

### Build MSM
TICA dimensionality reduction, clustering and lag-time optimization

In [None]:
import pyemma as pe

In [None]:
tica_objs = [pe.coordinates.tica(dfeats_fixed, lag = lag) for lag in [5,10,50,100,200,300,500,900,1500,2000,2500,3000]]

In [None]:
fig, ax = plt.subplots(ncols = 2, figsize=(8, 3))
ax[0].semilogy([to.lag for to in tica_objs], [to.timescales[:10] for to in tica_objs])
ax[0].set_xlabel('lag time / steps')
ax[0].set_ylabel('implied timescale / steps')
ax[1].plot([to.lag for to in tica_objs], [to.ndim for to in tica_objs])
ax[1].set_xlabel('lag time / steps')
ax[1].set_ylabel('number of dimensions for 95% kinetic variance')
fig.tight_layout()


In [None]:
_ = np.argmin([to.ndim for to in tica_objs])

In [None]:
Y = tica_objs[_].get_output()

In [None]:
Ys = np.vstack(Y)

In [None]:
a=plt.hist2d(Ys[:, 0], Ys[:, 1], norm=mpl.colors.LogNorm(), bins=256)#, interpolation='gaussian')

In [None]:
cluster_obj = pe.coordinates.cluster_kmeans(data = Y, k=384, stride=2)

In [None]:
its = pe.msm.its([dt[::10] for i,dt in enumerate(cluster_obj.dtrajs) ], lags = [ 5, 10, 20, 30, 50, 70, 90, 100, 120,150,200,250], nits=6, errors='bayes')

In [None]:
pe.plots.plot_implied_timescales(its, ylog=True)

In [None]:
msm = pe.msm.bayesian_markov_model([dt[::1] for i,dt in enumerate(cluster_obj.dtrajs) ], lag = 1500)

In [None]:
ckt = msm.cktest(4)

In [None]:
pe.plots.plot_cktest(ckt, diag=True)

In [None]:
HMM = msm.coarse_grain(4)

In [None]:

inmeta = [[np.where(HMM.metastable_assignments[dt.reshape(-1)]==i)[0] for i in range(HMM.nstates)] for dt in msm.discrete_trajectories_active]

Meta_filtered = [[np.where(np.isin(HMM.metastable_assignments[dt.reshape(-1)],[i], invert=True))[0] for i in range(HMM.nstates)] for dt in msm.discrete_trajectories_active]

not_in_meta_data = [[[df[t] for t in np.split(mf[i], np.where(np.diff(im[i])>1)[0]) if len(t)>300] for i in range(HMM.nstates)] for df,mf,im in zip(dfeats_fixed, Meta_filtered, inmeta) ]

not_in_meta_stacked = [a+b for a,b in zip(*not_in_meta_data)]

In [None]:
regl_=[2000/np.vstack([2*t-1 for t in not_in_meta_stacked[i]]).shape[0]**0.5 for i in range(4)]
regl_

In [None]:
dMRFs = []

for M in range(4):
    logistic_regression_kwargs={'fit_intercept': True, 'penalty': 'l1', 'C': regl_[M], 
                            'tol': 0.0001, 'solver': 'saga'}
    dMRFs.append(estimate_dMRF(not_in_meta_stacked[M], 
                               lag=400, stride=10, Encoder = LabelBinarizer(neg_label=-1, pos_label=1),
                          logistic_regression_kwargs=logistic_regression_kwargs
                              ))


In [None]:
[len(d.get_active_subsystems()) for d in dMRFs]

Visual comparison of $J(\tau)$ for different sub-sampled data-sets. correlation to estimate on full data-set

In [None]:
fig,ax=plt.subplots(ncols=4, figsize=(12,10))
for M,d in enumerate(dMRFs):
    ax[M].imshow(np.hstack([d.get_subsystem_couplings(), d.get_subsystem_biases().reshape(-1,1)] ))
fig.tight_layout()
#ax[1].imshow(np.hstack([np.vstack(villin_nf_coupl), villin_nf_bias] ))

In [None]:
fig,ax = plt.subplots(ncols=4, figsize=(10,2))
[ax[M].scatter(dMRFs[M].get_subsystem_couplings().ravel(),  dmrf_all_data.get_subsystem_couplings().ravel()) for M in range(4)]
[ax[M].set_title("MS {}. Ndp {}".format(M+1, len(np.vstack(not_in_meta_stacked[M]))  )) for M in range(4)]
plt.tight_layout()

Generate synthetic trajectories

In [None]:
synthts = [d.simulate(nsteps=100000, start= (np.array(not_in_meta_stacked[M][0][0])) ) for M,d in enumerate(dMRFs)] 

In [None]:
syntht_all_nb = dmrf_all_data.simulate(nsteps=100000, start=dfeats_fixed[0][0])

In [None]:
syntht_all_nb[syntht_all_nb==-3]=-1

In [None]:
Y_synth_all_data = tica_objs[_].transform(syntht_all_nb)

In [None]:

Y_synths = []

for s in synthts:
    Y_synths.append(np.vstack(tica_objs[9].transform(s)))


In [None]:
fig, axs = plt.subplots(nrows=2, ncols = 3,figsize=(single_column_width,.75*single_column_width), sharex=True, sharey=True)
ax = axs.ravel()
ax[0].hist2d(Ys[:, 0], Ys[:, 1], bins=128,norm=mpl.colors.LogNorm(),  label="all data")
ax[0].set_title('All MD data')
ax[0].set_ylabel('TIC2')

#ax[-1].hist2d(Y_synth_all_data[:, 0], Y_synth_all_data[:, 1], bins=128,norm=mpl.colors.LogNorm(),  label="all data")
#ax[-1].set_title('All MD data')
ax[-1].axis('off')#set_ylabel('TIC2')


for I,ys in enumerate(Y_synths):
    ax[I+1].hist2d(ys[:,0],ys[:,1], bins=128, norm=mpl.colors.LogNorm(), label="missing meta {}".format(I))
    #ax[I+1].scatter(cluster_obj.cluster_centers_[HMM.metastable_assignments==I, 0],  cluster_obj.cluster_centers_[HMM.metastable_assignments==I, 1],marker='^', alpha=0.1, color='r')
    ax[I+1].set_title('Without {}'.format(I+1))
    if I==2:
        ax[I+1].set_ylabel('TIC2')
    if I>1:
        ax[I+1].set_xlabel('TIC1')
fig.tight_layout()
plt.savefig('_tica_leave_one_out_bba.pdf', dpi=300)

In [None]:
#pe.coordinates.save_trajs(source, HMM.sample_by_observation_probabilities(10), "BBA_HMM_", fmt="pdb")

In [None]:
dtraj_Y_synth_all = cluster_obj.transform(Y_synth_all_data)


In [None]:
mrfmsm = pe.msm.estimate_markov_model(dtrajs=dtraj_Y_synth_all.reshape(-1), lag=1)

In [None]:
dtraj_synths = cluster_obj.transform(Y_synths)
dtraj_synths_all = cluster_obj.transform(Y_synth_all_data)

In [None]:
from mpl_toolkits.axes_grid1 import make_axes_locatable
def _colorbar(mappable):
    ax = mappable.axes
    fig = ax.figure
    divider = make_axes_locatable(ax)
    cax = divider.append_axes("right", size="5%", pad=0.05)
    return fig.colorbar(mappable, cax=cax)

In [None]:
for i in range(5):
    print("set_color color{}=".format(i) + str(list(mpl.colors.to_rgb(cmap(i)))))

In [None]:
yy_=[(syntht[:,np.argsort(syntht.var(axis=0))[:]][np.where(HMM.metastable_assignments[ye[::1]]==3)[0],:].mean(axis=0)+1)/2 for syntht,ye in zip(synthts, dtraj_synths)]
xx_=dfeats_fixed[0][np.where(HMM.metastable_assignments[cluster_obj.dtrajs[0][:]]==3)[0],:].mean(axis=0)[:]

In [None]:
import matplotlib.gridspec as gridspec
from scipy.stats import spearmanr
from itertools import product

fig = plt.figure(figsize=(double_column_width/1.3,double_column_width*1.1))

gs = gridspec.GridSpec(143, 100, left=0.097,bottom=0.02,top=0.98,right=0.999)#, wspace=None, hspace=None)
ax = plt.subplot(gs[5:23, 2:22])
ax2 = plt.subplot(gs[5:23, 34:55])
ax3 = plt.subplot(gs[5:23, 70:78])
ls = []
for i in range(HMM.nstates): 
    ax2.scatter(np.vstack(synthts[i]).var(axis=0), (np.vstack([df[::300] for df in dfeats_fixed])).var(axis=0), s=5, label="Without state {:d}".format(i+1))
    l = ax.scatter(np.vstack(synthts[i]).mean(axis=0), (np.vstack([df[::300] for df in dfeats_fixed])).mean(axis=0), s=5, label="Without state {:d}".format(i+1))
    ls.append(l)
ax.set_xlabel('DGM Mean feature')
ax2.set_xlabel('DGM Variance of feature ')

ax.set_ylabel('MD Mean feature')
ax2.set_ylabel('MD Variance of feature')

ax3.axis('off')
ax3.legend(ls, ["Without state {:d}".format(i+1) for i in range(5)], fontsize=9, loc=(-0.20,0.0))
#plt.tight_layout()

#fig, ax = plt.subplots(nrows=5, ncols=5, sharey=True, sharex=True, figsize=(10,10))
ax_ = np.array([[plt.subplot(gs[40+23*(j)+2*j:40+23*(j+1)+j*2, 23*(i)+2*i:23*(i+1)+2*i]) for i in range(HMM.nstates)] for j in range(HMM.nstates)])

axf = ax_.flatten()

ax.text(-.35, 1.30, "A", transform=ax.transAxes,
          fontsize=12,  va='top')

ax2.text(-.35, 1.30, "B", transform=ax2.transAxes,
          fontsize=12,  va='top')
i=0
axf[i].text(-.35, 1.30, "C", transform=axf[i].transAxes,
          fontsize=12, va='top')
for j in range(HMM.nstates):
    for syntht,ye in zip(synthts, dtraj_synths):
        yy_=syntht[np.where(HMM.metastable_assignments[ye[::1]]==j)[0],:].mean(axis=0)
        xx_=np.vstack(dfeats_fixed)[np.where(HMM.metastable_assignments[np.concatenate(cluster_obj.dtrajs)[:]]==j)[0],:].mean(axis=0)
        axf[i].scatter(xx_, yy_,s=5, color="C{:d}".format(j))
    
        axf[i].text(0.05, 0.95, r"$\rho={:0.2f}$".format(np.corrcoef(xx_,yy_)[0,1]), transform=axf[i].transAxes,
          fontsize=9, va='top')

        axf[i].set_xlim((-1,1))
        axf[i].set_ylim((-1,1))
        
        #axf[i].text(xx_.min(), yy_.max()-0.1,)
        if i<4:
            axf[i].set_title('State {}'.format(i+1), fontsize = 10)
        
        if i%4==0: 
            #axf[i].set_ylabel('Empirical mean feature')
            if i==10:
                axf[i].set_ylabel('MD Mean feature')
        else:
            axf[i].set_yticklabels([])
            axf[i].set_yticks([])
        
        if i>11:
            #axf[i].set_xlabel('MRF mean feature')
            if i==22:
                axf[i].set_xlabel('DGM Mean feature')
        else:
            axf[i].set_xticklabels([])
            axf[i].set_xticks([])

        
        if i%HMM.nstates==j:
            axf[i].set_facecolor((0.85,0.85,0.85))
        i=i+1
fig.text(0,0.42 ,'MD Mean feature' , rotation=90)
fig.text(0.47,0.01 ,'DGM Mean feature' , rotation=0)

#axf[-1].axis('off')
#gs.tight_layout(fig, pad=-0.5)
plt.savefig('feature_scatter_BBA.pdf')

In [None]:
from scipy import spatial

In [None]:
from scipy import spatial
def mrftraj_to_dtraj(mrftraj, ftrajs, transformer = lambda x:0.5*(x+1)):
    dtraj_out = []
    errors = []
    ftrajl = np.cumsum([0]+[len(d) for d in ftrajs])
    ftraj = np.vstack(ftrajs)
    for m in mrftraj:
        pair_contacts = spatial.distance.cdist(transformer(m).reshape(1,-1), ftraj, metric='hamming').reshape(-1)
        idx = np.argmin(pair_contacts)
        tidx = max((ftrajl<idx).sum()-1, 0)
        dtraj_out.append([tidx, idx - ftrajl[tidx] ])
        errors.append(pair_contacts[idx])
        
    return dtraj_out, errors

In [None]:
resampledtraj,errors = mrftraj_to_dtraj(syntht_all_nb[:1000], dfeats_fixed, transformer=lambda x:x)

In [None]:
rtrajs = []
errs = []
for i in range(HMM.nstates):
    _resampledtraj,_errors = mrftraj_to_dtraj(synthts[i][:1000], dfeats_fixed, transformer=lambda x:x)
    rtrajs.append(_resampledtraj)
    errs.append(_errors)
    pe.coordinates.save_traj(source, np.array(_resampledtraj, dtype=int), "Resampled_bba_without_{}.pdb".format(i))

In [None]:
resampledtraj,errors = mrftraj_to_dtraj(syntht_all_nb[:1000], dfeats_fixed, transformer=lambda x:x)

In [None]:
pe.coordinates.save_traj(source, np.vstack((np.zeros((1,1000)) ,resampledtraj)).T.astype(int), "bba_subsampled.pdb")

export data

In [None]:
np.savez("resmtraj3_bba.npz", **{'data': np.array([dfeats_fixed[a][f, :] for a,f in np.array(rtrajs[3])])})

np.savez("state_assign_resmtraj_3_bba.npz", **{'data': HMM.metastable_assignments[cluster_obj.transform(tica_objs[9].transform(np.array([dfeats_fixed[a][f, :] for a,f in np.array(rtrajs[3])]))).reshape(-1)]} )

np.savez("recerr_resmtraj_3_bba.npz", **{'data': errs[3]} )

In [None]:
sub_sys = 30
meanfree_synthtrajs = [synthts[i][:2094,sub_sys]-(synthts[i][:2094,sub_sys]).mean(axis=0) for i in range(HMM.nstates)]
meanfree_MD = [df[::400,sub_sys]-np.vstack(dfeats_fixed)[::400,sub_sys].mean(axis=0) for df in dfeats_fixed]

In [None]:
plt.semilogx(np.arange(1,len(meanfree_MD[1])+1    ), np.correlate(meanfree_MD[1],meanfree_MD[1],mode='full')[len(meanfree_MD[1])-1:])
print(np.correlate(meanfree_MD[1],meanfree_MD[1]))

In [None]:
mfl=[len(mf) for mf in meanfree_MD]

In [None]:
def func(x, a, b, c):
    return a * np.exp(-b * x) + c

In [None]:
from scipy.optimize import curve_fit

In [None]:
for i in range(4):
    popt, pcov = curve_fit(func, np.arange(1,500)*60/1000., np.correlate(meanfree_synthtrajs[i][:500],meanfree_synthtrajs[i][:500], mode='full')[500:])
    print("relaxation rate {:0.3f} for without state {:d}".format(popt[1],i+1))

popt, pcov = curve_fit(func, np.arange(1,501)*60./1000., np.mean([np.correlate(mf,mf,mode='full')[l:][:min(mfl)-1] for mf,l in zip(meanfree_MD,mfl)], axis=0)[:500])
print("relaxation rate MD", popt[1])    


In [None]:
from itertools import product
fig = plt.figure(figsize=(single_column_width,1.8*1.75*single_column_width/2))

gs = gridspec.GridSpec(180, 90,left=0.16,bottom=0.05,top=0.95,right=0.95, wspace=0.0, hspace=0.)#, width_ratios=1, height_ratios=1)
axs = [plt.subplot(gs[30*(i)+15:30*(i+1)+15, 3+29*(j):3+29*(j+1)]) for i,j in product(range(1), range(3))]
axs.append(plt.subplot(gs[30*(1)+15:30*(1+1)+15, 3+29*(0):3+29*(0+1)]))


axs.append(plt.subplot(gs[:15, :]))
#gs.update(hspace=0, wspace=0)

ax_ylbl = plt.subplot(gs[30:,:3])
ax_ylbl.set_ylabel(r'$p(x)$ (log-scale)')
ylbl=ax_ylbl.axes.yaxis.get_label()
fig.text(ylbl.get_position()[0]+0.05,1.5*ylbl.get_position()[1]+0.09*2/3 , ylbl.get_text(), rotation=90)



ax_ylbl.axis('off')


ls = []
for i, ax in enumerate(axs):
    ax.set_ylim((2e-4,2.0))
    if i>3:
        ax.axis('off')
    if i==4:
        ax.legend([l, l2, l3], ["dMRF", "HMM (full MD)", "State left out during estimation"], fontsize=8, loc=(0.27,0.2))
        ax.set_yticks([])
        ax.set_xticks([])
        ax.axis('off')
        
    else:
        ising_msm = pe.msm.estimate_markov_model(HMM.metastable_assignments[dtraj_synths[i].reshape(-1)], lag=1)
        l = ax.bar(range(1,5), ising_msm.stationary_distribution,hatch="//", fill=False, label="Ising")
        np.savetxt('bba_hist_dmrf{:}.txt'.format(i), ising_msm.stationary_distribution)
        ls.append(l)
        l2 = ax.bar(range(1,5), HMM.stationary_distribution,fill=True, alpha=0.2, log=True, label="HMM (full MD)")
        ls.append(l2)
        l3 = ax.scatter([i+1], 0.5, marker="*", s=50, color='purple')

        if i>0:
            ax.set_xticks([1,2,3,4])
            if i>2:
                ax.set_xlabel(r'Meta-stable state / $x$')

        else:
            ax.set_xticks([])

        if i in [0, 3]:
            continue
            #ax.set_ylabel('State prob (log-scale)')
        else:
            ax.set_yticks([])
np.savetxt('bba_hist_hmm.txt', HMM.stationary_distribution)

fig.text(0.58,0.65, r"Meta-stable state / $x$", rotation=0)

            
axs2 = [plt.subplot(gs[105:145, 35*i+20*i:35*(i+1)+20*i]) for i in range(2) ]

bins = np.unique(np.concatenate(errs))
for i, err in enumerate(errs):
    axs2[0].hist(err, bins=bins, label = "Without state {:d}".format(i+1), histtype='step', lw=1, normed=True, log=False)
axs2[0].set_xlabel(r'$\epsilon $')
axs2[0].set_ylabel(r'$p(\epsilon)$')

for i in range(HMM.nstates):
    a_ = np.correlate(meanfree_synthtrajs[i][:2094],meanfree_synthtrajs[i][:2094], mode='full')[2094:]
    axs2[1].plot(np.arange(1,2094)*60./1000.,a_/a_[0], lw=1, label="Without state {:d}".format(i+1) )
#axs2[-1].axis('off')
a_ = np.mean([np.correlate(mf,mf,mode='full')[l:][:min(mfl)-1] for mf,l in zip(meanfree_MD,mfl)], axis=0)
axs2[1].plot(np.arange(1,1277)*60./1000.,a_/a_[0],lw=1, label="MD", color='k')
axs2[1].set_xlabel(r'$\tau$ / $\mu$s')

axs2[1].set_ylabel(r'$C(\tau)$ / Glu 17 $\phi$ rotamer')
axs2[1].semilogx()
axs2[1].set_xlim((axs2[1].get_xlim()[0],100))
ax3 = plt.subplot(gs[165:, :])
ax3.legend([child for child in axs2[1].get_children() if isinstance(child, mpl.lines.Line2D)],
                 [child.get_label() for child in axs2[1].get_children() if isinstance(child, mpl.lines.Line2D)]
                 ,loc=(0.0,-0.4), ncol=2)#axs2[0].legend()
ax3.axis('off')
axs[0].text(-.35, 1.30, "A", transform=axs[0].transAxes,
          fontsize=12,  va='top')
axs2[0].text(-.35, 1.30, "B", transform=axs2[0].transAxes,
          fontsize=12,  va='top')
axs2[1].text(-.35, 1.30, "C", transform=axs2[1].transAxes,
          fontsize=12,  va='top')
#gs.tight_layout(fig,pad=-1.5)
plt.savefig('statdist_err_acf_bba.pdf')

Data exports for plotting

In [None]:
np.savetxt('bba_errs.txt', errs)

np.savetxt('bba_synthtrajs.txt', meanfree_synthtrajs)

np.savetxt('bba_syntht_all.txt', dtraj_synths_all.reshape(-1))

HMM.save('bba_hmm.pyemma', overwrite=True)

a_ = np.mean([np.correlate(mf,mf,mode='full')[l:][:min(mfl)-1] for mf,l in zip(meanfree_MD,mfl)], axis=0)

np.savetxt('bba_md_acf.txt', a_)

In [None]:

tica_dmrfs = [pe.coordinates.tica(data=[(syntht_all_nb+1)/2], lag=lag) for  lag in [1,2,3,4,5]]

fig, ax = plt.subplots(ncols = 2, figsize=(8, 3))
ax[0].semilogy([to.lag for to in tica_dmrfs], [to.timescales[:10] for to in tica_dmrfs])
ax[0].set_xlabel('lag time / steps')
ax[0].set_ylabel('implied timescale / steps')
ax[1].plot([to.lag for to in tica_dmrfs], [to.ndim for to in tica_dmrfs])
ax[1].set_xlabel('lag time / steps')
ax[1].set_ylabel('number of dimensions for 95% kinetic variance')
fig.tight_layout()


hva_saa=tica_dmrfs[1].get_output()

a=plt.hist2d(hva_saa[0][:,0],hva_saa[0][:,1], norm=mpl.colors.LogNorm(), bins=256)#, interpolation='gaussian')


cluster_dmrf = pe.coordinates.cluster_kmeans(hva_saa, 1024, stride=10)

dmrf_msm = pe.msm.estimate_markov_model(cluster_dmrf.dtrajs, lag=1)



In [None]:
plt.semilogy(dmrf_msm.timescales()[:15],'o')

In [None]:
dmrf_hmm_ts = pe.msm.timescales_hmsm(cluster_dmrf.dtrajs, nstates=4, lags=[1,2,3,4],errors='bayes')
dmrf_hmm_ = dmrf_hmm_ts.models[0]

In [None]:
HMM_blinded_dmrfs=[]
msm_blinded_msms=[]
__nstates=[4,3,3,3]
for k in range(4):
    tica_dmrfs = [pe.coordinates.tica(data=[(synthts[k]+1)/2], lag=lag) for  lag in [1,2,3,4,5]]

    fig, ax = plt.subplots(ncols = 2, figsize=(8, 3))
    ax[0].semilogy([to.lag for to in tica_dmrfs], [to.timescales[:10] for to in tica_dmrfs])
    ax[0].set_xlabel('lag time / steps')
    ax[0].set_ylabel('implied timescale / steps')
    ax[1].plot([to.lag for to in tica_dmrfs], [to.ndim for to in tica_dmrfs])
    ax[1].set_xlabel('TIC1')
    ax[1].set_ylabel('TIC2')
    fig.tight_layout()


    hva_saa=tica_dmrfs[1].get_output()

    a=plt.hist2d(hva_saa[0][:,0],hva_saa[0][:,1], norm=mpl.colors.LogNorm(), bins=256)#, interpolation='gaussian')


    cluster_dmrf = pe.coordinates.cluster_kmeans(hva_saa, 384, stride=10)
    #msm_blinded_msms.append(pe.msm.estimate_markov_model([dt.reshape(-1) for dt in cluster_dmrf.get_output()], lag=1))
    HMM_blinded_dmrfs.append(pe.msm.bayesian_hidden_markov_model([dt.reshape(-1) for dt in cluster_dmrf.get_output()], __nstates[k], lag=1))
    

In [None]:
from itertools import product
from matplotlib import gridspec

In [None]:
mfpt_mats=[np.zeros((4,4))]+[np.zeros((a,a)) for i,a in enumerate(__nstates)]
#mfpt_mats.append()
for k in range(4):
    for i,j in product(range(__nstates[k]), repeat=2):
        #try:
            mfpt_mats[k+1][i,j] = 0.2*0.4*HMM_blinded_dmrfs[k].mfpt(i,j)
        #except:
        #    print(k,i,j)
for i,j in product(range(4), repeat=2):
    mfpt_mats[0][i,j] = 0.2*dmrf_hmm_.mfpt(i,j)*1e-3

In [None]:
dmrf_lifetimes=[]
dmrf_lifetimes.append([1.5*0.2*dmrf_hmm_.mfpt([j],[i for i in range(dmrf_hmm_.nstates) if i!=j]) for j in range(dmrf_hmm_.nstates)])
for k in range(4):
    dmrf_lifetimes.append([1.5*0.2*HMM_blinded_dmrfs[k].mfpt([j],[i for i in range(HMM_blinded_dmrfs[k].nstates) if i!=j]) for j in range(HMM_blinded_dmrfs[k].nstates)])
    #for i,j in product(range(__nstates[k]), repeat=2):
    #    mfpt_mats[k+1][i,j] = 0.3*HMM_blinded_dmrfs[k].mfpt(i,j)*0.2
    #for i,j in product(range(5), repeat=2):
    #mfpt_mats[0][i,j] = 0.3*dmrf_hmm_.mfpt(i,j)*0.2

In [None]:
lifetimes_hmm = [1.5*1e-3*0.2*HMM.mfpt([j],[i for i in range(HMM.nstates) if i!=j]) for j in range(HMM.nstates)]

In [None]:
hmm_mfpts = np.zeros((HMM.nstates,HMM.nstates))
for i,j in product(range(HMM.nstates), repeat=2):
    hmm_mfpts[i,j] = 0.2*HMM.mfpt(i,j)*1.5*1e-3

In [None]:
from scipy.optimize import linear_sum_assignment

In [None]:
fig = plt.figure(figsize=(single_column_width*1.3, 1.3*double_column_width*2./3))
from matplotlib import patheffects
gs = gridspec.GridSpec(160, 120, left=0.15,bottom=0.08,top=0.95,right=0.90)#, wspace=None, hspace=None)
ax_feats = [plt.subplot(gs[35*(i):35*(i+1), 40*(j):40*(j+1)]) for i,j in product(range(2), range(3))]
ax_lifetimes = [plt.subplot(gs[35*(i)+90:35*(i+1)+90, 40*(j):40*(j+1)]) for i,j in product(range(2), range(3))]


for k,_dmrf_hmm in enumerate(HMM_blinded_dmrfs):
    HMM_MD_configurations = np.array([dfeats_fixed[0][np.where(HMM.metastable_assignments[cluster_obj.dtrajs[0][:]]==j)[0],:].mean(axis=0) for j in range(HMM.nstates)]).T

    HMM_dMRF_configurations= np.array([synthts[k][_dmrf_hmm.hidden_state_trajectories[0]==i].mean(axis=0) for i in range(_dmrf_hmm.nstates)]).T
    correlates_ = np.zeros((HMM.nstates,_dmrf_hmm.nstates))
    for i,j in product(range(HMM.nstates), range(_dmrf_hmm.nstates)):
        correlates_[i, j] = np.std(HMM_MD_configurations[:,i]-HMM_dMRF_configurations[:,j])/np.std(HMM_MD_configurations[:,i])
    
    cax = np.array(ax_feats).ravel()[k]
    cax2 = np.array(ax_lifetimes).ravel()[k]

    if k<1:
        cax.set_xticks([])
        cax2.set_xticks([])
    else:
        cax2.set_xlabel('Metastable state')
        cax2.set_xticks(range(1,6))
        cax.set_xlabel('Avg. MD feature')
    if k in [1,2,4]:
        cax.set_yticks([])
        cax2.set_yticks([])
    else:
        cax.set_ylabel('Avg. DGM feature')
        cax2.set_ylabel(r'lifetime / $\mu s$')

        
    for i,j in zip(*linear_sum_assignment(correlates_)):
            cax.scatter(HMM_MD_configurations[:,i], HMM_dMRF_configurations[:,j], s=1, color=f'C{i}')
            t=cax.text(-1,0.8-0.3*i, r'$\rho=%.2f$'%(np.corrcoef(HMM_MD_configurations[:,i], HMM_dMRF_configurations[:,j])[0,1]), color=f'C{i}')
            t.set_path_effects([patheffects.Stroke(linewidth=0.5, foreground='black'),
                       patheffects.Normal()])
            cax2.bar(i+1, dmrf_lifetimes[k+1][j], log=False, color=f'C{i}')
            #cax2.set_ylim([0,0.9])
    cax.set_xlim([-1.1,1.1])
    cax.set_ylim([-1.1,1.1])
    cax.text(0,-1, f"Without {k+1}", va='center',ha='center')
    cax2.text(2.5,10.5, f"Without {k+1}", va='top',ha='center')
    cax2.scatter(range(1,5),lifetimes_hmm,s=15,color='C7', zorder=10, marker="*",lw=0.1,edgecolors='k')
    cax2.set_xlim(0.5,4.5)
    cax2.set_ylim(0,11)
    if k in [1,2,4]:
        cax2.set_yticks([])
        cax2.set_yticklabels([])    
    
    
np.array(ax_feats)[-1].axis('off')
np.array(ax_lifetimes)[-1].axis('off')
np.array(ax_feats)[-2].axis('off')
np.array(ax_lifetimes)[-2].axis('off')
np.array(ax_lifetimes)[-3].set_xticks([1,2,3,4])
np.array(ax_lifetimes)[-3].set_xlim(np.array(ax_lifetimes)[0].get_xlim())

ax_feats[0].text(-.35, 1.15, "C", transform=ax_feats[0].transAxes,
          fontsize=12,  va='top')

ax_lifetimes[0].text(-.35, 1.15, "D", transform=ax_lifetimes[0].transAxes,
          fontsize=12,  va='top')

#gs.tight_layout(fig)
fig.savefig('BBA_DGM_METASTABLE.pdf')
