# Protein Embedding Generator for CAFA6

This notebook generates embeddings for protein sequences using ESM2 or ESM3 models.

Outputs: `protein_embeddings_train.npy`, `protein_embeddings_test.npy`, and ID lists.

## Model Selection

In [None]:
# Choose model: 'esm2' or 'esm3'
MODEL_TYPE = 'esm2'  # Change to 'esm3' for ESM3

if MODEL_TYPE == 'esm2':
    model_name = 'facebook/esm2_t6_8M_UR50D'
elif MODEL_TYPE == 'esm3':
    # ESM3 support (may require installation)
    # !pip install esm
    import esm
    model_name = 'esm3_sm_open_v0'  # Example, adjust as needed
else:
    raise ValueError("MODEL_TYPE must be 'esm2' or 'esm3'")

print(f'Selected model: {MODEL_TYPE} - {model_name}')

## Setup and Load Data

In [None]:
import os
import numpy as np
from Bio import SeqIO
import torch
from transformers import AutoTokenizer, AutoModel
import warnings
warnings.filterwarnings('ignore')

# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Using device: {device}')

# Data paths
data_dir = '/kaggle/input/cafa-6-protein-function-prediction/Train/'
train_seq_file = os.path.join(data_dir, 'train_sequences.fasta')
test_seq_file = os.path.join(data_dir, 'testsuperset.fasta')

# Load sequences
train_sequences = {record.id: str(record.seq) for record in SeqIO.parse(train_seq_file, 'fasta')}
test_sequences = {record.id: str(record.seq) for record in SeqIO.parse(test_seq_file, 'fasta')}
print(f'Loaded {len(train_sequences)} train and {len(test_sequences)} test sequences')

## Load Model and Generate Embeddings

In [None]:
if MODEL_TYPE == 'esm2':
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    model = AutoModel.from_pretrained(model_name)
    model.to(device)
    model.eval()
    
    def get_embedding(sequence):
        inputs = tokenizer(sequence, return_tensors='pt', truncation=True, max_length=1024).to(device)
        with torch.no_grad():
            outputs = model(**inputs)
        return outputs.last_hidden_state.mean(dim=1).cpu().numpy().flatten()
        
elif MODEL_TYPE == 'esm3':
    # Placeholder for ESM3 (adjust based on actual API)
    model = esm.pretrained.load_model_and_alphabet(model_name)[0]
    model.to(device)
    
    def get_embedding(sequence):
        # Implement ESM3 embedding logic
        pass

print(f'Model loaded: {model_name}')

# Batch processing
batch_size = 10

# Train embeddings
train_ids = list(train_sequences.keys())
train_embeddings = []
for i in range(0, len(train_ids), batch_size):
    batch_ids = train_ids[i:i+batch_size]
    batch_seqs = [train_sequences[pid] for pid in batch_ids]
    
    if MODEL_TYPE == 'esm2':
        inputs = tokenizer(batch_seqs, return_tensors='pt', truncation=True, max_length=1024, padding=True).to(device)
        with torch.no_grad():
            outputs = model(**inputs)
        embeddings = outputs.last_hidden_state.mean(dim=1).cpu().numpy()
    
    train_embeddings.extend(embeddings)
    if (i // batch_size) % 10 == 0:
        print(f'Processed {i + len(batch_ids)} / {len(train_ids)} train proteins')

# Test embeddings
test_ids = list(test_sequences.keys())
test_embeddings = []
for i in range(0, len(test_ids), batch_size):
    batch_ids = test_ids[i:i+batch_size]
    batch_seqs = [test_sequences[pid] for pid in batch_ids]
    
    if MODEL_TYPE == 'esm2':
        inputs = tokenizer(batch_seqs, return_tensors='pt', truncation=True, max_length=1024, padding=True).to(device)
        with torch.no_grad():
            outputs = model(**inputs)
        embeddings = outputs.last_hidden_state.mean(dim=1).cpu().numpy()
    
    test_embeddings.extend(embeddings)
    if (i // batch_size) % 10 == 0:
        print(f'Processed {i + len(batch_ids)} / {len(test_ids)} test proteins')

print(f'Embeddings generated. Train: {len(train_embeddings)}, Test: {len(test_embeddings)}')

## Save Embeddings

In [None]:
# Save to numpy files
np.save('/kaggle/working/protein_embeddings_train.npy', np.array(train_embeddings))
np.save('/kaggle/working/protein_embeddings_test.npy', np.array(test_embeddings))
np.save('/kaggle/working/train_ids.npy', np.array(train_ids))
np.save('/kaggle/working/test_ids.npy', np.array(test_ids))

print('Embeddings saved to /kaggle/working/')