<a href="https://colab.research.google.com/github/agemagician/ProtTrans/blob/master/Embedding/PyTorch/Advanced/ProtT5-XL-UniRef50.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## Important Notes:
1. ProtT5-XL-UniRef50 has both encoder and decoder, for feature extraction we only load and use the encoder part.
2. Loading only the encoder part, reduces the inference speed and the GPU memory requirements by half.
2. In order to use ProtT5-XL-UniRef50 encoder, you must install the latest huggingface transformers version from their GitHub repo.
3. If you are intersted in both the encoder and decoder, you should use T5Model rather than T5EncoderModel.

<h3>Extracting protein sequences' features using ProtT5-XL-UniRef50 pretrained-model</h3>

**1. Load necessry libraries including huggingface transformers**

In [65]:
!pip install -q SentencePiece transformers

In [68]:
import torch
from transformers import T5EncoderModel, T5Tokenizer
import re
import numpy as np
import gc
import pandas

<b>2. Load the vocabulary and ProtT5-XL-UniRef50 Model<b>

In [69]:
tokenizer = T5Tokenizer.from_pretrained("Rostlab/prot_t5_xl_uniref50", do_lower_case=False )

In [70]:
model = T5EncoderModel.from_pretrained("Rostlab/prot_t5_xl_uniref50")

Some weights of the model checkpoint at Rostlab/prot_t5_xl_uniref50 were not used when initializing T5EncoderModel: ['decoder.block.0.layer.0.SelfAttention.v.weight', 'decoder.block.14.layer.0.SelfAttention.v.weight', 'decoder.block.3.layer.0.SelfAttention.o.weight', 'decoder.block.18.layer.2.DenseReluDense.wo.weight', 'decoder.block.21.layer.0.SelfAttention.o.weight', 'decoder.block.6.layer.2.DenseReluDense.wo.weight', 'decoder.block.9.layer.0.SelfAttention.o.weight', 'decoder.block.1.layer.2.DenseReluDense.wi.weight', 'decoder.block.12.layer.2.DenseReluDense.wo.weight', 'decoder.block.17.layer.0.SelfAttention.q.weight', 'decoder.block.8.layer.0.SelfAttention.q.weight', 'decoder.block.5.layer.1.EncDecAttention.q.weight', 'decoder.block.12.layer.2.layer_norm.weight', 'decoder.block.7.layer.2.DenseReluDense.wo.weight', 'lm_head.weight', 'decoder.block.22.layer.1.EncDecAttention.q.weight', 'decoder.block.10.layer.1.EncDecAttention.o.weight', 'decoder.block.20.layer.2.DenseReluDense.wi.we

In [71]:
gc.collect()

1048

<b>3. Load the model into the GPU if avilabile and switch to inference mode<b>

In [72]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

In [73]:
model = model.to(device)
model = model.eval()

<b>4. Create or load sequences and map rarely occured amino acids (U,Z,O,B) to (X)<b>

In [74]:
# sequences_Example = ["A E T C Z A O","S K T Z P"]
import pandas
import re
dataframe = pandas.read_csv('../../../data/2022-01-07-23h_subset-selectAB-E.csv', low_memory=False)
top_20 = dataframe[dataframe['epitope_aa'].isin(dataframe['epitope_aa'].value_counts()[0:10].keys().to_list())]
subset = top_20.groupby("epitope_aa").sample(n=70, random_state=42, replace=False)
subset['epitope_aa'].value_counts()[0:10]

CINGVCWTV    70
FLCMKALLL    70
GILGFVFTL    70
GLCTLVAML    70
IVTDFSVIK    70
KLGGALQAK    70
LLWNGPMAV    70
NLVPMVATV    70
RLRAEAQVK    70
YLQPRTFLL    70
Name: epitope_aa, dtype: int64

In [104]:
subset['sequences'] = [" ".join(cdr3a) + ' [SEP] '+ " ".join(cdr3b) for cdr3a,cdr3b in zip(subset['cdr3_alpha_aa'], subset['cdr3_beta_aa'])]
subset['sequences'].head()

16887    C L V A R G G S Y I P T F [SEP] C A S S H S A ...
15058    C A M R E H T S G T Y K Y I F [SEP] C A S S D ...
16879    C A Y R S L D L S G N T P L V F [SEP] C A S S ...
16888    C L V A S P S G G Y N K L I F [SEP] C A S S L ...
16875    C A V Q A N T N A G K S T F [SEP] C A S S F G ...
Name: sequences, dtype: object

In [105]:
subset['sequences'] = [re.sub(r"[UZOB]", "X", sequence) for sequence in subset['sequences']]
sequences_Example = subset['sequences'].to_list()
sequences_Example

['C L V A R G G S Y I P T F [SEP] C A S S H S A G V F M N T E A F F',
 'C A M R E H T S G T Y K Y I F [SEP] C A S S D S L V R G Y Q E T Q Y F',
 'C A Y R S L D L S G N T P L V F [SEP] C A S S L Y I Q G G E Q Y F',
 'C L V A S P S G G Y N K L I F [SEP] C A S S L A R Q E E T Q Y F',
 'C A V Q A N T N A G K S T F [SEP] C A S S F G R Q A Y E Q Y F',
 'C I V R V P F L Y N Q G G K L I F [SEP] C A S S Y S V K G L N T E A F F',
 'C A L S E S A N S G G Y Q K V T F [SEP] C A S S P R T S G G Y Q E T Q Y F',
 'C L V G V P V G A G S Y Q L T F [SEP] C A S T T G S S E K L F F',
 'C A A S A R G Q A G T A L I F [SEP] C A S S G P G G G A F F',
 'C A V D L T G A G S Y Q L T F [SEP] C A S S L P D R A G E K L F F',
 'C L V G V P G S A R Q L T F [SEP] C A S S L T V S L S P D L N E Q F F',
 'C L V G A P L V F [SEP] C S A T R S S G E P E Q F F',
 'C A V G A G T N A G K S T F [SEP] C A S S Q E S G T E A F F',
 'C L V G D G Y G N N R L A F [SEP] C A S S L P D R A G E K L F F',
 'C A V A D P R E Y G N K L V F [S

<b>5. Tokenize, encode sequences and load it into the GPU if possibile<b>

In [106]:
ids = tokenizer.batch_encode_plus(sequences_Example, add_special_tokens=True, padding=True)

In [107]:
input_ids = torch.tensor(ids['input_ids']).to(device)
attention_mask = torch.tensor(ids['attention_mask']).to(device)

<b>6. Extracting sequences' features and load it into the CPU if needed<b>

In [108]:
with torch.no_grad():
    embedding = model(input_ids=input_ids,attention_mask=attention_mask)

In [109]:
embedding = embedding.last_hidden_state.cpu().numpy()

<b>7. Remove padding (\<pad\>) and special tokens (\</s\>) that is added by ProtT5-XL-UniRef50 model<b>

In [110]:
features = [] 
for seq_num in range(len(embedding)):
    seq_len = (attention_mask[seq_num] == 1).sum()
    seq_emd = embedding[seq_num][:seq_len-1]
    features.append(seq_emd)

In [111]:
print(features)

[array([[ 0.30897254,  0.2090359 , -0.32464352, ..., -0.04472381,
        -0.21388204,  0.05481818],
       [ 0.13908426,  0.36502844, -0.27454895, ...,  0.2751863 ,
         0.08911769,  0.00329558],
       [ 0.10876895, -0.09918918, -0.13224587, ...,  0.01106232,
         0.02163352,  0.24692625],
       ...,
       [ 0.12894996,  0.07142997, -0.08744176, ...,  0.15462337,
         0.0384613 , -0.22172771],
       [-0.12031344,  0.05632034, -0.3148046 , ...,  0.08684896,
        -0.06580658, -0.21626973],
       [-0.005287  , -0.01445357, -0.19656587, ...,  0.02686592,
        -0.11398099, -0.23774001]], dtype=float32), array([[ 0.11521405, -0.06577437, -0.20553476, ...,  0.0496741 ,
         0.06522491,  0.02348308],
       [ 0.2683668 ,  0.03002981, -0.29187232, ...,  0.11646785,
         0.18926264, -0.01735737],
       [ 0.17001376,  0.09447163, -0.263591  , ...,  0.12828441,
         0.20248675,  0.10915561],
       ...,
       [ 0.22085333,  0.21547958,  0.0383692 , ...,  0.022

In [112]:
# n sequences
print(len(features))

700


In [121]:
# length first sequence
print(len(features[0]))

31

In [132]:
# length embedding vector of the first sequence first aa (all 1024)
print(len(features[0][0]))

1024
