In [1]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

import os
for dirname, _, filenames in os.walk('/kaggle/input'):
    for filename in filenames:
        print(os.path.join(dirname, filename))

# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

/kaggle/input/cafa-6-protein-function-prediction/sample_submission.tsv
/kaggle/input/cafa-6-protein-function-prediction/IA.tsv
/kaggle/input/cafa-6-protein-function-prediction/Test/testsuperset.fasta
/kaggle/input/cafa-6-protein-function-prediction/Test/testsuperset-taxon-list.tsv
/kaggle/input/cafa-6-protein-function-prediction/Train/train_terms.tsv
/kaggle/input/cafa-6-protein-function-prediction/Train/train_sequences.fasta
/kaggle/input/cafa-6-protein-function-prediction/Train/train_taxonomy.tsv
/kaggle/input/cafa-6-protein-function-prediction/Train/go-basic.obo


In [3]:
# CELL 1 - Install Dependencies

import sys
!{sys.executable} -m pip install transformers biopython scipy -q

# Verify installation
try:
    from Bio import SeqIO
    from transformers import AutoModel
    import scipy.sparse as sp
    print("‚úÖ All dependencies installed successfully!")
    print("   - transformers")
    print("   - biopython") 
    print("   - scipy")
    print("\n‚úÖ Now run CELL 2")
except ImportError as e:
    print(f"‚ùå Installation failed: {e}")
    print("   Try restarting the kernel and running this cell again")

‚úÖ All dependencies installed successfully!
   - transformers
   - biopython
   - scipy

‚úÖ Now run CELL 2


In [4]:
"""
CELL 2 - LOAD AND PREPARE DATA
This loads sequences, annotations, and creates datasets
Time: ~2-3 minutes
"""

import pandas as pd
import numpy as np
from Bio import SeqIO
from collections import defaultdict
import scipy.sparse as sp
import gc
import torch
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split

print("="*70)
print("üìä CELL 2: LOADING DATA")
print("="*70)

# Path to data
DATA_PATH = '/kaggle/input/cafa-6-protein-function-prediction'

# ----------------------------------------------------------------------------
# 2.1: Load sequences with correct ID parsing
# ----------------------------------------------------------------------------

def parse_uniprot_id(fasta_header):
    """Extract UniProt accession: 'sp|P9WHI7|RECN_MYCT' -> 'P9WHI7'"""
    parts = fasta_header.split('|')
    return parts[1] if len(parts) >= 2 else fasta_header

def load_fasta(filepath, max_samples=None):
    """Load FASTA file"""
    records = []
    print(f"Loading {filepath}...")
    for i, record in enumerate(SeqIO.parse(filepath, "fasta")):
        records.append({
            'protein_id': parse_uniprot_id(record.id),
            'sequence': str(record.seq)
        })
        if max_samples and i >= max_samples - 1:
            break
    return pd.DataFrame(records)

# Load training sequences (use subset for faster training)
USE_SUBSET = True  # Set False for full training
if USE_SUBSET:
    print("‚ö†Ô∏è Using 10k subset for faster training")
    train_sequences = load_fasta(f'{DATA_PATH}/Train/train_sequences.fasta', max_samples=10000)
else:
    train_sequences = load_fasta(f'{DATA_PATH}/Train/train_sequences.fasta')

print(f"‚úÖ Train sequences: {len(train_sequences):,}")

# ----------------------------------------------------------------------------
# 2.2: Load GO annotations
# ----------------------------------------------------------------------------

print("\nüìã Loading GO annotations...")
train_annotations = pd.read_csv(
    f'{DATA_PATH}/Train/train_terms.tsv',
    sep='\t', header=None,
    names=['protein_id', 'go_term', 'aspect']
)

# Filter to loaded proteins
train_protein_ids = set(train_sequences['protein_id'].values)
train_annotations = train_annotations[train_annotations['protein_id'].isin(train_protein_ids)]

print(f"‚úÖ Annotations: {len(train_annotations):,}")

# ----------------------------------------------------------------------------
# 2.3: Filter GO terms by frequency
# ----------------------------------------------------------------------------

print("\nüî§ Filtering GO terms...")
MIN_FREQ = 10  # Terms must appear at least 10 times
term_counts = train_annotations['go_term'].value_counts()
frequent_terms = term_counts[term_counts >= MIN_FREQ].index.tolist()
train_annotations = train_annotations[train_annotations['go_term'].isin(frequent_terms)]

all_go_terms = sorted(frequent_terms)
print(f"‚úÖ Using {len(all_go_terms):,} GO terms")

# ----------------------------------------------------------------------------
# 2.4: Create sparse label matrix
# ----------------------------------------------------------------------------

print("\nüî¢ Creating label matrix...")
def create_sparse_labels(annotations_df, protein_ids, go_terms):
    n_proteins, n_terms = len(protein_ids), len(go_terms)
    protein_to_idx = {pid: i for i, pid in enumerate(protein_ids)}
    term_to_idx = {term: i for i, term in enumerate(go_terms)}
    
    label_matrix = sp.lil_matrix((n_proteins, n_terms), dtype=np.float32)
    
    for _, row in annotations_df.iterrows():
        if row['protein_id'] in protein_to_idx and row['go_term'] in term_to_idx:
            label_matrix[protein_to_idx[row['protein_id']], term_to_idx[row['go_term']]] = 1.0
    
    return label_matrix.tocsr()

train_labels_sparse = create_sparse_labels(
    train_annotations,
    train_sequences['protein_id'].tolist(),
    all_go_terms
)
print(f"‚úÖ Label matrix: {train_labels_sparse.shape}")

# ----------------------------------------------------------------------------
# 2.5: Train/validation split
# ----------------------------------------------------------------------------

print("\n‚úÇÔ∏è Splitting data...")
train_indices, val_indices = train_test_split(
    np.arange(len(train_sequences)), test_size=0.15, random_state=42
)

train_df = train_sequences.iloc[train_indices].reset_index(drop=True)
val_df = train_sequences.iloc[val_indices].reset_index(drop=True)
train_labels_split = train_labels_sparse[train_indices]
val_labels_split = train_labels_sparse[val_indices]

print(f"‚úÖ Train: {len(train_df):,} | Val: {len(val_df):,}")

# ----------------------------------------------------------------------------
# 2.6: Create datasets
# ----------------------------------------------------------------------------

print("\nüîß Creating datasets...")

class SparseProteinDataset(Dataset):
    def __init__(self, sequences_df, labels_sparse, go_terms):
        self.sequences_df = sequences_df.reset_index(drop=True)
        self.labels_sparse = labels_sparse
        self.go_terms = go_terms
    
    def __len__(self):
        return len(self.sequences_df)
    
    def __getitem__(self, idx):
        row = self.sequences_df.iloc[idx]
        labels_dense = self.labels_sparse[idx].toarray().squeeze() if self.labels_sparse is not None else None
        return {
            'sequence': row['sequence'],
            'protein_id': row['protein_id'],
            'labels': torch.FloatTensor(labels_dense) if labels_dense is not None else None
        }

train_dataset = SparseProteinDataset(train_df, train_labels_split, all_go_terms)
val_dataset = SparseProteinDataset(val_df, val_labels_split, all_go_terms)

BATCH_SIZE = 4
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=0)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=0)

print(f"‚úÖ DataLoaders ready (batch size: {BATCH_SIZE})")

# ----------------------------------------------------------------------------
# 2.7: Load GO hierarchy and IA weights
# ----------------------------------------------------------------------------

print("\nüå≥ Loading GO hierarchy...")
# Simplified hierarchy (each term is its own parent)
go_hierarchy_df = pd.DataFrame({'child': all_go_terms, 'parent': all_go_terms})

print("\n‚öñÔ∏è Loading IA weights...")
ia_df = pd.read_csv(f'{DATA_PATH}/IA.tsv', sep='\t', header=None, names=['go_term', 'ia_weight'])
ia_df = ia_df[ia_df['go_term'].isin(all_go_terms)]
ia_weights_dict = dict(zip(ia_df['go_term'], ia_df['ia_weight']))

# Fill missing with mean
mean_ia = np.mean(list(ia_weights_dict.values()))
for term in all_go_terms:
    if term not in ia_weights_dict:
        ia_weights_dict[term] = mean_ia

print(f"‚úÖ IA weights loaded")

# Cleanup
del train_annotations, train_labels_sparse
gc.collect()

print("\n" + "="*70)
print("‚úÖ CELL 2 COMPLETE - Data loaded and ready!")
print("="*70)
print("\nNext: Run CELL 3 (Model Definition)")


üìä CELL 2: LOADING DATA
‚ö†Ô∏è Using 10k subset for faster training
Loading /kaggle/input/cafa-6-protein-function-prediction/Train/train_sequences.fasta...
‚úÖ Train sequences: 10,000

üìã Loading GO annotations...
‚úÖ Annotations: 107,297

üî§ Filtering GO terms...
‚úÖ Using 1,603 GO terms

üî¢ Creating label matrix...
‚úÖ Label matrix: (10000, 1603)

‚úÇÔ∏è Splitting data...
‚úÖ Train: 8,500 | Val: 1,500

üîß Creating datasets...
‚úÖ DataLoaders ready (batch size: 4)

üå≥ Loading GO hierarchy...

‚öñÔ∏è Loading IA weights...
‚úÖ IA weights loaded

‚úÖ CELL 2 COMPLETE - Data loaded and ready!

Next: Run CELL 3 (Model Definition)


In [5]:
"""
CELL 3 - MODEL DEFINITION
This defines the neural network and helper classes
Time: ~1 minute
"""

import torch.nn as nn
from transformers import AutoTokenizer, AutoModel

print("="*70)
print("üß† CELL 3: DEFINING MODEL")
print("="*70)

# ----------------------------------------------------------------------------
# 3.1: GO Hierarchy Class
# ----------------------------------------------------------------------------

class GOHierarchy:
    """Handles GO term hierarchy for label propagation"""
    def __init__(self, hierarchy_df, all_go_terms):
        self.all_go_terms = all_go_terms
        self.term_to_idx = {term: i for i, term in enumerate(all_go_terms)}
        self.children = {term: set() for term in all_go_terms}
        self.parents = {term: set() for term in all_go_terms}
        
        for _, row in hierarchy_df.iterrows():
            child, parent = row['child'], row['parent']
            if child in self.term_to_idx and parent in self.term_to_idx:
                self.children[parent].add(child)
                self.parents[child].add(parent)
    
    def propagate_predictions(self, predictions):
        """Propagate child probabilities to parents"""
        if isinstance(predictions, torch.Tensor):
            pred_np = predictions.cpu().numpy()
            was_torch = True
        else:
            pred_np = predictions
            was_torch = False
        
        propagated = pred_np.copy()
        for term in self.all_go_terms:
            if term not in self.parents:
                continue
            term_idx = self.term_to_idx[term]
            term_probs = propagated[:, term_idx]
            for parent in self.parents[term]:
                parent_idx = self.term_to_idx[parent]
                propagated[:, parent_idx] = np.maximum(propagated[:, parent_idx], term_probs)
        
        return torch.from_numpy(propagated).to(predictions.device) if was_torch else propagated

print("‚úÖ GOHierarchy class defined")

# ----------------------------------------------------------------------------
# 3.2: IA Weights Class
# ----------------------------------------------------------------------------

class IAWeights:
    """Handles Information Accretion weights"""
    def __init__(self, ia_dict, all_go_terms):
        self.weights = torch.FloatTensor([ia_dict.get(term, 1.0) for term in all_go_terms])
    
    def get_weights(self, device='cpu'):
        return self.weights.to(device)

print("‚úÖ IAWeights class defined")

# ----------------------------------------------------------------------------
# 3.3: Main Model Architecture
# ----------------------------------------------------------------------------

class ProteinFunctionPredictor(nn.Module):
    """ESM2-based protein function predictor"""
    def __init__(self, num_go_terms, esm_model_name="facebook/esm2_t12_35M_UR50D",
                 hidden_dim=512, dropout=0.3, freeze_esm=True):
        super().__init__()
        self.num_go_terms = num_go_terms
        self.esm_model_name = esm_model_name
        
        # Load ESM2
        self.esm = AutoModel.from_pretrained(esm_model_name)
        if freeze_esm:
            for param in self.esm.parameters():
                param.requires_grad = False
        
        self.embedding_dim = self.esm.config.hidden_size
        
        # Classifier head
        self.classifier = nn.Sequential(
            nn.Linear(self.embedding_dim, hidden_dim),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, hidden_dim // 2),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim // 2, num_go_terms)
        )
    
    def forward(self, input_ids, attention_mask):
        with torch.no_grad() if not self.training else torch.enable_grad():
            outputs = self.esm(input_ids=input_ids, attention_mask=attention_mask)
        sequence_embedding = outputs.last_hidden_state[:, 0, :]
        logits = self.classifier(sequence_embedding)
        return logits

print("‚úÖ ProteinFunctionPredictor class defined")

# ----------------------------------------------------------------------------
# 3.4: Trainer Class
# ----------------------------------------------------------------------------

class CAFA6Trainer:
    """Handles training loop"""
    def __init__(self, model, go_hierarchy, ia_weights, device='cuda', learning_rate=2e-4):
        self.model = model.to(device)
        self.go_hierarchy = go_hierarchy
        self.ia_weights = ia_weights
        self.device = device
        self.tokenizer = AutoTokenizer.from_pretrained(model.esm_model_name)
        
        weights = ia_weights.get_weights(device) / ia_weights.get_weights(device).mean()
        self.criterion = nn.BCEWithLogitsLoss(pos_weight=weights)
        self.optimizer = torch.optim.AdamW(
            [p for p in model.parameters() if p.requires_grad],
            lr=learning_rate, weight_decay=0.01
        )

print("‚úÖ CAFA6Trainer class defined")

print("\n" + "="*70)
print("‚úÖ CELL 3 COMPLETE - Model architecture defined!")
print("="*70)
print("\nNext: Run CELL 4 (Training)")


üß† CELL 3: DEFINING MODEL
‚úÖ GOHierarchy class defined
‚úÖ IAWeights class defined
‚úÖ ProteinFunctionPredictor class defined
‚úÖ CAFA6Trainer class defined

‚úÖ CELL 3 COMPLETE - Model architecture defined!

Next: Run CELL 4 (Training)


In [6]:
"""
CELL 4 - TRAIN MODEL
This trains the model on training data
Time: ~10-15 minutes for 3 epochs
"""

from tqdm.auto import tqdm

print("="*70)
print("üèãÔ∏è CELL 4: TRAINING MODEL")
print("="*70)

# ----------------------------------------------------------------------------
# 4.1: Initialize components
# ----------------------------------------------------------------------------

print("\nüîß Initializing...")
go_hierarchy = GOHierarchy(go_hierarchy_df, all_go_terms)
ia_weights = IAWeights(ia_weights_dict, all_go_terms)

model = ProteinFunctionPredictor(
    num_go_terms=len(all_go_terms),
    esm_model_name="facebook/esm2_t12_35M_UR50D",
    hidden_dim=512, dropout=0.3, freeze_esm=True
)

trainer = CAFA6Trainer(
    model=model, go_hierarchy=go_hierarchy, ia_weights=ia_weights,
    device='cuda' if torch.cuda.is_available() else 'cpu', learning_rate=2e-4
)

print(f"‚úÖ Model ready: {sum(p.numel() for p in model.parameters() if p.requires_grad):,} trainable params")

# ----------------------------------------------------------------------------
# 4.2: Training configuration
# ----------------------------------------------------------------------------

NUM_EPOCHS = 3
ACCUMULATION_STEPS = 4
MAX_SEQ_LENGTH = 512

print(f"\n‚öôÔ∏è Config: {NUM_EPOCHS} epochs, batch={BATCH_SIZE}, accumulation={ACCUMULATION_STEPS}")

# ----------------------------------------------------------------------------
# 4.3: Training loop
# ----------------------------------------------------------------------------

best_val_loss = float('inf')
training_history = {'train_loss': [], 'val_loss': []}

for epoch in range(NUM_EPOCHS):
    print(f"\n{'='*70}")
    print(f"EPOCH {epoch+1}/{NUM_EPOCHS}")
    print('='*70)
    
    # Training
    model.train()
    total_loss = 0
    batch_count = 0
    trainer.optimizer.zero_grad()
    
    for batch_idx, batch in enumerate(tqdm(train_loader, desc="Training")):
        sequences = batch['sequence']
        labels = batch['labels'].to(trainer.device)
        
        encoded = trainer.tokenizer(sequences, padding=True, truncation=True,
                                   max_length=MAX_SEQ_LENGTH, return_tensors='pt')
        input_ids = encoded['input_ids'].to(trainer.device)
        attention_mask = encoded['attention_mask'].to(trainer.device)
        
        logits = model(input_ids, attention_mask)
        loss = trainer.criterion(logits, labels) / ACCUMULATION_STEPS
        loss.backward()
        
        if (batch_idx + 1) % ACCUMULATION_STEPS == 0:
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            trainer.optimizer.step()
            trainer.optimizer.zero_grad()
        
        total_loss += loss.item() * ACCUMULATION_STEPS
        batch_count += 1
        
        del input_ids, attention_mask, logits, loss, encoded
        if batch_idx % 20 == 0:
            torch.cuda.empty_cache()
            gc.collect()
    
    avg_train_loss = total_loss / batch_count
    training_history['train_loss'].append(avg_train_loss)
    
    # Validation
    model.eval()
    val_loss = 0
    val_batch_count = 0
    
    with torch.no_grad():
        for batch in tqdm(val_loader, desc="Validation"):
            sequences = batch['sequence']
            labels = batch['labels'].to(trainer.device)
            
            encoded = trainer.tokenizer(sequences, padding=True, truncation=True,
                                       max_length=MAX_SEQ_LENGTH, return_tensors='pt')
            input_ids = encoded['input_ids'].to(trainer.device)
            attention_mask = encoded['attention_mask'].to(trainer.device)
            
            logits = model(input_ids, attention_mask)
            loss = trainer.criterion(logits, labels)
            val_loss += loss.item()
            val_batch_count += 1
            
            del input_ids, attention_mask, logits, loss, encoded
    
    val_loss /= val_batch_count
    training_history['val_loss'].append(val_loss)
    
    print(f"\nüìà Epoch {epoch+1}: Train Loss={avg_train_loss:.4f}, Val Loss={val_loss:.4f}")
    
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': trainer.optimizer.state_dict(),
            'train_loss': avg_train_loss,
            'val_loss': val_loss,
            'go_terms': all_go_terms
        }, '/kaggle/working/best_model.pt')
        print(f"‚úÖ Best model saved! (Val Loss: {val_loss:.4f})")
    
    torch.cuda.empty_cache()
    gc.collect()

# Load best model
checkpoint = torch.load('/kaggle/working/best_model.pt')
model.load_state_dict(checkpoint['model_state_dict'])
print(f"\n‚úÖ Loaded best model from epoch {checkpoint['epoch']+1}")

print("\n" + "="*70)
print("‚úÖ CELL 4 COMPLETE - Model trained!")
print("="*70)
print("\nNext: Run CELL 5 (Generate Predictions)")


üèãÔ∏è CELL 4: TRAINING MODEL

üîß Initializing...


config.json:   0%|          | 0.00/778 [00:00<?, ?B/s]

2025-11-18 06:57:13.518757: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1763449033.701920      48 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1763449033.750594      48 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


AttributeError: 'MessageFactory' object has no attribute 'GetPrototype'

AttributeError: 'MessageFactory' object has no attribute 'GetPrototype'

AttributeError: 'MessageFactory' object has no attribute 'GetPrototype'

AttributeError: 'MessageFactory' object has no attribute 'GetPrototype'

AttributeError: 'MessageFactory' object has no attribute 'GetPrototype'

model.safetensors:   0%|          | 0.00/136M [00:00<?, ?B/s]

Some weights of EsmModel were not initialized from the model checkpoint at facebook/esm2_t12_35M_UR50D and are newly initialized: ['pooler.dense.bias', 'pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


tokenizer_config.json:   0%|          | 0.00/95.0 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/93.0 [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/125 [00:00<?, ?B/s]

‚úÖ Model ready: 789,571 trainable params

‚öôÔ∏è Config: 3 epochs, batch=4, accumulation=4

EPOCH 1/3


Training:   0%|          | 0/2125 [00:00<?, ?it/s]

Validation:   0%|          | 0/375 [00:00<?, ?it/s]


üìà Epoch 1: Train Loss=0.0566, Val Loss=0.0192
‚úÖ Best model saved! (Val Loss: 0.0192)

EPOCH 2/3


Training:   0%|          | 0/2125 [00:00<?, ?it/s]

Validation:   0%|          | 0/375 [00:00<?, ?it/s]


üìà Epoch 2: Train Loss=0.0203, Val Loss=0.0191
‚úÖ Best model saved! (Val Loss: 0.0191)

EPOCH 3/3


Training:   0%|          | 0/2125 [00:00<?, ?it/s]

Validation:   0%|          | 0/375 [00:00<?, ?it/s]


üìà Epoch 3: Train Loss=0.0197, Val Loss=0.0189
‚úÖ Best model saved! (Val Loss: 0.0189)

‚úÖ Loaded best model from epoch 3

‚úÖ CELL 4 COMPLETE - Model trained!

Next: Run CELL 5 (Generate Predictions)


In [7]:
"""
CELL 5 - GENERATE PREDICTIONS FOR ALL TEST PROTEINS
This is the CRITICAL cell - predicts for all 224k proteins
Time: ~60-90 minutes
"""

print("="*70)
print("üîÆ CELL 5: GENERATING PREDICTIONS FOR ALL TEST PROTEINS")
print("="*70)

# ----------------------------------------------------------------------------
# 5.1: Load FULL test superset
# ----------------------------------------------------------------------------

print("\nüì• Loading FULL test superset...")
print("‚ö†Ô∏è Loading ALL 224,309 proteins (not just 5,000)")

test_sequences_full = []
for record in SeqIO.parse(f'{DATA_PATH}/Test/testsuperset.fasta', 'fasta'):
    test_sequences_full.append({
        'protein_id': parse_uniprot_id(record.id),
        'sequence': str(record.seq)
    })

test_df_full = pd.DataFrame(test_sequences_full)
print(f"‚úÖ Loaded {len(test_df_full):,} test proteins")

# ----------------------------------------------------------------------------
# 5.2: Create test dataset and loader
# ----------------------------------------------------------------------------

class TestProteinDataset(Dataset):
    def __init__(self, sequences_df):
        self.sequences_df = sequences_df.reset_index(drop=True)
    
    def __len__(self):
        return len(self.sequences_df)
    
    def __getitem__(self, idx):
        row = self.sequences_df.iloc[idx]
        return {
            'sequence': row['sequence'],
            'protein_id': row['protein_id']
        }

test_dataset_full = TestProteinDataset(test_df_full)
test_loader_full = DataLoader(test_dataset_full, batch_size=4, shuffle=False, num_workers=0)

print(f"‚úÖ Test loader ready: {len(test_dataset_full):,} proteins")

# ----------------------------------------------------------------------------
# 5.3: Generate predictions (THIS TAKES TIME!)
# ----------------------------------------------------------------------------

print("\nüîÆ Generating predictions...")
print("‚è±Ô∏è This will take 60-90 minutes - be patient!")

model.eval()
all_predictions_full = []
all_protein_ids_full = []

with torch.no_grad():
    for batch_idx, batch in enumerate(tqdm(test_loader_full, desc="Predicting")):
        sequences = batch['sequence']
        protein_ids = batch['protein_id']
        
        # Tokenize
        encoded = trainer.tokenizer(
            sequences, padding=True, truncation=True,
            max_length=MAX_SEQ_LENGTH, return_tensors='pt'
        )
        input_ids = encoded['input_ids'].to(trainer.device)
        attention_mask = encoded['attention_mask'].to(trainer.device)
        
        # Predict
        logits = model(input_ids, attention_mask)
        probs = torch.sigmoid(logits)
        
        # Propagate through GO hierarchy
        probs = go_hierarchy.propagate_predictions(probs)
        
        all_predictions_full.append(probs.cpu().numpy())
        all_protein_ids_full.extend(protein_ids)
        
        # Memory cleanup
        del input_ids, attention_mask, logits, probs, encoded
        
        if batch_idx % 100 == 0:
            torch.cuda.empty_cache()
            gc.collect()
        
        # Progress update every 10k proteins
        if (batch_idx * 4) % 10000 == 0:
            print(f"   Processed {batch_idx * 4:,} / {len(test_df_full):,} proteins...")

# Combine predictions
test_predictions_full = np.vstack(all_predictions_full)
print(f"\n‚úÖ Predictions complete!")
print(f"   Shape: {test_predictions_full.shape}")
print(f"   Proteins: {len(all_protein_ids_full):,}")

print("\n" + "="*70)
print("‚úÖ CELL 5 COMPLETE - All predictions generated!")
print("="*70)
print("\nNext: Run CELL 6 (Create Submission)")


üîÆ CELL 5: GENERATING PREDICTIONS FOR ALL TEST PROTEINS

üì• Loading FULL test superset...
‚ö†Ô∏è Loading ALL 224,309 proteins (not just 5,000)
‚úÖ Loaded 224,309 test proteins
‚úÖ Test loader ready: 224,309 proteins

üîÆ Generating predictions...
‚è±Ô∏è This will take 60-90 minutes - be patient!


Predicting:   0%|          | 0/56078 [00:00<?, ?it/s]

   Processed 0 / 224,309 proteins...
   Processed 10,000 / 224,309 proteins...
   Processed 20,000 / 224,309 proteins...
   Processed 30,000 / 224,309 proteins...
   Processed 40,000 / 224,309 proteins...
   Processed 50,000 / 224,309 proteins...
   Processed 60,000 / 224,309 proteins...
   Processed 70,000 / 224,309 proteins...
   Processed 80,000 / 224,309 proteins...
   Processed 90,000 / 224,309 proteins...
   Processed 100,000 / 224,309 proteins...
   Processed 110,000 / 224,309 proteins...
   Processed 120,000 / 224,309 proteins...
   Processed 130,000 / 224,309 proteins...
   Processed 140,000 / 224,309 proteins...
   Processed 150,000 / 224,309 proteins...
   Processed 160,000 / 224,309 proteins...
   Processed 170,000 / 224,309 proteins...
   Processed 180,000 / 224,309 proteins...
   Processed 190,000 / 224,309 proteins...
   Processed 200,000 / 224,309 proteins...
   Processed 210,000 / 224,309 proteins...
   Processed 220,000 / 224,309 proteins...

‚úÖ Predictions complete!

In [8]:
"""
CELL 6 - CREATE SUBMISSION FILE
This creates the final submission using Top-K method
Time: ~5-10 minutes
"""

print("="*70)
print("üìù CELL 6: CREATING SUBMISSION FILE")
print("="*70)

# ----------------------------------------------------------------------------
# 6.1: Top-K submission creator
# ----------------------------------------------------------------------------

def create_submission_topk(protein_ids, predictions, go_terms, k=50, output_file='submission.tsv'):
    """Create submission by selecting top K terms per protein"""
    print(f"\nüìä Creating Top-{k} submission for {len(protein_ids):,} proteins...")
    
    rows = []
    for i in tqdm(range(len(protein_ids)), desc="Building submission"):
        pid = protein_ids[i]
        probs = predictions[i]
        
        # Get top k predictions
        if k >= len(probs):
            topk_idx = np.argsort(probs)[::-1]
        else:
            topk_idx = np.argpartition(probs, -k)[-k:]
            topk_idx = topk_idx[np.argsort(probs[topk_idx])][::-1]
        
        # Add to submission
        for j in topk_idx:
            rows.append(f"{pid}\t{go_terms[j]}\t{probs[j]:.6f}\n")
    
    # Write to file
    with open(output_file, 'w') as f:
        f.writelines(rows)
    
    print(f"‚úÖ Saved: {output_file}")
    print(f"   Total predictions: {len(rows):,}")
    print(f"   Proteins: {len(protein_ids):,}")
    print(f"   Predictions per protein: {k}")
    
    return output_file

# ----------------------------------------------------------------------------
# 6.2: Generate submissions
# ----------------------------------------------------------------------------

print("\n1Ô∏è‚É£ Creating Top-50 submission (RECOMMENDED):")
submission_50 = create_submission_topk(
    all_protein_ids_full,
    test_predictions_full,
    all_go_terms,
    k=50,
    output_file='/kaggle/working/submission_full_top50.tsv'
)

print("\n2Ô∏è‚É£ Creating Top-100 submission:")
submission_100 = create_submission_topk(
    all_protein_ids_full,
    test_predictions_full,
    all_go_terms,
    k=100,
    output_file='/kaggle/working/submission_full_top100.tsv'
)

print("\n3Ô∏è‚É£ Creating Top-20 submission:")
submission_20 = create_submission_topk(
    all_protein_ids_full,
    test_predictions_full,
    all_go_terms,
    k=20,
    output_file='/kaggle/working/submission_full_top20.tsv'
)

# ----------------------------------------------------------------------------
# 6.3: Validate submissions
# ----------------------------------------------------------------------------

print("\n" + "="*70)
print("‚úÖ VALIDATION")
print("="*70)

import os

for k, filepath in [(50, submission_50), (100, submission_100), (20, submission_20)]:
    size_mb = os.path.getsize(filepath) / 1e6
    print(f"\nTop-{k}: {filepath}")
    print(f"  Size: {size_mb:.1f} MB")
    print(f"  Expected predictions: {len(all_protein_ids_full) * k:,}")

# Sample check
print("\nSample from Top-50:")
with open(submission_50, 'r') as f:
    for i, line in enumerate(f):
        if i < 5:
            print(f"  {line.strip()}")
        else:
            break

print("\n" + "="*70)
print("üéâ PIPELINE COMPLETE!")
print("="*70)

print("\nüì§ SUBMISSION FILES READY:")
print("   1. submission_full_top50.tsv  ‚Üê SUBMIT THIS FIRST")
print("   2. submission_full_top100.tsv")
print("   3. submission_full_top20.tsv")

print("\nüìä WHAT CHANGED:")
print(f"   OLD: 5,000 proteins, 1.9 predictions per protein ‚Üí Score: 0.00")
print(f"   NEW: {len(all_protein_ids_full):,} proteins, 50 predictions per protein ‚Üí Should be >0!")

print("\nüí° EXPECTED RESULTS:")
print("   ‚úÖ Non-zero score (almost guaranteed)")
print("   ‚úÖ Full test set coverage")
print("   ‚úÖ No threshold issues")

print("\nüîÑ TO IMPROVE FURTHER:")
print("   1. Set USE_SUBSET=False in Cell 2 (train on full data)")
print("   2. Increase NUM_EPOCHS to 5-10 in Cell 4")
print("   3. Use larger model: esm2_t33_650M_UR50D")
print("   4. Reduce MIN_FREQ to 5 (more GO terms)")
print("   5. Tune per-term thresholds on validation set")

print("\n" + "="*70)
print("‚úÖ READY TO SUBMIT! Download submission_full_top50.tsv")
print("="*70)


üìù CELL 6: CREATING SUBMISSION FILE

1Ô∏è‚É£ Creating Top-50 submission (RECOMMENDED):

üìä Creating Top-50 submission for 224,309 proteins...


Building submission:   0%|          | 0/224309 [00:00<?, ?it/s]

‚úÖ Saved: /kaggle/working/submission_full_top50.tsv
   Total predictions: 11,215,450
   Proteins: 224,309
   Predictions per protein: 50

2Ô∏è‚É£ Creating Top-100 submission:

üìä Creating Top-100 submission for 224,309 proteins...


Building submission:   0%|          | 0/224309 [00:00<?, ?it/s]

‚úÖ Saved: /kaggle/working/submission_full_top100.tsv
   Total predictions: 22,430,900
   Proteins: 224,309
   Predictions per protein: 100

3Ô∏è‚É£ Creating Top-20 submission:

üìä Creating Top-20 submission for 224,309 proteins...


Building submission:   0%|          | 0/224309 [00:00<?, ?it/s]

‚úÖ Saved: /kaggle/working/submission_full_top20.tsv
   Total predictions: 4,486,180
   Proteins: 224,309
   Predictions per protein: 20

‚úÖ VALIDATION

Top-50: /kaggle/working/submission_full_top50.tsv
  Size: 303.7 MB
  Expected predictions: 11,215,450

Top-100: /kaggle/working/submission_full_top100.tsv
  Size: 607.4 MB
  Expected predictions: 22,430,900

Top-20: /kaggle/working/submission_full_top20.tsv
  Size: 121.5 MB
  Expected predictions: 4,486,180

Sample from Top-50:
  A0A0C5B5G6	GO:0005515	0.342592
  A0A0C5B5G6	GO:0005829	0.220537
  A0A0C5B5G6	GO:0005576	0.156827
  A0A0C5B5G6	GO:0042802	0.118702
  A0A0C5B5G6	GO:0016020	0.097402

üéâ PIPELINE COMPLETE!

üì§ SUBMISSION FILES READY:
   1. submission_full_top50.tsv  ‚Üê SUBMIT THIS FIRST
   2. submission_full_top100.tsv
   3. submission_full_top20.tsv

üìä WHAT CHANGED:
   OLD: 5,000 proteins, 1.9 predictions per protein ‚Üí Score: 0.00
   NEW: 224,309 proteins, 50 predictions per protein ‚Üí Should be >0!

üí° EXPECTED RE