# Testing UniRef Model

In [1]:
import torch
from transformers import T5EncoderModel, T5Tokenizer
import re
import numpy as np

# Load model
tokenizer = T5Tokenizer.from_pretrained("Rostlab/prot_t5_xl_uniref50", do_lower_case=False)
model = T5EncoderModel.from_pretrained("Rostlab/prot_t5_xl_uniref50")

# Load into GPU if available
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
model = model.to(device)
model = model.eval() 

Some weights of the model checkpoint at Rostlab/prot_t5_xl_uniref50 were not used when initializing T5EncoderModel: ['decoder.block.8.layer.1.EncDecAttention.k.weight', 'decoder.block.18.layer.2.layer_norm.weight', 'decoder.block.4.layer.1.EncDecAttention.o.weight', 'decoder.block.4.layer.2.DenseReluDense.wi.weight', 'decoder.block.23.layer.1.layer_norm.weight', 'decoder.block.4.layer.0.SelfAttention.v.weight', 'decoder.block.13.layer.0.SelfAttention.v.weight', 'decoder.block.7.layer.0.SelfAttention.k.weight', 'decoder.block.1.layer.0.SelfAttention.v.weight', 'decoder.block.9.layer.1.EncDecAttention.o.weight', 'decoder.block.21.layer.2.layer_norm.weight', 'decoder.block.23.layer.2.DenseReluDense.wi.weight', 'decoder.block.3.layer.0.SelfAttention.k.weight', 'decoder.block.10.layer.1.EncDecAttention.q.weight', 'decoder.block.5.layer.0.SelfAttention.k.weight', 'decoder.block.19.layer.2.layer_norm.weight', 'decoder.block.13.layer.0.SelfAttention.o.weight', 'decoder.block.3.layer.1.layer_no

In [2]:
# Tokenize, encode, and load example sequences
ex_seq = ["A T W","A T E"]
ex_seq = [re.sub(r"[UZOB]", "X", sequence) for sequence in ex_seq]
ids = tokenizer.batch_encode_plus(ex_seq, add_special_tokens=True, padding=True)
input_ids = torch.tensor(ids['input_ids']).to(device)
attention_mask = torch.tensor(ids['attention_mask']).to(device)

# Extract sequence features
with torch.no_grad():
    embedding = model(input_ids=input_ids,attention_mask=attention_mask)
embedding = embedding.last_hidden_state.cpu().numpy()

# Remove padding and special tokens
features = [] 
for seq_num in range(len(embedding)):
    seq_len = (attention_mask[seq_num] == 1).sum()
    seq_emd = embedding[seq_num][:seq_len-1]
    features.append(seq_emd)
features

[array([[ 0.1265764 , -0.13151374, -0.2601224 , ...,  0.00458627,
         -0.05920463, -0.04529901],
        [ 0.2802877 , -0.10432567, -0.425026  , ...,  0.18578516,
         -0.09256832, -0.15548012],
        [ 0.07381413,  0.01820545, -0.00343402, ...,  0.11101398,
         -0.12984815, -0.02367226]], dtype=float32),
 array([[ 0.15986352, -0.2527481 , -0.15840194, ...,  0.06446997,
         -0.05601623, -0.13918552],
        [ 0.23181392, -0.2732637 , -0.30261785, ..., -0.03905839,
          0.1586108 , -0.00670971],
        [ 0.16332299, -0.10366669,  0.10422383, ...,  0.167325  ,
         -0.17385639,  0.02975434]], dtype=float32)]

# Trying Real Protein Sequences

In [36]:
from Bio import SeqIO

# Use biopython to parse fasta file and append sequences to list
sequences = []
with open("sequence.fasta") as file:
    for seq in SeqIO.parse(file, 'fasta'):

        # Add space after each character so each amino acid is vectorized
        seq = str(seq.seq)
        seq = ' '.join([*seq])
        sequences.append(seq)

# Cut list down for testing
sequences = sequences[0:10]
sequences[2]

'M P D Q I S V S E F V A E T L E D Y K A P T A S S F T M R T A Q C R D T V A A I E E'

In [37]:
# Tokenize, encode, pad
ids = tokenizer.batch_encode_plus(sequences, add_special_tokens=True, padding=True)
input_ids = torch.tensor(ids['input_ids']).to(device)
attention_mask = torch.tensor(ids['attention_mask']).to(device)

# Extract sequence features
with torch.no_grad():
    embedding = model(input_ids=input_ids,attention_mask=attention_mask)  #decoder_input_ids for decoder weights
embedding = embedding.last_hidden_state.cpu().numpy()

# Remove padding and special tokens
features = [] 
for seq_num in range(len(embedding)):
    seq_len = (attention_mask[seq_num] == 1).sum()
    seq_emd = embedding[seq_num][:seq_len-1]
    features.append(seq_emd)
features

: 

: 