# VAMPnets

<a rel="license" href="http://creativecommons.org/licenses/by/4.0/"><img alt="Creative Commons Licence" style="border-width:0" src="https://i.creativecommons.org/l/by/4.0/88x31.png" title='This work is licensed under a Creative Commons Attribution 4.0 International License.' align="right"/></a><br><br>

In this session we will see an example of how to use VAMPnets to extract a coarse-grained model from raw data using a n unsupervised deep learning approach. We will load data from a 2D toy model with xxx states, and build and train a neural network that assigns each datapoint to a separate state, and finally visualize the information we extracted from the dataset. 
After this, we will follow the same process to analyse a trajectory of the molecule Alanine Dipeptide, since it is a 30D system whose dynamics can be easily visualized in a 2D space.


<a id="ref-1" href="https://www.nature.com/articles/s41467-017-02388-1">Here</a> you can find literature on the used method.

**Remember**:
- to run the currently highlighted cell, hold <kbd>&#x21E7; Shift</kbd> and press <kbd>&#x23ce; Enter</kbd>;
- to get help for a specific function, place the cursor within the function's brackets, hold <kbd>&#x21E7; Shift</kbd>, and press <kbd>&#x21E5; Tab</kbd>;

### Import the required packages

In case you haven't installed pytorch: [Installation instructions](https://pytorch.org/get-started/locally/).

In [None]:
%matplotlib inline
import matplotlib.pyplot as plt
import numpy as np
import mdshare
import pyemma
import deeptime as dt
import torch
import torch.nn as nn

from tqdm.notebook import tqdm
from deeptime.plots import plot_implied_timescales
from deeptime.util.validation import implied_timescales

In [None]:
# this is optional if you have CUDA/GPU support
device = torch.device("cuda")
torch.backends.cudnn.benchmark = True

torch.set_num_threads(12)

### Guided example: 2D toy model
We start by loading the data for the 2D model by using the package `mdshare`. The `fetch` function fetches the data from our servers. **Do not use `mdshare` for your own data!**

In [None]:
file = mdshare.fetch("hmm-doublewell-2d-100k.npz", working_directory="data")
with np.load(file) as fh:
    data = fh["trajectory"]

Next we want to visualize how the datas are distributed in the 2D space.

#### Exercise
Plot the density of the data using a function from the `pyemma` package

In [None]:
pyemma.plots.plot_density(data[:,0], data[:,1]) ##FIXME
plt.show()

### Hyperparameter selection
The next step is a bit tricky, as hyperparameter selection requires some experience to be done correctly. We provided some default values that will allow for a smooth training of our model. The meaning of every hyperparameter is explained in the next cell.

In [None]:
# Tau, how much is the timeshift of the two datasets
tau = 1

# Batch size for Stochastic Gradient descent
batch_size = 3000

# Which trajectory points percentage is used as validation
val_ratio = 0.1

# How many hidden layers the network has
network_depth = 4

# "Width" of every layer
layer_width = 20

# Learning rate used for the ADAM optimizer
learning_rate = 5e-3

# How many output states the network has
output_size = 2

# List of nodes of each layer
nodes = [data.shape[1]] + [layer_width for _ in range(network_depth)] + [output_size]

# Iteration over the training set in the fitting process;
# basically how many iterations our training algorithm will do
nb_epoch = 20

### Data preprocessing

Now we can to prepare our data so that it can be used for training our VAMPnets model. We want two arrays made of coupled datapoints, which are selected from the main trajectory at indexes $i, i+\tau$. We want the two trajectories to be shuffled, but to maintain the correspondence between the non-time-lagged and the time-lagged datapoints. Finally, we want to split our data into training set and validation set, the former being used for training the algorithm, and the latter being necessary to test whether the network is overfitting ( = the resulting transformation works only on the training set but not on data from the same distribution).
<br>
<br>

In [None]:
dataset = dt.util.data.TrajectoryDataset(lagtime=tau, trajectory=data.astype(np.float32))

In [None]:
n_val = int(len(dataset)*val_ratio)
train_data, val_data = torch.utils.data.random_split(dataset, [len(dataset) - n_val, n_val])

In [None]:
from deeptime.util.torch import MLP
lobe = MLP(units=nodes, nonlinearity=nn.ELU, output_nonlinearity=nn.Softmax)

In [None]:
vampnet = dt.decomposition.deep.VAMPNet(lobe=lobe, learning_rate=learning_rate)

In [None]:
from torch.utils.data import DataLoader

loader_train = DataLoader(train_data, batch_size=batch_size, shuffle=True)
loader_val = DataLoader(val_data, batch_size=len(val_data), shuffle=False)

In [None]:
model = vampnet.fit(loader_train, n_epochs=nb_epoch, validation_loader=loader_val, progress=tqdm).fetch_model()

### Model validation

When the previous cell will finish running, we have successfully (🤞) trained our VAMPnets. We can plot the training information to visualize how well our training proceeded, and by plotting both training and validation information we can make sure that our model didn't overfit. Before running the next cell, consider that the our network's training and validation scores should converge to a value slightly lower than $2$, since the score is calculated as the norm of the singular values of the estimated Koopman operator. We only have 2 output nodes and the largest singular value is always $=1$.

In [None]:
plt.loglog(*vampnet.train_scores.T, label='training')
plt.loglog(*vampnet.validation_scores.T, label='validation')
plt.xlabel('step')
plt.ylabel('score')
plt.legend();

Now we can finally reap the results of our work: if the network was properly trained it should automatically separate the two wells in our system. We can verify this hypothesis by first transforming our dataset with the network using the `model.predict` method.

In [None]:
transformed_data = model.transform(data)

Now we can visualize to which states the network assigns every point; we do so in the following cell by calculating to which state every datapoint is most likely to be assigned by the network:

In [None]:
plt.scatter(*data.T, c=transformed_data[:,0])

If you are looking at an orange and a blue ball, your network reached its optimal state during the training. 

We can further analyze the output of the network by visualizing the decision landscape:

In [None]:
xmax = np.max(np.abs(data[:, 0]))
ymin = np.min(data[:, 1])
ymax = np.max(data[:, 1])
grid = np.meshgrid(np.linspace(-xmax-1, xmax+1, 150), np.linspace(ymin-1, ymax+1, 50))
xy = np.dstack(grid).reshape(-1, 2)
z = model.transform(xy)[:,0]

cb = plt.contourf(grid[0], grid[1], z.reshape(grid[0].shape), levels=15, cmap='coolwarm')
plt.colorbar(cb);

Since this is a very simple system, the network should enforce a very sharp classification, with most of the points belonging to either `state 1` or `state 2`, with only a few points in between having a mixed value.

As a last step, we can verify that the network preserves the slow information in the system by plotting the implied timescales present in our transformed data:

In [None]:
lagtimes = np.arange(1, 11)
its = implied_timescales([dt.decomposition.VAMP(lagtime=lag, observable_transform=model).fit(data).fetch_model() for lag in lagtimes])
fig, axes = plt.subplots(1, 1, figsize=(6, 4))

plot_implied_timescales(its, ax=axes)
axes.set_yscale('log')
axes.set_xlabel('lagtime (steps)')
axes.set_ylabel('timescale (steps)')
fig.tight_layout()

## Hands-on session: Alanine Dipeptide
In the following three cells, you are given the loading function for the alanine-dipeptide trajectories (along with its 2 dihedral values), a plot that shows how to visualize information about the molecule using the dihedral data, and a set of hyperparameters. Build and train a network that classifies alanine samples, and set the number of epochs so that your network converges to a stable score. Plot your results and confront them to the provided examples.

#### Cell 1: Loading
**NOTE: do NOT use the dihedral information for the training! It would be easier to do so, but the interesting aspect of this exercise lies in seeing how easily the network extracts a low level representation from a highly dimensional space**

In [None]:
ala_coords_file = mdshare.fetch(
    "alanine-dipeptide-3x250ns-heavy-atom-positions.npz", working_directory="data"
)
with np.load(ala_coords_file) as fh:
    data = fh["arr_0"]

dihedral_file = mdshare.fetch(
    "alanine-dipeptide-3x250ns-backbone-dihedrals.npz", working_directory="data"
)
with np.load(dihedral_file) as fh:
    dihedral = fh["arr_0"]

#### Cell 2: Visualization
Since the dynamics of the molecule are completely described by its position in the dihedral plane, we can use these two variables every time we need to pass an x-axis and y-axis to a plotting function

In [None]:
pyemma.plots.plot_density(*dihedral.T, cmap="viridis")
plt.show()

#### Cell 3: Hyperparameters
The `nb_epochs` variable is missing a value. Experiment with the training and find a number of epochs that ensures that your network will converge every time you train it

In [None]:
tau = 1

batch_size = 10000

train_ratio = 0.9

network_depth = 6

layer_width = 30

learning_rate = 5e-3

output_size = 6

nodes = [data.shape[1]] + [layer_width for _ in range(network_depth)] + [output_size]

nb_epoch = 30## FIXME

In [None]:
dataset = dt.util.data.TrajectoryDataset(lagtime=tau, trajectory=data.astype(np.float32))

In [None]:
n_val = int(len(dataset)*val_ratio)
train_data, val_data = torch.utils.data.random_split(dataset, [len(dataset) - n_val, n_val])

In [None]:
from deeptime.util.torch import MLP
lobe = MLP(units=nodes, nonlinearity=nn.ELU, output_nonlinearity=nn.Softmax)

In [None]:
vampnet = dt.decomposition.deep.VAMPNet(lobe=lobe, learning_rate=learning_rate)

In [None]:
from torch.utils.data import DataLoader

loader_train = DataLoader(train_data, batch_size=batch_size, shuffle=True)
loader_val = DataLoader(val_data, batch_size=len(val_data), shuffle=False)

In [None]:
model = vampnet.fit(loader_train, n_epochs=nb_epoch, validation_loader=loader_val, progress=tqdm).fetch_model()

In [None]:
plt.loglog(*vampnet.train_scores.T, label='training')
plt.loglog(*vampnet.validation_scores.T, label='validation')
plt.xlabel('step')
plt.ylabel('score')
plt.legend();

In [None]:
transformed_data = model.transform(data)

In [None]:
lagtimes = np.arange(1, 11)
its = implied_timescales([dt.decomposition.VAMP(lagtime=lag, observable_transform=model).fit(data).fetch_model() for lag in lagtimes])
fig, axes = plt.subplots(1, 1, figsize=(6, 4))

plot_implied_timescales(its, ax=axes)
axes.set_yscale('log')
axes.set_xlabel('lagtime (steps)')
axes.set_ylabel('timescale (steps)')
fig.tight_layout()

In [None]:
for i in range(output_size):
    plt.scatter(*dihedral.T, c=transformed_data[:,i], s=0.5)
    plt.show()

In [None]:
colorcode = np.argmax(transformed_data, axis=1)

In [None]:
plt.scatter(*dihedral.T, c=colorcode, s=0.5)
plt.show()

In [None]:
## Your network code goes here

When you are done, the results should look like this:

#### Dihedral space separation
<img style="float: left;" src="./img/space_division.png"/>

#### Output values for each node
<img  style="float: left;" src="./img/prob_state1.png"/>
<img  style="float: left;" src="./img/prob_state2.png"/>
<img  style="float: left;" src="./img/prob_state3.png"/>
<img  style="float: left;" src="./img/prob_state4.png"/>
<img  style="float: left;" src="./img/prob_state5.png"/>
<img  style="float: left;" src="./img/prob_state6.png"/>

#### Timescales
<img style="float: left;" src="./img/timescales.png"/>