# ProtransT5 embedding 


In [1]:
#@title Import dependencies. { display-mode: "form" }
# Load ProtT5 in half-precision (more specifically: the encoder-part of ProtT5-XL-U50) 
# import requests
from transformers import T5Tokenizer, T5EncoderModel
import torch
import re
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print("Using device: {}".format(device))

#@title Load encoder-part of ProtT5 in half-precision. { display-mode: "form" }
# Load ProtT5 in half-precision (more specifically: the encoder-part of ProtT5-XL-U50 in half-precision) 
transformer_link = "Rostlab/prot_t5_xl_half_uniref50-enc"
print("Loading: {}".format(transformer_link))
model = T5EncoderModel.from_pretrained(transformer_link)
model.full() if device=='cpu' else model.half() # only cast to full-precision if no GPU is available
model = model.to(device)
model = model.eval()
tokenizer = T5Tokenizer.from_pretrained(transformer_link, do_lower_case=False )

sequence_examples = ["PRTEINO", "GPSGLGLPAGLYAFNSGGISLDLGINDPVPFNTVGSQFGTAISQLDADTFVISETGFYKITVIANTATASVLGGLTIQVNGVPVPGTGSSLISLGAPIVIQAITQITTTPSLVEVIVTGLGLSLALGTSASIIIEKVA"]
def get_emb(sequence_examples):
    # this will replace all rare/ambiguous amino acids by X and introduce white-space between all amino acids
    sequence_examples = [" ".join(list(re.sub(r"[UZOB]", "X", sequence))) for sequence in sequence_examples]

    # tokenize sequences and pad up to the longest sequence in the batch
    ids = tokenizer.batch_encode_plus(sequence_examples, add_special_tokens=True, padding="longest")
    input_ids = torch.tensor(ids['input_ids']).to(device)
    attention_mask = torch.tensor(ids['attention_mask']).to(device)

    # generate embeddings
    with torch.no_grad():
        embedding_repr = model(input_ids=input_ids,attention_mask=attention_mask)

    # extract embeddings for the first ([0,:]) sequence in the batch while removing padded & special tokens ([0,:7]) 
    emb_0 = embedding_repr.last_hidden_state[0]
    return emb_0

  from .autonotebook import tqdm as notebook_tqdm


Using device: cuda:0
Loading: Rostlab/prot_t5_xl_half_uniref50-enc


In [None]:
import os
import torch

def print_files_in_directory(directory):
    n_saved = 0
    for root, dirs, files in os.walk(directory):
        for file_name in files:
            if (file_name == 'seq.pt'):
                seq = torch.load(os.path.join(root, file_name))
                proT5_emb = get_emb([seq])
                torch.save(proT5_emb, os.path.join(root, 'proT5_emb.pt'))  
                n_saved += 1
                if (n_saved % 100 == 0):
                    print('Saved {} files'.format(n_saved))
                

# Provide the directory path here
directory_path = '../data/casp12_data_30/'

print_files_in_directory(directory_path)


GSMANKPMQPITSTANKIVWSDPTRLSTTFSASLLRQRVKVGIAELNNVSGQYVSVYKRPAPKPEGGADAGVIMPNENQSIRTVISGSAENLATLKAEWETHKRNVDTLFASGNAGLGFLDPTAAIVSSDTTA
torch.Size([134, 1024])
../data/casp12_data_30/test/FM#T0859
../data/casp12_data_30/test/FM#T0859/seq.pt
MAKSHHHHHHTSVENNALSLVARGCAVAAPCRTKVAEQLLEIGAKAGMAGLAGAAVKDMADRMTSDELEHLITLQMMGNDEITTKYLSSLHDKYGSGAASNPNIGKDLTDAEKVELGGSGSGTGTPPPSENDPKQQNEKTVDKLNQKQESAIKKIDNTIKNALKDHDIIGTLKDMDGKPVPKENGGYWDAMQEMQNTLRGLRNHADTLKNVNNPEAQAAYGRATDAINKIESALKGYGI
torch.Size([240, 1024])
../data/casp12_data_30/test/FM#T0862
../data/casp12_data_30/test/FM#T0862/seq.pt
MSAETVNNYDYSDWYENAAPTKAPVEVIPPCDPTADEGLFHICIAAISLVVMLVLAILARRQKLSDNQRGLTGLLSPVNFLDHTQHKGLAVAVYGVLFCKLVGMVLSHHPLPFTKEVANKEFWMILALLYYPTLYYPLLACGTLHNKVGYVLGSLLSWTHFGILVWQKVDCPKTPQIYKYYALFGSLPQIACLAFLSFQYPLLLFKGLQNTETANASEDLSSSYYRDYVKKILKKKKPTKISSSTSKPKLFDRLRDAVKSYIYTPEDVFRFPLKLAISVVVAFIALYQMALLLISGVLPTLHIVRRGVDENIAFLLAGFNIILSNDRQEVVRIVVYYLWCVEICYVSAVTLSCLVNLLMLMRSMVLHRSNLKGLYRGDSLNVFNCHRSIRPSRPALVCWMGFTSYQAAFLCLGMAI