# AIDO.Protein-16B: Training with ModelGenerator

## ProteinGym Supervised DMS Benchmark
The Deep Mutational Scanning (DMS) Benchmark in ProteinGym is a comprehensive collection of 283 standardized DMS assays, comprising more than 2.7 million mutated protein sequences from over 200 diverse protein families. These assays capture a wide range of functional properties, such as ligand binding, thermostability, viral replication, drug resistance, and more. The dataset spans diverse taxa (humans, other eukaryotes, prokaryotes, and viruses) and includes a variety of mutation types, such as single amino acid substitutions and indels (insertions or deletions). The primary goal of the DMS Benchmark is to model protein fitness landscapes, which represent the relationship between genetic mutations and their effects on protein fitness or functionality.

We finetune the [AIDO.Protein-16B](https://huggingface.co/genbio-ai/AIDO.DNA-16B) and [AIDO.Protein-16B-v1](https://huggingface.co/genbio-ai/AIDO.DNA-16B-v1) models on the DMS benmark.
AIDO.ModelGenerator implements both Linear Probing and LoRA finetuning, following a 5-fold cross-validation scheme with the random split strategy proposed in the original [ProteinGym paper](https://www.biorxiv.org/content/10.1101/2023.12.07.570727v1). ModelGenerator also implements both DDP and FSDP for efficient finetuning.

For tasks involving indels with limited data, we apply Linear Probing, while LoRA is used for substitution tasks and other indel tasks to balance computational efficiency with fine-tuning effectiveness.

In [1]:
! which mgen

/opt/miniconda/envs/genbio/bin/mgen


In [2]:
! mgen --help

usage: mgen [-h] [-c CONFIG] [--print_config[=flags]]
            {fit,validate,test,predict} ...

Lightning Trainer command line tool

options:
  -h, --help            Show this help message and exit.
  -c CONFIG, --config CONFIG
                        Path to a configuration file in json or yaml format.
  --print_config[=flags]
                        Print the configuration after applying all other
                        arguments and exit. The optional flags customizes the
                        output and are one or more keywords separated by
                        comma. The supported flags are: skip_default,
                        skip_null.

subcommands:
  For more details of each subcommand, add it as an argument followed by
  --help.

  Available subcommands:
    fit                 Runs the full optimization routine.
    validate            Perform one evaluation epoch over the validation set.
    test                Perform one evaluation epoch over the test set. It's


## Start Training by Command Line

### Command to run LoRA finetuning with DDP on 1 nodes, 3 GPUs per node:

Config file: [substitution_LoRA_DDP.yaml](../ModelGenerator/experiments/AIDO.Protein/DMS/configs/substitution_LoRA_DDP.yaml)

#### Add Logger to yaml:

```yaml
trainer:
  logger:
  - class_path: lightning.pytorch.loggers.WandbLogger
    init_args:
      name: null
      project: null
```

Start training with the command below:

```bash
export HF_HOME=/tmp/hf_cache

TASK_NAME='A4GRB6_PSEAI_Chen_2020'
MUTATION_TYPE='singles_substitutions'
RUN_NAME=${TASK_NAME}_fold0

mgen fit --config experiments/AIDO.Protein/DMS/configs/substitution_LoRA_DDP.yaml \
    --data.train_split_files "[\"${MUTATION_TYPE}/${TASK_NAME}.tsv\"]" \
    --data.cv_test_fold_id 0 \
    --data.batch_size 2 \
    --trainer.logger.name ${RUN_NAME} \
    --trainer.logger.project AIDO_Demo \
    --trainer.num_nodes 1 \
    --trainer.devices auto
```

```
wandb: Currently logged in as: hnsfyfyzlp (tsinghua-lipan) to https://api.wandb.ai. Use `wandb login --relogin` to force relogin
wandb: Tracking run with wandb version 0.22.2
wandb: Run data is saved locally in ./wandb/run-20251023_184315-A4GRB6_PSEAI_Chen_2020_fold0
wandb: Run `wandb offline` to turn off syncing.
wandb: Syncing run A4GRB6_PSEAI_Chen_2020_fold0
wandb: ⭐️ View project at https://wandb.ai/tsinghua-lipan/AIDO_Demo
wandb: 🚀 View run at https://wandb.ai/tsinghua-lipan/AIDO_Demo/runs/A4GRB6_PSEAI_Chen_2020_fold0
Repo card metadata block was not found. Setting CardData to empty.
Repo card metadata block was not found. Setting CardData to empty.
Repo card metadata block was not found. Setting CardData to empty.
singles_substitutions/A4GRB6_PSEAI_Chen_(…): 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1.43M/1.43M [00:00<00:00, 1.90MB/s]
Generating train split: 5004 examples [00:00, 207951.11 examples/s]
label: mean = [-1.93724161], std = [2.1050639]
/opt/miniconda/envs/genbio/lib/python3.12/site-packages/huggingface_hub/file_download.py:942: FutureWarning: `resume_download` is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use `force_download=True`.
  warnings.warn(
/opt/miniconda/envs/genbio/lib/python3.12/site-packages/huggingface_hub/file_download.py:942: FutureWarning: `resume_download` is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use `force_download=True`.
  warnings.warn(
/opt/miniconda/envs/genbio/lib/python3.12/site-packages/huggingface_hub/file_download.py:942: FutureWarning: `resume_download` is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use `force_download=True`.
  warnings.warn(
Loading checkpoint shards: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 13/13 [00:17<00:00,  1.37s/it]
Loading checkpoint shards: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 13/13 [00:17<00:00,  1.37s/it]
Loading checkpoint shards: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 13/13 [00:18<00:00,  1.39s/it]
trainable params: 11,948,544 || all params: 16,071,144,192 || trainable%: 0.0743
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2]
LOCAL_RANK: 1 - CUDA_VISIBLE_DEVICES: [0,1,2]
LOCAL_RANK: 2 - CUDA_VISIBLE_DEVICES: [0,1,2]

  | Name     | Type                | Params | Mode 
---------------------------------------------------------
0 | metrics  | ModuleDict          | 0      | train
1 | loss     | MSELoss             | 0      | train
2 | backbone | aido_protein_16b_v1 | 16.1 B | train
3 | adapter  | MLPPoolAdapter      | 295 K  | train
---------------------------------------------------------
12.2 M    Trainable params
16.1 B    Non-trainable params
16.1 B    Total params
64,285.757Total estimated model params size (MB)
1829      Modules in train mode
1160      Modules in eval mode
Sanity Checking: |                                                                                                                                                                  | 0/? [00:00<?, ?it/s]/opt/miniconda/envs/genbio/lib/python3.12/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:424: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=68` in the `DataLoader` to improve performance.
[rank1]:[W1023 18:43:42.316982181 CPUAllocator.cpp:245] Memory block of unknown size was allocated before the profiling started, profiler results will not include the deallocation event                 
[rank2]:[W1023 18:43:42.317155840 CPUAllocator.cpp:245] Memory block of unknown size was allocated before the profiling started, profiler results will not include the deallocation event
[rank0]:[W1023 18:43:42.319645894 CPUAllocator.cpp:245] Memory block of unknown size was allocated before the profiling started, profiler results will not include the deallocation event
/opt/miniconda/envs/genbio/lib/python3.12/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:424: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=68` in the `DataLoader` to improve performance.
Epoch 0:   1%|▎                                | 4/511 [00:03<06:24,  1.32it/s, v_num=old0, train_loss=1.330, train_pearson=1.000, train_spearman=1.000, train_mae=1.150, train_r2=-622., train_mse=1.330][rank0]:[W1023 18:43:48.316836962 collection.cpp:647] Warning: Optimizer.step#AdamW.step (function operator())
[rank2]:[W1023 18:43:48.410975745 collection.cpp:647] Warning: Optimizer.step#AdamW.step (function operator())
[rank1]:[W1023 18:43:48.496507396 collection.cpp:647] Warning: Optimizer.step#AdamW.step (function operator())
Epoch 0:  24%|███████▎                      | 124/511 [01:16<03:57,  1.63it/s, v_num=old0, train_loss=0.884, train_pearson=1.000, train_spearman=1.000, train_mae=0.693, train_r2=-0.833, train_mse=0.884]
```

<img src="images/16b_training_curve.png" alt="AIDO.Protein" width="80%" style="background-color:white;"/>

## Create object step by step in Jupyter Notebook

It is very simple to start training through one command, but you may not know how it works. Let's break down the contents of the yaml configuration file.

Config file: [substitution_LoRA_DDP.yaml](../ModelGenerator/experiments/AIDO.Protein/DMS/configs/substitution_LoRA_DDP.yaml)


In [1]:
import os, sys, pathlib, torch
print(f"torch.__version__: {torch.__version__}")
os.environ['HF_HOME'] = '/tmp/hf_cache'

torch.__version__: 2.6.0+cu118


Dataset: [genbio-ai/ProteinGYM-DMS](https://huggingface.co/datasets/genbio-ai/ProteinGYM-DMS)


```yaml
data:
  class_path: modelgenerator.data.DMSFitnessPrediction
  init_args:
    path: genbio-ai/ProteinGYM-DMS
    train_split_files:
    - singles_substitutions/VRPI_BPT7_Tsuboyama_2023_2WNM.tsv
    train_split_name: 'train'
    random_seed: 42
    batch_size: 32
    cv_num_folds: 5
    cv_test_fold_id: 0
    cv_enable_val_fold: true
    cv_fold_id_col: fold_id
```

In [8]:
from modelgenerator.data import DMSFitnessPrediction

datamodule = DMSFitnessPrediction(
    path='genbio-ai/ProteinGYM-DMS', 
    train_split_files=['singles_substitutions/VRPI_BPT7_Tsuboyama_2023_2WNM.tsv'],
    train_split_name='train', 
    random_seed=42, 
    batch_size=4, 
    num_workers=8,
    cv_num_folds=5, 
    cv_test_fold_id=0, 
    cv_enable_val_fold=True, 
    cv_fold_id_col='fold_id')

datamodule.setup()

print("#Train: ", len(datamodule.train_dataset))
print("#Val: ", len(datamodule.val_dataset))
print("#Test: ", len(datamodule.test_dataset))

Repo card metadata block was not found. Setting CardData to empty.
label: mean = [-0.68593335], std = [0.93930578]


#Train:  634
#Val:  216
#Test:  197


```yaml
model:
  class_path: modelgenerator.tasks.SequenceRegression
  init_args:
    backbone:
      class_path: modelgenerator.backbones.aido_protein_16b_v1
      init_args:
        use_peft: true
        max_length: 2048
    adapter:
      class_path: modelgenerator.adapters.MLPPoolAdapter
      init_args:
        hidden_sizes:
        - 128
        dropout: 0.1
        dropout_in_middle: false
    optimizer:
      class_path: torch.optim.AdamW
      init_args:
        lr: 0.0001
        weight_decay: 0.01
    lr_scheduler:
      class_path: modelgenerator.lr_schedulers.CosineWithWarmup
      init_args:
        warmup_ratio: 0.05
```

There is one difference in the initialization of model: backbone, optimizer, adapter and other classes require the task objects to initialize, so we pass in callable object

In [None]:
import torch
from functools import partial
from modelgenerator.tasks import SequenceRegression
from modelgenerator.backbones import aido_protein_16b_v1
from modelgenerator.adapters import MLPPoolAdapter
from modelgenerator.lr_schedulers import CosineWithWarmup

modelmodule = SequenceRegression(
    backbone    = partial(aido_protein_16b_v1, use_peft=True, max_length=2048),
    adapter     = partial(MLPPoolAdapter, hidden_sizes=[128], dropout=0.1, dropout_in_middle=False),
    optimizer   = partial(torch.optim.AdamW, lr=0.0001, weight_decay=0.01),
    lr_scheduler= partial(CosineWithWarmup, warmup_ratio=0.05)
)

modelmodule.setup(stage='fit')



```yaml
trainer:
  accelerator: auto
  devices: auto
  logger: false
  callbacks:
  - class_path: lightning.pytorch.callbacks.ModelCheckpoint # save ckpt at the end of each epoch, and save the best val_mcc ckpt
    init_args:
      filename: epoch_{epoch}-val_mcc:{val_spearman:.3f}
      monitor: val_spearman
      mode: max
  - class_path: lightning.pytorch.callbacks.early_stopping.EarlyStopping
    dict_kwargs:
      monitor: val_spearman
      mode: max
      patience: 10
  max_steps: 10000
  gradient_clip_val: 0.1
  default_root_dir: logs
```

In [None]:
from lightning.pytorch import Trainer
from lightning.pytorch.callbacks import ModelCheckpoint, EarlyStopping

trainer = Trainer(
        accelerator="auto", 
        devices="auto", 
        logger=False,
        callbacks=[
            ModelCheckpoint(
                filename="epoch_{epoch}-val_mcc:{val_spearman:.3f}",
                monitor="val_spearman",
                mode="max"
            ),
            EarlyStopping(
                monitor="val_spearman",
                mode="max",
                patience=10
            )
        ],
        max_steps=10000,
        gradient_clip_val=0.1,
        default_root_dir="logs",
)

Trainer will use only 1 of 3 GPUs because it is running inside an interactive / notebook environment. You may try to set `Trainer(devices=3)` but please note that multi-GPU inside interactive / notebook environments is considered experimental and unstable. Your mileage may vary.
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs


In [10]:
trainer.fit(modelmodule, datamodule=datamodule, ckpt_path=None)

Repo card metadata block was not found. Setting CardData to empty.
label: mean = [-0.68593335], std = [0.93930578]


OutOfMemoryError: CUDA out of memory. Tried to allocate 136.00 MiB. GPU 0 has a total capacity of 79.18 GiB of which 44.56 MiB is free. Process 1712574 has 33.05 GiB memory in use. Including non-PyTorch memory, this process has 46.07 GiB memory in use. Of the allocated memory 45.26 GiB is allocated by PyTorch, and 214.93 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)