# AIDO.Protein-RAG-16B

[AIDO.Protein-RAG-16B](https://huggingface.co/genbio-ai/AIDO.Protein-RAG-16B) is a multimodal protein language model that integrates Multiple Sequence Alignment (MSA) and structural data, building upon the AIDO.Protein-16B foundation. The training process comprises three main stages:

1. 2D RoPE encoding fine-tuning
2. Initial training on 100 billion tokens from UniRef50/UniClust30 MSA data
3. Subsequent training on 80 billion tokens from AlphaFold Database MSA and structural data

<img src="images/rag_1.png" alt="AIDO.Protein-RAG" width="300" style="background-color:white;"/>

<img src="images/rag_2.png" alt="AIDO.Protein-RAG" width="400" style="background-color:white;"/>

| Hyper-params                | (1) 1D -> 2D finetuning | (2) UniRef50/Uniclust30 MSA finetuning | (3) AFDB MSA & Structure tokens finetuning |
| --------------------------- | :---------------------: | :------------------------------------: | :----------------------------------------: |
| Initialized parameters      |   AIDO.Protein-16B      |       Stage (1)                        |                      Stage (2)             |
| Data                        |   ColabFoldDB, UniRef   |       HHblits_MSA, Retriever_MSA       |        AFDB MSA & Structure tokens         |
| Global Batch Size           |           512           |                  256                   |                    256                     |
| Sequence length             |          2048           |                 12800                  |                   12800                    |
| Per Device Micro Batch Size |            1            |                   1                    |                     1                      |
| Precision                   |     Mixed FP32-FP16     |            Mixed FP32-FP16             |              Mixed FP32-FP16               |
| LR                          |       [5e-6,5e-5]       |              [1e-6, 1e-5]              |                    1e-5                    |
| Num Tokens                  |       10 billion        |              100 billion               |                 80 billion                 |

Reference: [Retrieval Augmented Protein Language Models for Protein Structure Prediction](https://www.biorxiv.org/content/10.1101/2024.12.02.626519v1)

## Step-by-Step Example

I will introduce how to manually load the model and tokenizer; how to preprocess the input MSA file and PDB file; and finally how to obtain the protein embedding.

In [3]:
print("Hello world")
import os, sys, pathlib, torch
import torch.nn.functional as F
import numpy as np
os.environ['HF_HOME'] = '/tmp/hf_cache'

import modelgenerator
modelgenerator_path = str(pathlib.Path(modelgenerator.__file__).parent)
print(f"modelgenerator_path: {modelgenerator_path}")

from modelgenerator.huggingface_models.fm4bio import FM4BioForMaskedLM
from modelgenerator.huggingface_models.fm4bio import FM4BioTokenizer
print("Hello world")

from utils import misc
from utils import protein

Hello world
modelgenerator_path: /jfs/pan-li/Demo/ModelGenerator/modelgenerator
Hello world


#### Load model

In [6]:
model = FM4BioForMaskedLM.from_pretrained("genbio-ai/AIDO.Protein-RAG-16B")
model = model.cuda().eval()
print(model)



Loading checkpoint shards: 100%|██████████| 13/13 [00:08<00:00,  1.59it/s]
Some weights of FM4BioForMaskedLM were not initialized from the model checkpoint at genbio-ai/AIDO.Protein-RAG-16B and are newly initialized: ['output_embed.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


FM4BioForMaskedLM(
  (bert): FM4BioModel(
    (embeddings): FM4BioEmbeddings(
      (word_embeddings): Embedding(640, 2304, padding_idx=0)
      (str_embeddings): Linear(in_features=384, out_features=2304, bias=False)
      (dropout): Dropout(p=0, inplace=False)
    )
    (encoder): FM4BioEncoder(
      (layer): ModuleList(
        (0-35): 36 x FM4BioLayer(
          (attention): FM4BioAttention(
            (ln): RnaRMSNorm()
            (self): FM4BioSelfAttention(
              (query): Linear(in_features=2304, out_features=2304, bias=True)
              (key): Linear(in_features=2304, out_features=2304, bias=True)
              (value): Linear(in_features=2304, out_features=2304, bias=True)
              (dropout): Dropout(p=0, inplace=False)
            )
            (output): FM4BioSelfOutput(
              (dense): Linear(in_features=2304, out_features=2304, bias=True)
              (dropout): Dropout(p=0, inplace=False)
            )
          )
          (ln): RnaRMSNorm()
   

#### Load tokenizer

In [7]:
vocab_file = os.path.join(modelgenerator_path, "huggingface_models/fm4bio/vocab_protein.txt")
tokenizer = FM4BioTokenizer(vocab_file=vocab_file)

#### Load MSA and preprocess

In [4]:
msa = misc.load_msa_txt("sample_data/MK01_HUMAN_Brenan_2016.txt.gz")
print(f"Number of MSA sequences: {len(msa)}")
for seq in msa[:5]:
    print(seq)

query_sequence = msa[0]
print(f"length of query: {len(query_sequence)}")
msa = msa[1:]

Number of MSA sequences: 82905
MAAAAAAGAGPEMVRGQVFDVGPRYTNLSYIGEGAYGMVCSAYDNVNKVRVAIKKISPFEHQTYCQRTLREIKILLRFRHENIIGINDIIRAPTIEQMKDVYIVQDLMETDLYKLLKTQHLSNDHICYFLYQILRGLKYIHSANVLHRDLKPSNLLLNTTCDLKICDFGLARVADPDHDHTGFLTEYVATRWYRAPEIMLNSKGYTKSIDIWSVGCILAEMLSNRPIFPGKHYLDQLNHILGILGSPSQEDLNCIINLKARNYLLSLPHKNKVPWNRLFPNADSKALDLLDKMLTFNPHKRIEVEQALAHPYLEQYYDPSDEPIAEAPFKFDMELDDLPKEKLKELIFEETARFQPGYRS
-------------VRGQVFDVGPRYTNLSYIGEGAYGMVCSAYDNVNKVRVAIKKISPFEHQTYCQRTLREIKILLRFRHENIIGINDIIRAPTIEQMKDVYIVQDLMETDLYKLLKTQHLSNDHICYFLYQILRGLKYIHSANVLHRDLKPSNLLLNTTCDLKICDFGLARVADPDHDHTGFLTEYVATRWYRAPEIMLNSKGYTKSIDIWSVGCILAEMLSNRPIFPGKHYLDQLNHILGILGSPSQEDLNCIINLKARNYLLSLPHKNKVPWNRLFPNADSKALDLLDKMLTFNPHKRIEVEQALAHPYLEQYYDPSDEPIAEAPFKFDMELDDLPKEKLKELIFEETARFQPG---
-------------VRGQVFDVGPRYTNLSYIGEGAYGMVCSAYDNINKVRVAIKKISPFEHQTYCQRTLREIKILLRFRHENIIGINDIIRAPTIEQMKDVYIVQDLMETDLYKLLKTQHLSNDHICYFLYQILRGLKYIHSANVLHRDLKPSNLLLNTTCDLKICDFGLARVADPDHDHTGFLTEYVATRWYRAPEIMLNSKGYTKSIDIWSVGCILAEMLSNRPIFPGKHYLDQLNHILGILGSP

If we concatenate all MSA tokens, it will be 82905 $\times$ 360 = 29,845,800 tokens ! We use the greedy select strategy to obtain a subset from all MSAs so that the edit distance between each sequence is maximized.

In [7]:
f_msa = misc.greedy_select(msa, num_seqs=None, num_tokens=12_800, seed=0)
f_msa.sort(key=lambda x: x.count('-'))

print(f"Number of MSA sequences: {len(f_msa)}")
for seq in f_msa[:5]:
    print(seq)

num_tokens = sum([ len(seq)-seq.count('-') for seq in f_msa ])
print(f"Number of tokens in MSA (excluding gaps): {num_tokens}")

Number of MSA sequences: 43
---------NKQTVQNHSFDTIDKYKVTEIIGSGTYGVVAICKELQTDQQFALKKNIVFPDENHQLRMLRELKMLHHFRCPYIVNLKDVYVPNQLNQLRDIEMITSLMEADLRDIFDSQSLSPKHVKWFMYQICLSVYYMQKAKILHRDLKPENILVNSQCDVAICDFGLARGYYNSLQNKKLSSNYVVSRWYRPPELLTNATQYNKTLDMWSVGCIMYELLKGEVLFKGSGSIDQIQRIIKQLGTPAIDDFNGSEAARDYIY-NKFPICKRRSFTKRLPNTCPMAIDLMKRMLTFNMYKRINPLDALLHPYFREFYDQSDLNIPLSPFDTAWEENMLSSNDLKLEAFNTLKSIKK----
----------VYKVRGQSFDIEDTYTVTSVVGHGAYGVVCAALDDRTFQEVAIKRVSVFEDLIDGRRIWREILILRLLRCRNMLRLLRVLPPKPITEFRDLYMVTDLFDTDLFAIIRQKNMSTDMLRRVGARVLQCLADMHTMGIVHRDIKPSNILLRDEENATVCDFGLARAGLLDLTEPLDLTDYVVTRWYRPPELLLM-CSYSFPIDMWAVGCVMAEYVTQRPIFAGRDYIHQLQLVLASVNITGTSFMESTSASAINHMNDVARKYGTRPLSNLLAALPKEGFDLVNRMLAFEPDNRITALEALQHPFFEPLALEEPARTLSPAVELSFDMAEISEYQLRRAIWDEVEHYKKQ---
---------NTVEHSLTQFTVPKRYQNLSPFEYNSHDIVTYATDTNTGKKVTIRKILPFDSVARAHRTYRELKLRIHLNDAQVAQLYDVFTPEDLNNFETLYLVENYVEYDLKRVIYSVVVTEDHIKMVIYCLLRGLKFIHSAGIIHSRIISSNIGIDKDSNVSIFGWDSAATAHIRRKYDEYNESDIYRRWY-APEMIINPEHCNEKVDIWSVGCIMAELIVRQPLFPGTSQSDQLTKIFDITGTPDSK

#### Load Structure Tokenizer, PDB file and get structure embeddings

We used [genbio-ai/AIDO.StructureTokenizer](https://huggingface.co/genbio-ai/AIDO.StructureTokenizer) to tokenize the protein structure into discrete tokens and embeddings.

<img src="images/aido_structure.png" alt="aido_structure" width="800" style="background-color:white;"/>

In [10]:
str_tokenizer = misc.AIDO_Structure_Tokenizer(device='cuda:0')

In [11]:
with open(f"sample_data/MK01_HUMAN_Brenan_2016.pdb") as IN:
    text = IN.read()

prot = protein.from_pdb_string(text, molecular_type='protein')
prot

Molecular weight: 38.433KDa
A (360): MAAAAAAGAGPEMVRGQVFDVGPRYTNLSYIGEGAYGMVCSAYDNVNKVRVAIKKISPFEHQTYCQRTLREIKILLRFRHENIIGINDIIRAPTIEQMKDVYIVQDLMETDLYKLLKTQHLSNDHICYFLYQILRGLKYIHSANVLHRDLKPSNLLLNTTCDLKICDFGLARVADPDHDHTGFLTEYVATRWYRAPEIMLNSKGYTKSIDIWSVGCILAEMLSNRPIFPGKHYLDQLNHILGILGSPSQEDLNCIINLKARNYLLSLPHKNKVPWNRLFPNADSKALDLLDKMLTFNPHKRIEVEQALAHPYLEQYYDPSDEPIAEAPFKFDMELDDLPKEKLKELIFEETARFQPGYRS

In [12]:
print(f"prot.aatype: {prot.aatype.shape}\n{prot.aatype}")
print(f"prot.atom_positions: {prot.atom_positions.shape}\n{prot.atom_positions}")
print(f"prot.atom_mask: {prot.atom_mask.shape}\n{prot.atom_mask}")

prot.aatype: (360,)
[12  0  0  0  0  0  0  7  0  7 14  6 12 19  1  7  5 19 13  3 19  7 14  1
 18 16  2 10 15 18  9  7  6  7  0 18  7 12 19  4 15  0 18  3  2 19  2 11
 19  1 19  0  9 11 11  9 15 14 13  6  8  5 16 18  4  5  1 16 10  1  6  9
 11  9 10 10  1 13  1  8  6  2  9  9  7  9  2  3  9  9  1  0 14 16  9  6
  5 12 11  3 19 18  9 19  5  3 10 12  6 16  3 10 18 11 10 10 11 16  5  8
 10 15  2  3  8  9  4 18 13 10 18  5  9 10  1  7 10 11 18  9  8 15  0  2
 19 10  8  1  3 10 11 14 15  2 10 10 10  2 16 16  4  3 10 11  9  4  3 13
  7 10  0  1 19  0  3 14  3  8  3  8 16  7 13 10 16  6 18 19  0 16  1 17
 18  1  0 14  6  9 12 10  2 15 11  7 18 16 11 15  9  3  9 17 15 19  7  4
  9 10  0  6 12 10 15  2  1 14  9 13 14  7 11  8 18 10  3  5 10  2  8  9
 10  7  9 10  7 15 14 15  5  6  3 10  2  4  9  9  2 10 11  0  1  2 18 10
 10 15 10 14  8 11  2 11 19 14 17  2  1 10 13 14  2  0  3 15 11  0 10  3
 10 10  3 11 12 10 16 13  2 14  8 11  1  9  6 19  6  5  0 10  0  8 14 18
 10  6  5 18 18  3 14 15  3  6 

In [13]:
str_embs, str_toks = str_tokenizer.encode(prot.aatype, prot.atom_positions, prot.atom_mask, get_embedding=True)
str_embs = str_embs.cuda()
str_toks = str_toks.cuda()
print(str_embs.shape)
print(str_toks.shape)

torch.Size([360, 384])
torch.Size([360])


#### Tokenize the protein sequences

<img src="images/rag_1.png" alt="AIDO.Protein-RAG" width="300" style="background-color:white;"/>

<img src="images/rag_2.png" alt="AIDO.Protein-RAG" width="400" style="background-color:white;"/>

In [None]:
def tokenize(q_seq, msa, tokenizer, max_context=12800):
    """
    Tokenizes the input sequence and optionally additional sequences for multiple sequence alignment (MSA).
    
    Args:
        q_seq (str): The query sequence to be tokenized.
        msa (list or None): A list of sequences for multiple sequence alignment. If None, no MSA sequences are added.
        tokenizer (object): The tokenizer object used to encode the sequences.
        max_context (int, optional): The maximum number of tokens to consider in the context. Defaults to 12800.
    
    Returns:
        tuple: A tuple containing:
            - tokens (np.ndarray): The tokenized sequences.
            - pos_encoding (np.ndarray): The positional encoding for the tokens.
    """
    len_seq = len(q_seq)
    tokens = tokenizer.encode(q_seq, add_special_tokens=False)
    num_seq = 1
    
    for msa_seq in msa:
        assert len(msa_seq) == len_seq, f"len(msa_seq)={len(msa_seq)}, len_seq={len_seq}"
        tokens.extend(tokenizer.encode(msa_seq, add_special_tokens=False))
        num_seq += 1
    
    pos_encoding = np.stack([ np.tile(np.arange(len_seq), num_seq), np.repeat(np.arange(num_seq), len_seq) ])
    
    tokens = np.array(tokens)
    tok_mask = (tokens != tokenizer._token_to_id['-'])
    tokens, pos_encoding = tokens[tok_mask][:max_context], pos_encoding[..., tok_mask][..., :max_context]
    return tokens, pos_encoding

tokens, pos_encoding = tokenize(query_sequence, f_msa, tokenizer, max_context=12800)
tokens       = torch.from_numpy(tokens).cuda()
pos_encoding = torch.from_numpy(pos_encoding).cuda()

print(f"tokens: {tokens.shape}\n{tokens}")
print(f"pos_encoding: {pos_encoding.shape}\n{pos_encoding}")

tokens: torch.Size([12800])
tensor([17,  2,  2,  ...,  3, 20,  9], device='cuda:0')
pos_encoding: torch.Size([2, 12800])
tensor([[  0,   1,   2,  ..., 214, 215, 216],
        [  0,   0,   0,  ...,  40,  40,  40]], device='cuda:0')


#### Padding the structure embeddings to match token length

Sequence embedding and structure embedding will be added element-wise. We need to pad the structure embedding to the same length as the sequence.

In [15]:
print(f"str_embs: {str_embs.shape}")
print(f"tokens: {tokens.shape}")

padding = tokens.shape[0]-str_embs.shape[0]
str_embs = F.pad(str_embs, (0, 0, 0, padding))

print(f"str_embs: {str_embs.shape}")

str_embs: torch.Size([360, 384])
tokens: torch.Size([12800])
str_embs: torch.Size([12800, 384])


In [17]:
with torch.no_grad(): # , torch.amp.autocast('cuda', dtype=torch.bfloat16):
    lm_output = model(
        input_ids=tokens[None],
        position_ids=pos_encoding[None],
        inputs_str_embeds=str_embs[None],
        output_hidden_states=True,
    )
    last_hidden_state = lm_output.hidden_states[-1]

print(f"last_hidden_state: {last_hidden_state.shape}")

last_hidden_state: torch.Size([1, 12800, 2304])


## ModelGenerator tasks

* **Get embeddings**: input sequence, get per-residue and per-sequence embeddings.
* **Sequence level classification**: input sequence, get classification label (e.g., enzyme/non-enzyme).
* **Token level classification**: input sequence, get per-residue labels (e.g., secondary structure).
* **Sequence level regression**: input sequence, get a real-valued output (e.g., stability).

```python
from modelgenerator.tasks import Embed
from modelgenerator.tasks import SequenceClassification
from modelgenerator.tasks import TokenClassification
from modelgenerator.tasks import SequenceRegression
```

### How to implement these tasks using ModelGenerator?
* **Backbone**: use `genbio-ai/AIDO.Protein-RAG-16B` as the backbone model.
* **Adaptors**: different adaptors can be used for different tasks.
* **Dataset**: different datasets can be used for different tasks.
* **Loss functions**: different loss functions can be used for different tasks.

The following section explains how to use the predefined task class in ModelGenerator to load the model

### Embeddings with MSA and Structure

In [8]:
import os, sys, pathlib, torch
os.environ['HF_HOME'] = '/tmp/hf_cache'

import random
import numpy as np
import torch
from modelgenerator.tasks import Embed

model = Embed.from_config({"model.backbone": "aido_protein_rag_16b"}).eval()
model.backbone.max_length = 12800
restypes = 'ARNDCQEGHILKMFPSTWYV'
data = {
    'sequences': [''.join(random.choice(restypes) for _ in range(50))],
    'msa': [ [ ''.join(random.choice(restypes+'-') for _ in range(50)) for _ in range(25) ] ],
    'str_emb': np.random.normal(size=(1, 50, 384))
}
transformed_batch = model.transform(data)
for k, v in transformed_batch.items():
    if isinstance(v, torch.Tensor):
        print(f"{k}: {v.shape}")
    elif isinstance(v, list):
        print(f"{k}: list of length {len(v)}")

with torch.no_grad():
    embedding = model(transformed_batch)

print(embedding.shape)

Loading checkpoint shards: 100%|██████████| 13/13 [00:09<00:00,  1.40it/s]


sequences: list of length 1
input_ids: torch.Size([1, 1244])
attention_mask: torch.Size([1, 50])
special_tokens_mask: list of length 1
full_attention_mask: list of length 1
query_tokens_mask: list of length 1
position_ids: list of length 1
inputs_str_embeds: list of length 1
torch.Size([1, 50, 2304])


### Sequence Level Classification

In [2]:
import torch
from modelgenerator.tasks import SequenceClassification

model = SequenceClassification.from_config({"model.backbone": "aido_protein_rag_16b", "model.n_classes": 2}).eval()
model.backbone.max_length = 12800
restypes = 'ARNDCQEGHILKMFPSTWYV'
data = {
    'sequences': [''.join(random.choice(restypes) for _ in range(50))],
    'msa': [ [ ''.join(random.choice(restypes+'-') for _ in range(50)) for _ in range(25) ] ],
    'str_emb': np.random.normal(size=(1, 50, 384))
}
transformed_batch = model.transform(data)
with torch.no_grad():
    logits = model(transformed_batch)

print(logits)
print(torch.argmax(logits, dim=-1))

Loading checkpoint shards: 100%|██████████| 13/13 [00:08<00:00,  1.46it/s]


tensor([[0.0039, 0.2004]])
tensor([1])


### Token Level Classification

In [3]:
import torch
from modelgenerator.tasks import TokenClassification
model = TokenClassification.from_config({"model.backbone": "aido_protein_rag_16b", "model.n_classes": 3}).eval()
model.backbone.max_length = 12800
restypes = 'ARNDCQEGHILKMFPSTWYV'
data = {
    'sequences': [''.join(random.choice(restypes) for _ in range(50))],
    'msa': [ [ ''.join(random.choice(restypes+'-') for _ in range(50)) for _ in range(25) ] ],
    'str_emb': np.random.normal(size=(1, 50, 384))
}
transformed_batch = model.transform(data)
with torch.no_grad():
    logits = model(transformed_batch)

print(logits)
print(torch.argmax(logits, dim=-1))

Loading checkpoint shards: 100%|██████████| 13/13 [00:08<00:00,  1.54it/s]


tensor([[[ 0.1389, -0.0476, -0.1429],
         [ 0.0948, -0.0613,  0.0507],
         [ 0.0497, -0.1755, -0.0152],
         [-0.0173, -0.0536,  0.1043],
         [ 0.0619, -0.0745, -0.0095],
         [ 0.0890, -0.0762,  0.0661],
         [ 0.1128,  0.0329,  0.0497],
         [-0.0100,  0.0207, -0.0390],
         [ 0.1484,  0.0014,  0.0852],
         [ 0.0504, -0.0116, -0.0466],
         [-0.0045,  0.0447, -0.1606],
         [ 0.0580,  0.0081, -0.1067],
         [ 0.0450,  0.0838, -0.0925],
         [ 0.0133, -0.0746, -0.0793],
         [ 0.0226,  0.0126,  0.0038],
         [-0.0122,  0.1079, -0.0010],
         [ 0.0734,  0.0792,  0.1325],
         [-0.0793,  0.0099, -0.0040],
         [ 0.0741,  0.0202, -0.0121],
         [ 0.1228,  0.0773,  0.0103],
         [ 0.0661, -0.0741, -0.1237],
         [ 0.1132,  0.0276, -0.0024],
         [ 0.0739, -0.0825, -0.0299],
         [ 0.0052,  0.0290,  0.0185],
         [ 0.0949, -0.0175,  0.0382],
         [ 0.1304,  0.0236, -0.0514],
         [-0

### Seq Level Regression

In [4]:
from modelgenerator.tasks import SequenceRegression
model = SequenceRegression.from_config({"model.backbone": "aido_protein_rag_16b"}).eval()
model.backbone.max_length = 12800
restypes = 'ARNDCQEGHILKMFPSTWYV'
data = {
    'sequences': [''.join(random.choice(restypes) for _ in range(50))],
    'msa': [ [ ''.join(random.choice(restypes+'-') for _ in range(50)) for _ in range(25) ] ],
    'str_emb': np.random.normal(size=(1, 50, 384))
}
transformed_batch = model.transform(data)
with torch.no_grad():
    logits = model(transformed_batch)

print(logits.shape)

Loading checkpoint shards: 100%|██████████| 13/13 [00:09<00:00,  1.38it/s]


torch.Size([1, 1])
