# String Method Analysis Markov-State-Models
## Imports

In [None]:
import os
import pickle
import sys
import logging
import numpy as np
import matplotlib.pyplot as plt


logging.getLogger("stringmethod").setLevel(logging.ERROR)
sys.path.append("../string-method-gmxapi/")
import src.analysis as spc

In [None]:
%load_ext lab_black
%load_ext autoreload
%autoreload 2

## Load data

This notebook needs to run in the string simulation folder, this cell will get you there. You also set up a path for writing the figures.

In [None]:
name_sim = "C2I_lb_v1/"
simulation_directory = f"/data/sperez/Projects/string_sims/data/raw/{name_sim}/"
path_report = f"/data/sperez/Projects/string_sims/reports/figures/{name_sim}"
os.chdir(simulation_directory)
os.getcwd()

In [None]:
with open("cv.pkl", "rb") as file:
    cvs, ndx_groups = pickle.load(file)

The `load_swarm_data` function will load the swarm data in the `cv_coordinates`. If you set `extract=True` it will read the data from the swarm files. If you have done this previously you can set `extract=False` so the function just reads `postprocessing/cv_coordinates.npy`. `first_iteration` can be used to exclude initial swarms as equilibration and `last_iteration` can be done to exclude some iterations for example if you want to estimate the FES convergence by comparing blocks of data.

In [None]:
cv_coordinates = spc.load_swarm_data(
    extract=True, first_iteration=100, last_iteration=None
)

## Dimensionality reduction with TICA

The following cell computes the tica projection of the string cvs and discards the tics that have the lowest kinetic variance. This reduces the cvs space to a lower dimensional space that is adapted to the kinetic variance. You can use the drop keyword to drop certain cvs that are not well converged in the string simulation or that change very little from the beggining to the end of the string. The best case scenario is that `drop=[]` just works.

In [None]:
tica = spc.cvs_to_tica(cv_coordinates, drop=[20, 21, 22, 23, 32, 33, 34, 35])

## Cluster

The next cell plots the "vamp score" of using `n_clustercenters` to make an MSM. You should find that at some point the vamp score saturates. Choose the minimum number of clusters that gives you the saturated vamp score as the value of k for the next steps. This might take a little while.

In [None]:
n_clustercenters = [5, 10, 30, 50, 75, 100, 200, 500][::-1]
fig, ax = spc.get_vamp_vs_k(n_clustercenters, tica)

If the calculation fails, there is something wrong with your MSM. Either you have too little transitions or there too many cvs in tica to have all the states well connected. Solutions:
+ Reduce the maximum number of clusters (drop 200 and 500) of `n_clustercenters` and see if you get a saturated curve.
+ Reduce the number of cvs that went into your TICA calculation.
+ Do more iterations of the string method.

## MSM Deeptime

Choose the number of clusters, `k`, for the clustering from the previous calculation. Also change n_proc to however many processors you can use.

In [None]:
k = 100
clusters = spc.k_means_cluster(tica, k, stride=1, max_iter=500, n_proc=8, seed=28101990)

In [None]:
%%time
msm, weights = spc.get_msm(clusters)

## CVs for projection

Make a `cv_proj` numpy array with shape (n_iteration * n_swarms_iterations, n_frames_per_iter, 2). n_frames_per_iter is usally 2 since you only record the value of the cvs at the begining and end of the swarm. The last dimesions are the cvs on which you would like to project your FES using the weights obtained from the msm. The FES is then the negative log of a *weighted* histogram of the projection cvs using the weights from the msm. The projection cvs can be anything that you can calculate for a structure, not necessarily the cvs of the string. In the example bellow it is the mean of two cvs.

In [None]:
cv_proj = spc.cvs_to_SF_IG(cv_coordinates, [0, 1], [10, 11])

## Project FES

Do the projection and take log. You have to choose a bandwidth for the [KDE](https://en.wikipedia.org/wiki/Kernel_density_estimation) of the histogram. It should be big enough to reduce noise but not so big to remove features. If you give `None`

In [None]:
bandwidth = 0.05
p_of_cv, extent = spc.get_kde(cv_proj, weights, bandwidth)
F0 = -np.log(p_of_cv)
F = F0 - F0.min()
F[F > 40] = np.nan

Do the projection and take log. You have to choose a bandwidth for the [KDE](https://en.wikipedia.org/wiki/Kernel_density_estimation) of the histogram. It should be big enough to reduce noise but not so big to remove features. If you give `None`

In [None]:
bandwidth = 0.05
p_of_cv, extent = spc.get_kde(cv_proj, weights, bandwidth)
F0 = -np.log(p_of_cv)
F = F0 - F0.min()
F[F > 40] = np.nan

## Plot FES

In [None]:
fig, ax = spc.plot_2D_heatmap(
    F,
    extent,
    f_max=20,
    cbar_label="Free Energy (kT)",
    xlabel="SF (nm)",
    ylabel="IG (nm)",
)
fig.tight_layout()
fig.savefig(path_report + "FES.png")

## Bootstrap to get error

The problem with calculating errors in MD is that most statistical techniques for this rely on the data being uncorrelated. MD data is most of the time highly correlated due to the proximity in time and starting structure. Correlated data generates artificially low error estimates. 

For this reason we use blocking. In our case we will use blocking+bootstrapping. This is very well explained in this [very usefull video](https://www.youtube.com/watch?v=gHXXGYIgasE&t=1854s) by prof. Giovanni Bussi.

The uncertainty is calculated as half of the interval containing 95% of the probability of the distribution of histograms generated in the bootstraps.

This part is probably going to be slow! Maybe it will go over night. It is actually doing len(blocks) * n_boot msms! The good things is that once you have figured out for your system (and similar systems) what is a reasonable number of blocks then you can just do `blocks=[my_reasonable_number_blocks]`.

In [None]:
blocks = [2, 4, 8, 16, 32]
n_blocks = len(blocks)
n_boot = 200
calculate = False
if calculate:
    errors = spc.get_error(
        cv_proj, clusters, extent, n_boot=n_boot, bandwidth=0.05, nbin=55, blocks=blocks, seed=28101990
    )
    np.save(f"postprocessing/errors_{n_boot}_{n_blocks}.npy", errors)
else:
    errors = np.load(f"postprocessing/errors_{n_boot}_{n_blocks}.npy")

Choose the number of blocks that gives you a high error.

Note,`e_min` and `e_max` are choosen to remove the extremely high or low values of error that are generated due to poor sampling or high free energy. These regions of the "free error surface" are not what we care about and thus we remove it from the statistic and the visualization.

In [None]:
e_max = 6
e_min = 1.0e-03
e = errors.copy()
e[e > e_max] = np.nan
e[e <= e_min] = np.nan
_ = plt.plot(np.array(blocks), np.nanmean(e, axis=(1, 2)), marker="o")
_ = plt.xlabel("Number of blocks", size=15)
_ = plt.ylabel("FES error (kT)", size=15)

From the previous plot you can see which is the adequate number of blocks that low but still gives you the plateauing (or highest) error.

In [None]:
number_blocks = 16
f_max = 20
e_max = 6
e_min = 1.0e-03

e = errors[blocks.index(number_blocks)].copy()
e[e > e_max] = np.nan
e[e <= e_min] = np.nan

fig, ax = plt.subplots(1, 2, figsize=(10 * 2, 7), sharex=True, sharey=True)
_ = spc.plot_2D_heatmap(
    F,
    extent,
    f_max=20,
    cbar_label="Free Energy (kT)",
    xlabel="SF (nm)",
    ylabel="IG (nm)",
    fig=fig,
    ax=ax[0],
)
_ = spc.plot_2D_heatmap(
    e,
    extent,
    f_max=e_max,
    cbar_label="FES Uncertainty (kT)",
    xlabel="SF (nm)",
    cmap=plt.cm.viridis_r,
    fig=fig,
    ax=ax[1],
)
ax[1].set_title("Bootstrap Error (95%)")
fig.tight_layout()
fig.savefig(path_report + "FES_error.png")