### Example notebook to build a deep MSM with observables

This notebook aims to be a template for users trying to build a deep reversible Markov State model with additional experimental observables.
We simulate the situation by firstly estimate observables on the whole data set. These values will represent the true "experimental" observables. Afterwards, we will disturb our data in order to mimic the situation that the simulation has a systematic bias, e.g. through the force field. We will estimate again a deep reversible MSM with the additional observables in the hope to recover the kinetics of the undisturbed data.

The code is based on the package deeptime, where this code should soon be integrated. This should be seen as a 
pre-alpha version.

In case you have real experimental data available, you must skip the part with the data disturbance. However, it seems recommendable to train a baseline model without the additional information to check if the performance improves.

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
# import mdshare  # for trajectory data

from tqdm.notebook import tqdm  # progress bar

In [None]:
if torch.cuda.is_available():
    device = torch.device("cuda")
    torch.backends.cudnn.benchmark = True
else:
    device = torch.device("cpu")
torch.set_num_threads(12)

In [None]:
# Example how to load your data (of course you can change the features to your likings)
def load_trajectories(files, pdb):
    '''
    Loads all the trajectories specified in files and directly estimates the residue_mindist features 
    used in the paper.
    You can change the features: see http://www.emma-project.org/latest/api/generated/pyemma.coordinates.featurizer.html#
    ------
    Inputs:
    files: list of strings. List of locations of all trajectories which should be loaded.
    pdb: string. location of the corresponding pdb file.
    
    Returns:
    data: list of np.array. List of the residue_mindist features for all trajectories specified in files.
            If only one trajectory is supplied, it returns directly the np.array.
    '''
    
    import pyemma
    feat = pyemma.coordinates.featurizer(pdb)
    feat.add_residue_mindist(residue_pairs='all', scheme='closest-heavy', ignore_nonprotein=True, threshold=None, periodic=True)
    data = pyemma.coordinates.load(files, features=feat)
    
    return data

In [None]:
# Load the data
# with the example code from above
files = ['/path/to/file1', 'path/to/file2']
pdb = '/path/to/pdb.pdb'
output_all_files = load_trajectories(files, pdb)
# you can then save the processed data, to directly load the interesting features
# np.save('/path/to/save', output_all_files)
# output_all_files = np.load('/path/to/save.npy')
# output_all_files = np.load('/path/to/save.npy') # need to specify where your data lies. 
# You can use the lines above to prepare your own data
# output_all_files = np.load('/srv/public/andreas/data/desres/2f4k/villin_skip1.npy') # this line is for checking
traj_whole = output_all_files

traj_data_points, input_size = traj_whole[0].shape
# Skip data to make the data less correlated
skip=1
data = [traj_whole[0][::skip]]

n_residues = int(-1/2 + np.sqrt(1/4+input_size*2) + 3)

In [None]:
# Hyperparameter definitions, should be adapted for specific problems

# number of output nodes/states of the MSM or Koopman model, therefore also nodes of chi
# The list defines how the output will be coarse grained from first to last entry
output_size = 3

# Tau, how much is the timeshift of the two datasets in the default training
# tau_chi for pretraining the vampnet usually smaller than the tau for the deepMSM
tau = 50*25//skip # 5, 20
tau_chi = 25//skip

# Batch size for Stochastic Gradient descent
batch_size = 512
# Larger batch size for fine tuning weights at the end of training
batch_size_large = 20000

# Which trajectory points percentage is used as training, validation, and rest for test
valid_ratio = 0.3
test_ratio = 0.3

# How many hidden layers the network chi has
network_depth = 4

# Width of every layer of chi
layer_width = 30

# Mask hyperparameter
mask_const=False # if the trained attention mask is constant over time
patchsize=4 # size of the sliding window
mask_depth=4 # if time dependent how many hidden layers has the attention network
mask_width=30 # the width of the attention hidden layers
factor_att=True # if to use a factor which scales the input on average back to input
regularizer_noise = 1.0 # noise to regularize

# Learning rate used for the ADAM optimizer
learning_rate = 5e-4

# create a list with the number of nodes for each layer
nodes = [layer_width]*network_depth

# epsilon for numerical inversion of correlation matrices
epsilon = np.array(1e-7).astype('float32')

### Split the data into train, validation, and test set

In [None]:
from deeptime.util.data import TrajectoryDataset

dataset = TrajectoryDataset(lagtime=tau_chi, trajectory=data[0])

In [None]:
n_val = int(len(dataset)*valid_ratio)
n_test = int(len(dataset)*test_ratio)
train_data, val_data, test_data = torch.utils.data.random_split(dataset, [len(dataset) - n_val - n_test, n_val, n_test])

### Define the structure of the VAMPnet

In [None]:
from helper import Mean_std_layer, pred_batchwise, get_its, get_ck, estimate_mu

normalizer = Mean_std_layer(input_size, mean=torch.Tensor(train_data.dataset.data.mean(0)),
                           std=torch.Tensor(train_data.dataset.data.std(0)))

lobe = nn.Sequential(
    normalizer,
    nn.Linear(data[0].shape[1], layer_width), nn.ELU(),
    nn.Linear(layer_width, layer_width), nn.ELU(),
    nn.Linear(layer_width, layer_width), nn.ELU(),
    nn.Linear(layer_width, layer_width), nn.ELU(),
    nn.Linear(layer_width, layer_width), nn.ELU(),
    nn.Linear(layer_width, output_size),
    nn.Softmax(dim=1)  # obtain fuzzy probability distribution over output states
)
from copy import deepcopy
lobe_timelagged = deepcopy(lobe).to(device=device)
lobe = lobe.to(device=device)

print(lobe)

### Define the estimators

In [None]:
from deeptime.decomposition.deep import VAMPNet
from deepmsm import DeepMSM

vampnet = VAMPNet(lobe=lobe, learning_rate=5e-4, device=device) # for pretraining the VAMPnet without mask
deepmsm = DeepMSM(lobe=lobe, output_dim=output_size, learning_rate=5e-4, device=device)

### Create DataLoader for validation and training data

In [None]:
from torch.utils.data import DataLoader

loader_train = DataLoader(train_data, batch_size=batch_size, shuffle=True)
loader_val = DataLoader(val_data, batch_size=len(val_data), shuffle=False)

In [None]:
# Can be used to record the training performance with tensorboard
# it is not necessary for training or using the methods
# if you do not wish to install the additional package just leave the flag to false!
tensorboard_installed = True
if tensorboard_installed:
    from torch.utils.tensorboard import SummaryWriter
    writer = SummaryWriter()
    input_model, _ = next(iter(loader_train))
    writer.add_graph(lobe, input_to_model=input_model.to(device))
else:
    writer=None

### Train the vampnet

In [None]:
model = vampnet.fit(loader_train, n_epochs=10,
                    validation_loader=loader_val, progress=tqdm).fetch_model()
plt.loglog(*vampnet.train_scores.T, label='training')
plt.loglog(*vampnet.validation_scores.T, label='validation')
plt.xlabel('step')
plt.ylabel('score')
plt.legend();

In [None]:
state_probabilities = model.transform(data[0])
for ix, (mini, maxi) in enumerate(zip(np.min(state_probabilities, axis=0),
                                      np.max(state_probabilities, axis=0))):
    print(f"State {ix+1}: [{mini}, {maxi}]")

### Extract the parameters of the trained vampnet

In [None]:
state_dict_vampnet = vampnet.lobe.state_dict()
vampnet.lobe.load_state_dict(state_dict_vampnet)

### Train for the deepMSM

In [None]:
# train only for the matrix S
deepmsm.fit(loader_train, n_epochs=100, validation_loader=loader_val, train_mode='s', tb_writer=writer)
plt.loglog(*deepmsm.train_scores.T, label='training')
plt.loglog(*deepmsm.validation_scores.T, label='validation')
plt.xlabel('step')
plt.ylabel('score')
plt.legend();

In [None]:
# Train for S and u
deepmsm.fit(loader_train, n_epochs=1000, validation_loader=loader_val, train_mode='us', tb_writer=writer)
plt.loglog(*deepmsm.train_scores.T, label='training')
plt.loglog(*deepmsm.validation_scores.T, label='validation')
plt.xlabel('step')
plt.ylabel('score')
plt.legend();

In [None]:
# Train for chi, u, and S in an iterative manner
deepmsm.fit_routine(loader_train, n_epochs=5, validation_loader=loader_val, rel=0.001, reset_u=False, 
                    max_iter=1000, tb_writer=writer)
plt.loglog(*deepmsm.train_scores.T, label='training')
plt.loglog(*deepmsm.validation_scores.T, label='validation')
plt.xlabel('step')
plt.ylabel('score')
plt.legend();
# reset u in order to escape possible local minima
deepmsm.fit_routine(loader_train, n_epochs=5, validation_loader=loader_val, rel=0.001, reset_u=True, 
                    max_iter=1000, tb_writer=writer)
plt.loglog(*deepmsm.train_scores.T, label='training')
plt.loglog(*deepmsm.validation_scores.T, label='validation')
plt.xlabel('step')
plt.ylabel('score')
plt.legend();
deepmsm.fit_routine(loader_train, n_epochs=5, validation_loader=loader_val, rel=0.001, reset_u=False, 
                    max_iter=1000, tb_writer=writer)
plt.loglog(*deepmsm.train_scores.T, label='training')
plt.loglog(*deepmsm.validation_scores.T, label='validation')
plt.xlabel('step')
plt.ylabel('score')
plt.legend();

### Save the final deepMSM

In [None]:
deepmsm.save_params('./test_params')

### Extract the model and estimate the transition matrix

In [None]:
model_msm = deepmsm.fetch_model()
T = model_msm.get_transition_matrix(test_data.dataset.data, test_data.dataset.data_lagged)

In [None]:
np.linalg.eigvals(T)

### Estimate the transition matrix for different tau values

In [None]:
# define tau values
steps = 8
tau_msm = tau
tau_ck = np.arange(1,(steps+1))*tau_msm
tau_its = np.concatenate([np.array([1, 3, 5]), tau_ck])

In [None]:
deepmsm.load_params('./test_params.npz')
T_results = np.ones((len(tau_its) ,output_size, output_size))
its_all_vamp = []
for i, tau_i in enumerate(tau_its):
    if i==0: # T for this tau was already evaluated
        T_results[i]=T
    else:
        # split the data with the new tau
        dataset = TrajectoryDataset(lagtime=tau_i, trajectory=data[0])
        n_val = int(len(dataset)*valid_ratio)
        n_test = int(len(dataset)*test_ratio)
        train_data, val_data, test_data = torch.utils.data.random_split(dataset, [len(dataset) - n_val - n_test, n_val, n_test])
        loader_train = DataLoader(train_data, batch_size=batch_size, shuffle=True)
        loader_val = DataLoader(val_data, batch_size=len(val_data), shuffle=False)
        # reset u and S to be retrained for the new tau
        deepmsm.reset_u_S(loader_train)
        # reset the optimizers for u and S
        deepmsm.reset_opt_u_S(lr=1)
        # train for S
        for _ in range(5):
            model_msm_i = deepmsm.fit(loader_train, n_epochs=1000, validation_loader=loader_val, train_mode='s').fetch_model()
            # train for u and S
            model_msm_i = deepmsm.fit(loader_train, n_epochs=100, validation_loader=loader_val, train_mode='us').fetch_model()
        # retrieve the transition matrix for the specific tau
        T_results[i]  = model_msm_i.get_transition_matrix(test_data.dataset.data, test_data.dataset.data_lagged)

### Estimate implied timescales

In [None]:
its = get_its(T_results, tau_its, calculate_K=False)

In [None]:
fac = 200.*skip*1e-6  # factor to change from frames into mikroseconds, adapt for your data!!!
# fac = 0.0002

plt.figure(figsize=(6,4));

label_x = np.array([.1,0.3,1, 2, 5,10,100,1000])/fac # array is in microsecond
label_y = np.array([.1,1, 2, 5,10, 100, 1000])/fac
# fig = plt.figure(figsize = (8,8))
for j in range(0,output_size-1):
    plt.semilogy(tau_its, its[::-1][j], lw=5)
#     plt.fill_between(tau_its, all_its_vamp_min[::-1][j], all_its_vamp_max[::-1][j], alpha = 0.3)
plt.semilogy(tau_its,tau_its, 'k')
plt.xlabel('lag [$\mu$s]', fontsize=26)
plt.xticks(label_x, label_x*fac, fontsize=22)
plt.ylabel('timescale [$\mu$s]', fontsize=26)
plt.yticks(label_y, np.round(label_y*fac, decimals=1), fontsize=22)
plt.fill_between(tau_its,tau_its,0.1,alpha = 0.2,color='k');
plt.ylim(0.01/fac, 3/fac)
plt.xlim(tau_its[0], 1/fac)
plt.show()

### Estimate CK-test

In [None]:
predicted, estimated = get_ck(T_results[3:], tau_ck)

In [None]:
import matplotlib.gridspec as gridspec
fig = plt.figure(figsize = (16,16))
gs1 = gridspec.GridSpec(output_size, output_size)
gs1.update(wspace=0.1, hspace=0.05)
states = output_size
for index_i in range(states):
    for index_j in range(states):
        ax = plt.subplot(gs1[index_i*output_size+index_j])
        ax.plot(tau_ck, predicted[index_i, index_j], color='b', lw=4)
        ax.plot(tau_ck, estimated[index_i, index_j], color = 'r', lw=4, linestyle = '--')
#         ax.fill_between(tau_ck,lx_min[index_i, index_j],lx_max[index_i, index_j], alpha = 0.25 )
#         ax.errorbar(tau_ck, rx_mean[index_i, index_j], yerr= np.array([rx_mean[index_i][index_j]-rx_min[index_i][index_j], rx_max[index_i][index_j]-rx_mean[index_i][index_j]]), color = 'r', lw=4, linestyle = '--')
        title = str(index_i+1)+ '->' +str(index_j+1)
        
        ax.text(.75,.8, title,
            horizontalalignment='center',
            transform=ax.transAxes,  fontdict = {'size':26})
    
        ax.set_ylim((-0.1,1.1));
        ax.set_xlim((0, tau_ck[-1]+5));
        
        if (index_j == 0):
            ax.axes.get_yaxis().set_ticks([0, 1])
            ax.yaxis.set_tick_params(labelsize=32)
        
        else:
            ax.axes.get_yaxis().set_ticks([])
        
        if (index_i == output_size -1):
            
            xticks = np.array([2000,6000])
            float_formatter = lambda x: np.array([("%.1f" % y if y > 0.001 else "0") for y in x])
            
            ax.xaxis.set_ticks(xticks);
            ax.xaxis.set_ticklabels(((xticks*fac*1000).astype('int')/1000));
            ax.xaxis.set_tick_params(labelsize=32)
        else:
            ax.axes.get_xaxis().set_ticks([])
            
        if (index_i == output_size - 1 and index_j == output_size - 4):
            ax.text(2.16, -0.4, "[$\mu$s]",
                horizontalalignment='center',
                transform=ax.transAxes,  fontdict = {'size':28})
plt.show()

### Retrain the model for tau_msm

In [None]:
deepmsm.load_params('./test_params.npz')
dataset = TrajectoryDataset(lagtime=tau_msm, trajectory=data[0])
n_val = int(len(dataset)*valid_ratio)
n_test = int(len(dataset)*test_ratio)
train_data, val_data, test_data = torch.utils.data.random_split(dataset, [len(dataset) - n_val - n_test, n_val, n_test])
loader_train = DataLoader(train_data, batch_size=batch_size, shuffle=True)
loader_val = DataLoader(val_data, batch_size=len(val_data), shuffle=False)
# reset u and S to be retrained for the new tau
deepmsm.reset_u_S(loader_train)
# reset the optimizers for u and S
deepmsm.reset_opt_u_S(lr=1)

In [None]:
for _ in range(5):
    model_msm_final = deepmsm.fit(loader_train, n_epochs=1000, validation_loader=loader_val, train_mode='s', tb_writer=writer).fetch_model()
    # train for u and S
    model_msm_final = deepmsm.fit(loader_train, n_epochs=100, validation_loader=loader_val, train_mode='us', tb_writer=writer).fetch_model()

In [None]:
model_msm_final.timescales(test_data.dataset.data, test_data.dataset.data_lagged, tau_msm)*fac

In [None]:
T_true = model_msm_final.get_transition_matrix(test_data.dataset.data, test_data.dataset.data_lagged)
all_eigval_true = np.sort(np.linalg.eigvals(T_true))[:2]

### We have now trained our reference model which uses the true data

### Define observable 
Here we define an observable as a specific contact being formed or not. Firstly, we will estimate all observables which we want to use as the true "experimental" ones.

In [None]:
from helper import TimeLaggedDatasetObs
from deeptime.markov.tools.analysis import mfpt

In [None]:
# return the model
state_dict_true = deepmsm.state_dict()
# Define which contact you wanna use
contact_obs = 41 # index in the input array of the contact we use as microscopic observable
distances = - np.log(traj_whole[0][:,contact_obs])
contacts = (distances <0.45).astype('int')

obs_values = np.array([contacts, (contacts-1)*-1]).T
chi_true = model_msm_final(traj_whole[0])
mu_true = model_msm_final.get_mu(traj_whole[0][tau_msm:])
states_mu_true = estimate_mu(mu_true, chi_true, np.arange(tau_msm, chi_true.shape[0]))

dataset = TimeLaggedDatasetObs.from_trajectory(lagtime=tau_msm, data=traj_whole[0], 
                                               data_obs_ev=obs_values, data_obs_ac=obs_values)
# we know that the unfolded state is the most probable one, folded second for Villin, adapt for your data!!!
sort_id = np.argsort(states_mu_true)
index_unfolded = sort_id[-1] 
index_folded = sort_id[-2]
# If we wanna identify the folding process we need to define which states are the ones most extreme to that particular process
obs_true = model_msm_final.observables(dataset.data, dataset.data_lagged, dataset.data_obs_ev, dataset.data_obs_ac, state1=[index_unfolded], state2=[index_folded])
# extract the true values
ev_true = obs_true[0]
ac_true = obs_true[1]
eigval_true = obs_true[2] 
# estimate folding and unfolding rates
mfpt_fold_true = mfpt(T_true, index_folded, index_unfolded) * tau_msm * fac
mfpt_unfold_true = mfpt(T_true, index_unfolded, index_folded) * tau_msm * fac

### Manipulate the data


We seek to remove transitions between the folded and unfolded state to mimic a force field which overestimates the energy barrier between these two states.

if you have real experimental data skip this part. Always work on your whole simulation data!

In [None]:
# We seek to manipulate the folding process, so first we identify the process
T = model_msm_final.get_transition_matrix(test_data.dataset.data, test_data.dataset.data_lagged)
eigvals, eigvecs = np.linalg.eig(T)
sort_id = np.argsort(eigvals)
eigvals_sort = eigvals[sort_id]
eigvecs_sort = eigvecs[:,sort_id]

# Estimate the eigenfunction corresponding to the folding process which is the fastest one for our model
eigfunc = chi_true[::skip] @ eigvecs_sort[:,-3]
min_eigfunc = eigfunc.min()
max_eigfunc = eigfunc.max()

# Now find the transitions

# Find data points which are close to the folded and unfolded state
starting_points = np.where(eigfunc < (min_eigfunc + 0.05))[0]
end_points = np.where(eigfunc > (max_eigfunc - 0.05))[0]
if starting_points[0] > end_points[0]:
    temp = starting_points
    starting_points = end_points
    end_points = temp
transition_forward = []
transition_backward = []
flag = True
counter = 0
while flag:
    
    if counter%2==0: # if forward transition
        # find the last frame before changing to end state
        last_frame = np.where((starting_points - end_points[0])<0)[0][-1] 
        # The transition already starts tau_msm before
        transition_forward.append(np.arange(starting_points[last_frame]-tau_msm, end_points[0]))
        
        starting_points = starting_points[last_frame+1:]
    else: # backward direction
        
        # find the last frame before changing to end state
        last_frame = np.where((end_points - starting_points[0])<0)[0][-1] 
        # The transition already starts tau_msm before
        transition_backward.append(np.arange(end_points[last_frame]-tau_msm, starting_points[0]))
        
        end_points = end_points[last_frame+1:]
        
        
    counter +=1
    if len(end_points)==0 or len(starting_points)==0:
        flag=False
# now find the frames which are not part of a transition
non_transition = np.arange(eigfunc.shape[0]-tau_msm)
non_transition = np.setdiff1d(non_transition, np.concatenate(transition_forward+ transition_forward))        

In [None]:
# we can check our results by plotting the classifications
skip_fig = 1
plt.plot(non_transition, eigfunc[non_transition][::skip_fig], '.')
forwards = np.concatenate(transition_forward, axis=0)
plt.plot(forwards[::skip_fig], eigfunc[forwards][::skip_fig], '.')
backwards = np.concatenate(transition_backward, axis=0)
plt.plot(backwards[::skip_fig], eigfunc[backwards][::skip_fig], '.')
plt.xlim(0,20000)
plt.show()

In [None]:
# The non_transition frames are definitely in the manipulated data set, we leave these untouched
ind_train = []
ind_valid = []
ind_test = []
# we assign them randomly into training/validation/test set
non_length = non_transition.shape[0]//3
np.random.shuffle(non_transition)
ind_train.append(non_transition[:non_length])
ind_valid.append(non_transition[non_length:2*non_length])
ind_test.append(non_transition[2*non_length:])

# now take only a percentage of forward and backward events into the data
p_for = 0.25 # percentage of how many forward events end into the data
p_back = 0.25 # percentage for the backward event

nr_for = int(p_for*len(transition_forward)//3) # Number of transitions in each data set (training, validation, test)
print('Number of forward transitions in each data set: {}'.format(nr_for))
ind_trajs_temp = np.arange(len(transition_forward))
np.random.shuffle(ind_trajs_temp) # shuffle where they end in
for i in range(nr_for):
    ind_train.append(transition_forward[ind_trajs_temp[i]])

    ind_valid.append(transition_forward[ind_trajs_temp[i+nr_for]])

    ind_test.append(transition_forward[ind_trajs_temp[i+2*nr_for]])
# the same for the unfolding
p_back = 0.25

nr_back = int(p_back*len(transition_backward)//3)
print('Number of backward transitions in each data set: {}'.format(nr_for))
ind_trajs_temp = np.arange(len(transition_backward))
np.random.shuffle(ind_trajs_temp)
for i in range(nr_back):
    ind_train.append(transition_backward[ind_trajs_temp[i]])

    ind_valid.append(transition_backward[ind_trajs_temp[i+nr_back]])

    ind_test.append(transition_backward[ind_trajs_temp[i+2*nr_back]])

ind_train = np.concatenate(ind_train)
ind_valid = np.concatenate(ind_valid)
ind_test = np.concatenate(ind_test)

np.random.shuffle(ind_train)
np.random.shuffle(ind_valid)
np.random.shuffle(ind_test)

### Prepare the data for training

In [None]:
# we have now the frame indexes which will be included into train/validation and test set
train_data = TimeLaggedDatasetObs.from_frames(lagtime=tau_msm, data=traj_whole[0], frames=ind_train,
                                               data_obs_ev=None, data_obs_ac=None)
val_data =  TimeLaggedDatasetObs.from_frames(lagtime=tau_msm, data=traj_whole[0], frames=ind_valid,
                                               data_obs_ev=None, data_obs_ac=None) 
test_data = TimeLaggedDatasetObs.from_frames(lagtime=tau_msm, data=traj_whole[0], frames=ind_test,
                                               data_obs_ev=None, data_obs_ac=None)
# This is now the data which is manipulated.
loader_train = DataLoader(train_data, batch_size=batch_size, shuffle=True)
loader_val = DataLoader(val_data, batch_size=len(val_data), shuffle=False)

### Train a model on the manipulated data but without including additional information about the observables

In [None]:
# Since we start again with the result of the VAMPnet, we need to load that first
vampnet.lobe.load_state_dict(state_dict_vampnet)
# and train it on the new data
model = vampnet.fit(loader_train, n_epochs=10,
                    validation_loader=loader_val, progress=tqdm).fetch_model()
plt.loglog(*vampnet.train_scores.T, label='training')
plt.loglog(*vampnet.validation_scores.T, label='validation')
plt.xlabel('step')
plt.ylabel('score')
plt.legend();
plt.show()
state_dict_vampnet = vampnet.lobe.state_dict()
# reset u and S to be retrained for the new tau
deepmsm.set_rev_var(loader_train)
# deepmsm.reset_u_S_wo()
# deepmsm.reset_u_S(loader_train)
# reset the optimizers for u and S
# deepmsm.reset_opt_u_S(lr=1)
deepmsm.reset_opt_all(lr=1)
deepmsm.reset_scores()

In [None]:
# train only for the matrix S
deepmsm.fit(loader_train, n_epochs=100, validation_loader=loader_val, train_mode='s', tb_writer=writer)
plt.loglog(*deepmsm.train_scores.T, label='training')
plt.loglog(*deepmsm.validation_scores.T, label='validation')
plt.xlabel('step')
plt.ylabel('score')
plt.legend()
plt.show()

In [None]:
# train for u and S
deepmsm.fit(loader_train, n_epochs=1000, validation_loader=loader_val, train_mode='us', tb_writer=writer)
plt.loglog(*deepmsm.train_scores.T, label='training')
plt.loglog(*deepmsm.validation_scores.T, label='validation')
plt.xlabel('step')
plt.ylabel('score')
plt.legend()
plt.show()

In [None]:
# Train for chi, u, and S in an iterative manner
deepmsm.fit_routine(loader_train, n_epochs=5, validation_loader=loader_val, rel=0.001, reset_u=False, 
                    max_iter=1000, tb_writer=writer)
plt.loglog(*deepmsm.train_scores.T, label='training')
plt.loglog(*deepmsm.validation_scores.T, label='validation')
plt.xlabel('step')
plt.ylabel('score')
plt.legend();
deepmsm.fit_routine(loader_train, n_epochs=5, validation_loader=loader_val, rel=0.001, reset_u=True, 
                    max_iter=1000, tb_writer=writer)
plt.loglog(*deepmsm.train_scores.T, label='training')
plt.loglog(*deepmsm.validation_scores.T, label='validation')
plt.xlabel('step')
plt.ylabel('score')
plt.legend();
deepmsm.fit_routine(loader_train, n_epochs=5, validation_loader=loader_val, rel=0.001, reset_u=False, 
                    max_iter=1000, tb_writer=writer)
plt.loglog(*deepmsm.train_scores.T, label='training')
plt.loglog(*deepmsm.validation_scores.T, label='validation')
plt.xlabel('step')
plt.ylabel('score')
plt.legend();

### Create the data with observables

In [None]:
train_data = TimeLaggedDatasetObs.from_frames(lagtime=tau_msm, data=traj_whole[0], frames=ind_train,
                                               data_obs_ev=obs_values, data_obs_ac=obs_values)
val_data =  TimeLaggedDatasetObs.from_frames(lagtime=tau_msm, data=traj_whole[0], frames=ind_valid,
                                               data_obs_ev=obs_values, data_obs_ac=obs_values) 
test_data = TimeLaggedDatasetObs.from_frames(lagtime=tau_msm, data=traj_whole[0], frames=ind_test,
                                               data_obs_ev=obs_values, data_obs_ac=obs_values)
# This is now the data which is manipulated.
loader_train = DataLoader(train_data, batch_size=batch_size, shuffle=True)
loader_val = DataLoader(val_data, batch_size=len(val_data), shuffle=False)

### Estimate the observables with the model without additional informations of experimental values

In [None]:
obs_before = model_msm_final.observables(test_data.data, test_data.data_lagged, test_data.data_obs_ev, test_data.data_obs_ac, state1=[index_unfolded], state2=[index_folded])
# extract the true values
ev_before = obs_before[0]
ac_before = obs_before[1]
eigval_before = obs_before[2] 

### Compare the defined observables
The values of before should be different from the true ones, otherwise our data manipulation did not have an effect on that particular observables

In [None]:
ev_true, ev_before

In [None]:
ac_true, ac_before

In [None]:
eigval_true, eigval_before

In [None]:
# check all the eigenvalues of the transition matrix
T_before = model_msm_final.get_transition_matrix(test_data.data, test_data.data_lagged)
all_eigval_before = np.sort(np.linalg.eigvals(T_before))[:2]

In [None]:
# Check the estimation of the defined states
mu_before = model_msm_final.get_mu(test_data.data_lagged)
states_mu_before = estimate_mu(mu_before, chi_true[::skip], ind_test+tau_msm)
print(states_mu_true, states_mu_before)

In [None]:
# estimate folding and unfolding rates
mfpt_fold_before = mfpt(T_before, index_folded, index_unfolded) * tau_msm * fac
mfpt_unfold_before = mfpt(T_before, index_unfolded, index_folded) * tau_msm * fac

In [None]:
print(mfpt_fold_before, mfpt_fold_true)
print(mfpt_unfold_before, mfpt_unfold_true)

In [None]:
# save the weights
state_dict_before = deepmsm.state_dict()

### Now train with the the observables

In [None]:
# Since we start again with the result of the VAMPnet, we need to load that first
vampnet.lobe.load_state_dict(state_dict_vampnet)
# reset u and S to be retrained for the new tau
# deepmsm.set_rev_var(loader_train)
deepmsm.reset_u_S_wo()
# deepmsm.reset_u_S(loader_train)
# reset the optimizers for u and S
# deepmsm.reset_opt_u_S(lr=1)
deepmsm.reset_opt_all(lr=0.1)
deepmsm.reset_scores()

### Define the regularization parameter

In [None]:
# you can turn them off by switching a value to 0
xi_ev = np.array([1.,1.])*10.
xi_ac = np.array([1.,1.])*10.
xi_its = np.array([1.])*10

In [None]:
deepmsm.fit_obs(loader_train, 1000, validation_loader=loader_val, train_mode='s', 
               exp_ev=ev_true, exp_ac=ac_true, exp_its=eigval_true,
               xi_ev=xi_ev, xi_ac=xi_ac, xi_its=xi_its,
               its_state1=[index_unfolded], its_state2=[index_folded], tb_writer=writer)
plt.loglog(*np.abs(deepmsm.train_scores.T), label='training')
plt.loglog(*np.abs(deepmsm.validation_scores.T), label='validation')
plt.xlabel('step')
plt.ylabel('score')
plt.legend();

In [None]:
# Training loop
score_temp = deepmsm.validation_scores[-1,-1]
weights_temp = deepmsm.state_dict()
for i in range(10):
    # Train for everything
    deepmsm.fit_obs(loader_train, 5, validation_loader=loader_val, train_mode='all', 
               exp_ev=ev_true, exp_ac=ac_true, exp_its=eigval_true,
               xi_ev=xi_ev, xi_ac=xi_ac, xi_its=xi_its,
               its_state1=[index_unfolded], its_state2=[index_folded], tb_writer=writer)
    # train only for u and s
    deepmsm.fit_obs(loader_train, 100, validation_loader=loader_val, train_mode='us', 
               exp_ev=ev_true, exp_ac=ac_true, exp_its=eigval_true,
               xi_ev=xi_ev, xi_ac=xi_ac, xi_its=xi_its,
               its_state1=[index_unfolded], its_state2=[index_folded], tb_writer=writer)
    # save the weights if the validation score is better!
    if deepmsm.validation_scores[-1,-1]< score_temp:
        score_temp = deepmsm.validation_scores[-1,-1]
        print('new score: {:.3}'.format(score_temp))
        weights_temp = deepmsm.state_dict()

In [None]:
obs_after = model_msm_final.observables(test_data.data, test_data.data_lagged, test_data.data_obs_ev, test_data.data_obs_ac, state1=[index_unfolded], state2=[index_folded])
# extract the true values
ev_after = obs_after[0]
ac_after = obs_after[1]
eigval_after = obs_after[2] 

In [None]:
ev_after, ev_true

In [None]:
ac_after, ac_true

In [None]:
eigval_after, eigval_true

In [None]:
# check all the eigenvalues of the transition matrix
T_after = model_msm_final.get_transition_matrix(test_data.data, test_data.data_lagged)
all_eigval_after = np.sort(np.linalg.eigvals(T_after))[:2]

In [None]:
# estimate the stationary distribution of predefined states
mu_after = model_msm_final.get_mu(test_data.data_lagged)
states_mu_after = estimate_mu(mu_after, chi_true[::skip], ind_test+tau_msm)
print(states_mu_true, states_mu_after)

In [None]:
# estimate folding and unfolding rates
mfpt_fold_after = mfpt(T_after, index_folded, index_unfolded) * tau_msm * fac
mfpt_unfold_after = mfpt(T_after, index_unfolded, index_folded) * tau_msm * fac

In [None]:
print(mfpt_fold_after, mfpt_fold_true)
print(mfpt_unfold_true, mfpt_unfold_after)

### Plot the final comparison

In [None]:
labels = ['True', '', '']
for i in range(2):
    plt.hlines(ev_true[i], i-0.25, i+0.25,'k', '--',label=labels[i])
plt.plot(ev_after, 'o', ms=10, label='+ Obs')
plt.plot(ev_before, 'o', ms=10, label='without')
plt.xlabel('Contact', fontsize=18)
plt.ylabel('Value [%]', fontsize=18)
plt.xticks([0,1], ['Formed', 'Unformed'], fontsize=16)
plt.title('Expectation Values', fontsize=18)
plt.legend(fontsize=14)
plt.show()

for i in range(2):
    plt.hlines(ac_true[i], i-0.25, i+0.25,'k', '--',label=labels[i])
plt.plot(ac_after, 'o', ms=10, label='+ Obs')
plt.plot(ac_before, 'o', ms=10, label='without')
plt.xlabel('Contact staying', fontsize=18)
plt.ylabel('Value [%]', fontsize=18)
plt.xticks([0,1], ['Formed', 'Unformed'], fontsize=16)
plt.title('Autocorrelation Values', fontsize=18)
plt.legend(fontsize=14)
plt.show()

for i in range(2):
    plt.hlines(all_eigval_true[i], i-0.25, i+0.25,'k', '--',label=labels[i])
plt.plot(all_eigval_after, 'o', ms=10, label='+ Obs')
plt.plot(all_eigval_before, 'o', ms=10, label='without')
plt.xlabel('Eigenvalue', fontsize=18)
plt.ylabel('Value', fontsize=18)
plt.xticks([0,1], ['Folding', 'Misfolding'], fontsize=16)
plt.title('Eigenvalue', fontsize=18)
plt.legend(fontsize=14)
plt.show()

for i in range(3):
    plt.hlines(np.sort(states_mu_true)[i], i-0.25, i+0.25,'k', '--',label=labels[i])
plt.plot(np.sort(states_mu_after), 'o', ms=10, label='+ Obs')
plt.plot(np.sort(states_mu_before), 'o', ms=10, label='without')
plt.xlabel('State', fontsize=18)
plt.ylabel('Probability [%]', fontsize=18)
plt.xticks([0,1,2], ['Misfolded', 'Folded', 'Unfolded'], fontsize=16)
plt.title('Stationary Distribution', fontsize=18)
plt.legend(fontsize=14)
plt.show()

for i in range(2):
    plt.hlines([mfpt_fold_true, mfpt_unfold_true][i], i-0.25, i+0.25,'k', '--',label=labels[i])
plt.plot([mfpt_fold_after, mfpt_unfold_after], 'o', ms=10, label='+ Obs')
plt.plot([mfpt_fold_before, mfpt_unfold_before], 'o', ms=10, label='without')
plt.xlabel('Process', fontsize=18)
plt.ylabel('MFPT [\mu s]', fontsize=18)
plt.xticks([0,1], ['Unfolding', 'Folding'], fontsize=16)
plt.title('Mean First Passage Time', fontsize=18)
plt.legend(fontsize=14)
plt.show()