# Embeddings

We explore the added value of embeddings in the prediction of antigen specificity.

We develop `script_11_compute_embeddings.py` to compute embeddings for the sequences in the global dataset, to be later used with models.

From [`bio-embeddings`](https://docs.bioembeddings.com/v0.2.3/#):
- preference is for `prottrans_t5_xl_u50`, followed by `esm1b`

Notes
- Installing `bio-embeddings` with pip is annoying. Had issues installing jsonnet and had to install separately through conda, not pip. Afterwards, installation of `bio-embeddings[all]` worked.
- Download model files separately, check link from [my other github repo](https://github.com/ursueugen/ir-ageing/blob/main/02a_aminoacid_embeddings.ipynb).
    - Downloading is slow, leave overnight (~8GB per model, for the large ones).
    - Links for downloading models
        - esm1b:
            - model_file: http://data.bioembeddings.com/public/embeddings/embedding_models/esm1b/esm1b_t33_650M_UR50S.pt
        - prottrans_t5_xl_u50:
            - model_directory: http://data.bioembeddings.com/public/embeddings/embedding_models/t5/prottrans_t5_xl_u50.zip
            - half_precision_model_directory: http://data.bioembeddings.com/public/embeddings/embedding_models/t5/half_prottrans_t5_xl_u50.zip

In [1]:
from pathlib import Path
import pickle
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import math


import torch
import torch.nn as nn
import torch.nn.functional as F
# import torchvision
# import torchvision.transforms as transforms


from NegativeClassOptimization import ml
from NegativeClassOptimization import utils
from NegativeClassOptimization import preprocessing

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
ag_pos = "3VRL"
ag_neg = "1ADQ"
num_samples = 200

df = utils.load_1v1_binary_dataset(ag_pos, ag_neg, num_samples=num_samples)
df_train = df.iloc[:int(num_samples*0.8)]
df_test = df.iloc[int(num_samples*0.8):]

## Adding embeddings

In [3]:
slide = df_train["Slide"].iloc[0]
slide

'CARHLLWYFDV'

In [4]:
esm1b_embedder = preprocessing.load_embedder("ESM1b")
esm1b_embedding = esm1b_embedder.embed(slide)
esm1b_embedder.reduce_per_protein(esm1b_embedding)

pt_embedder = preprocessing.load_embedder("ProtTransT5XLU50")
pt_embedding = pt_embedder.embed(slide)
pt_embedder.reduce_per_protein(pt_embedding)

print(esm1b_embedding.shape, pt_embedding.shape)

(11, 1280) (11, 1024)


Demo: adding embeddings to slides from dataframe.

In [5]:
df = utils.load_global_dataframe()
print(df.shape)
df.head(1)

(460483, 8)


Unnamed: 0,ID_slide_Variant,CDR3,Best,Slide,Energy,Structure,UID,Antigen
0,5319791_04a,CARSAAFITTVGWYFDVW,True,AAFITTVGWYF,-94.7,128933-BRRSLUDUUS,1ADQ_5319791_04a,1ADQ


In [6]:
slide_embeddings_per_residue = {}
slide_embeddings_per_prot = {}

for slide in df["Slide"].iloc[:3]:
    
    esm1b_emb = esm1b_embedder.embed(slide)
    esm1b_emb_per_prot = esm1b_embedder.reduce_per_protein(esm1b_emb)

    pt_emb = pt_embedder.embed(slide)
    pt_emb_per_prot = pt_embedder.reduce_per_protein(pt_emb)

    slide_embeddings_per_residue[slide] = {
        "ESM1b": esm1b_emb.tolist(),
        "ProtTransT5XLU50": pt_emb.tolist(),
    }
    slide_embeddings_per_prot[slide] = {
        "ESM1b": esm1b_emb_per_prot.tolist(),
        "ProtTransT5XLU50": pt_emb_per_prot.tolist(),
    }

# with open("test.pkl", "wb+") as f:
#     pickle.dump(slide_embeddings_per_residue, f)

## VSH8 hand-engineered embeddings

In [None]:
preprocessing.get_vsh8_embedding_matrix()

Unnamed: 0,VSHE_1,VSHE_2,VSHE_3,VSHE_4,VSHE_5,VSHE_6,VSHE_7,VSHE_8
A,0.15,1.11,1.35,0.92,0.02,0.91,0.36,0.48
R,1.47,1.45,1.24,1.27,1.55,1.47,1.3,0.83
N,0.99,0.0,0.37,0.69,0.55,0.85,0.73,0.8
D,1.15,0.67,0.41,0.01,2.68,1.31,0.03,0.56
C,0.18,1.67,0.46,0.21,0.0,1.2,1.61,0.19
Q,0.96,0.12,0.18,0.16,0.09,0.42,0.2,0.41
E,1.18,0.4,0.1,0.36,2.16,0.17,0.91,0.02
G,0.2,1.53,2.63,2.28,0.53,1.18,2.01,1.34
H,0.43,0.25,0.37,0.19,0.51,1.28,0.93,0.65
I,1.27,0.14,0.3,1.8,0.3,1.61,0.16,0.13


## Simple network on embeddings vs one-hot

We evaluate the added value of embeddings.

In [5]:
import pickle

with open("../data/slack_1/global/embeddings/slide_embeddings_per_prot.pkl", "rb") as f:
    emb_dict = pickle.load(f)

In [25]:
from typing import List


def embed_slide(slide: str, emb_dict: dict) -> np.ndarray:
    emb_choice = "ProtTransT5XLU50"
    if slide not in emb_dict:
        # raise ValueError(f"Slide {slide} not in embedding dictionary.")
        return None
    emb = np.array(emb_dict[slide][emb_choice])
    return emb


slide = "AAFITTVGWYF"
emb = embed_slide(slide, emb_dict)
len(emb)

1024

In [7]:
model = ml.SNN(
    num_hidden_units=10, 
    input_dim=1024
    )

# utils.num_trainable_params(model) ~ 10k
model(torch.tensor(emb).reshape(1, -1))

tensor([[0.5382]], grad_fn=<SigmoidBackward0>)

In [32]:
from NegativeClassOptimization import preprocessing, pipelines

pipe = pipelines.BinaryclassPipeline()

In [3]:
pipe.step_1_process_data(
    ag_pos="3VRL",
    ag_neg="1ADQ",
    N=200,
)



In [17]:
embs = pipe.df_train_val["Slide"].apply(lambda x: embed_slide(x, emb_dict))
embs = list(filter(lambda e: e is not None, embs))
len(embs)

92

In [26]:
def embed_df(df, emb_dict):
    slides: pd.Series = df["Slide"]
    emb = slides.apply(lambda x: embed_slide(x, emb_dict))
    df["embedding"] = emb
    return df

df_e = embed_df(pipe.df_train_val, emb_dict)

In [35]:
preprocessing.arr_from_list_series(df_e["embedding"].dropna())

array([[ 0.04806438,  0.04111566, -0.02125612, ...,  0.0954366 ,
         0.01638312, -0.02267723],
       [ 0.10578969,  0.05308867, -0.27451834, ...,  0.20939325,
        -0.14231008, -0.10286613],
       [ 0.02021367,  0.09454548, -0.16444129, ...,  0.21021718,
        -0.06837969, -0.06736568],
       ...,
       [ 0.18288158,  0.04464582, -0.09879034, ...,  0.14656641,
         0.0064978 , -0.0356589 ],
       [ 0.12855272,  0.07887488, -0.04265216, ...,  0.17842086,
        -0.04470585, -0.08178511],
       [ 0.08736168,  0.08419388, -0.20473045, ...,  0.16757761,
         0.00387542, -0.06436346]])