### Notebook to construct a deep reversible Markov State Model where additional information of experimental values can be included

The training steps include:
1. Pretraining of an ordinary VAMPnet (can be used to compare the results)
2. Initiate the values for the paramter matrix $\mathbf{S}$ and the reweighting vector $\mathbf{u}$ and train for matching specific observables
3. Training for $\mathbf{u}$ and $\mathbf{S}$ and train for matching specific observables
4. Training for $\boldsymbol{\chi}$ and $\mathbf{u}$ and $\mathbf{S}$ and train for matching specific observables

The analysis consists of:
1. Compare the performance on predefined observables not included in the training against a model which is trained without the additional experimental information


In [None]:
# import all the relevant packages
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline

import torch.nn as nn
import torch.nn.functional as F
import torch
from torch.utils import data
import torch.optim as optim
import matplotlib.gridspec as gridspec

In [None]:
# Check for cuda supported device
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# Assuming that we are on a CUDA machine, this should print a CUDA device:

print(device)

In [None]:
# Define where the data lies on the local machine
# Needs to be adapted for own data!!!
test_system = '2F4K'
pdb_system = '2f4k_villin.pdb'

root = '/group/ag_cmb/simulation-data/DESRES-Science2011-FastProteinFolding/DESRES-Trajectory_{0}-0-protein/{0}-0-protein/{0}-0-protein'.format(test_system)

In [None]:
# Load the data
output_all_files = np.load('/srv/public/andreas/data/desres/2f4k/villin_skip1.npy')
traj_whole = output_all_files

traj_data_points, input_size = traj_whole[0].shape
skip=1

In [None]:
# Hyperparameter definitions, should be adapted for specific problems

# number of output nodes/states of the MSM or Koopman model, therefore also nodes of chi
output_sizes = [3,2]
# tau list for timescales estimation
tau_list = [25,50,100,150,200]
number_taus = len(tau_list)
tau_list = np.array(tau_list)//skip

# Tau, how much is the timeshift of the two datasets in the default training
tau = 2*25//skip # 5, 20
tau_chi = 2*25//skip

# Batch size for Stochastic Gradient descent
batch_size = 10000

# Which trajectory points percentage is used as training
train_ratio = 0.33
valid_ratio = 0.33

# How many hidden layers the network chi has
network_depth = 5

# Width of every layer of chi
layer_width = 100

# Learning rate used for the ADAM optimizer
learning_rate = 5e-4

# create a list with the number of nodes for each layer
nodes = [layer_width]*network_depth

# epsilon for numerical inversion of correlation matrices
epsilon = np.array(1e-7).astype('float32')

In [None]:
traj_whole_new = [traj_whole[0][::skip]]
input_size = traj_whole_new[0].shape[1]

In [None]:
# function for data generation
def estimate_koopman_op(trajs, tau, force_symmetric = False):
    '''Estimates the koopman operator for a given trajectory at the lag time
        specified. The formula for the estimation is:
            K = C00 ^ -1 @ C01

    Parameters
    ----------
    traj: numpy array with size [traj_timesteps, traj_dimensions]
        Trajectory described by the returned koopman operator

    tau: int
        Time shift at which the koopman operator is estimated
        
    force_symmetric: boolean, default = False
        if true, calculates the symmetrized version of K instead

    Returns
    -------
    koopman_op: numpy array with shape [traj_dimensions, traj_dimensions]
        Koopman operator estimated at timeshift tau

    '''
    # if tau larger 0, interpret trajs as either a list of trajectories
    # or a single trajectory
    # otherwise interpret trajs as a list of the data and time-lagged data
    # possibly in random order
    if tau > 0:
        if type(trajs) == list:
            traj = np.concatenate([t[:-tau] for t in trajs], axis = 0)
            traj_lag = np.concatenate([t[tau:] for t in trajs], axis = 0)
        else:
            traj = trajs[:-tau]
            traj_lag = trajs[tau:]
    else:
        traj = trajs[0]
        traj_lag = trajs[1]
        
    koopman_op = np.eye(traj.shape[1])

    c_0 = traj.T @ traj
    c_tau = traj.T @ traj_lag
    
    # if you want to symmetrize the correlation matrices
    if force_symmetric:
        c_0 = c_0 + traj_lag.T @ traj_lag
        c_tau = c_tau + traj_lag.T @ traj

    eigv_all, eigvec_all = np.linalg.eig(c_0)
    include = eigv_all > epsilon
    eigv = eigv_all[include]
    eigvec = eigvec_all[:,include]
    c0_inv = eigvec @ np.diag(1/eigv) @ np.transpose(eigvec)

    koopman_op = c0_inv @ c_tau

    return koopman_op



# utility function for plotting implied timescales
# the hyperparameters allow for it to calculate errorbars and work from different input data types
def get_its(data, lags, calculate_K = True, multiple_runs = False):
    
    def get_single_its(data):

        if type(data) == list:
            outputsize = data[0].shape[1]
        else:
            outputsize = data.shape[1]

        single_its = np.zeros((outputsize-1, len(lags)))

        for t, tau_lag in enumerate(lags):
            if calculate_K:
                koopman_op = self.estimate_koopman_op(data, tau_lag)
            else:
                koopman_op = data[t]
            k_eigvals, k_eigvec = np.linalg.eig(np.real(koopman_op))
            k_eigvals = np.sort(np.absolute(k_eigvals))
            k_eigvals = k_eigvals[:-1]
            single_its[:,t] = (-tau_lag / np.log(k_eigvals))

        return np.array(single_its)


    if not multiple_runs:

        its = get_single_its(data)

    else:

        its = []
        for data_run in data:
            its.append(get_single_its(data_run))

    return its


### Helper functions

In [None]:
def _inv(x, ret_sqrt=False):
    '''Utility function that returns the inverse of a matrix, with the
    option to return the square root of the inverse matrix.
    Parameters
    ----------
    x: numpy array with shape [m,m]
        matrix to be inverted

    ret_sqrt: bool, optional, default = False
        if True, the square root of the inverse matrix is returned instead
    Returns
    -------
    x_inv: numpy array with shape [m,m]
        inverse of the original matrix
    '''

    # Calculate eigvalues and eigvectors
    eigval_all, eigvec_all = torch.symeig(x, eigenvectors=True)

    # Filter out eigvalues below threshold and corresponding eigvectors
    eig_th = torch.Tensor(epsilon)
    index_eig = eigval_all > eig_th
    eigval = eigval_all[index_eig]
    eigvec = eigvec_all[:,index_eig]

    # Build the diagonal matrix with the filtered eigenvalues or square
    # root of the filtered eigenvalues according to the parameter
    if ret_sqrt:
        diag = torch.diag(torch.sqrt(1/eigval))
    else:
        diag = torch.diag(1/eigval)
    # Rebuild the square root of the inverse matrix
    x_inv = torch.matmul(eigvec, torch.matmul(diag, eigvec.T))

    return x_inv


def _prep_data(data_t, data_tau):
    '''Utility function that transorms the input data from a tensorflow - 
    viable format to a structure used by the following functions in the
    pipeline.
    Parameters
    ----------
    data: tensorflow tensor with shape [b, 2*o]
        original format of the data
    Returns
    -------
    x: tensorflow tensor with shape [o, b]
        transposed, mean-free data corresponding to the left, lag-free lobe
        of the network

    y: tensorflow tensor with shape [o, b]
        transposed, mean-free data corresponding to the right, lagged lobe
        of the network

    b: tensorflow float32
        batch size of the data

    o: int
        output size of each lobe of the network

    '''


    # Subtract the mean
    x = data_t - torch.mean(data_t, dim=0, keepdim=True)
    y = data_tau - torch.mean(data_tau, dim=0, keepdim=True)

    return x, y


### Definition of the classes for the model and the attention network

In [None]:
class Mask(torch.nn.Module):
    ''' Attention mask either independent from the time point (mask_const=True) or dependent.
    If dependent the attention is estimated via a NN with depth and width given as input, which are 
    otherwise ignored. 
    The attention mechanism assumes that distances are used. skip_res is number of residues skiped when estimating
    the distance. 
    
    
    '''
    def __init__(self, input_size, mask_const, depth=0, width=100 , heads=1, sensitivity=1, fac=True):
        super(Mask, self).__init__()
        
#         self.alpha = torch.Tensor(1, input_size, N, heads).fill_(0)

        skip_res = 2
        self.n_residues = int(-1/2 + np.sqrt(1/4+input_size*2) + skip_res)
        self.residues_1 = []
        self.residues_2 = []
        # estimate the pairs
        for n1 in range(self.n_residues-skip_res):
            for n2 in range(n1+skip_res, self.n_residues):
                self.residues_1.append(n1)
                self.residues_2.append(n2)
        self.mask_const = mask_const
        if mask_const:        
            self.alpha = torch.randn((1, self.n_residues, heads)) * 0.5
            self.weight = torch.nn.Parameter(data=self.alpha, requires_grad=True)
        else:
            nodes = [input_size]
            for i in range(depth):
                nodes.append(width)
            
            
            self.hfc = [nn.Linear(nodes[i], nodes[i+1]) for i in range(len(nodes)-1)]
            self.softmax = nn.Linear(nodes[-1], self.n_residues * heads)
            
        self.sensitivity = sensitivity
        if fac:
            self.fac = self.n_residues * self.n_residues
        else:
            self.fac = 1.
        self.heads = heads
        
    def forward(self, x):
        
        # weights for each residue
        if self.mask_const:
            weight_sf = torch.sum(F.softmax(self.weight*self.sensitivity, dim=1), dim=2)/self.heads
        else:
            y = x
            for layer in self.hfc:
                y = layer(y)
            y = self.softmax(y)
            y = torch.reshape(y, (-1, self.n_residues, heads))
            weight_sf = torch.sum(F.softmax(y*self.sensitivity, dim=1), dim=2)/self.heads
            
            
        # estimate for applying it
        weight_1 = weight_sf[:,self.residues_1]
        weight_2 = weight_sf[:,self.residues_2]
        
        masked_x = x * weight_1 * weight_2 * self.fac# include factor
        
        return masked_x
    
    def get_softmax(self, x=None):
        if self.mask_const:
            weight_sf = torch.sum(F.softmax(self.weight*self.sensitivity, dim=1), dim=2)/self.heads
        else:
            
            y = x
            for layer in self.hfc:
                y = layer(y)
            y = self.softmax(y)
            y = torch.reshape(y, (-1, n_residues, heads))
            weight_sf = torch.sum(F.softmax(y*self.sensitivity, dim=1), dim=2)/self.heads

        
        return weight_sf
    
    def set_weights(self, weights):
        if self.mask_const:
            alpha = torch.Tensor(weights)

            self.weight = torch.nn.Parameter(data=alpha, requires_grad=True)
        else:
            print('not implemented yet. You need to define all layers in the mask')

            

            
class Coarse_grain(torch.nn.Module):
    ''' Attention mask either independent from the time point (mask_const=True) or dependent.
    If dependent the attention is estimated via a NN with depth and width given as input, which are 
    otherwise ignored. 
    The attention mechanism assumes that distances are used. skip_res is number of residues skiped when estimating
    the distance. 
    
    
    '''
    def __init__(self, input_dim, output_dim, sen=1):
        super(Coarse_grain, self).__init__()

        self.N = input_dim
        self.M = output_dim
        self.sen = sen
        
        self.alpha = torch.randn((self.N, self.M)) * 0.5
        self.weight = torch.nn.Parameter(data=self.alpha, requires_grad=True)
        
    def forward(self, x):
        
        
        kernel = F.softmax(self.sen * self.weight, dim=1)
        
        ret = x @ kernel

        return ret
    
    def get_softmax(self):
        
        return F.softmax(self.sen * self.weight, dim=1)
            
    def reset_params(self):
        
        with torch.no_grad():
            
            
            self.weight.copy_(torch.randn((self.N, self.M)) * 0.5) 
        

        
class U_layer(torch.nn.Module):
    ''' Attention mask either independent from the time point (mask_const=True) or dependent.
    If dependent the attention is estimated via a NN with depth and width given as input, which are 
    otherwise ignored. 
    The attention mechanism assumes that distances are used. skip_res is number of residues skiped when estimating
    the distance. 
    
    
    '''
    def __init__(self, output_dim, activation):
        super(U_layer, self).__init__()

        self.M = output_dim
        
        self.alpha = torch.Tensor(1, self.M).fill_(1/self.M)
        self.u_kernel = torch.nn.Parameter(data=self.alpha, requires_grad=True)
        self.acti = activation
        
    def forward(self, chi_t, chi_tau):
        
        # we need batchsize to stack the outputs later so that it fullfills keras' requirements
        batchsize = chi_t.shape[0]

        # note: corr_tau is the correlation matrix of the time-shifted data
        # presented in the paper at page 6, "Normalization of transition density"
        corr_tau = 1./batchsize * torch.matmul(chi_tau.T, chi_tau)
        chi_mean = torch.mean(chi_tau, dim=0, keepdim=True)

        kernel_u = self.acti(self.u_kernel)

        # u is the normalized and transformed kernel of this layer
        u = kernel_u / torch.sum(chi_mean * kernel_u, dim=1, keepdim=True)

        v = torch.matmul(corr_tau, u.T)

        mu = 1./batchsize * torch.matmul(chi_tau, u.T)
        
        Sigma =  torch.matmul((chi_tau * mu).T, chi_tau)
        
        chi_mean_t = torch.mean(chi_t, dim=0, keepdim=True)
        u_t = kernel_u / torch.sum(chi_mean_t * kernel_u, dim=1, keepdim=True)
        mu_t = 1./batchsize * torch.matmul(chi_t, u_t.T)
        Sigma_t =  torch.matmul((chi_t * mu_t).T, chi_t)

        gamma = chi_tau * (torch.matmul(chi_tau, u.T))

        C_00 = 1./batchsize * torch.matmul(chi_t.T, chi_t)
        C_11 = 1./batchsize * torch.matmul(gamma.T, gamma)
        C_01 = 1./batchsize * torch.matmul(chi_t.T, gamma)


        ret = [
            v,
            C_00,
            C_11,
            C_01,
            Sigma,
            mu_t,
            Sigma_t
        ]
        
        return ret
    
class S_layer(torch.nn.Module):
    ''' Attention mask either independent from the time point (mask_const=True) or dependent.
    If dependent the attention is estimated via a NN with depth and width given as input, which are 
    otherwise ignored. 
    The attention mechanism assumes that distances are used. skip_res is number of residues skiped when estimating
    the distance. 
    
    
    '''
    def __init__(self, output_dim, activation, renorm=True):
        super(S_layer, self).__init__()

        self.M = output_dim
        
        self.alpha = torch.Tensor(self.M, self.M).fill_(0.1)
        self.S_kernel = torch.nn.Parameter(data=self.alpha, requires_grad=True)
        self.acti = activation
        self.renorm = renorm
        
    def forward(self, v, C_00, C_11, C_01, Sigma):
        
        # we need batchsize to stack the outputs later so that it fullfills keras' requirements
            
        batchsize = v.shape[0]

        # transform the kernel weights
        kernel_w = self.acti(self.S_kernel)
        
        # enforce symmetry
        W1 = kernel_w + kernel_w.T

        # normalize the weights
        norm = W1 @ v
        
        if self.renorm:
            # make sure that the largest value of norm is < 1
            quasi_inf_norm = lambda x: torch.sum((x**20))**(1./20)
#             print(norm, quasi_inf_norm(norm))
            W1 = W1 / quasi_inf_norm(norm)
            norm = W1 @ v

        
        w2 = (1 - torch.squeeze(norm)) / torch.squeeze(v)
        S = W1 + torch.diag(w2)

        # calculate K
        K = S @ Sigma

        # VAMP-E matrix for the computation of the loss
        VampE_matrix = S.T @ C_00 @ S @ C_11 - 2*S.T @ C_01

        # stack outputs so that the first dimension is = batchsize, keras requirement
        
        ret = [VampE_matrix, K, S]
        
        return ret
            
class VampNet(nn.Module):
    ''' VAMPnet class
    TODO:
    - revVAMP
    - revDMSM
    - DMSM
    - VAMP
    
    inputs:
    
    input_size: (int) size of the input features
    output_sizes: (list & int) list of output sizes. If more than one elemtent, expects coarse graining
    nodes: (list & int) list of output size of hidden layers
    
    
    '''
    def __init__(self, input_size, output_sizes, nodes, train_mean, train_std, 
                 valid_T=False, reversible=False,
                 mask_const=True, mask_depth=0, mask_width=0, heads=1, sensitivity=1, fac=True,
                 softmax_fac=1.):
        super(VampNet, self).__init__()
        
        # which physical constraints are enacted which can need more parameters with different contraints
        # Furthermore the best learning practice is different
        self.valid_T = valid_T
        self.reversible = reversible
        if valid_T:
            model='DMSM'
        else:
            model='VAMPnet'
        if reversible:
            rev='rev'
        else:
            rev=''
        print('The trained model is a '+rev+model) 
        
        if valid_T and not reversible:
            self.gamma = True
        else:
            self.gamma = False
        if reversible:
            if valid_T:
                self.factor_S = 1
                self.factor_u = 0.01
                self.renorm = True
                acti_S = lambda x: self.factor_S * torch.exp(x)
                acti_u = lambda x: self.factor_u * torch.exp(x)
                
            else: 
                self.factor_S = 0.001
                self.factor_u = .000001

                linear_S = lambda x: self.factor_S * x
                linear_u = lambda x: self.factor_u * x
                self.renorm = False
                acti_S = linear_S
                acti_u = linear_u
            
            
            self.u_layers = [U_layer(o, acti_u) for o in output_sizes]
            self.S_layers = [S_layer(o, acti_S, self.renorm) for o in output_sizes]
                
# activations for reversible VAMPnets


#activations for revDMSM

        
        
        self.N = len(output_sizes)
        self.output_sizes = output_sizes
        
        self.softmax_fac = softmax_fac
        self.fac = fac
        self.nodes = nodes
        self.mask_const = mask_const
        self.Mask = Mask(input_size, mask_const, mask_depth, mask_width, heads, sensitivity, fac=fac)
        
        self.hfc = [nn.Linear(input_size, nodes[0])]
        for i in range(len(nodes)-1):
            self.hfc.append(nn.Linear(nodes[i], nodes[i+1]))
            
        
        self.fc_softmax = nn.Linear(nodes[-1], output_sizes[0])
        
        if self.gamma:
            self.gamma_layer = nn.Linear(nodes[-1], output_sizes[0])
            # is just needed once, not for coarse graining, since it convert the same as chi
        
        self.coarse_grain_layer = [Coarse_grain(output_sizes[n], output_sizes[n+1]) for n in range(len(output_sizes)-1)]
        
        self.train_mean = torch.Tensor(train_mean)
        self.train_std = torch.Tensor(train_std)
    
    def forward_before_sm(self, x):
        
        x = (x-self.train_mean)/self.train_std
        shape = x.shape
        b = shape[0]
        
        x = self.Mask(x) # b x input_size

        
        for layer_list_i in self.hfc:

            x = F.elu(layer_list_i(x))
            
        return x
        
    def forward(self, x):
        
        x = self.forward_before_sm(x)

        x_output = F.softmax(self.softmax_fac*self.fc_softmax(x), dim=1)
        
        return x_output
    
    def forward_cg(self, x, id):
        
        x_output = self.coarse_grain_layer[id](x)
        
        return x_output
    
    def forward_all(self, x):
        outputs = []
        
        x = self.forward(x)
        
        outputs.append(x)
        
        for layer in self.coarse_grain_layer:
            x = layer(x)
            outputs.append(x)
            
        return outputs
    
    def forward_gamma(self, x):
        
        x_gamma = F.elu(self.gamma_layer(x)) + 1. # plus 1 so it is always positive
        
        return x_gamma
    
        
    def get_attention(self, x=None):
        return self.Mask.get_softmax(x=x)
    
    def set_soft_fac(self, new_value):
        self.softmax_fac = torch.Tensor(new_value)
        
    def set_soft_fac_cg(self, id_cg, new_value):
        
        self.coarse_grain_layer[id_cg].sen = new_value
    
    def get_params_vamp(self):
        if self.mask_const:
            for param in self.Mask.parameters():
                yield param
        else:
            for layer in self.Mask.hfc:
                for param in layer.parameters():
                    yield param
            for param in self.Mask.softmax.parameters():
                yield param
                    
        for layer in self.hfc:
            
            for param in layer.parameters():
                yield param
        
        for param in self.fc_softmax.parameters():
            yield param
    
    def get_params_DMSM(self, all=True):
        if all:
            if self.mask_const:
                for param in self.Mask.parameters():
                    yield param
            else:
                for layer in self.Mask.hfc:
                    for param in layer.parameters():
                        yield param
                for param in self.Mask.softmax.parameters():
                    yield param

            for layer in self.hfc:

                for param in layer.parameters():
                    yield param

            for param in self.fc_softmax.parameters():
                yield param
        for param in self.gamma_layer.parameters():
            yield param

    def get_params_rev(self, all=True, u_flag=True, S_flag=True):
        if all:
            if self.mask_const:
                for param in self.Mask.parameters():
                    yield param
            else:
                for layer in self.Mask.hfc:
                    for param in layer.parameters():
                        yield param
                for param in self.Mask.softmax.parameters():
                    yield param

            for layer in self.hfc:

                for param in layer.parameters():
                    yield param

            for param in self.fc_softmax.parameters():
                yield param
        if u_flag:
            for param in self.u_layers[0].parameters():
                yield param
        if S_flag:
            for param in self.S_layers[0].parameters():
                yield param
    
    def get_params_all(self):
        if self.mask_const:
            for param in self.Mask.parameters():
                yield param
        else:
            for layer in self.Mask.hfc:
                for param in layer.parameters():
                    yield param
            for param in self.Mask.softmax.parameters():
                yield param
                    
        for layer in self.hfc:
            
            for param in layer.parameters():
                yield param
        
        for param in self.fc_softmax.parameters():
            yield param
        for layer in self.coarse_grain_layer:
            for param in layer.parameters():
                yield param
        
    
    def get_params_wo_mask(self):
        for layer in self.hfc:
            for param in layer.parameters():
                yield param
       
        for param in self.fc_softmax.parameters():
            yield param
    
    def get_params_softmax(self):
        
        for param in self.fc_softmax.parameters():
            yield param
            
    
    def get_params_mask(self):
        if self.mask_const:
            for param in self.Mask.parameters():
                yield param
        else:
            for layer in self.Mask.hfc:
                for param in layer.parameters():
                    yield param
            for param in self.Mask.softmax.parameters():
                yield param
            
    def get_params_cg(self, index=[]):
        for id in index:
            for param in self.coarse_grain_layer[id].parameters():
                yield param
    
    def get_params_cg_rev(self, index=[], all=True):
        
        for id in index:
            if all:
                for param in self.coarse_grain_layer[id].parameters():
                    yield param
            for param in self.u_layers[id+1].parameters():
                yield param
            for param in self.S_layers[id+1].parameters():
                yield param
                
                
    def set_rev_var(self, layer_id=0, S=True):
        
        chi_t = self.forward(tensor_train_X1)
        chi_tau = self.forward(tensor_train_X2)
    
        for i in range(layer_id):
            chi_t = self.forward_cg(chi_t, i)
            chi_tau = self.forward_cg(chi_tau, i)
            
        Data_chi_X = chi_t.detach().numpy()
        Data_chi_Y = chi_tau.detach().numpy()
        fullbatch = Data_chi_X.shape[0]


        c_0 = 1/fullbatch * Data_chi_X.T @ Data_chi_X
        c_tau = 1/fullbatch * Data_chi_X.T @ Data_chi_Y
        c_1 = 1/fullbatch * Data_chi_Y.T @ Data_chi_Y

        eigv_all, eigvec_all = np.linalg.eigh(c_0)
        include = eigv_all > epsilon
        eigv = eigv_all[include]
        eigvec = eigvec_all[:,include]
        c0_inv = eigvec @ np.diag(1/eigv) @ np.transpose(eigvec)

        K_vamp = c0_inv @ c_tau

        # estimate pi, the stationary distribution vector
        eigv, eigvec = np.linalg.eig(K_vamp.T)
        ind_pi = np.argmin((eigv-1)**2)

        pi_vec = eigvec[:,ind_pi]
        pi = pi_vec / np.sum(pi_vec, keepdims=True)

        # reverse the consruction of u 
        u_optimal = c0_inv @ pi
        
        if self.valid_T:
            u_kernel = np.log(np.abs(u_optimal/self.factor_u))
        else:
            u_kernel = u_optimal / self.factor_u
        
        with torch.no_grad():
            for param in self.u_layers[layer_id].parameters():
            
                param.copy_(torch.Tensor(u_kernel[None,:]))  
            
        if S:
            
            _, _, _, _, Sigma_input, _, _ = self.u_layers[layer_id](chi_t, chi_tau)
            Sigma = Sigma_input.detach().numpy()
            
            eigv_all, eigvec_all = np.linalg.eigh(Sigma)
            include = eigv_all > epsilon
            eigv = eigv_all[include]
            eigvec = eigvec_all[:,include]
            sigma_inv = eigvec @ np.diag(1/eigv) @ np.transpose(eigvec)

            # reverse the construction of S
            S_nonrev = K_vamp @ sigma_inv
            S_rev_add = 1/2 * (S_nonrev + S_nonrev.T)
            if self.valid_T:
                kernel_S = S_rev_add / 2.
                kernel_S = np.log(np.abs(kernel_S/self.factor_S))
            else:
                kernel_S = S_rev_add / 2. / self.factor_S

            with torch.no_grad():
                for param in self.S_layers[layer_id].parameters():

                    param.copy_(torch.Tensor(kernel_S)) 
            
    def get_weights(self):
        
        weights_dict = {}
        weights_dict['hfc'] = []
        weights_dict['sm'] = []

        for layer in  self.hfc:
            for param in layer.parameters():
                weights_dict['hfc'].append(param.detach().numpy().copy())

        for param in self.fc_softmax.parameters():
            weights_dict['sm'].append(param.detach().numpy().copy())

        if self.reversible:
            weights_dict['S'] = [param.detach().numpy().copy() for param in self.S_layers[0].parameters()]
            weights_dict['u'] = [param.detach().numpy().copy() for param in self.u_layers[0].parameters()]

        if self.mask_const:
            weights_dict['Mask'] = [param.detach().numpy().copy() for param in self.Mask.parameters()]
        else:
            weights_dict['Mask_hf'] = []
            weights_dict['Mask_sm'] = []

            for layer in self.Mask.hfc:
                for param in layer.parameters():
                    weights_dict['Mask_hf'].append(param.detach().numpy().copy())
            for param in self.Mask.softmax.parameters():
                weights_dict['Mask_sm'].append(param.detach().numpy().copy())

        weights_dict['cg'] = []
        weights_dict['S_cg'] = []
        weights_dict['u_cg'] = []
        for i, layer in enumerate(self.coarse_grain_layer):
            for param in layer.parameters():
                weights_dict['cg'].append(param.detach().numpy().copy())
            weights_dict['S_cg'].append([param.detach().numpy().copy() for param in self.S_layers[i+1].parameters()][0])
            weights_dict['u_cg'].append([param.detach().numpy().copy() for param in self.u_layers[i+1].parameters()][0])
        
        return weights_dict
    
    def set_weights(self, weights_dict):
        with torch.no_grad():
            i = 0
            for layer in self.hfc:
                for param in layer.parameters():    
                    param.copy_(torch.Tensor(weights_dict['hfc'][i])) 
                    i+=1
            i = 0
            for param in self.fc_softmax.parameters():
                param.copy_(torch.Tensor(weights_dict['sm'][i]))
                i+=1

            if self.reversible:
                i=0
                for param in self.S_layers[0].parameters():
                    param.copy_(torch.Tensor(weights_dict['S'][i]))
                    i+=1
                i=0
                for param in self.u_layers[0].parameters():
                    param.copy_(torch.Tensor(weights_dict['u'][i]))
                    i+=1

            if self.mask_const:
                i=0
                for param in self.Mask.parameters():
                    param.copy_(torch.Tensor(weights_dict['Mask'][i]))
                    i+=1
            else:
                i=0
                for layer in self.Mask.hfc:
                    for param in layer.parameters():
                        param.copy_(torch.Tensor(weights_dict['Mask_hf'][i]))
                        i+=1
                i=0
                for param in self.Mask.softmax.parameters():
                    param.copy_(torch.Tensor(weights_dict['Mask_sm'][i]))
                    i+=1
            i=0
            for layer in self.coarse_grain_layer:

                for param in layer.parameters():
                    param.copy_(torch.Tensor(weights_dict['cg'][i]))
                for param in self.S_layers[i+1].parameters():
                    param.copy_(torch.Tensor(weights_dict['S_cg'][i]))
                for param in self.u_layers[i+1].parameters():
                    param.copy_(torch.Tensor(weights_dict['u_cg'][i]))
                i+=1

In [None]:
def VAMP_score(chi_t, chi_tau, corr=False):
    '''Calculates the VAMP-2 score with respect to the network lobes while 
    symmetrizing the correlation matrices. Can be used as a loss function
    for keras models.
    Parameters
    ----------
    chi_t: tensorflow tensor.
        parameter not needed for the calculation, added to comply with Keras
        rules for loss fuctions format.

    chi_tau: tensorflow tensor with shape [batch_size, 2 * output_size]
        output of the two lobes of the network

    Returns
    -------
    loss_score: tensorflow tensor with shape [1].
    '''
    shape = chi_t.shape
    
        
    batch_size = shape[0]



    x, y = _prep_data(chi_t, chi_tau) 

    # Calculate the covariance matrices
    cov_00 = 1/(batch_size - 1) * torch.matmul(x.T, x) 
    cov_11 = 1/(batch_size - 1) * torch.matmul(y.T, y)
    cov_01 = 1/(batch_size - 1) * torch.matmul(x.T, y)

    # Calculate the inverse of the self-covariance matrices
    cov_00_inv = _inv(cov_00, ret_sqrt = True)
    cov_11_inv = _inv(cov_11, ret_sqrt = True)
    

    # Estimate Vamp-matrix
    vamp_matrix = torch.matmul(cov_00_inv, torch.matmul(cov_01, cov_11_inv))
    
    
    vamp_score = torch.norm(vamp_matrix)
    score = (vamp_score**2).unsqueeze(0)
    if corr:
        return score, torch.trace(cov_00)
    else:
        return score

In [None]:
def vampe_loss(chi, gamma):
    b = chi.shape[0]
    
    c00 = 1/b*(torch.matmul(chi.T, chi))
    c11 = 1/b*(torch.matmul(gamma.T, gamma))
    c01 = 1/b*(torch.matmul(chi.T, gamma))

    gamma_dia_inv = torch.diag(1/(torch.mean(gamma, dim=0)))  # add something so no devide by zero

    first_term = c00 @ gamma_dia_inv @ c11 @ gamma_dia_inv
    second_term = 2 * (c01 @ gamma_dia_inv)
    vampe_arg = first_term - second_term
    vampe = torch.trace(vampe_arg)
    
    return -vampe

In [None]:
def vampe_loss_rev(chi_t, chi_tau, layer_id=0, return_mu=False, return_mu_K_Sigma=False):
    
    v, C_00, C_11, C_01, Sigma, mu_t, Sigma_t = Full_net.u_layers[layer_id](chi_t, chi_tau)
    matrix, K, _ = Full_net.S_layers[layer_id](v, C_00, C_11, C_01, Sigma)
    vampe = torch.trace(matrix)
    
    if return_mu:
        return -vampe, mu_t
    
    elif return_mu_K_Sigma:
        return -vampe, mu_t, K, Sigma_t
    
    else:
        return -vampe
    
def vampe_loss_rev_only_S(v, C_00, C_11, C_01, Sigma, layer_id=0):
    
    matrix, K, _ = Full_net.S_layers[layer_id](v, C_00, C_11, C_01, Sigma)
#     print(K)
    vampe = torch.trace(matrix)
    
    return -vampe

In [None]:
fac_bf_sm = True
valid_T=True # if valid transition matrix is enforced
reversible=True # if reversibility is enforced

# attention stuff
mask_const=True # if the trained attention mask is constant over time
heads=1
mask_depth=2 # if time dependent how many hidden layers has the attention network
mask_width=100 # the width of the attention hidden layers
sensitivity=1. # factor before attention softmax for clearer assignment
factor_att=True # if to use a factor which scales the input on average back to input

softmax_fac=1. # factor before classification softmax

# We will use a pre trained network on the whole trajectory to estimate the true observable values
train_mean = np.load('./dicts/tau2500skip1msm_mean.npy')
train_std = np.load('./dicts/tau2500skip1msm_std.npy')

Full_net = VampNet(input_size, output_sizes, nodes, train_mean, train_std, 
                 valid_T=valid_T, reversible=reversible,
                 mask_const=mask_const, mask_depth=mask_depth, mask_width=mask_width, heads=heads, sensitivity=sensitivity,
                 fac=factor_att,
                 softmax_fac=softmax_fac)

In [None]:
weights_dict = Full_net.get_weights()

In [None]:
optimizer_vamp = optim.Adam(Full_net.get_params_vamp(), lr=learning_rate*10)
optimizer_full = optim.Adam(Full_net.get_params_all(), lr=learning_rate) # also for all coarse grain

if Full_net.reversible:
    optimizer_rev = optim.Adam(Full_net.get_params_rev(), lr=learning_rate/10)
    optimizer_rev_u_S = optim.Adam(Full_net.get_params_rev(all=False), lr=learning_rate)
    optimizer_rev_S = optim.Adam(Full_net.get_params_rev(all=False, u_flag=False), lr=learning_rate*10)
    optimizer_rev_u = optim.Adam(Full_net.get_params_rev(all=False, S_flag=False), lr=learning_rate*10)

### Estimation of the real value of the observables and Data manipulation

1. Load a model trained on the full data set and estimate the real observable values
2. Identify a good microscopic observable which could be possibly measured in an experiment (here a contact)
3. Manipulate the data by removing folding and unfolding events from the data by observing the eigenfunction corresponding to the folding process and identify the frames which represent the crossing

In [None]:
# First load the model estimated on the whole data set
weights_msm = np.load('./dicts/tau2500skip1msm.npy', allow_pickle=True).item()
Full_net.set_weights(weights_msm)
   
tau = 50*2*25//skip
chi_true = Full_net.forward(torch.Tensor(traj_whole_new[0])).detach()
chi_whole = chi_true.numpy()
distances = - np.log(traj_whole_new[0])
contacts = (distances <0.45).astype('int')
# Define which distance is the microscopic observable
contact_obs = [41]
# Define helper functions to estimate the observables
def plot_mu(tensor_train_X1, tensor_train_X2, frames):
    
    chi_t = Full_net(tensor_train_X1)
    chi_tau = Full_net(tensor_train_X2)
    _, _, _, _, _, mu_t, _ = Full_net.u_layers[0](chi_t, chi_tau)
    mu = mu_t
    state_prob = torch.sum(mu * chi_true[frames], dim=0).detach().numpy()
#     plt.plot(state_prob, '.')
#     plt.show()
    return state_prob
# True stationary distribution of the three states
prob_states_true = plot_mu(torch.Tensor(traj_whole_new[0][:-tau]), torch.Tensor(traj_whole_new[0][tau:]), np.arange(traj_whole_new[0].shape[0]-tau))
# functions to estimate the difference of the true observable value and the one from the current model
def obs_loss(obs_value, mu, exp_value):
    
    exp_value_estimated = torch.sum(obs_value * mu)
    
    error = torch.abs(exp_value - exp_value_estimated)
    
    return error

def obs_time_loss(obs_value, mu, chi, K, Sigma, exp_value):

    state_weight = mu*chi

    pi = torch.sum(state_weight, dim=0) # prob to be in a state
    # obs value within a state, the weighting factor needs to be normalized for each state
    ai = torch.sum(state_weight*obs_value, dim=0, keepdims=True) / torch.sum(state_weight, dim=0, keepdims=True)
    # prob to observe an unconditional jump state i to j
    X = Sigma @ K
    a_sim = torch.matmul(ai, torch.matmul(X, ai.T)) # shape 1x1
        
    error = torch.abs(a_sim - exp_value) 
           
    return error

def vampe_loss_rev_obs(chi_t, chi_tau, layer_id=0, return_mu=False, all_eigvals=False):
    
    v, C_00, C_11, C_01, Sigma, mu_t, Sigma_t = Full_net.u_layers[layer_id](chi_t, chi_tau)
    matrix, K, S = Full_net.S_layers[layer_id](v, C_00, C_11, C_01, Sigma)
    vampe = torch.trace(matrix)
    
    
    eigval_all, eigvec_all = torch.symeig(Sigma, eigenvectors=True)

    # Filter out eigvalues below threshold and corresponding eigvectors
    eig_th = torch.Tensor(epsilon)
    index_eig = eigval_all > eig_th
#     print(index_eig)
    eigval = eigval_all[index_eig]
    eigvec = eigvec_all[:,index_eig]

    # Build the diagonal matrix with the filtered eigenvalues or square
    # root of the filtered eigenvalues according to the parameter

    diag = torch.diag(torch.sqrt(eigval))

#     print(diag.shape, eigvec.shape)    
    # Rebuild the square root of the inverse matrix
    Sigma_sqrt = torch.matmul(eigvec, torch.matmul(diag, eigvec.T))
    
    S_similar = Sigma_sqrt @ S @ Sigma_sqrt
    
    eigval_all, eigvec_all = torch.symeig(S_similar, eigenvectors=True)
    
#     sort_eigvals = torch.sort(eigval_all)
    if all_eigvals:
        ret = [-vampe, S_similar, K, eigval_all]
    else:
        ret = [-vampe, S_similar, K, eigval_all[0]]
    
    return ret

chi_t = chi_true[:-tau]
chi_tau = chi_true[tau:]


_, S_similar, K_test, eigval_fold = vampe_loss_rev_obs(chi_t, chi_tau)
# True eigenvalue of the folding process
true_value = eigval_fold.detach()

def get_K_rev(tensor_t=torch.Tensor(traj_whole_new[0][:-tau]), tensor_tau=torch.Tensor(traj_whole_new[0][tau:])):
    
    chi_t = Full_net(tensor_t)
    chi_tau = Full_net(tensor_tau)
    v, C_00, C_11, C_01, Sigma, _, _ = Full_net.u_layers[0](chi_t, chi_tau)
    _, K, _ = Full_net.S_layers[0](v, C_00, C_11, C_01, Sigma)
    
    return K.detach()
# microscopic observable value for each time frame, if a contact is formed or not
obs_value_time = [(contacts[:,i][:,None]).astype('float32') for i in contact_obs]
# and the reverse observable
obs_value_time.append(((obs_value_time[0]-1.)*-1.).astype('float32'))



chi_t = torch.Tensor(chi_whole[:-tau])
chi_tau = torch.Tensor(chi_whole[tau:])
# Estimate the true auto correlation value of the contact staying formed
exp_value_time = []
for obs_v in obs_value_time:
    obs_value_tensor = torch.Tensor(obs_v[:-tau])
    exp_value_tensor = torch.Tensor(np.array([0.]))

    score_curr, mu_t, K, Sigma = vampe_loss_rev(chi_t, chi_tau, return_mu_K_Sigma=True)
    error_obs = obs_time_loss(obs_value_tensor, mu_t, chi_t, K, Sigma, exp_value_tensor)
    exp_value_time.append(error_obs.detach())
print(exp_value_time)
# Estimate the expectation value that the contact is formed
obs_exp_value = [(contacts[:,i][:,None]).astype('float32') for i in contact_obs]
exp_exp_value = []
for obs_v in obs_exp_value:
    obs_value_tensor = torch.Tensor(obs_v[:-tau])
    exp_value_tensor = torch.Tensor(np.array([0.]))

    score_curr, mu_t = vampe_loss_rev(chi_t, chi_tau, return_mu=True)
    error_obs = obs_loss(obs_value_tensor, mu_t, exp_value_tensor)
    exp_exp_value.append(error_obs.detach())
print(exp_exp_value)

for obs in obs_value_time:
#     plt.plot(obs)
#     plt.show()
    print(np.sum(obs[:-tau]*obs[tau:])/(obs[tau:].shape[0]))
# remove data depending on where you land

K_exp = get_K_rev()
np.linalg.eigvals(K_exp)
eigvals, eigvec = np.linalg.eig(K_exp)
sort_ind = np.argsort(eigvals)
eigvals_sort = eigvals[sort_ind]
eigvec_sort = eigvec[:,sort_ind]

# Manipulate the data
data_corrupt = 'barrier' 
if data_corrupt == 'barrier':
    K = get_K_rev(torch.Tensor(traj_whole[0][:-tau]), torch.Tensor(traj_whole[0][tau:]))
    # K = estimate_koopman_op(chi_whole, tau)
    np.linalg.eigvals(K)
    eigvals, eigvecs = np.linalg.eig(K)
    sort_id = np.argsort(eigvals)
    eigvals_sort = eigvals[sort_id]
    eigvecs_sort = eigvecs[:,sort_id]
    print(eigvals_sort[-3])
    # Estimate the eigenfunction corresponding to the folding process
    eigfunc = chi_whole @ eigvecs_sort[:,-3]
    min_eigfunc = eigfunc.min()
    max_eigfunc = eigfunc.max()
    # Find data points which are close to the folded and unfolded state
    starting_points = np.where(eigfunc < (min_eigfunc + 0.05))[0]
    end_points = np.where(eigfunc > (max_eigfunc - 0.05))[0]
    
    # truncate the whole trajectory into folding and unfolding events
    transition_forward_ind = []
    transition_backward_ind = []

    for j, s in enumerate(starting_points):

        distance_end = end_points - s

        where_positive = np.where(distance_end>0)[0]
        if len(where_positive)>0:
            e = end_points[where_positive[0]]
            if j+1 < starting_points.shape[0]:
                next_s = starting_points[j+1]
                if next_s > e:
                    s_new = s-tau
                    if s_new < 0:
                        s_new = 0
                    transition_forward_ind.append(np.arange(s_new, e))
                    print('Found new transition {}, {}'.format(s_new, e))


    for j, s in enumerate(end_points):

        distance_end = starting_points - s

        where_positive = np.where(distance_end>0)[0]
        if len(where_positive)>0:
            e = starting_points[where_positive[0]]
            if j+1 < end_points.shape[0]:
                next_s = end_points[j+1]
                if next_s > e:
                    s_new = s-tau
                    if s_new < 0:
                        s_new = 0
                    transition_backward_ind.append(np.arange(s_new, e))
                    print('Found new back transition {}, {}'.format(s_new, e))

    # Find the frames which are not part of the folding/unfolding
    non_transition_ind = []

    len_for = len(transition_forward_ind)
    len_back = len(transition_backward_ind)

    if len_for > len_back:
        length_total = len_for
    else:
        length_total = len_back
    s = 0
    for i in range(length_total):
        if i < len_for:
            forward = transition_forward_ind[i]
        if i < len_back:
            backward = transition_backward_ind[i]

        if forward[0] < backward[0]:
            e = forward[0]
            non_transition_ind.append(np.arange(s,e))
            s = forward[-1]+1

            if s < backward[0]:
                e = backward[0]
                non_transition_ind.append(np.arange(s,e))
            s = backward[-1]
        else:
            e = backward[0]
            non_transition_ind.append(np.arange(s,e))
            s = backward[-1]+1

            if s < forward[0]:
                e = forward[0]
                non_transition_ind.append(np.arange(s,e))
            s = forward[-1]

    non_transition_ind.append(np.arange(s,traj_whole[0].shape[0]-tau))


    non_transition_ind = np.concatenate(non_transition_ind)
    test_for = np.concatenate(transition_forward_ind)
    test_back = np.concatenate(transition_backward_ind)
    print(test_for.shape, test_back.shape, non_transition_ind.shape)
    plt.plot(eigfunc,'.')
    plt.plot(non_transition_ind, eigfunc[non_transition_ind], '.')
    plt.plot(test_for, eigfunc[test_for], '.')
    plt.plot(test_back, eigfunc[test_back], '.')
    plt.show()


    ind_train = []
    ind_valid = []
    ind_test = []
    non_length = non_transition_ind.shape[0]//3
    np.random.shuffle(non_transition_ind)
    ind_train.append(non_transition_ind[:non_length])
    ind_valid.append(non_transition_ind[non_length:2*non_length])
    ind_test.append(non_transition_ind[2*non_length:])
    # add folding events with a probability of p_for, which influences how much less often a folding
    # will occur in the fake trajectory
    p_for = 0.25

    nr_for = int(p_for*len(transition_forward_ind)//3)
    print('Number of transitions: {}'.format(nr_for))
    ind_trajs_temp = np.arange(len(transition_forward_ind))
    np.random.shuffle(ind_trajs_temp)
    for i in range(nr_for):
        ind_train.append(transition_forward_ind[ind_trajs_temp[i]])

        ind_valid.append(transition_forward_ind[ind_trajs_temp[i+nr_for]])

        ind_test.append(transition_forward_ind[ind_trajs_temp[i+2*nr_for]])
    # the same for the unfolding
    p_back = 0.25

    nr_back = int(p_back*len(transition_backward_ind)//3)
    print('Number of transitions: {}'.format(nr_for))
    ind_trajs_temp = np.arange(len(transition_backward_ind))
    np.random.shuffle(ind_trajs_temp)
    for i in range(nr_back):
        ind_train.append(transition_backward_ind[ind_trajs_temp[i]])

        ind_valid.append(transition_backward_ind[ind_trajs_temp[i+nr_back]])

        ind_test.append(transition_backward_ind[ind_trajs_temp[i+2*nr_back]])

    ind_train = np.concatenate(ind_train)
    ind_valid = np.concatenate(ind_valid)
    ind_test = np.concatenate(ind_test)
    plt.plot(ind_train, eigfunc[ind_train], '.')
    plt.plot(ind_valid, eigfunc[ind_valid], '.')
    plt.plot(ind_test, eigfunc[ind_test], '.')
    plt.show()
    
    np.random.shuffle(ind_train)
    np.random.shuffle(ind_valid)
    np.random.shuffle(ind_test)
    def get_data_for_tau_corrupt(single_traj, obs_value_time, tau, ind_train, ind_valid, ind_test):



        obs_ord = []
        for ob in obs_value_time:
            obs_ord.append(ob[:-tau])
        traj_ord = single_traj[:-tau]
        traj_ord_lag = single_traj[tau:]

        traj_data_train = traj_ord[ind_train]
        traj_data_train_lag = traj_ord_lag[ind_train]
        obs_train = []
        for ob in obs_ord:
            obs_train.append(ob[ind_train].astype('float32'))
        
        traj_data_valid = traj_ord[ind_valid]
        traj_data_valid_lag = traj_ord_lag[ind_valid]
        
        obs_valid = []
        for ob in obs_ord:
            obs_valid.append(ob[ind_valid].astype('float32'))
        # Input of the first network
        
        
        traj_data_test = traj_ord[ind_test]
        traj_data_test_lag = traj_ord_lag[ind_test]
        obs_test = []
        for ob in obs_ord:
            obs_test.append(ob[ind_test].astype('float32'))
        
        X1_train = traj_data_train.astype('float32')
        X2_train  = traj_data_train_lag.astype('float32')

        # Input for validation
        X1_vali = traj_data_valid.astype('float32')
        X2_vali = traj_data_valid_lag.astype('float32')
    
        X1_test = traj_data_test.astype('float32')
        X2_test = traj_data_test_lag.astype('float32')

        return X1_train, X2_train, obs_train, X1_vali, X2_vali, obs_valid, X1_test, X2_test, obs_test


    X1_train_cor, X2_train_cor, obs_train, X1_vali_cor, X2_vali_cor, obs_valid, X1_test_cor, X2_test_cor, obs_test = get_data_for_tau_corrupt(traj_whole_new[0], obs_exp_value, tau, ind_train, ind_valid, ind_test)

In [None]:
print('The true eigenvalue of the folding process: {:.3}'.format(true_value))

In [None]:
# The indeces for the test trajectory for evaluating the performance on
frames_test = ind_test

### Get data for training chi

In [None]:
if data_corrupt == 'barrier':     
    def get_data_for_tau(single_traj, tau, ind_train, ind_valid, ind_test):



    
        traj_ord = single_traj[:-tau]
        traj_ord_lag = single_traj[tau:]


        traj_data_train = traj_ord[ind_train]
        traj_data_train_lag = traj_ord_lag[ind_train]
        
        traj_data_valid = traj_ord[ind_valid]
        traj_data_valid_lag = traj_ord_lag[ind_valid]
        
        
        traj_data_test = traj_ord[ind_test]
        traj_data_test_lag = traj_ord_lag[ind_test]
        
        X1_train = traj_data_train.astype('float32')
        X2_train  = traj_data_train_lag.astype('float32')

        # Input for validation
        X1_vali = traj_data_valid.astype('float32')
        X2_vali = traj_data_valid_lag.astype('float32')
    
        X1_test = traj_data_test.astype('float32')
        X2_test = traj_data_test_lag.astype('float32')
        

        return X1_train, X2_train, X1_vali, X2_vali, X1_test, X2_test


    X1_train, X2_train, X1_vali, X2_vali, X1_test, X2_test = get_data_for_tau(traj_whole_new[0], tau_chi, ind_train, ind_valid, ind_test)

In [None]:
tensor_train_X1 = torch.Tensor(X1_train)
tensor_train_X2 = torch.Tensor(X2_train) # transform to torch tensor
tensor_valid_X1 = torch.Tensor(X1_vali)
tensor_valid_X2 = torch.Tensor(X2_vali)
tensor_test_X1 = torch.Tensor(X1_test)
tensor_test_X2 = torch.Tensor(X2_test)

trainset = data.TensorDataset(tensor_train_X1, tensor_train_X2) # create your datset
trainloader = data.DataLoader(trainset, batch_size=batch_size,
                              shuffle=True, num_workers=2)

trainloader_full = data.DataLoader(trainset, batch_size=300000,
                              shuffle=True, num_workers=2)

testset = data.TensorDataset(tensor_valid_X1, tensor_valid_X2) # create your datset
testloader = data.DataLoader(testset, batch_size=batch_size,
                             shuffle=True, num_workers=2)

full_batch = False
if full_batch:
    train_l = trainloader_full
else:
    train_l = trainloader

### Trianing loops for the different models

In [None]:
def train_for_VAMPnet(runs, plot_mask_every=10, verbose=True, plot_training=True, corr=False, best_weights_flag=False):
    
    epoch_loss = np.zeros(runs)
    epoch_loss_corr = np.zeros(runs)
    epoch_loss_valid = np.zeros(runs)
    epoch_loss_valid_corr = np.zeros(runs)
#     sen = np.linspace(.1,2,10)
    if best_weights_flag:
        best_score=0.
        best_weights = Full_net.get_weights()
    for epoch in range(runs):  # loop over the dataset multiple times

#         Full_net.
        opt = optimizer_vamp
    #     sen_temp = sen_set[epoch//sen_every]

#         Full_net.set_soft_fac(sen[[epoch]])


        running_epoch_loss = []
        running_epoch_corr = []
        for i, data_batch in enumerate(train_l, 0):
            # get the inputs; data is a list of [inputs, labels]
            inputs_t, inputs_tau = data_batch

            # zero the parameter gradients

            opt.zero_grad()

            # estimate weights
            chi_t = Full_net(inputs_t)
            chi_tau = Full_net(inputs_tau)

            score_list = VAMP_score(chi_t, chi_tau, corr=corr)

            if corr:
                score_curr = score_list[0]
                score_corr = score_list[1]
                loss = -score_curr - score_corr
                running_epoch_corr.append(score_corr.item())
            else:
                score_curr = score_list
                loss = - score_curr

            loss.backward()

            opt.step()

            running_epoch_loss.append(score_curr.item())

        # validation
        chi_t_vali = Full_net(tensor_valid_X1).detach()
        chi_tau_vali = Full_net(tensor_valid_X2).detach()
        loss_vali_list = VAMP_score(chi_t_vali, chi_tau_vali, corr=corr)
        del chi_t_vali
        del chi_tau_vali
        del chi_t
        del chi_tau
        if corr:
            loss_vali = loss_vali_list[0].detach()
            loss_vali_corr = loss_vali_list[1].detach()
            epoch_loss_valid_corr[epoch] = loss_vali_corr.item()
        else:
            loss_vali = loss_vali_list.detach()
        epoch_loss_valid[epoch] = loss_vali.item()
        if best_weights_flag:
            if epoch_loss_valid[epoch]>best_score:
                print('Better validation score, save weights')
                best_weights = Full_net.get_weights()
                best_score = epoch_loss_valid[epoch]
            
        epoch_loss[epoch] = np.mean(running_epoch_loss)
        epoch_loss_corr[epoch] = np.mean(running_epoch_corr)
        if verbose:
            print('Run {}, total loss: {:.3}, valid loss: {:.3}'.format(epoch+1, epoch_loss[epoch], epoch_loss_valid[epoch]))

    if plot_training:
        plt.plot(np.arange(1,runs+1), epoch_loss, label='VAMP_loss')
        plt.plot(np.arange(1,runs+1), epoch_loss_valid, label='VAMP_loss_valid')
        plt.legend()
        plt.show()
        if corr:
            plt.plot(np.arange(1,runs+1), epoch_loss_corr, label='Corr')
            plt.plot(np.arange(1,runs+1), epoch_loss_corr, label='Corr_valid')
            plt.legend()
            plt.show()
    if best_weights_flag:
        print('Set best weights')
        Full_net.set_weights(best_weights)

In [None]:
# Put the parameters to the value before we loaded the true model
Full_net.set_weights(weights_dict)

### Pretrain a VAMPnet which will be used as initialization for both models w/o exp. observables
This model only sees the manipulated data

In [None]:
optimizer_vamp = optim.Adam(Full_net.get_params_vamp(), lr=learning_rate)
train_for_VAMPnet(30, corr=True)

optimizer_vamp = optim.Adam(Full_net.get_params_vamp(), lr=learning_rate)
train_for_VAMPnet(30)


In [None]:
train_for_VAMPnet(100)

In [None]:
weights_dict_chi = Full_net.get_weights()

### Helper functions to train for the reversible DMSM

In [None]:
def train_for_S(runs=100, verbose=True, plot_training=True):
    
    chi_t = Full_net(tensor_train_X1)
    chi_tau = Full_net(tensor_train_X2)
    v, C_00, C_11, C_01, Sigma, _, _ = Full_net.u_layers[0](chi_t, chi_tau)
    
    v = v.detach()
    C_00 = C_00.detach()
    C_11 = C_11.detach()
    C_01 = C_01.detach()
    Sigma = Sigma.detach()
    
    chi_t_valid = Full_net(tensor_valid_X1)
    chi_tau_valid = Full_net(tensor_valid_X2)
    v_valid, C_00_valid, C_11_valid, C_01_valid, Sigma_valid, _, _ = Full_net.u_layers[0](chi_t_valid, chi_tau_valid)
    
    v_valid = v_valid.detach()
    C_00_valid = C_00_valid.detach()
    C_11_valid = C_11_valid.detach()
    C_01_valid = C_01_valid.detach()
    Sigma_valid = Sigma_valid.detach()
    
    epoch_loss = np.zeros(runs)
    epoch_loss_valid = np.zeros(runs)
    opt = optimizer_rev_S
    for epoch in range(runs):  # loop over the dataset multiple times

        opt.zero_grad()


        score_curr = vampe_loss_rev_only_S(v, C_00, C_11, C_01, Sigma)



        loss = - score_curr

        loss.backward()

        opt.step()


        epoch_loss[epoch] = np.mean(-loss.item())
        
        score_curr_valid = vampe_loss_rev_only_S(v_valid, C_00_valid, C_11_valid, C_01_valid, Sigma_valid)
        epoch_loss_valid[epoch] = np.mean(score_curr_valid.item())
        if verbose:
            print('Run {}, total loss: {:.3}, valid loss: {:.3}'.format(epoch+1, epoch_loss[epoch], epoch_loss_valid[epoch]))

    if plot_training:
        plt.plot(np.arange(1,runs+1), epoch_loss, label='VAMP_loss')
        plt.plot(np.arange(1,runs+1), epoch_loss_valid, label='VAMP_loss_valid')
        plt.legend()
        plt.show()    

In [None]:
def train_for_u_S(runs=100, verbose=True, plot_training=True):
    
    chi_t = Full_net(tensor_train_X1).detach()
    chi_tau = Full_net(tensor_train_X2).detach()
    
    chi_t_valid = Full_net(tensor_valid_X1).detach()
    chi_tau_valid = Full_net(tensor_valid_X2).detach()
    
    
    
    
    epoch_loss = np.zeros(runs)
    epoch_loss_valid = np.zeros(runs)
    opt = optimizer_rev_u_S
    for epoch in range(runs):  # loop over the dataset multiple times

        opt.zero_grad()


        score_curr = vampe_loss_rev(chi_t, chi_tau)

        loss = - score_curr

        loss.backward()

        opt.step()

        score_curr_valid = vampe_loss_rev(chi_t_valid, chi_tau_valid)
        epoch_loss[epoch] = np.mean(-loss.item())
        epoch_loss_valid[epoch] = np.mean(score_curr_valid.item())
        if verbose:
            print('Run {}, total loss: {:.3}, valid loss: {:.3}'.format(epoch+1, epoch_loss[epoch], epoch_loss_valid[epoch]))

    if plot_training:
        plt.plot(np.arange(1,runs+1), epoch_loss, label='VAMP_loss')
        plt.plot(np.arange(1,runs+1), epoch_loss_valid, label='VAMP_loss_valid')
        plt.legend()
        plt.show()    

### Train a revDMSM without observable knowledge to compare with

In [None]:
Full_net.set_weights(weights_dict_chi)

In [None]:
chi_t = Full_net.forward(tensor_train_X1).detach().numpy()
chi_tau = Full_net.forward(tensor_train_X2).detach().numpy()
K_vamp = estimate_koopman_op([chi_t, chi_tau], 0)

In [None]:
if Full_net.reversible:
    optimizer_rev = optim.Adam(Full_net.get_params_rev(), lr=learning_rate/10)
    
    optimizer_rev_S = optim.Adam(Full_net.get_params_rev(all=False, u_flag=False), lr=learning_rate*10)
    optimizer_rev_u = optim.Adam(Full_net.get_params_rev(all=False, S_flag=False), lr=learning_rate*10)

In [None]:
# First initialize u and S with the estimates of the VAMPnet, then train for S
if Full_net.reversible:
    Full_net.set_rev_var(S=True)
    train_for_S(runs=3000, verbose=False)

In [None]:
weights_after_S = Full_net.get_weights()

In [None]:
def train_for_rev(runs, plot_mask_every=10, verbose=True, plot_training=True):
    
    epoch_loss = np.zeros(runs)
    epoch_loss_valid = np.zeros(runs)

    for epoch in range(runs):  # loop over the dataset multiple times


        opt = optimizer_rev
 

        running_epoch_loss = []
        
        for i, data_batch in enumerate(trainloader_full, 0):
            # get the inputs; data is a list of [inputs, labels]
            inputs_t, inputs_tau = data_batch

            # zero the parameter gradients

            opt.zero_grad()

            # estimate weights
            chi_t = Full_net(inputs_t)
            chi_tau = Full_net(inputs_tau)

            score_curr = vampe_loss_rev(chi_t, chi_tau)



            loss = - score_curr

            loss.backward()

            opt.step()

            running_epoch_loss.append(score_curr.item())

        # validation
        chi_t_valid = Full_net(tensor_valid_X1)
        chi_tau_valid = Full_net(tensor_valid_X2)
        score_valid = vampe_loss_rev(chi_t_valid, chi_tau_valid)
        
        epoch_loss[epoch] = np.mean(running_epoch_loss)
        epoch_loss_valid[epoch] = np.mean(score_valid.item())
        if verbose:
            print('Run {}, total loss: {:.3}, valid loss: {:.3}'.format(epoch+1, epoch_loss[epoch], epoch_loss_valid[epoch]))

    if plot_training:
        plt.plot(np.arange(1,runs+1), epoch_loss, label='VAMP_loss')
        plt.plot(np.arange(1,runs+1), epoch_loss_valid, label='VAMP_loss_valid')
        plt.legend()
        plt.show()    

In [None]:
if Full_net.reversible:
    optimizer_rev = optim.Adam(Full_net.get_params_rev(), lr=learning_rate/100)
    optimizer_rev_u_S = optim.Adam(Full_net.get_params_rev(all=False), lr=learning_rate)
    optimizer_rev_S = optim.Adam(Full_net.get_params_rev(all=False, u_flag=False), lr=learning_rate*10)
    optimizer_rev_u = optim.Adam(Full_net.get_params_rev(all=False, S_flag=False), lr=learning_rate*1)

In [None]:
weights_before_rec = Full_net.get_weights()

### Training loop until the validation score is converged

In [None]:
def recursive_opt_rev(verbose, plot_training, reset=True):
    
    chi_t = Full_net(tensor_valid_X1)
    chi_tau = Full_net(tensor_valid_X2)
    score_old = vampe_loss_rev(chi_t, chi_tau).item()
    
    
    weights_all = Full_net.get_weights()
    score = score_old


    while (score >= score_old):
        if reset:
            Full_net.set_rev_var(S=False)

        train_for_S(runs=1000, verbose=verbose, plot_training=plot_training)

        train_for_u_S(runs=1000, verbose=verbose, plot_training=plot_training)

        train_for_rev(runs=10, verbose=verbose, plot_training=plot_training)

        chi_t = Full_net(tensor_valid_X1)
        chi_tau = Full_net(tensor_valid_X2)
        score = vampe_loss_rev(chi_t, chi_tau).item()

        print('Old score {}, new score {}'.format(score_old, score))
        if (score>score_old):
            print('Score is better and weights are saved')
            score_old = score
            weights_all = Full_net.get_weights()


    Full_net.set_weights(weights_all)

    
    weights_all = Full_net.get_weights()
    return weights_all

In [None]:
Full_net.set_weights(weights_after_S)

In [None]:
weights_after = recursive_opt_rev(verbose=False, plot_training=True, reset=False)

In [None]:
if Full_net.reversible:
    optimizer_rev = optim.Adam(Full_net.get_params_rev(), lr=learning_rate/100)
    optimizer_rev_u_S = optim.Adam(Full_net.get_params_rev(all=False), lr=learning_rate)
    optimizer_rev_S = optim.Adam(Full_net.get_params_rev(all=False, u_flag=False), lr=learning_rate*10)
    optimizer_rev_u = optim.Adam(Full_net.get_params_rev(all=False, S_flag=False), lr=learning_rate*10)

### Trianing loop until the validation score is converged but with resetting the parameters of u

In [None]:
def recursive_opt_rev(verbose, plot_training, reset=True):
    
    chi_t = Full_net(tensor_valid_X1)
    chi_tau = Full_net(tensor_valid_X2)
    score_old = vampe_loss_rev(chi_t, chi_tau).item()
    
    
    weights_all = Full_net.get_weights()
    score = score_old


    while (score >= score_old):
        if reset:
            Full_net.set_rev_var(S=False)
        train_for_rev(runs=10, verbose=verbose, plot_training=plot_training)
        
        train_for_S(runs=1000, verbose=verbose, plot_training=plot_training)

        train_for_u_S(runs=100, verbose=verbose, plot_training=plot_training)

        

        chi_t = Full_net(tensor_valid_X1)
        chi_tau = Full_net(tensor_valid_X2)
        score = vampe_loss_rev(chi_t, chi_tau).item()

        print('Old score {}, new score {}'.format(score_old, score))
        if (score>score_old):
            print('Score is better and weights are saved')
            score_old = score
            weights_all = Full_net.get_weights()


    Full_net.set_weights(weights_all)

    
    weights_all = Full_net.get_weights()
    return weights_all

In [None]:
weights_after = recursive_opt_rev(verbose=False, plot_training=True)

In [None]:
K = get_K_rev()
# compare the eigenvalues of the VAMPnet result with the revDMSM result
np.linalg.eigvals(K), np.linalg.eigvals(K_vamp)

In [None]:
weights_after_rec = Full_net.get_weights()

### Train u and S of the DMSM for larger tau, the data was prepared above

In [None]:
tau = 50*2*25//skip

tensor_train_X1 = torch.Tensor(X1_train_cor)
tensor_train_X2 = torch.Tensor(X2_train_cor) # transform to torch tensor
tensor_valid_X1 = torch.Tensor(X1_vali_cor)
tensor_valid_X2 = torch.Tensor(X2_vali_cor)
tensor_test_X1 = torch.Tensor(X1_test_cor)
tensor_test_X2 = torch.Tensor(X2_test_cor)

trainset = data.TensorDataset(tensor_train_X1, tensor_train_X2) # create your datset
trainloader = data.DataLoader(trainset, batch_size=batch_size,
                              shuffle=True, num_workers=2)

trainloader_full = data.DataLoader(trainset, batch_size=300000,
                              shuffle=True, num_workers=2)

testset = data.TensorDataset(tensor_valid_X1, tensor_valid_X2) # create your datset
testloader = data.DataLoader(testset, batch_size=batch_size,
                             shuffle=True, num_workers=2)

In [None]:
Full_net.set_weights(weights_after_rec)
with torch.no_grad():
    for param in Full_net.S_layers[0].parameters():
        param.copy_(torch.Tensor(np.ones((output_sizes[0], output_sizes[0]))))
with torch.no_grad():
    for param in Full_net.u_layers[0].parameters():
        param.copy_(torch.Tensor(np.ones((1, output_sizes[0]))))
        
optimizer_rev_u_S = optim.Adam(Full_net.get_params_rev(all=False), lr=learning_rate)
optimizer_rev_S = optim.Adam(Full_net.get_params_rev(all=False, u_flag=False), lr=learning_rate*10)

In [None]:
for _ in range(5):
    train_for_S(runs=1000, verbose=False, plot_training=True)

    train_for_u_S(runs=1000, verbose=False, plot_training=True)

In [None]:
K_msm = get_K_rev(tensor_test_X1, tensor_test_X2)
chi_t = Full_net.forward(tensor_train_X1).detach().numpy()
chi_tau = Full_net.forward(tensor_train_X2).detach().numpy()
K_vamp_msm = estimate_koopman_op([chi_t, chi_tau], 0)

In [None]:
# Compare the eigenvalues of VAMPnet and revDMSM, they should be similar
np.linalg.eigvals(K_msm), np.linalg.eigvals(K_vamp_msm)

In [None]:
weights_msm = Full_net.get_weights()

In [None]:
# Simple renaming of the data
tensor_train_X1_cor = torch.Tensor(X1_train_cor)
tensor_train_X2_cor = torch.Tensor(X2_train_cor) # transform to torch tensor
tensor_valid_X1_cor = torch.Tensor(X1_vali_cor)
tensor_valid_X2_cor = torch.Tensor(X2_vali_cor)
tensor_test_X1_cor = torch.Tensor(X1_test_cor)
tensor_test_X2_cor = torch.Tensor(X2_test_cor)

trainset_cor = data.TensorDataset(tensor_train_X1_cor, tensor_train_X2_cor) # create your datset
trainloader_cor = data.DataLoader(trainset_cor, batch_size=batch_size,
                              shuffle=True, num_workers=2)

trainloader_full_cor = data.DataLoader(trainset_cor, batch_size=X1_train_cor.shape[0],
                              shuffle=True, num_workers=2)

testset_cor = data.TensorDataset(tensor_valid_X1_cor, tensor_valid_X2_cor) # create your datset
testloader_cor = data.DataLoader(testset_cor, batch_size=X1_vali_cor.shape[0],
                             shuffle=True, num_workers=2)

### Estimate the observable values and compare to the true once

These values should be different, since we manipulated the data and did not use any information from the true simulation yet

In [None]:
chi_t = Full_net(torch.Tensor(X1_train_cor))
chi_tau = Full_net(torch.Tensor(X2_train_cor))


_, S_similar, K_test, eigval_fold = vampe_loss_rev_obs(chi_t, chi_tau)
eigval_fold_before = eigval_fold.detach()

exp_value_time_before = []
for obs_v in obs_train:
    obs_value_tensor = torch.Tensor(obs_v)
    exp_value_tensor = torch.Tensor(np.array([0.]))

    score_curr, mu_t, K, Sigma = vampe_loss_rev(chi_t, chi_tau, return_mu_K_Sigma=True)
    error_obs = obs_time_loss(obs_value_tensor, mu_t, chi_t, K, Sigma, exp_value_tensor)
    exp_value_time_before.append(error_obs)

In [None]:
exp_value_time_before, exp_exp_value

In [None]:
eigval_fold_before, true_value

### Estimate model with information of some experimental observables
Here we use the expectation value of the contact being formed and the timescale of the folding process, but it can be easily adapted to the other cases presented in the paper

### Loss functions

In [None]:
def timescale_loss(obs_value, exp_value):
    
    error = torch.sum(torch.abs(exp_value - obs_value))
    
    return error

    
def obs_loss(obs_value, mu, exp_value):
    
    exp_value_estimated = torch.sum(obs_value * mu)
    
    error = torch.abs(exp_value - exp_value_estimated)
    
    return error

def obs_time_loss(obs_value, mu, chi, K, Sigma, exp_value):

    state_weight = mu*chi

    pi = torch.sum(state_weight, dim=0) # prob to be in a state
    # obs value within a state, the weighting factor needs to be normalized for each state
    ai = torch.sum(state_weight*obs_value, dim=0, keepdims=True) / torch.sum(state_weight, dim=0, keepdims=True)
    # prob to observe an unconditional jump state i to j
    X = Sigma @ K
    a_sim = torch.matmul(ai, torch.matmul(X, ai.T)) # shape 1x1
        
    error = torch.abs(a_sim - exp_value) 
           
    return error

### Since we want to match eigenvalues of a specific process we need to define, what the process is
Our approach: We pretrained a VAMPnet, where we can identify the states which are the folded and which is the unfolded state. This helps us to identify the process which changes the most between these two states.

In [None]:
# get frames for folded and unfolded structure as extreme values of the eigenfunction of the folding process
frame1_fu, frame2_fu = np.argmax(eigfunc), np.argmin(eigfunc) 
# get the corresponding states of the new chi
pred_fu = Full_net.forward(torch.Tensor(traj_whole[0][[frame1_fu,frame2_fu]]))
state1_fu = np.argmax(pred_fu[0].detach().numpy())
state2_fu = np.argmax(pred_fu[1].detach().numpy())
def vampe_loss_rev_obs(chi_t, chi_tau, layer_id=0, return_mu=False, all_eigvals=False):
    
    v, C_00, C_11, C_01, Sigma, mu_t, Sigma_t = Full_net.u_layers[layer_id](chi_t, chi_tau)
    matrix, K, S = Full_net.S_layers[layer_id](v, C_00, C_11, C_01, Sigma)
    vampe = torch.trace(matrix)
    
    
    eigval_all, eigvec_all = torch.symeig(Sigma, eigenvectors=True)

    # Filter out eigvalues below threshold and corresponding eigvectors
    eig_th = torch.Tensor(epsilon)
    index_eig = eigval_all > eig_th
#     print(index_eig)
    eigval = eigval_all[index_eig]
    eigvec = eigvec_all[:,index_eig]

    # Build the diagonal matrix with the filtered eigenvalues or square
    # root of the filtered eigenvalues according to the parameter

    diag = torch.diag(torch.sqrt(eigval))
    diag_inv = torch.diag(1/torch.sqrt(eigval))
    #     print(diag.shape, eigvec.shape)    
    # Rebuild the square root of the inverse matrix
    Sigma_sqrt = torch.matmul(eigvec, torch.matmul(diag, eigvec.T))
    Sigma_sqrt_inv = torch.matmul(eigvec, torch.matmul(diag_inv, eigvec.T))

    S_similar = Sigma_sqrt @ S @ Sigma_sqrt

    eigval_all, eigvec_all = torch.symeig(S_similar, eigenvectors=True)
    eigvecs_K = Sigma_sqrt_inv @ eigvec_all

    # Find the process which is the folding
    process_id = torch.argmin(eigvecs_K[state1_fu,:]*eigvecs_K[state2_fu,:]).detach()
    
#     sort_eigvals = torch.sort(eigval_all)
    if all_eigvals:
        ret = [-vampe, S_similar, K, eigval_all]
    else:
        ret = [-vampe, S_similar, K, eigval_all[process_id], mu_t, Sigma_t]
    
    return ret

### Training function for updating u and S while matching the observable

In [None]:
def train_for_obs_timescale(exp_value_ts, obs_value, exp_value, weight_loss_ts=1, weight_loss=1, runs=100, S_flag=True, u_flag=False, verbose=False, plot_training=False):
        
        
    chi_t = Full_net(tensor_train_X1_cor).detach()
    chi_tau = Full_net(tensor_train_X2_cor).detach()
    
    
    
    
    
    epoch_loss = np.zeros(runs)
    epoch_vampe = np.zeros(runs)
    epoch_error = np.zeros(runs)
    epoch_error1 = np.zeros(runs)
#     epoch_error2 = np.zeros(runs)
    
    
    
    if u_flag and S_flag:
        opt = optimizer_rev_u_S_cor
    elif u_flag:
        opt = optimizer_rev_u_cor
    else:
        opt = optimizer_rev_S_cor
    
    obs_value_tensor = []
    exp_value_tensor = []
    for ob in obs_value:
        obs_value_tensor.append(torch.Tensor(ob))
    for ex in exp_value:
        exp_value_tensor.append(ex.detach())
        
    
    for epoch in range(runs):  # loop over the dataset multiple times
    
        opt.zero_grad()

        # estimate weights
#             print(inputs_t)

#             print(chi_t)
        score_curr, S_similar, K, eigval_fold, mu_t, Sigma_t = vampe_loss_rev_obs(chi_t, chi_tau)
        
        error_obs = timescale_loss(eigval_fold, exp_value_ts)
        
        error_obs1 = obs_loss(obs_value_tensor[0], mu_t, exp_value_tensor[0])
        
#         error_obs1 = obs_time_loss(obs_value_tensor[0], mu_t, chi_t, K, Sigma_t, exp_value_tensor[0])
#         error_obs2 = obs_time_loss(obs_value_tensor[1], mu_t, chi_t, K, Sigma_t, exp_value_tensor[1])
        
        loss = - score_curr + weight_loss * error_obs1 + weight_loss_ts * error_obs

        loss.backward()

        opt.step()
    
        epoch_loss[epoch] = np.mean(-loss.item())
        epoch_vampe[epoch] = np.mean(score_curr.item())
        epoch_error[epoch] = np.mean(error_obs.item())
        epoch_error1[epoch] = np.mean(error_obs1.item())
#         epoch_error2[epoch] = np.mean(error_obs2.item())
        if verbose:
            print('Run {}, total loss: {:.3}'.format(epoch+1, epoch_loss[epoch]))

    if plot_training:
        plt.plot(np.arange(1,runs+1), epoch_loss, label='Total_loss')
        plt.legend()
        plt.show() 
        
        plt.plot(np.arange(1,runs+1), epoch_vampe, label='VAMP_loss')
        plt.legend()
        plt.show() 
        
        plt.plot(np.arange(1,runs+1), epoch_error, label='Error_loss_ts')
        plt.legend()
        plt.show() 
        
        plt.plot(np.arange(1,runs+1), epoch_error1, label='Error_loss')
#         plt.plot(np.arange(1,runs+1), epoch_error2, label='Error_loss2')
        plt.legend()
        plt.show() 

### Training function to update all parameters while matching the observables

In [None]:
def train_for_obs_timescale_all(exp_value_ts, obs_value, exp_value, weight_loss_ts=1, weight_loss=1, runs=100, verbose=False, plot_training=False):
        
         
    
    epoch_loss = np.zeros(runs)
    epoch_vampe = np.zeros(runs)
    epoch_error = np.zeros(runs)
    epoch_error1 = np.zeros(runs)
#     epoch_error2 = np.zeros(runs)
    
    opt = optimizer_rev_cor
    obs_value_tensor = []
    exp_value_tensor = []
    for ob in obs_value:
        obs_value_tensor.append(torch.Tensor(ob))
    for ex in exp_value:
        exp_value_tensor.append(ex.detach())
        
    
    for epoch in range(runs):  # loop over the dataset multiple times
    
        opt.zero_grad()

        # estimate weights
#             print(inputs_t)

#             print(chi_t)
        chi_t = Full_net(tensor_train_X1_cor)
        chi_tau = Full_net(tensor_train_X2_cor)
        score_curr, S_similar, K, eigval_fold, mu_t, Sigma_t = vampe_loss_rev_obs(chi_t, chi_tau)
        
        error_obs = timescale_loss(eigval_fold, exp_value_ts)
        
        error_obs1 = obs_loss(obs_value_tensor[0], mu_t, exp_value_tensor[0])
        
#         error_obs1 = obs_time_loss(obs_value_tensor[0], mu_t, chi_t, K, Sigma_t, exp_value_tensor[0])
#         error_obs2 = obs_time_loss(obs_value_tensor[1], mu_t, chi_t, K, Sigma_t, exp_value_tensor[1])
        
        loss = - score_curr + weight_loss * error_obs1 + weight_loss_ts * error_obs

        loss.backward()

        opt.step()
    
        epoch_loss[epoch] = np.mean(-loss.item())
        epoch_vampe[epoch] = np.mean(score_curr.item())
        epoch_error[epoch] = np.mean(error_obs.item())
        epoch_error1[epoch] = np.mean(error_obs1.item())
#         epoch_error2[epoch] = np.mean(error_obs2.item())
        if verbose:
            print('Run {}, total loss: {:.3}'.format(epoch+1, epoch_loss[epoch]))

    if plot_training:
        plt.plot(np.arange(1,runs+1), epoch_loss, label='Total_loss')
        plt.legend()
        plt.show() 
        
        plt.plot(np.arange(1,runs+1), epoch_vampe, label='VAMP_loss')
        plt.legend()
        plt.show() 
        
        plt.plot(np.arange(1,runs+1), epoch_error, label='Error_loss')
        plt.legend()
        plt.show() 
        
        plt.plot(np.arange(1,runs+1), epoch_error1, label='Error_loss')
#         plt.plot(np.arange(1,runs+1), epoch_error2, label='Error_loss2')
        plt.legend()
        plt.show() 
        

### Train only for the expectation value

In [None]:
def train_for_obs(obs_value, exp_value, weight_loss=1, runs=100, S_flag=True, u_flag=False, verbose=False, plot_training=False):
        
        
    chi_t = Full_net(tensor_train_X1_cor).detach()
    chi_tau = Full_net(tensor_train_X2_cor).detach()
    
    epoch_loss = np.zeros(runs)
    epoch_vampe = np.zeros(runs)
    epoch_error1 = np.zeros(runs)
#     epoch_error2 = np.zeros(runs)
    if u_flag and S_flag:
        opt = optimizer_rev_u_S_cor
    elif u_flag:
        opt = optimizer_rev_u_cor
    else:
        opt = optimizer_rev_S_cor
    obs_value_tensor = []
    exp_value_tensor = []
    for ob in obs_value:
        obs_value_tensor.append(torch.Tensor(ob))
    for ex in exp_value:
        exp_value_tensor.append(ex.detach())
    
    for epoch in range(runs):  # loop over the dataset multiple times
    
        opt.zero_grad()

        # estimate weights
#             print(inputs_t)

#             print(chi_t)
        score_curr, mu_t = vampe_loss_rev(chi_t, chi_tau, return_mu=True)
        
        error_obs1 = obs_loss(obs_value_tensor[0], mu_t, exp_value_tensor[0])
#         error_obs2 = obs_time_loss(obs_value_tensor[1], mu_t, chi_t, K, Sigma_t, exp_value_tensor[1])
        
        loss = - score_curr + weight_loss * error_obs1# + weight_loss * error_obs2

        loss.backward()

        opt.step()
    
        epoch_loss[epoch] = np.mean(-loss.item())
        epoch_vampe[epoch] = np.mean(score_curr.item())
        epoch_error1[epoch] = np.mean(error_obs1.item())
#         epoch_error2[epoch] = np.mean(error_obs2.item())
        if verbose:
            print('Run {}, total loss: {:.3}'.format(epoch+1, epoch_loss[epoch]))

    if plot_training:
        plt.plot(np.arange(1,runs+1), epoch_loss, label='Total_loss')
        plt.legend()
        plt.show() 
        
        plt.plot(np.arange(1,runs+1), epoch_vampe, label='VAMP_loss')
        plt.legend()
        plt.show() 
        
        plt.plot(np.arange(1,runs+1), epoch_error1, label='Error_loss')
#         plt.plot(np.arange(1,runs+1), epoch_error2, label='Error_loss2')
        plt.legend()
        plt.show() 

In [None]:
def train_for_obs_all(obs_value, exp_value, weight_loss=1, runs=100, verbose=False, plot_training=False):
        

    epoch_loss = np.zeros(runs)
    epoch_vampe = np.zeros(runs)
    epoch_error1 = np.zeros(runs)
#     epoch_error2 = np.zeros(runs)
    opt = optimizer_rev_cor
    
    obs_value_tensor = []
    exp_value_tensor = []
    for ob in obs_value:
        obs_value_tensor.append(torch.Tensor(ob))
    for ex in exp_value:
        exp_value_tensor.append(ex.detach())
    
    for epoch in range(runs):  # loop over the dataset multiple times
    
        opt.zero_grad()

        # estimate weights
#             print(inputs_t)

#             print(chi_t)
        chi_t = Full_net(tensor_train_X1_cor)
        chi_tau = Full_net(tensor_train_X2_cor)
        score_curr, mu_t = vampe_loss_rev(chi_t, chi_tau, return_mu=True)
        
        error_obs1 = obs_loss(obs_value_tensor[0], mu_t, exp_value_tensor[0])
#         error_obs2 = obs_time_loss(obs_value_tensor[1], mu_t, chi_t, K, Sigma_t, exp_value_tensor[1])
        
        loss = - score_curr + weight_loss * error_obs1# + weight_loss * error_obs2

        loss.backward()

        opt.step()
    
        epoch_loss[epoch] = np.mean(-loss.item())
        epoch_vampe[epoch] = np.mean(score_curr.item())
        epoch_error1[epoch] = np.mean(error_obs1.item())
#         epoch_error2[epoch] = np.mean(error_obs2.item())
        if verbose:
            print('Run {}, total loss: {:.3}'.format(epoch+1, epoch_loss[epoch]))

    if plot_training:
        plt.plot(np.arange(1,runs+1), epoch_loss, label='Total_loss')
        plt.legend()
        plt.show() 
        
        plt.plot(np.arange(1,runs+1), epoch_vampe, label='VAMP_loss')
        plt.legend()
        plt.show() 
        
        plt.plot(np.arange(1,runs+1), epoch_error1, label='Error_loss')
        plt.plot(np.arange(1,runs+1), epoch_error2, label='Error_loss2')
        plt.legend()
        plt.show() 

In [None]:
def eval_obs(tensor_X1, tensor_X2, exp_value_ts, obs_value, exp_value, weight_loss_ts=1., weight_loss=1.):
    obs_value_tensor = []
    exp_value_tensor = []
    for ob in obs_value:
        obs_value_tensor.append(torch.Tensor(ob))
    for ex in exp_value:
        exp_value_tensor.append(ex.detach())
        
    
    chi_t = Full_net(tensor_X1)
    chi_tau = Full_net(tensor_X2)
    score_curr, mu_t = vampe_loss_rev(chi_t, chi_tau, return_mu=True)
        
    error_obs1 = obs_loss(obs_value_tensor[0], mu_t, exp_value_tensor[0])
#         error_obs2 = obs_time_loss(obs_value_tensor[1], mu_t, chi_t, K, Sigma_t, exp_value_tensor[1])
        
    # + weight_loss * error_obs2

#     error_obs1 = obs_time_loss(obs_value_tensor[0], mu_t, chi_t, K, Sigma_t, exp_value_tensor[0])
#     error_obs2 = obs_time_loss(obs_value_tensor[1], mu_t, chi_t, K, Sigma_t, exp_value_tensor[1])
    _, S_similar, K_test, eigval_fold, _, _ = vampe_loss_rev_obs(chi_t, chi_tau)
    eigval_fold = eigval_fold.detach()
    error_obs = timescale_loss(eigval_fold, exp_value_ts)
    loss = - score_curr + weight_loss * error_obs1 + weight_loss_ts * error_obs
#     loss = - score_curr + weight_loss * error_obs1 + weight_loss * error_obs2 + weight_loss_ts * error_obs
    
    return -loss.item()


In [None]:
Full_net.set_weights(weights_msm)

In [None]:
prob_states_before = plot_mu(tensor_test_X1_cor, tensor_test_X2_cor, frames_test)

In [None]:
chi_t_test = Full_net(torch.Tensor(X1_test_cor))
chi_tau_test = Full_net(torch.Tensor(X2_test_cor))
exp_value_time_before = []
for obs_v in obs_test:
    obs_value_tensor = torch.Tensor(obs_v)
    exp_value_tensor = torch.Tensor(np.array([0.]))

    score_curr, mu_t = vampe_loss_rev(chi_t_test, chi_tau_test, return_mu=True)
    error_obs = obs_loss(obs_value_tensor, mu_t, exp_value_tensor)
    exp_value_time_before.append(error_obs.detach().numpy())
exp_value_time_before = np.array(exp_value_time_before).squeeze()

### Now train also for the observable

In [None]:
change_S = True
retrain_chi = True
reset = True

In [None]:
# Reset the weights
Full_net.set_weights(weights_dict_chi)

if reset:
    with torch.no_grad():
        for param in Full_net.u_layers[0].parameters():
            param.copy_(torch.Tensor(np.ones((1, output_sizes[0]))))
    
                
                
optimizer_rev_cor = optim.Adam(Full_net.get_params_rev(), lr=learning_rate)
optimizer_rev_u_S_cor = optim.Adam(Full_net.get_params_rev(all=False), lr=learning_rate)
optimizer_rev_S_cor = optim.Adam(Full_net.get_params_rev(all=False, u_flag=False), lr=learning_rate*1)
optimizer_rev_u_cor = optim.Adam(Full_net.get_params_rev(all=False, S_flag=False), lr=learning_rate*1)

In [None]:
# how much should the exp values be enforced
weight_loss_list= np.linspace(0.1,3,10)
weight_loss = 10.
weight_loss_ts = 10.#10.

In [None]:
if change_S:
    if reset:
        with torch.no_grad():
            for param in Full_net.S_layers[0].parameters():
                param.copy_(torch.Tensor(np.ones((output_sizes[0], output_sizes[0]))))

### First train u and S

In [None]:
score_best_valid = -5#eval_obs(tensor_valid_X1_cor, tensor_valid_X2_cor, true_value, weight_loss=weight_loss)
score_best_train = -5#eval_obs(tensor_train_X1_cor, tensor_train_X2_cor, true_value, weight_loss=weight_loss)
for _ in range(150):
    if change_S:
        train_for_obs_timescale(true_value, obs_train, exp_exp_value, weight_loss=weight_loss, weight_loss_ts=weight_loss_ts, runs=100, S_flag=True, u_flag=True, verbose=False, plot_training=False)
        score_temp_valid = eval_obs(tensor_valid_X1_cor, tensor_valid_X2_cor, true_value, obs_valid, exp_exp_value, weight_loss_ts=weight_loss_ts, weight_loss=weight_loss)
        score_temp_train = eval_obs(tensor_train_X1_cor, tensor_train_X2_cor, true_value, obs_train, exp_exp_value, weight_loss=weight_loss, weight_loss_ts=weight_loss_ts)
    
        K_sim_before3 = get_K_rev(tensor_train_X1_cor, tensor_train_X2_cor)
        print(np.linalg.eigvals(K_sim_before3), np.linalg.eigvals(K_exp) )
        print('The score for training is: {:.5}, and for valid: {:.5}'.format(score_temp_train, score_temp_valid))
        if score_temp_valid > score_best_valid:
            print('Better validation score, save weights')
            score_best_valid = score_temp_valid
            weights_temp_best_valid = Full_net.get_weights()

        if score_temp_train > score_best_train:
            print('Better training score, save weights')
            score_best_train = score_temp_train
            weights_temp_best_train = Full_net.get_weights()

In [None]:
Full_net.set_weights(weights_temp_best_valid)

### Train everything

In [None]:
if (change_S and retrain_chi):
    for c in range(150):
       
        train_for_obs_timescale_all(true_value, obs_train, exp_exp_value, weight_loss_ts=weight_loss_ts, weight_loss=weight_loss, runs=10, verbose=False, plot_training=False)
        score_temp_valid = eval_obs(tensor_valid_X1_cor, tensor_valid_X2_cor, true_value, obs_valid, exp_exp_value, weight_loss_ts=weight_loss_ts, weight_loss=weight_loss)
        score_temp_train = eval_obs(tensor_train_X1_cor, tensor_train_X2_cor, true_value, obs_train, exp_exp_value, weight_loss=weight_loss, weight_loss_ts=weight_loss_ts)
        print('The score for training is: {:.5}, and for valid: {:.5}'.format(score_temp_train, score_temp_valid))
        if score_temp_valid > score_best_valid:
            print('Better validation score, save weights')
            score_best_valid = score_temp_valid
            weights_temp_best_valid = Full_net.get_weights()
            print(plot_mu(tensor_test_X1_cor, tensor_test_X2_cor, frames_test), prob_states_true)
        if score_temp_train > score_best_train:
            print('Better training score, save weights')
            score_best_train = score_temp_train
            weights_temp_best_train = Full_net.get_weights()
        
        train_for_obs_timescale(true_value, obs_train, exp_exp_value, weight_loss_ts=weight_loss_ts, weight_loss=weight_loss, runs=100, S_flag=True, u_flag=True, verbose=False, plot_training=False)

        K_sim_before3 = get_K_rev(tensor_train_X1_cor, tensor_train_X2_cor)
        print(np.linalg.eigvals(K_sim_before3), np.linalg.eigvals(K_exp) )
        score_temp_valid = eval_obs(tensor_valid_X1_cor, tensor_valid_X2_cor, true_value, obs_valid, exp_exp_value, weight_loss_ts=weight_loss_ts, weight_loss=weight_loss)
        score_temp_train = eval_obs(tensor_train_X1_cor, tensor_train_X2_cor, true_value, obs_train, exp_exp_value, weight_loss=weight_loss, weight_loss_ts=weight_loss_ts)
        print('The score for training is: {:.5}, and for valid: {:.5}'.format(score_temp_train, score_temp_valid))
        if score_temp_valid > score_best_valid:
            print('Better validation score, save weights')
            score_best_valid = score_temp_valid
            weights_temp_best_valid = Full_net.get_weights()
            print(plot_mu(tensor_test_X1_cor, tensor_test_X2_cor, frames_test), prob_states_true)
        if score_temp_train > score_best_train:
            print('Better training score, save weights')
            score_best_train = score_temp_train
            weights_temp_best_train = Full_net.get_weights()
            
        

In [None]:
Full_net.set_weights(weights_temp_best_valid)

### Final training for only u and S

In [None]:
for _ in range(100):
    if change_S:
        train_for_obs_timescale(true_value, obs_train, exp_exp_value, weight_loss_ts=weight_loss_ts, weight_loss=weight_loss, runs=100, S_flag=True, u_flag=True, verbose=False, plot_training=False)
        score_temp_valid = eval_obs(tensor_valid_X1_cor, tensor_valid_X2_cor, true_value, obs_valid, exp_exp_value, weight_loss_ts=weight_loss_ts, weight_loss=weight_loss)
        score_temp_train = eval_obs(tensor_train_X1_cor, tensor_train_X2_cor, true_value, obs_train, exp_exp_value, weight_loss=weight_loss, weight_loss_ts=weight_loss_ts)
        K_sim_before3 = get_K_rev(tensor_train_X1_cor, tensor_train_X2_cor)
        print(np.linalg.eigvals(K_sim_before3), np.linalg.eigvals(K_exp) )
        print('The score for training is: {:.5}, and for valid: {:.5}'.format(score_temp_train, score_temp_valid))
        if score_temp_valid > score_best_valid:
            print('Better validation score, save weights')
            score_best_valid = score_temp_valid
            weights_temp_best_valid = Full_net.get_weights()
            print(plot_mu(tensor_test_X1_cor, tensor_test_X2_cor, frames_test), prob_states_true)
            
        if score_temp_train > score_best_train:
            print('Better training score, save weights')
            score_best_train = score_temp_train
            weights_temp_best_train = Full_net.get_weights()

In [None]:
weights_temp_end = Full_net.get_weights()
Full_net.set_weights(weights_temp_best_valid)
# Full_net.set_weights(weights_temp_best_train)

In [None]:
chi_t = Full_net(torch.Tensor(X1_test_cor))
chi_tau = Full_net(torch.Tensor(X2_test_cor))


_, S_similar, K_test, eigval_fold, _, _ = vampe_loss_rev_obs(chi_t, chi_tau)


exp_value_time_after = []
for obs_v in obs_test:
    obs_value_tensor = torch.Tensor(obs_v)
    exp_value_tensor = torch.Tensor(np.array([0.]))

    score_curr, mu_t = vampe_loss_rev(chi_t, chi_tau, return_mu=True)
    error_obs = obs_loss(obs_value_tensor, mu_t, exp_value_tensor)
    exp_value_time_after.append(error_obs.detach().numpy())
exp_value_time_after = np.array(exp_value_time_after).squeeze()

print(eigval_fold)

In [None]:
true_value

In [None]:
K_sim_after = get_K_rev(tensor_test_X1_cor, tensor_test_X2_cor)
np.linalg.eigvals(K_sim_after), np.linalg.eigvals(K_exp) 

In [None]:
# Estimate stationary distribution of states
prob_states_after = plot_mu(tensor_test_X1_cor, tensor_test_X2_cor, frames_test)

In [None]:
np.linalg.eigvals(K_msm)

### Define plotting functions

In [None]:
fac_unit = 200.*skip*1e-6
def plot_obs(ts=True):
    
    eigvals_true = np.linalg.eigvals(K_exp)
    eigvals_true = np.sort(eigvals_true)[:-1]
    if ts:
        its_true = - tau / np.log(eigvals_true) * fac_unit
    else:
        its_true = eigvals_true
        
    eigvals_before = np.linalg.eigvals(K_msm)
    eigvals_before = np.sort(eigvals_before)[:-1]
    if ts:
        its_before = - tau / np.log(np.abs(eigvals_before)) * fac_unit
    else:
        its_before = eigvals_before
        
    eigvals_after = np.linalg.eigvals(K_sim_after)
    eigvals_after = np.sort(eigvals_after)[:-1]
    if ts:
        its_after = - tau / np.log(eigvals_after) * fac_unit
    else:
        its_after = eigvals_after
    
    for i, it in enumerate(its_true):
        if i==0:
            label='True'
        else:
            label=''
        plt.hlines(its_true[i], i-0.25, i+0.25, ls='--', colors='k', label=label, lw=2)
    
    plt.plot(its_before,'v', label='Before', ms=10)
    plt.plot(its_after,'o', label='After', ms=10)
    plt.title('Implied Timescales')
    indicator = ['I', 'II']
    plt.xticks([0,1],indicator, fontsize=14)
    plt.xlabel('Process', fontsize=16)
    plt.ylabel('Implied Timescale [$\mu$s]', fontsize=16)
    plt.legend()
    plt.show()
    
    
def plot_prob_states():
    
    for i, pi_i in enumerate(prob_states_true):
        if i==0:
            label='True'
        else:
            label=''
        plt.hlines(pi_i*100, i-0.25, i+0.25, ls='--', colors='k', label=label, lw=2)
        
    plt.plot(prob_states_before*100,'v', label='Before', ms=10)
    plt.plot(prob_states_after*100,'o', label='After', ms=10)
    plt.title('State Probability')
    indicator = ['I', 'II', 'III']
    plt.xticks([0,1,2],indicator, fontsize=14)
    plt.xlabel('State', fontsize=16)
    plt.ylabel('Probability [$\%$]', fontsize=16)
    plt.legend()
    plt.show()
    
    
def plot_obs_train():
    
    for i, pi_i in enumerate(exp_exp_value):
        if i==0:
            label='True'
        else:
            label=''
        
        plt.hlines(pi_i.numpy(), i-0.25, i+0.25, ls='--', colors='k', label=label, lw=2)
        
    
    plt.plot(exp_value_time_before,'v', label='Before', ms=10)
    plt.plot(exp_value_time_after,'o', label='After', ms=10)
    plt.title('Observable')
    indicator = ['I']
    plt.xticks([0],indicator, fontsize=14)
    plt.xlabel('Observable', fontsize=16)
    plt.ylabel('Value', fontsize=16)
    plt.legend()
    plt.show()

### Plot results

In [None]:
plot_obs()
plot_prob_states()
plot_obs_train()