# AIDO.Protein-RAG-16B

[IDO.Protein-RAG-16B](https://huggingface.co/genbio-ai/AIDO.Protein-RAG-16B) is a multimodal protein language model that integrates Multiple Sequence Alignment (MSA) and structural data, building upon the AIDO.Protein-16B foundation. The training process comprises three main stages:

1. 2D RoPE encoding fine-tuning
2. Initial training on 100 billion tokens from UniRef50/UniClust30 MSA data
3. Subsequent training on 80 billion tokens from AlphaFold Database MSA and structural data

<img src="images/rag_1.png" alt="AIDO.Protein-RAG" width="300" style="background-color:white;"/>

<img src="images/rag_2.png" alt="AIDO.Protein-RAG" width="400" style="background-color:white;"/>

| Hyper-params                | (1) 1D -> 2D finetuning | (2) UniRef50/Uniclust30 MSA finetuning | (3) AFDB MSA & Structure tokens finetuning |
| --------------------------- | :---------------------: | :------------------------------------: | :----------------------------------------: |
| Initialized parameters      |   AIDO.Protein-16B      |       Stage (1)                        |                      Stage (2)             |
| Data                        |   ColabFoldDB, UniRef   |       HHblits_MSA, Retriever_MSA       |        AFDB MSA & Structure tokens         |
| Global Batch Size           |           512           |                  256                   |                    256                     |
| Sequence length             |          2048           |                 12800                  |                   12800                    |
| Per Device Micro Batch Size |            1            |                   1                    |                     1                      |
| Precision                   |     Mixed FP32-FP16     |            Mixed FP32-FP16             |              Mixed FP32-FP16               |
| LR                          |       [5e-6,5e-5]       |              [1e-6, 1e-5]              |                    1e-5                    |
| Num Tokens                  |       10 billion        |              100 billion               |                 80 billion                 |


## Create object step by step in Jupyter Notebook

```yaml
data:
  class_path: modelgenerator.data.DMSFitnessPrediction
  init_args:
    path: genbio-ai/ProteinGYM-DMS-RAG
    is_rag_dataset: true
    train_split_files: null
    config_name: NCAP_I34A1_Doud_2015
    x_col: sequences
    extra_cols:
    - msa
    - str_emb
    extra_col_aliases:
    - msa
    - str_emb
    train_split_name: 'train'
    random_seed: 42
    batch_size: 1
    cv_num_folds: 5
    cv_test_fold_id: 0
    cv_enable_val_fold: false
    cv_replace_val_fold_as_test_fold: true
    cv_fold_id_col: fold_id
    msa_random_seed: 1
    max_context_length: 6400
```

datasets: `pip install datasets==3.0.0`

In [2]:
import os
os.environ["HF_HOME"] = "/tmp/hf_cache"

from modelgenerator.data import DMSFitnessPrediction

datamodule = DMSFitnessPrediction(
    path="genbio-ai/ProteinGYM-DMS-RAG", 
    is_rag_dataset=True, 
    train_split_files=None, 
    config_name="NCAP_I34A1_Doud_2015", 
    x_col="sequences", 
    extra_cols=["msa", "str_emb"], 
    extra_col_aliases=["msa", "str_emb"], 
    train_split_name="train", 
    random_seed=42, 
    batch_size=1, 
    cv_num_folds=5, 
    cv_test_fold_id=0, 
    cv_enable_val_fold=False, 
    cv_replace_val_fold_as_test_fold=True, 
    cv_fold_id_col="fold_id", 
    msa_random_seed=1, 
    max_context_length=6400)

datamodule.setup()

Repo card metadata block was not found. Setting CardData to empty.
label: mean = [-2.74618815], std = [1.2256391]


In [3]:
print( len(datamodule.train_dataset) )
print( len(datamodule.val_dataset) )
print( len(datamodule.test_dataset) )

7540
1922
1922


In [4]:
sample = datamodule.train_dataset[0]
print_dict(sample)

[0;34;49msequences[0m: [0;33;49mtype[0m=builtins.str, [0;33;49mv[:20][0m=LASQGTKRSYEQMETDGERQ
[0;34;49mlabels[0m: [0;33;49mtype[0m=numpy.ndarray, [0;33;49mv.dtype[0m=float64, [0;33;49mv.shape[0m=(1,)
[0;34;49mmsa[0m: [0;33;49mtype[0m=builtins.list, [0;33;49mlen(v)[0m=91
[0;34;49mstr_emb[0m: [0;33;49mtype[0m=numpy.ndarray, [0;33;49mv.dtype[0m=float32, [0;33;49mv.shape[0m=(498, 384)


In [10]:
sample

{'sequences': 'LASQGTKRSYEQMETDGERQNATEIRASVGKMIGGIGRFYIQMCTELKLSDYEGRLIQNSLTIERMVLSAFDERRNKYLEEHPSAGKDPKKTGGPIYRRVNGKWMRELILYDKEEIRRIWRQANNGDDATAGLTHMMIWHSNLNDATYQRTRALVRTGMDPRMCSLMQGSTLPRRSGAAGAAVKGVGTMVMELVRMIKRGINDRNFWRGENGRKTRIAYERMCNILKGKFQTAAQKAMMDQVRESRNPGNAEFEDLTFLARSALILRGSVAHKSCLPACVYGPAVASGYDFEREGYSLVGIDPFRLLQNSQVYSLIRPNENPAHKSQLVWMACHSAAFEDLRVLSFIKGTKVLPRGKLSTRGVQIASNENMETMESSTLELRSRYWAIRTRSGGNTNQQRASAGQISIQPTFSVQRNLPFDRTTIMAAFNGNTEGRTSDMRTEIIRMMESARPEDVSFQGRGVFELSDEKAASPIVPSFDMSNEGSYFFGDNAEEYDN',
 'labels': array([1.72476094]),
 'msa': ['MASQGTKRSYEQMETDGERQNATEIRASVGKMIGGIGRFYIQMCTELKLSDYEGRLIQNSLTIERMVLSAFDERRNKYLEEHPSAGKDPKKTGGPIYRRVNGKWMRELILYDKEEIRRIWRQANNGDDATAGLTHMMIWHSNLNDATYQRTRALVRTGMDPRMCSLMQGSTLPRRSGAAGAAVKGVGTMVMELVRMIKRGINDRNFWRGENGRKTRIAYERMCNILKGKFQTAAQKAMMDQVRESRNPGNAEFEDLTFLARSALILRGSVAHKSCLPACVYGPAVASGYDFEREGYSLVGIDPFRLLQNSQVYSLIRPNENPAHKSQLVWMACHSAAFEDLRVLSFIKGTKVLPRGKLSTRGVQIASNENMETMESSTLELRSRYWAIRTRSGGNTNQQRASAGQISIQPTFSVQRNLPFDRTTIMAAFNGNTEGRTSDMRT

## Training by Command Line

Config file: 
* [substitution_LoRA_DDP.yaml](../ModelGenerator/experiments/AIDO.Protein-RAG/DMS_RAG/configs/substitution_LoRA_DDP.yaml)
* [wandb.yaml](../ModelGenerator/experiments/AIDO.Protein-RAG/DMS_RAG/configs/wandb.yaml)

```bash
export HF_HOME=/tmp/hf_cache

mgen fit \
    --config experiments/AIDO.Protein-RAG/DMS_RAG/configs/substitution_LoRA_DDP.yaml \
    --config experiments/AIDO.Protein-RAG/DMS_RAG/configs/wandb.yaml \
    --trainer.logger.project AIDO_Protein_DMS_LoRA_DDP \
    --trainer.logger.name Q2N0S5_9HIV1_Haddox_2018_fold0 \
    --trainer.logger.id Q2N0S5_9HIV1_Haddox_2018_fold0_98hto81p \
    --trainer.default_root_dir logs/DMS_Benchmark \
    --trainer.logger.save_dir logs/DMS_Benchmark \
    --data.train_split_files "[singles_substitutions/Q2N0S5_9HIV1_Haddox_2018.tsv]" \
    --data.cv_test_fold_id 0 \
    --trainer.precision bf16-mixed \
    --trainer.accumulate_grad_batches 1 \
    --trainer.num_nodes 1 \
    --data.batch_size 1 \
    --model.init_args.backbone.init_args.config_overwrites.str_embedding_in 384 \
    --trainer.callbacks.patience 10
```

<img src="images/rag_dms_curve.png" alt="AIDO.Protein-RAG curve" width="80%" style="background-color:white;"/>