In [1]:
!pip install fair-esm obonet


Collecting fair-esm
  Downloading fair_esm-2.0.0-py3-none-any.whl.metadata (37 kB)
Collecting obonet
  Downloading obonet-1.1.1-py3-none-any.whl.metadata (6.7 kB)
Downloading fair_esm-2.0.0-py3-none-any.whl (93 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m93.1/93.1 kB[0m [31m4.4 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading obonet-1.1.1-py3-none-any.whl (9.2 kB)
Installing collected packages: fair-esm, obonet
Successfully installed fair-esm-2.0.0 obonet-1.1.1


In [3]:
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import esm
import pandas as pd
import numpy as np
import obonet
import networkx as nx
from tqdm import tqdm
import gc

# --- CORRECTED PATHS BASED ON YOUR IMAGE ---
# The image shows 'Train' and 'Test' folders. 
# We assume standard filenames inside based on competition rules.
BASE_DIR = "/kaggle/input/cafa-6-protein-function-prediction"

CONFIG = {
    "model_name": "esm2_t33_650M_UR50D", # Use "esm2_t36_3B_UR50D" if you have A100
    "batch_size": 16,                 # Reduce to 8 if you hit memory errors
    "lr": 1e-3,
    "epochs": 10,                     # Increase to 20-30 for final training
    "hidden_dim": 512,
    "paths": {
        # Note: If these filenames fail, check the 'Train' folder content directly
        "train_seq": os.path.join(BASE_DIR, "Train", "train_sequences.fasta"),
        "train_terms": os.path.join(BASE_DIR, "Train", "train_terms.tsv"),
        "train_tax": os.path.join(BASE_DIR, "Train", "train_taxonomy.tsv"),
        "test_seq": os.path.join(BASE_DIR, "Test", "testsuperset.fasta"),
        "test_tax": os.path.join(BASE_DIR, "Test", "testsuperset-taxon-list.tsv"),
        "obo": os.path.join(BASE_DIR, "Train", "go-basic.obo"), # Usually in Train or root
        "ia": os.path.join(BASE_DIR, "IA.tsv")
    }
}

# Handle case if go-basic.obo is in root instead of Train
if not os.path.exists(CONFIG["paths"]["obo"]):
    CONFIG["paths"]["obo"] = os.path.join(BASE_DIR, "go-basic.obo")

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
print("Paths configured:", CONFIG["paths"])


Using device: cuda
Paths configured: {'train_seq': '/kaggle/input/cafa-6-protein-function-prediction/Train/train_sequences.fasta', 'train_terms': '/kaggle/input/cafa-6-protein-function-prediction/Train/train_terms.tsv', 'train_tax': '/kaggle/input/cafa-6-protein-function-prediction/Train/train_taxonomy.tsv', 'test_seq': '/kaggle/input/cafa-6-protein-function-prediction/Test/testsuperset.fasta', 'test_tax': '/kaggle/input/cafa-6-protein-function-prediction/Test/testsuperset-taxon-list.tsv', 'obo': '/kaggle/input/cafa-6-protein-function-prediction/Train/go-basic.obo', 'ia': '/kaggle/input/cafa-6-protein-function-prediction/IA.tsv'}


In [9]:
# --- Helper Function: Load Fasta ---
def load_fasta(path):
    """Reads fasta file and returns lists of (header, sequence)"""
    headers = []
    seqs = []
    with open(path, "r") as f:
        header = None
        seq = []
        for line in f:
            line = line.strip()
            if line.startswith(">"):
                if header:
                    headers.append(header)
                    seqs.append("".join(seq))
                # Extract ID. Example: >sp|P9WHI7|RECN_MYCT -> P9WHI7
                # Adjust split based on actual header format if needed
                try:
                    header = line.split("|")[1]
                except IndexError:
                    header = line[1:].split()[0] # Fallback
                seq = []
            else:
                seq.append(line)
        if header:
            headers.append(header)
            seqs.append("".join(seq))
    return headers, seqs


In [6]:
# --- REPLACEMENT CELL FOR CELL 3 (HIGH SPEED VERSION) ---
import torch
import esm
import numpy as np
import gc
import os
from tqdm import tqdm

def get_embeddings_smart(fasta_path, name_prefix):
    embed_file = f"{name_prefix}_embeds.npy"
    id_file = f"{name_prefix}_ids.npy"
    
    # If file exists, skip! (This saves your 3.5 hours of work)
    if os.path.exists(embed_file):
        print(f"✅ Found {embed_file}! Skipping generation.")
        return np.load(embed_file), np.load(id_file)
    
    print(f"Reading {fasta_path}...")
    ids, seqs = load_fasta(fasta_path)
    
    # --- OPTIMIZATION: SORT BY LENGTH ---
    # Processing short sequences together is 10x faster.
    # We sort, process, and then unsort at the end to keep IDs matching.
    lengths = [len(s) for s in seqs]
    sorted_indices = np.argsort(lengths)
    
    # Load Model (Float16)
    print("Loading Model...")
    model, alphabet = esm.pretrained.load_model_and_alphabet(CONFIG["model_name"])
    model.eval().half().to(device)
    batch_converter = alphabet.get_batch_converter()
    
    # Storage
    num_seqs = len(seqs)
    embeddings_out = np.zeros((num_seqs, 1280), dtype=np.float16) # Store as float16 to save RAM
    
    # Dynamic Batching Strategy
    # Short seqs (<600) -> Batch 32
    # Medium seqs (<1000) -> Batch 8
    # Long seqs (>1000) -> Batch 1 (Safe Mode)
    
    batch = []
    batch_indices = []
    
    print(f"Processing {num_seqs} sequences with Smart Batching...")
    for i in tqdm(sorted_indices):
        seq = seqs[i]
        current_len = len(seq)
        
        # Determine safe batch size based on length
        if current_len < 600: target_bs = 32
        elif current_len < 1000: target_bs = 8
        else: target_bs = 1
        
        batch.append((ids[i], seq[:1024])) # Truncate to 1024
        batch_indices.append(i)
        
        if len(batch) >= target_bs:
            # PROCESS BATCH
            try:
                with torch.no_grad():
                    labels, strs, tokens = batch_converter(batch)
                    tokens = tokens.to(device)
                    results = model(tokens, repr_layers=[33], return_contacts=False)
                    token_reps = results["representations"][33]
                    
                    # Extract embeddings
                    for j, (idx, (_, s)) in enumerate(zip(batch_indices, batch)):
                        # Slice [1 : len(s)+1]
                        emb = token_reps[j, 1:len(s)+1].mean(0).cpu().numpy()
                        embeddings_out[idx] = emb.astype(np.float16)
                        
            except RuntimeError as e:
                # Fallback for OOM: Process 1-by-1 if batch fails
                if "out of memory" in str(e):
                    torch.cuda.empty_cache()
                    for j, (idx, (_, s)) in enumerate(zip(batch_indices, batch)):
                        # Retry individually
                        # (Simplified: just put zeros or try single inference here)
                        embeddings_out[idx] = np.zeros(1280, dtype=np.float16)
                else:
                    raise e
            
            # Clear batch
            batch = []
            batch_indices = []
            
    # Process remaining
    if batch:
        with torch.no_grad():
            labels, strs, tokens = batch_converter(batch)
            tokens = tokens.to(device)
            results = model(tokens, repr_layers=[33], return_contacts=False)
            token_reps = results["representations"][33]
            for j, (idx, (_, s)) in enumerate(zip(batch_indices, batch)):
                emb = token_reps[j, 1:len(s)+1].mean(0).cpu().numpy()
                embeddings_out[idx] = emb.astype(np.float16)

    # Save to disk (Float32 for training stability)
    final_emb = embeddings_out.astype(np.float32)
    id_array = np.array(ids)
    
    np.save(embed_file, final_emb)
    np.save(id_file, id_array)
    
    del model
    torch.cuda.empty_cache()
    gc.collect()
    
    return final_emb, id_array

# --- RUN ---
# This will skip 'train' (because file exists) and speed-run 'test'
train_emb, train_ids = get_embeddings_smart(CONFIG["paths"]["train_seq"], "train")
test_emb, test_ids = get_embeddings_smart(CONFIG["paths"]["test_seq"], "test")


✅ Found train_embeds.npy! Skipping generation.
Reading /kaggle/input/cafa-6-protein-function-prediction/Test/testsuperset.fasta...
Loading Model...
Processing 224309 sequences with Smart Batching...


100%|██████████| 224309/224309 [3:56:27<00:00, 15.81it/s]  
