# Antibody-Antigen Binding Affinity Prediction - Training

**Ultra-Speed Training v2.6** - IgT5 + ESM-2 Dual Encoder Model

## Instructions:
1. Upload your CSV file (`ab_ag_affinity_complete.csv`) to Google Drive in folder `AbAg_Training`
2. Run all cells in order
3. Training will auto-save checkpoints - you can resume if disconnected

**Expected time:** ~21-22 hours for 50 epochs on T4 GPU

## 1. Mount Google Drive & Setup

In [None]:
# Mount Google Drive
from google.colab import drive
drive.mount('/content/drive')

# Create working directory
import os
work_dir = '/content/drive/MyDrive/AbAg_Training'
os.makedirs(work_dir, exist_ok=True)
%cd {work_dir}

# Check if dataset exists
dataset_file = 'ab_ag_affinity_complete.csv'
if os.path.exists(dataset_file):
    print(f"Dataset found: {dataset_file}")
    !wc -l {dataset_file}
else:
    print(f"ERROR: Please upload '{dataset_file}' to Google Drive folder 'AbAg_Training'")

## 2. Install Dependencies

In [None]:
!pip install -q transformers pandas scipy scikit-learn tqdm sentencepiece faesm bitsandbytes accelerate
print("Dependencies installed!")

## 3. Configuration

Adjust these settings if needed:

In [None]:
# === CONFIGURATION ===
# You can modify these settings

CONFIG = {
    'data_path': 'ab_ag_affinity_complete.csv',  # Your dataset file
    'output_dir': 'outputs',                      # Where to save models
    'epochs': 50,                                 # Number of training epochs
    'batch_size': 16,                             # Batch size (16 works well on T4)
    'accumulation_steps': 3,                      # Gradient accumulation (effective batch = 48)
    'learning_rate': 3e-3,                        # Learning rate
    'weight_decay': 0.02,                         # L2 regularization
    'dropout': 0.35,                              # Dropout rate
    'early_stopping_patience': 10,                # Stop if no improvement for N epochs
    'validation_frequency': 1,                    # Validate every N epochs
}

print("Configuration:")
for k, v in CONFIG.items():
    print(f"  {k}: {v}")

## 4. Training Code

This cell contains all the training code. Just run it - no modifications needed.

In [None]:
"""
ULTRA SPEED Training v2.6 - All Advanced Optimizations
IgT5 + ESM-2 Dual Encoder for Antibody-Antigen Binding Prediction
"""

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader, Sampler
from torch.utils.checkpoint import checkpoint
import pandas as pd
import numpy as np
from scipy import stats
from sklearn.model_selection import train_test_split
from sklearn.metrics import mean_squared_error, mean_absolute_error, r2_score
from tqdm import tqdm
from pathlib import Path
import time
import shutil
import gc
import random
from transformers import T5EncoderModel, T5Tokenizer, AutoTokenizer, AutoModel, BitsAndBytesConfig
import threading
import subprocess
import sys
import os
import csv
import json
from datetime import datetime

# ============================================================================
# NUCLEAR FIX: Force disable torch.compile globally
# ============================================================================
import torch._dynamo
import torch.compiler

torch._dynamo.config.suppress_errors = True
torch.compiler.disable()
os.environ['TORCH_COMPILE_DISABLE'] = '1'
os.environ['TORCH_CUDAGRAPH_DISABLE'] = '1'

print("torch.compile DISABLED (prevents CUDA graphs errors)")

# Try to import FAESM for FlashAttention
try:
    from faesm.esm import FAEsmForMaskedLM
    FLASH_ATTN_AVAILABLE = True
    print("FlashAttention (FAESM) available")
except ImportError:
    FLASH_ATTN_AVAILABLE = False
    print("Using standard ESM-2 (no FlashAttention)")

# Enable backend optimizations
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
torch.backends.cudnn.benchmark = True
torch.backends.cudnn.deterministic = False

# ============================================================================
# DISK CLEANUP
# ============================================================================
def cleanup_disk_space():
    """Clean up disk space"""
    try:
        subprocess.run(['pip', 'cache', 'purge'], capture_output=True)
        torch.cuda.empty_cache()
        gc.collect()
    except:
        pass

# ============================================================================
# SEQUENCE BUCKETING
# ============================================================================
class BucketBatchSampler(Sampler):
    """Group sequences by similar lengths to minimize padding"""
    
    def __init__(self, dataset, batch_size, drop_last=True, buckets=[256, 384, 512]):
        self.dataset = dataset
        self.batch_size = batch_size
        self.drop_last = drop_last
        self.buckets = sorted(buckets)
        self.bucket_indices = {b: [] for b in self.buckets}
        
        for idx in range(len(dataset)):
            item = dataset[idx]
            seq_len = len(item['antibody_sequence'])
            bucket = min([b for b in self.buckets if b >= seq_len], default=self.buckets[-1])
            self.bucket_indices[bucket].append(idx)
        
        print(f"\nBucket Distribution:")
        for bucket in self.buckets:
            count = len(self.bucket_indices[bucket])
            print(f"  <={bucket}: {count:,} samples ({count/len(dataset)*100:.1f}%)")
    
    def __iter__(self):
        bucket_order = list(self.buckets)
        random.shuffle(bucket_order)
        
        for bucket in bucket_order:
            indices = self.bucket_indices[bucket].copy()
            random.shuffle(indices)
            
            for i in range(0, len(indices), self.batch_size):
                batch = indices[i:i+self.batch_size]
                if len(batch) == self.batch_size or not self.drop_last:
                    yield batch
    
    def __len__(self):
        count = 0
        for bucket in self.buckets:
            n = len(self.bucket_indices[bucket])
            count += n // self.batch_size
            if not self.drop_last and n % self.batch_size > 0:
                count += 1
        return count

# ============================================================================
# MODEL
# ============================================================================
class IgT5ESM2Model(nn.Module):
    """Dual-encoder model: IgT5 for antibody + ESM-2 for antigen"""
    
    def __init__(self, dropout=0.3, use_quantization=True, use_checkpointing=True):
        super().__init__()
        self.use_checkpointing = use_checkpointing
        
        print("Loading models...")
        
        # Quantization config
        quantization_config = None
        if use_quantization:
            try:
                quantization_config = BitsAndBytesConfig(
                    load_in_8bit=True,
                    llm_int8_threshold=6.0,
                    llm_int8_has_fp16_weight=False
                )
                print("  Using INT8 quantization")
            except:
                print("  Quantization not available, using BFloat16")
        
        # Load IgT5
        print("  Loading IgT5 for antibody...")
        self.igt5_tokenizer = T5Tokenizer.from_pretrained(
            "Exscientia/IgT5", do_lower_case=False, use_fast=True
        )
        if quantization_config:
            self.igt5_model = T5EncoderModel.from_pretrained(
                "Exscientia/IgT5", quantization_config=quantization_config, device_map="auto"
            )
        else:
            self.igt5_model = T5EncoderModel.from_pretrained("Exscientia/IgT5")
        
        # Load ESM-2
        print("  Loading ESM-2 for antigen...")
        self.esm2_tokenizer = AutoTokenizer.from_pretrained(
            "facebook/esm2_t33_650M_UR50D", use_fast=True
        )
        if FLASH_ATTN_AVAILABLE:
            self.esm2_model = FAEsmForMaskedLM.from_pretrained("facebook/esm2_t33_650M_UR50D")
        else:
            self.esm2_model = AutoModel.from_pretrained("facebook/esm2_t33_650M_UR50D")
        
        # Freeze encoders
        for param in self.igt5_model.parameters():
            param.requires_grad = False
        for param in self.esm2_model.parameters():
            param.requires_grad = False
        
        # Regressor
        igt5_dim = self.igt5_model.config.d_model
        esm2_dim = self.esm2_model.config.hidden_size
        combined_dim = igt5_dim + esm2_dim
        
        self.regressor_block1 = nn.Sequential(
            nn.Linear(combined_dim, 1024), nn.GELU(), nn.Dropout(dropout), nn.LayerNorm(1024)
        )
        self.regressor_block2 = nn.Sequential(
            nn.Linear(1024, 512), nn.GELU(), nn.Dropout(dropout), nn.LayerNorm(512)
        )
        self.regressor_block3 = nn.Sequential(
            nn.Linear(512, 256), nn.GELU(), nn.Dropout(dropout), nn.LayerNorm(256)
        )
        self.regressor_block4 = nn.Sequential(
            nn.Linear(256, 128), nn.GELU(), nn.Dropout(dropout)
        )
        self.regressor_final = nn.Linear(128, 1)
        
        print("  Model ready!")
    
    def get_batch_embeddings(self, sequences, model, tokenizer, device, pooling='mean'):
        inputs = tokenizer(
            sequences, return_tensors='pt', padding=True, truncation=True, max_length=512
        ).to(device, non_blocking=True)
        
        with torch.no_grad():
            outputs = model(**inputs)
            if pooling == 'mean':
                embeddings = outputs.last_hidden_state.mean(dim=1)
            else:
                embeddings = outputs.last_hidden_state[:, 0, :]
        return embeddings
    
    def forward(self, antibody_seqs, antigen_seqs, device):
        ab_embeddings = self.get_batch_embeddings(
            antibody_seqs, self.igt5_model, self.igt5_tokenizer, device, pooling='mean'
        )
        ag_embeddings = self.get_batch_embeddings(
            antigen_seqs, self.esm2_model, self.esm2_tokenizer, device, pooling='cls'
        )
        
        combined = torch.cat([ab_embeddings, ag_embeddings], dim=1)
        
        if self.use_checkpointing and self.training:
            x = checkpoint(self.regressor_block1, combined, use_reentrant=False)
            x = checkpoint(self.regressor_block2, x, use_reentrant=False)
            x = checkpoint(self.regressor_block3, x, use_reentrant=False)
            x = checkpoint(self.regressor_block4, x, use_reentrant=False)
        else:
            x = self.regressor_block1(combined)
            x = self.regressor_block2(x)
            x = self.regressor_block3(x)
            x = self.regressor_block4(x)
        
        return self.regressor_final(x).squeeze(-1)

# ============================================================================
# LOSS & UTILITIES
# ============================================================================
class FocalMSELoss(nn.Module):
    def __init__(self, gamma=2.0, label_smoothing=0.0):
        super().__init__()
        self.gamma = gamma
        self.label_smoothing = label_smoothing
    
    def forward(self, pred, target):
        if self.label_smoothing > 0:
            target_mean = target.mean()
            target = (1 - self.label_smoothing) * target + self.label_smoothing * target_mean
        mse = (pred - target) ** 2
        focal_weight = (1 + mse) ** self.gamma
        return (focal_weight * mse).mean()

class AbAgDataset(Dataset):
    def __init__(self, df):
        self.df = df.reset_index(drop=True)
    
    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, idx):
        return {
            'antibody_sequence': self.df.iloc[idx]['antibody_sequence'],
            'antigen_sequence': self.df.iloc[idx]['antigen_sequence'],
            'pKd': torch.tensor(self.df.iloc[idx]['pKd'], dtype=torch.float32)
        }

def collate_fn(batch):
    return {
        'antibody_seqs': [item['antibody_sequence'] for item in batch],
        'antigen_seqs': [item['antigen_sequence'] for item in batch],
        'pKd': torch.stack([item['pKd'] for item in batch])
    }

class EarlyStopping:
    def __init__(self, patience=10, min_delta=0.0001):
        self.patience = patience
        self.min_delta = min_delta
        self.counter = 0
        self.best_score = None
        self.best_epoch = 0
    
    def __call__(self, score, epoch):
        if self.best_score is None:
            self.best_score = score
            self.best_epoch = epoch
            return False
        
        if score > self.best_score + self.min_delta:
            self.best_score = score
            self.best_epoch = epoch
            self.counter = 0
            return False
        else:
            self.counter += 1
            print(f"  EarlyStopping: {self.counter}/{self.patience} (best: {self.best_score:.4f} @ epoch {self.best_epoch+1})")
            return self.counter >= self.patience

def compute_metrics(targets, predictions):
    mse = mean_squared_error(targets, predictions)
    rmse = np.sqrt(mse)
    mae = mean_absolute_error(targets, predictions)
    r2 = r2_score(targets, predictions)
    spearman, _ = stats.spearmanr(targets, predictions)
    pearson, _ = stats.pearsonr(targets, predictions)
    
    # High-affinity metrics (pKd >= 9)
    strong = targets >= 9.0
    pred_strong = predictions >= 9.0
    if strong.sum() > 0:
        tp = (strong & pred_strong).sum()
        fn = (strong & ~pred_strong).sum()
        recall = tp / (tp + fn) if (tp + fn) > 0 else 0
    else:
        recall = 0
    
    return {
        'rmse': rmse, 'mae': mae, 'r2': r2,
        'spearman': spearman, 'pearson': pearson,
        'recall_pkd9': recall * 100
    }

def quick_eval(model, loader, device):
    model.eval()
    predictions, targets = [], []
    
    with torch.no_grad():
        for i, batch in enumerate(loader):
            if i >= 50:  # Quick eval on 50 batches
                break
            with torch.amp.autocast('cuda', dtype=torch.bfloat16):
                preds = model(batch['antibody_seqs'], batch['antigen_seqs'], device)
            predictions.extend(preds.float().cpu().numpy())
            targets.extend(batch['pKd'].numpy())
    
    return compute_metrics(np.array(targets), np.array(predictions))

def full_eval(model, loader, device, desc="Eval"):
    model.eval()
    predictions, targets = [], []
    
    with torch.no_grad():
        for batch in tqdm(loader, desc=desc):
            with torch.amp.autocast('cuda', dtype=torch.bfloat16):
                preds = model(batch['antibody_seqs'], batch['antigen_seqs'], device)
            predictions.extend(preds.float().cpu().numpy())
            targets.extend(batch['pKd'].numpy())
    
    return compute_metrics(np.array(targets), np.array(predictions)), np.array(predictions), np.array(targets)

print("Training code loaded!")

## 5. Run Training

In [None]:
# ============================================================================
# MAIN TRAINING LOOP
# ============================================================================

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"\n{'='*70}")
print("ANTIBODY-ANTIGEN BINDING PREDICTION - TRAINING")
print(f"{'='*70}")
print(f"Device: {device}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
print(f"PyTorch: {torch.__version__}")

# Load data
print(f"\nLoading data from: {CONFIG['data_path']}")
df = pd.read_csv(CONFIG['data_path'])
print(f"Loaded {len(df):,} samples")

# Split data
train_df, temp_df = train_test_split(df, test_size=0.3, random_state=42)
val_df, test_df = train_test_split(temp_df, test_size=0.5, random_state=42)
val_df_quick = val_df.sample(frac=0.05, random_state=42)

print(f"\nDataset splits:")
print(f"  Train: {len(train_df):,} (70%)")
print(f"  Val:   {len(val_df):,} (15%)")
print(f"  Test:  {len(test_df):,} (15%)")

# Create datasets
train_dataset = AbAgDataset(train_df)
val_dataset_quick = AbAgDataset(val_df_quick)
val_dataset_full = AbAgDataset(val_df)
test_dataset = AbAgDataset(test_df)

# Create data loaders with bucketing
print("\nCreating data loaders...")
train_sampler = BucketBatchSampler(train_dataset, CONFIG['batch_size'], drop_last=True)
train_loader = DataLoader(
    train_dataset, batch_sampler=train_sampler, num_workers=4,
    prefetch_factor=4, pin_memory=True, persistent_workers=True, collate_fn=collate_fn
)
val_loader_quick = DataLoader(
    val_dataset_quick, batch_size=CONFIG['batch_size']*2, shuffle=False,
    num_workers=4, pin_memory=True, collate_fn=collate_fn
)
val_loader_full = DataLoader(
    val_dataset_full, batch_size=CONFIG['batch_size']*2, shuffle=False,
    num_workers=4, pin_memory=True, collate_fn=collate_fn
)
test_loader = DataLoader(
    test_dataset, batch_size=CONFIG['batch_size']*2, shuffle=False,
    num_workers=4, pin_memory=True, collate_fn=collate_fn
)

# Initialize model
print("\nInitializing model...")
model = IgT5ESM2Model(
    dropout=CONFIG['dropout'], use_quantization=True, use_checkpointing=True
).to(device)

# Loss and optimizer
criterion = FocalMSELoss(gamma=2.0, label_smoothing=0.05)
try:
    optimizer = torch.optim.AdamW(
        model.parameters(), lr=CONFIG['learning_rate'],
        weight_decay=CONFIG['weight_decay'], fused=True
    )
    print("Using fused optimizer")
except:
    optimizer = torch.optim.AdamW(
        model.parameters(), lr=CONFIG['learning_rate'],
        weight_decay=CONFIG['weight_decay']
    )

# LR scheduler
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=CONFIG['epochs'])

# Setup
output_dir = Path(CONFIG['output_dir'])
output_dir.mkdir(exist_ok=True)
early_stopping = EarlyStopping(patience=CONFIG['early_stopping_patience'])

start_epoch = 0
best_spearman = -1

# Try to resume from checkpoint
checkpoint_path = output_dir / 'checkpoint_latest.pth'
if checkpoint_path.exists():
    print(f"\nResuming from checkpoint...")
    ckpt = torch.load(checkpoint_path, map_location=device)
    try:
        model.load_state_dict(ckpt['model_state_dict'], strict=False)
        optimizer.load_state_dict(ckpt['optimizer_state_dict'])
        start_epoch = ckpt['epoch'] + 1
        best_spearman = ckpt.get('best_val_spearman', -1)
        print(f"Resumed from epoch {start_epoch}, best Spearman: {best_spearman:.4f}")
    except Exception as e:
        print(f"Could not load checkpoint: {e}")

print(f"\n{'='*70}")
print(f"Starting training for {CONFIG['epochs']} epochs...")
print(f"{'='*70}\n")

# Training loop
for epoch in range(start_epoch, CONFIG['epochs']):
    print(f"\nEpoch {epoch+1}/{CONFIG['epochs']}")
    print("-" * 50)
    
    # Cleanup
    if epoch > 0:
        cleanup_disk_space()
    
    # Train
    model.train()
    total_loss = 0
    pbar = tqdm(enumerate(train_loader), total=len(train_loader), desc="Training")
    
    for batch_idx, batch in pbar:
        targets = batch['pKd'].to(device, non_blocking=True)
        
        with torch.amp.autocast('cuda', dtype=torch.bfloat16):
            predictions = model(batch['antibody_seqs'], batch['antigen_seqs'], device)
            loss = criterion(predictions, targets) / CONFIG['accumulation_steps']
        
        loss.backward()
        
        if (batch_idx + 1) % CONFIG['accumulation_steps'] == 0:
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()
            optimizer.zero_grad(set_to_none=True)
        
        total_loss += loss.item() * CONFIG['accumulation_steps']
        pbar.set_postfix({'loss': f'{loss.item() * CONFIG["accumulation_steps"]:.4f}'})
        
        # Save checkpoint every 500 batches
        if (batch_idx + 1) % 500 == 0:
            torch.save({
                'epoch': epoch, 'batch_idx': batch_idx,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'best_val_spearman': best_spearman
            }, output_dir / 'checkpoint_latest.pth')
    
    train_loss = total_loss / len(train_loader)
    
    # Validation
    if (epoch + 1) % CONFIG['validation_frequency'] == 0:
        print("\nValidating...")
        val_metrics = quick_eval(model, val_loader_quick, device)
        print(f"Val Spearman: {val_metrics['spearman']:.4f} | Recall@pKd>=9: {val_metrics['recall_pkd9']:.1f}%")
        
        if val_metrics['spearman'] > best_spearman:
            best_spearman = val_metrics['spearman']
            torch.save({
                'epoch': epoch, 'model_state_dict': model.state_dict(),
                'best_val_spearman': best_spearman
            }, output_dir / 'best_model.pth')
            print(f"Saved best model (Spearman: {best_spearman:.4f})")
        
        # Early stopping
        if early_stopping(val_metrics['spearman'], epoch):
            print(f"\nEarly stopping triggered at epoch {epoch+1}")
            break
    
    scheduler.step()
    print(f"Train Loss: {train_loss:.4f} | LR: {optimizer.param_groups[0]['lr']:.6f}")
    
    # Save epoch checkpoint
    torch.save({
        'epoch': epoch, 'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'best_val_spearman': best_spearman
    }, output_dir / 'checkpoint_latest.pth')

print(f"\n{'='*70}")
print("TRAINING COMPLETE!")
print(f"Best Validation Spearman: {best_spearman:.4f}")
print(f"{'='*70}")

## 6. Final Evaluation

In [None]:
# Load best model and evaluate
print("\nLoading best model for final evaluation...")
best_model_path = output_dir / 'best_model.pth'
if best_model_path.exists():
    ckpt = torch.load(best_model_path, map_location=device)
    model.load_state_dict(ckpt['model_state_dict'])
    print(f"Loaded best model from epoch {ckpt.get('epoch', 'unknown')+1}")

# Full validation evaluation
print(f"\n{'='*70}")
print("FULL VALIDATION SET EVALUATION")
print(f"{'='*70}")
val_metrics, val_preds, val_targets = full_eval(model, val_loader_full, device, "Validation")
print(f"\nValidation Results ({len(val_targets):,} samples):")
print(f"  Spearman: {val_metrics['spearman']:.4f}")
print(f"  Pearson:  {val_metrics['pearson']:.4f}")
print(f"  RMSE:     {val_metrics['rmse']:.4f}")
print(f"  MAE:      {val_metrics['mae']:.4f}")
print(f"  R2:       {val_metrics['r2']:.4f}")
print(f"  Recall@pKd>=9: {val_metrics['recall_pkd9']:.1f}%")

# Test set evaluation
print(f"\n{'='*70}")
print("TEST SET EVALUATION (UNSEEN DATA)")
print(f"{'='*70}")
test_metrics, test_preds, test_targets = full_eval(model, test_loader, device, "Test")
print(f"\nTest Results ({len(test_targets):,} samples):")
print(f"  Spearman: {test_metrics['spearman']:.4f}  <- TRUE PERFORMANCE")
print(f"  Pearson:  {test_metrics['pearson']:.4f}")
print(f"  RMSE:     {test_metrics['rmse']:.4f}")
print(f"  MAE:      {test_metrics['mae']:.4f}")
print(f"  R2:       {test_metrics['r2']:.4f}")
print(f"  Recall@pKd>=9: {test_metrics['recall_pkd9']:.1f}%")

# Save predictions
pd.DataFrame({'true_pKd': val_targets, 'pred_pKd': val_preds}).to_csv(output_dir / 'val_predictions.csv', index=False)
pd.DataFrame({'true_pKd': test_targets, 'pred_pKd': test_preds}).to_csv(output_dir / 'test_predictions.csv', index=False)

# Save metrics
with open(output_dir / 'final_metrics.json', 'w') as f:
    json.dump({'validation': val_metrics, 'test': test_metrics}, f, indent=2)

print(f"\n{'='*70}")
print("EVALUATION COMPLETE!")
print(f"Results saved to: {output_dir}")
print(f"{'='*70}")

## 7. Download Results (Optional)

Run this cell to download the trained model and results:

In [None]:
# List all output files
print("Output files:")
for f in sorted(output_dir.glob('*')):
    size_mb = f.stat().st_size / 1e6
    print(f"  {f.name} ({size_mb:.1f} MB)")

print(f"\nFiles are saved in Google Drive: {output_dir}")
print("You can download them directly from Drive or use:")
print("  from google.colab import files")
print("  files.download('outputs/best_model.pth')")