# Overview

## Enhanced sampling and collective variables

- cv-based enhanced sampling
- meaning of CVs:
    1. dimensionality reduction
    2. able to distinguish metastable states of interest
    3. able to promote the sampling along the minimum free energy pathways
- chicken and egg problem
- historical pathway: from physics to data-driven
    - physical intuition
    - linear transformation methods
    - non-linear (e.g. nn) cvs

## What is `mlcvs`
`mlcvs`, which stands for Machine Learning Collective Variables, is Python library for the construciton of machine learning-based Collective Variables (CVs) for atomistic simulations.

The main purposes of `mlcvs` are
- Simplify to the bone the use of such CVs for the users.
- Provide a flexible framework for further development over previous models.

`mlcvs` allows the user to start train and export mlcvs models from scratch with only few lines of code which furthermore do no require any expertise in coding.

The library is based on Pytorch and exploits many features of the Pytorch-Lightning package to simplify the overall workflow. 
The library is thought to be used alongside with PLUMED, thus it is structured to simplify as much as possible the interaction with that in terms of handling of data files and utilization of the mlcvs.

<center><img src="images/graphical_overview_mlcvs.png" width="800" /></center>

## `mlcvs` workflow
The main goal of `mlcvs` is to make the construction of mlcvs as straightforward and accessible as possible for all types of users.

In the basic workflow consists of few steps which corresponds to very few lines of code:
- Import training data using the functions in `utils`, i.e. PLUMED colvar files or others
- Organize the training data into a `DataModule` using the functions in `data`. This allows to expolit best the Lightning features
- Initialize the model as one of the CV classes in `cvs`.  
- Initialize a `pytorch_lightning.trainer`, this takes care of training, validating, logs and boring stuff :)
- Export the trained model with `model.to_torchscript()`
- TODO Generate a PLUMED input file 
- Enjoy the CV in PLUMED with our awesome interface

## Structure of CVs classes in `mlcvs`

The final product of `mlcvs` library are of course the CVs.
These are defined as classes which inherit from from a `BaseCV` class and from `pytorch_lightning.module`, which inherits from `torch.nn.module`.

The first super class is meant to define a template for all the CVs along with common utility methods and the handling of pre and post processing in the model. 

The second allows to exploit all the utilites from pytorch lightning.  

Each CV is characterized by its specific methods, attributes and properties, which are implemented on top of these two super classes.
The structure of CVs in `mlcvs` is thought to be modular, indeed the core of each model is defined as a series of `BLOCKS`, implemented as `torch.nn.module`, that are automatically executed sequentially in a similar fashion to what is done with `torch.nn.sequential`.
Each CV then has a `loss_fn` attribute that sets the loss function which has to be minimized for the optimization of the trainable blocks. On the other hand, the optimizer for the training over the trainable weights of the model is set as a property of the model.

In additon to the core of the CV class the user can also prepend and append pre and postprocessing models. These are in general thought to be `Transform` object, as they are not trainable, but in principle they could generic `torch.nn.Module`.
This possiblity allows to perform the non-trainable preprocessing operations on the dataset only once at the beginning of the training and to include anyways such operations in the final model for exporting, testing etc.
Furthermore this allows to perform postprocessing on the outputs of the model and include them after the training is already completed.   




## Structure of the code

### core
Implements building blocks of the mlcvs classes.
- **loss** :      Implements loss functions for the training of mlcvs
- **nn** :        Implements trainable machine-learning building blocks of the mlcvs classes, conceptually similar to torch.nn 
- **stats** :     Implements statistical analysis methods (LDA, TICA, PCA..) for the mlcvs classes
- **transform** : Implements non-trainable transformations of data

### cvs
Implements ready-to-use mlcvs classes and the `BaseCV` template class.
The CVs are divided based on the criterion used for the optimization in: 
- **unsupervised** :      Only require data about the system (`AutoEncoderCV` and `VariationalAutoEncoderCV`).
- **supervised**:         Require either labeled data from the different metastable states of the system (`DeepLDA` and `DeepTDA`) or data and target to be matched (`RegressionCV`)
- **timelagged**:         Require time-lagged data from reactive trajectory (`DeepTICA`)



### data
Implements all the tools used to efficiently handle data in `mlcvs`. The structure is inspired by `pytorch_lightning` with the addition of relying on a dictionary-like handling of the datasets based on keywords indexing for a better ease of use. 
The key elements are:
- **DictionaryDataset**:            A dictionary-like `torch.utils.data.Dataset` that works with tensors and names, i.e data,labels,target,weights etc. 
- **DictionaryDataModule**:         A `pytorch_lightning.LightningDatamodule` to be initialized from a DictionaryDataset.
- **FastDictionaryLoader** :        A DataLoader-like object for sets of tensors. It is adapted to work with dictionaries and to be faster than standard dataloader (see docs).


### utils
Implements miscellanous and transversal tools for a smoother workflow in `mlcvs`. 
- **io**:            Utils for fast and efficient data import from file, optimized for PLUMED colvar files.  
- **fes**:           Function to recover and plot 1D and 2D Free Energy Surface (FES) from biased data. The reweighting function is based on Kernel Density Estimation (KDE) either from `KDEpy` (faster) or `Scipy` (slower).
- **timelagged** :   Utils for timelagged datasets.
- **plot** :         Utils functions for often-used plots (i.e. `plot_metrics` and `plot_isolines_2D`) and `cm_fessa` and `cortina80` color palettes.
- **trainer** :      `pytorch_lihtning.Callback` functions for metrics logging.



### test
Except for very few simple side functions, all the classes and functions in `mlcvs` are tested to check for potential crashes of the code. The test functions are implemented at the end of each file and simply called here in the `test` folder.

## Data-driven collective variables

### 1. Unsupervised learning
- Learn structure from unlabeled data 
- Only enforced dimensionality reduction task
- Data: any

#### 1a) Auto Encoder

#### 1b) Variational Auto Encoder ??

### 2) Supervised learning

- Learn a regression/classification task
- Requires labeled samples (e.g. to which state the configurations belong to)
- can be used to enforce requirement that CVs should distinguish metastable states 

#### 2a) Neural network-based Discriminant Analysis (DeepLDA)

#### 2b) Targeted Discriminant Analysis (DeepTDA)

#### 2c) Regression CV

### 3) Temporal-informed CVs

- Learn how system evolve
- Can be used to extract slow modes which describe transition from one metastable state to another
- BUT require already to have reactive simulations

#### 3a) Neural network-based TICA (DeepTICA)

#### 3b) Time-lagged Auto Encoder (TAE)

### 4) Multi-task learning

#### 4a) Autoencoder + TDA?