### Example notebook to build an hierarchical deep MSM

This notebook aims to be a template for users trying to build a deep reversible Markov State model with coarse-graining layers and an additional attention mechanism in order to find important residues for the kinetics.

The code is based on the package deeptime, where this code should soon be integrated. This should be seen as a 
pre-alpha version.

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
# import mdshare  # for trajectory data

from tqdm.notebook import tqdm  # progress bar

In [None]:
if torch.cuda.is_available():
    device = torch.device("cuda")
    torch.backends.cudnn.benchmark = True
else:
    device = torch.device("cpu")
torch.set_num_threads(12)

In [None]:
# Example how to load your data (of course you can change the features to your likings)
def load_trajectories(files, pdb):
    '''
    Loads all the trajectories specified in files and directly estimates the residue_mindist features 
    used in the paper.
    You can change the features: see http://www.emma-project.org/latest/api/generated/pyemma.coordinates.featurizer.html#
    ------
    Inputs:
    files: list of strings. List of locations of all trajectories which should be loaded.
    pdb: string. location of the corresponding pdb file.
    
    Returns:
    data: list of np.array. List of the residue_mindist features for all trajectories specified in files.
            If only one trajectory is supplied, it returns directly the np.array.
    '''
    
    import pyemma
    feat = pyemma.coordinates.featurizer(pdb)
    feat.add_residue_mindist(residue_pairs='all', scheme='closest-heavy', ignore_nonprotein=True, threshold=None, periodic=True)
    data = pyemma.coordinates.load(files, features=feat)
    
    return data

In [None]:
# Load the data
# with the example code from above
files = ['/path/to/file1', 'path/to/file2']
pdb = '/path/to/pdb.pdb'
output_all_files = load_trajectories(files, pdb)
# you can then save the processed data, to directly load the interesting features
# np.save('/path/to/save', output_all_files)
# output_all_files = np.load('/path/to/save.npy')
# output_all_files = np.load('/path/to/save.npy') # need to specify where your data lies. 
# You can use the lines above to prepare your own data
# output_all_files = np.load('/srv/public/andreas/data/desres/2f4k/villin_skip1.npy') # this line is for checking
traj_whole = output_all_files

traj_data_points, input_size = traj_whole[0].shape
# Skip data to make the data less correlated
skip=25
data = [traj_whole[0][::skip]]

n_residues = int(-1/2 + np.sqrt(1/4+input_size*2) + 3)

In [None]:
# Hyperparameter definitions, should be adapted for specific problems

# number of output nodes/states of the MSM or Koopman model, therefore also nodes of chi
# The list defines how the output will be coarse grained from first to last entry
output_sizes = [4,3,2]

# Tau, how much is the timeshift of the two datasets in the default training
# tau_chi for pretraining the vampnet usually smaller than the tau for the deepMSM
tau = 20*2*25//skip # 5, 20
tau_chi = 25//skip

# Batch size for Stochastic Gradient descent
batch_size = 512
# Larger batch size for fine tuning weights at the end of training
batch_size_large = 20000

# Which trajectory points percentage is used as training, validation, and rest for test
valid_ratio = 0.3
test_ratio = 0.2

# How many hidden layers the network chi has
network_depth = 4

# Width of every layer of chi
layer_width = 30

# Mask hyperparameter
mask_const=False # if the trained attention mask is constant over time
patchsize=4 # size of the sliding window
mask_depth=4 # if time dependent how many hidden layers has the attention network
mask_width=30 # the width of the attention hidden layers
factor_att=True # if to use a factor which scales the input on average back to input
regularizer_noise = 1.0 # noise to regularize

# Learning rate used for the ADAM optimizer
learning_rate = 5e-4

# create a list with the number of nodes for each layer
nodes = [layer_width]*network_depth

# epsilon for numerical inversion of correlation matrices
epsilon = np.array(1e-7).astype('float32')

### Split the data into train, validation, and test set

In [None]:
from deeptime.util.data import TrajectoryDataset

dataset = TrajectoryDataset(lagtime=tau_chi, trajectory=data[0])

In [None]:
n_val = int(len(dataset)*valid_ratio)
n_test = int(len(dataset)*test_ratio)
train_data, val_data, test_data = torch.utils.data.random_split(dataset, [len(dataset) - n_val - n_test, n_val, n_test])

### Define the structure of the VAMPnet

In [None]:
from helper import Mean_std_layer, Mask, pred_batchwise, plot_mask, get_its, get_ck, plot_cg

normalizer = Mean_std_layer(input_size, mean=torch.Tensor(train_data.dataset.data.mean(0)),
                           std=torch.Tensor(train_data.dataset.data.std(0)))

mask = Mask(data[0].shape[1],mask_const, mask_depth, mask_width, patchsize, fac=factor_att, noise=regularizer_noise, device=device)

lobe = nn.Sequential(
    nn.Linear(data[0].shape[1], layer_width), nn.ELU(),
    nn.Linear(layer_width, layer_width), nn.ELU(),
    nn.Linear(layer_width, layer_width), nn.ELU(),
    nn.Linear(layer_width, layer_width), nn.ELU(),
    nn.Linear(layer_width, layer_width), nn.ELU(),
    nn.Linear(layer_width, output_sizes[0]),
    nn.Softmax(dim=1)  # obtain fuzzy probability distribution over output states
)
lobe_vampnet = nn.Sequential(
    normalizer,
    lobe  # obtain fuzzy probability distribution over output states
)
lobe_msm = nn.Sequential(
    lobe)
lobe_mask = nn.Sequential(
    normalizer,
    mask)
lobe_vampnet_mask = nn.Sequential(
    lobe_mask,
    lobe)
from copy import deepcopy
lobe_timelagged = deepcopy(lobe).to(device=device)
lobe = lobe.to(device=device)
lobe_vampnet.to(device=device)
lobe_msm.to(device=device)
lobe_vampnet_mask.to(device=device)
lobe_mask.to(device=device)
print(lobe)

### Define the estimators

In [None]:
from deeptime.decomposition.deep import VAMPNet
from deepmsm import DeepMSM

vampnet = VAMPNet(lobe=lobe_vampnet, learning_rate=5e-4, device=device) # for pretraining the VAMPnet without mask
vampnet_mask = VAMPNet(lobe=lobe_vampnet_mask, learning_rate=5e-4, device=device)
deepmsm = DeepMSM(lobe=lobe, output_dim=output_sizes[0], coarse_grain=output_sizes[1:], mask=lobe_mask, learning_rate=5e-4, device=device)

### Create DataLoader for validation and training data

In [None]:
from torch.utils.data import DataLoader

loader_train = DataLoader(train_data, batch_size=batch_size, shuffle=True)
loader_val = DataLoader(val_data, batch_size=len(val_data), shuffle=False)

In [None]:
# Can be used to record the training performance with tensorboard
# it is not necessary for training or using the methods
# if you do not wish to install the additional package just leave the flag to false!
tensorboard_installed = False
if tensorboard_installed:
    from torch.utils.tensorboard import SummaryWriter
    writer = SummaryWriter()
    input_model, _ = next(iter(loader_train))
    writer.add_graph(lobe, input_to_model=input_model.to(device))
else:
    writer=None

### Plot the mask before training

In [None]:
plot_mask(data=data[0], lobe=lobe_vampnet, mask=mask, mask_const=mask_const, device=device)

### Train the vampnet

In [None]:
model = vampnet.fit(loader_train, n_epochs=30,
                    validation_loader=loader_val, progress=tqdm).fetch_model()
plt.loglog(*vampnet.train_scores.T, label='training')
plt.loglog(*vampnet.validation_scores.T, label='validation')
plt.xlabel('step')
plt.ylabel('score')
plt.legend();

In [None]:
model = vampnet_mask.fit(loader_train, n_epochs=30,
                    validation_loader=loader_val, progress=tqdm).fetch_model()
plt.loglog(*vampnet_mask.train_scores.T, label='training')
plt.loglog(*vampnet_mask.validation_scores.T, label='validation')
plt.xlabel('step')
plt.ylabel('score')
plt.legend();

### Plot the mask

In [None]:
plot_mask(data=data[0], lobe=lobe_vampnet_mask, mask=mask, mask_const=mask_const, device=device)

In [None]:
state_probabilities = model.transform(data[0])
for ix, (mini, maxi) in enumerate(zip(np.min(state_probabilities, axis=0),
                                      np.max(state_probabilities, axis=0))):
    print(f"State {ix+1}: [{mini}, {maxi}]")

### Extract the parameters of the trained vampnet

In [None]:
state_dict_vampnet = vampnet_mask.lobe.state_dict()
vampnet_mask.lobe.load_state_dict(state_dict_vampnet)

### Train for the deepMSM

In [None]:
# train only for the matrix S
deepmsm.fit(loader_train, n_epochs=100, validation_loader=loader_val, train_mode='s', tb_writer=writer)
plt.loglog(*deepmsm.train_scores.T, label='training')
plt.loglog(*deepmsm.validation_scores.T, label='validation')
plt.xlabel('step')
plt.ylabel('score')
plt.legend();

In [None]:
# Train for S and u
deepmsm.fit(loader_train, n_epochs=1000, validation_loader=loader_val, train_mode='us', tb_writer=writer)
plt.loglog(*deepmsm.train_scores.T, label='training')
plt.loglog(*deepmsm.validation_scores.T, label='validation')
plt.xlabel('step')
plt.ylabel('score')
plt.legend();

In [None]:
# Train for chi, u, and S in an iterative manner
deepmsm.fit_routine(loader_train, n_epochs=50, validation_loader=loader_val, rel=0.001, reset_u=False, max_iter=1000, tb_writer=writer)
plt.loglog(*deepmsm.train_scores.T, label='training')
plt.loglog(*deepmsm.validation_scores.T, label='validation')
plt.xlabel('step')
plt.ylabel('score')
plt.legend();

### Plot the mask of the trained deepMSM

In [None]:
plot_mask(data=data[0], lobe=lobe_msm, mask=mask, mask_const=mask_const, device=device, vmax=0.25)

### Save the final deepMSM

In [None]:
deepmsm.save_params('./test_params')

### Extract the model and estimate the transition matrix

In [None]:
model_msm = deepmsm.fetch_model()
T = model_msm.get_transition_matrix(test_data.dataset.data, test_data.dataset.data_lagged)

In [None]:
np.linalg.eigvals(T)

### Estimate the transition matrix for different tau values

In [None]:
# define tau values
steps = 8
tau_msm = tau
tau_ck = np.arange(1,(steps+1))*tau_msm
tau_its = np.concatenate([np.array([1, 3, 5]), tau_ck])

In [None]:
deepmsm.load_params('./test_params.npz')
T_results = np.ones((len(tau_its) ,output_sizes[0], output_sizes[0]))
its_all_vamp = []
for i, tau_i in enumerate(tau_its):
    if i==0: # T for this tau was already evaluated
        T_results[i]=T
    else:
        # split the data with the new tau
        dataset = TrajectoryDataset(lagtime=tau_i, trajectory=data[0])
        n_val = int(len(dataset)*valid_ratio)
        n_test = int(len(dataset)*test_ratio)
        train_data, val_data, test_data = torch.utils.data.random_split(dataset, [len(dataset) - n_val - n_test, n_val, n_test])
        loader_train = DataLoader(train_data, batch_size=batch_size, shuffle=True)
        loader_val = DataLoader(val_data, batch_size=len(val_data), shuffle=False)
        # reset u and S to be retrained for the new tau
        deepmsm.reset_u_S(loader_train)
        # reset the optimizers for u and S
        deepmsm.reset_opt_u_S(lr=10)
        # train for S
        for _ in range(2):
            model_msm_i = deepmsm.fit(loader_train, n_epochs=1000, validation_loader=loader_val, train_mode='s').fetch_model()
            # train for u and S
            model_msm_i = deepmsm.fit(loader_train, n_epochs=100, validation_loader=loader_val, train_mode='us').fetch_model()
        # retrieve the transition matrix for the specific tau
        T_results[i]  = model_msm_i.get_transition_matrix(test_data.dataset.data, test_data.dataset.data_lagged)

### Estimate implied timescales

In [None]:
its = get_its(T_results, tau_its, calculate_K=False)

In [None]:
fac = 200.*skip*1e-6 
# fac = 0.0002

plt.figure(figsize=(6,4));

label_x = np.array([.1,0.3,1, 2, 5,10,100,1000])/fac # array is in microsecond
label_y = np.array([.1,1, 2, 5,10, 100, 1000])/fac
# fig = plt.figure(figsize = (8,8))
for j in range(0,output_sizes[0]-1):
    plt.semilogy(tau_its, its[::-1][j], lw=5)
#     plt.fill_between(tau_its, all_its_vamp_min[::-1][j], all_its_vamp_max[::-1][j], alpha = 0.3)
plt.semilogy(tau_its,tau_its, 'k')
plt.xlabel('lag [$\mu$s]', fontsize=26)
plt.xticks(label_x, label_x*fac, fontsize=22)
plt.ylabel('timescale [$\mu$s]', fontsize=26)
plt.yticks(label_y, np.round(label_y*fac, decimals=1), fontsize=22)
plt.fill_between(tau_its,tau_its,0.1,alpha = 0.2,color='k');
plt.ylim(0.01/fac, 3/fac)
plt.xlim(tau_its[0], 0.3/fac)
plt.show()

### Estimate CK-test

In [None]:
predicted, estimated = get_ck(T_results[3:], tau_ck)

In [None]:
import matplotlib.gridspec as gridspec
output_size = output_sizes[0]
fig = plt.figure(figsize = (16,16))
gs1 = gridspec.GridSpec(output_size, output_size)
gs1.update(wspace=0.1, hspace=0.05)
states = output_size
for index_i in range(states):
    for index_j in range(states):
        ax = plt.subplot(gs1[index_i*output_size+index_j])
        ax.plot(tau_ck, predicted[index_i, index_j], color='b', lw=4)
        ax.plot(tau_ck, estimated[index_i, index_j], color = 'r', lw=4, linestyle = '--')
#         ax.fill_between(tau_ck,lx_min[index_i, index_j],lx_max[index_i, index_j], alpha = 0.25 )
#         ax.errorbar(tau_ck, rx_mean[index_i, index_j], yerr= np.array([rx_mean[index_i][index_j]-rx_min[index_i][index_j], rx_max[index_i][index_j]-rx_mean[index_i][index_j]]), color = 'r', lw=4, linestyle = '--')
        title = str(index_i+1)+ '->' +str(index_j+1)
        
        ax.text(.75,.8, title,
            horizontalalignment='center',
            transform=ax.transAxes,  fontdict = {'size':26})
    
        ax.set_ylim((-0.1,1.1));
        ax.set_xlim((0, tau_ck[-1]+5));
        
        if (index_j == 0):
            ax.axes.get_yaxis().set_ticks([0, 1])
            ax.yaxis.set_tick_params(labelsize=32)
        
        else:
            ax.axes.get_yaxis().set_ticks([])
        
        if (index_i == output_size -1):
            
            xticks = np.array([20,60])
            float_formatter = lambda x: np.array([("%.1f" % y if y > 0.001 else "0") for y in x])
            
            ax.xaxis.set_ticks(xticks);
            ax.xaxis.set_ticklabels((xticks*fac));
            ax.xaxis.set_tick_params(labelsize=32)
        else:
            ax.axes.get_xaxis().set_ticks([])
            
        if (index_i == output_size - 1 and index_j == output_size - 4):
            ax.text(2.16, -0.4, "[$\mu$s]",
                horizontalalignment='center',
                transform=ax.transAxes,  fontdict = {'size':28})
plt.show()

### Retrain the model for tau_msm

In [None]:
deepmsm.load_params('./test_params.npz')
dataset = TrajectoryDataset(lagtime=tau_msm, trajectory=data[0])
n_val = int(len(dataset)*valid_ratio)
n_test = int(len(dataset)*test_ratio)
train_data, val_data, test_data = torch.utils.data.random_split(dataset, [len(dataset) - n_val - n_test, n_val, n_test])
loader_train = DataLoader(train_data, batch_size=batch_size, shuffle=True)
loader_val = DataLoader(val_data, batch_size=len(val_data), shuffle=False)
# reset u and S to be retrained for the new tau
deepmsm.reset_u_S(loader_train)
# reset the optimizers for u and S
deepmsm.reset_opt_u_S(lr=10)

In [None]:
for _ in range(5):
    model_msm_final = deepmsm.fit(loader_train, n_epochs=1000, validation_loader=loader_val, train_mode='s', tb_writer=writer).fetch_model()
    # train for u and S
    model_msm_final = deepmsm.fit(loader_train, n_epochs=100, validation_loader=loader_val, train_mode='us', tb_writer=writer).fetch_model()

In [None]:
model_msm_final.timescales(test_data.dataset.data, test_data.dataset.data_lagged, tau_msm)*fac

In [None]:
model_msm_final.get_transition_matrix(test_data.dataset.data, test_data.dataset.data_lagged)

### Estimate coarse-grain model

In [None]:
# initialize the coarse-graining layer with pcca
deepmsm.reset_cg(idx=0, lr=0.1)
deepmsm.initialize_cg_layer(idx=0, data_loader=loader_train, factor=1.)

In [None]:
# train the first coarse-graining layer
model_cg_1 = deepmsm.fit_cg(loader_train, n_epochs=3000, validation_loader=loader_val, train_mode='single', idx=0).fetch_model()

In [None]:
model_cg_1.timescales_cg(test_data.dataset.data, test_data.dataset.data_lagged, tau=tau_msm, idx=0)*fac

In [None]:
# Plot the learned coarse-graining
plot_cg(deepmsm.cg_list[0])

In [None]:
# initialize the second coarse-graining layer
deepmsm.reset_cg(idx=1, lr=0.1)
deepmsm.initialize_cg_layer(idx=1, data_loader=loader_train, factor=1.)

In [None]:
# train the second coarse-graining layer
model_cg_2 = deepmsm.fit_cg(loader_train, n_epochs=3000, validation_loader=loader_val, train_mode='single', idx=1).fetch_model()

In [None]:
model_cg_2.timescales_cg(test_data.dataset.data, test_data.dataset.data_lagged, tau=tau_msm, idx=1)*fac

In [None]:
# Plot the learned coarse-graining
plot_cg(deepmsm.cg_list[1])

In [None]:
# train for all respresentations at the same time
model_cg_all = deepmsm.fit_cg(loader_train, n_epochs=1000, validation_loader=loader_val, train_mode='all').fetch_model()

In [None]:
print(model_msm_final.timescales(test_data.dataset.data, test_data.dataset.data_lagged, tau_msm)*fac)
print(model_cg_1.timescales_cg(test_data.dataset.data, test_data.dataset.data_lagged, tau=tau_msm, idx=0)*fac)
print(model_cg_2.timescales_cg(test_data.dataset.data, test_data.dataset.data_lagged, tau=tau_msm, idx=1)*fac)