# String Method Analysis Markov-State-Models
## Imports

In [None]:
import os
import pickle
import sys
import logging
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
from glob import glob


logging.getLogger("stringmethod").setLevel(logging.ERROR)
sys.path.append("../string-method-gmxapi/")
import src.analysis.string_tica_msm as my_msm
import src.analysis.plotting as my_plot
import src.analysis.cvs as my_cvs
from src.analysis.utils import natural_sort

In [None]:
%load_ext lab_black
%load_ext autoreload
%autoreload 2

## Load data

This notebook needs to run in the string simulation folder, this cell will get you there. You also set up a path for writing the figures.

In [None]:
name_sim = "LB-CHARMM/"
simulation_directory = f"/data/sperez/Projects/string_sims/data/raw/{name_sim}/"
os.chdir(simulation_directory)
os.getcwd()

In [None]:
with open("cv.pkl", "rb") as file:
    cvs, ndx_groups = pickle.load(file)

The `load_swarm_data` function will load the swarm data in the `cv_coordinates`. If you set `extract=True` it will read the data from the swarm files. If you have done this previously you can set `extract=False` so the function just reads `postprocessing/cv_coordinates.npy`. `first_iteration` can be used to exclude initial swarms as equilibration and `last_iteration` can be done to exclude some iterations for example if you want to estimate the FES convergence by comparing blocks of data.

In [None]:
cv_coordinates = my_msm.load_swarm_data(
    extract=True, first_iteration=100, #last_iteration=300
)

In [None]:
files = natural_sort(glob("./strings/string[0-9]*txt"))
strings = np.array([np.loadtxt(file).T for file in files])

## Dimensionality reduction with TICA

The following cell computes the tica projection of the string cvs and discards the tics that have the lowest kinetic variance. This reduces the cvs space to a lower dimensional space that is adapted to the kinetic variance. You can use the drop keyword to drop certain cvs that are not well converged in the string simulation or that change very little from the beggining to the end of the string. The best case scenario is that `drop=[]` just works.

In [None]:
tica = my_msm.cvs_to_tica(cv_coordinates, drop=[20, 21, 22, 23, 32, 33, 34, 35])

## Cluster

The next cell plots the "vamp score" of using `n_clustercenters` to make an MSM. You should find that at some point the vamp score saturates. Choose the minimum number of clusters that gives you the saturated vamp score as the value of k for the next steps. This might take a little while.`n_jobs` refers to the number of parallel processes used. `scores` gives you the vamp scores in case you want to save them or use them some other way

In [None]:
n_clustercenters = [5, 10, 30, 50, 75, 100, 200, 500][::-1]
fig, ax, _ = my_msm.get_vamp_vs_k(n_clustercenters, tica, n_jobs=4)

If the calculation fails, there is something wrong with your MSM. Either you have too little transitions or there too many cvs in tica to have all the states well connected. Solutions:
+ Reduce the maximum number of clusters (drop 200 and 500) of `n_clustercenters` and see if you get a saturated curve.
+ Reduce the number of cvs that went into your TICA calculation.
+ Do more iterations of the string method.
+ Use the `allow_failed_msms=True` but be carefull :)

## MSM Deeptime

Choose the number of clusters, `k`, for the clustering from the previous calculation. Also change n_proc to however many processors you can use.

In [None]:
k = 100
clusters = my_msm.k_means_cluster(
    tica, k, stride=1, max_iter=500, n_jobs=4, seed=28101990
)

In [None]:
%%time
msm, weights = my_msm.get_msm(clusters)

## CVs for projection

Make a `cv_proj` numpy array with shape (n_iteration * n_swarms_iterations, n_frames_per_iter, 2). n_frames_per_iter is usally 2 since you only record the value of the cvs at the begining and end of the swarm. The last dimesions are the cvs on which you would like to project your FES using the weights obtained from the msm. The FES is then the negative log of a *weighted* histogram of the projection cvs using the weights from the msm. The projection cvs can be anything that you can calculate for a structure, not necessarily the cvs of the string. In the example bellow it is the mean of two cvs.

In [None]:
cv_proj = np.concatenate(
    [
        np.mean([cv_coordinates[:, :, 0:1], cv_coordinates[:, :, 1:2]], axis=0),
        np.mean([cv_coordinates[:, :, 10:11], cv_coordinates[:, :, 11:12]], axis=0),
    ],
    axis=2,
)

## Project FES

Do the projection and take log. You have to choose a bandwidth for the [KDE](https://en.wikipedia.org/wiki/Kernel_density_estimation) of the histogram. It should be big enough to reduce noise but not so big to remove features. If you give `None`

**Warning** The actual bandwith used by the algorithm is the covariance matrix of the data (since the gaussians are multidimensional) times the bandwidth. So if you change the data the spread of the KDE gaussians changes.

In [None]:
bandwidth = 0.05
p_of_cv, extent = my_msm.get_kde(cv_proj, weights, bandwidth)
F0 = -np.log(p_of_cv)
F = F0 - F0.min()
F[F > 40] = np.nan

## Plot FES

In [None]:
f_max = 25
fig, ax = my_plot.plot_2D_heatmap(
    F,
    extent,
    f_max=f_max,
    cbar_label="Free Energy (kT)",
    xlabel="SF (nm)",
    ylabel="IG (nm)",
)
fig.tight_layout()

## Bootstrap to get error

The problem with calculating errors in MD is that most statistical techniques for this rely on the data being uncorrelated. MD data is most of the time highly correlated due to the proximity in time and starting structure. Correlated data generates artificially low error estimates. 

For this reason we use blocking. In our case we will use blocking+bootstrapping. This is very well explained in this [very usefull video](https://www.youtube.com/watch?v=gHXXGYIgasE&t=1854s) by prof. Giovanni Bussi.

The uncertainty is calculated as half of the interval containing 95% of the probability of the distribution of histograms generated in the bootstraps.

This part is probably going to be slow! Maybe it will go over night. It is actually doing len(blocks) * n_boot msms! The good things is that once you have figured out for your system (and similar systems) what is a reasonable number of blocks then you can just do `blocks=[my_reasonable_number_blocks]`. 100-150 iterations seems reasonable in general.

In [None]:
# add n_jobs
n_boot = 100
blocks = [2, 4, 8, 16, 32]
errors = my_msm.get_error(
    cv_proj,
    clusters,
    extent,
    n_boot=n_boot,
    bandwidth=0.05,
    nbin=55,
    n_jobs=4,
    blocks=blocks,
)

In [None]:
fig, ax = plt.subplots(1, 1)
errors[:, ~np.isfinite(F)] = np.nan
label = f"n_boot={n_boot}"
mean = np.nanmean(errors, axis=(1, 2))
std_err = np.nanstd(errors, axis=(1, 2)) / np.sqrt(errors.shape[0])
ax.plot(np.array(blocks), mean, marker="o", label=label)
ax.fill_between(np.array(blocks), mean + std_err, mean - std_err, alpha=0.3)
ax.legend()
ax.set_xlabel("Number of blocks", size=15)
ax.set_ylabel("FES error (kT)", size=15)

From the previous plot you can see which is the adequate number of blocks that low but still gives you the plateauing (or highest) error.

Choose the number of blocks that gives you a high error.

In [None]:
number_blocks = 16
f_max = 20
e_max = None

e = errors[blocks.index(number_blocks)].copy()

fig, ax = plt.subplots(1, 2, figsize=(10 * 2, 7), sharex=True, sharey=True)
_ = my_plot.plot_2D_heatmap(
    F,
    extent,
    f_max=f_max,
    cbar_label="Free Energy (kT)",
    xlabel="SF (nm)",
    ylabel="IG (nm)",
    fig=fig,
    ax=ax[0],
)
_ = my_plot.plot_2D_heatmap(
    e,
    extent,
    f_max=e_max,
    cbar_label="FES Uncertainty (kT)",
    xlabel="SF (nm)",
    cmap=plt.cm.viridis_r,
    fig=fig,
    ax=ax[1],
)
ax[1].set_title("Bootstrap Error (95%)")
fig.tight_layout()

# Path CVs

Path cvs are calculated based on the article by [Branduardi et al.](https://aip.scitation.org/doi/pdf/10.1063/1.2432340). We assign two cvs to the path, `s_path` as a cv measuring the position of the trajectory along the path and `z_path` the position of the trajectory perpendicular to the path.

Define as the transition path as the average of the last `av_last_n_it` of the strings. Obtain a reasonable guess of the parameter lambda according to the heuristics of the paper.

In [None]:
n_strings = strings.shape[0]
av_last_n_it = 25
path = np.mean(strings[n_strings - av_last_n_it :, :, :], axis=0)
lam = my_cvs.get_path_lambda(path)
print(f"Lambda value for path {lam:.2f}")

Let's see if the lambda gives a well behaved path cv. The progress variable (s) should be increasing in the path itself and the distance to path variable (z) be low and constant.

## Representation of Path CV

### Load data and calculate path cvs

In [None]:
cv_of_path = []
for p in path.T:
    cv_of_path.append(my_cvs.cvs_to_path(p, path=path, lam=lam))
cv_of_path = np.array(cv_of_path)

In [None]:
cvs_path = []
for i in range(cv_coordinates.shape[0]):
    cvs_path.append([])
    for j in range(cv_coordinates.shape[1]):
        cvs_path[i].append(
            my_cvs.cvs_to_path(cv_coordinates[i, j, :], path=path, lam=lam)
        )
cvs_path = np.array(cvs_path)

### Path CV on final string

If the path and lambda you have calculated is OK: 
+ `S` should increase with bead number in an approximate range 0 to 1. 
+ `Z` should be small and oscillating about some constant value. (We don't care too much about z anyway)

In [None]:
fig, ax = plt.subplots(1, 2, figsize=(10 * 2, 7))
ax[0].plot(cv_of_path[:, 0], marker="o")
ax[1].plot(cv_of_path[:, 1], marker="o")
ax[0].set_xlabel("bead number")
ax[0].set_ylabel("S")
ax[1].set_xlabel("bead number")
ax[1].set_ylabel("Z")
ax[0].set_title(f"Path-variable {lam = :.2f}")
ax[1].set_title(f"Path-variable {lam = :.2f}")

### Path CV projected on IG vs SF

Check how the path variables project onto the canonical inactivation 2CV FES

The projection code is general, you can project any property provided it is in a numpy array with the right shape. 

In [None]:
bandwidth = 0.05

In [None]:
s_of_cv, extent = my_msm.project_property_on_cv_kde(
    cv_proj, weights=weights, proper=cvs_path[:, :, 0:1], bandwidth=bandwidth
)

In [None]:
z_of_cv, extent = my_msm.project_property_on_cv_kde(
    cv_proj, weights=weights, proper=cvs_path[:, :, 1:2], bandwidth=bandwidth
)

The z graph is not really interesting and I don't interpret it much.

The s projection graph (left) is very important. If your definition of the path and lambda are correct the color should go from low to high as the system moves from reactants to products (or vice versa). This is a way of knowing if you are capturing the right physics with the `s_path` cv.

In [None]:
fig, ax = plt.subplots(1, 2, figsize=(10 * 2, 7), sharex=True, sharey=True)
_ = my_plot.plot_2D_heatmap(
    s_of_cv,
    extent,
    cbar_label="s[path]",
    xlabel="SF (nm)",
    ylabel="IG (nm)",
    f_min=0,
    f_max=1,
    fig=fig,
    cmap=plt.cm.Spectral,
    ax=ax[0],
    n_colors=200,
    c_density=F,
    c_min=0,
    c_max=20,
    c_color="k",
)
ax[0].contour(F, levels=20, extent=extent, vmin=0, vmax=20, colors="k")
ax[0].grid(None)
_ = my_plot.plot_2D_heatmap(
    z_of_cv,
    extent,
    cbar_label="z[path]",
    xlabel="SF (nm)",
    cmap=plt.cm.magma,
    f_min=0,
    f_max=0.15,
    fig=fig,
    ax=ax[1],
    n_colors=200,
    c_density=F,
    c_min=0,
    c_max=20,
    c_color="w",
)
ax[1].contour(F, levels=20, extent=extent, vmin=0, vmax=20, colors="w")
ax[1].grid(None)
fig.tight_layout()

## Calculate FES projected on path CV

Calculate FES preliminarily, for example to optimize `bandwidth`.

In [None]:
s_path = cvs_path[:, :, 0:1]

In [None]:
%%time
bandwidth = 0.25
nbins = 100
p_of_cv, extent = my_msm.get_kde(s_path, weights, bandwidth, nbins=nbins)
F0 = -np.log(p_of_cv)
F = F0 - F0.min()
F[F > 40] = np.nan
s = np.linspace(extent[0], extent[1], nbins)

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(10, 7))
ax.plot(s, F, marker=".")
ax.legend()
ax.set_xlabel("s[path]", size=15)
ax.set_ylabel("F (kT)", size=15)
ax.set_ylim([0, 20])
ax.set_xlim([0, 1])

## Calculate errors FES s-path

In [None]:
%%time
import src.analysis as spc

blocks = [2, 4, 8, 16, 32]
n_blocks = len(blocks)
n_boot = 100
errors = spc.get_error(
    s_path,
    clusters,
    extent,
    n_boot=n_boot,
    bandwidth=bandwidth,
    nbin=nbins,
    blocks=blocks,
    seed=28101990,
    n_jobs=4,
)

Choose the number of blocks that gives you a high error.

In [None]:
fig, ax = plt.subplots(1, 1)
errors[:, ~np.isfinite(F)] = np.nan
label = f"n_boot={n_boot}"
mean = np.nanmean(errors, axis=1)
std_err = np.nanstd(errors, axis=1) / np.sqrt(errors.shape[0])
ax.plot(np.array(blocks), mean, marker="o", label=label)
ax.fill_between(np.array(blocks), mean + std_err, mean - std_err, alpha=0.3)
ax.legend()
ax.set_xlabel("Number of blocks", size=15)
ax.set_ylabel("FES error (kT)", size=15)

In [None]:
fig, ax = plt.subplots(1, 1)
error_block = 8
n_blocks = len(blocks)
s = np.linspace(extent[0], extent[1], nbins)
error = errors[blocks.index(error_block), :]
ax.fill_between(s, F + error, F - error, alpha=0.3)
ax.plot(s, F, label=f"{n_boot=}", marker=".")
ax.legend()
ax.set_xlabel("s[path]", size=15)
ax.set_ylabel("F (kT)", size=15)

### Study convergence of the FES as function of path

In [None]:
%%time
calculate = True
n_swarms = 36
n_beads = 18
step = 50
step = step * n_swarms * n_beads

FES_vs_t = []
FES_vs_t.append(np.linspace(extent[0], extent[1], nbins))
for i in tqdm(range(step, s_path.shape[0]+step,step)):
    s = s_path[: i , :, :]
    c = cv_coordinates[: i, :, :]
    t = my_msm.cvs_to_tica(c, drop=[20, 21, 22, 23, 32, 33, 34, 35]) 
    cl = my_msm.k_means_cluster(t, k, stride=1, max_iter=500, n_jobs=4, seed=28101990)
    try:
        _, w= my_msm.get_msm(cl, n_jobs=4)
    except:
        continue
    p_of_cv, extent = my_msm.get_kde(
        s,
        w,
        bandwidth,
        nbins=nbins,
    )
    f0 = -np.log(p_of_cv)
    f = f0 - f0.min()
    f[f > 40] = np.nan
    FES_vs_t.append(f)
FES_vs_t.append(f)
FES_vs_t = np.array(FES_vs_t)

In [None]:
fig, ax = spc.plot_FES_1d_vs_t(FES_vs_t, xlabel="s[path] (nm)", error=error)
ax.set_ylim([0, 25])
ax.set_xlim([0, 1])

## FES s-path vs property

In [None]:
s_path = cvs_path[:, :, 0:1]

## Path vs SF (checks)

It is interesting to have a 2D FES of the path cv vs another cv to see at which point in the transition the other cv changes. 

This is very easy to do with this code just prepare the other_cv array with the correct shape in the variable `other_cv`. In this case, I am doing some averaging which is usefull for KcsA but it can be anything really.

### SF

In [None]:
other_cv_id = [0, 1]
cv_name = "SF"
cv_fig_label = "SF (nm)"

In [None]:
other_cv = my_cvs.average_strings_to_cv(cv_coordinates, other_cv_id)

In [None]:
cvs = np.concatenate([s_path, other_cv], axis=2)

In [None]:
%%time
bandwidth = 0.05
p_of_cv, extent = my_msm.get_kde(cvs, weights, bandwidth)
F0 = -np.log(p_of_cv)
F = F0 - F0.min()
F[F > 40] = np.nan

Do the projection and take log. You have to choose a bandwidth for the [KDE](https://en.wikipedia.org/wiki/Kernel_density_estimation) of the histogram. It should be big enough to reduce noise but not so big to remove features. If you give `None`

In [None]:
fig, ax = my_plot.plot_2D_heatmap(
    F,
    extent,
    f_max=25,
    f_min=0,
    cbar_label="Free Energy (kT)",
    xlabel="s[path]",
    ylabel=cv_fig_label,
)
fig.tight_layout()