# CAFA-6 Training Notebook

This notebook trains a multi-head MLP model for protein function prediction using ProtT5 embeddings.

## Setup Instructions

1. Add the "cafa-6-dataset" dataset to this notebook
2. Enable GPU accelerator (Settings > Accelerator > GPU T4 x2 or P100)
3. Add your wandb API key as a Kaggle Secret named `WANDB_API_KEY`
4. Run all cells


In [None]:
# Install dependencies
!pip install -q wandb transformers sentencepiece


In [None]:
# Setup wandb
import os
from kaggle_secrets import UserSecretsClient

try:
    user_secrets = UserSecretsClient()
    wandb_api_key = user_secrets.get_secret("WANDB_API_KEY")
    os.environ["WANDB_API_KEY"] = wandb_api_key
    print("Wandb API key loaded from Kaggle secrets")
except Exception as e:
    print(f"Warning: Could not load wandb API key: {e}")
    print("You can still train, but wandb logging will be disabled")


In [None]:
import wandb

# Login to wandb
try:
    wandb.login()
    USE_WANDB = True
    print("Logged into wandb successfully!")
except Exception as e:
    print(f"Wandb login failed: {e}")
    USE_WANDB = False


In [None]:
import sys
import random
from pathlib import Path

import numpy as np
import pandas as pd
import torch
import torch.nn as nn
from torch.optim import Adam
from torch.utils.data import DataLoader, Dataset, Subset
from tqdm.auto import tqdm
from sklearn.model_selection import train_test_split

print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")


In [None]:
# Configuration
CONFIG = {
    # Data paths (Kaggle competition data)
    "train_fasta": "/kaggle/input/cafa-6-dataset/Train/train_sequences.fasta",
    "train_terms": "/kaggle/input/cafa-6-dataset/Train/train_terms.tsv",
    "go_obo": "/kaggle/input/cafa-6-dataset/Train/go-basic.obo",
    
    # Model
    "embedding_dim": 1024,  # ProtT5-XL embedding dimension
    "hidden_dims": [512, 256],
    "dropout": 0.1,
    
    # Training
    "batch_size": 32,
    "learning_rate": 1e-3,
    "weight_decay": 1e-5,
    "epochs": 100,
    "patience": 10,
    "val_split": 0.1,
    "seed": 42,
    
    # Embedding generation
    "embed_batch_size": 4,  # Smaller for Kaggle GPU memory
    "max_seq_length": 1024,
}

# Output paths
OUTPUT_DIR = Path("/kaggle/working")
CHECKPOINT_DIR = OUTPUT_DIR / "checkpoints"
CHECKPOINT_DIR.mkdir(parents=True, exist_ok=True)

print("Configuration loaded")


In [None]:
def set_seed(seed: int):
    """Set random seeds for reproducibility."""
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

set_seed(CONFIG["seed"])
print(f"Random seed set to {CONFIG['seed']}")


## Load Training Data


In [None]:
def parse_fasta(fasta_path: str) -> dict[str, str]:
    """Parse a FASTA file and return a dict of protein_id -> sequence."""
    sequences = {}
    current_id = None
    current_seq = []
    
    with open(fasta_path) as f:
        for line in f:
            line = line.strip()
            if line.startswith(">"):
                if current_id is not None:
                    sequences[current_id] = "".join(current_seq)
                current_id = line[1:].split()[0]
                current_seq = []
            else:
                current_seq.append(line)
        
        if current_id is not None:
            sequences[current_id] = "".join(current_seq)
    
    return sequences

# Load training sequences
train_sequences = parse_fasta(CONFIG["train_fasta"])
print(f"Loaded {len(train_sequences)} training sequences")


In [None]:
def load_annotations(terms_path: str) -> dict[str, set[str]]:
    """Load protein annotations from train_terms.tsv."""
    annotations = {}
    
    with open(terms_path) as f:
        next(f)  # Skip header
        for line in f:
            parts = line.strip().split("\t")
            if len(parts) >= 2:
                protein_id = parts[0]
                term = parts[1]
                if protein_id not in annotations:
                    annotations[protein_id] = set()
                annotations[protein_id].add(term)
    
    return annotations

# Load annotations
train_annotations = load_annotations(CONFIG["train_terms"])
print(f"Loaded annotations for {len(train_annotations)} proteins")


In [None]:
def load_go_terms_from_obo(obo_path: str) -> dict[str, str]:
    """Load GO term -> namespace mapping from OBO file."""
    term_to_namespace = {}
    current_term = None
    current_namespace = None
    
    namespace_map = {
        "molecular_function": "MF",
        "biological_process": "BP",
        "cellular_component": "CC",
    }
    
    with open(obo_path) as f:
        for line in f:
            line = line.strip()
            if line == "[Term]":
                if current_term and current_namespace:
                    term_to_namespace[current_term] = current_namespace
                current_term = None
                current_namespace = None
            elif line.startswith("id: GO:"):
                current_term = line[4:]
            elif line.startswith("namespace:"):
                ns = line.split(": ")[1]
                current_namespace = namespace_map.get(ns)
        
        if current_term and current_namespace:
            term_to_namespace[current_term] = current_namespace
    
    return term_to_namespace

# Load GO term namespaces
term_to_ontology = load_go_terms_from_obo(CONFIG["go_obo"])
print(f"Loaded {len(term_to_ontology)} GO terms from ontology")


In [None]:
# Build term indices per ontology
all_terms = set()
for terms in train_annotations.values():
    all_terms.update(terms)

mf_terms = sorted([t for t in all_terms if term_to_ontology.get(t) == "MF"])
bp_terms = sorted([t for t in all_terms if term_to_ontology.get(t) == "BP"])
cc_terms = sorted([t for t in all_terms if term_to_ontology.get(t) == "CC"])

mf_term_to_idx = {t: i for i, t in enumerate(mf_terms)}
bp_term_to_idx = {t: i for i, t in enumerate(bp_terms)}
cc_term_to_idx = {t: i for i, t in enumerate(cc_terms)}

print(f"Term counts:")
print(f"  MF: {len(mf_terms)}")
print(f"  BP: {len(bp_terms)}")
print(f"  CC: {len(cc_terms)}")


## Generate ProtT5 Embeddings


In [None]:
from transformers import T5Tokenizer, T5EncoderModel

# Load ProtT5 model
print("Loading ProtT5-XL model...")
tokenizer = T5Tokenizer.from_pretrained("Rostlab/prot_t5_xl_uniref50", do_lower_case=False)
model_t5 = T5EncoderModel.from_pretrained("Rostlab/prot_t5_xl_uniref50")

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model_t5 = model_t5.to(device)
model_t5.eval()

print(f"ProtT5 loaded on {device}")


In [None]:
def generate_embeddings(sequences: dict[str, str], batch_size: int = 4, max_length: int = 1024):
    """Generate ProtT5 embeddings for all sequences."""
    protein_ids = list(sequences.keys())
    embeddings = []
    
    # Process in batches
    for i in tqdm(range(0, len(protein_ids), batch_size), desc="Generating embeddings"):
        batch_ids = protein_ids[i:i + batch_size]
        batch_seqs = [sequences[pid][:max_length] for pid in batch_ids]
        
        # Add spaces between amino acids (ProtT5 format)
        batch_seqs = [" ".join(list(seq)) for seq in batch_seqs]
        
        # Tokenize
        encoded = tokenizer(
            batch_seqs,
            return_tensors="pt",
            padding=True,
            truncation=True,
            max_length=max_length,
        ).to(device)
        
        # Generate embeddings
        with torch.no_grad():
            outputs = model_t5(**encoded)
            # Mean pooling over sequence length
            attention_mask = encoded["attention_mask"].unsqueeze(-1)
            hidden_states = outputs.last_hidden_state
            masked_hidden = hidden_states * attention_mask
            mean_embedding = masked_hidden.sum(dim=1) / attention_mask.sum(dim=1)
            embeddings.append(mean_embedding.cpu().numpy())
    
    embeddings = np.vstack(embeddings)
    return protein_ids, embeddings

# Generate embeddings
print("\nGenerating embeddings for training sequences...")
protein_ids, embeddings = generate_embeddings(
    train_sequences,
    batch_size=CONFIG["embed_batch_size"],
    max_length=CONFIG["max_seq_length"],
)

print(f"\nGenerated embeddings shape: {embeddings.shape}")

# Free GPU memory
del model_t5
torch.cuda.empty_cache()


## Create Dataset and DataLoaders


In [None]:
class ProteinDataset(Dataset):
    """Dataset for protein embeddings and GO term annotations."""
    
    def __init__(
        self,
        protein_ids: list[str],
        embeddings: np.ndarray,
        annotations: dict[str, set[str]],
        mf_term_to_idx: dict[str, int],
        bp_term_to_idx: dict[str, int],
        cc_term_to_idx: dict[str, int],
    ):
        self.protein_ids = protein_ids
        self.embeddings = embeddings
        self.annotations = annotations
        self.mf_term_to_idx = mf_term_to_idx
        self.bp_term_to_idx = bp_term_to_idx
        self.cc_term_to_idx = cc_term_to_idx
    
    def __len__(self):
        return len(self.protein_ids)
    
    def __getitem__(self, idx):
        protein_id = self.protein_ids[idx]
        embedding = torch.tensor(self.embeddings[idx], dtype=torch.float32)
        
        # Create target vectors
        terms = self.annotations.get(protein_id, set())
        
        targets_mf = torch.zeros(len(self.mf_term_to_idx), dtype=torch.float32)
        targets_bp = torch.zeros(len(self.bp_term_to_idx), dtype=torch.float32)
        targets_cc = torch.zeros(len(self.cc_term_to_idx), dtype=torch.float32)
        
        for term in terms:
            if term in self.mf_term_to_idx:
                targets_mf[self.mf_term_to_idx[term]] = 1.0
            elif term in self.bp_term_to_idx:
                targets_bp[self.bp_term_to_idx[term]] = 1.0
            elif term in self.cc_term_to_idx:
                targets_cc[self.cc_term_to_idx[term]] = 1.0
        
        return {
            "protein_id": protein_id,
            "embedding": embedding,
            "targets_mf": targets_mf,
            "targets_bp": targets_bp,
            "targets_cc": targets_cc,
        }

# Create full dataset
full_dataset = ProteinDataset(
    protein_ids=protein_ids,
    embeddings=embeddings,
    annotations=train_annotations,
    mf_term_to_idx=mf_term_to_idx,
    bp_term_to_idx=bp_term_to_idx,
    cc_term_to_idx=cc_term_to_idx,
)

print(f"Dataset size: {len(full_dataset)}")


In [None]:
# Split into train and validation
indices = list(range(len(full_dataset)))
train_indices, val_indices = train_test_split(
    indices,
    test_size=CONFIG["val_split"],
    random_state=CONFIG["seed"],
)

train_dataset = Subset(full_dataset, train_indices)
val_dataset = Subset(full_dataset, val_indices)

train_loader = DataLoader(
    train_dataset,
    batch_size=CONFIG["batch_size"],
    shuffle=True,
    num_workers=0,
    pin_memory=True,
)

val_loader = DataLoader(
    val_dataset,
    batch_size=CONFIG["batch_size"],
    shuffle=False,
    num_workers=0,
    pin_memory=True,
)

print(f"Train samples: {len(train_dataset)}, Val samples: {len(val_dataset)}")
print(f"Train batches: {len(train_loader)}, Val batches: {len(val_loader)}")


## Define Model


In [None]:
class MultiHeadMLP(nn.Module):
    """Multi-head MLP for multi-label GO term prediction."""
    
    def __init__(
        self,
        embedding_dim: int,
        hidden_dims: list[int],
        num_mf_terms: int,
        num_bp_terms: int,
        num_cc_terms: int,
        dropout: float = 0.1,
    ):
        super().__init__()
        
        self.embedding_dim = embedding_dim
        self.hidden_dims = hidden_dims
        self.num_mf_terms = num_mf_terms
        self.num_bp_terms = num_bp_terms
        self.num_cc_terms = num_cc_terms
        self.dropout_rate = dropout
        
        # Shared backbone
        layers = []
        in_dim = embedding_dim
        for hidden_dim in hidden_dims:
            layers.extend([
                nn.Linear(in_dim, hidden_dim),
                nn.ReLU(),
                nn.Dropout(dropout),
            ])
            in_dim = hidden_dim
        self.backbone = nn.Sequential(*layers)
        
        # Separate heads for each ontology
        self.head_mf = nn.Linear(hidden_dims[-1], num_mf_terms)
        self.head_bp = nn.Linear(hidden_dims[-1], num_bp_terms)
        self.head_cc = nn.Linear(hidden_dims[-1], num_cc_terms)
    
    def forward(self, x):
        features = self.backbone(x)
        return {
            "logits_mf": self.head_mf(features),
            "logits_bp": self.head_bp(features),
            "logits_cc": self.head_cc(features),
        }
    
    def get_config(self):
        return {
            "embedding_dim": self.embedding_dim,
            "hidden_dims": self.hidden_dims,
            "num_mf_terms": self.num_mf_terms,
            "num_bp_terms": self.num_bp_terms,
            "num_cc_terms": self.num_cc_terms,
            "dropout": self.dropout_rate,
        }

# Create model
model = MultiHeadMLP(
    embedding_dim=CONFIG["embedding_dim"],
    hidden_dims=CONFIG["hidden_dims"],
    num_mf_terms=len(mf_terms),
    num_bp_terms=len(bp_terms),
    num_cc_terms=len(cc_terms),
    dropout=CONFIG["dropout"],
)

model = model.to(device)

num_params = sum(p.numel() for p in model.parameters())
print(f"Model parameters: {num_params:,}")


## Training Loop


In [None]:
class EarlyStopping:
    """Early stopping handler."""
    
    def __init__(self, patience: int = 10):
        self.patience = patience
        self.counter = 0
        self.best_loss = float("inf")
    
    def __call__(self, val_loss: float) -> bool:
        if val_loss < self.best_loss:
            self.best_loss = val_loss
            self.counter = 0
            return False
        else:
            self.counter += 1
            return self.counter >= self.patience

# Setup training
optimizer = Adam(
    model.parameters(),
    lr=CONFIG["learning_rate"],
    weight_decay=CONFIG["weight_decay"],
)
criterion = nn.BCEWithLogitsLoss()
early_stopping = EarlyStopping(patience=CONFIG["patience"])
best_val_loss = float("inf")

# Initialize wandb
if USE_WANDB:
    wandb.init(
        project="cafa6",
        config={
            **CONFIG,
            "num_mf_terms": len(mf_terms),
            "num_bp_terms": len(bp_terms),
            "num_cc_terms": len(cc_terms),
            "num_train_samples": len(train_dataset),
            "num_val_samples": len(val_dataset),
            "model_params": num_params,
        },
        reinit=True,
    )
    wandb.watch(model, log="gradients", log_freq=100)
    print("Wandb run initialized")


In [None]:
def train_epoch():
    """Train for one epoch."""
    model.train()
    total_loss = 0.0
    num_batches = 0
    
    for batch in tqdm(train_loader, desc="Training", leave=False):
        embeddings = batch["embedding"].to(device)
        targets_mf = batch["targets_mf"].to(device)
        targets_bp = batch["targets_bp"].to(device)
        targets_cc = batch["targets_cc"].to(device)
        
        optimizer.zero_grad()
        outputs = model(embeddings)
        
        loss_mf = criterion(outputs["logits_mf"], targets_mf)
        loss_bp = criterion(outputs["logits_bp"], targets_bp)
        loss_cc = criterion(outputs["logits_cc"], targets_cc)
        loss = (loss_mf + loss_bp + loss_cc) / 3.0
        
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()
        num_batches += 1
    
    return total_loss / max(num_batches, 1)

@torch.no_grad()
def validate():
    """Validate the model."""
    model.eval()
    total_loss = 0.0
    num_batches = 0
    
    for batch in tqdm(val_loader, desc="Validation", leave=False):
        embeddings = batch["embedding"].to(device)
        targets_mf = batch["targets_mf"].to(device)
        targets_bp = batch["targets_bp"].to(device)
        targets_cc = batch["targets_cc"].to(device)
        
        outputs = model(embeddings)
        
        loss_mf = criterion(outputs["logits_mf"], targets_mf)
        loss_bp = criterion(outputs["logits_bp"], targets_bp)
        loss_cc = criterion(outputs["logits_cc"], targets_cc)
        loss = (loss_mf + loss_bp + loss_cc) / 3.0
        
        total_loss += loss.item()
        num_batches += 1
    
    return total_loss / max(num_batches, 1)


In [None]:
# Training loop
print(f"\nStarting training for {CONFIG['epochs']} epochs...")
print(f"Device: {device}")
print("-" * 50)

history = {"train_loss": [], "val_loss": []}

for epoch in range(CONFIG["epochs"]):
    train_loss = train_epoch()
    val_loss = validate()
    
    history["train_loss"].append(train_loss)
    history["val_loss"].append(val_loss)
    
    print(f"Epoch {epoch + 1}/{CONFIG['epochs']} | Train Loss: {train_loss:.4f} | Val Loss: {val_loss:.4f}")
    
    # Log to wandb
    if USE_WANDB:
        wandb.log({
            "epoch": epoch + 1,
            "train_loss": train_loss,
            "val_loss": val_loss,
            "best_val_loss": min(best_val_loss, val_loss),
        })
    
    # Save best model
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        checkpoint = {
            "epoch": epoch,
            "model_state_dict": model.state_dict(),
            "optimizer_state_dict": optimizer.state_dict(),
            "val_loss": val_loss,
            "model_config": model.get_config(),
        }
        torch.save(checkpoint, CHECKPOINT_DIR / "best_model.pt")
        print("  -> Saved best model")
    
    # Early stopping
    if early_stopping(val_loss):
        print(f"\nEarly stopping triggered at epoch {epoch + 1}")
        break

# Save last model
checkpoint = {
    "epoch": epoch,
    "model_state_dict": model.state_dict(),
    "optimizer_state_dict": optimizer.state_dict(),
    "val_loss": val_loss,
    "model_config": model.get_config(),
}
torch.save(checkpoint, CHECKPOINT_DIR / "last_model.pt")

# Finish wandb run
if USE_WANDB:
    wandb.finish()

print("\n" + "=" * 50)
print("Training complete!")
print(f"Best validation loss: {best_val_loss:.4f}")


In [None]:
# Save term indices for inference
np.savez(
    CHECKPOINT_DIR / "term_index.npz",
    mf_terms=mf_terms,
    bp_terms=bp_terms,
    cc_terms=cc_terms,
)
print(f"Saved term index to {CHECKPOINT_DIR / 'term_index.npz'}")


In [None]:
# Plot training history (local visualization)
import matplotlib.pyplot as plt

plt.figure(figsize=(10, 6))
plt.plot(history["train_loss"], label="Train Loss", linewidth=2)
plt.plot(history["val_loss"], label="Val Loss", linewidth=2)
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.title("Training History")
plt.legend()
plt.grid(True, alpha=0.3)
plt.savefig(OUTPUT_DIR / "training_history.png", dpi=150, bbox_inches="tight")
plt.show()

print(f"\nPlot saved to {OUTPUT_DIR / 'training_history.png'}")


## Download Checkpoint

After training completes, download the checkpoint files from:
- `/kaggle/working/checkpoints/best_model.pt`
- `/kaggle/working/checkpoints/term_index.npz`

You can use these for inference on the test set.


In [None]:
# List output files
print("Output files:")
for f in CHECKPOINT_DIR.iterdir():
    print(f"  {f.name} ({f.stat().st_size / 1024 / 1024:.1f} MB)")
