## Example Usage of CurrAb

Full code used to evaluate models (including inference, per-position inference, and classification) can be found in here: [model-eval](./model-eval/).
This notebook only provides a basic example of loading & using CurrAb to predict masked residues.

### 1. Install required packages

In [1]:
# !pip install transformers torch

### 2. Load model & tokenizer
You can find the 650M-parameter models from our paper, including CurrAb, [on Hugging Face](https://huggingface.co/collections/brineylab/curriculum-paper-685b08a4b6986df7c5a5e3c4).

In [2]:
import torch
from transformers import EsmTokenizer, EsmForMaskedLM

# check if CUDA is available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# load model & tokenizer
model = EsmForMaskedLM.from_pretrained("brineylab/CurrAb").to(device)
tokenizer = EsmTokenizer.from_pretrained("brineylab/CurrAb")

### 3. Format sequences
If you want to use our test datasets, you can download them from [Zenodo](https://zenodo.org/records/14661302).

For this example, we will use a single paired sequence from the test dataset:

In [3]:
heavy_chain = "EVQLVESGGGLVQPGGSLRLSCAASGFTFSSYDMHWVRQATGKGLEWVSAIGTAGDTYYPGSVKGRFTISRENAKNSLYLQMNSLRAGDTAVYYCARGYCTNGVCYTFGDYGMDVWGQGTTVTVSS"
light_chain = "DIQMTQSPSSLSASVGDRVTITCRASQSISSYLNWYQQKPGKAPKLLIYAASSLQSGVPSRFSGSGSGTDFTLTISSLQPEDFATYYCQQSYSTLWTFGQGTKVEIK"

# format sequences with the <cls> separator
paired_sequence = f"{heavy_chain}<cls>{light_chain}"
heavy_sequence = f"{heavy_chain}<cls>"
light_sequence = f"<cls>{light_chain}"

### 4. Tokenize sequences

In [4]:
heavy_tokenized = tokenizer(heavy_sequence, return_tensors="pt").to(device)
print(heavy_tokenized)

{'input_ids': tensor([[ 0,  9,  7, 16,  4,  7,  9,  8,  6,  6,  6,  4,  7, 16, 14,  6,  6,  8,
          4, 10,  4,  8, 23,  5,  5,  8,  6, 18, 11, 18,  8,  8, 19, 13, 20, 21,
         22,  7, 10, 16,  5, 11,  6, 15,  6,  4,  9, 22,  7,  8,  5, 12,  6, 11,
          5,  6, 13, 11, 19, 19, 14,  6,  8,  7, 15,  6, 10, 18, 11, 12,  8, 10,
          9, 17,  5, 15, 17,  8,  4, 19,  4, 16, 20, 17,  8,  4, 10,  5,  6, 13,
         11,  5,  7, 19, 19, 23,  5, 10,  6, 19, 23, 11, 17,  6,  7, 23, 19, 11,
         18,  6, 13, 19,  6, 20, 13,  7, 22,  6, 16,  6, 11, 11,  7, 11,  7,  8,
          8,  0,  2]], device='cuda:0'), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 

### 5. run inference

#### a. Extract logits, attentions, or hidden states

In [5]:
with torch.no_grad():
    outputs = model(
        **heavy_tokenized,
        output_attentions=False,  # set to True to output attentions
        output_hidden_states=False,  # set to True to output hidden states
    )
logits = outputs.logits
print("Logits shape:", logits.shape)

Logits shape: torch.Size([1, 129, 33])


#### b. Predict a masked position

In [6]:
# mask the 20th residue (indexing starts at 0)
masked_heavy = list(heavy_chain)
masked_heavy[19] = "<mask>"
heavy_masked = "".join(masked_heavy) + "<cls>"
print("Masked sequence:", heavy_masked)

# tokenize masked sequence
masked_tokenized = tokenizer(heavy_masked, return_tensors="pt").to(device)

Masked sequence: EVQLVESGGGLVQPGGSLR<mask>SCAASGFTFSSYDMHWVRQATGKGLEWVSAIGTAGDTYYPGSVKGRFTISRENAKNSLYLQMNSLRAGDTAVYYCARGYCTNGVCYTFGDYGMDVWGQGTTVTVSS<cls>


In [7]:
# run inference with labels
with torch.no_grad():
    outputs = model(
        **masked_tokenized,
        labels=tokenizer(heavy_sequence, return_tensors="pt").input_ids.to(device),
        output_attentions=False,
        output_hidden_states=False,
    )
print("Loss:", outputs.loss.item())

Loss: 0.2867249846458435


In [8]:
# get predicted token at the masked position
mask_token_index = (
    (masked_tokenized.input_ids == tokenizer.mask_token_id)
    .nonzero(as_tuple=True)[1][0]
    .item()
)
predicted_token_id = outputs.logits[0, mask_token_index].argmax().item()
predicted_token = tokenizer.decode([predicted_token_id])
print(f"Predicted residue: {predicted_token}")

Predicted residue: L
