# Get ESM protein embeddings for TP53 variants

William Colgan May 4 2022

### Setup

In [1]:
import torch
import numpy as np
import pandas as pd

### Load data

In [33]:
df = pd.read_csv('data/A549_Nutlin_Zscores.csv')

In [34]:
# remove stop codons
df = df[df['AA_wt'] != 'Z']
df = df[df['AA_variant'] != 'Z']

# get wildtype amino acids
tp53_wt = df.drop_duplicates(subset=['Position']).AA_wt
tp53_wt = "".join(tp53_wt)


### Load ESM model

In [24]:
model, alphabet = torch.hub.load("facebookresearch/esm:main", "esm2_t33_650M_UR50D")
batch_converter = alphabet.get_batch_converter()
model.eval()  # disables dropout for deterministic results

Using cache found in /Users/william/.cache/torch/hub/facebookresearch_esm_main


ESM2(
  (embed_tokens): Embedding(33, 1280, padding_idx=1)
  (layers): ModuleList(
    (0-32): 33 x TransformerLayer(
      (self_attn): MultiheadAttention(
        (k_proj): Linear(in_features=1280, out_features=1280, bias=True)
        (v_proj): Linear(in_features=1280, out_features=1280, bias=True)
        (q_proj): Linear(in_features=1280, out_features=1280, bias=True)
        (out_proj): Linear(in_features=1280, out_features=1280, bias=True)
        (rot_emb): RotaryEmbedding()
      )
      (self_attn_layer_norm): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)
      (fc1): Linear(in_features=1280, out_features=5120, bias=True)
      (fc2): Linear(in_features=5120, out_features=1280, bias=True)
      (final_layer_norm): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)
    )
  )
  (contact_head): ContactPredictionHead(
    (regression): Linear(in_features=660, out_features=1, bias=True)
    (activation): Sigmoid()
  )
  (emb_layer_norm_after): LayerNorm((1280,), eps=1

### Get embeddings

In [36]:
var_embeddings = np.zeros((df.shape[0], 1280))
for i in range(0, df.shape[0]):
    var = df.iloc[i]["AA_variant"]
    pos = df.iloc[i]["Position"]-1
    seq = tp53_wt[:pos] + var + tp53_wt[pos + 1:]
    batch_data = [(df.iloc[i]['Allele'], seq)]
    batch_labels, batch_strs, batch_tokens = batch_converter(batch_data)
    with torch.no_grad():
        results = model(batch_tokens, repr_layers=[33], return_contacts=False)
        token_embeddings = results["representations"][33].numpy()
    var_embeddings[i,:] = token_embeddings[0,:,:].mean(0)
    if i % 500 == 0:
        print(i)
        np.save('./data/t33_650M_TP53_Embeddings.npy', var_embeddings)
np.save('./data/t33_650M_TP53_Embeddings.npy', var_embeddings)

0
500
1000
1500
2000
2500
3000
3500
4000
4500
5000
5500
6000
6500
7000
7500
