In [33]:
import re
import torch
import json
import sentencepiece
import torch.nn as nn

from transformers import T5EncoderModel, T5Tokenizer
from torch import Tensor
from typing import Dict
from tqdm import tqdm
from torch.nn.utils.rnn import pad_sequence

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type='cuda')

In [42]:
tokenizer = T5Tokenizer.from_pretrained("Rostlab/prot_t5_xl_half_uniref50-enc", model_max_length=1024)
model = T5EncoderModel.from_pretrained("Rostlab/prot_t5_xl_half_uniref50-enc")

In [50]:
decoder_layer = nn.TransformerDecoderLayer(d_model=1024, nhead=8)
transformer_decoder = nn.TransformerDecoder(decoder_layer, num_layers=6)
memory = torch.rand(10, 32, 1024)
tgt = torch.rand(20, 32, 1024)
out = transformer_decoder(tgt, memory)
out.shape

torch.Size([20, 32, 1024])

In [44]:
model = model.to(device)

In [57]:
def get_protein_embedding(sequence):

    seq = " ".join(list(re.sub(r"[UZOB]", "X", sequence)))

    # tokenize sequences and pad up to the longest sequence in the batch
    ids = tokenizer.batch_encode_plus(seq, add_special_tokens=True, padding='max_length')
    input_ids = torch.tensor(ids['input_ids']).to(device)
    attention_mask = torch.tensor(ids['attention_mask']).to(device)

    # generate embeddings
    with torch.no_grad():
        embedding_repr = model(input_ids=input_ids,attention_mask=attention_mask)

    print(embedding_repr.last_hidden_state.shape)
    return embedding_repr.last_hidden_state[0]

In [58]:
import json

# Load the records from the JSON file
with open('data/records.json', 'r') as f:
    records = json.load(f)

# Get the first 5 sequences
sequence_examples = [record['sequence'] for record in records[5:10]]

# Print the sequence examples
for sequence in sequence_examples:
    print(sequence)

MKEELDAFHQIFTTTKEAIERFMAMLTPVIENAEDDHERLYYHHIYEEEEQRLSRLDVLIPLIEKFQDETDEGLFSPSNNAFNRLLQELNLEKFGLHNFIEHVDLALFSFTDEERQTLLKELRKDAYEGYQYVKEKLAEINARFDHDYADPHAHHDEHRDHLADMPSAGSSHEEVQPVAHKKKGFTVGSLIQ
MKNWKKYAFASASVVALAAGLAACGNLTGNSKKAADSGDKPVIKMYQIGDKPDNLDELLANANKIIEEKVGAKLDIQYLGWGDYGKKMSVITSSGENYDIAFADNYIVNAQKGAYADLTELYKKEGKDLYKALDPAYIKGNTVNGKIYAVPVAANVASSQNFAFNGTLLAKYGIDISGVTSYETLEPVLKQIKEKAPDVVPFAIGKVFIPSDNFDYPVANGLPFVIDLEGDTTKVVNRYEVPRFKEHLKTLHKFYEAGYIPKDVATSDTSFDLQQDTWFVREETVGPADYGNSLLSRVANKDIQIKPITNFIKKNQTTQVANFVISNNSKNKEKSMEILNLLNTNPELLNGLVYGPEGKNWEKIEGKENRVRVLDGYKGNTHMGGWNTGNNWILYINENVTDQQIENSKKELAEAKESPALGFIFNTDNVKSEISAIANTMQQFDTAINTGTVDPDKAIPELMEKLKSEGAYEKVLNEMQKQYDEFLKNKK
MKFKTFSKSAVLLTASLAVLAACGSKNTASSPDYKLEGVTFPLQEKKTLKFMTASSPLSPKDPNEKLILQRLEKETGVHIDWTNYQSDFAEKRNLDISSGDLPDAIHNDGASDVDLMNWAKKGVIIPVEDLIDKYMPNLKKILDEKPEYKALMTAPDGHIYSFPWIEELGDGKESIHSVNDMAWINKDWLKKLGLEMPKTTDDLIKVLEAFKNGDPNGNGEADEIPFSFISGNGNEDFKFLFAAFGIGDNDDHLVVGNDGKVDFTADNDNYKEGVKFIRQLQEKGLIDKEAFEHDWNSYIAKGHDQKFGVYFTWD

In [59]:
get_protein_embedding(sequence_examples[0])

OutOfMemoryError: CUDA out of memory. Tried to allocate 5.98 GiB (GPU 0; 12.00 GiB total capacity; 16.48 GiB already allocated; 0 bytes free; 16.49 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF

In [62]:
def get_per_protein_embedding(sequence):
    # this will replace all rare/ambiguous amino acids by X and introduce white-space between all amino acids
    sequence = " ".join(list(re.sub(r"[UZOB]", "X", sequence)))

    # tokenize sequences and pad up to the longest sequence in the batch
    ids = tokenizer.batch_encode_plus(sequence, add_special_tokens=True, padding="longest")
    input_ids = torch.tensor(ids['input_ids']).to(device)
    attention_mask = torch.tensor(ids['attention_mask']).to(device)

    # generate embeddings
    with torch.no_grad():
        embedding_repr = model(input_ids=input_ids,attention_mask=attention_mask)

    # extract embeddings for the first ([0,:]) sequence in the batch while removing padded & special tokens ([0,:7]) 
    emb_0 = embedding_repr.last_hidden_state
    print(f"emb_0 shape: {emb_0.shape}")
    # if you want to derive a single representation (per-protein embedding) for the whole protein
    emb_0_per_protein = emb_0.mean(dim=0) # shape (1024)
    return emb_0_per_protein

In [65]:
seq = "ADTIVAVELDTYPNTDIGDPSYPHIGIDIKSVRSKKTAKWNMQNGKVGTAHIIYNSVGKRLSAVVSYPNGDSATVSF"
q = get_per_protein_embedding(seq)
print(q)
print(q.shape)

emb_0 shape: torch.Size([153, 2, 1024])
tensor([[ 0.0903, -0.1629, -0.1658,  ..., -0.0564, -0.1935, -0.0765],
        [-0.0403, -0.1693, -0.0632,  ..., -0.0043, -0.0497, -0.0558]],
       device='cuda:0')
torch.Size([2, 1024])


In [68]:
inputs = tokenizer(seq, return_tensors="pt", padding=True, truncation=True, max_length=1024)
inputs.to(device)
# Generate embeddings
with torch.no_grad():
    embeddings = model(**inputs).last_hidden_state

print(embeddings.shape)
# Post-processing the embeddings
# This step depends on how you want to use these embeddings. For example:
protein_embeddings = embeddings.mean(dim=1)  # Averaging over the sequence length

print(protein_embeddings.shape)  # Should b

torch.Size([1, 2, 1024])
torch.Size([1, 1024])


In [14]:
def save_protein_embeddings(records_path: str):
    # Load the records from the JSON file
    with open(records_path, 'r') as f:
        records = json.load(f)

    # Create an empty tensor dictionary to store the embeddings
    embedding_dict: Dict[str, Tensor] = {}

    for record in tqdm(records):
        
        record_id = record['id']
        sequence = record['sequence']
        
        # Get the protein embedding for the sequence
        embedding = get_per_protein_embedding(sequence)
        
        # Save the embedding with its ID in the tensor dictionary
        embedding_dict[record_id] = embedding

    # Save the tensor dictionary to a file
    torch.save(embedding_dict, 'embeddings.pt')

save_protein_embeddings('data/records.json')

100%|██████████| 2135/2135 [23:30<00:00,  1.51it/s]


In [32]:
PAD_TOKEN = '<pad>'
EOS_TOKEN = '<eos>'
PAD_TOKEN_ID = 14
EOS_TOKEN_ID = 15


def add_eos_token(sequence: list, eos_token_id):
    # if a sequence is already 1024 tokens long, change the last token with eos_token_id
    if len(sequence) == 1024:
        sequence[-1] = eos_token_id
    else:
        sequence.append(eos_token_id)
    return sequence


def pad_ion_sequences(sequence: list, pad_token_id, max_length):
    # if a sequence is already 1024 tokens long, return it
    if len(sequence) == 1024:
        return sequence
    else:
        padded_sequence = sequence + [pad_token_id] * (max_length - len(sequence))
        return padded_sequence


def save_ion_embeddings(records_path: str):
    
    with open(records_path, 'r') as f:
        records = json.load(f)

    # Create an empty tensor dictionary to store the embeddings
    embedding_dict: Dict[str, Tensor] = {}

    for record in tqdm(records):

        record_id = record['id']
        sequence = record['metal_binding_sites']
        
        # Get the protein embedding for the sequence
        eosed_sequence = add_eos_token(sequence, EOS_TOKEN_ID)

        padded_sequence = pad_ion_sequences(eosed_sequence, PAD_TOKEN_ID, 1024)
        
        # Save the embedding with its ID in the tensor dictionary
        embedding_dict[record_id] = torch.tensor(padded_sequence, dtype=torch.long)

    #Save the tensor dictionary to a file
    torch.save(embedding_dict, 'target_embeddings.pt')


save_ion_embeddings('data/records.json')

100%|██████████| 2135/2135 [00:00<00:00, 39726.01it/s]


In [36]:
# Check all the target_embeddings are the same length without for loop
target_embeddings = torch.load('target_embeddings.pt')
for key, value in target_embeddings.items():
    if len(value) != 1024:
        print('Error in length of one of the target_embeddings')
        print(key, len(value))

    elif len(value) == 1024:
        print('|', end='')

||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||