### Notebook to construct a deep reversible Markov State Model with an attention mechanism and coarse-graining layers

The training steps include:
1. Pretraining of an ordinary VAMPnet (can be used to compare the results)
2. Initiate the values for the paramter matrix $\mathbf{S}$ and the reweighting vector $\mathbf{u}$
3. Training for $\mathbf{u}$ and $\mathbf{S}$
4. Training for $\boldsymbol{\chi}$ and $\mathbf{u}$ and $\mathbf{S}$
5. For estimating timescales only $\mathbf{u}$ and $\mathbf{S}$ have to be retrained
6. For coarse-graining the parameters will be initialized by PCCA and then trained with the VAMP-E score 

The analysis consists of:
1. Validation of the model via implied timescale and CK tests
2. Estimating a network graph based on the estimated stationary distribution and transition matrix
3. Analysis of the eigenfunctions
4. Building a hierarchical model via the coarse-graining layers
5. Saving structures with the help of $\textit{mdtraj}$



In [None]:
# import all the relevant packages
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline

import torch.nn as nn
import torch.nn.functional as F
import torch
from torch.utils import data
import torch.optim as optim
import matplotlib.gridspec as gridspec
# mdtraj will be needed to save structures, pyemma will be needed for the PCCA initialization!

In [None]:
# Check for cuda supported device
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# Assuming that we are on a CUDA machine, this should print a CUDA device:

print(device)

In [None]:
# Define where the data lies on the local machine
# Needs to be adapted for own data!!!
test_system = '2F4K'
pdb_system = '2f4k_villin.pdb'

root = '/group/ag_cmb/simulation-data/DESRES-Science2011-FastProteinFolding/DESRES-Trajectory_{0}-0-protein/{0}-0-protein/{0}-0-protein'.format(test_system)

In [None]:
# Load the data
output_all_files = np.load('/srv/public/andreas/data/desres/2f4k/villin_skip1.npy')
traj_whole = output_all_files

traj_data_points, input_size = traj_whole[0].shape
# Skip data to make the data less correlated
skip=25
traj_whole_new = [traj_whole[0][::skip]]

In [None]:
# Hyperparameter definitions, should be adapted for specific problems

# number of output nodes/states of the MSM or Koopman model, therefore also nodes of chi
# The list defines how the output will be coarse grained from first to last entry
output_sizes = [4,3,2]
# tau list for timescales estimation
tau_list = [25,50,100,150,200]
number_taus = len(tau_list)
tau_list = np.array(tau_list)//skip

# Tau, how much is the timeshift of the two datasets in the default training
# tau for pretraining the vampnet usually smaller than the tau for the deepMSM
tau = 5*2*25//skip # 5, 20
tau_chi = 25//skip

# Batch size for Stochastic Gradient descent
batch_size = 10000
batch_size_large = 20000

# Which trajectory points percentage is used as training, validation, and rest for test
train_ratio = 0.7
valid_ratio = 0.2

# How many hidden layers the network chi has
network_depth = 4

# Width of every layer of chi
layer_width = 100

# Learning rate used for the ADAM optimizer
learning_rate = 5e-4

# create a list with the number of nodes for each layer
nodes = [layer_width]*network_depth

# epsilon for numerical inversion of correlation matrices
epsilon = np.array(1e-7).astype('float32')

In [None]:
# function for data generation

def get_data_for_tau(traj_all, tau):
    
    for count, single_traj in enumerate(traj_all):
        
        
    
        if count == 0:
            traj_ord = single_traj[:-tau]
            traj_ord_lag = single_traj[tau:]
        else:
            traj_ord = np.concatenate((traj_ord, single_traj[:-tau]), axis=0)
            traj_ord_lag = np.concatenate((traj_ord_lag, single_traj[tau:]), axis=0)
    
    length_data = traj_ord.shape[0]
    
    shuffle_indexes = np.arange(length_data)
    np.random.shuffle(shuffle_indexes)

    traj = traj_ord[shuffle_indexes]
    traj_lag = traj_ord_lag[shuffle_indexes]
    

    length_train = int(np.floor((length_data) * train_ratio))
    length_vali = int(np.floor((length_data) * valid_ratio))
    
    traj_data_train = traj[:length_train]
    traj_data_train_lag = traj_lag[:length_train]
    
    end_vali = length_train+length_vali
    traj_data_valid = traj[length_train:end_vali]
    traj_data_valid_lag = traj_lag[length_train:end_vali]

    
    traj_data_test = traj[end_vali:]
    traj_data_test_lag = traj_lag[end_vali:]
    
    # Input of the first network
    X1_train = traj_data_train.astype('float32')
    X2_train  = traj_data_train_lag.astype('float32')

    # Input for validation
    X1_vali = traj_data_valid.astype('float32')
    X2_vali = traj_data_valid_lag.astype('float32')
    
    # Input for test
    X1_test = traj_data_test.astype('float32')
    X2_test = traj_data_test_lag.astype('float32')
    
    return X1_train, X2_train, X1_vali, X2_vali, X1_test, X2_test, length_train, length_vali, traj_ord, traj_ord_lag




### Helper functions

In [None]:
def estimate_koopman_op(trajs, tau, force_symmetric = False):
    '''Estimates the koopman operator for a given trajectory at the lag time
        specified. The formula for the estimation is:
            K = C00 ^ -1 @ C01

    Parameters
    ----------
    traj: numpy array with size [traj_timesteps, traj_dimensions]
        Trajectory described by the returned koopman operator

    tau: int
        Time shift at which the koopman operator is estimated
        
    force_symmetric: boolean, default = False
        if true, calculates the symmetrized version of K instead

    Returns
    -------
    koopman_op: numpy array with shape [traj_dimensions, traj_dimensions]
        Koopman operator estimated at timeshift tau

    '''
    # if tau larger 0, interpret trajs as either a list of trajectories
    # or a single trajectory
    # otherwise interpret trajs as a list of the data and time-lagged data
    # possibly in random order
    if tau > 0:
        if type(trajs) == list:
            traj = np.concatenate([t[:-tau] for t in trajs], axis = 0)
            traj_lag = np.concatenate([t[tau:] for t in trajs], axis = 0)
        else:
            traj = trajs[:-tau]
            traj_lag = trajs[tau:]
    else:
        traj = trajs[0]
        traj_lag = trajs[1]
        
    koopman_op = np.eye(traj.shape[1])

    c_0 = traj.T @ traj
    c_tau = traj.T @ traj_lag
    
    # if you want to symmetrize the correlation matrices
    if force_symmetric:
        c_0 = c_0 + traj_lag.T @ traj_lag
        c_tau = c_tau + traj_lag.T @ traj

    eigv_all, eigvec_all = np.linalg.eig(c_0)
    include = eigv_all > epsilon
    eigv = eigv_all[include]
    eigvec = eigvec_all[:,include]
    c0_inv = eigvec @ np.diag(1/eigv) @ np.transpose(eigvec)

    koopman_op = c0_inv @ c_tau

    return koopman_op



# utility function for plotting implied timescales
# the hyperparameters allow for it to calculate errorbars and work from different input data types
def get_its(data, lags, calculate_K = True, multiple_runs = False):
    
    def get_single_its(data):

        if type(data) == list:
            outputsize = data[0].shape[1]
        else:
            outputsize = data.shape[1]

        single_its = np.zeros((outputsize-1, len(lags)))

        for t, tau_lag in enumerate(lags):
            if calculate_K:
                koopman_op = estimate_koopman_op(data, tau_lag)
            else:
                koopman_op = data[t]
            k_eigvals, k_eigvec = np.linalg.eig(np.real(koopman_op))
            k_eigvals = np.sort(np.absolute(k_eigvals))
            k_eigvals = k_eigvals[:-1]
            single_its[:,t] = (-tau_lag / np.log(k_eigvals))

        return np.array(single_its)


    if not multiple_runs:

        its = get_single_its(data)

    else:

        its = []
        for data_run in data:
            its.append(get_single_its(data_run))

    return its

In [None]:
def _inv(x, ret_sqrt=False):
    '''Utility function that returns the inverse of a matrix, with the
    option to return the square root of the inverse matrix.
    Parameters
    ----------
    x: numpy array with shape [m,m]
        matrix to be inverted

    ret_sqrt: bool, optional, default = False
        if True, the square root of the inverse matrix is returned instead
    Returns
    -------
    x_inv: numpy array with shape [m,m]
        inverse of the original matrix
    '''

    # Calculate eigvalues and eigvectors
    eigval_all, eigvec_all = torch.symeig(x, eigenvectors=True)

    # Filter out eigvalues below threshold and corresponding eigvectors
    eig_th = torch.Tensor(epsilon)
    index_eig = eigval_all > eig_th
#     print(index_eig)
    eigval = eigval_all[index_eig]
    eigvec = eigvec_all[:,index_eig]

    # Build the diagonal matrix with the filtered eigenvalues or square
    # root of the filtered eigenvalues according to the parameter
    if ret_sqrt:
        diag = torch.diag(torch.sqrt(1/eigval))
    else:
        diag = torch.diag(1/eigval)
#     print(diag.shape, eigvec.shape)    
    # Rebuild the square root of the inverse matrix
    x_inv = torch.matmul(eigvec, torch.matmul(diag, eigvec.T))

    return x_inv


def _prep_data(data_t, data_tau):
    '''Utility function that transorms the input data from a tensorflow - 
    viable format to a structure used by the following functions in the
    pipeline.
    Parameters
    ----------
    data: tensorflow tensor with shape [b, 2*o]
        original format of the data
    Returns
    -------
    x: tensorflow tensor with shape [o, b]
        transposed, mean-free data corresponding to the left, lag-free lobe
        of the network

    y: tensorflow tensor with shape [o, b]
        transposed, mean-free data corresponding to the right, lagged lobe
        of the network

    b: tensorflow float32
        batch size of the data

    o: int
        output size of each lobe of the network

    '''


    # Subtract the mean
    x = data_t #- torch.mean(data_t, dim=0, keepdim=True)
    y = data_tau# - torch.mean(data_tau, dim=0, keepdim=True)

    return x, y


### Definition of the classes for the model and the attention network

In [None]:
class Mask(torch.nn.Module):
    ''' Attention mask either independent from the time point (mask_const=True) or dependent.
    If dependent the attention is estimated via a NN with depth and width given as input, which are 
    otherwise ignored. 
    The attention mechanism assumes that distances are used. skip_res is number of residues skiped when estimating
    the distance. 
    
    
    '''
    def __init__(self, input_size, mask_const, depth=0, width=100 , patchsize=4, sensitivity=1, fac=True,
                noise=0.):
        super(Mask, self).__init__()
        
#         self.alpha = torch.Tensor(1, input_size, N, heads).fill_(0)

        skip_res = 3
        self.noise = noise
        self.n_residues = int(-1/2 + np.sqrt(1/4+input_size*2) + skip_res)
        
        self.bs_per_res = [[] for _ in range(patchsize)]
        
        self.residues_1 = []
        self.residues_2 = []
        self.patchsize = patchsize
        self.number_weights = self.n_residues + (patchsize-1)
        # estimate the pairs
        for n1 in range(self.n_residues):
            for i in range(patchsize):
                self.bs_per_res[i].append(n1+i)
            
        for n1 in range(self.n_residues-skip_res):
            for n2 in range(n1+skip_res, self.n_residues):
                self.residues_1.append(n1)
                self.residues_2.append(n2)
                        
                        
        self.mask_const = mask_const
        
        if mask_const:        
            self.alpha = torch.randn((1, self.number_weights)) * 0.5
            self.weight = torch.nn.Parameter(data=self.alpha, requires_grad=True)
        else:
            nodes = [input_size]
            for i in range(depth):
                nodes.append(width)
            
            
            self.hfc = [nn.Linear(nodes[i], nodes[i+1]) for i in range(len(nodes)-1)]
            self.softmax = nn.Linear(nodes[-1], self.number_weights, bias=True)
            self.layers = nn.ModuleList(self.hfc)
            
        self.sensitivity = sensitivity
        if fac:
            self.fac = self.number_weights
        else:
            self.fac = 1.
    def forward(self, x):
        
        # weights for each residue
        if self.mask_const:

            weights_for_res = []
            for i in range(self.patchsize): # get all weights b for each residue
                weights_for_res.append(self.weight[None,:,self.bs_per_res[i]])
                
            weights_for_res = torch.prod(torch.cat(weights_for_res, dim=0), dim=0) # take the product of the b factors
            weights_for_res = F.softmax(weights_for_res, dim=1) # take the softmax over the residues
            weights_for_res = torch.sum(weights_for_res, dim=2)  # take the mean over the heads
        else:
            y = x
            for layer in self.hfc:
                y = F.elu(layer(y))
            y = self.softmax(y) 
#             y = torch.reshape(y, (-1, self.number_weights))
            # make the factors all positive
            y = F.softmax(y, dim=1)*self.fac
#             y = F.elu(y)+1
            weights_for_res = []
            for i in range(self.patchsize): # get all weights b for each residue
                weights_for_res.append(y[None,:,self.bs_per_res[i]])
                
            weights_for_res = torch.prod(torch.cat(weights_for_res, dim=0), dim=0) # take the product of the b factors
            weights_for_res = F.softmax(weights_for_res, dim=1) # take the softmax over the residues
            # take the mean over the heads
            
        
            
#         print(torch.cat(weight_all_1, dim=0).shape)
        weight_1 = weights_for_res[:,self.residues_1]
        weight_2 = weights_for_res[:,self.residues_2]
#         print(weight_1.shape)
        alpha = weight_1 * weight_2 * self.n_residues**2
        masked_x = x * alpha
        
        if self.noise > 0.:
            max_attention_value = torch.max(alpha, dim=1, keepdim=True)[0].detach()
#             shape = (x.shape[0], alpha.shape[1], alpha.shape[2])
            shape = alpha.shape
            random_numbers = torch.randn(shape, device=device) * self.noise
            masked_x += (1 - alpha/max_attention_value) * random_numbers
        
        return masked_x
    
    def get_softmax(self, x=None):
        if self.mask_const:
            weights_for_res = []
            for i in range(self.patchsize): # get all weights b for each residue
                weights_for_res.append(self.weight[None,:,self.bs_per_res[i]])
                
            weights_for_res = torch.prod(torch.cat(weights_for_res, dim=0), dim=0) # take the product of the b factors
            weights_for_res = F.softmax(weights_for_res, dim=1) # take the softmax over the residues
        else:
            y = x
            for layer in self.hfc:
                y = layer(y)
            y = self.softmax(y)
            y = F.softmax(y, dim=1)*self.fac
#             y = F.elu(y)+1
            weights_for_res = []
            for i in range(self.patchsize): # get all weights b for each residue
                weights_for_res.append(y[None,:,self.bs_per_res[i]])
                
            weights_for_res = torch.prod(torch.cat(weights_for_res, dim=0), dim=0) # take the product of the b factors
            weights_for_res = F.softmax(weights_for_res, dim=1) # take the softmax over the residues
            
#         weight_sf = torch.sum(F.softmax(self.weight*self.sensitivity, dim=self.dim), dim=3)
        
        return weights_for_res
    
    def set_weights(self, weights):
        if self.mask_const:
            alpha = torch.Tensor(weights)

            self.weight = torch.nn.Parameter(data=alpha, requires_grad=True)
        else:
            print('not implemented yet. You need to define all layers in the mask')

            

            
class Coarse_grain(torch.nn.Module):
    ''' Attention mask either independent from the time point (mask_const=True) or dependent.
    If dependent the attention is estimated via a NN with depth and width given as input, which are 
    otherwise ignored. 
    The attention mechanism assumes that distances are used. skip_res is number of residues skiped when estimating
    the distance. 
    
    
    '''
    def __init__(self, input_dim, output_dim, sen=1):
        super(Coarse_grain, self).__init__()

        self.N = input_dim
        self.M = output_dim
        self.sen = sen
        
        self.alpha = torch.randn((self.N, self.M)) * 0.5
        self.weight = torch.nn.Parameter(data=self.alpha, requires_grad=True)
        
    def forward(self, x):
        
        
        kernel = F.softmax(self.sen * self.weight, dim=1)
        
        ret = x @ kernel

        return ret
    
    def get_softmax(self):
        
        return F.softmax(self.sen * self.weight, dim=1)
    
    def get_cg_uS(self, chi_n, chi_tau_n, u_n, S_n, u_t_n, renorm):
        
        batchsize = chi_n.shape[0]
        M = F.softmax(self.sen * self.weight, dim=1)
        
        chi_t_m = chi_n @ M
        chi_tau_m = chi_tau_n @ M
        
        # estimate the pseudo inverse of M
        U, S_vec, V = torch.svd(M)
        s_nonzero = S_vec > 0
        s_zero = S_vec <= 0
        S_star = torch.cat((1/S_vec[s_nonzero], S_vec[s_zero]))
        U_star = torch.cat((U[:,s_nonzero], U[:,s_zero]), dim=1)
        V_star = torch.cat((V[:,s_nonzero], V[:,s_zero]), dim=1)
        G = V_star @ torch.diag(S_star) @ U_star.T
        
        # estimate the new u and S
        u_m = (G @ u_n.T).T
        # renormalize
        chi_mean = torch.mean(chi_tau_m, dim=0, keepdim=True)
        u_m = u_m / torch.sum(chi_mean * u_m, dim=1, keepdim=True)
        
        u_t_m = (G @ u_t_n.T).T
        chi_mean_t = torch.mean(chi_t_m, dim=0, keepdim=True)
        u_t_m = u_t_m / torch.sum(chi_mean_t * u_t_m, dim=1, keepdim=True)
        
        W1 = G @ S_n @ G.T
        #renormalize
        batchsize = chi_n.shape[0]
        corr_tau = 1./batchsize * torch.matmul(chi_tau_m.T, chi_tau_m)
        v = torch.matmul(corr_tau, u_m.T)
        norm = W1 @ v
        
        
        w2 = (1 - torch.squeeze(norm)) / torch.squeeze(v)
        S_temp = W1 + torch.diag(w2)
        if renorm:
            
            if (S_temp<0).sum()>0: # check if actually non-negativity is violated
                
                # make sure that the largest value of norm is < 1
                quasi_inf_norm = lambda x: torch.sum((x**20))**(1./20)
    #             print(norm, quasi_inf_norm(norm))
                W1 = W1 / quasi_inf_norm(norm)
                norm = W1 @ v
                
                w2 = (1 - torch.squeeze(norm)) / torch.squeeze(v)
                S_temp = W1 + torch.diag(w2)
                
                
        S_m = S_temp
        
        
        # estimate the VAMP-E matrix and other helpful instances
        mu = 1./batchsize * torch.matmul(chi_tau_m, u_m.T)
        Sigma =  torch.matmul((chi_tau_m * mu).T, chi_tau_m)
        
        
        mu_t = 1./batchsize * torch.matmul(chi_t_m, u_t_m.T)
        Sigma_t =  torch.matmul((chi_t_m * mu_t).T, chi_t_m)

        gamma = chi_tau_m * (torch.matmul(chi_tau_m, u_m.T))

        C_00 = 1./batchsize * torch.matmul(chi_t_m.T, chi_t_m)
        C_11 = 1./batchsize * torch.matmul(gamma.T, gamma)
        C_01 = 1./batchsize * torch.matmul(chi_t_m.T, gamma)
        
        
        K = S_m @ Sigma

        # VAMP-E matrix for the computation of the loss
        VampE_matrix = S_m.T @ C_00 @ S_m @ C_11 - 2*S_m.T @ C_01
        
        ret = [
            chi_t_m, 
            chi_tau_m, 
            u_m,
            u_t_m,
            S_m,
            mu_t,
            Sigma_t,
            K,
            VampE_matrix
        ]
        return ret
            
    def reset_params(self):
        
        with torch.no_grad():
            
            
            self.weight.copy_(torch.randn((self.N, self.M)) * 0.5) 
        
#         alpha = torch.randn((self.N, self.M)) * 0.5
#         self.weight = torch.nn.Parameter(data=alpha, requires_grad=True)
        
class U_layer(torch.nn.Module):
    ''' Attention mask either independent from the time point (mask_const=True) or dependent.
    If dependent the attention is estimated via a NN with depth and width given as input, which are 
    otherwise ignored. 
    The attention mechanism assumes that distances are used. skip_res is number of residues skiped when estimating
    the distance. 
    
    
    '''
    def __init__(self, output_dim, activation):
        super(U_layer, self).__init__()

        self.M = output_dim
        
        self.alpha = torch.Tensor(1, self.M).fill_(1/self.M)
        self.u_kernel = torch.nn.Parameter(data=self.alpha, requires_grad=True)
        self.acti = activation
        
    def forward(self, chi_t, chi_tau):
        
        # we need batchsize to stack the outputs later so that it fullfills keras' requirements
        batchsize = chi_t.shape[0]

        # note: corr_tau is the correlation matrix of the time-shifted data
        # presented in the paper at page 6, "Normalization of transition density"
        corr_tau = 1./batchsize * torch.matmul(chi_tau.T, chi_tau)
        chi_mean = torch.mean(chi_tau, dim=0, keepdim=True)

        kernel_u = self.acti(self.u_kernel)

        # u is the normalized and transformed kernel of this layer
        u = kernel_u / torch.sum(chi_mean * kernel_u, dim=1, keepdim=True)

        v = torch.matmul(corr_tau, u.T)

        mu = 1./batchsize * torch.matmul(chi_tau, u.T)
        
        Sigma =  torch.matmul((chi_tau * mu).T, chi_tau)
        
        chi_mean_t = torch.mean(chi_t, dim=0, keepdim=True)
        u_t = kernel_u / torch.sum(chi_mean_t * kernel_u, dim=1, keepdim=True)
        mu_t = 1./batchsize * torch.matmul(chi_t, u_t.T)
        Sigma_t =  torch.matmul((chi_t * mu_t).T, chi_t)

        gamma = chi_tau * (torch.matmul(chi_tau, u.T))

        C_00 = 1./batchsize * torch.matmul(chi_t.T, chi_t)
        C_11 = 1./batchsize * torch.matmul(gamma.T, gamma)
        C_01 = 1./batchsize * torch.matmul(chi_t.T, gamma)


        ret = [
            u,
            u_t,
            v,
            C_00,
            C_11,
            C_01,
            Sigma,
            mu_t,
            Sigma_t
        ]
        
        return ret
    
    
class S_layer(torch.nn.Module):
    ''' Attention mask either independent from the time point (mask_const=True) or dependent.
    If dependent the attention is estimated via a NN with depth and width given as input, which are 
    otherwise ignored. 
    The attention mechanism assumes that distances are used. skip_res is number of residues skiped when estimating
    the distance. 
    
    
    '''
    def __init__(self, output_dim, activation, renorm=True):
        super(S_layer, self).__init__()

        self.M = output_dim
        
        self.alpha = torch.Tensor(self.M, self.M).fill_(0.1)
        self.S_kernel = torch.nn.Parameter(data=self.alpha, requires_grad=True)
        self.acti = activation
        self.renorm = renorm
        
    def forward(self, v, C_00, C_11, C_01, Sigma):
        
        # we need batchsize to stack the outputs later so that it fullfills keras' requirements
            
        batchsize = v.shape[0]

        # transform the kernel weights
        kernel_w = self.acti(self.S_kernel)
        
        # enforce symmetry
        W1 = kernel_w + kernel_w.T

        # normalize the weights
        norm = W1 @ v
        
        
        w2 = (1 - torch.squeeze(norm)) / torch.squeeze(v)
        S_temp = W1 + torch.diag(w2)
        if self.renorm:
            
#             if (S_temp<0).sum()>0: # check if actually non-negativity is violated
                
            # make sure that the largest value of norm is < 1
            quasi_inf_norm = lambda x: torch.sum((x**20))**(1./20)
#             print(norm, quasi_inf_norm(norm))
            W1 = W1 / quasi_inf_norm(norm)
            norm = W1 @ v

            w2 = (1 - torch.squeeze(norm)) / torch.squeeze(v)
            S_temp = W1 + torch.diag(w2)
                
                
        S = S_temp

        # calculate K
        K = S @ Sigma

        # VAMP-E matrix for the computation of the loss
        VampE_matrix = S.T @ C_00 @ S @ C_11 - 2*S.T @ C_01

        # stack outputs so that the first dimension is = batchsize, keras requirement
        
        ret = [VampE_matrix, K, S]
        
        return ret
            
class VampNet(nn.Module):
    ''' VAMPnet class
    TODO:
    - revVAMP
    - revDMSM
    - DMSM
    - VAMP
    
    inputs:
    
    input_size: (int) size of the input features
    output_sizes: (list & int) list of output sizes. If more than one elemtent, expects coarse graining
    nodes: (list & int) list of output size of hidden layers
    
    
    '''
    def __init__(self, input_size, output_sizes, nodes, train_mean, train_std, 
                 valid_T=False, reversible=False,
                 mask_const=True, mask_depth=0, mask_width=0, patchsize=4, sensitivity=1, fac=True,
                 noise=0.,
                 softmax_fac=1.):
        super(VampNet, self).__init__()
        
        # which physical constraints are enacted which can need more parameters with different contraints
        # Furthermore the best learning practice is different
        self.valid_T = valid_T
        self.reversible = reversible
        if valid_T:
            model='DMSM'
        else:
            model='VAMPnet'
        if reversible:
            rev='rev'
        else:
            rev=''
        print('The trained model is a '+rev+model) 
        
        if valid_T and not reversible:
            self.gamma = True
        else:
            self.gamma = False
        if reversible:
            if valid_T:
                self.renorm = True
                acti_S = torch.nn.Softplus()
                acti_u = torch.nn.Softplus()
                
            else: 
                self.factor_S = 0.001
                self.factor_u = .000001

                linear_S = lambda x: self.factor_S * x
                linear_u = lambda x: self.factor_u * x
                self.renorm = False
                acti_S = linear_S
                acti_u = linear_u
            
            
            self.u_layers = [U_layer(o, acti_u) for o in output_sizes]
            self.S_layers = [S_layer(o, acti_S, self.renorm) for o in output_sizes]
            self.u_layers_tensor = nn.ModuleList(self.u_layers)
            self.S_layers_tensor = nn.ModuleList(self.S_layers)
                
# activations for reversible VAMPnets


#activations for revDMSM

        
        
        self.N = len(output_sizes)
        self.output_sizes = output_sizes
        
        self.softmax_fac = softmax_fac
        self.fac = fac
        self.nodes = nodes
        self.mask_const = mask_const
        self.Mask = Mask(input_size, mask_const, mask_depth, mask_width, patchsize, sensitivity, fac=fac, noise=noise)
        
        self.hfc = nn.ModuleList([nn.Linear(input_size, nodes[0])])
        for i in range(len(nodes)-1):
            self.hfc.append(nn.Linear(nodes[i], nodes[i+1]))
            
        
        self.fc_softmax = nn.Linear(nodes[-1], output_sizes[0])
        
        self.gamma_layer = nn.Linear(nodes[-1], output_sizes[0])
            # is just needed once, not for coarse graining, since it convert the same as chi
        
        self.coarse_grain_layer = [Coarse_grain(output_sizes[n], output_sizes[n+1]).to(device) for n in range(len(output_sizes)-1)]
        self.coarse_grain_layer_tensor = nn.ModuleList(self.coarse_grain_layer)
        self.train_mean = torch.Tensor(train_mean).to(device)
        self.train_std = torch.Tensor(train_std).to(device)
    
    def forward_before_sm(self, x):
        
        x = (x-self.train_mean)/self.train_std
        shape = x.shape
        b = shape[0]
        
        x = self.Mask(x) # b x input_size

        
        for layer_list_i in self.hfc:

            x = F.elu(layer_list_i(x))
            
        return x
        
    def forward(self, x):
        
        x = self.forward_before_sm(x)

        x_output = F.softmax(self.softmax_fac*self.fc_softmax(x), dim=1)
        
        return x_output
    
    def forward_cg(self, x, id):
        
        x_output = self.coarse_grain_layer[id](x)
        
        return x_output
    
    def forward_all(self, x):
        outputs = []
        
        x = self.forward(x)
        
        outputs.append(x)
        
        for layer in self.coarse_grain_layer:
            x = layer(x)
            outputs.append(x)
            
        return outputs
    
    def forward_gamma(self, x, whole=True):
        if whole:
            x = self.forward_before_sm(x)
        x_gamma = F.relu(self.gamma_layer(x)) # plus 1 so it is always positive
        
        return x_gamma
    
        
    def get_attention(self, x=None):
        return self.Mask.get_softmax(x=x)
    
    def set_soft_fac(self, new_value):
        self.softmax_fac = torch.Tensor(new_value)
        
    def set_soft_fac_cg(self, id_cg, new_value):
        
        self.coarse_grain_layer[id_cg].sen = new_value
    
    def get_params_vamp(self):
        if self.mask_const:
            for param in self.Mask.parameters():
                yield param
        else:
            for layer in self.Mask.hfc:
                for param in layer.parameters():
                    yield param
            for param in self.Mask.softmax.parameters():
                yield param
                    
        for layer in self.hfc:
            
            for param in layer.parameters():
                yield param
        
        for param in self.fc_softmax.parameters():
            yield param
    
    def get_params_DMSM(self, all=True):
        if all:
            if self.mask_const:
                for param in self.Mask.parameters():
                    yield param
            else:
                for layer in self.Mask.hfc:
                    for param in layer.parameters():
                        yield param
                for param in self.Mask.softmax.parameters():
                    yield param

            for layer in self.hfc:

                for param in layer.parameters():
                    yield param

            for param in self.fc_softmax.parameters():
                yield param
        for param in self.gamma_layer.parameters():
            yield param

    def get_params_rev(self, all=True, u_flag=True, S_flag=True):
        if all:
            if self.mask_const:
                for param in self.Mask.parameters():
                    yield param
            else:
                for layer in self.Mask.hfc:
                    for param in layer.parameters():
                        yield param
                for param in self.Mask.softmax.parameters():
                    yield param

            for layer in self.hfc:

                for param in layer.parameters():
                    yield param

            for param in self.fc_softmax.parameters():
                yield param
        if u_flag:
            for param in self.u_layers[0].parameters():
                yield param
        if S_flag:
            for param in self.S_layers[0].parameters():
                yield param
    
    def get_params_all(self, u_flag=False, S_flag=False):
        if self.mask_const:
            for param in self.Mask.parameters():
                yield param
        else:
            for layer in self.Mask.hfc:
                for param in layer.parameters():
                    yield param
            for param in self.Mask.softmax.parameters():
                yield param
                    
        for layer in self.hfc:
            
            for param in layer.parameters():
                yield param
        
        for param in self.fc_softmax.parameters():
            yield param
        for layer in self.coarse_grain_layer:
            for param in layer.parameters():
                yield param
                
        if u_flag:
            for param in self.u_layers[0].parameters():
                yield param
        if S_flag:
            for param in self.S_layers[0].parameters():
                yield param
        
    
    def get_params_wo_mask(self):
        for layer in self.hfc:
            for param in layer.parameters():
                yield param
       
        for param in self.fc_softmax.parameters():
            yield param
    
    def get_params_softmax(self):
        
        for param in self.fc_softmax.parameters():
            yield param
            
    
    def get_params_mask(self):
        if self.mask_const:
            for param in self.Mask.parameters():
                yield param
        else:
            for layer in self.Mask.hfc:
                for param in layer.parameters():
                    yield param
            for param in self.Mask.softmax.parameters():
                yield param
            
    def get_params_cg(self, index=[]):
        for id in index:
            for param in self.coarse_grain_layer[id].parameters():
                yield param
    
    def get_params_cg_rev(self, index=[], all=True):
        
        for id in index:
            if all:
                for param in self.coarse_grain_layer[id].parameters():
                    yield param
            for param in self.u_layers[id+1].parameters():
                yield param
            for param in self.S_layers[id+1].parameters():
                yield param
                
                
    def set_rev_var(self, layer_id=0, S=True):
        
        chi_t = torch.Tensor(pred_batchwise(tensor_train_X1)).to(device)
        chi_tau = torch.Tensor(pred_batchwise(tensor_train_X2)).to(device)
    
        for i in range(layer_id):
            chi_t = self.forward_cg(chi_t, i)
            chi_tau = self.forward_cg(chi_tau, i)
            
        Data_chi_X = chi_t.detach().to('cpu').numpy()
        Data_chi_Y = chi_tau.detach().to('cpu').numpy()
        fullbatch = Data_chi_X.shape[0]


        c_0 = 1/fullbatch * Data_chi_X.T @ Data_chi_X
        c_tau = 1/fullbatch * Data_chi_X.T @ Data_chi_Y
        c_1 = 1/fullbatch * Data_chi_Y.T @ Data_chi_Y

        eigv_all, eigvec_all = np.linalg.eigh(c_0)
        print(eigv_all)
        include = eigv_all > epsilon
        eigv = eigv_all[include]
        eigvec = eigvec_all[:,include]
        c0_inv = eigvec @ np.diag(1/eigv) @ np.transpose(eigvec)

        K_vamp = c0_inv @ c_tau

        # estimate pi, the stationary distribution vector
        eigv, eigvec = np.linalg.eig(K_vamp.T)
        ind_pi = np.argmin((eigv-1)**2)

        pi_vec = np.real(eigvec[:,ind_pi])
        pi = pi_vec / np.sum(pi_vec, keepdims=True)
        print('pi', pi)
        # reverse the consruction of u 
        u_optimal = c0_inv @ pi
        print('u optimal', u_optimal)
        if self.valid_T:
            u_kernel = np.log(np.exp(np.abs(u_optimal))-1)
        else:
            u_kernel = u_optimal / self.factor_u
        
        with torch.no_grad():
            for param in self.u_layers[layer_id].parameters():
            
                param.copy_(torch.Tensor(u_kernel[None,:]))  
            
        if S:
            
            _, _, _, _, _, _, Sigma_input, _, _ = self.u_layers[layer_id](chi_t, chi_tau)
            Sigma = Sigma_input.detach().to('cpu').numpy()
            
            eigv_all, eigvec_all = np.linalg.eigh(Sigma)
            include = eigv_all > epsilon
            eigv = eigv_all[include]
            eigvec = eigvec_all[:,include]
            sigma_inv = eigvec @ np.diag(1/eigv) @ np.transpose(eigvec)

            # reverse the construction of S
            S_nonrev = K_vamp @ sigma_inv
            S_rev_add = 1/2 * (S_nonrev + S_nonrev.T)
            if self.valid_T:
                kernel_S = S_rev_add / 2.
                kernel_S = np.log(np.exp(np.abs(kernel_S))-1)
            else:
                kernel_S = S_rev_add / 2. / self.factor_S

            with torch.no_grad():
                for param in self.S_layers[layer_id].parameters():

                    param.copy_(torch.Tensor(kernel_S)) 
            
    def get_weights(self):
        
        weights_dict = {}
        weights_dict['hfc'] = []
        weights_dict['sm'] = []

        for layer in  self.hfc:
            for param in layer.parameters():
                weights_dict['hfc'].append(param.detach().to('cpu').numpy().copy())

        for param in self.fc_softmax.parameters():
            weights_dict['sm'].append(param.detach().to('cpu').numpy().copy())

        if self.reversible:
            weights_dict['S'] = [param.detach().to('cpu').numpy().copy() for param in self.S_layers[0].parameters()]
            weights_dict['u'] = [param.detach().to('cpu').numpy().copy() for param in self.u_layers[0].parameters()]

        if self.mask_const:
            weights_dict['Mask'] = [param.detach().to('cpu').numpy().copy() for param in self.Mask.parameters()]
        else:
            weights_dict['Mask_hf'] = []
            weights_dict['Mask_sm'] = []

            for layer in self.Mask.hfc:
                for param in layer.parameters():
                    weights_dict['Mask_hf'].append(param.detach().to('cpu').numpy().copy())
            for param in self.Mask.softmax.parameters():
                weights_dict['Mask_sm'].append(param.detach().to('cpu').numpy().copy())

        weights_dict['cg'] = []
        weights_dict['S_cg'] = []
        weights_dict['u_cg'] = []
        for i, layer in enumerate(self.coarse_grain_layer):
            for param in layer.parameters():
                weights_dict['cg'].append(param.detach().to('cpu').numpy().copy())
            weights_dict['S_cg'].append([param.detach().to('cpu').numpy().copy() for param in self.S_layers[i+1].parameters()][0])
            weights_dict['u_cg'].append([param.detach().to('cpu').numpy().copy() for param in self.u_layers[i+1].parameters()][0])
        weights_dict['train_mean'] = self.train_mean.to('cpu').numpy()
        weights_dict['train_std'] = self.train_std.to('cpu').numpy()
        
        return weights_dict
    
    def set_weights(self, weights_dict):
        self.train_mean = torch.Tensor(weights_dict['train_mean']).to(device)
        self.train_std = torch.Tensor(weights_dict['train_std']).to(device)
        with torch.no_grad():
            i = 0
            for layer in self.hfc:
                for param in layer.parameters():    
                    param.copy_(torch.Tensor(weights_dict['hfc'][i])) 
                    i+=1
            i = 0
            for param in self.fc_softmax.parameters():
                param.copy_(torch.Tensor(weights_dict['sm'][i]))
                i+=1

            if self.reversible:
                i=0
                for param in self.S_layers[0].parameters():
                    param.copy_(torch.Tensor(weights_dict['S'][i]))
                    i+=1
                i=0
                for param in self.u_layers[0].parameters():
                    param.copy_(torch.Tensor(weights_dict['u'][i]))
                    i+=1

            if self.mask_const:
                i=0
                for param in self.Mask.parameters():
                    param.copy_(torch.Tensor(weights_dict['Mask'][i]))
                    i+=1
            else:
                i=0
                for layer in self.Mask.hfc:
                    for param in layer.parameters():
                        param.copy_(torch.Tensor(weights_dict['Mask_hf'][i]))
                        i+=1
                i=0
                for param in self.Mask.softmax.parameters():
                    param.copy_(torch.Tensor(weights_dict['Mask_sm'][i]))
                    i+=1
            i=0
            for layer in self.coarse_grain_layer:

                for param in layer.parameters():
                    param.copy_(torch.Tensor(weights_dict['cg'][i]))
                for param in self.S_layers[i+1].parameters():
                    param.copy_(torch.Tensor(weights_dict['S_cg'][i]))
                for param in self.u_layers[i+1].parameters():
                    param.copy_(torch.Tensor(weights_dict['u_cg'][i]))
                i+=1

### Definition of the VAMP-2 and VAMP-E score

In [None]:
def VAMP_score(chi_t, chi_tau, corr=False):
    '''Calculates the VAMP-2 score with respect to the network lobes while 
    symmetrizing the correlation matrices. Can be used as a loss function
    for keras models.
    Parameters
    ----------
    chi_t: tensorflow tensor.
        parameter not needed for the calculation, added to comply with Keras
        rules for loss fuctions format.

    chi_tau: tensorflow tensor with shape [batch_size, 2 * output_size]
        output of the two lobes of the network

    Returns
    -------
    loss_score: tensorflow tensor with shape [1].
    '''
    shape = chi_t.shape
    
        
    batch_size = shape[0]

#     weights = weights * (1-torch.sum(chi_t*chi_tau, dim=1, keepdim=True))


    x, y = _prep_data(chi_t, chi_tau) 

    # Calculate the covariance matrices
    cov_00 = 1/(batch_size - 1) * torch.matmul(x.T, x) 
    cov_11 = 1/(batch_size - 1) * torch.matmul(y.T, y)
    cov_01 = 1/(batch_size - 1) * torch.matmul(x.T, y)
    
#         print(cov_00)
    # Calculate the inverse of the self-covariance matrices
    cov_00_inv = _inv(cov_00, ret_sqrt = True)
    cov_11_inv = _inv(cov_11, ret_sqrt = True)
    

    # Estimate Vamp-matrix
    vamp_matrix = torch.matmul(cov_00_inv, torch.matmul(cov_01, cov_11_inv))
    
    
    vamp_score = torch.norm(vamp_matrix)
#     u, sing_values, v = torch.svd(vamp_matrix, compute_uv=True)
       
#     ind = sing_values<=1.
#     score = sing_values[ind].sum().unsqueeze(0)
    score = (vamp_score**2).unsqueeze(0)
    if corr:
        return score, torch.trace(cov_00)
    else:
        return score

In [None]:
def vampe_loss(chi, gamma):
    b = chi.shape[0]
    
    c00 = 1/b*(torch.matmul(chi.T, chi))
    c11 = 1/b*(torch.matmul(gamma.T, gamma))
    c01 = 1/b*(torch.matmul(chi.T, gamma))

    gamma_dia_inv = torch.diag(1/(torch.mean(gamma, dim=0)))  # add something so no devide by zero

    first_term = c00 @ gamma_dia_inv @ c11 @ gamma_dia_inv
    second_term = 2 * (c01 @ gamma_dia_inv)
    vampe_arg = first_term - second_term
    vampe = torch.trace(vampe_arg)
    
    return -vampe

### Definition of the losses needed for the reversible model if trained for different instances

In [None]:
def vampe_loss_rev(chi_t, chi_tau, layer_id=0, return_mu=False, return_mu_K_Sigma=False):
    
    u, u_t, v, C_00, C_11, C_01, Sigma, mu_t, Sigma_t = Full_net.u_layers[layer_id](chi_t, chi_tau)
    matrix, K, _ = Full_net.S_layers[layer_id](v, C_00, C_11, C_01, Sigma)
    vampe = torch.trace(matrix)
    
    if return_mu:
        return -vampe, mu_t
    
    elif return_mu_K_Sigma:
        return -vampe, mu_t, K, Sigma_t
    
    else:
        return -vampe
    
def vampe_loss_rev_only_S(v, C_00, C_11, C_01, Sigma, layer_id=0):
    
    matrix, K, _ = Full_net.S_layers[layer_id](v, C_00, C_11, C_01, Sigma)
#     print(K)
    vampe = torch.trace(matrix)
    
    return -vampe

def vampe_loss_rev_cg(chi_t, chi_tau, u, S, u_t, layer_id=0, return_mu=False, return_mu_K_Sigma=False, renorm=True):
    
    
    # only this line should be the part of it
    chi_t_m, chi_tau_m, u_m, u_t_m, S_m, mu_t, Sigma_t, K, VampE_matrix = Full_net.coarse_grain_layer[layer_id].get_cg_uS(
                                                                        chi_t, chi_tau, u, S, u_t, renorm)
    
    vampe = torch.trace(VampE_matrix)
    
    if return_mu:
        return -vampe, mu_t
    
    elif return_mu_K_Sigma:
        return -vampe, mu_t, K, Sigma_t
    
    else:
        return -vampe

In [None]:
# plot coarse graining matrix
def plot_cg(id):
    attention = Full_net.coarse_grain_layer[id].get_softmax()
    attention_np = attention.detach().to('cpu').numpy()
    plt.imshow(attention_np)
    plt.xlabel('From State', fontsize=18)
    plt.ylabel('To State', fontsize=18)
    plt.show()

In [None]:
# transform a trajectory which might not fit into memory at once, predict batchwise
def pred_batchwise(traj, batchsize=10000):
    
    data_size = traj.shape[0]
    batches = data_size//batchsize
    pred_all = []
    for i in range(batches):
        s = batchsize*i
        e = s+batchsize
        pred_temp = Full_net.forward(torch.Tensor(traj[s:e]).to(device)).detach().to('cpu').numpy()
        pred_all.append(pred_temp)
    if batches==0:
        pred_all.append(Full_net.forward(torch.Tensor(traj).to(device)).detach().to('cpu').numpy())
    else:
        pred_all.append(Full_net.forward(torch.Tensor(traj[e:]).to(device)).detach().to('cpu').numpy())
    
    return np.concatenate(pred_all, axis=0)

In [None]:
# plotting the mask
n_residues = int(-1/2 + np.sqrt(1/4+input_size*2) + 3)
def plot_mask(return_values=False, skip=5, vmax=1, top=10):
    n_residues = int(-1/2 + np.sqrt(1/4+input_size*2) + 3)
    if Full_net.mask_const:
        attention = Full_net.get_attention()
        attention_np = attention.detach().to('cpu').numpy()
        att_atom = np.reshape(attention_np, (n_residues,1))
        plt.imshow(att_atom, vmin=0, vmax=vmax, aspect='auto')
        plt.xlabel('System', fontsize=18)
        plt.ylabel('Input', fontsize=18)
        plt.xticks(np.arange(1),['{}'.format(i) for i in range(1)], fontsize=16)
        plt.yticks(np.arange(0,n_residues,skip),['x{}'.format(i) for i in range(0,n_residues,skip)], fontsize=16)
        plt.show()
    #     plt.savefig('./Figs/2x3_mix_Mask.pdf', bbox_inches='tight')
        if return_values:
            return att_atom
        
    else:
        pred_temp = pred_batchwise(traj_whole_new[0], batchsize=10000)
        arg_sort = np.argsort(pred_temp, axis=0)
        top_x_state = arg_sort[-top:]
        states = pred_temp.shape[1]
        att_atom = []
        for state in range(states):
            frames = top_x_state[:,state]
            attention = Full_net.get_attention(torch.Tensor(traj_whole_new[0][frames]).to(device))
            attention_np = attention.detach().to('cpu').numpy()
            att_atom.append(np.mean(attention_np, axis=0, keepdims=True))
        att_atom = np.concatenate(att_atom)
        
        plt.imshow(att_atom.T, vmin=0, vmax=vmax, aspect='auto')
        plt.xlabel('State', fontsize=18)
        plt.ylabel('Input', fontsize=18)
        plt.xticks(np.arange(states),['{}'.format(i) for i in range(states)], fontsize=16)
        plt.yticks(np.arange(0,n_residues,skip),['x{}'.format(i) for i in range(0,n_residues,skip)], fontsize=16)
        plt.show()
    #     plt.savefig('./Figs/2x3_mix_Mask.pdf', bbox_inches='tight')
        if return_values:
            return att_atom
        
# plot_mask(vmax=2, top=10)


### Training loops for the different models

In [None]:
def train_for_VAMPnet(runs, opt_list, weight_corr=1., plot_mask_every=10, verbose=True, plot_training=True, corr=False, best_weights_flag=False):
    
    epoch_loss = np.zeros(runs)
    epoch_loss_corr = np.zeros(runs)
    epoch_loss_valid = np.zeros(runs)
    epoch_loss_valid_corr = np.zeros(runs)
#     sen = np.linspace(.1,2,10)
    noise = Full_net.Mask.noise
    if best_weights_flag:
        best_score=0.
        best_weights = Full_net.get_weights()
    for epoch in range(runs):  # loop over the dataset multiple times

#         Full_net.
#         opt = optimizer_vamp
    #     sen_temp = sen_set[epoch//sen_every]

#         Full_net.set_soft_fac(sen[[epoch]])


        running_epoch_loss = []
        running_epoch_corr = []
        for i, data_batch in enumerate(train_l, 0):
            # get the inputs; data is a list of [inputs, labels]
            inputs_t, inputs_tau = data_batch

            # zero the parameter gradients
            for opt in opt_list:
                opt.zero_grad()

            # estimate weights
            chi_t = Full_net(inputs_t[0].to(device))
            chi_tau = Full_net(inputs_tau[0].to(device))

            score_list = VAMP_score(chi_t, chi_tau, corr=corr)

            if corr:
                score_curr = score_list[0]
                score_corr = score_list[1]
                loss = -score_curr - score_corr * weight_corr
                running_epoch_corr.append(score_corr.item())
            else:
                score_curr = score_list
                loss = - score_curr

            loss.backward()
            for opt in opt_list:
                opt.step()

            running_epoch_loss.append(score_curr.item())

        # validation
        Full_net.Mask.noise = 0.
        chi_t_vali = torch.Tensor(pred_batchwise(tensor_valid_X1)).to(device).detach()
        chi_tau_vali = torch.Tensor(pred_batchwise(tensor_valid_X2)).to(device).detach()
        Full_net.Mask.noise = noise
        loss_vali_list = VAMP_score(chi_t_vali, chi_tau_vali, corr=corr)
        del chi_t_vali
        del chi_tau_vali
        del chi_t
        del chi_tau
        if corr:
            loss_vali = loss_vali_list[0].detach()
            loss_vali_corr = loss_vali_list[1].detach()
            epoch_loss_valid_corr[epoch] = loss_vali_corr.item()
        else:
            loss_vali = loss_vali_list.detach()
        epoch_loss_valid[epoch] = loss_vali.item()
        if best_weights_flag:
            if epoch_loss_valid[epoch]>best_score:
                print('Better validation score, save weights')
                best_weights = Full_net.get_weights()
                best_score = epoch_loss_valid[epoch]
#         print(running_epoch_loss, running_epoch_corr)
        epoch_loss[epoch] = np.mean(running_epoch_loss)
        epoch_loss_corr[epoch] = np.mean(running_epoch_corr)
        if epoch_loss_corr[epoch]>0.98:
            weight_corr=0.
        if verbose:
            print('Run {}, total loss: {:.3}, valid loss: {:.3}'.format(epoch+1, epoch_loss[epoch], epoch_loss_valid[epoch]))

#         if (((epoch+1) % plot_mask_every)==0):
#             plot_mask(vmax=1/n_residues*10)
    if plot_training:
        plt.plot(np.arange(1,runs+1), epoch_loss, label='VAMP_loss')
        plt.plot(np.arange(1,runs+1), epoch_loss_valid, label='VAMP_loss_valid')
        plt.legend()
        plt.show()
        if corr:
            plt.plot(np.arange(1,runs+1), epoch_loss_corr, label='Corr')
            plt.plot(np.arange(1,runs+1), epoch_loss_corr, label='Corr_valid')
            plt.legend()
            plt.show()
    if best_weights_flag:
        print('Set best weights')
        Full_net.set_weights(best_weights)

In [None]:
def train_for_cg(id, runs, plot_mask_every=10, verbose=True, plot_training=True):
    
    epoch_loss = np.zeros(runs)
    
    chi_X1_train = torch.Tensor(pred_batchwise(tensor_train_X1))
    chi_X2_train = torch.Tensor(pred_batchwise(tensor_train_X2))
    
    for i in range(id):
        chi_X1_train = Full_net.forward_cg(chi_X1_train.to(device), i)
        chi_X2_train = Full_net.forward_cg(chi_X2_train.to(device), i)
    
    chi_X1_train = chi_X1_train.detach()
    chi_X2_train = chi_X2_train.detach()
    
    for epoch in range(runs):  # loop over the dataset multiple times


        opt = optimizer_cg[id]
    #     sen_temp = sen_set[epoch//sen_every]

    #     Full_net.Mask.sensitivity = sen_temp


        running_epoch_loss = []

        
        # get the inputs; data is a list of [inputs, labels]
#         inputs_t, inputs_tau = data_batch

        # zero the parameter gradients

        opt.zero_grad()

        # estimate weights
#             print(inputs_t)
        chi_t = Full_net.forward_cg(chi_X1_train.to(device), id)
        chi_tau = Full_net.forward_cg(chi_X2_train.to(device), id)

#             print(chi_t)

        score_curr = VAMP_score(chi_t, chi_tau)



        loss = - score_curr

        loss.backward()

        opt.step()

        running_epoch_loss.append(score_curr.item())


        epoch_loss[epoch] = np.mean(running_epoch_loss)
        if verbose:
            print('Run {}, total loss: {:.3}'.format(epoch+1, epoch_loss[epoch]))

        if (((epoch+1) % plot_mask_every)==0):
            plot_cg(id)
    if plot_training:
        plt.plot(np.arange(1,runs+1), epoch_loss, label='VAMP_loss')
        plt.legend()
        plt.show()    

In [None]:
def train_for_cg_rev(id, runs, plot_mask_every=10, verbose=True, plot_training=True):
    
    epoch_loss = np.zeros(runs)
    
    chi_X1_train = torch.Tensor(pred_batchwise(tensor_train_X1)).to(device)
    chi_X2_train = torch.Tensor(pred_batchwise(tensor_train_X2)).to(device)
    
    u, u_t, v, C_00, C_11, C_01, Sigma, mu_t, Sigma_t = Full_net.u_layers[0](chi_X1_train, chi_X2_train)
    
    matrix, K, S = Full_net.S_layers[0](v, C_00, C_11, C_01, Sigma)
    
    
    for i in range(id):
        chi_X1_train, chi_X2_train, u, u_t, S, mu_t, Sigma_t, K, VampE_matrix = Full_net.coarse_grain_layer[i].get_cg_uS(
                                                                        chi_X1_train, chi_X2_train, u, S, u_t, Full_net.renorm)
    
    chi_X1_train = chi_X1_train.detach()
    chi_X2_train = chi_X2_train.detach()
    u = u.detach()
    u_t = u_t.detach()
    S = S.detach()
    
    
#     trainset = data.TensorDataset(chi_X1_train, chi_X2_train) # create your datset
# #     trainloader = data.DataLoader(trainset, batch_size=batch_size,
# #                                   shuffle=True, num_workers=2)

#     trainloader_full = data.DataLoader(trainset, batch_size=X1_train.shape[0],
#                                   shuffle=True, num_workers=2)
    
    for epoch in range(runs):  # loop over the dataset multiple times


        opt = optimizer_cg[id]
    #     sen_temp = sen_set[epoch//sen_every]

    #     Full_net.Mask.sensitivity = sen_temp


        running_epoch_loss = []

        
        # get the inputs; data is a list of [inputs, labels]
        inputs_t, inputs_tau = chi_X1_train, chi_X2_train

        # zero the parameter gradients

        opt.zero_grad()

        # estimate weights
#             print(inputs_t)
#             print(chi_t)

        score_curr = vampe_loss_rev_cg(inputs_t, inputs_tau, u, S, u_t, id, 
                                       return_mu=False, return_mu_K_Sigma=False, renorm=Full_net.renorm)



        loss = - score_curr

        loss.backward()

        opt.step()

        running_epoch_loss.append(score_curr.item())


        epoch_loss[epoch] = np.mean(running_epoch_loss)
        if verbose:
            print('Run {}, total loss: {:.3}'.format(epoch+1, epoch_loss[epoch]))

        if (((epoch+1) % plot_mask_every)==0):
            plot_cg(id)
    if plot_training:
        plt.plot(np.arange(1,runs+1), epoch_loss, label='VAMP_loss')
        plt.legend()
        plt.show()    



In [None]:
# number of models to be trained
runs=10
weights_chi_list = []


for r in range(runs):
    
    X1_train, X2_train, X1_vali, X2_vali, X1_test, X2_test, length_train, length_vali, _, _ = get_data_for_tau(traj_whole_new, tau_chi)

    tensor_train_X1 = torch.Tensor(X1_train)
    tensor_train_X2 = torch.Tensor(X2_train) # transform to torch tensor
    tensor_valid_X1 = torch.Tensor(X1_vali)
    tensor_valid_X2 = torch.Tensor(X2_vali)
    tensor_test_X1 = torch.Tensor(X1_test)
    tensor_test_X2 = torch.Tensor(X2_test)

    trainset = data.TensorDataset(tensor_train_X1, tensor_train_X2) # create your datset

    trainloader = data.DataLoader(trainset, sampler=data.BatchSampler(data.RandomSampler(trainset), batch_size=batch_size, drop_last=True))
    # trainloader = data.DataLoader(trainset, batch_size=batch_size,
    #                               shuffle=True, num_workers=2)
    trainloader_full = data.DataLoader(trainset, sampler=data.BatchSampler(data.RandomSampler(trainset), batch_size=X1_train.shape[0], drop_last=True))
    # trainloader_full = data.DataLoader(trainset, batch_size=batch_size_large,
    #                               shuffle=True, num_workers=2)

    testset = data.TensorDataset(tensor_valid_X1, tensor_valid_X2) # create your datset
    testloader = data.DataLoader(testset, sampler=data.BatchSampler(data.RandomSampler(testset), batch_size=X1_vali.shape[0], drop_last=True))
    # testloader = data.DataLoader(testset, batch_size=batch_size,
    #                              shuffle=True, num_workers=2)

    full_batch = False
    if full_batch:
        train_l = trainloader_full
    else:
        train_l = trainloader


    # store this so that all the data are transformed into a whitened dataset w.r.t. the same mean and std
    # we assume a gaussian distribution of our data
    train_mean = X1_train.mean(0)
    train_std = X1_train.std(0)



    noise=1.
    fac_bf_sm = True
    valid_T=True # if valid transition matrix is enforced
    reversible=True # if reversibility is enforced

    # attention stuff
    mask_const=False # if the trained attention mask is constant over time
    patchsize=4
    mask_depth=4 # if time dependent how many hidden layers has the attention network
    mask_width=layer_width # the width of the attention hidden layers
    sensitivity=1. # factor before attention softmax for clearer assignment
    factor_att=True # if to use a factor which scales the input on average back to input

    softmax_fac=1. # factor before classification softmax


    Full_net = VampNet(input_size, output_sizes, nodes, train_mean, train_std, 
                     valid_T=valid_T, reversible=reversible,
                     mask_const=mask_const, mask_depth=mask_depth, mask_width=mask_width, patchsize=patchsize, sensitivity=sensitivity,
                     fac=factor_att, noise=noise,
                     softmax_fac=softmax_fac)
    Full_net.to(device)
    
    optimizer_vamp = optim.Adam(Full_net.get_params_vamp(), learning_rate/2)
    # opt_list = [optimizer_vamp, optimizer_vamp_mask]
    opt_list = [optimizer_vamp]
    train_for_VAMPnet(60, opt_list, corr=True)
    train_for_VAMPnet(200, opt_list, corr=False, best_weights_flag=True)
    
    weights_chi_list.append(Full_net.get_weights())

In [None]:
steps = 8
tau_msm = tau
tau_ck = np.arange(1,(steps+1))*tau_msm

In [None]:
lag = np.concatenate([np.array([3, 5]), tau_ck])

### Validation via ITS and CK-test

In [None]:
# lag = (np.linspace(3.22,6.3, 10)**4).astype('int')
K_results_6 = np.ones((runs, len(lag) ,output_sizes[0], output_sizes[0]))
its_6_all_vamp = []
for r in range(runs):
    Full_net.set_weights(weights_chi_list[r])
    pred = pred_batchwise(traj_whole_new[0], batchsize=10000)
    
    for i, tau_i in enumerate(lag):

        K_results_6[r,i]  = estimate_koopman_op(pred, tau_i, force_symmetric = False)

its_6_all_vamp = get_its(K_results_6, lag, False, multiple_runs=True)

In [None]:
all_its_vamp_np = np.array(its_6_all_vamp)
all_its_vamp_mean = all_its_vamp_np.mean(0)
all_its_vamp_min = all_its_vamp_np.min(0)
all_its_vamp_max = all_its_vamp_np.max(0)

In [None]:
fac = 200.*skip*1e-6 
# fac = 0.0002

plt.figure(figsize=(6,4));

label_x = np.array([.1,0.3,1, 2, 5,10,100,1000])/fac # array is in microsecond
label_y = np.array([.1,1, 2, 5,10, 100, 1000])/fac
# fig = plt.figure(figsize = (8,8))
for j in range(0,output_sizes[0]-1):
    plt.semilogy(lag, all_its_vamp_mean[::-1][j], lw=5)
    plt.fill_between(lag, all_its_vamp_min[::-1][j], all_its_vamp_max[::-1][j], alpha = 0.3)
plt.semilogy(lag,lag, 'k')
plt.xlabel('lag [$\mu$s]', fontsize=26)
plt.xticks(label_x, label_x*fac, fontsize=22)
plt.ylabel('timescale [$\mu$s]', fontsize=26)
plt.yticks(label_y, np.round(label_y*fac, decimals=1), fontsize=22)
plt.fill_between(lag,lag,0.1,alpha = 0.2,color='k');
plt.ylim(0.01/fac, 3/fac)
plt.xlim(lag[0], 0.3/fac)
plt.show()

In [None]:
weights_vamp = weights_chi_list[1]
Full_net.set_weights(weights_vamp)

In [None]:
def get_ck(K, lag):
    n_states = output_sizes[0]
    steps = len(lag)
    predicted = np.zeros((n_states, n_states, steps))
    estimated = np.zeros((n_states, n_states, steps))

    predicted[:,:,0] =  np.identity(n_states)
    estimated[:,:,0] =  np.identity(n_states)

    for vector, i  in zip(np.identity(n_states), range(n_states)):
        for n in range(1, steps):

            koop = K[0]
            fac = lag[n]//lag[0]
            koop_pred = np.linalg.matrix_power(koop,fac)

            koop_est = K[n]

            predicted[i,:,n]= vector @ koop_pred
            estimated[i,:,n]= vector @ koop_est
        
              
    return [predicted, estimated]

In [None]:
test_on = 10000
pred = pred_batchwise(traj_whole_new[0], batchsize=10000)

K_ck_vamp = np.zeros((runs, tau_ck.shape[0], output_sizes[0], output_sizes[0]))
all_lx = []
all_rx = []
for r in range(runs):
    for i, tau_ck_i in enumerate(tau_ck):
        pred_t = pred[:-tau_ck_i]
        pred_tau = pred[tau_ck_i:]
        frames_index = np.arange(pred_t.shape[0])
        indexes = np.random.choice(frames_index, size=test_on, replace=True)
        ck_traj = pred_t[indexes]
        ck_traj_tau = pred_tau[indexes]
        K_ck_vamp[r,i] = estimate_koopman_op([ck_traj, ck_traj_tau], 0)
    lx_side, rx_side = get_ck(K_ck_vamp[r], tau_ck)
    all_lx.append(lx_side)
    all_rx.append(rx_side)

In [None]:
all_lx_arr = np.array(all_lx)
lx_mean = np.mean(all_lx_arr, axis=0)
lx_min = np.min(all_lx_arr, axis=0)
lx_max = np.max(all_lx_arr, axis=0)
all_rx_arr = np.array(all_rx)
rx_mean = np.mean(all_rx_arr, axis=0)
rx_min = np.min(all_rx_arr, axis=0)
rx_max = np.max(all_rx_arr, axis=0)

In [None]:
import matplotlib.gridspec as gridspec
output_size = output_sizes[0]
fig = plt.figure(figsize = (16,16))
gs1 = gridspec.GridSpec(output_size, output_size)
gs1.update(wspace=0.1, hspace=0.05)
states = output_size
for index_i in range(states):
    for index_j in range(states):
        ax = plt.subplot(gs1[index_i*output_size+index_j])
        ax.plot(tau_ck, lx_mean[index_i, index_j], color='b', lw=4)
        ax.fill_between(tau_ck,lx_min[index_i, index_j],lx_max[index_i, index_j], alpha = 0.25 )
        ax.errorbar(tau_ck, rx_mean[index_i, index_j], yerr= np.array([rx_mean[index_i][index_j]-rx_min[index_i][index_j], rx_max[index_i][index_j]-rx_mean[index_i][index_j]]), color = 'r', lw=4, linestyle = '--')
        title = str(index_i+1)+ '->' +str(index_j+1)
        
        ax.text(.75,.8, title,
            horizontalalignment='center',
            transform=ax.transAxes,  fontdict = {'size':26})
    
        ax.set_ylim((-0.1,1.1));
        ax.set_xlim((0, tau_ck[-1]+5));
        
        if (index_j == 0):
            ax.axes.get_yaxis().set_ticks([0, 1])
            ax.yaxis.set_tick_params(labelsize=32)
        
        else:
            ax.axes.get_yaxis().set_ticks([])
        
        if (index_i == output_size -1):
            
            xticks = np.array([20,60])
            float_formatter = lambda x: np.array([("%.1f" % y if y > 0.001 else "0") for y in x])
            
            ax.xaxis.set_ticks(xticks);
            ax.xaxis.set_ticklabels((xticks*fac));
            ax.xaxis.set_tick_params(labelsize=32)
        else:
            ax.axes.get_xaxis().set_ticks([])
            
        if (index_i == output_size - 1 and index_j == output_size - 4):
            ax.text(2.16, -0.4, "[$\mu$s]",
                horizontalalignment='center',
                transform=ax.transAxes,  fontdict = {'size':28})

# fig.savefig('../figs/ck_villin_states_{}_vamp.pdf'.format(states), bbox_inches='tight')
plt.show()

### Check for negative entries in the Koopman matrix

In [None]:
chi_t = pred_batchwise(tensor_train_X1)
chi_tau = pred_batchwise(tensor_train_X2)
K_vamp = estimate_koopman_op([chi_t, chi_tau], 0)
print(np.linalg.eigvals(K_vamp))
print(chi_t.max(0))

In [None]:
K_vamp

### Training a reversible deep MSM

In [None]:
Full_net.set_weights(weights_vamp)

### Helper functions to train for u and S individually

In [None]:
def train_for_S(runs=100, verbose=True, plot_training=True):
    
    chi_t = torch.Tensor(pred_batchwise(tensor_train_X1)).to(device)
    chi_tau = torch.Tensor(pred_batchwise(tensor_train_X2)).to(device)
    _, _, v, C_00, C_11, C_01, Sigma, _, _ = Full_net.u_layers[0](chi_t, chi_tau)
    
    v = v.detach()
    C_00 = C_00.detach()
    C_11 = C_11.detach()
    C_01 = C_01.detach()
    Sigma = Sigma.detach()
    chi_t.detach()
    chi_tau.detach()
    
    chi_t_valid = torch.Tensor(pred_batchwise(tensor_valid_X1)).to(device)
    chi_tau_valid = torch.Tensor(pred_batchwise(tensor_valid_X2)).to(device)
    _, _, v_valid, C_00_valid, C_11_valid, C_01_valid, Sigma_valid, _, _ = Full_net.u_layers[0](chi_t_valid, chi_tau_valid)
    
    v_valid = v_valid.detach()
    C_00_valid = C_00_valid.detach()
    C_11_valid = C_11_valid.detach()
    C_01_valid = C_01_valid.detach()
    Sigma_valid = Sigma_valid.detach()
    chi_t_valid.detach()
    chi_tau_valid.detach()
    
    epoch_loss = np.zeros(runs)
    epoch_loss_valid = np.zeros(runs)
    opt = optimizer_rev_S
    for epoch in range(runs):  # loop over the dataset multiple times

        opt.zero_grad()

        # estimate weights
#             print(inputs_t)

#             print(chi_t)

        score_curr = vampe_loss_rev_only_S(v, C_00, C_11, C_01, Sigma)



        loss = - score_curr

        loss.backward()

        opt.step()


        epoch_loss[epoch] = np.mean(-loss.item())
        
        score_curr_valid = vampe_loss_rev_only_S(v_valid, C_00_valid, C_11_valid, C_01_valid, Sigma_valid)
        epoch_loss_valid[epoch] = np.mean(score_curr_valid.item())
        if verbose:
            print('Run {}, total loss: {:.3}, valid loss: {:.3}'.format(epoch+1, epoch_loss[epoch], epoch_loss_valid[epoch]))

    if plot_training:
        plt.plot(np.arange(1,runs+1), epoch_loss, label='VAMP_loss')
        plt.plot(np.arange(1,runs+1), epoch_loss_valid, label='VAMP_loss_valid')
        plt.legend()
        plt.show()    

In [None]:
def train_for_u_S(runs=100, verbose=True, plot_training=True):
    
    chi_t = torch.Tensor(pred_batchwise(tensor_train_X1)).to(device).detach()
    chi_tau = torch.Tensor(pred_batchwise(tensor_train_X2)).to(device).detach()
    
    chi_t_valid = torch.Tensor(pred_batchwise(tensor_valid_X1)).to(device).detach()
    chi_tau_valid = torch.Tensor(pred_batchwise(tensor_valid_X2)).to(device).detach()
    
    
    
    
    epoch_loss = np.zeros(runs)
    epoch_loss_valid = np.zeros(runs)
    opt = optimizer_rev_u_S
    for epoch in range(runs):  # loop over the dataset multiple times

        opt.zero_grad()

        # estimate weights
#             print(inputs_t)

#             print(chi_t)
        score_curr = vampe_loss_rev(chi_t, chi_tau)

        loss = - score_curr

        loss.backward()

        opt.step()

        score_curr_valid = vampe_loss_rev(chi_t_valid, chi_tau_valid)
        epoch_loss[epoch] = np.mean(-loss.item())
        epoch_loss_valid[epoch] = np.mean(score_curr_valid.item())
        if verbose:
            print('Run {}, total loss: {:.3}, valid loss: {:.3}'.format(epoch+1, epoch_loss[epoch], epoch_loss_valid[epoch]))

    if plot_training:
        plt.plot(np.arange(1,runs+1), epoch_loss, label='VAMP_loss')
        plt.plot(np.arange(1,runs+1), epoch_loss_valid, label='VAMP_loss_valid')
        plt.legend()
        plt.show()    

In [None]:
def get_its(data, lags, calculate_K = True, multiple_runs = False):
    
    def get_single_its(data):

        if type(data) == list:
            outputsize = data[0].shape[1]
        else:
            outputsize = data.shape[1]

        single_its = np.zeros((outputsize-1, len(lags)))

        for t, tau_lag in enumerate(lags):
            if calculate_K:
                koopman_op = estimate_koopman_op(data, tau_lag)
            else:
                koopman_op = data[t]
            k_eigvals, k_eigvec = np.linalg.eig(np.real(koopman_op))
            k_eigvals = np.sort(np.absolute(k_eigvals))
            k_eigvals = k_eigvals[:-1]
            single_its[:,t] = (-tau_lag / np.log(k_eigvals))

        return np.array(single_its)


    if not multiple_runs:

        its = get_single_its(data)

    else:

        its = []
        for data_run in data:
            its.append(get_single_its(data_run))

    return its

### Pretraining u and S

In [None]:
# Define optimizers
if Full_net.reversible:
    optimizer_rev = optim.Adam(Full_net.get_params_rev(), lr=learning_rate/10)
    
    optimizer_rev_S = optim.Adam(Full_net.get_params_rev(all=False, u_flag=False), lr=learning_rate*1000)
    optimizer_rev_u = optim.Adam(Full_net.get_params_rev(all=False, S_flag=False), lr=learning_rate*100)

In [None]:
# initialize u and S with values coming from the VAMPnet
if Full_net.reversible:
    Full_net.set_rev_var(S=True)
    train_for_S(runs=3000, verbose=False)

In [None]:
weights_after_S = Full_net.get_weights()

In [None]:
optimizer_rev_u_S = optim.Adam(Full_net.get_params_rev(all=False), lr=learning_rate*10)
optimizer_rev_S = optim.Adam(Full_net.get_params_rev(all=False, u_flag=False), lr=learning_rate*100)

In [None]:
Full_net.set_weights(weights_after_S)

In [None]:
if Full_net.reversible:
    for _ in range(10):
        train_for_u_S(runs=1000, verbose=False)
        train_for_S(runs=1000, verbose=False)

### Training for everything

In [None]:
def train_for_rev(runs, opt_chi, opt_u, opt_S, rel=0.01, reset=False, plot_mask_every=10, verbose=True, plot_training=True):
    
    epoch_loss = np.zeros(runs)
    epoch_loss_valid = np.zeros(runs)
#     sen = np.linspace(.1,2,10)
    opt_list = [opt_chi, opt_u, opt_S]
    
    best_score = 0.
    weights_best = Full_net.get_weights()
    for epoch in range(runs):  # loop over the dataset multiple times

#         Full_net.
#         opt = optimizer_rev
    #     sen_temp = sen_set[epoch//sen_every]

#         Full_net.set_soft_fac(sen[[epoch]])


        running_epoch_loss = []
        
        for i, data_batch in enumerate(trainloader_full, 0):
            # get the inputs; data is a list of [inputs, labels]
            inputs_t, inputs_tau = data_batch

            # zero the parameter gradients
            for opt in opt_list:
                opt.zero_grad()
                

            # estimate weights
            chi_t = Full_net(inputs_t[0].to(device))
            chi_tau = Full_net(inputs_tau[0].to(device))

            score_curr = vampe_loss_rev(chi_t, chi_tau)



            loss = - score_curr

            loss.backward()
            
            for opt in opt_list:
                opt.step()
            
                    
                
                
            running_epoch_loss.append(score_curr.item())
        flag=True
        print(running_epoch_loss)
        score_before=np.mean(running_epoch_loss)
        while flag:
            
            opt_S.zero_grad()
            opt_u.zero_grad()
            if reset:
                Full_net.set_rev_var(S=False)
            chi_t = torch.Tensor(pred_batchwise(tensor_train_X1)).to(device)
            chi_tau = torch.Tensor(pred_batchwise(tensor_train_X2)).to(device)

            score_curr = vampe_loss_rev(chi_t, chi_tau)
            loss = -score_curr
            loss.backward()
            
            opt_S.step()
            opt_u.step()
            print(score_before, score_curr)
            if score_curr-score_before < rel:
                flag = False
            score_before = score_curr
        # validation
        chi_t_valid = torch.Tensor(pred_batchwise(tensor_valid_X1)).to(device).detach()
        chi_tau_valid = torch.Tensor(pred_batchwise(tensor_valid_X2)).to(device).detach()
        score_valid = vampe_loss_rev(chi_t_valid, chi_tau_valid)
        
        epoch_loss[epoch] = np.mean(running_epoch_loss)
        epoch_loss_valid[epoch] = np.mean(score_valid.item())
        if verbose:
            print('Run {}, total loss: {:.3}, valid loss: {:.3}'.format(epoch+1, epoch_loss[epoch], epoch_loss_valid[epoch]))
        if epoch_loss_valid[epoch]> best_score:
            best_score = epoch_loss_valid[epoch]
            weights_best = Full_net.get_weights()
            print('better weights')
#         if (((epoch+1) % plot_mask_every)==0):
#             plot_mask(vmax=1/n_residues*10)
    if plot_training:
        plt.plot(np.arange(1,runs+1), epoch_loss, label='VAMP_loss')
        plt.plot(np.arange(1,runs+1), epoch_loss_valid, label='VAMP_loss_valid')
        plt.legend()
        plt.show()  
        
    print('Set best weights')
    Full_net.set_weights(weights_best)

In [None]:
weights_before_rec = Full_net.get_weights()

In [None]:
opt1 = optim.Adam(Full_net.get_params_vamp(), lr=learning_rate/5,) 
opt2 = optim.Adam(Full_net.get_params_rev(all=False, u_flag=False), lr=learning_rate*10)
opt3 = optim.Adam(Full_net.get_params_rev(all=False, S_flag=False), lr=learning_rate*1) # 

In [None]:
Full_net.set_weights(weights_before_rec)

In [None]:
plot_mask( vmax=1/n_residues * 3)

In [None]:
chi_t = pred_batchwise(tensor_train_X1)
chi_tau = pred_batchwise(tensor_train_X2)
K_vamp = estimate_koopman_op([chi_t, chi_tau], 0)
print(np.linalg.eigvals(K_vamp))
print(chi_t.max(0))

In [None]:
weights_vamp = Full_net.get_weights()

In [None]:
Full_net.set_weights(weights_vamp)

### Train first without resetting u

In [None]:
train_for_rev(100, opt1, opt2, opt3, rel=0.0001, reset=False)

### Afterwards reset u to check if the result was stuck in an suboptimal region

In [None]:
Full_net.set_rev_var(S=False)
train_for_rev(100, opt1, opt2, opt3, rel=0.0001, reset=True)

In [None]:
weights_before_rec = Full_net.get_weights()

In [None]:
Full_net.set_weights(weights_before_rec)

In [None]:
def get_K_rev(tensor_t=torch.Tensor(traj_whole_new[0][:-tau]), tensor_tau=torch.Tensor(traj_whole_new[0][tau:])):
    
    chi_t = torch.Tensor(pred_batchwise(tensor_t)).to(device)
    chi_tau = torch.Tensor(pred_batchwise(tensor_tau)).to(device)
    _,_,v, C_00, C_11, C_01, Sigma, _, _ = Full_net.u_layers[0](chi_t, chi_tau)
    _, K, _ = Full_net.S_layers[0](v, C_00, C_11, C_01, Sigma)
    
    return K.detach()

In [None]:
K_rev = get_K_rev().to('cpu')

### Compare the eigenvalues for VAMPnet and revDMSM case

In [None]:
# The eigenvalues of the VAMPnet will be higher since there are no restriction for the eigenfunctions!
np.linalg.eigvals(K_rev), np.linalg.eigvals(K_vamp)

In [None]:
weights_after_rec = Full_net.get_weights()

### Estimate K_revs for timescales and cktest


In [None]:
K_results_6_rev = np.ones((runs, len(lag) ,output_sizes[0], output_sizes[0]))

for r in range(runs):    
    
    for i, tau_i in enumerate(lag):
        print(r, i, tau_i)
        X1_train, X2_train, X1_vali, X2_vali, X1_test, X2_test, length_train, length_vali, _, _ = get_data_for_tau(traj_whole_new, tau_i)

        tensor_train_X1 = torch.Tensor(X1_train)
        tensor_train_X2 = torch.Tensor(X2_train) # transform to torch tensor
        tensor_valid_X1 = torch.Tensor(X1_vali)
        tensor_valid_X2 = torch.Tensor(X2_vali)
        tensor_test_X1 = torch.Tensor(X1_test)
        tensor_test_X2 = torch.Tensor(X2_test)

        trainset = data.TensorDataset(tensor_train_X1, tensor_train_X2) # create your datset

        trainloader = data.DataLoader(trainset, sampler=data.BatchSampler(data.RandomSampler(trainset), batch_size=batch_size, drop_last=True))
        # trainloader = data.DataLoader(trainset, batch_size=batch_size,
        #                               shuffle=True, num_workers=2)
        trainloader_full = data.DataLoader(trainset, sampler=data.BatchSampler(data.RandomSampler(trainset), batch_size=X1_train.shape[0], drop_last=True))
        # trainloader_full = data.DataLoader(trainset, batch_size=batch_size_large,
        #                               shuffle=True, num_workers=2)

        testset = data.TensorDataset(tensor_valid_X1, tensor_valid_X2) # create your datset
        testloader = data.DataLoader(testset, sampler=data.BatchSampler(data.RandomSampler(testset), batch_size=X1_vali.shape[0], drop_last=True))
        # testloader = data.DataLoader(testset, batch_size=batch_size,
        #                              shuffle=True, num_workers=2)

        full_batch = False
        if full_batch:
            train_l = trainloader_full
        else:
            train_l = trainloader
        
        Full_net.set_weights(weights_after_rec)
        
        
        with torch.no_grad():
            for param in Full_net.S_layers[0].parameters():
                param.copy_(torch.Tensor(np.ones((output_sizes[0], output_sizes[0]))))
#         with torch.no_grad():
#             for param in Full_net.u_layers[0].parameters():
#                 param.copy_(torch.Tensor(np.ones((1, output_sizes[0]))))

        optimizer_rev_u_S = optim.Adam(Full_net.get_params_rev(all=False), lr=learning_rate*10)
        optimizer_rev_S = optim.Adam(Full_net.get_params_rev(all=False, u_flag=False), lr=learning_rate*100)
        for _ in range(10):
            train_for_S(runs=1000, verbose=False, plot_training=True)

            train_for_u_S(runs=1000, verbose=False, plot_training=True)
    #     print(tau_i)
    #     K_results_rev[i]= training_for_tau_both(tau_i)
        K_results_6_rev[r,i]  = get_K_rev(tensor_test_X1, tensor_test_X2).to('cpu')

its_6_all_rev = get_its(K_results_6_rev, lag, False, multiple_runs=True)

K_ck_rev = K_results_6_rev[:,2:]
all_lx_rev = []
all_rx_rev = []
for r in range(runs):
    lx_side, rx_side = get_ck(K_ck_rev[r], tau_ck)
    all_lx_rev.append(lx_side)
    all_rx_rev.append(rx_side)

In [None]:
all_its_rev_np = np.array(its_6_all_rev)
all_its_rev_mean = all_its_rev_np.mean(0)
all_its_rev_min = all_its_rev_np.min(0)
all_its_rev_max = all_its_rev_np.max(0)
fac = 200.*skip*1e-6 
# fac = 0.0002

plt.figure(figsize=(6,4));

label_x = np.array([.1,0.3,1, 2, 5,10,100,1000])/fac # array is in microsecond
label_y = np.array([.1,1, 2, 5,10, 100, 1000])/fac
# fig = plt.figure(figsize = (8,8))
for j in range(0,output_sizes[0]-1):
    plt.semilogy(lag, all_its_rev_mean[::-1][j], lw=5)
    plt.fill_between(lag, all_its_rev_min[::-1][j], all_its_rev_max[::-1][j], alpha = 0.3)
plt.semilogy(lag,lag, 'k')
plt.xlabel('lag [$\mu$s]', fontsize=26)
plt.xticks(label_x, label_x*fac, fontsize=22)
plt.ylabel('timescale [$\mu$s]', fontsize=26)
plt.yticks(label_y, np.round(label_y*fac, decimals=1), fontsize=22)
plt.fill_between(lag,lag,0.1,alpha = 0.2,color='k');
plt.ylim(0.01/fac, 3/fac)
plt.xlim(lag[0], 0.3/fac)

In [None]:
all_lx_rev_arr = np.array(all_lx_rev)
lx_rev_mean = np.mean(all_lx_rev_arr, axis=0)
lx_rev_min = np.min(all_lx_rev_arr, axis=0)
lx_rev_max = np.max(all_lx_rev_arr, axis=0)
all_rx_rev_arr = np.array(all_rx_rev)
rx_rev_mean = np.mean(all_rx_rev_arr, axis=0)
rx_rev_min = np.min(all_rx_rev_arr, axis=0)
rx_rev_max = np.max(all_rx_rev_arr, axis=0)

import matplotlib.gridspec as gridspec
output_size = output_sizes[0]
fig = plt.figure(figsize = (16,16))
gs1 = gridspec.GridSpec(output_size, output_size)
gs1.update(wspace=0.1, hspace=0.05)
states = output_size
for index_i in range(states):
    for index_j in range(states):
        ax = plt.subplot(gs1[index_i*output_size+index_j])
        ax.plot(tau_ck, lx_rev_mean[index_i, index_j], color='b', lw=4)
        ax.fill_between(tau_ck,lx_rev_min[index_i, index_j],lx_rev_max[index_i, index_j], alpha = 0.25 )
        ax.errorbar(tau_ck, rx_rev_mean[index_i, index_j], yerr= np.array([rx_rev_mean[index_i][index_j]-rx_rev_min[index_i][index_j], rx_rev_max[index_i][index_j]-rx_rev_mean[index_i][index_j]]), color = 'r', lw=4, linestyle = '--')
        title = str(index_i+1)+ '->' +str(index_j+1)
        
        ax.text(.75,.8, title,
            horizontalalignment='center',
            transform=ax.transAxes,  fontdict = {'size':26})
    
        ax.set_ylim((-0.1,1.1));
        ax.set_xlim((0, tau_ck[-1]+5));
        
        if (index_j == 0):
            ax.axes.get_yaxis().set_ticks([0, 1])
            ax.yaxis.set_tick_params(labelsize=32)
        
        else:
            ax.axes.get_yaxis().set_ticks([])
        
        if (index_i == output_size -1):
            
            xticks = np.array([20,60])
            float_formatter = lambda x: np.array([("%.1f" % y if y > 0.001 else "0") for y in x])
            
            ax.xaxis.set_ticks(xticks);
            ax.xaxis.set_ticklabels((xticks*fac));
            ax.xaxis.set_tick_params(labelsize=32)
        else:
            ax.axes.get_xaxis().set_ticks([])
            
        if (index_i == output_size - 1 and index_j == output_size - 4):
            ax.text(2.16, -0.4, "[$\mu$s]",
                horizontalalignment='center',
                transform=ax.transAxes,  fontdict = {'size':28})

# fig.savefig('../figs/ck_villin_states_{}_vamp.pdf'.format(states), bbox_inches='tight')
plt.show()

### Plotting the network graphs for VAMPnet and revDMSM

In [None]:
import networks

In [None]:
eigvals, eigvec = np.linalg.eig(K_vamp.T)
sort_ind = np.argsort(eigvals)[::-1]
pi_vamp = eigvec[:,sort_ind[1]]
pi_vamp = pi_vamp/pi_vamp.sum()

In [None]:
eigvals, eigvec = np.linalg.eig(K_rev.T)
sort_ind = np.argsort(eigvals)[::-1]
pi_rev = eigvec[:,sort_ind[0]]
pi_rev = pi_rev/pi_rev.sum()

In [None]:
pos = 0.25*np.array([[0.,0.],[1.,0.],[0.,1.], [1.,1.]])

In [None]:
# labels have to adapted for specific user case
state_labels = ['F', 'U', 'M', 'PF']

In [None]:
K_vamp

In [None]:
fig, pos = networks.plot_network(K_vamp, state_sizes=pi_rev, pos = pos, arrow_label_format='%2.4f',
                                 state_scale=1., arrow_curvature=1, state_labels=state_labels,
                     arrow_scale=2., state_colors='lightsalmon', arrow_threshold=1e-7)

In [None]:
fig, pos = networks.plot_network(K_rev.numpy(), state_sizes=pi_rev, pos = pos, arrow_label_format='%2.4f',
                                 state_scale=1., arrow_curvature=1, state_labels=state_labels,
                     arrow_scale=2., state_colors='lightsalmon', arrow_threshold=1e-7)

### Save structures for defined states

In [None]:
# has to be modified depending where the data lies for the user
import mdtraj as md
number_of_frames = 20
md_traj_small = md.load_dcd(root+'-000.dcd',
                      top=root+'.pdb'
                )[:number_of_frames]
md_traj_super = md.load_pdb('../../dmsm_recap/vampnets/folded_states/2f4k_villin.pdb')
# number of residues for the system
n_residues=35

In [None]:
# estimate secondary structure for allignment
def get_heavy_atoms_dssp(dssp, in_all=True):
    list_atoms = []
    nd_s = ['H', 'E']
    
    if in_all:
        frames_top = dssp.shape[0]
    else:
        frames_top = 1
    
        
    list_res2nd = [True]*n_residues
    for dssp_frame in dssp[:frames_top]:
        for res in range(n_residues):
            if dssp_frame[res] not in nd_s:
                    list_res2nd[res] = False
                    

    for res in range(n_residues):
        if list_res2nd[res]:
            atoms_ind = [atom.index for atom in md_traj_super.topology.atoms if atom.residue == md_traj_super.topology.residue(res) and atom.name != 'H']
            list_atoms+= atoms_ind
    return list_atoms, list_res2nd


In [None]:
dssp = md.compute_dssp(md_traj_small, simplified=True)[:,:n_residues]
dssp_super = md.compute_dssp(md_traj_super, simplified=True)[:,:n_residues]
index_heavy_2nd = get_heavy_atoms_dssp(dssp)

In [None]:
def structures(cg_ind, vmax=1, skip=10, top=10):
    n_residues = int(-1/2 + np.sqrt(1/4+input_size*2) + 3)
    if Full_net.mask_const:
        attention = Full_net.get_attention()
        attention_np = attention.detach().numpy()
        att_atom = np.reshape(attention_np, (n_residues,1))
        plt.imshow(att_atom, vmin=0, vmax=vmax, aspect='auto')
        plt.xlabel('System', fontsize=18)
        plt.ylabel('Input', fontsize=18)
        plt.xticks(np.arange(1),['{}'.format(i) for i in range(1)], fontsize=16)
        plt.yticks(np.arange(0,n_residues,skip),['x{}'.format(i) for i in range(0,n_residues,skip)], fontsize=16)
        plt.show()
    #     plt.savefig('./Figs/2x3_mix_Mask.pdf', bbox_inches='tight')
        
        return att_atom
        
    else:
        pred_temp = pred_batchwise(traj_whole[0], batchsize=10000)
        pred_temp = torch.Tensor(pred_temp).to(device)
        for i in range(cg_ind):
            pred_temp = Full_net.forward_cg(pred_temp, i)
        pred_temp = pred_temp.detach().to('cpu').numpy()
        print(pred_temp.shape)
        arg_sort = np.argsort(pred_temp, axis=0)
        top_x_state = arg_sort[-top:][::-1]
        states = pred_temp.shape[1]
        att_atom = []
        for state in range(states):
            frames = top_x_state[:,state]
            attention = Full_net.get_attention(torch.Tensor(traj_whole[0][frames]).to(device))
            attention_np = attention.detach().to('cpu').numpy()
#             att_atom.append(np.mean(attention_np, axis=0, keepdims=True))
            print(attention_np.shape)
            att_atom.append(attention_np[None,:,:])
        att_residue = np.concatenate(att_atom, axis=0)
        
        plt.imshow(np.mean(att_residue,axis=1).T, vmin=0, vmax=vmax, aspect='auto')
        plt.xlabel('State', fontsize=18)
        plt.ylabel('Input', fontsize=18)
        plt.xticks(np.arange(states),['{}'.format(i) for i in range(states)], fontsize=16)
        plt.yticks(np.arange(0,n_residues,skip),['x{}'.format(i) for i in range(0,n_residues,skip)], fontsize=16)
        plt.show()
    #     plt.savefig('./Figs/2x3_mix_Mask.pdf', bbox_inches='tight')
        
        return att_residue, top_x_state



In [None]:
Full_net.set_weights(weights_after_rec)

### Save structures for revDMSM

In [None]:
traj_nr=0
threshold = 0
out_ind=0
attention, frames = structures(out_ind, top=number_of_frames, vmax=3)
for o_temp in range(frames.shape[1]):
    
    o = o_temp
    print(o)
    attention_clean = np.zeros_like(attention[o])
    for f in range(number_of_frames):
        for i in range(n_residues):
            if attention[o,f,i] < threshold:
                attention_clean[f,i]=0
            else:
                attention_clean[f,i]=attention[o,f,i] 
    attention_fixed = np.concatenate([attention_clean, attention_clean], axis=1)
    bfactors = np.repeat(attention_fixed, [res.n_atoms for res in md_traj_super.top.residues], axis=1)

    frames_of_state = frames[:,o]
    frames_per_file = 10000
    for i in range(number_of_frames):
        file_number = frames_of_state[i]//frames_per_file
        file_number_frame = frames_of_state[i]%frames_per_file
    #             print(numbers_all[i], traj_nr[i], file_number, file_number_frame)
        root_new = '/group/ag_cmb/simulation-data/DESRES-Science2011-FastProteinFolding/DESRES-Trajectory_{0}-{1}-protein/{0}-{1}-protein/{0}-{1}-protein'.format(test_system, traj_nr)
        md_traj_start = md.load_dcd(root_new+'-{:03}.dcd'.format(file_number),
                              top='/group/ag_cmb/scratch/deeptime_data/{0}/system-protein.pdb'.format(test_system)
                        )
        md_traj_small.xyz[i] = md_traj_start.xyz[file_number_frame]

    md_traj_small.superpose(md_traj_super) 
    dssp = md.compute_dssp(md_traj_small, simplified=True)[:,:n_residues]

    index_heavy_2nd, list_res = get_heavy_atoms_dssp(dssp)
    print('Index of residues in 2nd structure', np.arange(n_residues)[list_res])
    if index_heavy_2nd:
        md_traj_small_temp = md_traj_small.superpose(md_traj_small, frame=0, atom_indices=index_heavy_2nd)
        print('found secondary structure and aligning with {} atoms'.format(len(index_heavy_2nd)))
    else:
        md_traj_small_temp = md_traj_small.superpose(md_traj_small, frame=0, atom_indices=None)
        print('did not find secondary structure and aligning with pdb {} atoms'.format(len(index_heavy_2nd)))

    md_traj_small_temp.save_pdb('/group/ag_cmb/scratch/deeptime_data/{0}/attention/rev_{1}_{2}_smooth{3}.pdb'.format(test_system, output_sizes[out_ind], o_temp, patchsize, test), bfactors=bfactors)

In [None]:
Full_net.set_weights(weights_vamp)

### Save structures for VAMPnet

In [None]:
traj_nr=0
threshold = 0
out_ind=0
attention, frames = structures(out_ind, top=number_of_frames, vmax=3)
for o_temp in range(frames.shape[1]):
    
    o = o_temp
    print(o)
    attention_clean = np.zeros_like(attention[o])
    for f in range(number_of_frames):
        for i in range(n_residues):
            if attention[o,f,i] < threshold:
                attention_clean[f,i]=0
            else:
                attention_clean[f,i]=attention[o,f,i] 
    attention_fixed = np.concatenate([attention_clean, attention_clean], axis=1)
    bfactors = np.repeat(attention_fixed, [res.n_atoms for res in md_traj_super.top.residues], axis=1)

    frames_of_state = frames[:,o]
    frames_per_file = 10000
    for i in range(number_of_frames):
        file_number = frames_of_state[i]//frames_per_file
        file_number_frame = frames_of_state[i]%frames_per_file
    #             print(numbers_all[i], traj_nr[i], file_number, file_number_frame)
        root_new = '/group/ag_cmb/simulation-data/DESRES-Science2011-FastProteinFolding/DESRES-Trajectory_{0}-{1}-protein/{0}-{1}-protein/{0}-{1}-protein'.format(test_system, traj_nr)
        md_traj_start = md.load_dcd(root_new+'-{:03}.dcd'.format(file_number),
                              top='/group/ag_cmb/scratch/deeptime_data/{0}/system-protein.pdb'.format(test_system)
                        )
        md_traj_small.xyz[i] = md_traj_start.xyz[file_number_frame]

    md_traj_small.superpose(md_traj_super) 
    dssp = md.compute_dssp(md_traj_small, simplified=True)[:,:n_residues]

    index_heavy_2nd, list_res = get_heavy_atoms_dssp(dssp)
    print('Index of residues in 2nd structure', np.arange(n_residues)[list_res])
    if index_heavy_2nd:
        md_traj_small_temp = md_traj_small.superpose(md_traj_small, frame=0, atom_indices=index_heavy_2nd)
        print('found secondary structure and aligning with {} atoms'.format(len(index_heavy_2nd)))
    else:
        md_traj_small_temp = md_traj_small.superpose(md_traj_small, frame=0, atom_indices=None)
        print('did not find secondary structure and aligning with pdb {} atoms'.format(len(index_heavy_2nd)))

    md_traj_small_temp.save_pdb('/group/ag_cmb/scratch/deeptime_data/{0}/attention/vamp_{1}_{2}_smooth{3}.pdb'.format(test_system, output_sizes[out_ind], o_temp, patchsize, test), bfactors=bfactors)

In [None]:
Full_net.set_weights(weights_after_rec)

### Train revDMSM for different tau

In [None]:
from torch.utils import data

In [None]:
X1_train, X2_train, X1_vali, X2_vali, X1_test, X2_test, length_train, length_vali, _, _ = get_data_for_tau(traj_whole_new, tau)

tensor_train_X1 = torch.Tensor(X1_train)
tensor_train_X2 = torch.Tensor(X2_train) # transform to torch tensor
tensor_valid_X1 = torch.Tensor(X1_vali)
tensor_valid_X2 = torch.Tensor(X2_vali)
tensor_test_X1 = torch.Tensor(X1_test)
tensor_test_X2 = torch.Tensor(X2_test)

trainset = data.TensorDataset(tensor_train_X1, tensor_train_X2) # create your datset

trainloader = data.DataLoader(trainset, sampler=data.BatchSampler(data.RandomSampler(trainset), batch_size=batch_size, drop_last=True))
# trainloader = data.DataLoader(trainset, batch_size=batch_size,
#                               shuffle=True, num_workers=2)
trainloader_full = data.DataLoader(trainset, sampler=data.BatchSampler(data.RandomSampler(trainset), batch_size=batch_size_large, drop_last=True))
# trainloader_full = data.DataLoader(trainset, batch_size=batch_size_large,
#                               shuffle=True, num_workers=2)

testset = data.TensorDataset(tensor_valid_X1, tensor_valid_X2) # create your datset
testloader = data.DataLoader(testset, sampler=data.BatchSampler(data.RandomSampler(testset), batch_size=batch_size, drop_last=True))
# testloader = data.DataLoader(testset, batch_size=batch_size,
#                              shuffle=True, num_workers=2)

full_batch = False
if full_batch:
    train_l = trainloader_full
else:
    train_l = trainloader

### Retrain for the new tau

In [None]:
Full_net.set_weights(weights_after_rec)
with torch.no_grad():
    for param in Full_net.S_layers[0].parameters():
        param.copy_(torch.Tensor(np.ones((output_sizes[0], output_sizes[0]))))
with torch.no_grad():
    for param in Full_net.u_layers[0].parameters():
        param.copy_(torch.Tensor(np.ones((1, output_sizes[0]))))
        
optimizer_rev_u_S = optim.Adam(Full_net.get_params_rev(all=False), lr=learning_rate*10)
optimizer_rev_S = optim.Adam(Full_net.get_params_rev(all=False, u_flag=False), lr=learning_rate*100)
Full_net.set_rev_var()

In [None]:
for _ in range(5):
    train_for_S(runs=1000, verbose=False, plot_training=True)

    train_for_u_S(runs=1000, verbose=False, plot_training=True)

### Compare again with the VAMPnet, the difference should become smaller with larger tau

In [None]:
K_msm = get_K_rev(tensor_test_X1, tensor_test_X2).to('cpu')
chi_t = pred_batchwise(tensor_test_X1)
chi_tau = pred_batchwise(tensor_test_X2)
K_vamp_msm = estimate_koopman_op([chi_t, chi_tau], 0)

In [None]:
np.linalg.eigvals(K_msm), np.linalg.eigvals(K_vamp_msm)

In [None]:
weights_msm = Full_net.get_weights()

In [None]:
Full_net.set_weights(weights_msm)

### Estimate eigenfunctions

In [None]:
Full_net.Mask.noise=0.

In [None]:
tensor_t = torch.Tensor(traj_whole[0][:-tau*skip])
tensor_tau = torch.Tensor(traj_whole[0][tau*skip:])
K_msm = get_K_rev(tensor_t, tensor_tau).to('cpu')
chi_t = pred_batchwise(tensor_t)
chi_tau = pred_batchwise(tensor_tau)

In [None]:
eigvals, eigvec = np.linalg.eig(K_msm)
print(eigvals, eigvec)
sort_id = np.argsort(eigvals)[::-1]
print(eigvals[sort_id])

In [None]:
dssp_all = []
dssp_all_high = []
for dcc_traj in range(63):
    if dcc_traj<10:
        s = '-00{}.dcd'.format(dcc_traj)
    else:
        s = '-0{}.dcd'.format(dcc_traj)
    md_temp = md.load_dcd(root+s,
                         top=root+'.pdb')[::1]
    dssp_temp = md.compute_dssp(md_temp, simplified=True)[:,:n_residues]
    dssp_all.append(dssp_temp)
    dssp_temp = md.compute_dssp(md_temp, simplified=False)[:,:n_residues]
    dssp_all_high.append(dssp_temp)
dssp_all_np = np.concatenate(dssp_all, axis=0)
dssp_all_high_np = np.concatenate(dssp_all_high, axis=0)

In [None]:
pred_ord = pred_batchwise(traj_whole[0])

In [None]:
def moving_average(a, n=3) :
    ret = np.cumsum(a, dtype=float)
    ret[n:] = ret[n:] - ret[:-n]
    return ret[n - 1:] / n

### Include only secondary structure

In [None]:
start_ind = [2,13,21]
end_ind = [12,20,33]
not_native_ind = []
not_native_ind.append(np.arange(start_ind[0]))
native_ind = []

for i in range(2):
    not_native_ind.append(np.arange(end_ind[i],start_ind[i+1]))
    native_ind.append(np.arange(start_ind[i],end_ind[i]))
    if i ==1:
        native_ind.append(np.arange(start_ind[i+1],end_ind[i+1]))
not_native_ind.append(np.arange(end_ind[-1],n_residues))
not_native_ind = np.concatenate(not_native_ind)
native_ind = np.concatenate(native_ind)

In [None]:
dssp_super = md.compute_dssp(md_traj_super, simplified=True)[:,:n_residues]
dssp_super_high = md.compute_dssp(md_traj_super, simplified=False)[:,:n_residues]

In [None]:
sk = 25
possible_not_native = len(not_native_ind)
possible_native = len(native_ind)
dssp_frame_native = (dssp_all_high_np[:,native_ind] == dssp_super_high[:,native_ind]).sum(1)/possible_native
dssp_frame_non = (dssp_all_np[:,not_native_ind] != 'C').sum(1)/possible_not_native
for state_i in range(2):
    ind = sort_id[1+state_i]
    print('Eigenfunction {} of total states {} with eigenvalue {:.3}'.format(state_i, output_sizes[0], eigvals[ind]))

        
    eigfunc = pred_ord @ eigvec[:,ind]
    sort_ind = np.argsort(eigfunc)[::sk]
    plt.plot(eigfunc[sort_ind], '.')
#     plt.plot(eigfunc[::sk], dssp_frame_non[::sk], '.')
    plt.xlabel('$\psi_{}$ [a.u.]'.format(state_i+1), fontsize=14)
    plt.ylabel('Ratio of native helical residues', fontsize=14)
    plt.show()
    
#     plt.plot(dssp_frame[sort_ind[::10]], '.')
    plt.plot(moving_average(dssp_frame_native[sort_ind],30), '.')
    plt.xlabel('$\psi_{}$ [a.u.]'.format(state_i+1), fontsize=14)
    plt.ylabel('Ratio of native helical residues', fontsize=14)
#     plt.plot(dssp_frame_non[sort_ind[::10]], '.')
    plt.show()

### Slowest eigenfunction

In [None]:
state=0
native_contacts_conc = dssp_frame_native
ind = sort_id[1+state]
print('Eigenfunction {} of total states {} with eigenvalue {:.3}'.format(state_i, output_sizes[0], eigvals[ind]))
eigfunc = pred_ord @ eigvec[:,ind]

sort_ind = np.argsort(eigfunc)
ave_size = 30

sort_ind_sk = sort_ind[::sk]
add = 0.05
argmin2 = np.argmin((eigfunc-add)**2)

argmin1 = np.argmin((eigfunc+add)**2)


plt.plot(eigfunc[sort_ind_sk], '.')

plt.vlines(np.argwhere(sort_ind==argmin2)//sk,eigfunc.min(),eigfunc.max(), colors='C1', zorder=3)
plt.vlines(np.argwhere(sort_ind==argmin1)//sk,eigfunc.min(),eigfunc.max(), colors='C2', zorder=3)
rgwhere(sort_ind_sk==argmin)+10, 0, '$\psi=0.2$')
plt.xlabel('Configuration ordered by $\psi_1$', fontsize=16)
plt.ylabel('$\psi_1$ [a.u.]', fontsize=16)
plt.xticks(fontsize=14)
plt.yticks(fontsize=14)
# plt.savefig('./First_eigenfunction_villin_ordered_explain.png',dpi=1000, bbox_inches='tight')
plt.show()

plt.scatter(np.arange(eigfunc[::sk].shape[0]-ave_size+1),moving_average(native_contacts_conc[sort_ind_sk],ave_size), c=eigfunc[sort_ind_sk][:-ave_size+1], cmap='plasma')
plt.vlines(np.argwhere(sort_ind==argmin2)//sk,0,1, colors='C1', zorder=3)
plt.vlines(np.argwhere(sort_ind==argmin1)//sk,0,1, colors='C2', zorder=3)
plt.xlabel('Configuration ordered by $\psi_1$', fontsize=16)
plt.ylabel('Native contacts ratio', fontsize=16)
plt.xticks(fontsize=14)
plt.yticks(fontsize=14)
# plt.savefig('./First_eigenfunction_villin_ordered_colored.png',dpi=1000, bbox_inches='tight')
plt.show()

### Save structures

In [None]:
frames_eig = []
mean_12 = np.argwhere(sort_ind==argmin1)+(np.argwhere(sort_ind==argmin2) - np.argwhere(sort_ind==argmin1))//2
# mean_23 = np.argwhere(sort_ind==argmin2)+(np.argwhere(sort_ind==argmin3) - np.argwhere(sort_ind==argmin2))//2
frames_eig.append(sort_ind[0:20])
frames_eig.append(sort_ind[mean_12[0,0]-10:mean_12[0,0]+10])
# frames_eig.append(sort_ind[mean_23[0,0]-5:mean_23[0,0]+5])
frames_eig.append(sort_ind[-20:])

In [None]:
traj_nr=0
threshold = 0
out_ind=0

for o_temp in range(len(frames_eig)):
    
    o = o_temp
    print(o)
    

    frames_of_state = frames_eig[o]
    frames_per_file = 10000
    for i in range(number_of_frames):
        file_number = frames_of_state[i]//frames_per_file
        file_number_frame = frames_of_state[i]%frames_per_file
    #             print(numbers_all[i], traj_nr[i], file_number, file_number_frame)
        root_new = '/group/ag_cmb/simulation-data/DESRES-Science2011-FastProteinFolding/DESRES-Trajectory_{0}-{1}-protein/{0}-{1}-protein/{0}-{1}-protein'.format(test_system, traj_nr)
        md_traj_start = md.load_dcd(root_new+'-{:03}.dcd'.format(file_number),
                              top='/group/ag_cmb/scratch/deeptime_data/{0}/system-protein.pdb'.format(test_system)
                        )
        md_traj_small.xyz[i] = md_traj_start.xyz[file_number_frame]

    md_traj_small.superpose(md_traj_super) 
    dssp = md.compute_dssp(md_traj_small, simplified=True)[:,:n_residues]

    index_heavy_2nd, list_res = get_heavy_atoms_dssp(dssp)
    print('Index of residues in 2nd structure', np.arange(n_residues)[list_res])
    if index_heavy_2nd:
        md_traj_small_temp = md_traj_small.superpose(md_traj_small, frame=0, atom_indices=index_heavy_2nd)
        print('found secondary structure and aligning with {} atoms'.format(len(index_heavy_2nd)))
    else:
        md_traj_small_temp = md_traj_small.superpose(md_traj_small, frame=0, atom_indices=None)
        print('did not find secondary structure and aligning with pdb {} atoms'.format(len(index_heavy_2nd)))

    md_traj_small_temp.save_pdb('/group/ag_cmb/scratch/deeptime_data/{0}/attention/eig1_{1}_{2}_smooth{3}.pdb'.format(test_system, output_sizes[out_ind], o_temp, patchsize, test))

### Second slowest eigenfunction

In [None]:
state=1
native_contacts_conc = dssp_frame_native
ind = sort_id[1+state]
print('Eigenfunction {} of total states {} with eigenvalue {:.3}'.format(state_i, output_sizes[0], eigvals[ind]))
eigfunc = pred_ord @ eigvec[:,ind]

sort_ind = np.argsort(eigfunc)
ave_size = 30

sort_ind_sk = sort_ind[::sk]
add = 0.1
argmin2 = np.argmin((eigfunc-add)**2)

argmin1 = np.argmin((eigfunc+add)**2)


plt.plot(eigfunc[sort_ind_sk], '.')

plt.vlines(np.argwhere(sort_ind==argmin2)//sk,eigfunc.min(),eigfunc.max(), colors='C1', zorder=3)
plt.vlines(np.argwhere(sort_ind==argmin1)//sk,eigfunc.min(),eigfunc.max(), colors='C2', zorder=3)

plt.xlabel('Configuration ordered by $\psi_2$', fontsize=16)
plt.ylabel('$\psi_2$ [a.u.]', fontsize=16)
plt.xticks(fontsize=14)
plt.yticks(fontsize=14)
# plt.savefig('./First_eigenfunction_villin_ordered_explain.png',dpi=1000, bbox_inches='tight')
plt.show()

plt.scatter(np.arange(eigfunc[::sk].shape[0]-ave_size+1),moving_average(native_contacts_conc[sort_ind_sk],ave_size), c=eigfunc[sort_ind_sk][:-ave_size+1], cmap='plasma')
plt.vlines(np.argwhere(sort_ind==argmin2)//sk,0,1, colors='C1', zorder=3)
plt.vlines(np.argwhere(sort_ind==argmin1)//sk,0,1, colors='C2', zorder=3)
plt.xlabel('Configuration ordered by $\psi_2$', fontsize=16)
plt.ylabel('Native contacts ratio', fontsize=16)
plt.xticks(fontsize=14)
plt.yticks(fontsize=14)
# plt.savefig('./First_eigenfunction_villin_ordered_colored.png',dpi=1000, bbox_inches='tight')
plt.show()

In [None]:
frames_eig = []
mean_12 = np.argwhere(sort_ind==argmin1)+(np.argwhere(sort_ind==argmin2) - np.argwhere(sort_ind==argmin1))//2
# mean_23 = np.argwhere(sort_ind==argmin2)+(np.argwhere(sort_ind==argmin3) - np.argwhere(sort_ind==argmin2))//2
frames_eig.append(sort_ind[0:20])
frames_eig.append(sort_ind[mean_12[0,0]-10:mean_12[0,0]+10])
# frames_eig.append(sort_ind[mean_23[0,0]-5:mean_23[0,0]+5])
frames_eig.append(sort_ind[-20:])
traj_nr=0
threshold = 0
out_ind=0

for o_temp in range(len(frames_eig)):
    
    o = o_temp
    print(o)
    

    frames_of_state = frames_eig[o]
    frames_per_file = 10000
    for i in range(number_of_frames):
        file_number = frames_of_state[i]//frames_per_file
        file_number_frame = frames_of_state[i]%frames_per_file
    #             print(numbers_all[i], traj_nr[i], file_number, file_number_frame)
        root_new = '/group/ag_cmb/simulation-data/DESRES-Science2011-FastProteinFolding/DESRES-Trajectory_{0}-{1}-protein/{0}-{1}-protein/{0}-{1}-protein'.format(test_system, traj_nr)
        md_traj_start = md.load_dcd(root_new+'-{:03}.dcd'.format(file_number),
                              top='/group/ag_cmb/scratch/deeptime_data/{0}/system-protein.pdb'.format(test_system)
                        )
        md_traj_small.xyz[i] = md_traj_start.xyz[file_number_frame]

    md_traj_small.superpose(md_traj_super) 
    dssp = md.compute_dssp(md_traj_small, simplified=True)[:,:n_residues]

    index_heavy_2nd, list_res = get_heavy_atoms_dssp(dssp)
    print('Index of residues in 2nd structure', np.arange(n_residues)[list_res])
    if index_heavy_2nd:
        md_traj_small_temp = md_traj_small.superpose(md_traj_small, frame=0, atom_indices=index_heavy_2nd)
        print('found secondary structure and aligning with {} atoms'.format(len(index_heavy_2nd)))
    else:
        md_traj_small_temp = md_traj_small.superpose(md_traj_small, frame=0, atom_indices=None)
        print('did not find secondary structure and aligning with pdb {} atoms'.format(len(index_heavy_2nd)))

    md_traj_small_temp.save_pdb('/group/ag_cmb/scratch/deeptime_data/{0}/attention/eig2_{1}_{2}_smooth{3}.pdb'.format(test_system, output_sizes[out_ind], o_temp, patchsize, test))

### Building hierarchical models with coarse-graining layers

In [None]:
def set_cg(id, fac=1):
    layer = Full_net.coarse_grain_layer[id]
    
    chi_X1_train = torch.Tensor(pred_batchwise(tensor_train_X1)).to(device)
    chi_X2_train = torch.Tensor(pred_batchwise(tensor_train_X2)).to(device)

    u, u_t, v, C_00, C_11, C_01, Sigma, mu_t, Sigma_t = Full_net.u_layers[0](chi_X1_train, chi_X2_train)

    matrix, K, S = Full_net.S_layers[0](v, C_00, C_11, C_01, Sigma)


    for i in range(id):
        chi_X1_train, chi_X2_train, u, u_t, S, mu_t, Sigma_t, K, VampE_matrix = Full_net.coarse_grain_layer[i].get_cg_uS(
                                                                chi_X1_train, chi_X2_train, u, S, u_t, renorm=Full_net.renorm)
    K_test = K.detach().to('cpu').numpy().astype('float64')
    T_test = K_test / K_test.sum(axis=1)[:, None]
    from pyemma.msm import PCCA
    test_a = PCCA(T_test, output_sizes[id+1])
#     values = values / np.linalg.norm(values, axis=1, keepdims=True)
    values = test_a.memberships
    print(values)
    values = np.log(values)*fac
    with torch.no_grad():
            
            
        layer.weight.copy_(torch.Tensor(values)) 

In [None]:
optimizer_cg = []
for i in range(len(output_sizes)-1):
    optimizer_cg.append(optim.Adam(Full_net.get_params_cg([i]), lr=0.1))

In [None]:
# Full_net.coarse_grain_layer[0].reset_params()
set_cg(0,1.)
plot_cg(0)
optimizer_cg[0] = optim.Adam(Full_net.get_params_cg([0]), lr=0.1)

In [None]:
weights_before_rev_cg = Full_net.get_weights()

In [None]:
Full_net.set_weights(weights_before_rev_cg)

In [None]:
optimizer_cg[0] = optim.Adam(Full_net.get_params_cg([0]), lr=0.1)
# set_cg(0,2)


In [None]:
train_for_cg_rev(0,2000,500,False)

In [None]:
train_for_cg_rev(0,3000,500,False)

In [None]:
# Estimate the reversible transition matrix at the coarse graining level
id=0
chi_X1_train = torch.Tensor(pred_batchwise(tensor_test_X1)).to(device)
chi_X2_train = torch.Tensor(pred_batchwise(tensor_test_X2)).to(device)

u, u_t, v, C_00, C_11, C_01, Sigma, mu_t, Sigma_t = Full_net.u_layers[0](chi_X1_train, chi_X2_train)

matrix, K, S = Full_net.S_layers[0](v, C_00, C_11, C_01, Sigma)


for i in range(id):
    chi_X1_train, chi_X2_train, u, u_t, S, mu_t, Sigma_t, K, VampE_matrix = Full_net.coarse_grain_layer[i].get_cg_uS(
                                                                    chi_X1_train, chi_X2_train, u, S, u_t, renorm=Full_net.renorm)

chi_X1_train = chi_X1_train.detach()
chi_X2_train = chi_X2_train.detach()
u = u.detach()
u_t = u_t.detach()
S = S.detach()

score_curr = vampe_loss_rev_cg(chi_X1_train, chi_X2_train, u, S, u_t, id, 
                                   return_mu=False, return_mu_K_Sigma=True, renorm=Full_net.renorm)

In [None]:
# the eigenvalues of the coarse-grained model should recover the largest eigenvalues of K_msm
torch.eig(score_curr[2]), print(np.linalg.eigvals(K_msm))

### Train for the second coarse-graining layer

In [None]:
# Full_net.coarse_grain_layer[1].reset_params()
set_cg(1,1)
plot_cg(1)
# optimizer_cg[1] = optim.Adam(Full_net.get_params_cg([1]), lr=0.01)

In [None]:
optimizer_cg[1] = optim.Adam(Full_net.get_params_cg([1]), lr=0.1)

In [None]:
train_for_cg_rev(1,2000,500,False)

In [None]:
train_for_cg_rev(1,3000,500,False)

In [None]:
id=1
chi_X1_train = torch.Tensor(pred_batchwise(tensor_test_X1)).to(device)
chi_X2_train = torch.Tensor(pred_batchwise(tensor_test_X2)).to(device)

u, u_t, v, C_00, C_11, C_01, Sigma, mu_t, Sigma_t = Full_net.u_layers[0](chi_X1_train, chi_X2_train)

matrix, K, S = Full_net.S_layers[0](v, C_00, C_11, C_01, Sigma)


for i in range(id):
    chi_X1_train, chi_X2_train, u, u_t, S, mu_t, Sigma_t, K, VampE_matrix = Full_net.coarse_grain_layer[i].get_cg_uS(
                                                                    chi_X1_train, chi_X2_train, u, S, u_t, renorm=Full_net.renorm)

chi_X1_train = chi_X1_train.detach()
chi_X2_train = chi_X2_train.detach()
u = u.detach()
u_t = u_t.detach()
S = S.detach()

score_curr = vampe_loss_rev_cg(chi_X1_train, chi_X2_train, u, S, u_t, id, 
                                   return_mu=False, return_mu_K_Sigma=True, renorm=Full_net.renorm)

In [None]:
torch.eig(score_curr[2]), print(np.linalg.eigvals(K_msm))

In [None]:
weights_after_cg = Full_net.get_weights()

### Train the coarse graining matrices to optimize the score of all 3 models

In [None]:
def train_for_everything_rev(opt_list, runs, plot_mask_every=10, verbose=True, plot_training=True):
    
    chi_t_in = torch.Tensor(pred_batchwise(tensor_train_X1)).to(device)
    chi_tau_in = torch.Tensor(pred_batchwise(tensor_train_X2)).to(device)
    
    epoch_loss = np.zeros(runs)
    epoch_loss_valid = np.zeros(runs)
    
    epoch_loss = np.zeros(runs)
    epoch_loss_6 = np.zeros(runs)
    epoch_loss_3 = np.zeros(runs)
    epoch_loss_2 = np.zeros(runs)
    
    for epoch in range(runs):  # loop over the dataset multiple times

#         Full_net.
#         opt = optimizer_full
    #     sen_temp = sen_set[epoch//sen_every]

#         Full_net.set_soft_fac(sen[[epoch]])


        running_epoch_loss = []
        running_epoch_loss_6 = []
        running_epoch_loss_3 = []
        running_epoch_loss_2 = []
        

            # zero the parameter gradients
        chi_t = chi_t_in
        chi_tau = chi_tau_in
        
        for opt in opt_list:
            opt.zero_grad()
        score_curr_list = []
        
        u, u_t, v, C_00, C_11, C_01, Sigma, mu_t, Sigma_t = Full_net.u_layers[0](chi_t, chi_tau)

        VampE_matrix, K, S = Full_net.S_layers[0](v, C_00, C_11, C_01, Sigma)
        score_curr_list.append(-torch.trace(VampE_matrix))
        for j in range(len(output_sizes)-1):
            chi_t, chi_tau, u, u_t, S, mu_t, Sigma_t, K, VampE_matrix = Full_net.coarse_grain_layer[j].get_cg_uS(
                                                                    chi_t, chi_tau, u, S, u_t, Full_net.renorm)
            score_curr_list.append(-torch.trace(VampE_matrix))

        loss = - score_curr_list[0] - score_curr_list[1] - score_curr_list[2]

        loss.backward()
        for opt in opt_list:
            opt.step()
#             print(i, score_curr_list, inputs_t.shape)
        running_epoch_loss.append(-loss.item())
        running_epoch_loss_6.append(score_curr_list[0].item())
        running_epoch_loss_3.append(score_curr_list[1].item())
        running_epoch_loss_2.append(score_curr_list[2].item())


        epoch_loss[epoch] = np.mean(running_epoch_loss)
        epoch_loss_6[epoch] = np.mean(running_epoch_loss_6)
        epoch_loss_3[epoch] = np.mean(running_epoch_loss_3)
        epoch_loss_2[epoch] = np.mean(running_epoch_loss_2)
        
        if verbose:
            print('Run {}, total loss: {:.3}, , 6 loss: {:.3}, , 3 loss: {:.3}, , 2 loss: {:.3}'.format(epoch+1, 
                                    epoch_loss[epoch], epoch_loss_6[epoch], epoch_loss_3[epoch], epoch_loss_2[epoch]))

        if (((epoch+1) % plot_mask_every)==0):
            plot_cg(0)
            plot_cg(1)
#         train_for_S(tensor_train_X1, tensor_train_X2, optimizer_rev_S, runs=1000, verbose=False, plot_training=True)

#         train_for_u_S(tensor_train_X1, tensor_train_X2, optimizer_rev_u_S, runs=100, verbose=False, plot_training=True)
        
#     train_for_cg_rev(0, tensor_train_X1, tensor_train_X2, optimizer_cg[0], 1000, plot_mask_every=999,verbose=False,plot_training=True)
#     train_for_cg_rev(1, tensor_train_X1, tensor_train_X2, optimizer_cg[1], 1000, plot_mask_every=999,verbose=False,plot_training=True)
    if plot_training:
        plt.plot(np.arange(1,runs+1), epoch_loss, label='VAMP_loss')
        plt.legend()
        plt.show() 
        plt.plot(np.arange(1,runs+1), epoch_loss_6, label='VAMP_loss 6')
        plt.legend()
        plt.show() 
        plt.plot(np.arange(1,runs+1), epoch_loss_3, label='VAMP_loss 3')
        plt.legend()
        plt.show() 
        plt.plot(np.arange(1,runs+1), epoch_loss_2, label='VAMP_loss 2')
        plt.legend()
        plt.show() 
        
def get_params_all_rev(coarse=False, u=False, S=False):
    if coarse:
        for layer in Full_net.coarse_grain_layer:
            for param in layer.parameters():
                yield param
    if u:
        for param in Full_net.u_layers[0].parameters():
            yield param
    if S:
        for param in Full_net.S_layers[0].parameters():
            yield param



In [None]:
opt1 = optim.Adam(get_params_all_rev(coarse=True), lr=learning_rate*100)
opt2 = optim.Adam(get_params_all_rev(u=True), lr=learning_rate*10)
opt3 = optim.Adam(get_params_all_rev(S=True), lr=learning_rate*100)

In [None]:
# weights_temp = Full_net.get_weights()
Full_net.set_weights(weights_after_cg)

In [None]:
train_for_everything_rev([opt1, opt2, opt3], 1000, plot_mask_every=500, verbose=True)

In [None]:
weights_temp = Full_net.get_weights()

In [None]:
weights_final = Full_net.get_weights()

In [None]:
Full_net.set_weights(weights_final)

### Prepare to plot the hierarchical model

In [None]:
# get the coarse graining matrices
list_weights = []
for id in range(len(output_sizes)-1):
    attention = Full_net.coarse_grain_layer[id].get_softmax()
    attention_np = attention.detach().to('cpu').numpy()
    list_weights.append(attention_np)

In [None]:
def get_pi(K):
    eigvals, eigvec = np.linalg.eig(K.T)
    
    sort_id = np.argsort(eigvals)
    eigvals = eigvals[sort_id]
    eigvec = eigvec[:,sort_id]
    pi = eigvec[:,-1]
    pi = pi/pi.sum()
    print(eigvals[-1], pi)
    
    return pi

In [None]:
def get_K_rev_cg(id=0, tensor_t=torch.Tensor(traj_whole_new[0][:-tau]), tensor_tau=torch.Tensor(traj_whole_new[0][tau:])):
    
    chi_t = torch.Tensor(pred_batchwise(tensor_t)).to(device)
    chi_tau = torch.Tensor(pred_batchwise(tensor_tau)).to(device)
    u, u_t, v, C_00, C_11, C_01, Sigma, mu_t, Sigma_t = Full_net.u_layers[0](chi_t, chi_tau)
    matrix, K, S = Full_net.S_layers[0](v, C_00, C_11, C_01, Sigma)
    
    for i in range(id):
        chi_t, chi_tau, u, u_t, S, mu_t, Sigma_t, K, VampE_matrix = Full_net.coarse_grain_layer[i].get_cg_uS(
                                                                        chi_t, chi_tau, u, S, u_t, Full_net.renorm)
    
    return K.detach().to('cpu').numpy()

In [None]:
def get_mu_rev_cg(id=0, tensor_t=torch.Tensor(traj_whole_new[0][:-tau]), tensor_tau=torch.Tensor(traj_whole_new[0][tau:])):
    
    chi_t = torch.Tensor(pred_batchwise(tensor_t)).to(device)
    chi_tau = torch.Tensor(pred_batchwise(tensor_tau)).to(device)
    u, u_t, v, C_00, C_11, C_01, Sigma, mu_t, Sigma_t = Full_net.u_layers[0](chi_t, chi_tau)
    matrix, K, S = Full_net.S_layers[0](v, C_00, C_11, C_01, Sigma)
    
    for i in range(id):
        chi_t, chi_tau, u, u_t, S, mu_t, Sigma_t, K, VampE_matrix = Full_net.coarse_grain_layer[i].get_cg_uS(
                                                                        chi_t, chi_tau, u, S, u_t, Full_net.renorm)
    
    return mu_t.detach().to('cpu').numpy()

In [None]:
def get_state_probs_rev_cg(id=0, tensor_t=torch.Tensor(traj_whole_new[0][:-tau]), tensor_tau=torch.Tensor(traj_whole_new[0][tau:])):
    
    chi_t = torch.Tensor(pred_batchwise(tensor_t)).to(device)
    chi_tau = torch.Tensor(pred_batchwise(tensor_tau)).to(device)
    u, u_t, v, C_00, C_11, C_01, Sigma, mu_t, Sigma_t = Full_net.u_layers[0](chi_t, chi_tau)
    matrix, K, S = Full_net.S_layers[0](v, C_00, C_11, C_01, Sigma)
    
    for i in range(id):
        chi_t, chi_tau, u, u_t, S, mu_t, Sigma_t, K, VampE_matrix = Full_net.coarse_grain_layer[i].get_cg_uS(
                                                                        chi_t, chi_tau, u, S, u_t, Full_net.renorm)
    state_probs = torch.sum(chi_t * mu_t, dim=0)
    return state_probs.detach().to('cpu').numpy()

In [None]:
# estimate all transition matrices
K_6 = get_K_rev_cg(0)
K_3 = get_K_rev_cg(1)
K_2 = get_K_rev_cg(2)

In [None]:
# check if the eigenvalues match
np.linalg.eigvals(K_6), np.linalg.eigvals(K_3), np.linalg.eigvals(K_2)

In [None]:
# estimate the stationary distribution of each state
pi = []
pi.append(get_pi(K_6))
pi.append(get_pi(K_3))
pi.append(get_pi(K_2))

In [None]:
order = []
order_pi = True

max_matrix = False
fontsize_l = 36 
fontsize_s = 32
for t, coarse_matrix in enumerate(list_weights[::-1]):
    if max_matrix:
        vmax = coarse_matrix.max()
    else:
        vmax = 1.
    if t == 0:
#         order.append(np.arange(coarse_matrix.shape[1]))
        order.append(np.argsort(pi[-1]))
    sort_temp = order[-1]
#     max_index = np.argmax(coarse_matrix[:,sort_temp], axis=1)
#     sort_index = np.argsort(max_index)
    sort_in = []
    for index_out in sort_temp:
        max_out_state = np.where(np.argmax(coarse_matrix, axis=1)==index_out)[0]
        print(t, max_out_state)
        print(pi[-2-t][max_out_state]*coarse_matrix[max_out_state,index_out])
        if order_pi:
            ratio_prob = pi[-2-t][max_out_state]*coarse_matrix[max_out_state,index_out]
        else:
            ratio_prob = coarse_matrix[max_out_state,index_out]
        prob_sort = np.argsort(ratio_prob)
        sort_in.append(max_out_state[prob_sort])
    sort_index = np.concatenate(sort_in)
    order.append(sort_index)
    
    plt.imshow((coarse_matrix[sort_index][:,sort_temp]).T**10, vmin=0., vmax=1, cmap=plt.cm.Reds)
    plt.xticks(np.arange(coarse_matrix.shape[0]), fontsize=fontsize_s)
    plt.yticks(np.arange(coarse_matrix.shape[1]), fontsize=fontsize_s)
    plt.ylabel('To state', fontsize=fontsize_l)
    plt.xlabel('From state', fontsize=fontsize_l)
#     plt.savefig('./figs/matrix_{}to{}_{}_{}.svg'.format(coarse_matrix.shape[0], coarse_matrix.shape[1], patchsize, test), bbox_inches='tight')
    plt.show()

In [None]:
# get stationary distribution for each model, they should be close
mu_6 = get_mu_rev_cg(0)
mu_3 = get_mu_rev_cg(1)
mu_2 = get_mu_rev_cg(2)

In [None]:
state_probs_6 = get_state_probs_rev_cg(0)
state_probs_3 = get_state_probs_rev_cg(1)
state_probs_2 = get_state_probs_rev_cg(2)

In [None]:
# reorder them accordingly
output_ave = []
output_ave.append(get_pi(K_6)[order[-1]])
output_ave.append(get_pi(K_3)[order[-2]])
output_ave.append(get_pi(K_2)[order[-3]])
print(output_ave)

In [None]:
# reorder the coarse graining matrix aswell
list_weights_new = []
for t, weights in enumerate(list_weights):
    
    list_weights_new.append(weights[order[-t-1]][:,order[-t-2]])

In [None]:
# define the names of the states
names_6 = ['M','F','PF','U']
names_3 = ['M', 'F', 'U']
names_2 = ['M', 'U']
names = [names_6, names_3, names_2]

In [None]:
def draw_neural_net(ax, left, right, bottom, top, layer_sizes, activation, weights = None, max_plus_out = None, bias=None, fontsize=12, names=names):
    '''
    Draw a neural network cartoon using matplotilb.
    
    :usage:
        >>> fig = plt.figure(figsize=(12, 12))
        >>> draw_neural_net(fig.gca(), .1, .9, .1, .9, [4, 7, 2])
    
    :parameters:
        - ax : matplotlib.axes.AxesSubplot
            The axes on which to plot the cartoon (get e.g. by plt.gca())
        - left : float
            The center of the leftmost node(s) will be placed here
        - right : float
            The center of the rightmost node(s) will be placed here
        - bottom : float
            The center of the bottommost node(s) will be placed here
        - top : float
            The center of the topmost node(s) will be placed here
        - layer_sizes : list of int
            List of layer sizes, including input and output dimensionality
    '''
    fontsize=28
    n_layers = len(layer_sizes)
    v_spacing = ((top - bottom)/float(len(layer_sizes) - 1))/2
    h_spacing = (right - left)/float(max(layer_sizes))
    # Nodes
    for n, layer_size in enumerate(layer_sizes):
        names_i = names[n]
        if(max_plus_out is None):
            max_plus = max(activation[n])
        else:
            max_plus = max_plus_out[n]
        max_min = min(activation[n])
        #print(max_plus)
        if(max_plus == 0):
            max_plus = 1.
        layer_top = v_spacing*(layer_size - 1)/2. + (top + bottom)/2.
        left_layer =  (left + right)/2. - h_spacing*(layer_size - 1)/2.
        for m in range(layer_size):
            x = activation[n][m]#/max_plus
            if(max_min >= 0):
                r = 1
#                 b = 1 - (x)**0.8
#                 g = 1 - x**0.8
                b = 1-(.25)**0.8
                g = 1-(.25)**0.8
            else:
                if(activation[n][m] <0):
                    r = activation[n][m]/max_min
                    b = 0
                    g = 0.25
                else:
                    r = 0
                    b = x
                    g = 0.25
            radius = h_spacing/3. * x
            circle = plt.Circle((left_layer + m*h_spacing, -n*v_spacing + top),radius,
                                color=(r,g,b), ec='k', zorder=4)
            ax.add_artist(circle)
            names_new = names_i[m]
            if x >= 0.1:
                if len(names_new)<2:
                    ax.text(left_layer + m*h_spacing -h_spacing/20.,-n*v_spacing + top-v_spacing/32., '{}'.format(names_new), zorder=10, fontsize=fontsize)
                else:
                    ax.text(left_layer + m*h_spacing -h_spacing/15.,-n*v_spacing + top-v_spacing/32., '{}'.format(names_new), zorder=10, fontsize=fontsize)
            else:
                ax.text(left_layer + m*h_spacing -h_spacing/8.,-n*v_spacing + top-v_spacing/32., '{}'.format(names_new), zorder=10, fontsize=fontsize)
    # Edges
    for n, (layer_size_a, layer_size_b) in enumerate(zip(layer_sizes[:-1], layer_sizes[1:])):
        layer_top_a = v_spacing*(layer_size_a - 1)/2. + (top + bottom)/2.
        layer_top_b = v_spacing*(layer_size_b - 1)/2. + (top + bottom)/2.
        left_layer_a =  (left + right)/2. - h_spacing*(layer_size_a - 1)/2.
        left_layer_b =  (left + right)/2. - h_spacing*(layer_size_b - 1)/2.
        #print(layer_size_a)
        #print(layer_size_b)
        for m in range(layer_size_a):
            for o in range(layer_size_b):
                start_x = left_layer_a + m*h_spacing
                end_x = left_layer_b + o*h_spacing
                start_y = -n*v_spacing + top
                end_y = -(n+1)*v_spacing + top
                if(weights == None):
                    line = plt.Line2D([start_x, end_x],
                                      [start_y, end_y], c='k', label = 'line')
                else:
                    if(weights[n][m][o] > 0):
                        c = 'k'
                        ls = '-'
                    else:
                        c = 'grey'
                        ls = '--'
                    if weights[n][m][o] > 0.01:  
#                     line = plt.Line2D([start_x , end_x],
#                                       [start_y, end_y], c=c, label = 'line', linewidth = weights[n][m][o], linestyle = ls)
                        if weights[n][m][o]>0.5:
                            lw=1
                        elif weights[n][m][o]>0.25:
                            lw=0.5
                        else:
                            lw=0.25
                        line = plt.Line2D([start_x , end_x],
                                      [start_y, end_y], c=c, label = 'line', linewidth = lw, linestyle = ls)
                diff_x = end_x - start_x
                diff_y = end_y - start_y
                fac = 1/5#*layer_size_b/2
#                 fac2 = 1/layer_size_b
                if weights[n][m][o] > 0.01:
                    if weights[n][m][o] >= 0.999:
                        ax.text(start_x + diff_x*fac + (m*o)*fac/256, -n*v_spacing + top - v_spacing*fac, '{:.0f}'.format(weights[n][m][o]*100), fontsize=fontsize)
                    else:
                        ax.text(start_x + diff_x*fac + (m*o)*fac/256, -n*v_spacing + top - v_spacing*fac, '{:.1f}'.format(weights[n][m][o]*100), fontsize=fontsize)
                ax.add_artist(line)

In [None]:
fig = plt.figure(figsize=(24, 24))
ax = fig.gca()
ax.axis('off')
fontsize = 20
draw_neural_net(ax, .1, .9, .1, .9, output_sizes, output_ave, weights=list_weights_new, max_plus_out=[1.]*len(output_sizes), fontsize=fontsize, names=names)
# fig.savefig('./figs/coarse_grained_graph_{}_{}.svg'.format(patchsize, test), bbox_inches='tight')

### Save structures for each state of each model plus attention values as bfactors

In [None]:
def structures(cg_ind, vmax=1, skip=10, top=10):
    n_residues = int(-1/2 + np.sqrt(1/4+input_size*2) + 3)
    if Full_net.mask_const:
        attention = Full_net.get_attention()
        attention_np = attention.detach().numpy()
        att_atom = np.reshape(attention_np, (n_residues,1))
        plt.imshow(att_atom, vmin=0, vmax=vmax, aspect='auto')
        plt.xlabel('System', fontsize=18)
        plt.ylabel('Input', fontsize=18)
        plt.xticks(np.arange(1),['{}'.format(i) for i in range(1)], fontsize=16)
        plt.yticks(np.arange(0,n_residues,skip),['x{}'.format(i) for i in range(0,n_residues,skip)], fontsize=16)
        plt.show()
    #     plt.savefig('./Figs/2x3_mix_Mask.pdf', bbox_inches='tight')
        
        return att_atom
        
    else:
        pred_temp = pred_batchwise(traj_whole_new[0], batchsize=10000)
        pred_temp = torch.Tensor(pred_temp).to(device)
        for i in range(cg_ind):
            pred_temp = Full_net.forward_cg(pred_temp, i)
        pred_temp = pred_temp.detach().to('cpu').numpy()
        print(pred_temp.shape)
        arg_sort = np.argsort(pred_temp, axis=0)
        top_x_state = arg_sort[-top:][::-1]
        states = pred_temp.shape[1]
        att_atom = []
        for state in range(states):
            frames = top_x_state[:,state]
            attention = Full_net.get_attention(torch.Tensor(traj_whole_new[0][frames]).to(device))
            attention_np = attention.detach().to('cpu').numpy()
#             att_atom.append(np.mean(attention_np, axis=0, keepdims=True))
            print(attention_np.shape)
            att_atom.append(attention_np[None,:,:])
        att_residue = np.concatenate(att_atom, axis=0)
        
        plt.imshow(np.mean(att_residue,axis=1).T, vmin=0, vmax=vmax, aspect='auto')
        plt.xlabel('State', fontsize=18)
        plt.ylabel('Input', fontsize=18)
        plt.xticks(np.arange(states),['{}'.format(i) for i in range(states)], fontsize=16)
        plt.yticks(np.arange(0,n_residues,skip),['x{}'.format(i) for i in range(0,n_residues,skip)], fontsize=16)
        plt.show()
    #     plt.savefig('./Figs/2x3_mix_Mask.pdf', bbox_inches='tight')
        
        return att_residue, top_x_state



In [None]:
traj_nr=0
threshold = 0
for out_ind in range(len(output_sizes)):
    attention, frames = structures(out_ind, top=number_of_frames, vmax=3)
    for o_temp in range(frames.shape[1]):
        order_o = order[-out_ind-1]
        o = order_o[o_temp]
        print(o)
        attention_clean = np.zeros_like(attention[o])
        for f in range(number_of_frames):
            for i in range(n_residues):
                if attention[o,f,i] < threshold:
                    attention_clean[f,i]=0
                else:
                    attention_clean[f,i]=attention[o,f,i] 
        attention_fixed = np.concatenate([attention_clean, attention_clean], axis=1)
        bfactors = np.repeat(attention_fixed, [res.n_atoms for res in md_traj_super.top.residues], axis=1)

        frames_of_state = frames[:,o]
        frames_per_file = 10000
        for i in range(number_of_frames):
            file_number = frames_of_state[i]//frames_per_file
            file_number_frame = frames_of_state[i]%frames_per_file
        #             print(numbers_all[i], traj_nr[i], file_number, file_number_frame)
            root_new = '/group/ag_cmb/simulation-data/DESRES-Science2011-FastProteinFolding/DESRES-Trajectory_{0}-{1}-protein/{0}-{1}-protein/{0}-{1}-protein'.format(test_system, traj_nr)
            md_traj_start = md.load_dcd(root_new+'-{:03}.dcd'.format(file_number),
                                  top='/group/ag_cmb/scratch/deeptime_data/{0}/system-protein.pdb'.format(test_system)
                            )
            md_traj_small.xyz[i] = md_traj_start.xyz[file_number_frame]

        md_traj_small.superpose(md_traj_super) 
        dssp = md.compute_dssp(md_traj_small, simplified=True)[:,:n_residues]

        index_heavy_2nd, list_res = get_heavy_atoms_dssp(dssp)
        print('Index of residues in 2nd structure', np.arange(n_residues)[list_res])
        if index_heavy_2nd:
            md_traj_small_temp = md_traj_small.superpose(md_traj_small, frame=0, atom_indices=index_heavy_2nd)
            print('found secondary structure and aligning with {} atoms'.format(len(index_heavy_2nd)))
        else:
            md_traj_small_temp = md_traj_small.superpose(md_traj_small, frame=0, atom_indices=None)
            print('did not find secondary structure and aligning with pdb {} atoms'.format(len(index_heavy_2nd)))
        
        md_traj_small_temp.save_pdb('/group/ag_cmb/scratch/deeptime_data/{0}/attention/new_{1}_{2}_smooth{3}_{4}.pdb'.format(test_system, output_sizes[out_ind], o_temp, patchsize, test), bfactors=bfactors)

### Timescale calculation for the coarse graining models

In [None]:
def get_params_all_rev():
        
    for layer in Full_net.coarse_grain_layer:
        for param in layer.parameters():
            yield param

    for param in Full_net.u_layers[0].parameters():
        yield param
    
    for param in Full_net.S_layers[0].parameters():
        yield param
        
def vampe_loss_rev(chi_t, chi_tau, layer_id=0, return_mu=False, return_mu_K_Sigma=False):
    
    u, u_t, v, C_00, C_11, C_01, Sigma, mu_t, Sigma_t = Full_net.u_layers[layer_id](chi_t, chi_tau)
    matrix, K, _ = Full_net.S_layers[layer_id](v, C_00, C_11, C_01, Sigma)
    vampe = torch.trace(matrix)
    
    if return_mu:
        return -vampe, mu_t
    
    elif return_mu_K_Sigma:
        return -vampe, mu_t, K, Sigma_t
    
    else:
        return -vampe
    
def vampe_loss_rev_only_S(v, C_00, C_11, C_01, Sigma, layer_id=0):
    
    matrix, K, _ = Full_net.S_layers[layer_id](v, C_00, C_11, C_01, Sigma)
#     print(K)
    vampe = torch.trace(matrix)
    
    return -vampe

def vampe_loss_rev_cg(chi_t, chi_tau, u, S, u_t, layer_id=0, return_mu=False, return_mu_K_Sigma=False, renorm=True):
    
    
    # only this line should be the part of it
    chi_t_m, chi_tau_m, u_m, u_t_m, S_m, mu_t, Sigma_t, K, VampE_matrix = Full_net.coarse_grain_layer[layer_id].get_cg_uS(
                                                                        chi_t, chi_tau, u, S, u_t, renorm)
    
    vampe = torch.trace(VampE_matrix)
    
    if return_mu:
        return -vampe, mu_t
    
    elif return_mu_K_Sigma:
        return -vampe, mu_t, K, Sigma_t
    
    else:
        return -vampe
    
def train_for_S(chi_t, chi_tau, opt, runs=100, verbose=True, plot_training=True):
    
#     chi_t = torch.Tensor(pred_batchwise(tensor_train_X1)).to(device)
#     chi_tau = torch.Tensor(pred_batchwise(tensor_train_X2)).to(device)
    _, _, v, C_00, C_11, C_01, Sigma, _, _ = Full_net.u_layers[0](chi_t, chi_tau)
    
    v = v.detach()
    C_00 = C_00.detach()
    C_11 = C_11.detach()
    C_01 = C_01.detach()
    Sigma = Sigma.detach()
    chi_t.detach()
    chi_tau.detach()
    
    chi_t_valid = torch.Tensor(pred_batchwise(tensor_valid_X1)).to(device)
    chi_tau_valid = torch.Tensor(pred_batchwise(tensor_valid_X2)).to(device)
    _, _, v_valid, C_00_valid, C_11_valid, C_01_valid, Sigma_valid, _, _ = Full_net.u_layers[0](chi_t_valid, chi_tau_valid)
    
    v_valid = v_valid.detach()
    C_00_valid = C_00_valid.detach()
    C_11_valid = C_11_valid.detach()
    C_01_valid = C_01_valid.detach()
    Sigma_valid = Sigma_valid.detach()
    chi_t_valid.detach()
    chi_tau_valid.detach()
    
    epoch_loss = np.zeros(runs)
    epoch_loss_valid = np.zeros(runs)
#     opt = optimizer_rev_S
    for epoch in range(runs):  # loop over the dataset multiple times

        opt.zero_grad()

        # estimate weights
#             print(inputs_t)

#             print(chi_t)

        score_curr = vampe_loss_rev_only_S(v, C_00, C_11, C_01, Sigma)



        loss = - score_curr

        loss.backward()

        opt.step()


        epoch_loss[epoch] = np.mean(-loss.item())
        
        score_curr_valid = vampe_loss_rev_only_S(v_valid, C_00_valid, C_11_valid, C_01_valid, Sigma_valid)
        epoch_loss_valid[epoch] = np.mean(score_curr_valid.item())
        if verbose:
            print('Run {}, total loss: {:.3}, valid loss: {:.3}'.format(epoch+1, epoch_loss[epoch], epoch_loss_valid[epoch]))

    if plot_training:
        plt.plot(np.arange(1,runs+1), epoch_loss, label='VAMP_loss')
        plt.plot(np.arange(1,runs+1), epoch_loss_valid, label='VAMP_loss_valid')
        plt.legend()
        plt.show()    
def train_for_u_S(chi_t, chi_tau, opt, runs=100, verbose=True, plot_training=True):
    
#     chi_t = torch.Tensor(pred_batchwise(tensor_train_X1)).to(device).detach()
#     chi_tau = torch.Tensor(pred_batchwise(tensor_train_X2)).to(device).detach()
    
    chi_t_valid = torch.Tensor(pred_batchwise(tensor_valid_X1)).to(device).detach()
    chi_tau_valid = torch.Tensor(pred_batchwise(tensor_valid_X2)).to(device).detach()
    
    
    
    
    epoch_loss = np.zeros(runs)
    epoch_loss_valid = np.zeros(runs)
#     opt = optimizer_rev_u_S
    for epoch in range(runs):  # loop over the dataset multiple times

        opt.zero_grad()

        # estimate weights
#             print(inputs_t)

#             print(chi_t)
        score_curr = vampe_loss_rev(chi_t, chi_tau)

        loss = - score_curr

        loss.backward()

        opt.step()

        score_curr_valid = vampe_loss_rev(chi_t_valid, chi_tau_valid)
        epoch_loss[epoch] = np.mean(-loss.item())
        epoch_loss_valid[epoch] = np.mean(score_curr_valid.item())
        if verbose:
            print('Run {}, total loss: {:.3}, valid loss: {:.3}'.format(epoch+1, epoch_loss[epoch], epoch_loss_valid[epoch]))

    if plot_training:
        plt.plot(np.arange(1,runs+1), epoch_loss, label='VAMP_loss')
        plt.plot(np.arange(1,runs+1), epoch_loss_valid, label='VAMP_loss_valid')
        plt.legend()
        plt.show()    


def train_for_everything_rev(chi_t_in, chi_tau_in, opt, runs, plot_mask_every=10, verbose=True, plot_training=True):
    
    
    epoch_loss = np.zeros(runs)
    epoch_loss_valid = np.zeros(runs)
    
    epoch_loss = np.zeros(runs)
    epoch_loss_6 = np.zeros(runs)
    epoch_loss_3 = np.zeros(runs)
    epoch_loss_2 = np.zeros(runs)
    
    for epoch in range(runs):  # loop over the dataset multiple times

#         Full_net.
        opt = opt
    #     sen_temp = sen_set[epoch//sen_every]

#         Full_net.set_soft_fac(sen[[epoch]])


        running_epoch_loss = []
        running_epoch_loss_6 = []
        running_epoch_loss_3 = []
        running_epoch_loss_2 = []
        

            # zero the parameter gradients
        chi_t = chi_t_in
        chi_tau = chi_tau_in
        opt.zero_grad()
        score_curr_list = []
        
        u, u_t, v, C_00, C_11, C_01, Sigma, mu_t, Sigma_t = Full_net.u_layers[0](chi_t, chi_tau)

        VampE_matrix, K, S = Full_net.S_layers[0](v, C_00, C_11, C_01, Sigma)
        score_curr_list.append(-torch.trace(VampE_matrix))
        for j in range(len(output_sizes)-1):
            chi_t, chi_tau, u, u_t, S, mu_t, Sigma_t, K, VampE_matrix = Full_net.coarse_grain_layer[j].get_cg_uS(
                                                                    chi_t, chi_tau, u, S, u_t, Full_net.renorm)
            score_curr_list.append(-torch.trace(VampE_matrix))

        loss = - score_curr_list[0] - score_curr_list[1] - score_curr_list[2]

        loss.backward()

        opt.step()
#             print(i, score_curr_list, inputs_t.shape)
        running_epoch_loss.append(-loss.item())
        running_epoch_loss_6.append(score_curr_list[0].item())
        running_epoch_loss_3.append(score_curr_list[1].item())
        running_epoch_loss_2.append(score_curr_list[2].item())


        epoch_loss[epoch] = np.mean(running_epoch_loss)
        epoch_loss_6[epoch] = np.mean(running_epoch_loss_6)
        epoch_loss_3[epoch] = np.mean(running_epoch_loss_3)
        epoch_loss_2[epoch] = np.mean(running_epoch_loss_2)
        
        if verbose:
            print('Run {}, total loss: {:.3}, , 6 loss: {:.3}, , 3 loss: {:.3}, , 2 loss: {:.3}'.format(epoch+1, 
                                    epoch_loss[epoch], epoch_loss_6[epoch], epoch_loss_3[epoch], epoch_loss_2[epoch]))

        if (((epoch+1) % plot_mask_every)==0):
            plot_cg(0)
            plot_cg(1)

    if plot_training:
        plt.plot(np.arange(1,runs+1), epoch_loss, label='VAMP_loss')
        plt.legend()
        plt.show() 
        plt.plot(np.arange(1,runs+1), epoch_loss_6, label='VAMP_loss 6')
        plt.legend()
        plt.show() 
        plt.plot(np.arange(1,runs+1), epoch_loss_3, label='VAMP_loss 3')
        plt.legend()
        plt.show() 
        plt.plot(np.arange(1,runs+1), epoch_loss_2, label='VAMP_loss 2')
        plt.legend()
        plt.show() 

def get_Ks_for_tau(tau_i):
    X1_train_cor, X2_train_cor, X1_vali_cor, X2_vali_cor, X1_test_cor, X2_test_cor, length_train, length_vali, _, _ = get_data_for_tau(traj_whole_new, tau_i)


    tensor_train_X1 = torch.Tensor(X1_train_cor)
    tensor_train_X2 = torch.Tensor(X2_train_cor) # transform to torch tensor
    tensor_valid_X1 = torch.Tensor(X1_vali_cor)
    tensor_valid_X2 = torch.Tensor(X2_vali_cor)
    tensor_test_X1 = torch.Tensor(X1_test_cor)
    tensor_test_X2 = torch.Tensor(X2_test_cor)

    trainset = data.TensorDataset(tensor_train_X1, tensor_train_X2) # create your datset
    trainloader = data.DataLoader(trainset, batch_size=batch_size,
                                  shuffle=True, num_workers=2)
    Full_net.set_weights(weights_final)
    chi_t = torch.Tensor(pred_batchwise(tensor_valid_X1)).to(device).detach()
    chi_tau = torch.Tensor(pred_batchwise(tensor_valid_X2)).to(device).detach()
    score_old = vampe_loss_rev(chi_t, chi_tau).item()
    
    print(score_old)
    
    weights_temp = Full_net.get_weights()
    with torch.no_grad():
        for param in Full_net.S_layers[0].parameters():
            param.copy_(torch.Tensor(np.ones((output_sizes[0], output_sizes[0]))))
    with torch.no_grad():
        for param in Full_net.u_layers[0].parameters():
            param.copy_(torch.Tensor(np.ones((1, output_sizes[0]))))
            
    optimizer_rev_u_S_cg = optim.Adam(Full_net.get_params_rev(all=False), lr=learning_rate/10)
    optimizer_rev_S_cg = optim.Adam(Full_net.get_params_rev(all=False, u_flag=False), lr=learning_rate*10)
    optimizer_rev_u_S = optim.Adam(Full_net.get_params_rev(all=False), lr=learning_rate/10)
    optimizer_rev_S = optim.Adam(Full_net.get_params_rev(all=False, u_flag=False), lr=learning_rate*10)
    
#     Full_net.coarse_grain_layer[0].reset_params()
#     optimizer_cg[0] = optim.Adam(Full_net.get_params_cg([0]), lr=0.1)
#     Full_net.coarse_grain_layer[1].reset_params()
#     optimizer_cg[1] = optim.Adam(Full_net.get_params_cg([1]), lr=0.1)
    
    chi_t = torch.Tensor(pred_batchwise(tensor_valid_X1)).to(device).detach()
    chi_tau = torch.Tensor(pred_batchwise(tensor_valid_X2)).to(device).detach()
    score_old = vampe_loss_rev(chi_t, chi_tau).item()
    
    print(score_old)
    
    chi_t = torch.Tensor(pred_batchwise(tensor_train_X1)).to(device).detach()
    chi_tau = torch.Tensor(pred_batchwise(tensor_train_X2)).to(device).detach()
            
    flag=True
    counter = 0
    score_curr_list=[]
    chi_t_v = torch.Tensor(pred_batchwise(tensor_train_X1)).to(device).detach()
    chi_tau_v = torch.Tensor(pred_batchwise(tensor_train_X2)).to(device).detach()
    u, u_t, v, C_00, C_11, C_01, Sigma, mu_t, Sigma_t = Full_net.u_layers[0](chi_t_v, chi_tau_v)

    VampE_matrix, K, S = Full_net.S_layers[0](v, C_00, C_11, C_01, Sigma)
    score_curr_list.append(-torch.trace(VampE_matrix))
    for j in range(len(output_sizes)-1):
        chi_t_v, chi_tau_v, u, u_t, S, mu_t, Sigma_t, K, VampE_matrix = Full_net.coarse_grain_layer[j].get_cg_uS(
                                                                chi_t_v, chi_tau_v, u, S, u_t, Full_net.renorm)
        score_curr_list.append(-torch.trace(VampE_matrix))

    loss = - score_curr_list[0] - score_curr_list[1] - score_curr_list[2]
    score_old = - loss.item()
    print('Old score', score_old)
    while flag:
        score_curr_list = []
#         train_for_S(chi_t, chi_tau, optimizer_rev_S, runs=1000, verbose=False, plot_training=False)

#         train_for_u_S(chi_t, chi_tau, optimizer_rev_u_S, runs=1000, verbose=False, plot_training=False)
        train_for_everything_rev(chi_t, chi_tau, optimizer_rev_S_cg, runs=250,plot_mask_every=5000, verbose=False, plot_training=False)
        train_for_everything_rev(chi_t, chi_tau, optimizer_rev_u_S_cg, runs=250,plot_mask_every=5000, verbose=False, plot_training=False)
        if counter > 50:
            flag=False
        counter +=1
        chi_t_v = torch.Tensor(pred_batchwise(tensor_train_X1)).to(device).detach()
        chi_tau_v = torch.Tensor(pred_batchwise(tensor_train_X2)).to(device).detach()
        u, u_t, v, C_00, C_11, C_01, Sigma, mu_t, Sigma_t = Full_net.u_layers[0](chi_t_v, chi_tau_v)

        VampE_matrix, K, S = Full_net.S_layers[0](v, C_00, C_11, C_01, Sigma)
        score_curr_list.append(-torch.trace(VampE_matrix))
        for j in range(len(output_sizes)-1):
            chi_t_v, chi_tau_v, u, u_t, S, mu_t, Sigma_t, K, VampE_matrix = Full_net.coarse_grain_layer[j].get_cg_uS(
                                                                    chi_t_v, chi_tau_v, u, S, u_t, Full_net.renorm)
            score_curr_list.append(-torch.trace(VampE_matrix))

        loss = - score_curr_list[0] - score_curr_list[1] - score_curr_list[2]
        score = - loss.item()
        print(score)
        if score > score_old:
            weights_temp = Full_net.get_weights()
            score_old = score
#             print(score)
        else:
            flag=False
            Full_net.set_weights(weights_temp)
            print('Score decreased')
    


    K_list = []
    
    for o in range(len(output_sizes)):
        K_list.append(get_K_rev_cg(o, tensor_t=tensor_train_X1, tensor_tau=tensor_train_X2))
        print(np.linalg.eigvals(K_list[-1]))
    return K_list

In [None]:
step_size = tau*skip
max_tau = 10000
# lag = np.arange(50, max_tau, step_size)
# lag = (np.linspace(3.22,6.3, 10)**4).astype('int')
print(lag)
K_results_6 = np.ones((len(lag) ,output_sizes[0], output_sizes[0]))
K_results_3 = np.ones((len(lag) ,output_sizes[1], output_sizes[1]))
K_results_2 = np.ones((len(lag) ,output_sizes[2], output_sizes[2]))
for i, tau_i in enumerate(lag):
    print(tau_i)
#     K_results_rev[i]= training_for_tau_both(tau_i)
    K_list = get_Ks_for_tau(tau_i)
    K_results_6[i]  = K_list[0]
    K_results_3[i]  = K_list[1]
    K_results_2[i]  = K_list[2]

its_6 = get_its(K_results_6, lag, False)
its_3 = get_its(K_results_3, lag, False)
its_2 = get_its(K_results_2, lag, False)

### Plot timescales for each model

In [None]:
fac = 200.*skip*1e-6 
plt.figure(figsize=(6,4));

label_x = np.array([.1,0.3,1, 2, 5,10,100,1000])/fac # array is in microsecond
label_y = np.array([.1,1, 2, 5,10, 100, 1000])/fac
# fig = plt.figure(figsize = (8,8))
for j in range(0,output_sizes[0]-1):
    plt.semilogy(lag, its_6[::-1][j], lw=5)
#     plt.fill_between(tau_list, all_its_mean[i] -all_its_std[i], all_its_mean[i] + all_its_std[i], alpha = 0.3)
plt.semilogy(lag,lag, 'k')
plt.xlabel('lag [$\mu$s]', fontsize=26)
plt.xticks(label_x, label_x*fac, fontsize=22)
plt.ylabel('timescale [$\mu$s]', fontsize=26)
plt.yticks(label_y, np.round(label_y*fac, decimals=1), fontsize=22)
plt.fill_between(lag,lag,0.1,alpha = 0.2,color='k');
plt.ylim(0.01/fac, 3/fac)
plt.xlim(lag[0], 0.3/fac)
# plt.title('Villin', fontsize=16)
# plt.savefig('figs/its_{}_{}_{}.pdf'.format(output_sizes[0], patchsize, test), bbox_inches='tight')
plt.show()

In [None]:
fac = 200.*skip*1e-6 
plt.figure(figsize=(6,4));

label_x = np.array([.1,0.3,1, 2, 5,10,100,1000])/fac # array is in microsecond
label_y = np.array([.1,1, 2, 5,10, 100, 1000])/fac
# fig = plt.figure(figsize = (8,8))
for j in range(0,output_sizes[1]-1):
    plt.semilogy(lag, its_3[::-1][j], lw=5)
#     plt.fill_between(tau_list, all_its_mean[i] -all_its_std[i], all_its_mean[i] + all_its_std[i], alpha = 0.3)
plt.semilogy(lag,lag, 'k')
plt.xlabel('lag [$\mu$s]', fontsize=26)
plt.xticks(label_x, label_x*fac, fontsize=22)
plt.ylabel('timescale [$\mu$s]', fontsize=26)
plt.yticks(label_y, np.round(label_y*fac, decimals=1), fontsize=22)
plt.fill_between(lag,lag,0.1,alpha = 0.2,color='k');
plt.ylim(0.01/fac, 3/fac)
plt.xlim(lag[0], 0.3/fac)
# plt.title('Villin', fontsize=16)
# plt.savefig('figs/its_{}_{}_{}.pdf'.format(output_sizes[1], patchsize, test), bbox_inches='tight')
plt.show()

In [None]:
fac = 200.*skip*1e-6 
plt.figure(figsize=(6,4));

label_x = np.array([.1,0.3,1, 2, 5,10,100,1000])/fac # array is in microsecond
label_y = np.array([.1,1, 2, 5,10, 100, 1000])/fac
# fig = plt.figure(figsize = (8,8))
for j in range(0,output_sizes[2]-1):
    plt.semilogy(lag, its_2[::-1][j], lw=5)
#     plt.fill_between(tau_list, all_its_mean[i] -all_its_std[i], all_its_mean[i] + all_its_std[i], alpha = 0.3)
plt.semilogy(lag,lag, 'k')
plt.xlabel('lag [$\mu$s]', fontsize=26)
plt.xticks(label_x, label_x*fac, fontsize=22)
plt.ylabel('timescale [$\mu$s]', fontsize=26)
plt.yticks(label_y, np.round(label_y*fac, decimals=1), fontsize=22)
plt.fill_between(lag,lag,0.1,alpha = 0.2,color='k');
plt.ylim(0.01/fac, 3/fac)
plt.xlim(lag[0], 0.3/fac)
# plt.title('Villin', fontsize=16)
# plt.savefig('figs/its_{}_{}_{}.pdf'.format(output_sizes[2], patchsize, test), bbox_inches='tight')
plt.show()