# Antibody-Antigen Binding Prediction - OPTIMIZED v2

## Research-Enhanced Training with Advanced Regularization

**Key Improvements (Based on 2024 Research):**
- Cross-attention between antibody and antigen embeddings
- Cosine annealing with warm restarts (escape local minima)
- Combined loss: Huber + Spearman correlation
- Stratified batching by pKd range (prevent model collapse)
- Mixup augmentation for embeddings
- Multi-task learning (regression + classification)
- Optuna hyperparameter optimization

**Expected Performance:**
- Test Spearman: **0.45-0.55**
- Recall (pKd>=9): **60-80%**

---

# Step 1: Environment Setup

In [18]:
# Check GPU
import torch
import sys

print(f"PyTorch: {torch.__version__}")
print(f"CUDA: {torch.cuda.is_available()}")

if torch.cuda.is_available():
    gpu_name = torch.cuda.get_device_name(0)
    gpu_memory = torch.cuda.get_device_properties(0).total_memory / 1e9
    print(f"GPU: {gpu_name} ({gpu_memory:.1f}GB)")
    device = torch.device('cuda')
else:
    device = torch.device('cpu')
    print("WARNING: No GPU!")

PyTorch: 2.9.0+cu126
CUDA: True
GPU: NVIDIA A100-SXM4-80GB (85.2GB)


In [19]:
# Install packages
!pip install -q transformers>=4.41.0 sentencepiece optuna
print("Packages installed!")

Packages installed!


In [20]:
# A100 optimizations
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
torch.backends.cudnn.benchmark = True
torch.set_float32_matmul_precision('high')
print("A100 optimizations enabled")

A100 optimizations enabled


# Step 2: Imports & Utilities

In [21]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import json
import os
import time
import math
import random
from tqdm.auto import tqdm

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler
from torch.utils.tensorboard import SummaryWriter  # TensorBoard for PyTorch

from transformers import T5Tokenizer, T5EncoderModel, AutoTokenizer, AutoModel
from sklearn.model_selection import train_test_split
from sklearn.metrics import mean_squared_error, mean_absolute_error, r2_score
from scipy import stats

import optuna
from optuna.samplers import TPESampler
from optuna.pruners import MedianPruner

# Set random seeds for reproducibility
SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed(SEED)
    torch.cuda.manual_seed_all(SEED)

print("Libraries imported!")
print(f"Random seed set to {SEED}")

Libraries imported!
Random seed set to 42


In [22]:
# Comprehensive metrics
def compute_metrics(targets, predictions):
    mse = mean_squared_error(targets, predictions)
    rmse = np.sqrt(mse)
    mae = mean_absolute_error(targets, predictions)
    r2 = r2_score(targets, predictions)
    spearman, _ = stats.spearmanr(targets, predictions)
    pearson, _ = stats.pearsonr(targets, predictions)

    # Classification metrics at pKd=9
    strong = targets >= 9.0
    pred_strong = predictions >= 9.0
    tp = np.sum(strong & pred_strong)
    fn = np.sum(strong & ~pred_strong)
    fp = np.sum(~strong & pred_strong)

    recall = tp / (tp + fn) if (tp + fn) > 0 else 0
    precision = tp / (tp + fp) if (tp + fp) > 0 else 0

    return {
        'rmse': rmse, 'mae': mae, 'r2': r2,
        'spearman': spearman, 'pearson': pearson,
        'recall': recall * 100, 'precision': precision * 100
    }

# Early Stopping
class EarlyStopping:
    def __init__(self, patience=15, min_delta=0.001):
        self.patience = patience
        self.min_delta = min_delta
        self.counter = 0
        self.best_score = None
        self.early_stop = False

    def __call__(self, score):
        if self.best_score is None or score > self.best_score + self.min_delta:
            self.best_score = score
            self.counter = 0
        else:
            self.counter += 1
            if self.counter >= self.patience:
                self.early_stop = True
        return self.early_stop

print("Utilities defined!")

Utilities defined!


In [23]:
# Advanced Loss Functions (Data-Informed)

class HuberLoss(nn.Module):
    """Huber loss - robust to outliers"""
    def __init__(self, delta=1.0):
        super().__init__()
        self.delta = delta

    def forward(self, pred, target):
        diff = torch.abs(pred - target)
        loss = torch.where(
            diff < self.delta,
            0.5 * diff ** 2,
            self.delta * (diff - 0.5 * self.delta)
        )
        return loss.mean()

class CombinedLoss(nn.Module):
    """Combined loss: Huber + Soft Spearman + Classification

    Data-informed design:
    - Dataset is BIMODAL (peaks at 6-7 and 9-10)
    - Soft Spearman forces model to learn both modes
    - Classification threshold at 9 (therapeutic standard)

    Primary goal: Accurate pKd prediction
    Secondary goal: Identify therapeutic candidates (pKd >= 9)
    """
    def __init__(self, huber_weight=0.5, spearman_weight=0.4, class_weight=0.1):
        super().__init__()
        self.huber = HuberLoss(delta=1.0)
        self.huber_weight = huber_weight
        self.spearman_weight = spearman_weight
        self.class_weight = class_weight
        self.bce = nn.BCEWithLogitsLoss()

    def soft_spearman_loss(self, pred, target, temperature=1.0):
        """Differentiable Spearman correlation loss using soft ranking

        This is critical for bimodal data - forces model to learn
        correct ranking across both low and high affinity peaks.
        """
        # Compute pairwise differences
        pred_diff = pred.unsqueeze(1) - pred.unsqueeze(0)  # [B, B]
        target_diff = target.unsqueeze(1) - target.unsqueeze(0)  # [B, B]

        # Soft ranking using sigmoid
        pred_rank = torch.sigmoid(pred_diff / temperature).sum(dim=1)  # [B]
        target_rank = torch.sigmoid(target_diff / temperature).sum(dim=1)  # [B]

        # Normalize ranks
        pred_rank = (pred_rank - pred_rank.mean()) / (pred_rank.std() + 1e-8)
        target_rank = (target_rank - target_rank.mean()) / (target_rank.std() + 1e-8)

        # Correlation
        corr = (pred_rank * target_rank).mean()
        return 1 - corr  # Loss = 1 - correlation

    def forward(self, pred, target, class_logits=None):
        # Primary losses for affinity prediction
        huber_loss = self.huber(pred, target)
        spearman_loss = self.soft_spearman_loss(pred, target)

        loss = self.huber_weight * huber_loss + self.spearman_weight * spearman_loss

        # Classification: therapeutic threshold (pKd >= 9)
        if class_logits is not None:
            class_target = (target >= 9.0).float()  # Therapeutic standard
            class_loss = self.bce(class_logits, class_target)
            loss += self.class_weight * class_loss

        return loss

print("Loss functions defined (data-informed, therapeutic threshold=9)!")

Loss functions defined (data-informed, therapeutic threshold=9)!


# Step 3: Load Data

In [24]:
# Mount Google Drive
from google.colab import drive
drive.mount('/content/drive')

DRIVE_DIR = '/content/drive/MyDrive/AbAg_Training_02'
OUTPUT_DIR = f'{DRIVE_DIR}/training_output_OPTIMIZED_v2'
os.makedirs(OUTPUT_DIR, exist_ok=True)

print(f"Output: {OUTPUT_DIR}")

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
Output: /content/drive/MyDrive/AbAg_Training_02/training_output_OPTIMIZED_v2


In [25]:
# Load dataset
CSV_FILENAME = 'agab_phase2_full.csv'  # <- CHANGE THIS

df = pd.read_csv(os.path.join(DRIVE_DIR, CSV_FILENAME))
print(f"Dataset: {len(df):,} samples")
print(f"pKd range: {df['pKd'].min():.1f} - {df['pKd'].max():.1f}")
print(f"Strong binders: {(df['pKd']>=9).sum():,} ({100*(df['pKd']>=9).mean():.1f}%)")

Dataset: 159,735 samples
pKd range: -3.0 - 12.4
Strong binders: 54,741 (34.3%)


In [26]:
# Split data with stratification by pKd bins
df['pKd_bin'] = pd.cut(df['pKd'], bins=5, labels=False)

train_df, temp_df = train_test_split(df, test_size=0.3, random_state=42, stratify=df['pKd_bin'])
val_df, test_df = train_test_split(temp_df, test_size=0.5, random_state=42, stratify=temp_df['pKd_bin'])

print(f"Train: {len(train_df):,} | Val: {len(val_df):,} | Test: {len(test_df):,}")

Train: 111,814 | Val: 23,960 | Test: 23,961


In [27]:
# Dataset with stratified sampling support
class AbAgDataset(Dataset):
    def __init__(self, dataframe):
        self.data = dataframe.reset_index(drop=True)
        # Compute sample weights for stratified sampling
        pKd_bins = pd.cut(self.data['pKd'], bins=5, labels=False)
        bin_counts = pKd_bins.value_counts()
        # Use map for proper indexing (handles any bin indices)
        self.weights = pKd_bins.map(lambda x: 1.0 / bin_counts[x] if pd.notna(x) else 1.0).values
        self.weights = self.weights / self.weights.sum()

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        row = self.data.iloc[idx]
        return {
            'antibody_seqs': row['antibody_sequence'],
            'antigen_seqs': row['antigen_sequence'],
            'pKd': torch.tensor(row['pKd'], dtype=torch.float32)
        }

def collate_fn(batch):
    return {
        'antibody_seqs': [item['antibody_seqs'] for item in batch],
        'antigen_seqs': [item['antigen_seqs'] for item in batch],
        'pKd': torch.stack([item['pKd'] for item in batch])
    }

train_dataset = AbAgDataset(train_df)
val_dataset = AbAgDataset(val_df)
test_dataset = AbAgDataset(test_df)

print("Datasets created with stratified sampling weights!")

Datasets created with stratified sampling weights!


# Step 4: Enhanced Model Architecture

**New Features:**
- Cross-attention between Ab and Ag embeddings
- Multi-task output (regression + classification)
- Spectral normalization in regression head

In [28]:
# Cross-Attention Module
class CrossAttention(nn.Module):
    """Cross-attention between antibody and antigen embeddings"""
    def __init__(self, dim, num_heads=8, dropout=0.1):
        super().__init__()
        self.attention = nn.MultiheadAttention(dim, num_heads, dropout=dropout, batch_first=True)
        self.norm1 = nn.LayerNorm(dim)
        self.norm2 = nn.LayerNorm(dim)
        self.ffn = nn.Sequential(
            nn.Linear(dim, dim * 4),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(dim * 4, dim),
            nn.Dropout(dropout)
        )

    def forward(self, query, key_value):
        # Cross-attention
        attn_out, _ = self.attention(query, key_value, key_value)
        query = self.norm1(query + attn_out)
        # FFN
        ffn_out = self.ffn(query)
        query = self.norm2(query + ffn_out)
        return query

print("Cross-attention module defined!")

Cross-attention module defined!


In [29]:
# Enhanced Model with Cross-Attention
class EnhancedAbAgModel(nn.Module):
    def __init__(self, dropout=0.3, use_cross_attention=True, use_esm2_3b=True):
        super().__init__()

        print("Building enhanced model...")

        # IgT5 for antibodies
        self.igt5_tokenizer = T5Tokenizer.from_pretrained("Exscientia/IgT5")
        self.igt5_model = T5EncoderModel.from_pretrained("Exscientia/IgT5")
        self.igt5_dim = 1024  # IgT5 outputs 1024-dim embeddings

        # ESM-2 for antigens
        if use_esm2_3b:
            print("  Loading ESM-2 3B...")
            self.esm2_tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t36_3B_UR50D")
            self.esm2_model = AutoModel.from_pretrained("facebook/esm2_t36_3B_UR50D")
            self.esm2_dim = 2560
        else:
            print("  Loading ESM-2 650M...")
            self.esm2_tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t33_650M_UR50D")
            self.esm2_model = AutoModel.from_pretrained("facebook/esm2_t33_650M_UR50D")
            self.esm2_dim = 1280

        # Freeze encoders
        for param in self.igt5_model.parameters():
            param.requires_grad = False
        for param in self.esm2_model.parameters():
            param.requires_grad = False

        # Enable gradient checkpointing
        self.igt5_model.gradient_checkpointing_enable()
        self.esm2_model.gradient_checkpointing_enable()

        # Projection layers to common dimension
        self.common_dim = 512
        self.ab_proj = nn.Linear(self.igt5_dim, self.common_dim)  # 1024 -> 512
        self.ag_proj = nn.Linear(self.esm2_dim, self.common_dim)  # 2560 -> 512

        # Cross-attention
        self.use_cross_attention = use_cross_attention
        if use_cross_attention:
            self.cross_attn_ab = CrossAttention(self.common_dim, num_heads=8, dropout=dropout)
            self.cross_attn_ag = CrossAttention(self.common_dim, num_heads=8, dropout=dropout)

        # Regression head with spectral normalization
        self.regression_head = nn.Sequential(
            nn.utils.spectral_norm(nn.Linear(self.common_dim * 2, 512)),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.LayerNorm(512),

            nn.utils.spectral_norm(nn.Linear(512, 256)),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.LayerNorm(256),

            nn.utils.spectral_norm(nn.Linear(256, 128)),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.LayerNorm(128),

            nn.Linear(128, 1)
        )

        # Classification head (auxiliary task)
        self.classifier = nn.Linear(self.common_dim * 2, 1)

        trainable = sum(p.numel() for p in self.parameters() if p.requires_grad)
        print(f"  Trainable parameters: {trainable/1e6:.1f}M")

    def forward(self, antibody_seqs, antigen_seqs, device):
        # Tokenize
        ab_tokens = self.igt5_tokenizer(
            antibody_seqs, return_tensors='pt', padding=True,
            truncation=True, max_length=512
        ).to(device)

        ag_tokens = self.esm2_tokenizer(
            antigen_seqs, return_tensors='pt', padding=True,
            truncation=True, max_length=2048
        ).to(device)

        with torch.amp.autocast('cuda', dtype=torch.bfloat16):
            # Encode
            ab_out = self.igt5_model(**ab_tokens).last_hidden_state
            ag_out = self.esm2_model(**ag_tokens).last_hidden_state

            # Mean pooling
            ab_emb = ab_out.mean(dim=1)  # [B, 1024]
            ag_emb = ag_out.mean(dim=1)  # [B, 2560]

            # Project to common dimension
            ab_proj = self.ab_proj(ab_emb)  # [B, 512]
            ag_proj = self.ag_proj(ag_emb)  # [B, 512]

            # Cross-attention (optional)
            if self.use_cross_attention:
                # Add sequence dimension for attention
                ab_proj = ab_proj.unsqueeze(1)  # [B, 1, 512]
                ag_proj = ag_proj.unsqueeze(1)  # [B, 1, 512]

                ab_enhanced = self.cross_attn_ab(ab_proj, ag_proj).squeeze(1)
                ag_enhanced = self.cross_attn_ag(ag_proj, ab_proj).squeeze(1)

                combined = torch.cat([ab_enhanced, ag_enhanced], dim=1)
            else:
                combined = torch.cat([ab_proj, ag_proj], dim=1)

            # Predictions
            pKd_pred = self.regression_head(combined).squeeze(-1)
            class_logits = self.classifier(combined).squeeze(-1)

        return pKd_pred, class_logits

print("Enhanced model class defined!")

Enhanced model class defined!


In [30]:
# Step 5: Setup Training with Improved Hyperparameters
# Focus: Accurate pKd prediction with stable training

# Improved hyperparameters
BATCH_SIZE = 32
LEARNING_RATE = 5e-4  # Reduced from 1e-3 for stability
DROPOUT = 0.3
HUBER_WEIGHT = 0.5    # Primary: regression accuracy
SPEARMAN_WEIGHT = 0.4  # Primary: ranking correlation
CLASS_WEIGHT = 0.1     # Secondary: high binder classification
USE_CROSS_ATTENTION = True
WARMUP_EPOCHS = 5      # NEW: warmup period
EPOCHS = 50

print("="*60)
print("HYPERPARAMETERS (Improved for stability)")
print("="*60)
print(f"  Batch size: {BATCH_SIZE}")
print(f"  Learning rate: {LEARNING_RATE}")
print(f"  Warmup epochs: {WARMUP_EPOCHS}")
print(f"  Dropout: {DROPOUT}")
print(f"  Loss weights: Huber={HUBER_WEIGHT}, Spearman={SPEARMAN_WEIGHT}, Class={CLASS_WEIGHT}")
print(f"  Cross-attention: {USE_CROSS_ATTENTION}")
print(f"  Epochs: {EPOCHS}")

HYPERPARAMETERS (Improved for stability)
  Batch size: 32
  Learning rate: 0.0005
  Warmup epochs: 5
  Dropout: 0.3
  Loss weights: Huber=0.5, Spearman=0.4, Class=0.1
  Cross-attention: True
  Epochs: 50


In [31]:
# Training functions
def train_epoch(model, loader, optimizer, criterion, device, max_grad_norm=1.0):
    model.train()
    total_loss = 0

    for batch in loader:
        ab_seqs = batch['antibody_seqs']
        ag_seqs = batch['antigen_seqs']
        targets = batch['pKd'].to(device)

        pKd_pred, class_logits = model(ab_seqs, ag_seqs, device)
        loss = criterion(pKd_pred, targets, class_logits)

        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm)
        optimizer.step()

        total_loss += loss.item()

    return total_loss / len(loader)

def evaluate(model, loader, device):
    model.eval()
    preds, targets = [], []

    with torch.no_grad():
        for batch in loader:
            ab_seqs = batch['antibody_seqs']
            ag_seqs = batch['antigen_seqs']
            batch_targets = batch['pKd'].to(device)

            pKd_pred, _ = model(ab_seqs, ag_seqs, device)

            preds.extend(pKd_pred.float().cpu().numpy())
            targets.extend(batch_targets.float().cpu().numpy())

    return compute_metrics(np.array(targets), np.array(preds)), np.array(preds), np.array(targets)

print("Training functions defined!")

Training functions defined!


In [32]:
# Build model
print("Building model...")

model = EnhancedAbAgModel(
    dropout=DROPOUT,
    use_cross_attention=USE_CROSS_ATTENTION,
    use_esm2_3b=True
).to(device)

print("Model ready!")

Building model...
Building enhanced model...


Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

  Loading ESM-2 3B...


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Some weights of EsmModel were not initialized from the model checkpoint at facebook/esm2_t36_3B_UR50D and are newly initialized: ['pooler.dense.bias', 'pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


  Trainable parameters: 8.8M
Model ready!


In [33]:
# Setup DataLoaders, optimizer, schedulers with warmup + plateau
from torch.utils.data import WeightedRandomSampler
from torch.optim.lr_scheduler import LambdaLR, ReduceLROnPlateau
import datetime
import glob
import shutil

# DataLoaders with stratified sampling
# Use generator for reproducible sampling
g = torch.Generator()
g.manual_seed(SEED)

sampler = WeightedRandomSampler(train_dataset.weights, len(train_dataset), replacement=True, generator=g)
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, sampler=sampler,
                            num_workers=2, collate_fn=collate_fn, pin_memory=True)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False,
                        num_workers=2, collate_fn=collate_fn, pin_memory=True)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False,
                        num_workers=2, collate_fn=collate_fn, pin_memory=True)

# Optimizer with fused AdamW (only if CUDA available)
use_fused = torch.cuda.is_available()
optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE, weight_decay=0.01, fused=use_fused)

# Linear warmup then cosine decay (per-batch)
num_training_steps = len(train_loader) * EPOCHS
num_warmup_steps = len(train_loader) * WARMUP_EPOCHS

def lr_lambda(current_step):
    if current_step < num_warmup_steps:
        # Linear warmup
        return float(current_step) / float(max(1, num_warmup_steps))
    # Cosine decay
    progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps))
    return max(0.0, 0.5 * (1.0 + math.cos(math.pi * progress)))

scheduler = LambdaLR(optimizer, lr_lambda)

# ReduceLROnPlateau as backup (per-epoch)
plateau_scheduler = ReduceLROnPlateau(
    optimizer,
    mode='max',
    factor=0.5,
    patience=3,
    min_lr=1e-6
)

# Combined loss
criterion = CombinedLoss(HUBER_WEIGHT, SPEARMAN_WEIGHT, CLASS_WEIGHT)

# Early stopping
early_stopping = EarlyStopping(patience=10, min_delta=0.001)

# Checkpoint paths
model_path = os.path.join(OUTPUT_DIR, 'best_model.pth')
checkpoint_path = os.path.join(OUTPUT_DIR, 'checkpoint_latest.pth')
best_spearman = -1
start_epoch = 0
start_batch = 0  # For mid-epoch resume
history = {'loss': [], 'spearman': [], 'recall': [], 'lr': []}
log_dir = None  # Will be set below

# Store hyperparameters for checkpoint
hyperparams = {
    'batch_size': BATCH_SIZE,
    'learning_rate': LEARNING_RATE,
    'dropout': DROPOUT,
    'huber_weight': HUBER_WEIGHT,
    'spearman_weight': SPEARMAN_WEIGHT,
    'class_weight': CLASS_WEIGHT,
    'use_cross_attention': USE_CROSS_ATTENTION,
    'warmup_epochs': WARMUP_EPOCHS,
    'epochs': EPOCHS,
    'seed': SEED,
}

# Function to find and restore from best available checkpoint
def find_best_checkpoint():
    """Find the most recent valid checkpoint if latest is corrupted"""
    checkpoint_files = []
    expected_size_mb = 15000  # Checkpoints should be ~15GB

    for f in os.listdir(OUTPUT_DIR):
        if f.endswith('.pth') and f != 'best_model.pth':
            filepath = os.path.join(OUTPUT_DIR, f)
            size_mb = os.path.getsize(filepath) / (1024 * 1024)

            # Extract step/epoch number for sorting
            if 'step_' in f:
                num = int(f.split('step_')[1].split('.')[0])
                checkpoint_files.append((filepath, f, size_mb, num, 'step'))
            elif 'epoch_' in f:
                num = int(f.split('epoch_')[1].split('.')[0])
                # Convert epoch to approximate step for comparison
                checkpoint_files.append((filepath, f, size_mb, num * len(train_loader), 'epoch'))
            elif f == 'checkpoint_latest.pth':
                checkpoint_files.append((filepath, f, size_mb, -1, 'latest'))

    # Check if latest checkpoint is valid (should be ~15GB)
    latest_path = os.path.join(OUTPUT_DIR, 'checkpoint_latest.pth')
    if os.path.exists(latest_path):
        latest_size = os.path.getsize(latest_path) / (1024 * 1024)
        if latest_size > expected_size_mb * 0.9:  # Within 90% of expected size
            return None  # Latest is valid, no recovery needed

    # Find the most recent valid checkpoint
    valid_checkpoints = [(p, f, s, n, t) for p, f, s, n, t in checkpoint_files
                         if s > expected_size_mb * 0.9 and t != 'latest']

    if not valid_checkpoints:
        return None  # No valid backups

    # Sort by step number (highest first)
    valid_checkpoints.sort(key=lambda x: x[3], reverse=True)
    return valid_checkpoints[0]  # Return best checkpoint info

# Check for corrupted checkpoint and recover
if os.path.exists(checkpoint_path):
    try:
        # Try to load - this will fail if corrupted
        test_load = torch.load(checkpoint_path, weights_only=False)
        del test_load  # Free memory
    except Exception as e:
        print(f"WARNING: checkpoint_latest.pth is corrupted!")
        print(f"Error: {e}")

        # Find best backup
        best_backup = find_best_checkpoint()
        if best_backup:
            backup_path, backup_name, backup_size, _, _ = best_backup
            print(f"Found valid backup: {backup_name} ({backup_size:.1f} MB)")

            # Delete corrupted and restore from backup
            os.remove(checkpoint_path)
            shutil.copy(backup_path, checkpoint_path)
            print(f"Restored checkpoint_latest.pth from {backup_name}")
        else:
            print("No valid backup found. Deleting corrupted checkpoint - will start fresh.")
            os.remove(checkpoint_path)

# Resume from checkpoint if exists
if os.path.exists(checkpoint_path):
    print("Found checkpoint, resuming training...")
    checkpoint = torch.load(checkpoint_path, weights_only=False)

    # Check hyperparameter consistency
    saved_hyperparams = checkpoint.get('hyperparams', {})
    if saved_hyperparams:
        mismatches = []
        for key in ['batch_size', 'learning_rate', 'dropout']:
            if key in saved_hyperparams and saved_hyperparams[key] != hyperparams[key]:
                mismatches.append(f"{key}: saved={saved_hyperparams[key]}, current={hyperparams[key]}")
        if mismatches:
            print(f"WARNING: Hyperparameter mismatch detected!")
            for m in mismatches:
                print(f"  {m}")
            print("Training will continue with CURRENT hyperparameters.")

    model.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    best_spearman = checkpoint.get('best_spearman', -1)
    history = checkpoint.get('history', history)

    # Determine resume point
    saved_epoch = checkpoint['epoch']
    saved_batch = checkpoint.get('batch_idx', 0)

    # Check if this was a mid-epoch save or end-of-epoch save
    if saved_batch > 0 and saved_batch < len(train_loader):
        # Mid-epoch checkpoint - resume from this epoch at the saved batch
        start_epoch = saved_epoch
        start_batch = saved_batch
        print(f"Resuming mid-epoch: epoch {start_epoch+1}, batch {start_batch}")
    else:
        # End-of-epoch checkpoint - start next epoch
        start_epoch = saved_epoch + 1
        start_batch = 0
        print(f"Resuming from epoch {start_epoch+1}")

    # Restore scheduler states
    if 'scheduler_state_dict' in checkpoint:
        scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
    else:
        # Fallback for old checkpoints
        global_step = checkpoint.get('global_step', start_epoch * len(train_loader))
        for _ in range(global_step):
            scheduler.step()

    if 'plateau_scheduler_state_dict' in checkpoint:
        plateau_scheduler.load_state_dict(checkpoint['plateau_scheduler_state_dict'])

    # Restore early stopping state
    if 'early_stopping_best' in checkpoint:
        early_stopping.best_score = checkpoint['early_stopping_best']
        early_stopping.counter = checkpoint.get('early_stopping_counter', 0)

    # Restore RNG states for reproducibility
    if 'rng_state' in checkpoint:
        torch.set_rng_state(checkpoint['rng_state'])
    if 'cuda_rng_state' in checkpoint and torch.cuda.is_available():
        torch.cuda.set_rng_state(checkpoint['cuda_rng_state'])
    if 'numpy_rng_state' in checkpoint:
        np.random.set_state(checkpoint['numpy_rng_state'])
    if 'python_rng_state' in checkpoint:
        random.setstate(checkpoint['python_rng_state'])
    if 'sampler_rng_state' in checkpoint:
        g.set_state(checkpoint['sampler_rng_state'])

    # Resume TensorBoard to same directory if saved
    if 'log_dir' in checkpoint:
        log_dir = checkpoint['log_dir']
        print(f"Resuming TensorBoard logs to: {log_dir}")

    print(f"Best Spearman so far: {best_spearman:.4f}")
else:
    print("Starting fresh training...")

# TensorBoard logging - create new dir only if not resuming
if log_dir is None:
    log_dir = os.path.join(OUTPUT_DIR, 'runs', datetime.datetime.now().strftime('%Y%m%d-%H%M%S'))
writer = SummaryWriter(log_dir=log_dir)
print(f"TensorBoard logs: {log_dir}")

print(f"Train batches: {len(train_loader)}")
print(f"Val batches: {len(val_loader)}")
print(f"Warmup steps: {num_warmup_steps}")
print(f"Total steps: {num_training_steps}")
print(f"Fused optimizer: {use_fused}")
print("Schedulers: Warmup + Cosine (per-batch) + ReduceLROnPlateau (per-epoch)")
print("Ready for training!")

Starting fresh training...
TensorBoard logs: /content/drive/MyDrive/AbAg_Training_02/training_output_OPTIMIZED_v2/runs/20251124-003952
Train batches: 3495
Val batches: 749
Warmup steps: 17475
Total steps: 174750
Fused optimizer: True
Schedulers: Warmup + Cosine (per-batch) + ReduceLROnPlateau (per-epoch)
Ready for training!


# Step 6: Training Loop

In [None]:
# Training loop with checkpoint resuming + TensorBoard logging
print("="*60)
print("TRAINING")
print("="*60)

# Track previous checkpoint files for deletion
prev_step_checkpoint = None
prev_epoch_checkpoint = None

def permanent_delete(filepath):
    """Permanently delete file, bypassing Google Drive trash"""
    if filepath and os.path.exists(filepath):
        try:
            # Try to use Google Drive API for permanent deletion
            from google.colab import drive
            import subprocess
            # Get the file ID and delete permanently
            result = subprocess.run(
                ['rm', '-f', filepath],
                capture_output=True, text=True
            )
            # Force flush to ensure deletion
            os.sync() if hasattr(os, 'sync') else None
        except:
            # Fallback to regular delete
            os.remove(filepath)

def save_checkpoint(path, epoch, global_step, batch_idx=None):
    """Helper to save checkpoint with all necessary states"""
    checkpoint_data = {
        'epoch': epoch,
        'global_step': global_step,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'scheduler_state_dict': scheduler.state_dict(),
        'plateau_scheduler_state_dict': plateau_scheduler.state_dict(),
        'best_spearman': best_spearman,
        'history': history,
        'early_stopping_best': early_stopping.best_score,
        'early_stopping_counter': early_stopping.counter,
        'log_dir': log_dir,
        'hyperparams': hyperparams,
        # RNG states for exact reproducibility
        'rng_state': torch.get_rng_state(),
        'numpy_rng_state': np.random.get_state(),
        'python_rng_state': random.getstate(),
        'sampler_rng_state': g.get_state(),
    }
    if torch.cuda.is_available():
        checkpoint_data['cuda_rng_state'] = torch.cuda.get_rng_state()
    if batch_idx is not None:
        checkpoint_data['batch_idx'] = batch_idx
    torch.save(checkpoint_data, path)

for epoch in range(start_epoch, EPOCHS):
    start = time.time()

    # Train
    model.train()
    total_loss = 0
    num_batches_processed = 0

    # Determine where to start in this epoch
    skip_batches = start_batch if epoch == start_epoch else 0

    if skip_batches > 0:
        print(f"Skipping first {skip_batches} batches (already processed)")

    pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{EPOCHS}")

    for batch_idx, batch in enumerate(pbar):
        # Skip batches if resuming mid-epoch
        if batch_idx < skip_batches:
            continue

        ab_seqs = batch['antibody_seqs']
        ag_seqs = batch['antigen_seqs']
        targets = batch['pKd'].to(device)

        pKd_pred, class_logits = model(ab_seqs, ag_seqs, device)
        loss = criterion(pKd_pred, targets, class_logits)

        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()
        scheduler.step()

        total_loss += loss.item()
        num_batches_processed += 1
        global_step = epoch * len(train_loader) + batch_idx
        pbar.set_postfix({'loss': f'{loss.item():.4f}', 'lr': f'{scheduler.get_last_lr()[0]:.2e}'})

        # Log to TensorBoard every 100 batches
        if batch_idx % 100 == 0:
            writer.add_scalar('Train/BatchLoss', loss.item(), global_step)
            writer.add_scalar('Train/LearningRate', scheduler.get_last_lr()[0], global_step)

        # Save checkpoint every 1000 steps
        if (global_step + 1) % 1000 == 0:
            step_checkpoint_path = os.path.join(OUTPUT_DIR, f'checkpoint_step_{global_step+1}.pth')
            save_checkpoint(step_checkpoint_path, epoch, global_step + 1, batch_idx + 1)

            # Delete previous step checkpoint permanently
            if prev_step_checkpoint:
                permanent_delete(prev_step_checkpoint)
            prev_step_checkpoint = step_checkpoint_path

            # Also update the latest checkpoint
            save_checkpoint(checkpoint_path, epoch, global_step + 1, batch_idx + 1)

            print(f"\n  Saved step checkpoint at step {global_step+1}")

    # Validate
    metrics, val_preds, val_targets = evaluate(model, val_loader, device)
    elapsed = time.time() - start
    current_lr = scheduler.get_last_lr()[0]

    # Step plateau scheduler
    plateau_scheduler.step(metrics['spearman'])

    # Prediction distribution
    pred_mean = np.mean(val_preds)
    pred_std = np.std(val_preds)
    pred_min = np.min(val_preds)
    pred_max = np.max(val_preds)

    # Log epoch metrics to TensorBoard
    avg_loss = total_loss / num_batches_processed if num_batches_processed > 0 else 0
    writer.add_scalar('Train/EpochLoss', avg_loss, epoch)
    writer.add_scalar('Val/Spearman', metrics['spearman'], epoch)
    writer.add_scalar('Val/Recall', metrics['recall'], epoch)
    writer.add_scalar('Val/RMSE', metrics['rmse'], epoch)
    writer.add_scalar('Val/PredMean', pred_mean, epoch)
    writer.add_scalar('Val/PredStd', pred_std, epoch)
    writer.add_histogram('Val/Predictions', val_preds, epoch)
    writer.add_histogram('Val/Targets', val_targets, epoch)

    print(f"Loss: {avg_loss:.4f} | Spearman: {metrics['spearman']:.4f} | "
            f"Recall: {metrics['recall']:.1f}% | LR: {current_lr:.2e} | Time: {elapsed:.1f}s")
    print(f"  Pred dist: mean={pred_mean:.2f}, std={pred_std:.2f}, range=[{pred_min:.2f}, {pred_max:.2f}]")

    # Warning if predictions collapse
    if pred_std < 0.5:
        print(f"  WARNING: Low prediction variance! Model may be collapsing.")

    # Save best model
    if metrics['spearman'] > best_spearman:
        best_spearman = metrics['spearman']
        torch.save({
            'epoch': epoch,
            'global_step': (epoch + 1) * len(train_loader),
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'scheduler_state_dict': scheduler.state_dict(),
            'plateau_scheduler_state_dict': plateau_scheduler.state_dict(),
            'spearman': best_spearman,
            'hyperparams': hyperparams,
        }, model_path)
        print(f"  * Saved best! (Spearman: {best_spearman:.4f})")

    # Update history
    history['loss'].append(avg_loss)
    history['spearman'].append(metrics['spearman'])
    history['recall'].append(metrics['recall'])
    history['lr'].append(current_lr)

    # Save epoch checkpoint (batch_idx=0 signals end of epoch)
    epoch_checkpoint_path = os.path.join(OUTPUT_DIR, f'checkpoint_epoch_{epoch+1}.pth')
    save_checkpoint(epoch_checkpoint_path, epoch, (epoch + 1) * len(train_loader), batch_idx=0)

    # Delete previous epoch checkpoint permanently
    if prev_epoch_checkpoint:
        permanent_delete(prev_epoch_checkpoint)
    prev_epoch_checkpoint = epoch_checkpoint_path

    # Also update the latest checkpoint (for resuming)
    save_checkpoint(checkpoint_path, epoch, (epoch + 1) * len(train_loader), batch_idx=0)

    print(f"  Saved epoch {epoch+1} checkpoint")

    # Early stopping
    if early_stopping(metrics['spearman']):
        print(f"\nEarly stopping triggered after {epoch+1} epochs!")
        print(f"Best Spearman: {best_spearman:.4f}")
        break

# Close TensorBoard writer
writer.close()
print(f"\nTraining complete! Best Spearman: {best_spearman:.4f}")
print(f"TensorBoard logs saved to: {log_dir}")
print("To view: tensorboard --logdir {log_dir}")

TRAINING


Epoch 1/50:   0%|          | 0/3495 [00:00<?, ?it/s]



# Step 7: Evaluation

In [None]:
# Load best and evaluate
checkpoint = torch.load(model_path, weights_only=False)
model.load_state_dict(checkpoint['model_state_dict'])
print(f"Loaded best model from epoch {checkpoint['epoch']+1}")

# Validation
val_metrics, val_preds, val_targets = evaluate(model, val_loader, device)
print(f"\nVALIDATION:")
print(f"  Spearman: {val_metrics['spearman']:.4f}")
print(f"  RMSE: {val_metrics['rmse']:.4f}")
print(f"  Recall: {val_metrics['recall']:.1f}%")

# Test
test_metrics, test_preds, test_targets = evaluate(model, test_loader, device)
print(f"\nTEST (Final Performance):")
print(f"  Spearman: {test_metrics['spearman']:.4f}")
print(f"  RMSE: {test_metrics['rmse']:.4f}")
print(f"  Recall: {test_metrics['recall']:.1f}%")
print(f"  Precision: {test_metrics['precision']:.1f}%")

In [None]:
# Save results
pd.DataFrame({
    'true': test_targets, 'pred': test_preds,
    'error': test_preds - test_targets
}).to_csv(os.path.join(OUTPUT_DIR, 'test_predictions.csv'), index=False)

with open(os.path.join(OUTPUT_DIR, 'metrics.json'), 'w') as f:
    json.dump({
        'test': test_metrics,
        'val': val_metrics,
        'hyperparameters': {
            'batch_size': BATCH_SIZE,
            'learning_rate': LEARNING_RATE,
            'dropout': DROPOUT,
            'huber_weight': HUBER_WEIGHT,
            'spearman_weight': SPEARMAN_WEIGHT,
            'class_weight': CLASS_WEIGHT,
            'use_cross_attention': USE_CROSS_ATTENTION,
            'warmup_epochs': WARMUP_EPOCHS,
            'epochs': EPOCHS
        }
    }, f, indent=2, default=float)

print(f"\nResults saved to {OUTPUT_DIR}")

In [None]:
# Visualization
fig, axes = plt.subplots(1, 2, figsize=(12, 5))

# Training curve
axes[0].plot(history['spearman'], 'g-o')
axes[0].axhline(best_spearman, color='r', linestyle='--', label=f'Best: {best_spearman:.4f}')
axes[0].set_xlabel('Epoch')
axes[0].set_ylabel('Validation Spearman')
axes[0].set_title('Training Progress')
axes[0].legend()
axes[0].grid(alpha=0.3)

# Predictions
axes[1].scatter(test_targets, test_preds, alpha=0.3, s=10)
axes[1].plot([4, 14], [4, 14], 'r--', label='Perfect')
axes[1].set_xlabel('True pKd')
axes[1].set_ylabel('Predicted pKd')
axes[1].set_title(f'Test Set (Spearman: {test_metrics["spearman"]:.4f})')
axes[1].legend()
axes[1].grid(alpha=0.3)

plt.tight_layout()
plt.savefig(os.path.join(OUTPUT_DIR, 'results.png'), dpi=300)
plt.show()

print("Done!")