# Embeddings

We explore the added value of embeddings in the prediction of antigen specificity.

We develop `script_11_compute_embeddings.py` to compute embeddings for the sequences in the global dataset, to be later used with models.

From [`bio-embeddings`](https://docs.bioembeddings.com/v0.2.3/#):
- preference is for `prottrans_t5_xl_u50`, followed by `esm1b`

Notes
- Installing `bio-embeddings` with pip is annoying. Had issues installing jsonnet and had to install separately through conda, not pip. Afterwards, installation of `bio-embeddings[all]` worked.
- Download model files separately, check link from [my other github repo](https://github.com/ursueugen/ir-ageing/blob/main/02a_aminoacid_embeddings.ipynb).
    - Downloading is slow, leave overnight (~8GB per model, for the large ones).
    - Links for downloading models
        - esm1b:
            - model_file: http://data.bioembeddings.com/public/embeddings/embedding_models/esm1b/esm1b_t33_650M_UR50S.pt
        - prottrans_t5_xl_u50:
            - model_directory: http://data.bioembeddings.com/public/embeddings/embedding_models/t5/prottrans_t5_xl_u50.zip
            - half_precision_model_directory: http://data.bioembeddings.com/public/embeddings/embedding_models/t5/half_prottrans_t5_xl_u50.zip

In [31]:
from pathlib import Path
import pickle
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import math


import torch
import torch.nn as nn
import torch.nn.functional as F
# import torchvision
# import torchvision.transforms as transforms


from NegativeClassOptimization import ml
from NegativeClassOptimization import utils
from NegativeClassOptimization import preprocessing

In [2]:
ag_pos = "3VRL"
ag_neg = "1ADQ"
df_train, df_test = utils.load_sample_binary_dataset(ag_pos, ag_neg, num_samples=20000)

## Adding embeddings

In [3]:
slide = df_train["Slide"].iloc[0]
slide

'TLLFPHWYFDV'

In [4]:
esm1b_embedder = preprocessing.load_embedder("ESM1b")
esm1b_embedding = esm1b_embedder.embed(slide)
esm1b_embedder.reduce_per_protein(esm1b_embedding)

pt_embedder = preprocessing.load_embedder("ProtTransT5XLU50")
pt_embedding = pt_embedder.embed(slide)
pt_embedder.reduce_per_protein(pt_embedding)

print(esm1b_embedding.shape, pt_embedding.shape)

(11, 1280) (11, 1024)


Demo: adding embeddings to slides from dataframe.

In [19]:
df = utils.load_global_dataframe()
print(df.shape)
df.head(1)

(460483, 8)


Unnamed: 0,ID_slide_Variant,CDR3,Best,Slide,Energy,Structure,UID,Antigen
0,5319791_04a,CARSAAFITTVGWYFDVW,True,AAFITTVGWYF,-94.7,128933-BRRSLUDUUS,1ADQ_5319791_04a,1ADQ


In [29]:
slide_embeddings_per_residue = {}
slide_embeddings_per_prot = {}

for slide in df["Slide"].iloc[:100]:
    
    esm1b_emb = esm1b_embedder.embed(slide)
    esm1b_emb_per_prot = esm1b_embedder.reduce_per_protein(esm1b_emb)

    pt_emb = pt_embedder.embed(slide)
    pt_emb_per_prot = pt_embedder.reduce_per_protein(pt_emb)

    slide_embeddings_per_residue[slide] = {
        "ESM1b": esm1b_emb.tolist(),
        "ProtTransT5XLU50": pt_emb.tolist(),
    }
    slide_embeddings_per_prot[slide] = {
        "ESM1b": esm1b_emb_per_prot.tolist(),
        "ProtTransT5XLU50": pt_emb_per_prot.tolist(),
    }

# with open("test.pkl", "wb+") as f:
#     pickle.dump(slide_embeddings_per_residue, f)