In [None]:
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline

import torch.nn as nn
import torch.nn.functional as F
import torch
from torch.utils import data
import torch.optim as optim
import matplotlib.gridspec as gridspec

import h5py

### Hyperparameters

In [None]:
stride = 1

tau = 100//stride 

output_sizes = [8,8]
number_subsystems = len(output_sizes)
# tau list for timescales estimation
tau_list = [1,2,4,8]

# Batch size for Stochastic Gradient descent
batch_size = 20000
# Which trajectory points percentage is used as validation and testing, the rest is for training
valid_ratio = 0.3
test_ratio = 0.0001
# How many hidden layers the network chi has
network_depth = 3

# Width of every layer of chi
layer_width = 100
# create a list with the number of nodes for each layer
nodes = [layer_width]*network_depth
# data preparation
# how many residues are skipped for distance calculation
skip_res = 6
# Size of the windows for attention mechanism
patchsize = 8
# How many residues are skipped before defining a new window
skip_over = 4

# How strong the fake subsystem is
factor_fake = 2.
# How large the noise in the mask for regularization is
noise = 2.
# Threshold after which the attention weight is set to zero
cutoff=0.9
# Learning rate
learning_rate=0.001
# epsilon
epsilon=1e-5
# score method
score_mode='regularize' # one of ('trunc', 'regularize', 'clamp', 'old')

### Load data

In [None]:
# data set has a total length of 184 µs with a 1 ns resolution (total of 184000 frames)

data_trajs = []
hdf5_names = []
loaded_data_stride = 100
exclude_list = []
with h5py.File(f"/group/ag_cmb/scratch/deeptime_data/syt/syt_0cal_internal1by1_stride{loaded_data_stride}.hdf5", "r") as f: # 1 frame = 1 ns
    #print("datasets:", f.keys())
    for n, name in enumerate(f.keys()):
        if n not in exclude_list:
            hdf5_names.append(name)
            dset = f[name]
            dat = dset[...].astype('float32')

            data_trajs.append(1/np.exp(dat[::stride]))

### Define dataset

In [None]:
from deeptime.util.data import TrajectoriesDataset

dataset = TrajectoriesDataset.from_numpy(lagtime=tau, data=data_trajs)

In [None]:
# split into training/validation/test set
n_val = int(len(dataset)*valid_ratio)
n_test = int(len(dataset)*test_ratio)
train_data, val_data, test_data = torch.utils.data.random_split(dataset, [len(dataset) - n_val - n_test, n_val, n_test])

### Define networks

In [None]:
from masks import Mask_proteins
from collections import OrderedDict
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# Assuming that we are on a CUDA machine, this should print a CUDA device:
train_mean = np.concatenate(train_data.dataset.trajectories, axis=0).mean(0)
train_std = np.concatenate(train_data.dataset.trajectories, axis=0).std(0)
print(device)
input_size = train_data.dataset.trajectories[0].shape[-1]
mask = Mask_proteins(input_size, number_subsystems, skip_res=skip_res, patchsize=patchsize, skip=skip_over, mean=torch.Tensor(train_mean),
            std=torch.Tensor(train_std), factor_fake=factor_fake, noise=noise, device=device, cutoff=cutoff)
mask.to(device=device)
lobes = []
for output_size in output_sizes:
    lobe_dict = OrderedDict([('Layer_input', nn.Linear(input_size, layer_width)),
                            ('Elu_input', nn.ELU())])
    for d in range(network_depth):
        lobe_dict['Layer'+str(d)]=nn.Linear(layer_width, layer_width)
        lobe_dict['Elu'+str(d)]=nn.ELU()
    lobe_dict['Layer_output']=nn.Linear(layer_width, output_size)
    lobe_dict['Softmax']=nn.Softmax(dim=1) # obtain fuzzy probability distribution over output states
    
    lobe = nn.Sequential(
        lobe_dict 
    )
    lobes.append(lobe.to(device=device))

print(mask)
print(lobes)         

### Create iVAMPnets estimator

In [None]:
from ivampnets import iVAMPnet

In [None]:
ivampnet = iVAMPnet(lobes, mask, device, learning_rate=learning_rate, epsilon=epsilon, score_mode=score_mode)

### Plot mask before training

In [None]:
from examples import plot_mask
plot_mask(mask, skip=10)

### Create data loader

In [None]:
from torch.utils.data import DataLoader

loader_train = DataLoader(train_data, batch_size=batch_size, shuffle=True)
loader_val = DataLoader(val_data, batch_size=len(val_data), shuffle=False)

### Create a tensorboard writer to observe performance during training

In [None]:
tensorboard_installed = False
if tensorboard_installed:
    from torch.utils.tensorboard import SummaryWriter
    writer = SummaryWriter(log_dir='./runs/Syt/')
    input_model, _ = next(iter(loader_train))
    # writer.add_graph(lobe, input_to_model=input_model.to(device))
else:
    writer=None

### Fit the model on the training data

In [None]:
model = ivampnet.fit(loader_train, n_epochs=50, validation_loader=loader_val, mask=True, lam_decomp=20., 
                     lam_trace=1., start_mask=0, end_trace=20, tb_writer=writer, clip=False).fetch_model()

plot_mask(mask, skip=10)
mask.noise=5.
model = ivampnet.fit(loader_train, n_epochs=150, validation_loader=loader_val, mask=True, lam_decomp=50., 
                     lam_trace=0., start_mask=0, end_trace=0, tb_writer=writer, clip=False).fetch_model()
plot_mask(mask, skip=10)
mask.noise=10.
model = ivampnet.fit(loader_train, n_epochs=150, validation_loader=loader_val, mask=True, lam_decomp=100., 
                     lam_trace=0., start_mask=0, end_trace=0, tb_writer=writer, clip=False).fetch_model()

### Plot training and validation scores

In [None]:
plt.loglog(*ivampnet.train_scores.T, label='training')
plt.loglog(*ivampnet.validation_scores.T, label='validation')
plt.xlabel('step')
plt.ylabel('score')
plt.legend();

### Plot the mask after training

In [None]:
plot_mask(mask, skip=10)

In [None]:
from examples import plot_protein_its, plot_protein_mask
plot_protein_mask(mask, skip_start=4)

### Finally train without noise

In [None]:
# the noise is only important to make the training of the mask meaningfull.
# Here, the mask should be well trained, so we disable the mask training from here on.
mask.noise=0.
model = ivampnet.fit(loader_train, n_epochs=300, validation_loader=loader_val, mask=False, lam_decomp=100., 
                     lam_trace=0., start_mask=0, end_trace=0, tb_writer=writer, clip=False).fetch_model()

In [None]:
# In principle you can then also train the model without enforcing the decomposition score anymore
# However, you should observe if the independence score rise significantly, then you need to reverse the progress
# You can use the save_criteria parameter to control it.
model = ivampnet.fit(loader_train, n_epochs=300, validation_loader=loader_val, mask=False, lam_decomp=0., 
                     lam_trace=0., start_mask=0, end_trace=0, tb_writer=writer, clip=False, save_criteria=0.012).fetch_model()

### Estimate implied timescales

In [None]:
runs = 5
its = [[] for _ in range(runs)]
# cheap error estimation, instead of retraining chi, evaluate the model on different trajectories
percentage = 0.9
N_trajs = len(dataset.trajectories)
indexes_traj = np.arange(N_trajs)
n_val = int(N_trajs * percentage)
msmlags=np.array([1,2,4,6,10,15,20,25])*10
for run in range(runs):
    for tau_i in msmlags:
        np.random.shuffle(indexes_traj)
        indexes_used = indexes_traj[:n_val]
        data_t = np.concatenate([dataset.trajectories[a][:-tau_i] for a in indexes_used], axis=0)
        data_tau = np.concatenate([dataset.trajectories[a][tau_i:] for a in indexes_used], axis=0)
        its[run].append(model.timescales(data_t, data_tau, tau_i, batchsize=10000))

In [None]:
# reorder its, subsystems can have different outputsizes!
its_reorder = [np.zeros((runs,len(msmlags), output_sizes[n]-1)) for n in range(number_subsystems)]
for n in range(number_subsystems):
    for run in range(runs):
        for lag in range(len(msmlags)):
            its_reorder[n][run,lag] = its[run][lag][n]

In [None]:
axes, fig = plot_protein_its(its_reorder, msmlags, ylog=True, multiple_runs=True, percent=0.9)
x_ticks = np.array([1,5,10,20,40])*10
x_ticks_labels = x_ticks*stride # for estimating the right units!
y_ticks = np.array([1000,10000, 100000])/stride
y_ticks_labels = y_ticks*stride/1000
for n in range(number_subsystems):
    ax=axes[n]
    ax.plot(msmlags,msmlags, 'k')
    ax.fill_between(msmlags, msmlags[0], msmlags, color = 'k', alpha = 0.2)
    ax.set_xlabel('Lagtime [ns]', fontsize=16)
    if n==0:
        ax.set_ylabel('Implied Timescales [$\mu$s]', fontsize=16)
    ax.legend(fontsize=14, loc='lower right')
    ax.set_xticks(x_ticks)
    ax.set_xticklabels(x_ticks_labels, fontsize=14)
    ax.set_yticks(y_ticks)
    ax.set_yticklabels(y_ticks_labels, fontsize=14)
    ax.tick_params(direction='out', length=6, width=2, colors='k',
                   grid_color='k', grid_alpha=0.5)
    ax.set_xlim(10,250)
    ax.set_ylim(0.01*1000, 200*1000)
    # fig.savefig('./Syt_its.pdf', bbox_inches='tight')

plt.show()

In [None]:
# reproduces Fig. 5b

In [None]:
# ivampnet.save_params('./Syt_params')