# Data Handling

The models and evaluation are built based on the custom DataLoader class.
Every model implements its own subclass of the base DataLoader which handles custom data preprocessing and can process additional information written to the generated fasta files.

## Adding new datasets

To add a new dataset that can be processed with the loaders, a new function ```loader._load_<dataset>(...)``` has to be implemented. This function should handle loading of the sequences as a 1D-array of strings into the ```loader.data``` attribute and setting the ```loader.reference```attribute with the wildtype reference sequence of the respective dataset (otherwise performing the total variation distance computation with respect to the wildtype will not work). Additional custom attributes (e.g. fitness) can also be instantiated in this function. A new case for the dataset to be added to ```loader.load``` which defines the short name of the dataset and calls the specific load function implemented above. 
Now, the new dataset can be loaded as follows and used with the models below.

In [None]:
from genzyme.data import loaderFactory
dataset_name = "ired" # put the name of the new dataset as defined in .load here
loader = loaderFactory()
loader.load(dataset_name)

# Model Usage

This notebook contains example code for training the models and generating new sequences with them. The hyperparameters for each model as well as preprocessing, training and generation are handled in a ```.yaml``` file stored in the respective subdirectory of ```../configs```.

In [None]:
from omegaconf import OmegaConf
import os
import torch

In [None]:
from genzyme.models import modelFactory
from genzyme.data import loaderFactory

## ZymCTRL

In [None]:
cfg = OmegaConf.load(os.path.join(os.path.dirname(__file__), "../configs/zymctrl/config.yaml"))
OmegaConf.resolve(cfg)

The ZymCTRL model can be run with 3 different architectures that differ in the number of attention heads and layers:
- Original (20 heads, 36 layers)
- Small (10 heads, 20 layers)
- Tiny (5 heads, 5 layers)

Since the provided model class is merely a wrapper for the huggingface models, the config for the latter is provided separately in the ```model_dir```. It can point to the huggingface hub path in the case of the original model and has to point to a local directory containing the tokenizer and model configs for the custom smaller models (default is ```./../../data/ZymCTRL_<architecture>```). Note that the path must be specified relative to ```zymctrl.py```if it's not absolute. A mapping of the model names to the ```model_dir``` can be found below.

In [None]:
if cfg.model.name == "zymctrl":
    cfg.model.dir = "AI4PD/ZymCTRL"

elif cfg.model.name == "small":
    cfg.model.dir = './../../data/ZymCTRL_small'

elif cfg.model.name == "tiny":
    cfg.model.dir ='./../../data/ZymCTRL_tiny'

elif cfg.model.name != "zymctrl":
    raise NotImplementedError()

To preprocess the desired dataset, the respective loader class is used. It handles proper batching and tokenization of the sequences and flushes the train and test datasets to disk at the specified ```data_dir```. The datasets can also be kept in memory and returned directly by setting ```save=False```.

In [None]:
from transformers import AutoTokenizer
from genzyme.models.utils import SpecialTokens

if cfg.data.reload:
    loader = loaderFactory("ctrl")
    loader.load(cfg.data.name)
    loader.assign_control_tags(cfg.data.tag)
    loader.set_tokenizer(AutoTokenizer.from_pretrained(cfg.model.dir))

    special = SpecialTokens("<start>", "<end>", "<pad>", "<|endoftext|>", "<sep>", " ")
    loader.preprocess(special, cfg.data.test_split, 0, save=True, data_dir=f"../datasets/{cfg.data.name}/")   

In [None]:
model = modelFactory("zymctrl", cfg = cfg)

### Training

For training, only the train and validation dataset have to be specified, all other hyperparameters are passed via ```cfg```. One can either pass the datasets directly or specify the path to their location on disk.

In [None]:
model.run_training(train_dataset = f'../datasets/{cfg.data.name}/train',
                   eval_dataset = f'../datasets/{cfg.data.name}/test',
                   cfg = cfg)

### Generation

A list of EC numbers can be specified, for each of which ```cfg.generation.n_seqs``` sequences will be generated. The sequences of all prompts will be written to a fasta file whose path can be specified in ```cfg``` with the headers containing the prompt and the model perplexity for each sequence. If ```cfg.generation.keep_in_memory=True```, the sequences will also be returned.

In [None]:
model.generate(cfg)

## SEDD & Discrete Flow Modelling

Score Entropy Discrete Diffusion (SEDD) and Discrete Flow Modeling (DFM) are implemented in the same model class and can be switched between by changing the loss and noise type attributes in the config as follows:

| Model | ```cfg.training.loss``` | ```cfg.noise.type``` |
| ----- | ----------------- | -------------- |
| SEDD | ```"dwdse"``` | ```"loglinear"``` |
| DFM | ```"ce"``` | ```"linear"``` |

In [None]:
model = modelFactory("sedd", cfg_dir = "./../../configs/sedd/")

### Training

The model uses the config stored at ```config_dir``` as default. Any runtime overrides to the default config can be passed directly as a config object that only contains the relevant attributes that are to be overriden.
Data loading and preprocessing is handled by the ```run_training``` method.

The config that was used for training is stored in the ```cfg.work_dir```.

In [None]:
train_overrides = OmegaConf.create({"data": {"name": "ired", "test_split": 0.2, "grouped": False}, 
                                    "training": {"n_iters": 100000, "distributed": True, "loss": "ce", "batch_size": 128},
                                    "optim": {"lr": 3e-5},
                                    "noise": {"type": "linear"},
                                    "eval": {"batch_size": 64},
                                    "model": {"length": 291}})
model.run_training(train_overrides)

### Generation

The generation function takes as argument the path to the directory that contains the training config. If no particular checkpoint is specified explicitly, the model will try to load the most recent checkpoint it can find in ```model_path/checkpoints```.

Similar to the training, the generate method also allows for overrides of config arguments by passing a config object with the relevant attributes.

Most importantly, the predictors are specific to sedd or dfm and should only be used as follows:

| Model | ```cfg.sampling.predictor``` |
| ----- | ----------------- |
| SEDD | ```"euler"``` or ```"analytic"``` |
| DFM | ```"euler-dfm"``` |

In [None]:
model_path = "<YOUR WORK DIR>"
gen_overrides = OmegaConf.create({"sampling" : {"batch_size": 32, "steps": 1000, "n_samples": 10000, "predictor": "euler-dfm"},
                                      "out_dir": "../gen_data/dfm/"})
model.generate(model_path, gen_overrides)

## Energy-based Models

The abstract baseclass ```EnergyBasedModel``` implements core functions and an MCMC sampling routine. It has two child classes, ```DeepEBM``` and ```PottsModel```, that each implement their own training methods. They are both derived from ```torch.nn.Module``` aswell.

### Potts Model

In [None]:
cfg = OmegaConf.load("../configs/potts/config.yaml")
OmegaConf.resolve(cfg)

#### Preprocessing

The preprocessing is handled by the Potts loader class. The preprocess method returns pytorch dataloaders for training and testing. When using ```cfg.optimizer.method="l-bfgs"```, make sure to set the ```train_batch_size```equal to the training set size because the model will only be trained on the first batch from the loader.

Note that this implementation of energy-based models always expect fixed-size input (due to the fixed size of the parameters and because only fixed-length sequnences can be sampled at generation time) and therefore the training data has to be truncated / filtered to the same length. This can be achieved with ```loader.unify_seq_len(<LENGTH>)```.

In [None]:
loader = loaderFactory("potts")
loader.load(cfg.data.name)
loader.unify_seq_len(cfg.model.L)
train_data, test_data = loader.preprocess(cfg.training.test_split, 
                                          train_batch_size=cfg.trainig.batch_size, 
                                          test_batch_size=cfg.training.batch_size, 
                                          d = cfg.model.d, 
                                          shuffle=True)

#### Training

In [None]:
model = modelFactory("potts", cfg = cfg)
model.run_training(train_data, test_data, cfg)

In [None]:
torch.save(model, "your/model/name.pt")

To extract a contact map from the potts model, run ```get_coupling_info```. The contact map can be generated with or without performing average product correction.

In [None]:
import matplotlib.pyplot as plt

contacts = model.get_coupling_info(apc=True)
fig, ax = plt.subplots()
ax.imshow(contacts)

#### Generation

In order to start a run of the Markov chain, the starting state has to be provided. This can be a random sequence generated via ```get_random_seq``` or any arbitrary sequence in one-hot encoding with ```model.d```classes.
The model can be re-seeded via ```model.set_seed(<SEED>)```.

In [None]:
x0 = model.get_random_seq()
x0 = torch.nn.functional.one_hot(x0, num_classes = model.d)
model.generate(cfg, x0)

### Deep Energy-based Model

The deep energy-based model can be run with two different implementations of the energy function:
1. Energy modelled by ESM-2 with finetuned head
2. Energy modelled as quadratic form of pretrained ESM-2 embeddings with learnable matrix

where $s$ indicates the ESM-2 model and subscript $\theta$ indicates that a component is trainable.

The different models can be specified as follows
| Model | ```cfg.model.em_name```|
| -- | ---------------------- |
| $f_\theta(x) = s_\theta(x)$ | ```"esm"``` |
| $f_\theta(x) = s(x)^T A_\theta s(x)$ | ```"quadratic"``` |


In [None]:
cfg = OmegaConf.load("../configs/deep_ebm/config.yaml")
OmegaConf.resolve(cfg)
model = modelFactory("debm", cfg = cfg)

#### Preprocessing

Again, the training and validation sequences need to have the same length. Because the ESM-2 model, which is used to model the energy, takes tokenized sequences where amino acids are represented as integers as input, the ESM-2 tokenizer has to be passed to the preprocessing function.
The resulting dataloaders contain pairs of tokenized sequences, one encoded with the ESM-2 tokenizer, the other encoded with a simpler integer encoding with ```model.d``` classes.

In [None]:
loader = loaderFactory("debm")
loader.load("ired")
loader.unify_seq_len(cfg.model.L)
train_dl, test_dl = loader.preprocess(cfg.training.split,
                                      cfg.training.batch_size,
                                      cfg.training.val_batch_size,
                                      tokenizer = model.em_tokenizer)

#### Training

In [None]:
model.run_training(train_dl, test_dl, cfg)

#### Generation

In [None]:
x0 = model.get_random_seq()
x0 = torch.nn.functional.one_hot(x0, num_classes = model.d).double()
model.generate(cfg, x0)

## Random Baseline

The random model is the only model that does not use a config file (due to its simplicity). It will generate randomly drawn amino acid sequences of a fixed length. The model can be run in two modes:
- Without training: The model will draw an amino acid uniformly at random at each position
- With training: The model will draw an amino acid from the maximum likelihood estimate of the categorical distribution at each position (conserves the marginals of the amino acids)

### Preprocessing

The random model can be trained either on fixed length sequences or on sequences of different lengths by padding sequences that are shorter. The model comes with a built in function that handles padding to the desired maximum length. Preprocessing is only necessary when running the model with training, otherwise it suffices to pass the desired length of the resulting sequences to the constructor.

In [None]:
import numpy as np

loader = loaderFactory("random")
loader.load("ired")
max_len = np.max([len(seq) for seq in loader.get_data()])

model = modelFactory("random", length = max_len)
loader.set_data(model.pad_data(loader.get_data()))

Optionally, the sequences can be truncated/filtered to the same length with the built in method from the dataloader instead.

In [None]:
loader = loaderFactory("random")
loader.load("ired")
loader.unify_seq_len(290)

### Training

Training is optional, if omitted the model will generate sequences drawn uniformly at random

In [None]:
train_data = loader.get_data()
l = len(train_data[0])
model = modelFactory("random", max_len = l)
model.run_training(train_dataset = train_data)

### Generation

In [None]:
model.generate(n_samples = 10000, output_file = "sequences.fasta", keep_in_memory = False)