## Using esm1b (language model) embeddings as input for classifier 
## github: https://github.com/facebookresearch/esm

In [1]:
import esm # this is a module already in the SE3 kernel on digs, but the github link has the information to pip install esm 


In [2]:
import torch 

# Load ESM-1b model
model, alphabet = esm.pretrained.esm1b_t33_650M_UR50S()
batch_converter = alphabet.get_batch_converter() 
model.eval()  # inference mode 

Downloading: "https://dl.fbaipublicfiles.com/fair-esm/models/esm1b_t33_650M_UR50S.pt" to /home/sanaam/.cache/torch/hub/checkpoints/esm1b_t33_650M_UR50S.pt
Downloading: "https://dl.fbaipublicfiles.com/fair-esm/regression/esm1b_t33_650M_UR50S-contact-regression.pt" to /home/sanaam/.cache/torch/hub/checkpoints/esm1b_t33_650M_UR50S-contact-regression.pt


ProteinBertModel(
  (embed_tokens): Embedding(33, 1280, padding_idx=1)
  (layers): ModuleList(
    (0): TransformerLayer(
      (self_attn): MultiheadAttention(
        (k_proj): Linear(in_features=1280, out_features=1280, bias=True)
        (v_proj): Linear(in_features=1280, out_features=1280, bias=True)
        (q_proj): Linear(in_features=1280, out_features=1280, bias=True)
        (out_proj): Linear(in_features=1280, out_features=1280, bias=True)
      )
      (self_attn_layer_norm): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)
      (fc1): Linear(in_features=1280, out_features=5120, bias=True)
      (fc2): Linear(in_features=5120, out_features=1280, bias=True)
      (final_layer_norm): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)
    )
    (1): TransformerLayer(
      (self_attn): MultiheadAttention(
        (k_proj): Linear(in_features=1280, out_features=1280, bias=True)
        (v_proj): Linear(in_features=1280, out_features=1280, bias=True)
        (q_proj):

## single sequence embedding

In [11]:
# input: SINGLE sequence, LM_model (language model), lm_converter (data processor for the language model)
def get_ESM1b_embedding(input_sequence, LM_model = model, lm_converter = batch_converter ):
    data_ = [("prot", input_sequence.upper())] # process the data
    _, _, seq = lm_converter(data_)
    B, L = seq.shape
    L = L - 2 # remove start and end token
    LM_model.eval()
    with torch.no_grad():
        output = LM_model(seq, repr_layers=[33], return_contacts=True) # get the output from the language model
        embedding = output['representations'][33][:,1:-1,:] # embedding size (1, L, 1280)
        attention_map = output['attentions'][:,:,:,1:-1,1:-1] # attention map size (1, 33, 20, L, L)
        attention_map = attention_map.reshape(B, 33*20, L, L).permute(0,2,3,1) # (1, L, L, 660)
        return embedding,attention_map

In [12]:
example_sequence = 'RQLALEAKGETPSAVTRLSVVAKSEPQDEQSRSQSPRRIILSRLKAGEVDLLEEELGHLTTLTDVVKGADSLSAILPGDIAEDDITAVLCFVIEADQITFETVEVSPKISTPPVLKLAAEQAPTGRVEREKTTR'



In [13]:
example_embedding, example_attention_map = get_ESM1b_embedding(example_sequence)

In [14]:
example_embedding.shape, example_attention_map.shape

(torch.Size([1, 134, 1280]), torch.Size([1, 134, 134, 660]))

# getting embeddings of a batch of sequences

In [15]:
# the batch of sequences has to be formatted in this manner: 
# batch_sequences = [(protein_name_1,protein_sequence_1),(protein_name_2,protein_sequence_2),(protein_name_3,protein_sequence_3)]



In [17]:
batch_sequences = [('prot_1','RQLALEAKGETPSAVTRLSVVAKSEPQDEQSRSQSPRRIIL'),
                 ('prot_2','SAILPGDIAEDDITAVLCFVIEADQITFETVEVSPKISTPPVLKLAAEQAPTGRVEREKTTR'),
                 ('prot_3','SQSPRRIILSRLKAGEVDLLEEELGHLTTLTDVVKGADSLSAIL')]

In [28]:
# input: batch sequence, LM_model (language model), lm_converter (data processor for the language model)
def get_ESM1b_embedding_batch(sequence_batch, LM_model = model, lm_converter = batch_converter, average = False ):
    _, _, seq = lm_converter(sequence_batch)
    N, B, L = seq.shape
    L = L - 2 # remove start and end token
    LM_model.eval()
    with torch.no_grad():
        output = LM_model(seq, repr_layers=[33], return_contacts=True) # get the output from the language model
        embedding = output['representations'][33][:,1:-1,:] # embedding size (1, L, 1280)
        attention_map = output['attentions'][:,:,:,1:-1,1:-1] # attention map size (1, 33, 20, L, L)
        attention_map = attention_map.reshape(B, 33*20, L, L).permute(0,2,3,1) # (1, L, L, 660)
        
        # if you wanna average the embeddings along the sequence dimension -- i think this could be really cool too
        if (average): 
            embedding = embedding.mean(1)
            
        return embedding,attention_map
    
    

In [32]:
batch_embedding, batch_attention_map = get_ESM1b_embedding_batch(batch_sequences)
batch_embedding_average, batch_attention_map = get_ESM1b_embedding_batch(batch_sequences, average=True)

In [33]:
batch_embedding.shape, batch_attention_map.shape, batch_embedding_average.shape

(torch.Size([3, 62, 1280]),
 torch.Size([3, 62, 62, 660]),
 torch.Size([3, 1280]))

## so now when you are loading in the sequences for the model 
## get them in the batches format, where input_data = [(name1, sequence1),(name2,sequence2),(name3,sequence3)]
## pass that into the get_ESM1b_embedding_batch and get back the embedding
## I really like the idea of averaging the embedding along the sequence dimension -- that way you would have number_of_sequence x embedding_dimension vector (2D instead of 3D)
## not sure how to exactly use attention maps -- but y'all can figure that out 
