# 1. DeepLDA: design CVs from equilibrium fluctuations

Reference paper: _Bonati, Rizzi and Parrinello, [JCPL](https://pubs.acs.org/doi/10.1021/acs.jpclett.0c00535) (2020)_ [[arXiv]](https://arxiv.org/abs/2002.06562). 

The aim of this tutorial is to illustrate how we can design collective variables in a data-driven way, starting from local fluctuations of a set of physical descriptors in the metastable states. 

### Introduction

**DeepLDA summary**

To this extent, we resort to a statistical method called Linear Discriminant Analysis ([LDA](https://en.wikipedia.org/wiki/Linear_discriminant_analysis)). LDA searches for the linear projection $s=\bar{w}^T x$ of the input data $x$ such that the classes are maximally separated. This is measured by the so called Fisher's ratio:

$$\bar{w} = \text{argmax}_w {\frac{wS_b w^T}{wS_w w^T}}$$

where $S_b$ is the scatter matrix between the classes and $S_w$ the one between them. In the simple case of two classes (states) A and B these can be easily computed from the mean and covariance matrices of the data in the states:

$$S_w = \Sigma_A + \Sigma_B $$

$$S_b = (\mu_A-\mu_b)(\mu_A-\mu_b)^T $$

![LDA](images/lda.png)

From a practical perspective, the vector $\bar{w}$ is found by solving the generalized eigenvalue problem: $S_b\bar{w} = v S_w \bar{w} $, where the eigenvalue $v$ measures the amount of separation between the states.

Here we employ a non-linear generalization of LDA in which the mapping function is a neural network (NN). This greatly increases the discriminative power of the model, by learning a set of latent variables in which the metastable states are linearly separable. This is achieved by performing a nonlinear featurization of the inputs via a NN, and then LDA is applied to the outputs of the network. During the training, the parameters of the NN are optimized as to maximize the LDA eigenvalue $v$. In other words, we are transforming the input space in such a way that the discrimination between the states is maximal. 

Furthermore, in the context of enhanced sampling calculation, NNs lend themselves well because they provide a continuous and differentiable representation, and have no trouble in handling several descriptors. 

![DeepLDA scheme](images/DeepLDA_scheme.png)

To achieve this we use the following loss function:

$$ \mathcal{L} = -v - \alpha \frac{1}{1+(\bar{s}^2-1)^2} $$

where the first term correspond to the maximization of the eigenvalue which describe the amount of separation of the two states and the second is a lorentzian regularization of the average value of the output CV, which keeps it close to 1.

Note that to stabilize the learning we also regularize the calculation of $S_w$ by adding the identity matrix multiplied by a parameter $\lambda$, as: $S_w '= S_w+\lambda\mathbb{1}$ . The two regularization parameters $\alpha$ and $\lambda$ are not independent, as the former affects the numerator of the Fisher's ratio above and the latter the denominator. Hence, in the following, we will choose only $\lambda$ and set $\alpha=\frac{2}{\lambda}$.


To apply this method, the only knowledge that is required as initial inputs are the snapshots of the system in different metastable states. These could be for instance the reactant and product of a chemical reaction, a crystalline and a liquid configurations of a material, the bound and unbound states of a ligand into a protein, etc...

Of course, a variable built only by compressing the equilibrium fluctuations in the local minima of the FES into a CV will not be perfect. Rather, it should be seen as a first step which allows to observe some transitions and learn something new on the system that we can later use to refine it and find better CVs.

**Outline**

In this tutorial, we take as example **alanine dipeptide** in vacuum, which is often used as a toy model for enhanced sampling methods. This molecule has two metastable states, called $C7_{eq}$ and $C7_{ax}$, which in the following we will refer to as A and B, respectively. 

To mimic a realistic scenario, we assume that we do not know anything about this system but for two realizations of states A and B. 

![Ala 2 metastable states](images/ala2-belfast-2-transition.png)

Using these two realizations of the molecule we proceed as follows:

1. We perform short unbiased MD simulations in the metastable states and evaluate a set of physical descriptors (e.g. interatomic distances between heavy atoms)
2. We train the DeepLDA CV and inspect it
3. Finally, we apply a bias potential to enhance the fluctuations of the DeepLDA CV and (hopefully) drive the system back and forth between A to B.

## Setup

In [None]:
import mlcvs
import torch
import numpy as np
import matplotlib.pyplot as plt
import subprocess
from pathlib import Path
import pandas as pd

# delete outputs of simulations from `folder``
def clean(folder='./'):
    subprocess.run("rm bck.* COLVAR KERNELS alanine.*", cwd=folder, shell=True)

# execute bash command in the given folder
def execute(command, folder, background=False):
    cmd = subprocess.run(command, cwd=folder, shell=True, capture_output = True, text=True, close_fds=background)
    if cmd.returncode == 0:
        print(f'Completed: {command}')
    else:
        print(cmd.stderr)

#GMX_CMD = '. /work/sourceme.sh && gmx_mpi'
GMX_CMD = 'gmx_mpi'

In [None]:
# Define a few plotting functions

def plot_ramachandran(x,y,z,scatter=None, ax=None):
    # Setup plot
    if ax is None:
        _, ax = plt.subplots(figsize=(5,4.), dpi=100)
        ax.set_title(f'Ramachandran plot')

    # Plot countour plot
    h = ax.hexbin(x,y,C=z,cmap='fessa')
    cbar = plt.colorbar(h,ax=ax)
    cbar.set_label(f'Deep-LDA CV')

    ax.set_xlabel(r'$\phi$ [rad]')
    ax.set_ylabel(r'$\psi$ [rad]')

def plot_cv_histogram(s,label=None,ax=None,**kwargs):
    # Setup plot
    if ax is None:
        _, ax = plt.subplots(figsize=(5,4.), dpi=100)
        ax.set_title('Histogram')

    if (type(s)==torch.Tensor):
        s = s.squeeze(1).detach().numpy()

    # Plot histogram
    ax.hist(s,**kwargs)
    if label is not None:
        ax.set_xlabel(label)

## 1.0 Unbiased simulations in the metastable states

First, we perform short MD simulations starting from the two snapshots of the molecule, and characterizing each state with a set of descriptors. To proceed in a blind way, we choose to use as input features all the distances between heavy atoms (the list of such descriptors are in the file `plumed-distances.dat`).

**State A**

In [None]:
# CREATE FOLDER AND COPY INPUTS
folder = '1_DeepLDA/0_unbiased-sA/'
Path(folder).mkdir(parents=True, exist_ok=True)
execute(f"cp ../md_inputs/input.ala2.pdb ../md_inputs/input.tpr .", folder=folder)

# WRITE PLUMED INPUT FILE
with open(folder+"plumed.dat","w") as f:
    print("""
# vim:ft=plumed

# Compute torsion angles, as well as energy
MOLINFO STRUCTURE=input.ala2.pdb
phi: TORSION ATOMS=@phi-2
psi: TORSION ATOMS=@psi-2
theta: TORSION ATOMS=6,5,7,9
xi: TORSION ATOMS=16,15,17,19
ene: ENERGY

# Compute descriptors
INCLUDE FILE=../plumed-distances.dat

# Print 
PRINT FMT=%g STRIDE=100 FILE=COLVAR ARG=*

ENDPLUMED
""",file=f)

## RUN GROMACS
num_steps=500000

clean(folder) # note: this deletes all previous results in folder!
execute(f"{GMX_CMD} mdrun -s input.tpr -deffnm alanine -plumed plumed.dat -ntomp 1 -nsteps {num_steps} > alanine.out", folder=folder)

**State B**

In [None]:
folder = '1_DeepLDA/0_unbiased-sB/'
Path(folder).mkdir(parents=True, exist_ok=True)
execute(f"cp ../md_inputs/input.ala2.pdb .", folder=folder)
execute(f"cp ../md_inputs/input.sB.tpr input.tpr", folder=folder)

with open(folder+"plumed.dat","w") as f:
    print("""
# vim:ft=plumed

# Compute torsion angles, as well as energy
MOLINFO STRUCTURE=input.ala2.pdb
phi: TORSION ATOMS=@phi-2
psi: TORSION ATOMS=@psi-2
theta: TORSION ATOMS=6,5,7,9
xi: TORSION ATOMS=16,15,17,19
ene: ENERGY

INCLUDE FILE=../plumed-distances.dat

PRINT FMT=%g STRIDE=100 FILE=COLVAR ARG=*

ENDPLUMED
""",file=f)

## RUN GROMACS
num_steps=500000

clean(folder)
execute(f"{GMX_CMD} mdrun -s input.tpr -deffnm alanine -plumed plumed.dat -ntomp 1 -nsteps {num_steps} > alanine.out", folder=folder)

## 1.1 DeepLDA CV on pairwise distances (heavy atoms) 

### (a) Train CV

To load the PLUMED output of the two unbiased MD runs we can use the [load_dataframe](https://mlcvs.readthedocs.io/en/latest/autosummary/mlcvs.utils.io.load_dataframe.html) function.
From this data, we build our training dataset. Since this is a supervised learning task, the dataset will be of the form (`X,y`), in which `X` are the input samples and `y` the corresponding labels (the states to which they belong to).

In [None]:
from mlcvs.utils.io import load_dataframe
# load state A and assign label 0
folder = '1_DeepLDA/0_unbiased-sA/'
colvarA = load_dataframe(folder+"COLVAR")
colvarA['state']=np.full(len(colvarA),'A')
colvarA['label']=np.full(len(colvarA),0)

# load stateB and assign label 1
folder = '1_DeepLDA/0_unbiased-sB/'
colvarB = load_dataframe(folder+"COLVAR")
colvarB['state']=np.full(len(colvarB),'B')
colvarB['label']=np.full(len(colvarB),1)

# concatenate data into a single dataframe
colvar = pd.concat([colvarA,colvarB.reset_index(drop=True)])

# create training dataset 
X = colvar.filter(regex='d_').values
y = colvar['label'].values

# transform them into torch.tensors 
X = torch.Tensor(X)
y = torch.Tensor(y)

We can take a look at the descriptors, by computing their histogram in the two states. 

--> **Question:** Is there any descriptor that is able to discriminate by its own between the states?

In [None]:
descriptors_names = colvar.filter(regex='d_').columns.values

fig,axs = plt.subplots(5,9,figsize=(20,10),sharey=True)

for ax,desc in zip(axs.flatten(),descriptors_names):
    colvar.pivot(columns='state')[desc].plot.hist(bins=50,alpha=0.5,ax=ax,legend=False)
    ax.set_title(desc)

plt.tight_layout()

Here we use the [mlcvs](https://mlcvs.readthedocs.io) package to train a DeepLDA CV out of this data. This can be as simple as follows: define the network architecture, specify when to stop to training (e.g. by using early stopping on validation score) and call the fit method. This will output the training and validation score along the training. Let's give it a try!

In [None]:
from mlcvs.lda import DeepLDA_CV
nodes = [X.size(1),30,30,5]

model = DeepLDA_CV(nodes)

# TRAIN
model.set_earlystopping(patience=20,min_delta=0.1)
model.fit(X=X,y=y,log_every=100)

However, to better understand what we are doing, we shall consider a more detailed example, in which we analyze the different steps and options. First we create a `TensorDataset` which we then divide into training and validation set. From them, we construct `Dataloader`-like objects. The definition of such auxiliary objects is a standard PyTorch practice, which allows us to easily train the models on different devices.

In [None]:
from torch.utils.data import TensorDataset,random_split
from mlcvs.utils.data import FastTensorDataLoader

dataset = TensorDataset(X,y)
train_size = int(0.8 * len(dataset))
valid_size = len(dataset) - train_size

train_data, valid_data = random_split(dataset,[train_size,valid_size])
train_loader = FastTensorDataLoader(train_data,batch_size=0,shuffle=True) # here 0 means to use a single batch
valid_loader = FastTensorDataLoader(valid_data)

Then, we need to inizialize the neural network and the optimizer and define when to stop the training (EarlyStopping or after a given number of epochs). The following is a list with all the parameters and their explanation. Note that we also standardize the inputs such that their range is betwen the -1 and 1 in the training set. 

| Parameter | Type | Description |
| :- | :- | :- |
| **Neural network** |
| nodes | list | NN architecture (last value equal to the number of hidden layers which are input of LDA) |
| activation | string | Activation function (relu,tanh,elu,linear) |
| **Optimization** |
| lrate | float | Learning rate |
| sw_reg | float | S_w matrix regularization ($\lambda$)| 
| l2_reg | float | L2 regularization |
| num_epochs | int | Number of epochs |
| **Early Stopping** |
| es_patience | int | Number of epochs before stopping |
| es_consecutive | bool | Whether es_patience should count consecutive (True) or cumulative patience |
| es_min_delta | float | Minimum decrease of validation loss |
| **Log** |
| log_every | int | How often print the train/valid loss during training |

In [None]:
from mlcvs.lda import DeepLDA_CV

#------------- PARAMETERS -------------
nodes             = [X.size(1),30,30,5]
activation        = 'tanh'

lrate             = 0.001
sw_reg            = 0.05
l2_reg            = 1e-5

num_epochs        = 1000
earlystop         = True
es_patience       = 20
es_consecutive    = True
es_min_delta      = 0.1

log_every         = 100
#--------------------------------------

# DEVICE: check if there is a GPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# MODEL: initialize the DeepLDA CV object
model = DeepLDA_CV(nodes,activation)
model.to(device)

# OPTIMIZER: here we use adam
opt = torch.optim.Adam(model.parameters(), lr=lrate, weight_decay=l2_reg)
model.set_optimizer(opt)
# set criterion for stopping the learning --> avoid overfitting
model.set_earlystopping(patience=es_patience,consecutive=es_consecutive,min_delta=es_min_delta)

# REGULARIZATION: add regularization to the calculation of the S_w matrix
model.set_regularization(sw_reg=sw_reg)

# TRAIN: fit the model to maximize Fisher's discriminant ratio
model.fit(train_loader,valid_loader, 
            standardize_inputs = True, 
            standardize_outputs = False,
            log_every=log_every)

After the training we can plot the learning curve to see the training and validation score.

--> **Exercise**: You can try to repeat the training in the cell above and make sure that you find similar values for the loss function. You can examine the behaviour of the NN when changing the NN architecture, and in particular the number of outputs which regulates the space in which LDA is applied.

In [None]:
fig, ax = plt.subplots(figsize=(5,4),dpi=100)

loss_train = [x.cpu().numpy() for x in model.loss_train]
loss_valid = [x.cpu().numpy() for x in model.loss_valid]

# Loss function
ax.plot(loss_train,'-',label='Train')
ax.plot(loss_valid,'--',label='Valid')
ax.set_ylabel('Loss Function')

#if model.earlystopping_.early_stop:
#    ax.axvline(model.earlystopping_.best_epoch,ls='dotted',color='grey',alpha=0.5,label='Early Stopping')
ax.set_xlabel('#Epochs')
ax.legend(ncol=2)

plt.tight_layout()
plt.show()

To better understand what the network is doing, we can inspect the output of the NN, before the application of LDA. 

--> **Question**: What are the differences with the input descriptors?

In [None]:
model.output_hidden=True
with torch.no_grad():
    hidden = model(X).numpy()
df = pd.DataFrame( hidden )
df ['label'] = y
model.output_hidden=False

# Plot histogram in the two state
fig,axs = plt.subplots(1,nodes[-1],figsize=(16,4),sharey=True)

hidden_names = [i for i in range(nodes[-1])]

for ax,desc in zip(axs.flatten(),hidden_names):
    df.pivot(columns='label')[desc].plot.hist(bins=50,alpha=0.5,ax=ax,legend=False)
    ax.set_title(desc)

plt.tight_layout()

Finally, we can look at the histogram of the CV (right), which shows that the two states are mapped around -1 and +1. To appreciate the discriminative power, we can also inspect the Ramachandran plot (left) of the two torsion angles phi and psi where we colored the points according to the value of the Deep-LDA CV.

In [None]:
_, axs = plt.subplots(1,2, figsize=(10,4.), dpi=100)

axs[0].set_title('Ramachandran plot')
with torch.no_grad():
    s = model(X)
    
plot_ramachandran(colvar['phi'],colvar['psi'],s,ax=axs[0])

# Calculate CV values over training set
axs[1].set_title(f'Deep-LDA Histogram')
plot_cv_histogram(s,label=model.name_,bins=50, ax=axs[1] )

plt.tight_layout()

#### Number of NN outputs

In [None]:
nn_out = [10,5,3,1]

fig, axs = plt.subplots(2,len(nn_out),figsize=(16,8),dpi=100)

for i,out in enumerate(nn_out):
    print(f'NN outputs = {out}')
    # ARCHITECTURE
    nodes = [X.size(1),30,30,out]
    model = DeepLDA_CV(nodes)

    # TRAIN
    model.set_earlystopping(patience=20,min_delta=0.05,consecutive=True)
    model.fit(X=X,y=y,log_every=100)

    # PREDICT
    with torch.no_grad():
        s = model(X)
    
    # PLOT
    axs[0][i].set_title(f'NN outputs = {out}')
    plot_ramachandran(colvar['phi'],colvar['psi'],s,ax=axs[0][i])
    axs[1][i].set_title(f'Deep-LDA Histogram')
    plot_cv_histogram(s,label=model.name_,bins=50, ax=axs[1][i] )

plt.tight_layout()

#### Extra: features relevance

We can compute the relevance of the input features from the derivatives of the cv $s$ with respect to the descriptors $d_i$:
$$r_i = \sum_{j=1} ^{n} |\frac{\partial s^{(j)}}{\partial d_i ^{(j)}}|\ \sigma(d_i)$$
where the sum is over the configurations in the training set and $\sigma$ is the standard deviation of the descriptor.

In [None]:
# parameters
multiply_by_stddev = True #whether to multiply derivatives by std dev of inputs
order_by_importance = True #plot results ordered by importance

#input names 
input_names = colvar.filter(regex='d_').columns.values
n_input = len(input_names)

#init arrays
in_num=np.arange(n_input)
rank=torch.zeros(n_input)

#compute input std dev
if multiply_by_stddev:
    in_std=torch.std(X,axis=0).numpy()

#compute the derivatives of the outputs w.r.t. inputs over all training set
for (x,Y) in train_loader:    
    x.requires_grad=True
    for x_i,y_i in zip(x,Y):
        # calculate cv 
        s_i = model(x_i) 
        # calculate derivatives
        grad_i = torch.autograd.grad(s_i,x_i)
        # accumulate them
        rank += grad_i[0].abs()

rank = rank.numpy()

#multiply by std dev
if multiply_by_stddev:
    rank = rank * in_std

#normalize to 1
rank/= np.sum(rank)

#sort
if order_by_importance:
    index= rank.argsort()
    input_names = input_names[index]
    rank = rank[index]

#plot
fig=plt.figure(figsize=(5,0.25*n_input), dpi=100)
ax = fig.add_subplot(111)

if order_by_importance:
    ax.barh(in_num, rank,color='fessa1',edgecolor = 'fessa0',linewidth=0.3)
    ax.set_yticklabels(input_names,fontsize=9)
else:
    ax.barh(in_num[::-1], rank[::-1],color='fessa1',edgecolor = 'fessa0',linewidth=0.3)
    ax.set_yticklabels(input_names[::-1],fontsize=9)

ax.set_xlabel('Relevance')
ax.set_ylabel('Inputs')
ax.set_yticks(in_num)
ax.yaxis.tick_right()

In [None]:
n = 4

fig, axs = plt.subplots(1,n,figsize=(n*5,4),dpi=100)

for i in range(n):
    ax = axs[i]
    name = input_names[::-1][i]
    colvar.pivot(columns='state')[name].plot.hist(bins=50,alpha=0.5,ax=ax,legend=False)
    ax.set_title(name)

plt.tight_layout()

### (b) Bias DeepLDA

Once we have designed our CV from the data, we can use to to enhance the sampling. To do so, we compile the model using `torch.jit` (this is done inside the `export` function) which creates a python-independent file which can be loaded in PLUMED via the `PYTORCH_MODEL` function. If you look in the export folder below, you will find two files: `model_checkpoint.pt` which can be used to load the model back to Python, and a compiled one, `model.ptc` which we will load in the PLUMED input file. 

If you are curious about how to export pythorch-based functions with jit have a look at the source code of the [export](https://mlcvs.readthedocs.io/en/latest/autosummary/mlcvs.models.NeuralNetworkCV.html#mlcvs.models.NeuralNetworkCV.export) method.

PLUMED will use the pytorch C++ APIs to load the model and evaluate it together with its derivatives with automatic differentiation. The outputs of the model will be stored in components called `deep.node-0,deep.node-1,...`, where `deep` is the label assigned to the PLUMED function (see input below).

To apply the bias potential, our method of choice is [OPES](https://www.plumed.org/doc-master/user-doc/html/_o_p_e_s.html), as it as several advantages to metadynamics (fewer parameters, quickly converges to a quasi-static regime, performs a kernel merging which allows to bias more efficiently multiple CVs, well-behaved adaptive sigma, it can limit the amount of bias deposited...). However, this variable can be used with any other CV-based methods.

--> **Exercise**: look at the PLUMED input file below and fill the missing parameters, then run gromacs.

**Note**: on Deepnote this could require quite some time. If you prefer you can open the terminal and run gromacs on the background by executing the command in the argument of `execute` while you keep on playing around with the CVs training. 

In [None]:
folder = '1_DeepLDA/1_opes-deeplda/'
Path(folder).mkdir(parents=True, exist_ok=True)
execute(f"cp ../0_unbiased-sA/input* .", folder=folder)

# export model
model.export(folder)

# write plumed input
with open(folder+"plumed.dat","w") as f:
    print("""
# vim:ft=plumed

# Compute torsion angles, as well as energy
MOLINFO STRUCTURE=input.ala2.pdb
phi: TORSION ATOMS=@phi-2
psi: TORSION ATOMS=@psi-2
theta: TORSION ATOMS=6,5,7,9
xi: TORSION ATOMS=16,15,17,19
ene: ENERGY

# Compute descriptors
INCLUDE FILE=../plumed-distances.dat

# Compute DeepLDA CV
deep: PYTORCH_MODEL FILE=____FILL____ ARG=d_2_5,d_2_6,d_2_7,d_2_9,d_2_11,d_2_15,d_2_16,d_2_17,d_2_19,d_5_6,d_5_7,d_5_9,d_5_11,d_5_15,d_5_16,d_5_17,d_5_19,d_6_7,d_6_9,d_6_11,d_6_15,d_6_16,d_6_17,d_6_19,d_7_9,d_7_11,d_7_15,d_7_16,d_7_17,d_7_19,d_9_11,d_9_15,d_9_16,d_9_17,d_9_19,d_11_15,d_11_16,d_11_17,d_11_19,d_15_16,d_15_17,d_15_19,d_16_17,d_16_19,d_17_19
# Apply OPES bias 
opes: OPES_METAD ARG=deep.node-0 PACE=500 BARRIER=30

# Print 
PRINT FMT=%g STRIDE=500 FILE=COLVAR ARG=*

ENDPLUMED
""",file=f)

## RUN GROMACS
num_steps=2500000

clean(folder)
execute(f"{GMX_CMD} mdrun -s input.tpr -deffnm alanine -plumed plumed.dat -ntomp 1 -nsteps {num_steps} > alanine.out", folder=folder)

Once the simulation is over, we can plot the time evolution of the Deep-LDA CV, in which several transitions between the states A and B (-1 and 1) can be observed. Furthermore, we can look at the Ramachandran plot to see the explored region.

In [None]:
folder = '1_DeepLDA/1_opes-deeplda/'
colvar = load_dataframe(folder+'COLVAR')

fig,axs = plt.subplots(1,2,figsize=(10,4),dpi=100)
# Time evolution (DeepLDA)
colvar.plot.scatter('time','deep.node-0',s=1,ax=axs[0])
axs[1].set_xlabel('Time [ps]')
axs[1].set_xlabel('DeepLDA')
# 2D scatter plot colored with DeepLDA
colvar.plot.scatter('phi','psi',c='deep.node-0',s=1,cmap='fessa',ax=axs[1])
axs[1].set_xlabel(r'$\phi$ [rad]')
axs[1].set_ylabel(r'$\psi$ [rad]')
axs[1].set_aspect('equal')

plt.tight_layout()
plt.show()

From this simulation we can compute the free energy surface as a function of the DeepLDA CV. To do so, we use the `compute_fes` function which performs a (weighted) block average. 

In [None]:
from mlcvs.utils.fes import compute_fes

s = colvar['deep.node-0'].values

# compute weights
kbT = 2.5
w = np.exp(colvar['opes.bias'].values/kbT)

fig,ax = plt.subplots(figsize=(6,4),dpi=100)
fes,grid,bounds,error = compute_fes(s, weights=w, kbt=kbT, 
                                    blocks=5, bandwidth=0.01, 
                                    plot=True, ax = ax)
ax.set_xlabel('DeepLDA')
ax.set_ylabel('FES [kJ/mol]')
ax.set_ylim(0,50)

--> **Exercise**: calculate the FES as a function of other variables, such as the Ramachandran angles `phi` and `psi`.

Note: the function `compute_fes` uses the function `KernelDensity` from scikit-learn package to perform Kernel Density Estimation (KDE). While this allows us to easily perform KDE on many dimensions, we need to take into account the periodicity of the CV. One simple solution is to perform data augmentation

In [None]:
def augment_periodic(cv,weights=None,bandwidth=0.1):
    """Add points across the periodic boundaries (-np.pi and np.pi)

    Parameters
    ----------
    cv : np.array
    weigths : np.array, optional
    bandwidth : float, optional
    """
    mask = (cv < -np.pi + 3*bandwidth)
    index = np.argwhere(mask)[:,0]
    cv = np.insert( cv, index, cv[mask] + 2*np.pi )
    if weights is not None:
        weights = np.insert( weights, index, weights[mask] )

    mask = (cv > np.pi  - 5*bandwidth) & ( cv < np.pi )
    index = np.argwhere(mask)[:,0]
    cv = np.insert( cv, np.argwhere(mask)[:,0], cv[mask] - 2*np.pi )
    if weights is not None:
        weights = np.insert( weights, index, weights[mask] )

    return cv, weights

phi = colvar['phi'].values
psi = colvar['psi'].values
w = np.exp(colvar['opes.bias'].values/kbT)

bandwidth = 0.05

phi, w_phi = augment_periodic(phi,w,bandwidth)
psi, w_psi = augment_periodic(psi,w,bandwidth)

In [None]:
from mlcvs.utils.fes import compute_fes

# compute weights
kbT = 2.5

#w = np.exp(colvar['opes.bias'].values/kbT)

#phi = colvar['phi'].values

fig,axs = plt.subplots(1,2,figsize=(12,4),dpi=100)
ax = axs[0]
fes,grid,bounds,error = compute_fes(phi, weights=w_phi, kbt=kbT, 
                                    blocks=5, bandwidth=bandwidth, 
                                    plot=True, ax = ax)
ax.set_xlabel(r'$\phi$ [rad]')
ax.set_ylabel('FES [kJ/mol]')
ax.set_xlim(-np.pi,np.pi)
ax.set_ylim(0,50)

ax = axs[1]
fes,grid,bounds,error = compute_fes(psi, weights=w_psi, kbt=kbT, 
                                    blocks=5, bandwidth=bandwidth, 
                                    plot=True, ax = ax)
ax.set_xlabel(r'$\psi$ [rad]')
ax.set_ylabel('FES [kJ/mol]')
ax.set_xlim(-np.pi,np.pi)
ax.set_ylim(0,50)

plt.tight_layout()

Once we have obtained a reasonable sampling, we can try to understand what the NN has learnt. A naive guess would be to compute the correlation between the DeepLDA CV and all the input distances. 

--> **Question**: Can we identify some features that are more correlated than others? You can also try to use `method='spearman'` in the correlation function, which rather than looking for a linear correlation only assesses how well the relationship between two variables can be described using a monotonic function.

In [None]:
cols = ['deep.node-0']
cols.extend(colvar.filter(regex='d_').columns)
corr = colvar[cols].corr(method='pearson') 

fig,ax = plt.subplots(figsize=(16,4),dpi=100)

corr['deep.node-0'].drop('deep.node-0').plot(kind='bar', ax=ax, rot=35)
ax.set_ylabel('Correlation with DeepLDA')
plt.show()

Since the relationship between the input distances and the DeepLDA CV is not linear, understanding what the NN has learnt might not be an easy job. Hovewer, we could try to see whether we detect any correlation with respect to other physical descriptors which might play a role in the transition between the two states. 

--> **Exercise**: compute the correlation between the DeepLDA CV and the torsion angles computed in the PLUMED input file.

In [None]:
# select deelda and input distances, as well as dihedral angles
cols = ['deep.node-0', ______FILL______ ]

# compute correlation
corr = colvar[cols].corr(method='spearman')

# plot
fig,ax = plt.subplots(figsize=(4,4),dpi=100)
corr['deep.node-0'].drop('deep.node-0').plot(kind='bar', ax=ax, rot=35)
ax.set_ylabel('Correlation with DeepLDA')
plt.show()

In [None]:
fig,axs = plt.subplots(1,len(cols)-1,figsize=(16,4),dpi=100)

state = np.zeros(len(colvar))

state[colvar['deep.node-0']<-1.1] = -1
state[colvar['deep.node-0']>1.1] = 1

colvar['state'] = state

for i,desc in enumerate(cols[1:]):
    colvar.plot.scatter('deep.node-0',desc,c='state',s=1,ax=axs[i],cmap='fessa',colorbar=False)

plt.tight_layout()

### Bonus exercise: stretching the CV

As you might have noticed, the DeepLDA CV acts as a powerful classifier, which maps the equilibrium fluctuations into very narrow distributions. Although we were using OPES with adaptive bandwitdh estimation, this might still lead to artifacts in the enhanced sampling dynamics. To address this behaviour we can work on the regularization of the NN (e.g. penalty functions to the loss function) or, in this case, we can more simply stretch it. Since the states are mapped (due to the lorentzian regularization) on a sphere of radius ~ 1.1, if we use a function of the kind `s' = s + s^N` we obtain a transformation of the CV which is approximately linear around zero and amplifies the fluctuations around 1. In the picture below you can see the case for N=3. 

You can try to repeat the enhanced sampling simulation biasing a function of the CV, using the `CUSTOM` PLUMED keyword, e.g.:

`deep_mod: CUSTOM ARG=deep.node-0 FUNC=x+x^3 PERIODIC=NO`

In [None]:
x = np.linspace(-1.2,1.2,100)
plt.plot(x,x)
plt.plot(x,x+x**3)
plt.show()

In [None]:
folder = '1_DeepLDA/2_opes-deeplda-stretch/'
Path(folder).mkdir(parents=True, exist_ok=True)
execute(f"cp ../0_unbiased-sA/input* .", folder=folder)

# export model
#model.export(folder)
execute(f"cp ../1_opes-deeplda/model* .", folder=folder)

# write plumed input
with open(folder+"plumed.dat","w") as f:
    print("""
# vim:ft=plumed

# Compute torsion angles, as well as energy
MOLINFO STRUCTURE=input.ala2.pdb
phi: TORSION ATOMS=@phi-2
psi: TORSION ATOMS=@psi-2
theta: TORSION ATOMS=6,5,7,9
xi: TORSION ATOMS=16,15,17,19
ene: ENERGY

# Compute descriptors
INCLUDE FILE=../plumed-distances.dat

# Compute DeepLDA CV
deep: PYTORCH_MODEL FILE=model.ptc    ARG=d_2_5,d_2_6,d_2_7,d_2_9,d_2_11,d_2_15,d_2_16,d_2_17,d_2_19,d_5_6,d_5_7,d_5_9,d_5_11,d_5_15,d_5_16,d_5_17,d_5_19,d_6_7,d_6_9,d_6_11,d_6_15,d_6_16,d_6_17,d_6_19,d_7_9,d_7_11,d_7_15,d_7_16,d_7_17,d_7_19,d_9_11,d_9_15,d_9_16,d_9_17,d_9_19,d_11_15,d_11_16,d_11_17,d_11_19,d_15_16,d_15_17,d_15_19,d_16_17,d_16_19,d_17_19

# Stretch the CV
deep_mod: CUSTOM ARG=deep.node-0 FUNC=x+x^3 PERIODIC=NO

# Apply OPES bias 
opes: OPES_METAD ARG=deep_mod PACE=500 BARRIER=30

# Print 
PRINT FMT=%g STRIDE=500 FILE=COLVAR ARG=*

ENDPLUMED
""",file=f)

## RUN GROMACS
num_steps=5000000

clean(folder)
#execute(f"{GMX_CMD} mdrun -s input.tpr -deffnm alanine -plumed plumed.dat -ntomp 1 -nsteps {num_steps} > alanine.out", folder=folder)

In [None]:
folder = '1_DeepLDA/2_opes-deeplda-stretch/'
colvar = load_dataframe(folder+'COLVAR')

fig,axs = plt.subplots(1,2,figsize=(10,4),dpi=100)
# Time evolution (DeepLDA)
colvar.plot.scatter('time','deep_mod',s=1,ax=axs[0])
axs[1].set_xlabel('Time [ps]')
axs[1].set_xlabel('DeepLDA')
# 2D scatter plot colored with DeepLDA
colvar.plot.scatter('phi','psi',c='deep_mod',s=1,cmap='fessa',ax=axs[1])
axs[1].set_xlabel(r'$\phi$ [rad]')
axs[1].set_ylabel(r'$\psi$ [rad]')
axs[1].set_aspect('equal')

plt.tight_layout()
plt.show()

In [None]:
from mlcvs.utils.fes import compute_fes

s = colvar['deep.node-0'].values
s2 = colvar['deep_mod'].values

# compute weights
kbT = 2.5
w = np.exp(colvar['opes.bias'].values/kbT)

fig,axs = plt.subplots(1,2,figsize=(12,4),dpi=100)
ax = axs[0]
fes,grid,bounds,error = compute_fes(s, weights=w, kbt=kbT, 
                                    blocks=5, bandwidth=0.01, 
                                    plot=True, ax = ax)
ax.set_xlabel('DeepLDA')
ax.set_ylabel('FES [kJ/mol]')
ax.set_ylim(0,50)

ax = axs[1]
fes,grid,bounds,error = compute_fes(s2, weights=w, kbt=kbT, 
                                    blocks=5, bandwidth=0.02, 
                                    plot=True, ax = ax)
ax.set_xlabel('DeepLDA mod (x+x^3)')
ax.set_ylabel('FES [kJ/mol]')
ax.set_ylim(0,50)