In [1]:
# ============================================================
# Imports
# ============================================================
import os
import re
import torch
import h5py
import pandas as pd
from tqdm.auto import tqdm

from transformers import (
    T5Tokenizer, 
    T5EncoderModel, 
    EsmModel, 
    EsmTokenizer,
)
# import ablang2
from bio_embeddings.embed import ProtTransBertBFDEmbedder,SeqVecEmbedder

tqdm.pandas()

  from .autonotebook import tqdm as notebook_tqdm
2025-12-18 18:11:57.596336: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [2]:
# ============================================================
# DEVICE SETUP
# ============================================================

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(f"\n[INFO] Using device: {device}\n")


[INFO] Using device: cuda:0



In [3]:
!ls ../DeepInterAware/data

AB-Bind  AVIDa_hIL6  BioMap  CoVAbDab  example	HIV  SAbDab  SabDab2  Yeast


In [4]:
# ============================================================
# LOAD DATA
# ============================================================

ab_df   = pd.read_csv("../DeepInterAware/data/SAbDab/antibody.csv")
ag_df   = pd.read_csv("../DeepInterAware/data/SAbDab/antigen.csv")

# Unique sequences
seq_list = pd.unique(
    pd.concat([ab_df["ab_seq"],ag_df["ag_seq"]],axis =0).values.ravel())

In [5]:
seq_list[:3]

array(['DVQMTQSPSYLAASPGESVSISCKATENINTYLAWYQAKPGKTTKLLLYSGSTLQSGTPSRFSGSGSGTDFTLTISSLEPEDFAVYYCQQHNEYPLTFGSGTKLEIKEVELVESGGDLVQPGRSLKLSCAASGFTFSNLAMAWVRQTPTKGLEWVASISPAGITTYYRDSVKGRFTISRDNARNTQYLQMDSLRSEDTATYYCARHTGKSSFDYWGQGVMVTVSSG',
       'DIVITQSPSSMYASLGERVTITCKASQDINSYLSWFQQKPGKSPKTLIYRANRLVDGVPSRFSGSGSGQDYSLTISSLEYEDMGIYYCLQYDEFPLTFGAGTKLELKRTVAAPSVFIFPPSDEQLKSGTASVVCLLNNFYPREAKVQWKVDNALQSGNSQESVTEQDSKDSTYSLSSTLTLSKADYEKHKVYACEVTHQGLSSPVTKSFNRGECEVQLQESGPELVKPGASVKIPCKASGYTFTDYNMDWVKQSHGKSLEWIGDINPNNGGTIYNQKFKGKATLTVDKSSSTAYMELRSLTSEDTAVYYCARPDYYGSYGWYFDVWGTGTTVTVSSASTKGPSVFPLAPSSKSTSGGTAALGCLVKDYFPEPVTVSWNSGALTSGVHTFPAVLQSSGLYSLSSVVTVPSSSLGTQTYICNVNHKPSNTKVDKKVEPKS',
       'DIQMTQSPASLSVSVGETVTITCRASENIYSNLIWYQQKQGKSPQLLVYAATNLADGVPSRFSGSGSGTQYSLKINSLQSEDFGSYYCQHFWGTPLTFGAGTKLEIKRADAAPTVSIFPPSSEQLTSGGASVVCFLNNFYPKDINVKWKIDGSERQNGVLNSWTDQDSKDSTYSMSSTLTLTKDEYERHNSYTCEATHKTSTSPIVKSFNRNECQVQLLQSGAELVRPGSSVKISCKASGYVFTSYWMHWVKQRPGQGLEWIGQIYPGDGGTHYNGNFRDKATLTADKSSSTAYMHLSLTSEDSAV

In [6]:
# ============================================================
# LOAD ABLANG2 (ANTIBODY MODEL)
# ============================================================
def load_ablang2():
    model_name = "ablang2-paired"
    print(f"[INFO] Loading AbLang2: {model_name}")

    model = ablang2.pretrained(
        model_to_use=model_name,
        random_init=False,
        device=device
    )
    return model

# ============================================================
# LOAD ProtT5 (ANTIGEN MODEL)
# ============================================================
def load_prott5():
    model_name = "Rostlab/prot_t5_xl_half_uniref50-enc"
    print(f"[INFO] Loading ProtT5: {model_name}")

    tokenizer = T5Tokenizer.from_pretrained(
        model_name,
        do_lower_case=False
    )
    model = T5EncoderModel.from_pretrained(model_name)

    model = model.half() if device.type == "cuda" else model.float()
    model = model.to(device).eval()

    return model, tokenizer

# ============================================================
# LOAD ESM2 (Antigen Model)
# ============================================================
def load_esm2():
    model_name = "facebook/esm2_t36_3B_UR50D"
    print(f"[INFO] Loading ESM2 model: {model_name}")

    tokenizer = EsmTokenizer.from_pretrained(
        model_name,
        do_lower_case=False
    )
    model = EsmModel.from_pretrained(model_name)

    # FP16 for GPU, FP32 for CPU
    if device.type == "cuda":
        model = model.half()
    else:
        model = model.float()

    model = model.to(device).eval()
    return model, tokenizer

# ============================================================
# EMBEDDING FUNCTIONS
# ============================================================

def embed_ablang(hseq: str, lseq: str):
    """
    AbLang2 antibody embedding (VH | VL)
    Output shape: (L, D)
    """
    seq = f"{hseq}|{lseq}".upper()

    tokens = ab_model.tokenizer(
        [seq],
        pad=True,
        w_extra_tkns=False,
        device=device
    )

    with torch.no_grad():
        output = ab_model.AbRep(tokens).last_hidden_states

    return output.squeeze().cpu().numpy()


def embed_prot(seq: str):
    """
    ProtT5 antigen embedding
    Output shape: (L, 1024)
    """
    seq = re.sub(r"[UZOB]", "X", seq)
    seq = " ".join(seq)

    inputs = ag_tokenizer(
        seq,
        return_tensors="pt",
        add_special_tokens=True
    ).to(device)

    with torch.no_grad():
        outputs = ag_model(**inputs)

    token_emb = outputs.last_hidden_state.squeeze(0)
    mask = inputs["attention_mask"].squeeze(0).bool()

    return token_emb[mask].cpu().numpy()

def embed_esm(seq: str):
    """
    Returns ESM2 embeddings for antigen. shape 2560
    """
    tokens = esm_tokenizer(
        seq,
        return_tensors="pt",
        add_special_tokens=True
    )

    tokens = {k: v.to(device) for k, v in tokens.items()}

    with torch.no_grad():
        output = esm_model(**tokens)

    return output.last_hidden_state.squeeze(0).cpu().numpy()

In [7]:

# ag_model, ag_tokenizer = load_prott5()
# ab_model = load_ablang2()
# esm_model, esm_tokenizer = load_esm2()
seqvec_emb = SeqVecEmbedder()

In [11]:
# ============================================================
# SAVE EMBEDDINGS (HDF5)
# ============================================================

OUTPUT_PATH = "./data/SabDab/SabDab_SeqVec_Full.h5"
os.makedirs(os.path.dirname(OUTPUT_PATH),exist_ok =True)
print(f"[INFO] Saving embeddings to:\n{OUTPUT_PATH}\n")

with h5py.File(OUTPUT_PATH, "w") as hf:
    for seq in tqdm(seq_list):
        emb = seqvec_emb.embed(seq).reshape(-1).tolist()
        hf.create_dataset(seq, data=emb, compression="gzip")

[INFO] Saving embeddings to:
./data/SabDab/SabDab_SeqVec_Full.h5



100%|██████████| 2722/2722 [40:28<00:00,  1.12it/s]  


In [9]:
# ============================================================
# SANITY CHECK
# ============================================================

print("Sample embedding shape:",
      seqvec_emb.embed(seq_list[0]).shape)

Sample embedding shape: (3, 226, 1024)


In [10]:
load_seq ={}

with h5py.File(OUTPUT_PATH, "r") as hf:
    for seq in tqdm(hf.keys()):
        load_seq[seq] =hf[seq][:]
        break

  0%|          | 0/2722 [00:00<?, ?it/s]

In [13]:
load_seq

{'AAALTQPLSVSVSPGQTAIFTCSGDNLGDKYVYWFQQRPGQSPALLIYQDNKRPSGIPERFSGSNSGNTATLTISGTQSTDEADYYCQTWDSTVVFGGGTKLQVQLQESGPGLVAASDTLSLTCTVSGGSLAAFYWSWIRQAPGKGLEWIGYIYYSGSAYYSPSLESRVTMSDAAAAAAAAAAAAAVYYCVRAAAAAAFASWGQGTLVTV': array([ 0.01722717, -0.01322937, -0.05004883, ...,  0.06414795,
        -0.12078857, -0.19470215])}

In [None]:
# load_seq['DIVITQSPSSMYASLGERVTITCKASQDINSYLSWFQQKPGKSPKTLIYRANRLVDGVPSRFSGSGSGQDYSLTISSLEYEDMGIYYCLQYDEFPLTFGAGTKLELKRTVAAPSVFIFPPSDEQLKSGTASVVCLLNNFYPREAKVQWKVDNALQSGNSQESVTEQDSKDSTYSLSSTLTLSKADYEKHKVYACEVTHQGLSSPVTKSFNRGECEVQLQESGPELVKPGASVKIPCKASGYTFTDYNMDWVKQSHGKSLEWIGDINPNNGGTIYNQKFKGKATLTVDKSSSTAYMELRSLTSEDTAVYYCARPDYYGSYGWYFDVWGTGTTVTVSSASTKGPSVFPLAPSSKSTSGGTAALGCLVKDYFPEPVTVSWNSGALTSGVHTFPAVLQSSGLYSLSSVVTVPSSSLGTQTYICNVNHKPSNTKVDKKVEPKS']