# MSM estimation and validation

<a rel="license" href="http://creativecommons.org/licenses/by/4.0/"><img alt="Creative Commons Licence" style="border-width:0" src="https://i.creativecommons.org/l/by/4.0/88x31.png" title='This work is licensed under a Creative Commons Attribution 4.0 International License.' align="right"/></a>

In this notebook, we will cover how to estimate a Markov state model (MSM) and do model validation;
we also show how to save and restore model and estimator objects.
For this notebook, you need to know how to do data loading/visualization as well as dimension reduction.


**Remember**:
- to run the currently highlighted cell, hold <kbd>&#x21E7; Shift</kbd> and press <kbd>&#x23ce; Enter</kbd>;
- to get help for a specific function, place the cursor within the function's brackets, hold <kbd>&#x21E7; Shift</kbd>, and press <kbd>&#x21E5; Tab</kbd>;
- you can find the full documentation for PyEMMA at [PyEMMA.org](http://www.pyemma.org) and for deeptime at [deeptime-ml.github.io](https://deeptime-ml.github.io/).

In [None]:
%matplotlib inline
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm
import numpy as np
import mdshare
import pyemma

## Loading MD data and repeating the clustering step

Let's load alanine dipeptide backbone torsions and discretise with 200 $k$-means centers...

In [None]:
pdb = mdshare.fetch('alanine-dipeptide-nowater.pdb', working_directory='data')
files = mdshare.fetch('alanine-dipeptide-*-250ns-nowater.xtc', working_directory='data')

feat = pyemma.coordinates.featurizer(pdb)
feat.add_backbone_torsions(periodic=False)

data = pyemma.coordinates.load(files, features=feat)

from deeptime.clustering import KMeans
cluster = KMeans(200, max_iter=50, progress=tqdm)\
    .fit_fetch(np.concatenate(data)[::10])

... and plot the free energy along with the cluster centers:

In [None]:
fig, ax = plt.subplots()
pyemma.plots.plot_free_energy(*np.concatenate(data).T, ax=ax, legacy=False)
ax.scatter(*cluster.cluster_centers.T, s=15, c='k')
ax.set_xlabel('$\Phi$ / rad') 
ax.set_ylabel('$\Psi$ / rad')
fig.tight_layout()

## Implied time scales and lag time selection

The first step after obtaining the discretized dynamics is finding a suitable lag time.
The systematic approach is to estimate MSMs at various lag times and observe how the implied timescales (ITSs) of these models behave.
In particular, we are looking for lag time ranges in which the implied timescales are constant.

To that end we iterate over a range of lagtimes and estimate a Markov state model for each of them, subsequently computing the four slowest (``k=4``) timescales from it.

In [None]:
from deeptime.markov import TransitionCountEstimator
from deeptime.markov.msm import MaximumLikelihoodMSM

dtrajs = [cluster.transform(traj) for traj in data]

lags = [1, 2, 5, 10, 20, 50]
models = []
for lag in tqdm(lags, leave=False):
    counts_estimator = TransitionCountEstimator(lag, "sliding")
    counts = counts_estimator.fit_fetch(dtrajs)
    counts = counts.submodel_largest()
    
    msm_estimator = MaximumLikelihoodMSM()
    models.append(msm_estimator.fit_fetch(counts))

$$\begin{eqnarray*}
T(n \tau) & = & (T(\tau))^n\\[0.75em]
\lambda(n \tau) & = & (\lambda(\tau))^n\\[0.75em]
\mathrm{ITS}(n \tau) & = & - \frac{n \tau}{\ln \lambda(n \tau)} = - \frac{n \tau}{\ln (\lambda(\tau))^n} = - \frac{\tau}{\ln \lambda(\tau)} = \mathrm{ITS}(\tau)
\end{eqnarray*}$$

We can pass the returned estimated timescales as "lagtime-timescale"-tuple to the `pyemma.plots.plot_implied_timescales()` function:

In [None]:
from deeptime.util.validation import implied_timescales
from deeptime.plots import plot_implied_timescales

its = implied_timescales(models)
ax = plot_implied_timescales(its, n_its=4)
ax.set_yscale('log')
ax.set_xlabel('lagtime (ps)')
ax.set_ylabel('timescales (ps)');

The above plot tells us that there are three resolved processes (blue, orange, green) which are largely invariant to the MSM lag time.
The fourth ITS (red) is smaller than the lag time (black line, grey-shaded area);
it corresponds to a process which is faster than the lag time and, thus, is not resolved.
Since the implied timescales are, like the corresponding eigenvalues, sorted in decreasing order,
we know that all other remaining processes must be even faster.

## Error bars for the timescales

Error bars can be obtained with Bayesian sampling:

In [None]:
from deeptime.markov.msm import BayesianMSM

def its_bayesian_msm(data, lagtimes):
    models = []
    for lagtime in tqdm(lagtimes):
        counts = TransitionCountEstimator(lagtime, "effective").fit_fetch(data).submodel_largest()
        models.append(BayesianMSM(n_samples=50).fit_fetch(counts))

    return implied_timescales(models)

its = its_bayesian_msm(dtrajs, [1, 2, 5, 10, 20, 50])

ax = plot_implied_timescales(its, n_its=4)
ax.set_yscale('log')
ax.set_xlabel('lagtime (ps)')
ax.set_ylabel('timescales (ps)');

## Effect of the discretization on the implied timescales

Let's look at the discretisation's influence on the ITSs:

In [None]:
def its_msm(data, lagtimes):
    models = []
    for lag in tqdm(lagtimes, leave=False):
        counts = TransitionCountEstimator(lag, "sliding").fit_fetch(data).submodel_largest()
        models.append(MaximumLikelihoodMSM().fit_fetch(counts))
    return implied_timescales(models)

lags = [1, 2, 5, 10, 20, 50]

cluster_20 = KMeans(20, max_iter=50).fit_fetch(np.concatenate(data)[::10])
its_20 = its_msm([cluster_20.transform(x) for x in data], lags)

cluster_50 = KMeans(50, max_iter=50).fit_fetch(np.concatenate(data)[::10])
its_50 = its_msm([cluster_50.transform(x) for x in data], lags)

cluster_100 = KMeans(100, max_iter=50).fit_fetch(np.concatenate(data)[::10])
its_100 = its_msm([cluster_100.transform(x) for x in data], lags);

In [None]:
fig, axes = plt.subplots(2, 3, figsize=(12, 6))

pyemma.plots.plot_free_energy(*np.concatenate(data).T, ax=axes[0, 0], cbar=False)
axes[0, 0].scatter(*cluster_20.cluster_centers.T, s=15, c='k')
ax_its = plot_implied_timescales(its_20, ax=axes[1, 0], n_its=4)
ax_its.set_yscale('log')
ax_its.set_xlabel('lagtime (ps)')

pyemma.plots.plot_free_energy(*np.concatenate(data).T, ax=axes[0, 1], cbar=False)
axes[0, 1].scatter(*cluster_50.cluster_centers.T, s=15, c='k')
ax_its = plot_implied_timescales(its_50, ax=axes[1, 1], n_its=4)
ax_its.set_yscale('log')
ax_its.set_xlabel('lagtime (ps)')

pyemma.plots.plot_free_energy(*np.concatenate(data).T, ax=axes[0, 2], cbar=False)
axes[0, 2].scatter(*cluster_100.cluster_centers.T, s=15, c='k')
ax_its = plot_implied_timescales(its_100, ax=axes[1, 2], n_its=4)
ax_its.set_yscale('log')
ax_its.set_xlabel('lagtime (ps)')

fig.tight_layout()

## Estimating the maximum likelihood Markov model

In [None]:
counts_estimator = TransitionCountEstimator(lagtime=10, count_mode='sliding')
counts = counts_estimator.fit_fetch(dtrajs).submodel_largest()

msm_estimator = MaximumLikelihoodMSM()
msm = msm_estimator.fit_fetch(counts)

print(f'fraction of states used = {msm.state_fraction}')
print(f'fraction of counts used = {msm.count_fraction}')

In [None]:
msm.timescales(k=4)

The state space can be restricted to largest connected set (`submodel_largest()`) or any other selection of states:

In [None]:
counts = counts.submodel([0, 1, 3, 7])
print(f"States: {counts.states}, state symbols: {counts.state_symbols}")
msm = MaximumLikelihoodMSM().fit_fetch(counts)

print(f'fraction of states used = {msm.state_fraction}')
print(f'fraction of counts used = {msm.count_fraction}')

And restricted even further, always based on the _states_ of the current count model.

In [None]:
counts = counts.submodel([0, 3])
print(f"States: {counts.states}, state symbols: {counts.state_symbols}")
msm = MaximumLikelihoodMSM().fit(counts).fetch_model()

print(f'fraction of states used = {msm.state_fraction}')
print(f'fraction of counts used = {msm.count_fraction}')

## Estimating the Bayesian Markov model

In [None]:
count_estimator = TransitionCountEstimator(lagtime=10, count_mode='effective')
counts = count_estimator.fit_fetch(dtrajs).submodel_largest()
bayesian_msm_estimator = BayesianMSM()
bayesian_msm = bayesian_msm_estimator.fit_fetch(counts)

In [None]:
stats = bayesian_msm.gather_stats('timescales', k=3)
stats.L, stats.R

## The Chapman-Kolmogorov test

To see whether our model satisfies Markovianity, we perform (and visualize) a Chapman-Kolmogorow (CK) test.
Since we aim at modeling the dynamics between metastable states rather than between microstates, this will be conducted in the space of metastable states.
The latter are identified automatically using PCCA++ (which is explained later).
We usually choose the number of metastable states according to the implied timescales plot by identifying a gap between the ITS.

In [None]:
test_model = MaximumLikelihoodMSM(lagtime=10).fit_fetch(dtrajs);

In [None]:
models = []
for lagtime in [10, 20, 30, 40, 50, 80, 100]:
    models.append(MaximumLikelihoodMSM(lagtime=lagtime).fit_fetch(dtrajs))
ck_test = test_model.ck_test(models, n_metastable_sets=4)

In [None]:
from deeptime.plots import plot_ck_test

plot_ck_test(ck_test, xlabel='lagtime (ps)', sharey=True);

In [None]:
counts = TransitionCountEstimator(lagtime=10, count_mode='effective')\
    .fit_fetch(dtrajs).submodel_largest()
test_bmsm = BayesianMSM().fit_fetch(counts)

models = []
for lagtime in tqdm([10, 20, 30, 40, 50, 80, 100]):
    counts = TransitionCountEstimator(lagtime=lagtime, count_mode='effective')\
        .fit_fetch(dtrajs).submodel_largest()
    models.append(BayesianMSM().fit_fetch(counts))

In [None]:
ck_test = test_bmsm.ck_test(models, n_metastable_sets=4)
plot_ck_test(ck_test, xlabel='lagtime (ps)', sharey=True);

## Persisting and restoring estimators

In [None]:
import pickle

with open('cluster_50.pkl', 'wb') as f:
    pickle.dump(cluster_50, f)
    
with open('cluster_50.pkl', 'rb') as f:
    cluster_50_restored = pickle.load(f)
    
print(cluster_50_restored.n_clusters)

In [None]:
with open('msm.pkl', 'wb') as f:
    pickle.dump(msm, f)
with open('msm.pkl', 'rb') as f:
    msm_restored = pickle.load(f)
    
print(f"Timescales {msm.timescales()}, restored {msm_restored.timescales()}")

## Hands-on

#### Exercise 1

Load the heavy atom distances into memory, perform PCA and TICA (`lag=3`) with `dim=2`,
then discretize with $100$ $k$-means centers and a stride of $10$. Compare the two discretizations be generating implied timescale plots for both of them.

In [None]:
feat =  #FIXME
feat. #FIXME
data =  #FIXME

from sklearn.decomposition import PCA
pca = PCA(n_components=2).fit(np.concatenate(data))
tica = #FIXME

pca_output = #FIXME
tica_output = [tica.transform(traj) for traj in data]

cls_pca_estimator = KMeans(100, max_iter=50)
cls_pca = #FIXME
cls_tica = #FIXME

dtrajs_pca = [cls_pca.transform(pca.transform(traj)) for traj in data]
dtrajs_tica = # FIXME

its_pca = implied_timescales_msm(dtrajs_pca, lags=[1, 2, 5, 10, 20, 50])
its_tica = #FIXME

###### Solution

In [None]:
feat = pyemma.coordinates.featurizer(pdb)
pairs = feat.pairs(feat.select_Heavy())
feat.add_distances(pairs, periodic=False)
data = pyemma.coordinates.load(files, features=feat)

from sklearn.decomposition import PCA
pca = PCA(n_components=2).fit(np.concatenate(data))

from deeptime.decomposition import TICA
tica_estimator = TICA(lagtime=3, dim=2)
tica = tica_estimator.fit_fetch(data)

pca_output = [pca.transform(traj) for traj in data]
tica_output = [tica.transform(traj) for traj in data]

cls_pca = KMeans(100, max_iter=50).fit(np.concatenate(pca_output)[::10]).fetch_model()
cls_tica = KMeans(100, max_iter=50).fit(np.concatenate(tica_output)[::10]).fetch_model()

dtrajs_pca = [cls_pca.transform(pca.transform(traj)) for traj in data]
dtrajs_tica = [cls_tica.transform(tica.transform(traj)) for traj in data]

lags = [1, 2, 5, 10, 20, 50]
its_pca = implied_timescales([MaximumLikelihoodMSM(lagtime=lag).fit_fetch(dtrajs_pca) for lag in lags])
its_tica = implied_timescales([MaximumLikelihoodMSM(lagtime=lag).fit_fetch(dtrajs_tica) for lag in lags])

Let's visualize the ITS convergence for both projections:

In [None]:
fig, axes = plt.subplots(2, 3, figsize=(12, 6))
pyemma.plots.plot_feature_histograms(np.concatenate(pca_output), ax=axes[0, 0])
pyemma.plots.plot_feature_histograms(np.concatenate(tica_output), ax=axes[1, 0])
axes[0, 0].set_title('PCA')
axes[1, 0].set_title('TICA')
pyemma.plots.plot_density(*np.concatenate(pca_output).T, ax=axes[0, 1], cbar=False, alpha=0.1)
axes[0, 1].scatter(*cls_pca.cluster_centers.T, s=15, c='C1')
axes[0, 1].set_xlabel('PC 1')
axes[0, 1].set_ylabel('PC 2')
pyemma.plots.plot_density(*np.concatenate(tica_output).T, ax=axes[1, 1], cbar=False, alpha=0.1)
axes[1, 1].scatter(*cls_tica.cluster_centers.T, s=15, c='C1')
axes[1, 1].set_xlabel('IC 1')
axes[1, 1].set_ylabel('IC 2')
ax_its = plot_implied_timescales(its_pca, ax=axes[0, 2], n_its=4)
ax_its.set_yscale('log')
ax_its.set_xlabel('lagtime (ps)')
ax_its = plot_implied_timescales(its_tica, ax=axes[1, 2], n_its=4)
ax_its.set_yscale('log')
ax_its.set_xlabel('lagtime (ps)')
axes[0, 2].set_ylim(1, 2000)
axes[1, 2].set_ylim(1, 2000)
fig.tight_layout()

Despite the fact that PCA yields a projection with some defined basins,
the ITS plot shows that only one "slow" process is resolved which is more than one order of magnitude too fast.

TICA does find three slow processes which agree (in terms of the implied timescales) with the backbone torsions example above.

We conclude that this PCA projection is not suitable to resolve the slow dynamics of alanine dipeptide and we will continue to estimate/validate the TICA-based projection.

#### Exercise 2

Estimate a Bayesian MSM at lag time $10$ ps and perform/show a CK test for four metastable states.

In [None]:
counts_estimator = TransitionCountEstimator(lagtime=10, count_mode="effective")
counts = counts_estimator.fit_fetch(dtrajs_tica).submodel_largest()
bayesian_msm = # FIXME
pyemma.plots. #FIXME

###### Solution

In [None]:
counts_estimator = TransitionCountEstimator(lagtime=10, count_mode="effective")
counts = counts_estimator.fit_fetch(dtrajs_tica).submodel_largest()

test_model = BayesianMSM(n_samples=50).fit_fetch(counts)

models = []
for i in tqdm(range(1, 10)):
    counts = TransitionCountEstimator(lagtime=i * test_model.lagtime, count_mode="effective")\
        .fit_fetch(dtrajs_tica).submodel_largest()
    models.append(BayesianMSM(n_samples=50).fit_fetch(counts))

ck_test = test_model.ck_test(models, n_metastable_sets=4)
plot_ck_test(ck_test, xlabel='lagtime (ps)');

We again see a good agreement between model prediction and re-estimation.

## Wrapping up
In this notebook, we have learned how to estimate a regular or Bayesian MSM from discretized molecular simulation data with `deeptime` and `pyemma`, also how to perform basic model validation.

In detail, we have selected a suitable lag time by
- computing timescales from MSMs and Bayesian MSMs
- `pyemma.plots.plot_implied_timescales()` to visualize the convergence of the implied timescales.

We then have used
- `dt.markov.TransitionCountEstimator()` to estimate transition counts
- `dt.markov.msm.MaximumLikelihoodMSM()` to estimate a regular MSM,
- `dt.markov.msm.BayesianMSM()` to estimate a Bayesian MSM,
- the `timescales()` method of an estimated MSM object to access its implied timescales,
- the `chapman_kolmogorov_validator()` method of an estimated MSM estiamator to perform a Chapman-Kolmogorow test, and
- `pyemma.plots.plot_cktest()` to visualize the latter.