# Protein per-residue Embeddings
- Compute all protein per-residue embeddings using Prot-T5
## Input
- list of proteins containing the following fields:
-
pollys_output = ('class', 'architecture', 'topology', 'homology', 'domain_id', 's35',
's60', 's95', 's100', 's100_count', 'length', 'resolution', 'domain_sequence',
'homology_path', 'protein_sequence', "protein_id", "domain_start", "domain_end")

In [1]:
import pandas as pd

dataset = pd.read_csv("../data/example_protein_subset.csv")
model_name = "Rostlab/prot_t5_xl_half_uniref50-enc"

output_dir = "../data/embeddings"

In [2]:
import os
import time

import torch
from tqdm import tqdm
from transformers import T5EncoderModel, T5Tokenizer

# Set device (GPU if available, otherwise CPU)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# Load the pre-trained model and tokenizer
print(f"Loading model: {model_name}")
tokenizer = T5Tokenizer.from_pretrained(model_name)
model = T5EncoderModel.from_pretrained(model_name)
model = model.to(device)
model.eval()
print("Model loaded successfully")

# Load dataset
dataset = pd.read_csv("../data/example_protein_subset.csv")
print(f"Dataset loaded with {len(dataset)} proteins")
print(f"Columns: {dataset.columns}")


You are using the default legacy behaviour of the <class 'transformers.models.t5.tokenization_t5.T5Tokenizer'>. This is expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you. If you want to use the new behaviour, set `legacy=False`. This should only be set if you understand what it means, and thoroughly read the reason why this was added as explained in https://github.com/huggingface/transformers/pull/24565


Using device: cpu
Loading model: Rostlab/prot_t5_xl_half_uniref50-enc
Model loaded successfully
Dataset loaded with 1 proteins
Columns: Index(['class', 'architecture', 'topology', 'homology', 'domain_id', 's35',
       's60', 's95', 's100', 's100_count', 'length', 'resolution', 'sequence',
       'homology_path', 'protein_id', 'protein_sequence', 'domain_start',
       'domain_end'],
      dtype='object')


In [3]:
# Create output directory if it doesn't exist
os.makedirs(output_dir, exist_ok=True)
os.makedirs(f"{output_dir}/proteins", exist_ok=True)
os.makedirs(f"{output_dir}/domains", exist_ok=True)

# Dictionary to store embeddings
all_protein_embeddings = {}
all_domain_embeddings = {}

# Batch size for processing
batch_size = 1  # Process one protein at a time due to varying sequence lengths

# Process each sequence with error handling
start_time = time.time()
for index, row in tqdm(dataset.iterrows(), total=len(dataset), desc="Processing proteins"):
    try:
        # Get sequence and preprocess
        sequence = row["protein_sequence"]
        domain_id = row["domain_id"]
        protein_id = row["protein_id"]
        domain_start = row["domain_start"]
        domain_end = row["domain_end"]

        # Replace non-standard amino acids with X
        sequence = sequence.replace('U', 'X').replace('Z', 'X').replace('O', 'X')

        # Tokenize sequence
        ids = tokenizer.batch_encode_plus([sequence], add_special_tokens=True, padding=True, return_tensors="pt")
        input_ids = ids['input_ids'].to(device)
        attention_mask = ids['attention_mask'].to(device)

        with torch.no_grad():
            embedding = model(input_ids=input_ids, attention_mask=attention_mask)

        # Get per-residue embeddings (remove special tokens)
        # The first token is a special start token, and we need to remove it
        # We also need to remove padding tokens using the attention mask
        per_residue_embeddings = embedding.last_hidden_state.squeeze()[0, :].cpu()
        # Ensure we only keep embeddings for actual residues (not padding)
        seq_len = len(sequence)
        per_residue_embeddings = per_residue_embeddings[:seq_len]
        domain_embeddings = per_residue_embeddings[domain_start - 1:domain_end]
        # Store in dictionary with domain_id as key
        all_protein_embeddings[protein_id] = per_residue_embeddings
        all_domain_embeddings[domain_id] = domain_embeddings

        # Save individual protein embedding to file
        torch.save(per_residue_embeddings, f"{output_dir}/proteins/{protein_id}.pt")
        torch.save(domain_embeddings, f"{output_dir}/domains/{protein_id}.pt")

        if index % 10 == 0 and index > 0:
            print(f"Processed {index} proteins. Time elapsed: {time.time() - start_time:.2f}s")

    except Exception as e:
        print(f"Error processing protein {index} ({row.get('domain_id', 'unknown')}): {str(e)}")

# Save all embeddings to one file
torch.save(all_protein_embeddings, f"{output_dir}/proteins/all_protein_embeddings.pt")
torch.save(all_protein_embeddings, f"{output_dir}/domains/all_domain_embeddings.pt")
print(f"All embeddings saved")
print(f"Total time: {time.time() - start_time:.2f}s")

Processing proteins: 100%|██████████| 1/1 [00:00<00:00,  1.10it/s]

All embeddings saved
Total time: 0.93s



