# üß¨ Antibody-Antigen Binding Prediction Training

## Complete Step-by-Step Guide with Explanations

This notebook trains a deep learning model to predict antibody-antigen binding affinity (pKd values).

**What you'll learn:**
- How to set up a production-ready training pipeline
- Modern best practices (warmup, early stopping, comprehensive metrics)
- How to properly evaluate deep learning models

**Architecture:**
- IgT5 encoder (antibody sequences)
- ESM-2 encoder (antigen sequences)  
- Trainable regression head

**Total runtime:** ~2-3 hours on Tesla T4

**Data location:** Google Drive folder `AbAg_Training_02`

---

# Step 1: Environment Setup

**What this does:**
- Checks if GPU is available
- Installs required packages
- Enables optimization flags (TF32, cuDNN auto-tuner)

**Why it matters:**
- GPU is essential for training (50x faster than CPU)
- Optimization flags give 20-30% speedup
- Ensures all dependencies are installed

In [None]:
# Check GPU availability
import torch
import sys

print(f"Python version: {sys.version}")
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")

if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")
    device = torch.device('cuda')
else:
    print("‚ö†Ô∏è WARNING: GPU not available! Training will be very slow.")
    device = torch.device('cpu')

print(f"Using device: {device}")

In [None]:
# Install required packages (Colab-compatible versions)
print("Installing required packages...\n")

# Install packages compatible with current Colab environment
!pip install -q transformers>=4.41.0
!pip install -q sentencepiece

print("\n‚úÖ All packages installed successfully!")
print("‚úÖ Using Colab's pre-installed numpy, pandas, scikit-learn, scipy")

In [None]:
# Enable optimization flags
import torch

# Enable TF32 for faster matrix multiplication on Ampere GPUs
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True

# Enable cuDNN auto-tuner for optimal convolution algorithms
torch.backends.cudnn.benchmark = True

# Disable deterministic mode for speed (reproducibility not critical here)
torch.backends.cudnn.deterministic = False

print("‚úÖ Optimizations enabled:")
print("  ‚Ä¢ TF32 matrix multiplication")
print("  ‚Ä¢ cuDNN auto-tuner")
print("  ‚Ä¢ Non-deterministic mode (faster)")

# Step 2: Import Libraries & Define Utilities

**What this does:**
- Imports all necessary libraries
- Defines helper functions for metrics, early stopping, schedulers

**Why it matters:**
- Metrics: We need to measure performance accurately (12 different metrics)
- Early stopping: Prevents overfitting by stopping when performance plateaus
- LR scheduler: Warmup + cosine decay improves training stability

In [None]:
# Core imports
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from pathlib import Path
import json
import os
from tqdm.auto import tqdm

# PyTorch
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader

# Transformers (for pre-trained protein language models)
from transformers import (
    T5Tokenizer, T5EncoderModel,  # IgT5 for antibodies
    AutoTokenizer, AutoModel        # ESM-2 for antigens
)

# Scikit-learn
from sklearn.model_selection import train_test_split
from sklearn.metrics import mean_squared_error, mean_absolute_error, r2_score
from scipy import stats

print("‚úÖ All libraries imported successfully!")

In [None]:
# Comprehensive metrics function
def compute_comprehensive_metrics(targets, predictions):
    """
    Compute all 12 standard metrics for regression + classification.
    
    Args:
        targets: True pKd values (numpy array)
        predictions: Predicted pKd values (numpy array)
    
    Returns:
        Dictionary with 12 metrics
    """
    # Regression metrics
    mse = mean_squared_error(targets, predictions)
    rmse = np.sqrt(mse)
    mae = mean_absolute_error(targets, predictions)
    r2 = r2_score(targets, predictions)
    
    # Correlation metrics (with p-values for statistical significance)
    spearman, spearman_p = stats.spearmanr(targets, predictions)
    pearson, pearson_p = stats.pearsonr(targets, predictions)
    
    # Classification metrics for strong binders (pKd >= 9)
    strong_binders = targets >= 9.0
    predicted_strong = predictions >= 9.0
    
    # True positives, false positives, etc.
    tp = np.sum(strong_binders & predicted_strong)
    fp = np.sum(~strong_binders & predicted_strong)
    tn = np.sum(~strong_binders & ~predicted_strong)
    fn = np.sum(strong_binders & ~predicted_strong)
    
    # Calculate metrics (handle division by zero)
    recall = tp / (tp + fn) if (tp + fn) > 0 else 0.0
    precision = tp / (tp + fp) if (tp + fp) > 0 else 0.0
    specificity = tn / (tn + fp) if (tn + fp) > 0 else 0.0
    f1 = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0.0
    
    return {
        # Regression metrics
        'mse': mse,
        'rmse': rmse,
        'mae': mae,
        'r2': r2,
        # Correlation metrics
        'spearman': spearman,
        'spearman_p': spearman_p,
        'pearson': pearson,
        'pearson_p': pearson_p,
        # Classification metrics
        'recall_pkd9': recall * 100,
        'precision_pkd9': precision * 100,
        'f1_pkd9': f1 * 100,
        'specificity_pkd9': specificity * 100,
        # Sample counts
        'n_samples': len(targets),
        'n_strong_binders': int(strong_binders.sum())
    }

print("‚úÖ Metrics function defined")

In [None]:
# Early Stopping class
class EarlyStopping:
    """
    Monitors validation metric and stops training when no improvement.
    
    Args:
        patience: How many epochs to wait for improvement
        min_delta: Minimum change to qualify as improvement
        mode: 'max' for metrics like Spearman, 'min' for loss
    """
    def __init__(self, patience=10, min_delta=0.0001, mode='max', verbose=True):
        self.patience = patience
        self.min_delta = min_delta
        self.mode = mode
        self.verbose = verbose
        self.counter = 0
        self.best_score = None
        self.early_stop = False
        self.best_epoch = 0
    
    def __call__(self, score, epoch):
        """
        Check if training should stop.
        
        Args:
            score: Current validation metric
            epoch: Current epoch number
        
        Returns:
            True if should stop, False otherwise
        """
        if self.best_score is None:
            self.best_score = score
            self.best_epoch = epoch
            return False
        
        # Check for improvement
        if self.mode == 'max':
            improved = score > (self.best_score + self.min_delta)
        else:
            improved = score < (self.best_score - self.min_delta)
        
        if improved:
            self.best_score = score
            self.best_epoch = epoch
            self.counter = 0
        else:
            self.counter += 1
            if self.verbose:
                print(f"   No improvement for {self.counter}/{self.patience} epochs")
            
            if self.counter >= self.patience:
                self.early_stop = True
                if self.verbose:
                    print(f"\n‚ö†Ô∏è Early stopping triggered!")
                    print(f"   No improvement for {self.patience} epochs")
                    print(f"   Best score: {self.best_score:.4f} at epoch {self.best_epoch+1}")
                return True
        
        return False

print("‚úÖ EarlyStopping class defined")

In [None]:
# Learning rate scheduler with warmup
def get_warmup_cosine_scheduler(optimizer, warmup_epochs, total_epochs):
    """
    Create LR scheduler with linear warmup followed by cosine decay.
    
    Schedule:
    - Epochs 0 to warmup_epochs: Linear increase from 0 to max_lr
    - Epochs warmup_epochs to total_epochs: Cosine decay to ~0
    
    Args:
        optimizer: PyTorch optimizer
        warmup_epochs: Number of epochs for warmup phase
        total_epochs: Total number of training epochs
    
    Returns:
        LR scheduler
    """
    def lr_lambda(epoch):
        # Warmup phase: linear increase
        if epoch < warmup_epochs:
            return float(epoch) / float(max(1, warmup_epochs))
        
        # Cosine decay phase
        progress = float(epoch - warmup_epochs) / float(max(1, total_epochs - warmup_epochs))
        return max(0.0, 0.5 * (1.0 + np.cos(np.pi * progress)))
    
    return torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)

print("‚úÖ LR scheduler function defined")

In [None]:
# Focal MSE Loss with label smoothing
class FocalMSELoss(nn.Module):
    """
    Focal MSE loss: Focuses more on hard-to-predict samples.
    
    Regular MSE: loss = (pred - target)^2
    Focal MSE: loss = (1 + (pred - target)^2)^gamma * (pred - target)^2
    
    This increases weight on large errors, helping model focus on outliers.
    
    Args:
        gamma: Focusing parameter (higher = more focus on hard samples)
        label_smoothing: Smooth targets toward mean (reduces overconfidence)
    """
    def __init__(self, gamma=2.0, label_smoothing=0.0):
        super().__init__()
        self.gamma = gamma
        self.label_smoothing = label_smoothing
    
    def forward(self, pred, target):
        # Apply label smoothing if enabled
        if self.label_smoothing > 0:
            target_mean = target.mean()
            target = (1 - self.label_smoothing) * target + self.label_smoothing * target_mean
        
        # Compute focal MSE
        mse = (pred - target) ** 2
        focal_weight = (1 + mse) ** self.gamma
        return (focal_weight * mse).mean()

print("‚úÖ FocalMSELoss class defined")

# Step 3: Mount Google Drive & Load Data

**What this does:**
- Mounts your Google Drive
- Loads CSV from `AbAg_Training_02` folder
- Explores data distribution
- Splits into train/validation/test (70%/15%/15%)
- Creates PyTorch Dataset and DataLoader

**Why it matters:**
- No need to upload data each time
- Results saved directly to Drive
- Proper data splitting prevents data leakage
- DataLoader handles batching and shuffling automatically

In [None]:
# Mount Google Drive
from google.colab import drive

print("Mounting Google Drive...")
drive.mount('/content/drive')
print("‚úÖ Google Drive mounted!")

# Set up paths
DRIVE_DIR = '/content/drive/MyDrive/AbAg_Training_02'
OUTPUT_DIR = f'{DRIVE_DIR}/training_output'

# Create output directory if it doesn't exist
os.makedirs(OUTPUT_DIR, exist_ok=True)

print(f"\nüìÇ Working directories:")
print(f"   Data directory: {DRIVE_DIR}")
print(f"   Output directory: {OUTPUT_DIR}")

In [None]:
# List files in your Drive directory
print("\nüìÅ Files in AbAg_Training_02:")
files_in_dir = os.listdir(DRIVE_DIR)
csv_files = [f for f in files_in_dir if f.endswith('.csv')]

for f in csv_files:
    file_path = os.path.join(DRIVE_DIR, f)
    file_size = os.path.getsize(file_path) / (1024*1024)  # MB
    print(f"   ‚Ä¢ {f} ({file_size:.2f} MB)")

if not csv_files:
    print("   ‚ö†Ô∏è No CSV files found!")
    print("   Please upload your dataset to the AbAg_Training_02 folder")

In [None]:
# Load dataset - MODIFY THIS LINE to specify your CSV filename
CSV_FILENAME = 'agab_phase2_full.csv'  # ‚Üê CHANGE THIS to your actual CSV filename

csv_path = os.path.join(DRIVE_DIR, CSV_FILENAME)

print(f"Loading dataset from: {csv_path}")
df = pd.read_csv(csv_path)

print("\nüìä Dataset Overview:")
print(f"   Total samples: {len(df):,}")
print(f"   Columns: {list(df.columns)}")
print(f"\n   pKd Statistics:")
print(f"      Min:  {df['pKd'].min():.2f}")
print(f"      Max:  {df['pKd'].max():.2f}")
print(f"      Mean: {df['pKd'].mean():.2f}")
print(f"      Std:  {df['pKd'].std():.2f}")

# Count strong binders (pKd >= 9)
strong_binders = (df['pKd'] >= 9.0).sum()
strong_pct = 100 * strong_binders / len(df)
print(f"\n   Strong Binders (pKd‚â•9): {strong_binders:,} ({strong_pct:.2f}%)")

# Show first few rows
print("\n   First 3 samples:")
print(df.head(3))

In [None]:
# Split into train/val/test (70%/15%/15%)
print("Splitting data...\n")

# First split: 70% train, 30% temp
train_df, temp_df = train_test_split(df, test_size=0.3, random_state=42)

# Second split: Split temp into 50/50 -> 15% val, 15% test
val_df, test_df = train_test_split(temp_df, test_size=0.5, random_state=42)

# Create quick validation set (5% of val set for speed during training)
val_df_quick = val_df.sample(frac=0.05, random_state=42)

print("üìä Dataset splits:")
print(f"   Train:     {len(train_df):,} samples ({100*len(train_df)/len(df):.1f}%)")
print(f"   Val:       {len(val_df):,} samples ({100*len(val_df)/len(df):.1f}%)")
print(f"   Val Quick: {len(val_df_quick):,} samples ({100*len(val_df_quick)/len(df):.2f}%)")
print(f"   Test:      {len(test_df):,} samples ({100*len(test_df)/len(df):.1f}%)")
print("\n   Note: During training we use 'Val Quick' for speed.")
print("         After training we evaluate on full Val and Test sets.")

In [None]:
# Define PyTorch Dataset class
class AbAgDataset(Dataset):
    """
    PyTorch Dataset for antibody-antigen pairs.
    
    Returns:
        Dictionary with:
        - antibody_seqs: List of antibody sequences (strings)
        - antigen_seqs: List of antigen sequences (strings)
        - pKd: Tensor of binding affinities
    """
    def __init__(self, dataframe):
        self.data = dataframe.reset_index(drop=True)
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        row = self.data.iloc[idx]
        return {
            'antibody_seqs': row['antibody_sequence'],
            'antigen_seqs': row['antigen_sequence'],
            'pKd': torch.tensor(row['pKd'], dtype=torch.float32)
        }

print("‚úÖ Dataset class defined")

In [None]:
# Custom collate function for DataLoader
def collate_fn(batch):
    """
    Combines individual samples into a batch.
    
    Since sequences have variable length, we keep them as lists.
    Tokenization happens inside the model forward pass.
    """
    return {
        'antibody_seqs': [item['antibody_seqs'] for item in batch],
        'antigen_seqs': [item['antigen_seqs'] for item in batch],
        'pKd': torch.stack([item['pKd'] for item in batch])
    }

print("‚úÖ Collate function defined")

In [None]:
# Create DataLoaders
BATCH_SIZE = 16
NUM_WORKERS = 2

# Create datasets
train_dataset = AbAgDataset(train_df)
val_dataset_quick = AbAgDataset(val_df_quick)  # For during training
val_dataset_full = AbAgDataset(val_df)          # For final evaluation
test_dataset = AbAgDataset(test_df)             # For final evaluation

# Create data loaders
train_loader = DataLoader(
    train_dataset,
    batch_size=BATCH_SIZE,
    shuffle=True,  # Shuffle for training
    num_workers=NUM_WORKERS,
    collate_fn=collate_fn,
    pin_memory=True  # Faster GPU transfer
)

val_loader_quick = DataLoader(
    val_dataset_quick,
    batch_size=BATCH_SIZE,
    shuffle=False,  # No shuffle for validation
    num_workers=NUM_WORKERS,
    collate_fn=collate_fn,
    pin_memory=True
)

val_loader_full = DataLoader(
    val_dataset_full,
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=NUM_WORKERS,
    collate_fn=collate_fn,
    pin_memory=True
)

test_loader = DataLoader(
    test_dataset,
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=NUM_WORKERS,
    collate_fn=collate_fn,
    pin_memory=True
)

print("‚úÖ DataLoaders created:")
print(f"   ‚Ä¢ train_loader: {len(train_loader):,} batches")
print(f"   ‚Ä¢ val_loader_quick: {len(val_loader_quick):,} batches")
print(f"   ‚Ä¢ val_loader_full: {len(val_loader_full):,} batches")
print(f"   ‚Ä¢ test_loader: {len(test_loader):,} batches")

# Step 4: Model Architecture

**What this does:**
- Loads pre-trained IgT5 (antibody encoder)
- Loads pre-trained ESM-2 (antigen encoder)
- Freezes encoder weights (saves memory, faster training)
- Creates trainable regression head (1792D ‚Üí 1)

**Why this architecture:**
- Pre-trained encoders already understand protein sequences
- We only train the head to predict binding affinity
- This requires much less data than training from scratch

**Model size:**
- Total parameters: ~872M
- Trainable: ~2M (only the head)
- Frozen: ~870M (the encoders)

In [None]:
# Define model architecture
class IgT5ESM2Model(nn.Module):
    """
    Dual-encoder model for antibody-antigen binding prediction.
    
    Architecture:
    1. IgT5 encodes antibody sequence -> 512D embedding
    2. ESM-2 encodes antigen sequence -> 1280D embedding
    3. Concatenate -> 1792D combined embedding
    4. Regression head (MLP) -> single pKd value
    
    Args:
        dropout: Dropout rate for regularization
        freeze_encoders: Whether to freeze pre-trained weights
        use_checkpointing: Use gradient checkpointing (saves memory)
    """
    def __init__(self, dropout=0.35, freeze_encoders=True, use_checkpointing=True):
        super().__init__()
        
        print("üî® Building model...")
        
        # Load IgT5 for antibody sequences
        print("  üì• Loading IgT5 (antibody encoder)...")
        self.igt5_tokenizer = T5Tokenizer.from_pretrained("Exscientia/IgT5")
        self.igt5_model = T5EncoderModel.from_pretrained("Exscientia/IgT5")
        
        # Load ESM-2 for antigen sequences
        print("  üì• Loading ESM-2 (antigen encoder)...")
        self.esm2_tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t33_650M_UR50D")
        self.esm2_model = AutoModel.from_pretrained("facebook/esm2_t33_650M_UR50D")
        
        # Freeze encoders if requested
        if freeze_encoders:
            print("  üîí Freezing encoder weights...")
            for param in self.igt5_model.parameters():
                param.requires_grad = False
            for param in self.esm2_model.parameters():
                param.requires_grad = False
        
        # Enable gradient checkpointing for memory efficiency
        if use_checkpointing:
            self.igt5_model.gradient_checkpointing_enable()
            self.esm2_model.gradient_checkpointing_enable()
        
        # Get embedding dimensions
        self.igt5_dim = self.igt5_model.config.d_model  # 512
        self.esm2_dim = self.esm2_model.config.hidden_size  # 1280
        self.combined_dim = self.igt5_dim + self.esm2_dim  # 1792
        
        print(f"  üìè Embedding dimensions:")
        print(f"     IgT5: {self.igt5_dim}D")
        print(f"     ESM-2: {self.esm2_dim}D")
        print(f"     Combined: {self.combined_dim}D")
        
        # Build regression head: 1792 -> 1024 -> 512 -> 256 -> 128 -> 1
        print("  üß† Building regression head...")
        self.regression_head = nn.Sequential(
            nn.Linear(self.combined_dim, 1024),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.LayerNorm(1024),
            
            nn.Linear(1024, 512),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.LayerNorm(512),
            
            nn.Linear(512, 256),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.LayerNorm(256),
            
            nn.Linear(256, 128),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.LayerNorm(128),
            
            nn.Linear(128, 1)
        )
        
        # Count parameters
        total_params = sum(p.numel() for p in self.parameters())
        trainable_params = sum(p.numel() for p in self.parameters() if p.requires_grad)
        frozen_params = total_params - trainable_params
        
        print(f"\n  üìä Model Statistics:")
        print(f"     Total parameters: {total_params/1e6:.1f}M")
        print(f"     Trainable parameters: {trainable_params/1e6:.1f}M")
        print(f"     Frozen parameters: {frozen_params/1e6:.1f}M")
    
    def forward(self, antibody_seqs, antigen_seqs, device):
        """
        Forward pass.
        
        Args:
            antibody_seqs: List of antibody sequences (strings)
            antigen_seqs: List of antigen sequences (strings)
            device: Device to run on (cuda/cpu)
        
        Returns:
            Predicted pKd values (tensor, shape [batch_size])
        """
        # Tokenize antibody sequences
        antibody_tokens = self.igt5_tokenizer(
            antibody_seqs,
            return_tensors='pt',
            padding=True,
            truncation=True,
            max_length=512
        ).to(device)
        
        # Tokenize antigen sequences
        antigen_tokens = self.esm2_tokenizer(
            antigen_seqs,
            return_tensors='pt',
            padding=True,
            truncation=True,
            max_length=1024
        ).to(device)
        
        # Encode antibody (mean pooling over sequence)
        with torch.amp.autocast('cuda', dtype=torch.bfloat16):
            antibody_outputs = self.igt5_model(**antibody_tokens)
            antibody_embedding = antibody_outputs.last_hidden_state.mean(dim=1)  # [batch, 512]
            
            # Encode antigen (mean pooling over sequence)
            antigen_outputs = self.esm2_model(**antigen_tokens)
            antigen_embedding = antigen_outputs.last_hidden_state.mean(dim=1)  # [batch, 1280]
            
            # Concatenate embeddings
            combined = torch.cat([antibody_embedding, antigen_embedding], dim=1)  # [batch, 1792]
            
            # Regression head
            pKd_pred = self.regression_head(combined).squeeze(-1)  # [batch]
        
        return pKd_pred

print("‚úÖ Model class defined")

In [None]:
# Instantiate model and move to GPU
model = IgT5ESM2Model(
    dropout=0.35,
    freeze_encoders=True,
    use_checkpointing=True
)
model = model.to(device)

print(f"\n‚úÖ Model built successfully!")
print(f"‚úÖ Model moved to {device}")

# Step 5: Training Configuration

**What this does:**
- Sets all hyperparameters
- Creates optimizer (AdamW with weight decay for L2 regularization)
- Creates LR scheduler (warmup + cosine decay)
- Creates loss function (Focal MSE + label smoothing)
- Initializes early stopping

**Why these choices:**
- LR=3e-3: Fast convergence without instability
- Weight decay=0.02: Strong regularization to prevent overfitting
- Dropout=0.35: Good balance (not too aggressive)
- Warmup=5 epochs: Stabilizes early training
- Early stopping patience=10: Gives model time to improve

In [None]:
# Hyperparameters
config = {
    'epochs': 50,
    'batch_size': 16,
    'lr': 3e-3,
    'weight_decay': 0.02,           # L2 regularization
    'dropout': 0.35,
    'warmup_epochs': 5,
    'early_stopping_patience': 10,
    'label_smoothing': 0.05,        # Prevents overconfident predictions
    'max_grad_norm': 1.0,           # Gradient clipping
    'validation_frequency': 1       # Validate every N epochs
}

# Optimizer with L2 regularization (weight decay)
optimizer = torch.optim.AdamW(
    model.parameters(),
    lr=config['lr'],
    weight_decay=config['weight_decay'],
    fused=True  # Faster on newer GPUs
)

# LR Scheduler with warmup
scheduler = get_warmup_cosine_scheduler(
    optimizer,
    warmup_epochs=config['warmup_epochs'],
    total_epochs=config['epochs']
)

# Loss function with label smoothing
criterion = FocalMSELoss(
    gamma=2.0,
    label_smoothing=config['label_smoothing']
)

# Early stopping
early_stopping = EarlyStopping(
    patience=config['early_stopping_patience'],
    min_delta=0.0001,
    mode='max'  # Maximize Spearman correlation
)

print("‚úÖ Training configuration complete!")
print(f"\nüìä Configuration:")
for key, value in config.items():
    print(f"   {key}: {value}")

# Step 6: Training Loop

**What this does:**
- Trains model for up to 50 epochs
- Quick validation every epoch (on 5% of val set for speed)
- Gradient clipping for stability
- Early stopping to prevent overfitting
- Saves best model to Google Drive

**Expected runtime:**
- Tesla T4: ~3 minutes per epoch ‚Üí ~1.5-2.5 hours total
- V100: ~2 minutes per epoch ‚Üí ~1-1.5 hours total

**What you'll see:**
- Progress bar for each epoch
- Training loss decreasing
- Validation Spearman increasing (hopefully!)
- "‚úÖ Saved best model" when new best is found
- Early stopping message when training stops

In [None]:
# Training function
def train_epoch(model, loader, optimizer, criterion, device, epoch, max_grad_norm):
    """
    Train for one epoch.
    
    Args:
        model: The neural network
        loader: Training data loader
        optimizer: PyTorch optimizer
        criterion: Loss function
        device: cuda or cpu
        epoch: Current epoch number
        max_grad_norm: Gradient clipping threshold
    
    Returns:
        Average training loss for the epoch
    """
    model.train()  # Set model to training mode
    total_loss = 0
    
    pbar = tqdm(loader, desc=f"Epoch {epoch+1}")
    for batch in pbar:
        antibody_seqs = batch['antibody_seqs']
        antigen_seqs = batch['antigen_seqs']
        targets = batch['pKd'].to(device)
        
        # Forward pass with mixed precision (BFloat16)
        with torch.amp.autocast('cuda', dtype=torch.bfloat16):
            predictions = model(antibody_seqs, antigen_seqs, device)
            loss = criterion(predictions, targets)
        
        # Backward pass
        optimizer.zero_grad()
        loss.backward()
        
        # Gradient clipping (prevents exploding gradients)
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm)
        
        # Optimizer step
        optimizer.step()
        
        # Track loss
        total_loss += loss.item()
        pbar.set_postfix({'loss': f'{loss.item():.4f}'})
    
    return total_loss / len(loader)

print("‚úÖ Training function defined")

In [None]:
# Evaluation function
def eval_model(model, loader, device, desc="Evaluating"):
    """
    Evaluate model on validation or test set.
    
    Args:
        model: The neural network
        loader: Validation/test data loader
        device: cuda or cpu
        desc: Description for progress bar
    
    Returns:
        metrics: Dictionary with 12 metrics
        predictions: Numpy array of predictions
        targets: Numpy array of true values
    """
    model.eval()  # Set model to evaluation mode
    predictions = []
    targets = []
    
    with torch.no_grad():  # Disable gradient computation
        for batch in tqdm(loader, desc=desc):
            antibody_seqs = batch['antibody_seqs']
            antigen_seqs = batch['antigen_seqs']
            batch_targets = batch['pKd'].to(device)
            
            # Forward pass with mixed precision
            with torch.amp.autocast('cuda', dtype=torch.bfloat16):
                batch_predictions = model(antibody_seqs, antigen_seqs, device)
            
            # Collect results
            predictions.extend(batch_predictions.float().cpu().numpy())
            targets.extend(batch_targets.float().cpu().numpy())
    
    # Convert to numpy arrays
    predictions = np.array(predictions)
    targets = np.array(targets)
    
    # Compute all metrics
    metrics = compute_comprehensive_metrics(targets, predictions)
    return metrics, predictions, targets

print("‚úÖ Evaluation function defined")

In [None]:
# Main training loop
print("="*70)
print("STARTING TRAINING")
print("="*70)

# Model checkpoint path (saved to Google Drive)
model_save_path = os.path.join(OUTPUT_DIR, 'best_model.pth')

best_spearman = -1
training_history = {
    'train_loss': [],
    'val_spearman': [],
    'val_recall': [],
    'epoch': []
}

for epoch in range(config['epochs']):
    print(f"\nEpoch {epoch+1}/{config['epochs']}")
    print("-"*70)
    
    # Train for one epoch
    train_loss = train_epoch(
        model, train_loader, optimizer, criterion, device,
        epoch, config['max_grad_norm']
    )
    print(f"Train Loss: {train_loss:.4f}")
    
    # Validate every N epochs
    if (epoch + 1) % config['validation_frequency'] == 0:
        val_metrics, _, _ = eval_model(model, val_loader_quick, device, "Quick Val")
        val_spearman = val_metrics['spearman']
        val_recall = val_metrics['recall_pkd9']
        
        print(f"Val Spearman: {val_spearman:.4f} | Recall@pKd‚â•9: {val_recall:.2f}%")
        
        # Save best model to Google Drive
        if val_spearman > best_spearman:
            best_spearman = val_spearman
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'val_spearman': val_spearman,
                'config': config
            }, model_save_path)
            print(f"‚úÖ Saved best model to Drive: {model_save_path}")
        
        # Early stopping check
        if early_stopping(val_spearman, epoch):
            print(f"\n‚õî Stopping at epoch {epoch+1}")
            break
        
        # Record history
        training_history['train_loss'].append(train_loss)
        training_history['val_spearman'].append(val_spearman)
        training_history['val_recall'].append(val_recall)
        training_history['epoch'].append(epoch + 1)
    
    # LR scheduler step
    scheduler.step()
    current_lr = optimizer.param_groups[0]['lr']
    print(f"Learning Rate: {current_lr:.6f}")

print(f"\n{'='*70}")
print(f"TRAINING COMPLETE!")
print(f"Best Validation Spearman: {best_spearman:.4f}")
print(f"Model saved to: {model_save_path}")
print(f"{'='*70}")

# Step 7: Comprehensive Evaluation

**What this does:**
- Loads the best model from Google Drive
- Evaluates on FULL validation set (100% of val data)
- Evaluates on TEST set (100% of test data) - **This is your TRUE performance!**
- Computes all 12 metrics for both sets
- Saves predictions and metrics to Google Drive

**Why this matters:**
- During training we only validated on 5% (for speed)
- Now we get the full, accurate assessment
- Test set was completely unseen during training
- Test performance is what you should report in papers

**Expected results:**
- Spearman: 0.40-0.45
- RMSE: 1.2-1.4 pKd units
- Recall@pKd‚â•9: 95-100%

In [None]:
# Load best model from Google Drive
print("="*70)
print("FINAL COMPREHENSIVE EVALUATION")
print("="*70)

print(f"\nLoading best model from Drive: {model_save_path}")
checkpoint = torch.load(model_save_path)
model.load_state_dict(checkpoint['model_state_dict'])
print(f"‚úÖ Loaded model from epoch {checkpoint['epoch']+1}")
print(f"   Best quick validation Spearman: {checkpoint['val_spearman']:.4f}")

In [None]:
# Evaluate on FULL validation set
print("\n" + "-"*70)
print(f"Evaluating on FULL validation set ({len(val_dataset_full):,} samples)...")
print("-"*70)

val_metrics, val_preds, val_targets = eval_model(
    model, val_loader_full, device, "Full Validation"
)

print(f"\nüìä FULL VALIDATION METRICS:")
print(f"  Samples: {val_metrics['n_samples']:,}")
print(f"  Strong Binders (pKd‚â•9): {val_metrics['n_strong_binders']}")
print(f"\n  Regression Metrics:")
print(f"    RMSE:        {val_metrics['rmse']:.4f}")
print(f"    MAE:         {val_metrics['mae']:.4f}")
print(f"    MSE:         {val_metrics['mse']:.4f}")
print(f"    R¬≤:          {val_metrics['r2']:.4f}")
print(f"\n  Correlation Metrics:")
print(f"    Spearman œÅ:  {val_metrics['spearman']:.4f} (p={val_metrics['spearman_p']:.2e})")
print(f"    Pearson r:   {val_metrics['pearson']:.4f} (p={val_metrics['pearson_p']:.2e})")
print(f"\n  Classification Metrics (pKd‚â•9):")
print(f"    Recall:      {val_metrics['recall_pkd9']:.2f}%")
print(f"    Precision:   {val_metrics['precision_pkd9']:.2f}%")
print(f"    F1-Score:    {val_metrics['f1_pkd9']:.2f}%")
print(f"    Specificity: {val_metrics['specificity_pkd9']:.2f}%")

In [None]:
# Evaluate on TEST set (UNSEEN DATA!)
print("\n" + "-"*70)
print(f"Evaluating on TEST set ({len(test_dataset):,} samples)...")
print("-"*70)

test_metrics, test_preds, test_targets = eval_model(
    model, test_loader, device, "Test Set"
)

print(f"\nüìä TEST SET METRICS (UNSEEN DATA):")
print(f"  Samples: {test_metrics['n_samples']:,}")
print(f"  Strong Binders (pKd‚â•9): {test_metrics['n_strong_binders']}")
print(f"\n  Regression Metrics:")
print(f"    RMSE:        {test_metrics['rmse']:.4f}")
print(f"    MAE:         {test_metrics['mae']:.4f}")
print(f"    MSE:         {test_metrics['mse']:.4f}")
print(f"    R¬≤:          {test_metrics['r2']:.4f}")
print(f"\n  Correlation Metrics:")
print(f"    Spearman œÅ:  {test_metrics['spearman']:.4f} (p={test_metrics['spearman_p']:.2e})")
print(f"    Pearson r:   {test_metrics['pearson']:.4f} (p={test_metrics['pearson_p']:.2e})")
print(f"\n  Classification Metrics (pKd‚â•9):")
print(f"    Recall:      {test_metrics['recall_pkd9']:.2f}%")
print(f"    Precision:   {test_metrics['precision_pkd9']:.2f}%")
print(f"    F1-Score:    {test_metrics['f1_pkd9']:.2f}%")
print(f"    Specificity: {test_metrics['specificity_pkd9']:.2f}%")

In [None]:
# Save predictions and metrics to Google Drive
print("\n" + "-"*70)
print("Saving results to Google Drive...")
print("-"*70)

# Validation predictions
val_results = pd.DataFrame({
    'true_pKd': val_targets,
    'pred_pKd': val_preds,
    'error': val_preds - val_targets,
    'abs_error': np.abs(val_preds - val_targets)
})
val_pred_path = os.path.join(OUTPUT_DIR, 'val_predictions.csv')
val_results.to_csv(val_pred_path, index=False)
print(f"‚úÖ Saved: {val_pred_path}")

# Test predictions
test_results = pd.DataFrame({
    'true_pKd': test_targets,
    'pred_pKd': test_preds,
    'error': test_preds - test_targets,
    'abs_error': np.abs(test_preds - test_targets)
})
test_pred_path = os.path.join(OUTPUT_DIR, 'test_predictions.csv')
test_results.to_csv(test_pred_path, index=False)
print(f"‚úÖ Saved: {test_pred_path}")

# Save all metrics to JSON
all_metrics = {
    'validation_full': {k: float(v) if isinstance(v, (np.floating, np.integer)) else v
                       for k, v in val_metrics.items()},
    'test': {k: float(v) if isinstance(v, (np.floating, np.integer)) else v
            for k, v in test_metrics.items()},
    'best_quick_val_spearman': float(best_spearman),
    'config': config
}

metrics_path = os.path.join(OUTPUT_DIR, 'final_metrics.json')
with open(metrics_path, 'w') as f:
    json.dump(all_metrics, f, indent=2)
print(f"‚úÖ Saved: {metrics_path}")

print(f"\n{'='*70}")
print(f"‚úÖ EVALUATION COMPLETE!")
print(f"{'='*70}")
print(f"\nüìå KEY RESULTS:")
print(f"  Validation Spearman: {val_metrics['spearman']:.4f}")
print(f"  Test Spearman:       {test_metrics['spearman']:.4f} ‚Üê TRUE PERFORMANCE")
print(f"  Test RMSE:           {test_metrics['rmse']:.4f}")
print(f"  Test MAE:            {test_metrics['mae']:.4f}")
print(f"  Test R¬≤:             {test_metrics['r2']:.4f}")
print(f"  Test Recall@pKd‚â•9:   {test_metrics['recall_pkd9']:.2f}%")
print(f"\nüìÅ All results saved to: {OUTPUT_DIR}")
print(f"{'='*70}")

# Step 8: Results Visualization

**What this does:**
- Plots training curves (loss, Spearman correlation over epochs)
- Creates scatter plots (predictions vs actual values)
- Shows error distribution histogram
- Saves all plots to Google Drive

**Why visualizations matter:**
- See if model is learning (loss decreasing, Spearman increasing)
- Identify issues (overfitting, underfitting)
- Understand prediction quality
- Communicate results effectively

In [None]:
# Plot training curves
fig, axes = plt.subplots(1, 2, figsize=(15, 5))

# Training loss
ax1 = axes[0]
ax1.plot(training_history['epoch'], training_history['train_loss'], 'b-o', linewidth=2, markersize=4)
ax1.set_xlabel('Epoch', fontsize=12)
ax1.set_ylabel('Training Loss', fontsize=12)
ax1.set_title('Training Loss Over Time', fontsize=14, fontweight='bold')
ax1.grid(True, alpha=0.3)

# Validation Spearman
ax2 = axes[1]
ax2.plot(training_history['epoch'], training_history['val_spearman'], 'g-o', linewidth=2, markersize=4)
ax2.axhline(y=best_spearman, color='r', linestyle='--', linewidth=2, label=f'Best: {best_spearman:.4f}')
ax2.set_xlabel('Epoch', fontsize=12)
ax2.set_ylabel('Validation Spearman', fontsize=12)
ax2.set_title('Validation Spearman Over Time', fontsize=14, fontweight='bold')
ax2.legend(fontsize=11)
ax2.grid(True, alpha=0.3)

plt.tight_layout()
curves_path = os.path.join(OUTPUT_DIR, 'training_curves.png')
plt.savefig(curves_path, dpi=300, bbox_inches='tight')
plt.show()

print(f"‚úÖ Saved: {curves_path}")

In [None]:
# Prediction vs Actual scatter plots
fig, axes = plt.subplots(1, 2, figsize=(15, 6))

# Validation set
ax1 = axes[0]
ax1.scatter(val_targets, val_preds, alpha=0.3, s=10, color='blue')
ax1.plot([4, 14], [4, 14], 'r--', linewidth=2, label='Perfect prediction')
ax1.set_xlabel('True pKd', fontsize=12)
ax1.set_ylabel('Predicted pKd', fontsize=12)
ax1.set_title(f'Validation Set\nSpearman: {val_metrics["spearman"]:.4f}, RMSE: {val_metrics["rmse"]:.4f}',
              fontsize=13, fontweight='bold')
ax1.legend(fontsize=11)
ax1.grid(True, alpha=0.3)
ax1.set_xlim(4, 14)
ax1.set_ylim(4, 14)

# Test set
ax2 = axes[1]
ax2.scatter(test_targets, test_preds, alpha=0.3, s=10, color='orange')
ax2.plot([4, 14], [4, 14], 'r--', linewidth=2, label='Perfect prediction')
ax2.set_xlabel('True pKd', fontsize=12)
ax2.set_ylabel('Predicted pKd', fontsize=12)
ax2.set_title(f'Test Set (UNSEEN DATA)\nSpearman: {test_metrics["spearman"]:.4f}, RMSE: {test_metrics["rmse"]:.4f}',
              fontsize=13, fontweight='bold')
ax2.legend(fontsize=11)
ax2.grid(True, alpha=0.3)
ax2.set_xlim(4, 14)
ax2.set_ylim(4, 14)

plt.tight_layout()
scatter_path = os.path.join(OUTPUT_DIR, 'predictions_scatter.png')
plt.savefig(scatter_path, dpi=300, bbox_inches='tight')
plt.show()

print(f"‚úÖ Saved: {scatter_path}")

In [None]:
# Error distribution
fig, ax = plt.subplots(figsize=(10, 6))

test_errors = test_preds - test_targets
ax.hist(test_errors, bins=50, edgecolor='black', alpha=0.7, color='steelblue')
ax.axvline(x=0, color='r', linestyle='--', linewidth=2, label='Zero error')
ax.axvline(x=np.mean(test_errors), color='g', linestyle='--', linewidth=2,
           label=f'Mean error: {np.mean(test_errors):.4f}')
ax.set_xlabel('Prediction Error (pKd units)', fontsize=12)
ax.set_ylabel('Frequency', fontsize=12)
ax.set_title('Test Set: Error Distribution', fontsize=14, fontweight='bold')
ax.legend(fontsize=11)
ax.grid(True, alpha=0.3, axis='y')

plt.tight_layout()
error_path = os.path.join(OUTPUT_DIR, 'error_distribution.png')
plt.savefig(error_path, dpi=300, bbox_inches='tight')
plt.show()

print(f"‚úÖ Saved: {error_path}")

# Error statistics
print(f"\nüìä Error Analysis:")
print(f"   Mean error:    {np.mean(test_errors):.4f} pKd")
print(f"   Std error:     {np.std(test_errors):.4f} pKd")
print(f"   Median |error|: {np.median(np.abs(test_errors)):.4f} pKd")
print(f"   95th %ile |error|: {np.percentile(np.abs(test_errors), 95):.4f} pKd")

In [None]:
# Summary of all output files
print("\n" + "="*70)
print("ALL RESULTS SAVED TO GOOGLE DRIVE")
print("="*70)

print(f"\nüìÅ Output directory: {OUTPUT_DIR}")
print("\nFiles saved:")
print("  1. best_model.pth - Trained model weights (~3.5GB)")
print("  2. val_predictions.csv - Validation predictions")
print("  3. test_predictions.csv - Test predictions")
print("  4. final_metrics.json - All metrics")
print("  5. training_curves.png - Training visualization")
print("  6. predictions_scatter.png - Prediction plots")
print("  7. error_distribution.png - Error analysis")

print("\n‚úÖ All files are saved in your Google Drive!")
print("   You can access them anytime at:")
print(f"   {OUTPUT_DIR}")
print("\n" + "="*70)

# üéâ Training Complete!

## Summary

You've successfully trained a state-of-the-art antibody-antigen binding prediction model!

### What you accomplished:
‚úÖ Trained dual-encoder model (IgT5 + ESM-2)  
‚úÖ Implemented modern training practices (warmup, early stopping, regularization)  
‚úÖ Evaluated on proper train/val/test splits  
‚úÖ Computed comprehensive metrics (12 total)  
‚úÖ Created publication-ready visualizations  
‚úÖ Saved all results to Google Drive

### Key Results:
- **Test Spearman:** Your true, unbiased performance metric
- **Test RMSE:** Prediction error in pKd units
- **Test Recall@pKd‚â•9:** How well you identify strong binders

### Your Results (Saved in Drive):
All outputs are in: `Google Drive/AbAg_Training_02/training_output/`

### Next Steps:
1. Access results in your Google Drive
2. Analyze error patterns in `test_predictions.csv`
3. Try improving performance:
   - Experiment with hyperparameters (LR, dropout, warmup)
   - Train ensemble of models
   - Add more data

### Questions?
Review the code comments - each function is documented with:
- What it does
- Why it matters
- How it works

---

**Happy modeling! üß¨üöÄ**