In [1]:
# ============================================================
# Imports
# ============================================================
import os
import re
import torch
import h5py
import pandas as pd
from tqdm.auto import tqdm

from transformers import (
    T5Tokenizer, 
    T5EncoderModel, 
    EsmModel, 
    EsmTokenizer,
)
import ablang2

tqdm.pandas()

2025-12-19 16:30:56.452570: I external/local_tsl/tsl/cuda/cudart_stub.cc:31] Could not find cuda drivers on your machine, GPU will not be used.
2025-12-19 16:30:56.484549: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2025-12-19 16:30:56.484578: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2025-12-19 16:30:56.485684: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2025-12-19 16:30:56.492866: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructio

[2025-12-19 16:30:59,246] [INFO] [real_accelerator.py:219:get_accelerator] Setting ds_accelerator to cuda (auto detect)


/home/mhossai5/.conda/envs/llm_tor2/compiler_compat/ld: /home/mhossai5/local/cuda-11.8/lib64/libcufile.so: undefined reference to `std::runtime_error::~runtime_error()@GLIBCXX_3.4'
/home/mhossai5/.conda/envs/llm_tor2/compiler_compat/ld: /home/mhossai5/local/cuda-11.8/lib64/libcufile.so: undefined reference to `__gxx_personality_v0@CXXABI_1.3'
/home/mhossai5/.conda/envs/llm_tor2/compiler_compat/ld: /home/mhossai5/local/cuda-11.8/lib64/libcufile.so: undefined reference to `std::ostream::tellp()@GLIBCXX_3.4'
/home/mhossai5/.conda/envs/llm_tor2/compiler_compat/ld: /home/mhossai5/local/cuda-11.8/lib64/libcufile.so: undefined reference to `std::string::substr(unsigned long, unsigned long) const@GLIBCXX_3.4'
/home/mhossai5/.conda/envs/llm_tor2/compiler_compat/ld: /home/mhossai5/local/cuda-11.8/lib64/libcufile.so: undefined reference to `std::string::_M_replace_aux(unsigned long, unsigned long, unsigned long, char)@GLIBCXX_3.4'
/home/mhossai5/.conda/envs/llm_tor2/compiler_compat/ld: /home/mhos

In [2]:
# ============================================================
# DEVICE SETUP
# ============================================================

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(f"\n[INFO] Using device: {device}\n")


[INFO] Using device: cuda:0



In [3]:
!ls ./data/SabDab-RELAAI

AbAg_record_level_balanced_dataset.csv


In [4]:
# ============================================================
# LOAD DATA
# ============================================================

df   = pd.read_csv("./data/SabDab-RELAAI/AbAg_record_level_balanced_dataset.csv")

# Unique sequences
seq_list = pd.unique(df[["ab_seq","ag_seq"]].values.ravel())

In [5]:
seq_list[:3]

array(['EVQLVESGGGLVQPGGSLRLSCAASGYTFTSYWLHWVRQAPGKGLEWVGMIDPSNSDTRFNPNFKDRFTISADTSKNTAYLQMNSLRAEDTAVYYCATYRSYVTPLDYWGQGTLVTVSSDIQMTQSPSSLSASVGDRVTITCKSSQSLLYTSSQKNYLAWYQQKPGKAPKLLIYWASTRESGVPSRFSGSGSGTDFTLTISSLQPEDFATYYCQQYYAYPWTFGQGTKVEIK',
       'MKYQLPNFTAETPIQNVILHEHHIFLGATNYIYVLNEEDLQKVAEYKTGPVLEHPDCFPCQDCSSKANLSGGVWKDNINMALVVDTYYDDQLISCGSVNRGTCQRHVFPHNHTADIQSEVHCIFSPQIEEPSQCPDCVVSALGAKVLSSVKDRFINFFVGNTINSSYFPDHPLHSISVRRLKETKDGFMFLTDQSYIDVLPEFRDSYPIKYVHAFESNNFIYFLTVQRETLDAQTFHTRIIRFCSINSGLHSYMEMPLECILTKEVFNILQAAYVSKPGAQLARQIGASLNDDILFGVFAQSKPDSAEPMDRSAMCAFPIKYVNDFFNKIVNKNNVRCLQHFYGPNHEHCEYRTEFTTALQRVDLFMGQFSEVLLTSISTFIKGDLTIANLGTSEGRFMQVVVSRSGPSTPHVNFLLDSHPVSPEVIVEHTLNNGYTLVITGKKITKIPLNGLGCRHFQSCSQCLSAPPFVQCGWCHDKCVRSEECLSGTWTQQICLPA',
       'QVQLVESGGGVVQPGRSLRLSCAASGFTFSSYGMHWVRQAPGKGLEWVAVMYYDGSNKDYVDSVKGRFTISRDNSKNTLYLQMNRLRAEDTAVYYCAREKDHYDILTGYNYYYGLDVWGQGTTVTVSSDIQMTQSPSSLSASVGDRVTITCRASQGIRNDLGWYQQKPGKAPKRLIYAASSLESGVPSRFSGSGSGTEFTLTISSVQPEDFVTYYCLQHNSNPLTFGGGTKVEIK'],


In [6]:
# ============================================================
# LOAD ABLANG2 (ANTIBODY MODEL)
# ============================================================
def load_ablang2():
    model_name = "ablang2-paired"
    print(f"[INFO] Loading AbLang2: {model_name}")

    model = ablang2.pretrained(
        model_to_use=model_name,
        random_init=False,
        device=device
    )
    return model

# ============================================================
# LOAD ProtT5 (ANTIGEN MODEL)
# ============================================================
def load_prott5():
    model_name = "Rostlab/prot_t5_xl_half_uniref50-enc"
    print(f"[INFO] Loading ProtT5: {model_name}")

    tokenizer = T5Tokenizer.from_pretrained(
        model_name,
        do_lower_case=False
    )
    model = T5EncoderModel.from_pretrained(model_name)

    model = model.half() if device.type == "cuda" else model.float()
    model = model.to(device).eval()

    return model, tokenizer

# ============================================================
# LOAD ESM2 (Antigen Model)
# ============================================================
def load_esm2():
    model_name = "facebook/esm2_t36_3B_UR50D"
    print(f"[INFO] Loading ESM2 model: {model_name}")

    tokenizer = EsmTokenizer.from_pretrained(
        model_name,
        do_lower_case=False
    )
    model = EsmModel.from_pretrained(model_name)

    # FP16 for GPU, FP32 for CPU
    if device.type == "cuda":
        model = model.half()
    else:
        model = model.float()

    model = model.to(device).eval()
    return model, tokenizer

# ============================================================
# EMBEDDING FUNCTIONS
# ============================================================

def embed_ablang(hseq: str, lseq: str):
    """
    AbLang2 antibody embedding (VH | VL)
    Output shape: (L, D)
    """
    seq = f"{hseq}|{lseq}".upper()

    tokens = ab_model.tokenizer(
        [seq],
        pad=True,
        w_extra_tkns=False,
        device=device
    )

    with torch.no_grad():
        output = ab_model.AbRep(tokens).last_hidden_states

    return output.squeeze().cpu().numpy()


def embed_prot(seq: str):
    """
    ProtT5 antigen embedding
    Output shape: (L, 1024)
    """
    seq = re.sub(r"[UZOB]", "X", seq)
    seq = " ".join(seq)

    inputs = ag_tokenizer(
        seq,
        return_tensors="pt",
        add_special_tokens=True
    ).to(device)

    with torch.no_grad():
        outputs = ag_model(**inputs)

    token_emb = outputs.last_hidden_state.squeeze(0)
    mask = inputs["attention_mask"].squeeze(0).bool()

    return token_emb[mask].cpu().numpy()

def embed_esm(seq: str):
    """
    Returns ESM2 embeddings for antigen. shape 2560
    """
    tokens = esm_tokenizer(
        seq,
        return_tensors="pt",
        add_special_tokens=True
    )

    tokens = {k: v.to(device) for k, v in tokens.items()}

    with torch.no_grad():
        output = esm_model(**tokens)

    return output.last_hidden_state.squeeze(0).cpu().numpy()

In [7]:

# ag_model, ag_tokenizer = load_prott5()
# ab_model = load_ablang2()
esm_model, esm_tokenizer = load_esm2()

[INFO] Loading ESM2 model: facebook/esm2_t36_3B_UR50D


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Some weights of EsmModel were not initialized from the model checkpoint at facebook/esm2_t36_3B_UR50D and are newly initialized: ['pooler.dense.bias', 'pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [8]:
# ============================================================
# SAVE EMBEDDINGS (HDF5)
# ============================================================

OUTPUT_PATH = "./data/SabDab-RELAAI/SabDab_RELAAI_esmt36_Full.h5"
os.makedirs(os.path.dirname(OUTPUT_PATH),exist_ok =True)
print(f"[INFO] Saving embeddings to:\n{OUTPUT_PATH}\n")

with h5py.File(OUTPUT_PATH, "w") as hf:
    for seq in tqdm(seq_list):
        emb = embed_esm(seq).reshape(-1).tolist()
        hf.create_dataset(seq, data=emb, compression="gzip")

[INFO] Saving embeddings to:
./data/SabDab-RELAAI/SabDab_RELAAI_esmt36_Full.h5



  0%|          | 0/1692 [00:00<?, ?it/s]

In [9]:
# ============================================================
# SANITY CHECK
# ============================================================

print("Sample embedding shape:",
      embed_esm(seq_list[0]).shape)

Sample embedding shape: (234, 2560)


In [10]:
 tokens = esm_tokenizer(
        seq_list[0],
        return_tensors="pt",
        add_special_tokens=True
)

tokens = {k: v.to(device) for k, v in tokens.items()}
tokens

{'input_ids': tensor([[ 0,  9,  7, 16,  4,  7,  9,  8,  6,  6,  6,  4,  7, 16, 14,  6,  6,  8,
           4, 10,  4,  8, 23,  5,  5,  8,  6, 19, 11, 18, 11,  8, 19, 22,  4, 21,
          22,  7, 10, 16,  5, 14,  6, 15,  6,  4,  9, 22,  7,  6, 20, 12, 13, 14,
           8, 17,  8, 13, 11, 10, 18, 17, 14, 17, 18, 15, 13, 10, 18, 11, 12,  8,
           5, 13, 11,  8, 15, 17, 11,  5, 19,  4, 16, 20, 17,  8,  4, 10,  5,  9,
          13, 11,  5,  7, 19, 19, 23,  5, 11, 19, 10,  8, 19,  7, 11, 14,  4, 13,
          19, 22,  6, 16,  6, 11,  4,  7, 11,  7,  8,  8, 13, 12, 16, 20, 11, 16,
           8, 14,  8,  8,  4,  8,  5,  8,  7,  6, 13, 10,  7, 11, 12, 11, 23, 15,
           8,  8, 16,  8,  4,  4, 19, 11,  8,  8, 16, 15, 17, 19,  4,  5, 22, 19,
          16, 16, 15, 14,  6, 15,  5, 14, 15,  4,  4, 12, 19, 22,  5,  8, 11, 10,
           9,  8,  6,  7, 14,  8, 10, 18,  8,  6,  8,  6,  8,  6, 11, 13, 18, 11,
           4, 11, 12,  8,  8,  4, 16, 14,  9, 13, 18,  5, 11, 19, 19, 23, 16, 16,
   

In [10]:
load_seq ={}

with h5py.File(OUTPUT_PATH, "r") as hf:
    for seq in tqdm(hf.keys()):
        load_seq[seq] =hf[seq][:]
        break

  0%|          | 0/2722 [00:00<?, ?it/s]

In [13]:
load_seq

{'AAALTQPLSVSVSPGQTAIFTCSGDNLGDKYVYWFQQRPGQSPALLIYQDNKRPSGIPERFSGSNSGNTATLTISGTQSTDEADYYCQTWDSTVVFGGGTKLQVQLQESGPGLVAASDTLSLTCTVSGGSLAAFYWSWIRQAPGKGLEWIGYIYYSGSAYYSPSLESRVTMSDAAAAAAAAAAAAAVYYCVRAAAAAAFASWGQGTLVTV': array([ 0.01722717, -0.01322937, -0.05004883, ...,  0.06414795,
        -0.12078857, -0.19470215])}

In [None]:
# load_seq['DIVITQSPSSMYASLGERVTITCKASQDINSYLSWFQQKPGKSPKTLIYRANRLVDGVPSRFSGSGSGQDYSLTISSLEYEDMGIYYCLQYDEFPLTFGAGTKLELKRTVAAPSVFIFPPSDEQLKSGTASVVCLLNNFYPREAKVQWKVDNALQSGNSQESVTEQDSKDSTYSLSSTLTLSKADYEKHKVYACEVTHQGLSSPVTKSFNRGECEVQLQESGPELVKPGASVKIPCKASGYTFTDYNMDWVKQSHGKSLEWIGDINPNNGGTIYNQKFKGKATLTVDKSSSTAYMELRSLTSEDTAVYYCARPDYYGSYGWYFDVWGTGTTVTVSSASTKGPSVFPLAPSSKSTSGGTAALGCLVKDYFPEPVTVSWNSGALTSGVHTFPAVLQSSGLYSLSSVVTVPSSSLGTQTYICNVNHKPSNTKVDKKVEPKS']