# Antibody-Antigen Binding Prediction - Optuna + ProtT5 v2.8 (OPTIMIZED)

**IgT5 + ProtT5-XL with Maximum Speed Optimizations**

## Speed Optimizations Applied:
1. **Half-Precision Encoder** (`prot_t5_xl_half_uniref50-enc`) - 50% faster, uses only 8GB VRAM!
2. **SDPA (Scaled Dot-Product Attention)** - Native PyTorch optimized attention
3. **Gradient Checkpointing** - Trade 20% speed for 60% memory savings (optional)
4. **torch.compile** - JIT compilation for faster execution (A100 optimized)
5. **BetterTransformer** - Fused kernels via optimum library (optional)
6. **bfloat16 Mixed Precision** - A100 native precision

## Architecture (v2.8):
- **Antibody encoder**: IgT5 (1024-dim)
- **Antigen encoder**: ProtT5-XL-Half (1024-dim) - OPTIMIZED!
- **Cross-attention fusion**: Bidirectional attention
- **Multi-task output**: Regression (pKd) + Classification
- **Loss**: MSE (0.7) + BCE (0.3)

## Expected Speed on A100:
| Optimization | Speed | Memory |
|--------------|-------|--------|
| Baseline ProtT5 | 1.0x | 40GB |
| + Half-Precision | 1.5x | 20GB |
| + SDPA | 1.8x | 18GB |
| + torch.compile | 2.2x | 18GB |

## Estimated Training Time (A100 40GB):
- Optuna tuning (15 trials): ~20-30 min
- Final training (50 epochs): ~3-5 hours

## 1. Setup

In [None]:
from google.colab import drive
drive.mount('/content/drive')

import os
work_dir = '/content/drive/MyDrive/AbAg_Training'
os.makedirs(work_dir, exist_ok=True)
%cd {work_dir}

dataset_file = 'ab_ag_affinity_complete.csv'
if os.path.exists(dataset_file):
    print(f"Dataset found: {dataset_file}")
    !wc -l {dataset_file}
else:
    print(f"ERROR: Upload '{dataset_file}' to 'AbAg_Training' folder")

In [None]:
# Install dependencies with optimization libraries
!pip install -q transformers>=4.41.0 pandas scipy scikit-learn tqdm sentencepiece optuna
!pip install -q optimum accelerate  # BetterTransformer + acceleration

print("="*60)
print("Dependencies installed!")
print("="*60)
print("  transformers: Model loading + SDPA support")
print("  optimum: BetterTransformer optimization")
print("  accelerate: Mixed precision + device placement")
print("="*60)

## 2. GPU Detection & Configuration

In [None]:
import torch
import subprocess

print("="*60)
print("GPU DETECTION")
print("="*60)

if torch.cuda.is_available():
    gpu_name = torch.cuda.get_device_name(0)
    print(f"GPU: {gpu_name}")
    
    try:
        result = subprocess.run(['nvidia-smi', '--query-gpu=memory.total,memory.free', 
                                '--format=csv,nounits,noheader'],
                               capture_output=True, text=True, timeout=10)
        parts = result.stdout.strip().split(',')
        vram_total = int(parts[0].strip()) / 1024
        vram_free = int(parts[1].strip()) / 1024
        print(f"VRAM: {vram_total:.1f} GB total, {vram_free:.1f} GB free")
    except:
        vram_free = 40
    
    is_a100 = 'A100' in gpu_name
    if is_a100:
        print("\nA100 DETECTED - Using optimized settings!")
        BATCH_OPTIONS = [24, 32, 48]  # A100 can handle larger batches
        USE_QUANTIZATION = False
    elif vram_free >= 20:
        print("\nHigh VRAM GPU detected")
        BATCH_OPTIONS = [16, 24, 32]
        USE_QUANTIZATION = False
    else:
        print("\nStandard GPU detected")
        BATCH_OPTIONS = [8, 12, 16]
        USE_QUANTIZATION = True
else:
    print("No GPU - using CPU")
    BATCH_OPTIONS = [4, 8]
    USE_QUANTIZATION = False

print(f"\nBatch size options: {BATCH_OPTIONS}")
print(f"Use quantization: {USE_QUANTIZATION}")
print("="*60)

In [None]:
# === NARROWED CONFIGURATION (Research-Based for ProtT5 v2.8) ===
#
# These ranges are based on:
# - Multi-task Bioassay Pre-training (MBP) 2024: lr=1e-3, dropout=0.1
# - ProtT5 fine-tuning papers: lower dropout than ESM-2
# - Your working v2.7 config (lr=1e-3, dropout=0.1, weight_decay=1e-5)
#
# Key differences from ESM-2:
# - ProtT5 works better with LOWER dropout (0.1 vs 0.3)
# - Same T5 architecture for both encoders -> more stable training

CONFIG = {
    'data_path': 'ab_ag_affinity_complete.csv',
    'output_dir': 'outputs_optuna_v28',
    
    # Optuna Settings
    'n_trials': 15,           # Fewer trials needed with narrow ranges
    'optuna_epochs': 5,       # Quick evaluation per trial
    'optuna_patience': 4,     # Stop if no improvement
    
    # Final Training
    'final_epochs': 50,
    'early_stopping_patience': 15,
    
    # ===== NARROWED SEARCH RANGES (ProtT5 optimized) =====
    # Learning rate: 1e-3 to 5e-3 (MBP 2024 uses 1e-3)
    'lr_range': (1e-3, 5e-3),
    
    # Dropout: 0.10 to 0.30 (ProtT5 works better with less dropout)
    # MBP 2024 uses 0.1, your v2.7 uses 0.1
    'dropout_range': (0.10, 0.30),
    
    # Weight decay: 1e-5 to 0.01 (tighter range)
    'weight_decay_range': (1e-5, 0.01),
    
    # Batch size: A100 optimized options
    'batch_size_options': BATCH_OPTIONS,
    
    # Cross-attention options
    'use_cross_attention': True,
    
    # Loss weights (MSE + BCE)
    'mse_weight': 0.7,
    'class_weight': 0.3,
}

print("NARROWED Configuration (ProtT5 v2.8):")
print(f"  Learning rate: {CONFIG['lr_range']} (centered on 1e-3)")
print(f"  Dropout: {CONFIG['dropout_range']} (lower for ProtT5)")
print(f"  Weight decay: {CONFIG['weight_decay_range']}")
print(f"  Batch sizes: {CONFIG['batch_size_options']}")
print(f"  Trials: {CONFIG['n_trials']}")
print(f"  Loss: MSE({CONFIG['mse_weight']}) + BCE({CONFIG['class_weight']})")

## 3. Training Code

In [None]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import pandas as pd
import numpy as np
from scipy import stats
from sklearn.model_selection import train_test_split
from sklearn.metrics import mean_squared_error, mean_absolute_error, r2_score
from tqdm import tqdm
from pathlib import Path
import gc
from transformers import T5EncoderModel, T5Tokenizer
import optuna
from optuna.pruners import MedianPruner
from optuna.samplers import TPESampler
import json
import os

# ============================================================================
# OPTIMIZATION SETTINGS
# ============================================================================
OPTIMIZATION_CONFIG = {
    # Use half-precision encoder (50% faster, 50% less memory)
    # Source: https://huggingface.co/Rostlab/prot_t5_xl_half_uniref50-enc
    'use_half_precision_encoder': True,
    
    # Enable SDPA (Scaled Dot-Product Attention) - native PyTorch optimization
    # Automatically uses FlashAttention-like kernels when available
    'use_sdpa': True,
    
    # Enable gradient checkpointing (trades 20% speed for 60% memory)
    # Useful if running out of memory
    'use_gradient_checkpointing': False,
    
    # Enable torch.compile (JIT compilation, ~20% speedup)
    # Set to False if you encounter CUDA graph errors
    'use_torch_compile': True,
    
    # Enable BetterTransformer (fused kernels)
    # Deprecated in favor of SDPA, but still useful for some models
    'use_better_transformer': False,
}

print("="*60)
print("PROTRANS OPTIMIZATION CONFIG")
print("="*60)
for k, v in OPTIMIZATION_CONFIG.items():
    status = "ON" if v else "OFF"
    print(f"  {k}: {status}")
print("="*60)

# A100/TF32 optimizations
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
torch.backends.cudnn.benchmark = True
torch.set_float32_matmul_precision('high')

# Disable torch.compile debug if not using it
if not OPTIMIZATION_CONFIG['use_torch_compile']:
    import torch._dynamo
    torch._dynamo.config.suppress_errors = True

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"\nDevice: {device}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")

# ============================================================================
# CROSS-ATTENTION MODULE
# ============================================================================
class CrossAttention(nn.Module):
    """Cross-attention between antibody and antigen embeddings"""
    def __init__(self, dim, num_heads=8, dropout=0.1):
        super().__init__()
        self.attention = nn.MultiheadAttention(dim, num_heads, dropout=dropout, batch_first=True)
        self.norm1 = nn.LayerNorm(dim)
        self.norm2 = nn.LayerNorm(dim)
        self.ffn = nn.Sequential(
            nn.Linear(dim, dim * 4),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(dim * 4, dim),
            nn.Dropout(dropout)
        )
    
    def forward(self, query, key_value):
        attn_out, _ = self.attention(query, key_value, key_value)
        query = self.norm1(query + attn_out)
        ffn_out = self.ffn(query)
        query = self.norm2(query + ffn_out)
        return query

# ============================================================================
# OPTIMIZED MODEL - IgT5 + ProtT5-XL-Half with Speed Optimizations
# ============================================================================
class IgT5ProtT5Model(nn.Module):
    """
    OPTIMIZED Dual-encoder model for antibody-antigen binding prediction.
    
    Speed Optimizations:
    1. prot_t5_xl_half_uniref50-enc: Half-precision encoder (50% faster)
    2. SDPA: Scaled dot-product attention (native PyTorch)
    3. Gradient checkpointing: Trade speed for memory (optional)
    4. torch.compile: JIT compilation (A100 optimized)
    5. BetterTransformer: Fused kernels (optional)
    
    Architecture:
    - Antibody: IgT5 (1024-dim)
    - Antigen: ProtT5-XL-Half (1024-dim) - OPTIMIZED
    - Cross-attention + Multi-task output
    """
    def __init__(self, dropout=0.1, hidden_dims=[512, 256, 128], use_cross_attention=True):
        super().__init__()
        
        print("\n" + "="*60)
        print("BUILDING OPTIMIZED ProtT5 MODEL")
        print("="*60)
        
        # Determine which ProtT5 model to use
        if OPTIMIZATION_CONFIG['use_half_precision_encoder']:
            ag_model_name = "Rostlab/prot_t5_xl_half_uniref50-enc"
            ag_dtype = torch.float16
            print("  Antigen encoder: ProtT5-XL-Half (OPTIMIZED)")
            print("    - 50% less memory")
            print("    - 50% faster inference")
        else:
            ag_model_name = "Rostlab/prot_t5_xl_uniref50"
            ag_dtype = torch.float32
            print("  Antigen encoder: ProtT5-XL (full precision)")
        
        # Determine attention implementation
        attn_impl = "sdpa" if OPTIMIZATION_CONFIG['use_sdpa'] else "eager"
        print(f"  Attention: {attn_impl.upper()}")
        
        # IgT5 for antibody
        print("\n  Loading IgT5 for antibodies...")
        self.ab_tokenizer = T5Tokenizer.from_pretrained("Exscientia/IgT5")
        self.ab_model = T5EncoderModel.from_pretrained(
            "Exscientia/IgT5",
            attn_implementation=attn_impl,
            torch_dtype=torch.float16 if OPTIMIZATION_CONFIG['use_half_precision_encoder'] else torch.float32
        )
        self.ab_dim = 1024
        
        # ProtT5 for antigen (OPTIMIZED)
        print(f"  Loading {ag_model_name}...")
        self.ag_tokenizer = T5Tokenizer.from_pretrained("Rostlab/prot_t5_xl_uniref50")
        self.ag_model = T5EncoderModel.from_pretrained(
            ag_model_name,
            attn_implementation=attn_impl,
            torch_dtype=ag_dtype
        )
        self.ag_dim = 1024
        
        # Enable gradient checkpointing if requested
        if OPTIMIZATION_CONFIG['use_gradient_checkpointing']:
            self.ab_model.gradient_checkpointing_enable()
            self.ag_model.gradient_checkpointing_enable()
            print("  Gradient checkpointing: ENABLED (saves 60% memory)")
        
        # Apply BetterTransformer if requested
        if OPTIMIZATION_CONFIG['use_better_transformer']:
            try:
                from optimum.bettertransformer import BetterTransformer
                self.ab_model = BetterTransformer.transform(self.ab_model)
                self.ag_model = BetterTransformer.transform(self.ag_model)
                print("  BetterTransformer: ENABLED (fused kernels)")
            except Exception as e:
                print(f"  BetterTransformer: FAILED ({e})")
        
        # Freeze encoders (only train projection + attention + heads)
        for p in self.ab_model.parameters(): p.requires_grad = False
        for p in self.ag_model.parameters(): p.requires_grad = False
        print("  Encoders: FROZEN (only train MLP heads)")
        
        # Projection to common dimension
        self.common_dim = 512
        self.ab_proj = nn.Linear(self.ab_dim, self.common_dim)
        self.ag_proj = nn.Linear(self.ag_dim, self.common_dim)
        
        # Cross-attention (optional)
        self.use_cross_attention = use_cross_attention
        if use_cross_attention:
            self.cross_attn_ab = CrossAttention(self.common_dim, num_heads=8, dropout=dropout)
            self.cross_attn_ag = CrossAttention(self.common_dim, num_heads=8, dropout=dropout)
            print("  Cross-attention: ENABLED")
        
        # Regression head with spectral normalization
        self.regression_head = nn.Sequential(
            nn.utils.spectral_norm(nn.Linear(self.common_dim * 2, hidden_dims[0])),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.LayerNorm(hidden_dims[0]),
            
            nn.utils.spectral_norm(nn.Linear(hidden_dims[0], hidden_dims[1])),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.LayerNorm(hidden_dims[1]),
            
            nn.utils.spectral_norm(nn.Linear(hidden_dims[1], hidden_dims[2])),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.LayerNorm(hidden_dims[2]),
            
            nn.Linear(hidden_dims[2], 1)
        )
        
        # Classification head
        self.classifier = nn.Linear(self.common_dim * 2, 1)
        
        trainable = sum(p.numel() for p in self.parameters() if p.requires_grad)
        total = sum(p.numel() for p in self.parameters())
        print(f"\n  Trainable: {trainable/1e6:.1f}M | Total: {total/1e6:.1f}M")
        print("="*60)
    
    def forward(self, ab_seqs, ag_seqs, device):
        # Tokenize antibodies
        ab_tokens = self.ab_tokenizer(
            ab_seqs, return_tensors='pt', padding=True,
            truncation=True, max_length=512
        ).to(device)
        
        # Tokenize antigens - ProtT5 expects SPACE-SEPARATED amino acids
        ag_seqs_spaced = [" ".join(list(seq)) for seq in ag_seqs]
        ag_tokens = self.ag_tokenizer(
            ag_seqs_spaced, return_tensors='pt', padding=True,
            truncation=True, max_length=2048
        ).to(device)
        
        # Encode (frozen encoders, no gradient)
        with torch.no_grad():
            ab_out = self.ab_model(**ab_tokens).last_hidden_state
            ag_out = self.ag_model(**ag_tokens).last_hidden_state
        
        # Mean pooling
        ab_emb = ab_out.mean(dim=1)
        ag_emb = ag_out.mean(dim=1)
        
        # Cast to bfloat16 for trainable layers (A100 optimized)
        ab_emb = ab_emb.to(torch.bfloat16)
        ag_emb = ag_emb.to(torch.bfloat16)
        
        # Project to common dimension
        ab_proj = self.ab_proj(ab_emb)
        ag_proj = self.ag_proj(ag_emb)
        
        # Cross-attention
        if self.use_cross_attention:
            ab_proj = ab_proj.unsqueeze(1)
            ag_proj = ag_proj.unsqueeze(1)
            
            ab_enhanced = self.cross_attn_ab(ab_proj, ag_proj).squeeze(1)
            ag_enhanced = self.cross_attn_ag(ag_proj, ab_proj).squeeze(1)
            
            combined = torch.cat([ab_enhanced, ag_enhanced], dim=1)
        else:
            combined = torch.cat([ab_proj, ag_proj], dim=1)
        
        # Predictions
        pKd_pred = self.regression_head(combined).squeeze(-1)
        class_logits = self.classifier(combined).squeeze(-1)
        
        return pKd_pred.float(), class_logits.float()

# ============================================================================
# STABLE LOSS FUNCTION (MSE + BCE)
# ============================================================================
class StableCombinedLoss(nn.Module):
    """MSE + BCE loss for multi-task learning"""
    def __init__(self, mse_weight=0.7, class_weight=0.3):
        super().__init__()
        self.mse = nn.MSELoss()
        self.bce = nn.BCEWithLogitsLoss()
        self.mse_weight = mse_weight
        self.class_weight = class_weight
    
    def forward(self, pred, target, class_logits=None):
        mse_loss = self.mse(pred, target)
        loss = self.mse_weight * mse_loss
        
        if class_logits is not None:
            class_target = (target >= 9.0).float()
            class_loss = self.bce(class_logits, class_target)
            loss += self.class_weight * class_loss
        
        return loss

# ============================================================================
# DATASET
# ============================================================================
class AbAgDataset(Dataset):
    def __init__(self, df):
        self.df = df.reset_index(drop=True)
    def __len__(self): return len(self.df)
    def __getitem__(self, i):
        return {'ab': self.df.iloc[i]['antibody_sequence'],
                'ag': self.df.iloc[i]['antigen_sequence'],
                'y': torch.tensor(self.df.iloc[i]['pKd'], dtype=torch.float32)}

def collate(batch):
    return {'ab': [b['ab'] for b in batch], 'ag': [b['ag'] for b in batch],
            'y': torch.stack([b['y'] for b in batch])}

print("\n" + "="*60)
print("OPTIMIZED ProtT5 v2.8 CODE LOADED")
print("="*60)
print("Speed optimizations:")
print("  1. Half-precision encoder (prot_t5_xl_half)")
print("  2. SDPA attention (native PyTorch)")
print("  3. Gradient checkpointing (optional)")
print("  4. torch.compile (JIT, optional)")
print("  5. BetterTransformer (fused kernels, optional)")
print("="*60)

## 4. Optuna Hyperparameter Tuning

In [None]:
# Global model cache (avoid reloading frozen encoders during Optuna trials)
MODEL_CACHE = None
COMPILED_FORWARD = None

def get_model(dropout, hidden_dims, use_cross_attention, use_cache=True):
    """Get model with optimizations, reusing frozen encoders from cache."""
    global MODEL_CACHE, COMPILED_FORWARD
    
    if use_cache and MODEL_CACHE is not None:
        # Reuse frozen encoders, only rebuild trainable layers
        m = MODEL_CACHE
        
        # Rebuild projection layers
        m.ab_proj = nn.Linear(m.ab_dim, m.common_dim).to(device).to(torch.bfloat16)
        m.ag_proj = nn.Linear(m.ag_dim, m.common_dim).to(device).to(torch.bfloat16)
        
        # Rebuild cross-attention
        m.use_cross_attention = use_cross_attention
        if use_cross_attention:
            m.cross_attn_ab = CrossAttention(m.common_dim, num_heads=8, dropout=dropout).to(device).to(torch.bfloat16)
            m.cross_attn_ag = CrossAttention(m.common_dim, num_heads=8, dropout=dropout).to(device).to(torch.bfloat16)
        
        # Rebuild regression head
        m.regression_head = nn.Sequential(
            nn.utils.spectral_norm(nn.Linear(m.common_dim * 2, hidden_dims[0])),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.LayerNorm(hidden_dims[0]),
            
            nn.utils.spectral_norm(nn.Linear(hidden_dims[0], hidden_dims[1])),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.LayerNorm(hidden_dims[1]),
            
            nn.utils.spectral_norm(nn.Linear(hidden_dims[1], hidden_dims[2])),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.LayerNorm(hidden_dims[2]),
            
            nn.Linear(hidden_dims[2], 1)
        ).to(device).to(torch.bfloat16)
        
        # Rebuild classifier
        m.classifier = nn.Linear(m.common_dim * 2, 1).to(device).to(torch.bfloat16)
        
        return m
    
    # First call: create full model
    m = IgT5ProtT5Model(dropout, hidden_dims, use_cross_attention).to(device)
    
    # Cast trainable layers to bfloat16
    m.ab_proj = m.ab_proj.to(torch.bfloat16)
    m.ag_proj = m.ag_proj.to(torch.bfloat16)
    if m.use_cross_attention:
        m.cross_attn_ab = m.cross_attn_ab.to(torch.bfloat16)
        m.cross_attn_ag = m.cross_attn_ag.to(torch.bfloat16)
    m.regression_head = m.regression_head.to(torch.bfloat16)
    m.classifier = m.classifier.to(torch.bfloat16)
    
    MODEL_CACHE = m
    return m

def objective(trial, train_df, val_df):
    """Optuna objective with speed optimizations."""
    
    # ===== NARROWED HYPERPARAMETERS =====
    lr = trial.suggest_float('lr', CONFIG['lr_range'][0], CONFIG['lr_range'][1], log=True)
    dropout = trial.suggest_float('dropout', CONFIG['dropout_range'][0], CONFIG['dropout_range'][1])
    wd = trial.suggest_float('weight_decay', CONFIG['weight_decay_range'][0], CONFIG['weight_decay_range'][1], log=True)
    bs = trial.suggest_categorical('batch_size', CONFIG['batch_size_options'])
    
    # Hidden layer config
    hidden_cfg = trial.suggest_categorical('hidden', ['medium', 'large'])
    hidden_map = {
        'medium': [512, 256, 128],
        'large': [512, 512, 256]
    }
    hidden_dims = hidden_map[hidden_cfg]
    use_cross_attention = CONFIG['use_cross_attention']
    
    print(f"\nTrial {trial.number}: lr={lr:.4f}, dropout={dropout:.2f}, wd={wd:.5f}, bs={bs}, hidden={hidden_cfg}")
    
    try:
        model = get_model(dropout, hidden_dims, use_cross_attention)
        
        # Apply torch.compile for speed (A100 optimized)
        if OPTIMIZATION_CONFIG['use_torch_compile']:
            try:
                model = torch.compile(model, mode='reduce-overhead')
                print("  torch.compile: ENABLED")
            except Exception as e:
                print(f"  torch.compile: FAILED ({e})")
        
        train_loader = DataLoader(AbAgDataset(train_df), batch_size=bs, shuffle=True, 
                                  num_workers=2, pin_memory=True, collate_fn=collate)
        val_loader = DataLoader(AbAgDataset(val_df), batch_size=bs*2, shuffle=False, 
                                num_workers=2, pin_memory=True, collate_fn=collate)
        
        opt = torch.optim.AdamW(
            [p for p in model.parameters() if p.requires_grad], 
            lr=lr, weight_decay=wd,
            fused=True  # A100 fused optimizer
        )
        loss_fn = StableCombinedLoss(
            mse_weight=CONFIG['mse_weight'], 
            class_weight=CONFIG['class_weight']
        )
        best = -1
        
        for epoch in range(CONFIG['optuna_epochs']):
            model.train()
            for batch in train_loader:
                y = batch['y'].to(device)
                
                with torch.amp.autocast('cuda', dtype=torch.bfloat16):
                    pKd_pred, class_logits = model(batch['ab'], batch['ag'], device)
                    loss = loss_fn(pKd_pred, y, class_logits)
                
                loss.backward()
                torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
                opt.step()
                opt.zero_grad(set_to_none=True)
            
            # Validate
            model.eval()
            preds, targs = [], []
            with torch.no_grad():
                for batch in val_loader:
                    with torch.amp.autocast('cuda', dtype=torch.bfloat16):
                        pKd_pred, _ = model(batch['ab'], batch['ag'], device)
                    preds.extend(pKd_pred.float().cpu().numpy())
                    targs.extend(batch['y'].numpy())
            
            spearman = stats.spearmanr(targs, preds)[0]
            if not np.isnan(spearman):
                best = max(best, spearman)
            
            trial.report(spearman, epoch)
            if trial.should_prune():
                raise optuna.TrialPruned()
        
        print(f"  -> Spearman = {best:.4f}")
        del train_loader, val_loader, opt
        gc.collect()
        torch.cuda.empty_cache()
        return best
        
    except Exception as e:
        print(f"  -> Failed: {e}")
        gc.collect()
        torch.cuda.empty_cache()
        return -1

print("Optuna objective ready with speed optimizations!")

In [None]:
print("\n" + "="*60)
print("OPTUNA HYPERPARAMETER SEARCH (ProtT5 v2.8)")
print("="*60)

# Load data
df = pd.read_csv(CONFIG['data_path'])
print(f"Total samples: {len(df):,}")

# Filter invalid pKd values (critical for stable training!)
df = df[(df['pKd'] >= 4.0) & (df['pKd'] <= 14.0)].reset_index(drop=True)
print(f"Valid samples (pKd 4-14): {len(df):,}")

# Subset for tuning (faster Optuna trials)
tune_df = df.sample(n=min(50000, len(df)), random_state=42)
train_tune, val_tune = train_test_split(tune_df, test_size=0.2, random_state=42)
print(f"Tuning set: {len(train_tune):,} train, {len(val_tune):,} val")

# Print search space
print(f"\nSearch Space (NARROWED for ProtT5):")
print(f"  LR: {CONFIG['lr_range'][0]:.4f} - {CONFIG['lr_range'][1]:.4f}")
print(f"  Dropout: {CONFIG['dropout_range'][0]:.2f} - {CONFIG['dropout_range'][1]:.2f}")
print(f"  Weight Decay: {CONFIG['weight_decay_range'][0]:.5f} - {CONFIG['weight_decay_range'][1]:.4f}")
print(f"  Batch Size: {CONFIG['batch_size_options']}")
print(f"  Hidden: ['medium', 'large']")
print(f"  Loss: MSE({CONFIG['mse_weight']}) + BCE({CONFIG['class_weight']})")

# Early stopping for Optuna
class OptunaStop:
    def __init__(self, patience=4):
        self.patience = patience
        self.best = -float('inf')
        self.count = 0
    def __call__(self, study, trial):
        if trial.value and trial.value > self.best:
            self.best = trial.value
            self.count = 0
            print(f"\n*** NEW BEST: Spearman = {trial.value:.4f} ***")
        else:
            self.count += 1
        if self.count >= self.patience:
            print(f"\nStopping: no improvement for {self.patience} trials")
            study.stop()

# Run Optuna
print("\n" + "-"*60)
study = optuna.create_study(
    direction='maximize', 
    sampler=TPESampler(seed=42),
    pruner=MedianPruner(n_startup_trials=3, n_warmup_steps=2)
)
study.optimize(
    lambda t: objective(t, train_tune, val_tune), 
    n_trials=CONFIG['n_trials'],
    callbacks=[OptunaStop(CONFIG['optuna_patience'])], 
    show_progress_bar=True
)

# Results
print("\n" + "="*60)
print("BEST HYPERPARAMETERS FOUND")
print("="*60)
print(f"Best Spearman: {study.best_value:.4f}")
for k, v in study.best_params.items():
    if isinstance(v, float):
        print(f"  {k}: {v:.6f}")
    else:
        print(f"  {k}: {v}")

best_params = study.best_params
out_dir = Path(CONFIG['output_dir'])
out_dir.mkdir(exist_ok=True)
with open(out_dir / 'best_params.json', 'w') as f:
    json.dump(best_params, f, indent=2)
print(f"\nSaved to: {out_dir / 'best_params.json'}")

## 5. Final Training with Best Parameters

In [None]:
print("\n" + "="*60)
print("FINAL TRAINING WITH BEST HYPERPARAMETERS (OPTIMIZED)")
print("="*60)

# Best params
lr = best_params['lr']
dropout = best_params['dropout']
wd = best_params['weight_decay']
bs = best_params['batch_size']
hidden_map = {'medium': [512, 256, 128], 'large': [512, 512, 256]}
hidden_dims = hidden_map[best_params['hidden']]
use_cross_attention = CONFIG['use_cross_attention']

print(f"\nHyperparameters:")
print(f"  Learning Rate: {lr:.6f}")
print(f"  Dropout: {dropout:.3f}")
print(f"  Weight Decay: {wd:.6f}")
print(f"  Batch Size: {bs}")
print(f"  Hidden: {hidden_dims}")
print(f"  Cross-Attention: {use_cross_attention}")

print(f"\nOptimizations:")
for k, v in OPTIMIZATION_CONFIG.items():
    status = "ON" if v else "OFF"
    print(f"  {k}: {status}")

# Full data split
train_df, temp_df = train_test_split(df, test_size=0.3, random_state=42)
val_df, test_df = train_test_split(temp_df, test_size=0.5, random_state=42)
val_quick = val_df.sample(frac=0.1, random_state=42)

print(f"\nData Split:")
print(f"  Train: {len(train_df):,}")
print(f"  Val: {len(val_df):,}")
print(f"  Test: {len(test_df):,}")

# Fresh model with best params (clear cache)
global MODEL_CACHE
MODEL_CACHE = None

print("\nLoading OPTIMIZED ProtT5 v2.8 model...")
model = IgT5ProtT5Model(dropout, hidden_dims, use_cross_attention).to(device)

# Cast trainable layers to bfloat16
model.ab_proj = model.ab_proj.to(torch.bfloat16)
model.ag_proj = model.ag_proj.to(torch.bfloat16)
if model.use_cross_attention:
    model.cross_attn_ab = model.cross_attn_ab.to(torch.bfloat16)
    model.cross_attn_ag = model.cross_attn_ag.to(torch.bfloat16)
model.regression_head = model.regression_head.to(torch.bfloat16)
model.classifier = model.classifier.to(torch.bfloat16)

# Apply torch.compile for maximum speed
if OPTIMIZATION_CONFIG['use_torch_compile']:
    print("\nApplying torch.compile (this may take a moment on first run)...")
    try:
        model = torch.compile(model, mode='reduce-overhead')
        print("  torch.compile: SUCCESS - ~20% speedup expected")
    except Exception as e:
        print(f"  torch.compile: FAILED ({e}) - continuing without JIT")

# Data loaders with optimized settings
train_loader = DataLoader(
    AbAgDataset(train_df), batch_size=bs, shuffle=True, 
    num_workers=4, pin_memory=True, collate_fn=collate,
    prefetch_factor=2, persistent_workers=True
)
val_loader_q = DataLoader(
    AbAgDataset(val_quick), batch_size=bs*2, shuffle=False, 
    num_workers=4, pin_memory=True, collate_fn=collate
)
val_loader_f = DataLoader(
    AbAgDataset(val_df), batch_size=bs*2, shuffle=False, 
    num_workers=4, pin_memory=True, collate_fn=collate
)
test_loader = DataLoader(
    AbAgDataset(test_df), batch_size=bs*2, shuffle=False, 
    num_workers=4, pin_memory=True, collate_fn=collate
)

# Optimizer with A100 fused AdamW
opt = torch.optim.AdamW(
    [p for p in model.parameters() if p.requires_grad], 
    lr=lr, weight_decay=wd,
    fused=True  # A100 fused optimizer (~10% speedup)
)

# Scheduler: ReduceLROnPlateau (MBP 2024)
from torch.optim.lr_scheduler import ReduceLROnPlateau
sched = ReduceLROnPlateau(opt, mode='max', factor=0.6, patience=10, min_lr=1e-6)

# Loss
loss_fn = StableCombinedLoss(mse_weight=CONFIG['mse_weight'], class_weight=CONFIG['class_weight'])

# Early stopping
class EarlyStop:
    def __init__(self, patience):
        self.patience = patience
        self.best = -float('inf')
        self.count = 0
    def __call__(self, score):
        if score > self.best:
            self.best = score
            self.count = 0
            return False
        self.count += 1
        return self.count >= self.patience

early = EarlyStop(CONFIG['early_stopping_patience'])
best_spearman = -1

print(f"\nTraining for {CONFIG['final_epochs']} epochs...")
print("(First epoch may be slower due to torch.compile warmup)\n")

import time
epoch_times = []

for epoch in range(CONFIG['final_epochs']):
    epoch_start = time.time()
    
    model.train()
    total_loss = 0
    pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{CONFIG['final_epochs']}")
    
    for batch in pbar:
        y = batch['y'].to(device)
        
        with torch.amp.autocast('cuda', dtype=torch.bfloat16):
            pKd_pred, class_logits = model(batch['ab'], batch['ag'], device)
            loss = loss_fn(pKd_pred, y, class_logits)
        
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        opt.step()
        opt.zero_grad(set_to_none=True)
        
        total_loss += loss.item()
        pbar.set_postfix({'loss': f'{loss.item():.4f}'})
    
    # Quick validation
    model.eval()
    preds, targs = [], []
    with torch.no_grad():
        for batch in val_loader_q:
            with torch.amp.autocast('cuda', dtype=torch.bfloat16):
                pKd_pred, _ = model(batch['ab'], batch['ag'], device)
            preds.extend(pKd_pred.float().cpu().numpy())
            targs.extend(batch['y'].numpy())
    
    spearman = stats.spearmanr(targs, preds)[0]
    current_lr = opt.param_groups[0]['lr']
    epoch_time = time.time() - epoch_start
    epoch_times.append(epoch_time)
    
    print(f"  Loss: {total_loss/len(train_loader):.4f} | Spearman: {spearman:.4f} | LR: {current_lr:.2e} | Time: {epoch_time:.1f}s")
    
    # Step scheduler
    sched.step(spearman)
    
    # Save best
    if spearman > best_spearman:
        best_spearman = spearman
        torch.save({
            'epoch': epoch, 
            'model': model.state_dict(), 
            'spearman': spearman, 
            'params': best_params,
            'config': CONFIG,
            'optimizations': OPTIMIZATION_CONFIG
        }, out_dir / 'best_model.pth')
        print(f"  *** Saved best model (Spearman: {spearman:.4f}) ***")
    
    # Checkpoint
    torch.save({
        'epoch': epoch, 
        'model': model.state_dict(), 
        'opt': opt.state_dict(),
        'best_spearman': best_spearman
    }, out_dir / 'checkpoint.pth')
    
    if early(spearman):
        print(f"\nEarly stopping at epoch {epoch+1}")
        break

avg_time = sum(epoch_times[1:]) / len(epoch_times[1:]) if len(epoch_times) > 1 else epoch_times[0]
print(f"\n" + "="*60)
print(f"Training complete!")
print(f"  Best Spearman: {best_spearman:.4f}")
print(f"  Avg epoch time: {avg_time:.1f}s (excluding warmup)")
print("="*60)

## 6. Final Evaluation

In [None]:
# Load best model
ckpt = torch.load(out_dir / 'best_model.pth', map_location=device)
model.load_state_dict(ckpt['model'])
print(f"Loaded best model from epoch {ckpt['epoch']+1}")
print(f"Best Spearman during training: {ckpt['spearman']:.4f}")

def evaluate(model, loader, name):
    """Evaluate ProtT5 v2.8 model with multi-task output."""
    model.eval()
    preds, class_preds, targs = [], [], []
    
    with torch.no_grad():
        for batch in tqdm(loader, desc=name):
            with torch.amp.autocast('cuda', dtype=torch.bfloat16):
                pKd_pred, class_logits = model(batch['ab'], batch['ag'], device)
            preds.extend(pKd_pred.float().cpu().numpy())
            class_preds.extend(torch.sigmoid(class_logits).float().cpu().numpy())
            targs.extend(batch['y'].numpy())
    
    t, p = np.array(targs), np.array(preds)
    c = np.array(class_preds)
    
    # Classification metrics at pKd=9
    strong = t >= 9.0
    pred_strong = c >= 0.5
    tp = np.sum(strong & pred_strong)
    fn = np.sum(strong & ~pred_strong)
    fp = np.sum(~strong & pred_strong)
    
    recall = tp / (tp + fn) if (tp + fn) > 0 else 0
    precision = tp / (tp + fp) if (tp + fp) > 0 else 0
    
    return {
        'spearman': stats.spearmanr(t, p)[0],
        'pearson': stats.pearsonr(t, p)[0],
        'rmse': np.sqrt(mean_squared_error(t, p)),
        'mae': mean_absolute_error(t, p),
        'r2': r2_score(t, p),
        'recall': recall * 100,
        'precision': precision * 100
    }, p, t

# Validation
print("\n" + "="*60)
print("VALIDATION SET RESULTS")
print("="*60)
val_m, val_p, val_t = evaluate(model, val_loader_f, "Validation")
print(f"  Spearman:  {val_m['spearman']:.4f}")
print(f"  Pearson:   {val_m['pearson']:.4f}")
print(f"  RMSE:      {val_m['rmse']:.4f}")
print(f"  MAE:       {val_m['mae']:.4f}")
print(f"  R2:        {val_m['r2']:.4f}")
print(f"  Recall:    {val_m['recall']:.1f}%")
print(f"  Precision: {val_m['precision']:.1f}%")

# Test
print("\n" + "="*60)
print("TEST SET RESULTS (UNSEEN DATA)")
print("="*60)
test_m, test_p, test_t = evaluate(model, test_loader, "Test")
print(f"  Spearman:  {test_m['spearman']:.4f}")
print(f"  Pearson:   {test_m['pearson']:.4f}")
print(f"  RMSE:      {test_m['rmse']:.4f}")
print(f"  MAE:       {test_m['mae']:.4f}")
print(f"  R2:        {test_m['r2']:.4f}")
print(f"  Recall:    {test_m['recall']:.1f}%")
print(f"  Precision: {test_m['precision']:.1f}%")

# Save predictions and results
pd.DataFrame({'true_pKd': val_t, 'pred_pKd': val_p}).to_csv(out_dir / 'val_predictions.csv', index=False)
pd.DataFrame({'true_pKd': test_t, 'pred_pKd': test_p}).to_csv(out_dir / 'test_predictions.csv', index=False)

with open(out_dir / 'final_results.json', 'w') as f:
    json.dump({
        'best_params': best_params,
        'config': {
            'mse_weight': CONFIG['mse_weight'],
            'class_weight': CONFIG['class_weight'],
            'use_cross_attention': CONFIG['use_cross_attention'],
        },
        'validation': {k: float(v) for k, v in val_m.items()},
        'test': {k: float(v) for k, v in test_m.items()},
        'architecture': 'IgT5 + ProtT5-XL (v2.8)'
    }, f, indent=2)

print(f"\nAll results saved to: {out_dir}")
print("\n" + "="*60)
print("DONE! ProtT5 v2.8 Training Complete")
print("="*60)