In [1]:
import os
import numpy as np
import pandas as pd
from Bio import SeqIO
import matplotlib.pyplot as plt
from matplotlib import cm
from matplotlib.patches import Patch
import mpl_stylesheet
import re
import gc
mpl_stylesheet.banskt_presentation(fontfamily = 'mono', fontsize = 20, colors = 'banskt', dpi = 300)

In [2]:
from transformers import T5Tokenizer, T5EncoderModel, T5Model
import torch
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')


## Download models from: 
## https://github.com/sacdallago/bio_embeddings/blob/develop/bio_embeddings/utilities/defaults.yml

# prottrans_t5_bfd:
#   model_directory: "http://data.bioembeddings.com/public/embeddings/embedding_models/t5/prottrans_t5_bfd.zip"
# prottrans_t5_uniref50:
#   model_directory: "http://data.bioembeddings.com/public/embeddings/embedding_models/t5/prottrans_t5_uniref50.zip"
# prottrans_t5_xl_u50:
#   model_directory: "http://data.bioembeddings.com/public/embeddings/embedding_models/t5/prottrans_t5_xl_u50.zip"
#   half_precision_model_directory: "http://data.bioembeddings.com/public/embeddings/embedding_models/t5/half_prottrans_t5_xl_u50.zip"


In [3]:
fullmodel = T5Model.from_pretrained("models/prottrans_t5_xl_u50").to(device)

In [4]:
# Load the tokenizer
tokenizer = T5Tokenizer.from_pretrained('models/prottrans_t5_xl_u50', do_lower_case=False)


You are using the legacy behaviour of the <class 'transformers.models.t5.tokenization_t5.T5Tokenizer'>. This means that tokens that come after special tokens will not be properly handled. We recommend you to read the related pull request available at https://github.com/huggingface/transformers/pull/24565


In [None]:
# Load the model
#model = T5EncoderModel.from_pretrained("models/half_prottrans_t5_xl_u50").to(device)

In [7]:
# only GPUs support half-precision currently; if you want to run on CPU use full-precision (not recommended, much slower)
fullmodel.full() if device=='cpu' else fullmodel.half()
gc.collect()

# prepare your protein sequences as a list
sequence_examples = ["PRTEINO", "SEQWENCE"]

# replace all rare/ambiguous amino acids by X and introduce white-space between all amino acids
sequence_examples = [" ".join(list(re.sub(r"[UZOB]", "X", sequence))) for sequence in sequence_examples]


In [8]:
sequence_examples

['P R T E I N X', 'S E Q W E N C E']

In [13]:

# tokenize sequences and pad up to the longest sequence in the batch
ids = tokenizer.batch_encode_plus(sequence_examples, add_special_tokens=True, padding="longest")

input_ids = torch.tensor(ids['input_ids']).to(device)
attention_mask = torch.tensor(ids['attention_mask']).to(device)

# generate embeddings
with torch.no_grad():
    embedding_repr = fullmodel(input_ids=input_ids,attention_mask=attention_mask, decoder_input_ids=input_ids)

## A better way to obtain the 'correct' embedding length
# features = [] 
# for seq_num in range(len(embedding)):
#     seq_len = (attention_mask[seq_num] == 1).sum()
#     seq_emd = embedding[seq_num][:seq_len-1]
#     features.append(seq_emd)
    
# # extract residue embeddings for the first ([0,:]) sequence in the batch and remove padded & special tokens ([0,:7]) 
# emb_0 = embedding_repr.last_hidden_state[0,:7] # shape (7 x 1024)
# # same for the second ([1,:]) sequence but taking into account different sequence lengths ([1,:8])
# emb_1 = embedding_repr.last_hidden_state[1,:8] # shape (8 x 1024)

# # if you want to derive a single representation (per-protein embedding) for the whole protein
# emb_0_per_protein = emb_0.mean(dim=0) # shape (1024)

In [16]:
embedding_repr

Seq2SeqModelOutput(last_hidden_state=tensor([[[ 2.8491e-01,  2.3376e-01, -5.3174e-01,  ..., -1.8677e-01,
           5.0812e-02, -2.1863e-01],
         [ 2.8247e-01,  2.5098e-01, -5.5420e-01,  ..., -2.5162e-02,
           1.6577e-01, -2.6807e-01],
         [ 3.0688e-01,  2.4023e-01, -2.4438e-01,  ...,  1.5662e-01,
           1.2756e-01, -2.9004e-01],
         ...,
         [ 2.2644e-01,  4.3848e-01, -6.0889e-01,  ...,  7.1096e-04,
          -3.1763e-01, -2.4683e-01],
         [-2.0645e-02,  2.4255e-01, -8.5986e-01,  ..., -9.9170e-01,
          -9.6191e-01, -4.6484e-01],
         [-6.1859e-02,  1.2140e-01, -3.4814e-01,  ..., -7.8027e-01,
          -6.7236e-01, -1.3416e-01]],

        [[ 2.2974e-01,  1.0706e-01, -1.4465e-01,  ...,  4.3701e-02,
           1.6098e-03, -1.1646e-01],
         [ 2.4048e-01,  1.0980e-01, -6.5369e-02,  ..., -6.6895e-02,
           7.4036e-02, -5.2277e-02],
         [ 2.6465e-01,  3.2776e-02,  1.1334e-01,  ..., -2.9816e-02,
           4.1107e-02, -1.1060e-01],
  

In [17]:
print(embedding_repr.last_hidden_state.shape)

torch.Size([2, 9, 1024])


In [19]:
print(embedding_repr.encoder_last_hidden_state.shape)

torch.Size([2, 9, 1024])


In [21]:
embedding_repr.last_hidden_state

tensor([[[ 2.8491e-01,  2.3376e-01, -5.3174e-01,  ..., -1.8677e-01,
           5.0812e-02, -2.1863e-01],
         [ 2.8247e-01,  2.5098e-01, -5.5420e-01,  ..., -2.5162e-02,
           1.6577e-01, -2.6807e-01],
         [ 3.0688e-01,  2.4023e-01, -2.4438e-01,  ...,  1.5662e-01,
           1.2756e-01, -2.9004e-01],
         ...,
         [ 2.2644e-01,  4.3848e-01, -6.0889e-01,  ...,  7.1096e-04,
          -3.1763e-01, -2.4683e-01],
         [-2.0645e-02,  2.4255e-01, -8.5986e-01,  ..., -9.9170e-01,
          -9.6191e-01, -4.6484e-01],
         [-6.1859e-02,  1.2140e-01, -3.4814e-01,  ..., -7.8027e-01,
          -6.7236e-01, -1.3416e-01]],

        [[ 2.2974e-01,  1.0706e-01, -1.4465e-01,  ...,  4.3701e-02,
           1.6098e-03, -1.1646e-01],
         [ 2.4048e-01,  1.0980e-01, -6.5369e-02,  ..., -6.6895e-02,
           7.4036e-02, -5.2277e-02],
         [ 2.6465e-01,  3.2776e-02,  1.1334e-01,  ..., -2.9816e-02,
           4.1107e-02, -1.1060e-01],
         ...,
         [ 7.5378e-02,  1

In [22]:

embedding_repr.encoder_last_hidden_state

tensor([[[ 0.1168, -0.0473,  0.0298,  ...,  0.0785,  0.0415,  0.1687],
         [ 0.2410, -0.1147, -0.2081,  ...,  0.0519, -0.1573,  0.0936],
         [ 0.4048, -0.1521,  0.0276,  ...,  0.0848, -0.0267,  0.0267],
         ...,
         [ 0.3108, -0.0559, -0.2308,  ...,  0.0483,  0.3550, -0.1281],
         [-0.1562, -0.0110,  0.2014,  ..., -0.0067,  0.0096, -0.0553],
         [-0.3003,  0.0337,  0.1202,  ...,  0.1093, -0.1835,  0.0505]],

        [[ 0.1117, -0.1654, -0.2759,  ...,  0.0377,  0.1609, -0.0633],
         [ 0.2302, -0.1619, -0.1133,  ...,  0.0074, -0.0874,  0.0403],
         [ 0.2764,  0.0426, -0.2500,  ...,  0.2140, -0.0558,  0.0230],
         ...,
         [ 0.0187, -0.0552,  0.1316,  ..., -0.1865,  0.1500, -0.1359],
         [ 0.1115, -0.0443,  0.2025,  ..., -0.0111, -0.1198, -0.0564],
         [-0.0795, -0.0544,  0.1328,  ...,  0.0110, -0.0254,  0.0034]]],
       device='cuda:0', dtype=torch.float16)