# Embed Sequences

The primary purpose of this notebook is to take the SCOP CSV of sequences and embed them all.

In [1]:
import os
import itertools
from multiprocessing import Pool

from tqdm import tqdm
import numpy as np
import pandas as pd

import torch
import faiss
import esm

from sklearn.metrics.pairwise import cosine_similarity
from Bio import SeqIO

### Metadata Mapping: Contains unique integer index

In [2]:
scop_csv_path = '/scratch/gpfs/jr8867/datasets/scop/scop_data.csv'
scop_df = pd.read_csv(scop_csv_path)
scop_df

Unnamed: 0,index,uid,fa,sf,seq
0,0,Q03131,4000119,3000038,MSGPRSRTTSRRTPVRIGAVVVASSTSELLDGLAAVADGRPHASVV...
1,1,P09147,4000088,3000038,MRVLVTGGSGYIGSHTCVQLLQNGHDVIILDNLCNSKRSVLPVIER...
2,2,P61889,4000045,3000039,MKVAVLGAAGGIGQALALLLKTQLPSGSELSLYDIAPVTPGVAVDL...
3,3,P00334,4000029,3000038,MSFTLTNKNVIFVAGLGGIGLDTSKELLKRDLKNLVILDRIENPAA...
4,4,O33830,4000089,3000039,MPSVKIGIIGAGSAVFSLRLVSDLCKTPGLSGSTVTLMDIDEERLD...
...,...,...,...,...,...
35972,35972,P20585,4004015,3000587,MSRRKPASGGLAASSSAPARQAVLSRFFQSTGSLKSTSSSTGAADQ...
35973,35973,P20585,4004015,3002020,MSRRKPASGGLAASSSAPARQAVLSRFFQSTGSLKSTSSSTGAADQ...
35974,35974,P52701,4004015,3001688,MSRQSTLYSFFPKSPALSDANKASARASREGGRAAAAPGASPSPGG...
35975,35975,P52701,4004015,3000587,MSRQSTLYSFFPKSPALSDANKASARASREGGRAAAAPGASPSPGG...


# Loading the  Embedding Model

In [3]:
cuda_available = torch.cuda.is_available()
print("CUDA available:", cuda_available)
device = torch.device("cuda" if cuda_available else "cpu")

CUDA available: True


In [4]:
model, alphabet = esm.pretrained.esm2_t33_650M_UR50D()
batch_converter = alphabet.get_batch_converter()
model.to(device)
model.eval() # Ensures that the model is in evaluation mode, not training mode

ESM2(
  (embed_tokens): Embedding(33, 1280, padding_idx=1)
  (layers): ModuleList(
    (0-32): 33 x TransformerLayer(
      (self_attn): MultiheadAttention(
        (k_proj): Linear(in_features=1280, out_features=1280, bias=True)
        (v_proj): Linear(in_features=1280, out_features=1280, bias=True)
        (q_proj): Linear(in_features=1280, out_features=1280, bias=True)
        (out_proj): Linear(in_features=1280, out_features=1280, bias=True)
        (rot_emb): RotaryEmbedding()
      )
      (self_attn_layer_norm): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)
      (fc1): Linear(in_features=1280, out_features=5120, bias=True)
      (fc2): Linear(in_features=5120, out_features=1280, bias=True)
      (final_layer_norm): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)
    )
  )
  (contact_head): ContactPredictionHead(
    (regression): Linear(in_features=660, out_features=1, bias=True)
    (activation): Sigmoid()
  )
  (emb_layer_norm_after): LayerNorm((1280,), eps=1

In [5]:
# Clear CUDA cache to free up memory
if torch.cuda.is_available():
    torch.cuda.empty_cache()

In [29]:
# deprecated?
def embed_sequence(sequence):
    """
    Convert a protein sequence to embedding using ESM-2 with mean pooling.
    
    Args:
    - sequence (str): A protein sequence as string.

    Returns:
    - embedding (np.array): Array of shape (D,) where D is embedding size.
    """

    # Tokenize sequence
    data = [(str(0), sequence)]
    batch_labels, batch_strs, batch_tokens = batch_converter(data)

    # Move to GPU if available
    batch_tokens = batch_tokens.to(device)

    # Forward pass to get embeddings
    with torch.no_grad():
        results = model(batch_tokens, repr_layers=[33], return_contacts=False)

    # Extract embedding (mean pooling excluding padding)
    token_embeddings = results["representations"][33]  # Extract last hidden layer (layer 33 for this model)
    
    # Apply mean pooling
    valid_tokens = batch_tokens[0] != alphabet.padding_idx  # Mask out padding tokens
    seq_embedding = token_embeddings[0, valid_tokens].mean(dim=0).cpu().numpy()

    return seq_embedding

def embed_sequence_batch(sequences, batch_size=8):
    """
    Convert a batch of protein sequences to embeddings using ESM-2 with mean pooling.
    
    Args:
    - sequences (list): List of protein sequences as strings.
    - batch_size (int): Number of sequences to process at once.

    Returns:
    - embeddings (np.array): Array of shape (N, D) where N is number of sequences, D is embedding size.
    """
    all_embeddings = []

    sub_batch_size = 1
    
    # Process in sub-batches to avoid memory issues
    for i in range(0, len(sequences), sub_batch_size):
        sub_batch_seqs = sequences[i:i+sub_batch_size]
        sub_batch_data = [(str(j), seq) for j, seq in enumerate(sub_batch_seqs)]
        
        batch_labels, batch_strs, batch_tokens = batch_converter(sub_batch_data)
        batch_tokens = batch_tokens.to(device)
        
        # Clear CUDA cache to free up memory
        if torch.cuda.is_available():
            torch.cuda.empty_cache()
        
        with torch.no_grad():
            results = model(batch_tokens, repr_layers=[33], return_contacts=False)
        
        token_embeddings = results["representations"][33]
        
        # Apply mean pooling for each sequence in the sub-batch
        for j in range(len(sub_batch_seqs)):
            valid_tokens = batch_tokens[j] != alphabet.padding_idx
            seq_embedding = token_embeddings[j, valid_tokens].mean(dim=0).cpu().numpy()
            all_embeddings.append(seq_embedding)
        
        # Clear CUDA cache again
        if torch.cuda.is_available():
            torch.cuda.empty_cache()
    
    return np.array(all_embeddings)


In [32]:

def process_and_save_embeddings(df, output_dir, batch_size=16, save_every=100):
    """
    Process sequences in batches, update FAISS index, and save periodically.
    Uses the existing 'index' column from the DataFrame as FAISS indices.
    
    Args:
    - df (DataFrame): DataFrame containing sequences, metadata, and an 'index' column
    - output_dir (str): Directory to save FAISS index and metadata
    - batch_size (int): Number of sequences to process in each batch for FAISS updates
    - save_every (int): Save after processing this many sequences
    
    Returns:
    - index: Final FAISS index
    - metadata: DataFrame with metadata
    """
    os.makedirs(output_dir, exist_ok=True)
    
    # Setup FAISS index with explicit IDs
    dimension = 1280  # ESM-2 embedding dimension
    index = faiss.IndexIDMap(faiss.IndexFlatL2(dimension))
    
    # Process in batches
    total_processed = 0
    
    for i in tqdm(range(0, len(df), batch_size), ncols=100):
        batch_df = df.iloc[i:i+batch_size]
        
        # Get embeddings for this batch
        batch_embeddings = embed_sequence_batch(batch_df['seq'].tolist())
        
        # Get the original indices from the DataFrame
        batch_indices = batch_df['index'].values
        
        # Add to FAISS index with explicit IDs
        index.add_with_ids(
            batch_embeddings.astype('float32'),
            np.array(batch_indices, dtype=np.int64)
        )
        
        total_processed += len(batch_embeddings)
        
        # Save periodically
        if total_processed % save_every == 0 or i + batch_size >= len(df):
            faiss_path = os.path.join(output_dir, 'protein_embeddings.index')
            
            # Save FAISS index
            faiss.write_index(index, faiss_path)
            
            # No need to save metadata separately since we're using the original indices
            print(f"Saved {total_processed} embeddings to {faiss_path}")
            print(f"Using original DataFrame indices as FAISS IDs")
    
    return index, df  # Return the original DataFrame as metadata

In [None]:
# Example usage
output_dir = '/scratch/gpfs/jr8867/embeddings/scop'
batch_size = 16  # Batch size for FAISS updates
save_every = 100  # Save after processing this many sequences

# Process all sequences
index, metadata = process_and_save_embeddings(scop_df, output_dir, batch_size, save_every)

In [38]:
# Function to search for similar proteins
def search_similar_proteins(query_seq, index, metadata, k=5):
    """
    Search for similar proteins using a query sequence.
    
    Args:
    - query_seq (str): Query protein sequence
    - index: FAISS index
    - metadata (DataFrame): Metadata DataFrame
    - k (int): Number of nearest neighbors to return
    
    Returns:
    - DataFrame: Metadata of similar proteins
    """
    # Embed query sequence
    query_embedding = embed_sequence(query_seq)
    query_embedding = np.array([query_embedding]).astype('float32')
    
    # Search in FAISS index
    D, I = index.search(query_embedding, k)
    
    # Get metadata for results
    result_indices = I[0]
    similar_proteins = metadata[metadata['index'].isin(result_indices)]
    
    # Add distance information
    distances = {idx: dist for idx, dist in zip(result_indices, D[0])}
    similar_proteins['distance'] = similar_proteins['index'].map(distances)
    
    # Sort by distance
    similar_proteins = similar_proteins.sort_values('distance')
    
    return similar_proteins

In [39]:
example_seq = scop_df.iloc[0]['seq']
index = faiss.read_index('/scratch/gpfs/jr8867/embeddings/scop/protein_embeddings.index')
similar = search_similar_proteins(example_seq, index, scop_df)
print("Similar proteins:")
print(similar)


Similar proteins:
    index     uid       fa       sf  \
0       0  Q03131  4000119  3000038   
32     32  Q59771  4000004  3000046   
23     23  P9WNX3  4000051  3000006   
22     22  P9WNX3  4000099  3000044   
24     24  P9WNX3  4000037  3000019   

                                                  seq  distance  
0   MSGPRSRTTSRRTPVRIGAVVVASSTSELLDGLAAVADGRPHASVV...  0.000000  
32  MSIDSALNWDGEMTVTRFDRETGAHFVIRLDSTQLGPAAGGTRAAQ...  3.769551  
23  MSLPVVLIADKLAPSTVAALGDQVEVRWVDGPDRDKLLAAVPEADA...  4.216374  
22  MSLPVVLIADKLAPSTVAALGDQVEVRWVDGPDRDKLLAAVPEADA...  4.216374  
24  MSLPVVLIADKLAPSTVAALGDQVEVRWVDGPDRDKLLAAVPEADA...  4.216374  


A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  similar_proteins['distance'] = similar_proteins['index'].map(distances)
