In [39]:
import esm
import torch
import pandas as pd
import tqdm

In [40]:
model, alphabet = esm.pretrained.esm2_t33_650M_UR50D()
model.eval()

ESM2(
  (embed_tokens): Embedding(33, 1280, padding_idx=1)
  (layers): ModuleList(
    (0): TransformerLayer(
      (self_attn): MultiheadAttention(
        (k_proj): Linear(in_features=1280, out_features=1280, bias=True)
        (v_proj): Linear(in_features=1280, out_features=1280, bias=True)
        (q_proj): Linear(in_features=1280, out_features=1280, bias=True)
        (out_proj): Linear(in_features=1280, out_features=1280, bias=True)
        (rot_emb): RotaryEmbedding()
      )
      (self_attn_layer_norm): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)
      (fc1): Linear(in_features=1280, out_features=5120, bias=True)
      (fc2): Linear(in_features=5120, out_features=1280, bias=True)
      (final_layer_norm): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)
    )
    (1): TransformerLayer(
      (self_attn): MultiheadAttention(
        (k_proj): Linear(in_features=1280, out_features=1280, bias=True)
        (v_proj): Linear(in_features=1280, out_features=1280, bia

In [41]:
df = pd.read_csv('data/OsmoticStress_with_binary_positions.csv')

In [42]:
df_clean = df[["Uniprot_ID", "full_sequence"]].drop_duplicates()

# Remove index

df_clean.reset_index(drop=True, inplace=True)

In [43]:
df_clean.shape

(2746, 2)

In [44]:
df_clean = df_clean[~df_clean['full_sequence'].apply(lambda x: isinstance(x, float))]

df_clean['full_sequence'] = df_clean['full_sequence'].astype(str)

In [45]:
df_clean

Unnamed: 0,Uniprot_ID,full_sequence
0,P15703,MRFSTTLATAATALFFTASQVSAIGELAFNLGVKNNDGTCKSTSDY...
1,P38174,MTDAEIENSPASDLKELNLENEGVEQQDQAKADESDPVESKKKKNK...
2,P26637,MSSGLVLENTARRDALIAIEKKYQKIWAEEHQFEIDAPSIEDEPIT...
3,P06169,MSEITLGKYLFERLKQVNVNTVFGLPGDFNLSLLDKIYEVEGMRWA...
4,P00359,MVRVAINGFGRIGRLVMRIALSRPNVEVVALNDPFITNDYAAYMFK...
...,...,...
2741,P38887,MILKLVHCLVALTGLIFAKPYQQQQAVLAPSQDVPLRDIHIGDINF...
2742,P53093,MSYGREDTTIEPDFIEPDAPLAASGGVADNIGGTMQNSGSRGTLDE...
2743,Q04772,MSSDGMNRDVSNSKPNVRFAAPQRLSVAHPAISSPLHMPMSKSSRK...
2744,P21192,MDNVVDPWYINPSGFAKDTQDEEYVQHHDNVNPTIPPPDNYILNNE...


In [46]:
import tqdm
import numpy as np
import torch

def generate_embeddings(model, alphabet, sequences):
    embeddings = []
    
    batch_converter = alphabet.get_batch_converter()
    
    for sequence in tqdm.tqdm(sequences, desc="Generating Embeddings"):
        data = [(0, sequence)]
        batch_labels, batch_strs, batch_tokens = batch_converter(data)

        # Generate embeddings
        with torch.no_grad():
            results = model(batch_tokens, repr_layers=[33]) 
            token_embeddings = results['representations'][33]

        # Average embeddings across all tokens and convert to numpy
        # averaged_embedding = token_embeddings.mean(dim=1).numpy()
        # embeddings.append(averaged_embedding)

        # Create full embeddings
        full_embedding = token_embeddings.squeeze().numpy()
        embeddings.append(full_embedding)
    
    #return np.vstack(embeddings) 
    return embeddings

df_test = df_clean.copy()

In [47]:
df_test

Unnamed: 0,Uniprot_ID,full_sequence
0,P15703,MRFSTTLATAATALFFTASQVSAIGELAFNLGVKNNDGTCKSTSDY...
1,P38174,MTDAEIENSPASDLKELNLENEGVEQQDQAKADESDPVESKKKKNK...
2,P26637,MSSGLVLENTARRDALIAIEKKYQKIWAEEHQFEIDAPSIEDEPIT...
3,P06169,MSEITLGKYLFERLKQVNVNTVFGLPGDFNLSLLDKIYEVEGMRWA...
4,P00359,MVRVAINGFGRIGRLVMRIALSRPNVEVVALNDPFITNDYAAYMFK...
...,...,...
2741,P38887,MILKLVHCLVALTGLIFAKPYQQQQAVLAPSQDVPLRDIHIGDINF...
2742,P53093,MSYGREDTTIEPDFIEPDAPLAASGGVADNIGGTMQNSGSRGTLDE...
2743,Q04772,MSSDGMNRDVSNSKPNVRFAAPQRLSVAHPAISSPLHMPMSKSSRK...
2744,P21192,MDNVVDPWYINPSGFAKDTQDEEYVQHHDNVNPTIPPPDNYILNNE...


In [48]:
sequences = df_test['full_sequence'].tolist()
embeddings = generate_embeddings(model, alphabet, sequences)

Generating Embeddings:  24%|██▍       | 648/2697 [1:55:55<5:46:45, 10.15s/it] 

In [None]:
print(len(embeddings))
print(df_test.shape[0]) 

10
10


In [None]:
type(embeddings[0][0][0])

numpy.float32

In [None]:
df_test['full_embedding'] = [e.tolist() for e in embeddings]

In [None]:
df_test.to_pickle('data/embeddings_test.pkl')