In [15]:
import re
import torch
import json
import sentencepiece

from transformers import T5EncoderModel, T5Tokenizer
from torch import Tensor
from typing import Dict
from tqdm import tqdm
from torch.nn.utils.rnn import pad_sequence

In [4]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type='cuda')

In [5]:
tokenizer = T5Tokenizer.from_pretrained("Rostlab/prot_t5_xl_uniref50")
model = T5EncoderModel.from_pretrained("Rostlab/prot_t5_xl_uniref50")

Downloading (…)"spiece.model";: 100%|██████████| 238k/238k [00:00<00:00, 904kB/s]
To support symlinks on Windows, you either need to activate Developer Mode or to run Python as an administrator. In order to see activate developer mode, see this article: https://docs.microsoft.com/en-us/windows/apps/get-started/enable-your-device-for-development
Downloading (…)al_tokens_map.json";: 100%|██████████| 1.79k/1.79k [00:00<00:00, 1.79MB/s]
Downloading (…)enizer_config.json";: 100%|██████████| 24.0/24.0 [00:00<00:00, 23.9kB/s]
Downloading (…)"config.json";: 100%|██████████| 546/546 [00:00<00:00, 546kB/s]
Downloading (…)"pytorch_model.bin";: 100%|██████████| 11.3G/11.3G [40:12<00:00, 4.67MB/s]  
Some weights of the model checkpoint at Rostlab/prot_t5_xl_uniref50 were not used when initializing T5EncoderModel: ['decoder.block.5.layer.2.DenseReluDense.wi.weight', 'decoder.block.20.layer.0.layer_norm.weight', 'decoder.block.19.layer.2.DenseReluDense.wo.weight', 'decoder.block.14.layer.2.DenseReluD

In [8]:
model = model.to(device)

In [9]:
def get_per_protein_embedding(sequence):
    # this will replace all rare/ambiguous amino acids by X and introduce white-space between all amino acids
    sequence = " ".join(list(re.sub(r"[UZOB]", "X", sequence)))

    # tokenize sequences and pad up to the longest sequence in the batch
    ids = tokenizer.batch_encode_plus(sequence, add_special_tokens=True, padding="longest")
    input_ids = torch.tensor(ids['input_ids']).to(device)
    attention_mask = torch.tensor(ids['attention_mask']).to(device)

    # generate embeddings
    with torch.no_grad():
        embedding_repr = model(input_ids=input_ids,attention_mask=attention_mask)

    # extract embeddings for the first ([0,:]) sequence in the batch while removing padded & special tokens ([0,:7]) 
    emb_0 = embedding_repr.last_hidden_state[0,:7] # shape (7 x 1024)

    # if you want to derive a single representation (per-protein embedding) for the whole protein
    emb_0_per_protein = emb_0.mean(dim=0) # shape (1024)
    return emb_0_per_protein

In [10]:
q = get_per_protein_embedding("MAGNIFICENT")
print(q)
print(q.shape)

tensor([ 0.0357, -0.1899, -0.1827,  ...,  0.0174, -0.2099, -0.0339],
       device='cuda:0')
torch.Size([1024])


In [14]:
def save_protein_embeddings(records_path: str):
    # Load the records from the JSON file
    with open(records_path, 'r') as f:
        records = json.load(f)

    # Create an empty tensor dictionary to store the embeddings
    embedding_dict: Dict[str, Tensor] = {}

    for record in tqdm(records):
        
        record_id = record['id']
        sequence = record['sequence']
        
        # Get the protein embedding for the sequence
        embedding = get_per_protein_embedding(sequence)
        
        # Save the embedding with its ID in the tensor dictionary
        embedding_dict[record_id] = embedding

    # Save the tensor dictionary to a file
    torch.save(embedding_dict, 'embeddings.pt')

save_protein_embeddings('data/records.json')

100%|██████████| 2135/2135 [23:30<00:00,  1.51it/s]


In [32]:
PAD_TOKEN = '<pad>'
EOS_TOKEN = '<eos>'
PAD_TOKEN_ID = 14
EOS_TOKEN_ID = 15


def add_eos_token(sequence: list, eos_token_id):
    # if a sequence is already 1024 tokens long, change the last token with eos_token_id
    if len(sequence) == 1024:
        sequence[-1] = eos_token_id
    else:
        sequence.append(eos_token_id)
    return sequence


def pad_ion_sequences(sequence: list, pad_token_id, max_length):
    # if a sequence is already 1024 tokens long, return it
    if len(sequence) == 1024:
        return sequence
    else:
        padded_sequence = sequence + [pad_token_id] * (max_length - len(sequence))
        return padded_sequence


def save_ion_embeddings(records_path: str):
    
    with open(records_path, 'r') as f:
        records = json.load(f)

    # Create an empty tensor dictionary to store the embeddings
    embedding_dict: Dict[str, Tensor] = {}

    for record in tqdm(records):

        record_id = record['id']
        sequence = record['metal_binding_sites']
        
        # Get the protein embedding for the sequence
        eosed_sequence = add_eos_token(sequence, EOS_TOKEN_ID)

        padded_sequence = pad_ion_sequences(eosed_sequence, PAD_TOKEN_ID, 1024)
        
        # Save the embedding with its ID in the tensor dictionary
        embedding_dict[record_id] = torch.tensor(padded_sequence, dtype=torch.long)

    #Save the tensor dictionary to a file
    torch.save(embedding_dict, 'target_embeddings.pt')


save_ion_embeddings('data/records.json')

100%|██████████| 2135/2135 [00:00<00:00, 39726.01it/s]


In [36]:
# Check all the target_embeddings are the same length without for loop
target_embeddings = torch.load('target_embeddings.pt')
for key, value in target_embeddings.items():
    if len(value) != 1024:
        print('Error in length of one of the target_embeddings')
        print(key, len(value))

    elif len(value) == 1024:
        print('|', end='')

||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||