# üß¨ Antibody-Antigen Binding Prediction - OPTIMIZED Training

## A100 + Google Drive + ESM-2 3B + Optuna Optimization

**This notebook includes all optimizations:**
- ‚úÖ **Google Drive integration** - Auto-loads data, auto-saves results
- ‚úÖ **A100-80GB optimized** - TF32, large batches, optimized memory
- ‚úÖ **ESM-2 3B model** - State-of-the-art protein encoder
- ‚úÖ **Optuna hyperparameter optimization** - Find optimal training settings
- ‚úÖ **Fixed training issues** - Proper warmup, patience, validation

**Key improvements over COMPLETE version:**
- üîß **Full validation** for early stopping (not 5% sample)
- üîß **Reduced warmup** to prevent early stopping during warmup
- üîß **ReduceLROnPlateau** for dynamic LR adjustment
- üîß **Optuna search** for LR, dropout, batch size
- üîß **Better patience** settings

**Expected performance:**
- Test Spearman: **0.42-0.50**
- Training time: ~1-2 hours (with optimization)

---

# Step 1: Environment Setup

In [None]:
# Check GPU - should be A100
import torch
import sys

print(f"Python version: {sys.version}")
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")

if torch.cuda.is_available():
    gpu_name = torch.cuda.get_device_name(0)
    gpu_memory = torch.cuda.get_device_properties(0).total_memory / 1e9
    print(f"GPU: {gpu_name}")
    print(f"GPU Memory: {gpu_memory:.2f} GB")
    device = torch.device('cuda')
    
    if 'A100' in gpu_name:
        print("\n‚úÖ A100 GPU detected! Optimizations will be enabled.")
    else:
        print(f"\n‚ö†Ô∏è WARNING: Expected A100 but got {gpu_name}")
else:
    print("‚ö†Ô∏è WARNING: GPU not available!")
    device = torch.device('cpu')

print(f"\nUsing device: {device}")

In [None]:
# Install required packages
print("Installing required packages...\n")

!pip install -q transformers>=4.41.0
!pip install -q sentencepiece
!pip install -q optuna

print("\n‚úÖ All packages installed successfully!")
print("‚úÖ Using Colab's pre-installed numpy, pandas, scikit-learn, scipy")

In [None]:
# Enable A100-specific optimizations
import torch

torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
torch.backends.cudnn.benchmark = True
torch.backends.cudnn.deterministic = False
torch.set_float32_matmul_precision('high')

print("‚úÖ A100 optimizations enabled")

# Step 2: Import Libraries & Utilities

In [None]:
# Core imports
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from pathlib import Path
import json
import os
from tqdm.auto import tqdm
import time

# PyTorch
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader

# Transformers
from transformers import (
    T5Tokenizer, T5EncoderModel,
    AutoTokenizer, AutoModel
)

# Scikit-learn
from sklearn.model_selection import train_test_split
from sklearn.metrics import mean_squared_error, mean_absolute_error, r2_score
from scipy import stats

# Optuna
import optuna
from optuna.pruners import MedianPruner
from optuna.samplers import TPESampler

print("‚úÖ All libraries imported successfully!")
print(f"Optuna version: {optuna.__version__}")

In [None]:
# Comprehensive metrics function
def compute_comprehensive_metrics(targets, predictions):
    """Compute all standard metrics"""
    mse = mean_squared_error(targets, predictions)
    rmse = np.sqrt(mse)
    mae = mean_absolute_error(targets, predictions)
    r2 = r2_score(targets, predictions)
    
    spearman, spearman_p = stats.spearmanr(targets, predictions)
    pearson, pearson_p = stats.pearsonr(targets, predictions)
    
    strong_binders = targets >= 9.0
    predicted_strong = predictions >= 9.0
    
    tp = np.sum(strong_binders & predicted_strong)
    fp = np.sum(~strong_binders & predicted_strong)
    tn = np.sum(~strong_binders & ~predicted_strong)
    fn = np.sum(strong_binders & ~predicted_strong)
    
    recall = tp / (tp + fn) if (tp + fn) > 0 else 0.0
    precision = tp / (tp + fp) if (tp + fp) > 0 else 0.0
    specificity = tn / (tn + fp) if (tn + fp) > 0 else 0.0
    f1 = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0.0
    
    return {
        'mse': mse, 'rmse': rmse, 'mae': mae, 'r2': r2,
        'spearman': spearman, 'spearman_p': spearman_p,
        'pearson': pearson, 'pearson_p': pearson_p,
        'recall_pkd9': recall * 100, 'precision_pkd9': precision * 100,
        'f1_pkd9': f1 * 100, 'specificity_pkd9': specificity * 100,
        'n_samples': len(targets), 'n_strong_binders': int(strong_binders.sum())
    }

# Early Stopping with better tracking
class EarlyStopping:
    def __init__(self, patience=15, min_delta=0.001, mode='max', verbose=True):
        self.patience = patience
        self.min_delta = min_delta
        self.mode = mode
        self.verbose = verbose
        self.counter = 0
        self.best_score = None
        self.early_stop = False
        self.best_epoch = 0
    
    def __call__(self, score, epoch):
        if self.best_score is None:
            self.best_score = score
            self.best_epoch = epoch
            return False
        
        if self.mode == 'max':
            improved = score > (self.best_score + self.min_delta)
        else:
            improved = score < (self.best_score - self.min_delta)
        
        if improved:
            self.best_score = score
            self.best_epoch = epoch
            self.counter = 0
        else:
            self.counter += 1
            if self.verbose:
                print(f"   No improvement for {self.counter}/{self.patience} epochs")
            
            if self.counter >= self.patience:
                self.early_stop = True
                if self.verbose:
                    print(f"\n‚ö†Ô∏è Early stopping triggered!")
                    print(f"   Best score: {self.best_score:.4f} at epoch {self.best_epoch+1}")
                return True
        return False

# LR Scheduler with shorter warmup
def get_warmup_cosine_scheduler(optimizer, warmup_epochs, total_epochs):
    def lr_lambda(epoch):
        if epoch < warmup_epochs:
            return float(epoch + 1) / float(max(1, warmup_epochs))  # Start from 1/warmup, not 0
        progress = float(epoch - warmup_epochs) / float(max(1, total_epochs - warmup_epochs))
        return max(0.1, 0.5 * (1.0 + np.cos(np.pi * progress)))  # Min 10% of LR
    return torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)

# ReduceLROnPlateau wrapper
class ReduceLROnPlateauWrapper:
    def __init__(self, optimizer, factor=0.5, patience=5, min_lr=1e-6):
        self.scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
            optimizer, mode='max', factor=factor, patience=patience, 
            min_lr=min_lr, verbose=True
        )
    
    def step(self, val_score):
        self.scheduler.step(val_score)

# Focal MSE Loss
class FocalMSELoss(nn.Module):
    def __init__(self, gamma=2.0, label_smoothing=0.0):
        super().__init__()
        self.gamma = gamma
        self.label_smoothing = label_smoothing
    
    def forward(self, pred, target):
        if self.label_smoothing > 0:
            target_mean = target.mean()
            target = (1 - self.label_smoothing) * target + self.label_smoothing * target_mean
        mse = (pred - target) ** 2
        focal_weight = (1 + mse) ** self.gamma
        return (focal_weight * mse).mean()

print("‚úÖ Utility functions defined")

# Step 3: Mount Google Drive & Load Data

In [None]:
# Mount Google Drive
from google.colab import drive

print("Mounting Google Drive...")
drive.mount('/content/drive')
print("‚úÖ Google Drive mounted!")

# Set up paths
DRIVE_DIR = '/content/drive/MyDrive/AbAg_Training_02'
OUTPUT_DIR = f'{DRIVE_DIR}/training_output_OPTIMIZED'

os.makedirs(OUTPUT_DIR, exist_ok=True)

print(f"\nüìÇ Working directories:")
print(f"   Data directory: {DRIVE_DIR}")
print(f"   Output directory: {OUTPUT_DIR}")

In [None]:
# Load dataset
CSV_FILENAME = 'agab_phase2_full.csv'  # ‚Üê CHANGE THIS to your filename

csv_path = os.path.join(DRIVE_DIR, CSV_FILENAME)
df = pd.read_csv(csv_path)

print(f"\nüìä Dataset: {len(df):,} samples")
print(f"   pKd range: {df['pKd'].min():.2f} - {df['pKd'].max():.2f}")
print(f"   Strong binders (‚â•9): {(df['pKd']>=9).sum():,} ({100*(df['pKd']>=9).sum()/len(df):.1f}%)")

In [None]:
# Split data
train_df, temp_df = train_test_split(df, test_size=0.3, random_state=42)
val_df, test_df = train_test_split(temp_df, test_size=0.5, random_state=42)

print("\nüìä Dataset splits:")
print(f"   Train:  {len(train_df):,}")
print(f"   Val:    {len(val_df):,}")
print(f"   Test:   {len(test_df):,}")

In [None]:
# Dataset and DataLoader
class AbAgDataset(Dataset):
    def __init__(self, dataframe):
        self.data = dataframe.reset_index(drop=True)
    def __len__(self):
        return len(self.data)
    def __getitem__(self, idx):
        row = self.data.iloc[idx]
        return {
            'antibody_seqs': row['antibody_sequence'],
            'antigen_seqs': row['antigen_sequence'],
            'pKd': torch.tensor(row['pKd'], dtype=torch.float32)
        }

def collate_fn(batch):
    return {
        'antibody_seqs': [item['antibody_seqs'] for item in batch],
        'antigen_seqs': [item['antigen_seqs'] for item in batch],
        'pKd': torch.stack([item['pKd'] for item in batch])
    }

# Create datasets
train_dataset = AbAgDataset(train_df)
val_dataset = AbAgDataset(val_df)  # FULL validation set
test_dataset = AbAgDataset(test_df)

print("‚úÖ Datasets created")

# Step 4: Model Architecture (ESM-2 3B)

In [None]:
# Model with ESM-2 3B
class IgT5ESM2_3B_Model(nn.Module):
    def __init__(self, dropout=0.3, freeze_encoders=True, use_checkpointing=True):
        super().__init__()
        
        print("üî® Building model with ESM-2 3B...")
        
        # IgT5 for antibodies
        print("  üì• Loading IgT5...")
        self.igt5_tokenizer = T5Tokenizer.from_pretrained("Exscientia/IgT5")
        self.igt5_model = T5EncoderModel.from_pretrained("Exscientia/IgT5")
        
        # ESM-2 3B for antigens
        print("  üì• Loading ESM-2 3B (this will take a moment)...")
        self.esm2_tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t36_3B_UR50D")
        self.esm2_model = AutoModel.from_pretrained("facebook/esm2_t36_3B_UR50D")
        print("  ‚úÖ ESM-2 3B loaded!")
        
        # Freeze encoders
        if freeze_encoders:
            for param in self.igt5_model.parameters():
                param.requires_grad = False
            for param in self.esm2_model.parameters():
                param.requires_grad = False
        
        # Gradient checkpointing
        if use_checkpointing:
            self.igt5_model.gradient_checkpointing_enable()
            self.esm2_model.gradient_checkpointing_enable()
        
        # Dimensions
        self.igt5_dim = self.igt5_model.config.d_model  # 512
        self.esm2_dim = self.esm2_model.config.hidden_size  # 2560
        self.combined_dim = self.igt5_dim + self.esm2_dim  # 3072
        
        # Regression head
        self.regression_head = nn.Sequential(
            nn.Linear(self.combined_dim, 1536),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.LayerNorm(1536),
            
            nn.Linear(1536, 768),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.LayerNorm(768),
            
            nn.Linear(768, 384),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.LayerNorm(384),
            
            nn.Linear(384, 128),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.LayerNorm(128),
            
            nn.Linear(128, 1)
        )
        
        trainable_params = sum(p.numel() for p in self.parameters() if p.requires_grad)
        print(f"  üìä Trainable parameters: {trainable_params/1e6:.1f}M")
    
    def forward(self, antibody_seqs, antigen_seqs, device):
        antibody_tokens = self.igt5_tokenizer(
            antibody_seqs, return_tensors='pt', padding=True,
            truncation=True, max_length=512
        ).to(device)
        
        antigen_tokens = self.esm2_tokenizer(
            antigen_seqs, return_tensors='pt', padding=True,
            truncation=True, max_length=2048
        ).to(device)
        
        with torch.amp.autocast('cuda', dtype=torch.bfloat16):
            antibody_outputs = self.igt5_model(**antibody_tokens)
            antibody_embedding = antibody_outputs.last_hidden_state.mean(dim=1)
            
            antigen_outputs = self.esm2_model(**antigen_tokens)
            antigen_embedding = antigen_outputs.last_hidden_state.mean(dim=1)
            
            combined = torch.cat([antibody_embedding, antigen_embedding], dim=1)
            pKd_pred = self.regression_head(combined).squeeze(-1)
        
        return pKd_pred

print("‚úÖ Model class defined")

# Step 5: Optuna Hyperparameter Optimization

**Search space:**
- Learning rate: 1e-4 to 5e-3
- Dropout: 0.1 to 0.5
- Batch size: 32, 48, 64
- Warmup epochs: 1-3
- Label smoothing: 0-0.1

**Quick optimization:** 10 trials (adjust for more thorough search)

In [None]:
# Training functions
def train_epoch(model, loader, optimizer, criterion, device, max_grad_norm):
    model.train()
    total_loss = 0
    
    for batch in loader:
        antibody_seqs = batch['antibody_seqs']
        antigen_seqs = batch['antigen_seqs']
        targets = batch['pKd'].to(device)
        
        with torch.amp.autocast('cuda', dtype=torch.bfloat16):
            predictions = model(antibody_seqs, antigen_seqs, device)
            loss = criterion(predictions, targets)
        
        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm)
        optimizer.step()
        
        total_loss += loss.item()
    
    return total_loss / len(loader)

def eval_model(model, loader, device):
    model.eval()
    predictions = []
    targets = []
    
    with torch.no_grad():
        for batch in loader:
            antibody_seqs = batch['antibody_seqs']
            antigen_seqs = batch['antigen_seqs']
            batch_targets = batch['pKd'].to(device)
            
            with torch.amp.autocast('cuda', dtype=torch.bfloat16):
                batch_predictions = model(antibody_seqs, antigen_seqs, device)
            
            predictions.extend(batch_predictions.float().cpu().numpy())
            targets.extend(batch_targets.float().cpu().numpy())
    
    predictions = np.array(predictions)
    targets = np.array(targets)
    metrics = compute_comprehensive_metrics(targets, predictions)
    return metrics, predictions, targets

print("‚úÖ Training functions defined")

In [None]:
# Optuna objective function
def objective(trial):
    """Optuna objective - maximize Spearman correlation"""
    
    # Hyperparameter search space
    lr = trial.suggest_float('learning_rate', 5e-4, 3e-3, log=True)
    dropout = trial.suggest_float('dropout', 0.2, 0.4)
    batch_size = trial.suggest_categorical('batch_size', [32, 48, 64])
    warmup_epochs = trial.suggest_int('warmup_epochs', 1, 3)
    label_smoothing = trial.suggest_float('label_smoothing', 0.0, 0.1)
    weight_decay = trial.suggest_float('weight_decay', 0.005, 0.02, log=True)
    
    # Create model
    model = IgT5ESM2_3B_Model(dropout=dropout, freeze_encoders=True)
    model = model.to(device)
    
    # Create data loaders
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True,
                              num_workers=4, collate_fn=collate_fn, pin_memory=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False,
                            num_workers=4, collate_fn=collate_fn, pin_memory=True)
    
    # Optimizer
    optimizer = torch.optim.AdamW(
        model.parameters(),
        lr=lr,
        weight_decay=weight_decay,
        fused=True
    )
    
    # Scheduler
    scheduler = get_warmup_cosine_scheduler(optimizer, warmup_epochs, 20)
    
    # Loss
    criterion = FocalMSELoss(gamma=2.0, label_smoothing=label_smoothing)
    
    # Training (shorter for optimization)
    best_spearman = -1
    patience_counter = 0
    
    for epoch in range(20):  # Max 20 epochs for optimization
        train_loss = train_epoch(model, train_loader, optimizer, criterion, device, 1.0)
        
        # Validate
        val_metrics, _, _ = eval_model(model, val_loader, device)
        val_spearman = val_metrics['spearman']
        
        if val_spearman > best_spearman:
            best_spearman = val_spearman
            patience_counter = 0
        else:
            patience_counter += 1
        
        # Early stopping for optimization
        if patience_counter >= 5:
            break
        
        # Pruning
        trial.report(val_spearman, epoch)
        if trial.should_prune():
            raise optuna.TrialPruned()
        
        scheduler.step()
    
    # Cleanup
    del model, optimizer, train_loader, val_loader
    torch.cuda.empty_cache()
    
    return best_spearman

print("‚úÖ Optuna objective function defined")

In [None]:
# Run hyperparameter optimization
print("="*70)
print("üîç OPTUNA HYPERPARAMETER OPTIMIZATION")
print("="*70)

# Create study
study = optuna.create_study(
    direction='maximize',
    sampler=TPESampler(seed=42),
    pruner=MedianPruner(n_startup_trials=3, n_warmup_steps=5)
)

# Number of trials (increase for better optimization)
N_TRIALS = 10  # Adjust: 10 for quick, 20-30 for thorough

print(f"\nRunning {N_TRIALS} optimization trials...")
print("This will take ~30-60 minutes\n")

study.optimize(objective, n_trials=N_TRIALS, show_progress_bar=True)

print("\n" + "="*70)
print("‚úÖ OPTIMIZATION COMPLETE")
print("="*70)
print(f"\nüèÜ Best Spearman: {study.best_trial.value:.4f}")
print(f"\nBest hyperparameters:")
for key, value in study.best_params.items():
    print(f"   {key}: {value}")

# Step 6: Train Final Model with Best Hyperparameters

In [None]:
# Get best hyperparameters
best_params = study.best_params

# Final configuration
config = {
    'epochs': 50,
    'batch_size': best_params['batch_size'],
    'lr': best_params['learning_rate'],
    'weight_decay': best_params['weight_decay'],
    'dropout': best_params['dropout'],
    'warmup_epochs': best_params['warmup_epochs'],
    'label_smoothing': best_params['label_smoothing'],
    'early_stopping_patience': 15,
    'reduce_lr_patience': 5,
    'max_grad_norm': 1.0
}

print("="*70)
print("üìã FINAL TRAINING CONFIGURATION")
print("="*70)
for key, value in config.items():
    print(f"   {key}: {value}")

In [None]:
# Create final model
print("\nBuilding final model with optimized hyperparameters...\n")

model = IgT5ESM2_3B_Model(
    dropout=config['dropout'],
    freeze_encoders=True,
    use_checkpointing=True
)
model = model.to(device)

# Create data loaders
train_loader = DataLoader(train_dataset, batch_size=config['batch_size'], shuffle=True,
                          num_workers=4, collate_fn=collate_fn, pin_memory=True)
val_loader = DataLoader(val_dataset, batch_size=config['batch_size'], shuffle=False,
                        num_workers=4, collate_fn=collate_fn, pin_memory=True)
test_loader = DataLoader(test_dataset, batch_size=config['batch_size'], shuffle=False,
                         num_workers=4, collate_fn=collate_fn, pin_memory=True)

print(f"\n‚úÖ DataLoaders created:")
print(f"   ‚Ä¢ train: {len(train_loader)} batches")
print(f"   ‚Ä¢ val: {len(val_loader)} batches")
print(f"   ‚Ä¢ test: {len(test_loader)} batches")

In [None]:
# Optimizer and schedulers
optimizer = torch.optim.AdamW(
    model.parameters(),
    lr=config['lr'],
    weight_decay=config['weight_decay'],
    fused=True
)

# Warmup + Cosine scheduler
scheduler = get_warmup_cosine_scheduler(
    optimizer,
    warmup_epochs=config['warmup_epochs'],
    total_epochs=config['epochs']
)

# ReduceLROnPlateau (additional)
reduce_lr = ReduceLROnPlateauWrapper(
    optimizer,
    factor=0.5,
    patience=config['reduce_lr_patience'],
    min_lr=1e-6
)

# Loss
criterion = FocalMSELoss(gamma=2.0, label_smoothing=config['label_smoothing'])

# Early stopping
early_stopping = EarlyStopping(
    patience=config['early_stopping_patience'],
    min_delta=0.001,
    mode='max'
)

print("‚úÖ Training components configured")

In [None]:
# Main training loop
print("="*70)
print("üöÄ STARTING FINAL TRAINING")
print("="*70)

model_save_path = os.path.join(OUTPUT_DIR, 'best_model.pth')
best_spearman = -1
training_history = {'train_loss': [], 'val_spearman': [], 'val_recall': [], 'epoch': [], 'time_per_epoch': []}

print(f"\nTarget: Test Spearman > 0.42\n")

for epoch in range(config['epochs']):
    start_time = time.time()
    print(f"\nEpoch {epoch+1}/{config['epochs']}")
    print("-"*70)
    
    # Train
    model.train()
    total_loss = 0
    pbar = tqdm(train_loader, desc=f"Training")
    
    for batch in pbar:
        antibody_seqs = batch['antibody_seqs']
        antigen_seqs = batch['antigen_seqs']
        targets = batch['pKd'].to(device)
        
        with torch.amp.autocast('cuda', dtype=torch.bfloat16):
            predictions = model(antibody_seqs, antigen_seqs, device)
            loss = criterion(predictions, targets)
        
        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), config['max_grad_norm'])
        optimizer.step()
        
        total_loss += loss.item()
        pbar.set_postfix({'loss': f'{loss.item():.4f}'})
    
    train_loss = total_loss / len(train_loader)
    epoch_time = time.time() - start_time
    
    # Validate on FULL validation set
    val_metrics, _, _ = eval_model(model, val_loader, device)
    val_spearman = val_metrics['spearman']
    val_recall = val_metrics['recall_pkd9']
    
    print(f"Train Loss: {train_loss:.4f} | Val Spearman: {val_spearman:.4f} | Recall: {val_recall:.1f}% | Time: {epoch_time:.1f}s")
    
    # Save best
    if val_spearman > best_spearman:
        best_spearman = val_spearman
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'val_spearman': val_spearman,
            'config': config,
            'best_params': best_params
        }, model_save_path)
        print(f"‚úÖ Saved best model (Spearman: {val_spearman:.4f})")
    
    # Early stopping
    if early_stopping(val_spearman, epoch):
        print(f"\n‚õî Early stopping at epoch {epoch+1}")
        break
    
    # Update schedulers
    scheduler.step()
    reduce_lr.step(val_spearman)
    
    current_lr = optimizer.param_groups[0]['lr']
    print(f"Learning Rate: {current_lr:.6f}")
    
    # Record history
    training_history['train_loss'].append(train_loss)
    training_history['val_spearman'].append(val_spearman)
    training_history['val_recall'].append(val_recall)
    training_history['epoch'].append(epoch + 1)
    training_history['time_per_epoch'].append(epoch_time)

total_time = sum(training_history['time_per_epoch'])
print(f"\n{'='*70}")
print(f"TRAINING COMPLETE!")
print(f"Best Validation Spearman: {best_spearman:.4f}")
print(f"Total training time: {total_time/60:.1f} minutes")
print(f"{'='*70}")

# Step 7: Final Evaluation

In [None]:
# Load best model and evaluate
import torch
import numpy as np
import pandas as pd
import json
import os

print("="*70)
print("FINAL EVALUATION")
print("="*70)

checkpoint = torch.load(model_save_path, weights_only=False)
model.load_state_dict(checkpoint['model_state_dict'])
print(f"\n‚úÖ Loaded best model from epoch {checkpoint['epoch']+1}")

# Validation
print("\nEvaluating on validation set...")
val_metrics, val_preds, val_targets = eval_model(model, val_loader, device)

print(f"\nüìä VALIDATION METRICS:")
print(f"   Spearman: {val_metrics['spearman']:.4f}")
print(f"   RMSE:     {val_metrics['rmse']:.4f}")
print(f"   MAE:      {val_metrics['mae']:.4f}")
print(f"   R¬≤:       {val_metrics['r2']:.4f}")
print(f"   Recall:   {val_metrics['recall_pkd9']:.2f}%")

# Test
print("\nEvaluating on TEST set...")
test_metrics, test_preds, test_targets = eval_model(model, test_loader, device)

print(f"\nüìä TEST METRICS (TRUE PERFORMANCE):")
print(f"   Spearman: {test_metrics['spearman']:.4f} ‚Üê FINAL RESULT")
print(f"   RMSE:     {test_metrics['rmse']:.4f}")
print(f"   MAE:      {test_metrics['mae']:.4f}")
print(f"   R¬≤:       {test_metrics['r2']:.4f}")
print(f"   Recall:   {test_metrics['recall_pkd9']:.2f}%")

# Save results
val_results = pd.DataFrame({
    'true_pKd': val_targets, 'pred_pKd': val_preds,
    'error': val_preds - val_targets, 'abs_error': np.abs(val_preds - val_targets)
})
val_results.to_csv(os.path.join(OUTPUT_DIR, 'val_predictions.csv'), index=False)

test_results = pd.DataFrame({
    'true_pKd': test_targets, 'pred_pKd': test_preds,
    'error': test_preds - test_targets, 'abs_error': np.abs(test_preds - test_targets)
})
test_results.to_csv(os.path.join(OUTPUT_DIR, 'test_predictions.csv'), index=False)

# Save all metrics
try:
    training_time = total_time / 60
except NameError:
    training_time = None

all_metrics = {
    'model': 'IgT5 + ESM-2 3B (Optuna Optimized)',
    'gpu': 'A100-80GB',
    'best_params': best_params,
    'validation': {k: float(v) if isinstance(v, (np.floating, np.integer)) else v
                   for k, v in val_metrics.items()},
    'test': {k: float(v) if isinstance(v, (np.floating, np.integer)) else v
            for k, v in test_metrics.items()},
    'config': config,
    'training_time_minutes': training_time,
    'optuna_trials': len(study.trials)
}

with open(os.path.join(OUTPUT_DIR, 'final_metrics.json'), 'w') as f:
    json.dump(all_metrics, f, indent=2)

print(f"\n‚úÖ All results saved to: {OUTPUT_DIR}")

In [None]:
# Visualization
import matplotlib.pyplot as plt

fig, axes = plt.subplots(1, 3, figsize=(18, 5))

# Training curves
ax1 = axes[0]
ax1.plot(training_history['epoch'], training_history['val_spearman'], 'g-o', linewidth=2)
ax1.axhline(y=best_spearman, color='r', linestyle='--', label=f'Best: {best_spearman:.4f}')
ax1.set_xlabel('Epoch')
ax1.set_ylabel('Validation Spearman')
ax1.set_title('Training Progress (Optuna Optimized)', fontweight='bold')
ax1.legend()
ax1.grid(True, alpha=0.3)

# Test predictions
ax2 = axes[1]
ax2.scatter(test_targets, test_preds, alpha=0.3, s=10, color='orange')
ax2.plot([4, 14], [4, 14], 'r--', linewidth=2, label='Perfect')
ax2.set_xlabel('True pKd')
ax2.set_ylabel('Predicted pKd')
ax2.set_title(f'Test Set\nSpearman: {test_metrics["spearman"]:.4f}', fontweight='bold')
ax2.legend()
ax2.grid(True, alpha=0.3)
ax2.set_xlim(4, 14)
ax2.set_ylim(4, 14)

# Optuna optimization history
ax3 = axes[2]
trials = [t for t in study.trials if t.state == optuna.trial.TrialState.COMPLETE]
trial_values = [t.value for t in trials]
ax3.plot(range(1, len(trial_values)+1), trial_values, 'b-o', linewidth=2)
ax3.axhline(y=study.best_trial.value, color='r', linestyle='--', 
            label=f'Best: {study.best_trial.value:.4f}')
ax3.set_xlabel('Trial')
ax3.set_ylabel('Validation Spearman')
ax3.set_title('Optuna Optimization History', fontweight='bold')
ax3.legend()
ax3.grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig(os.path.join(OUTPUT_DIR, 'results_summary.png'), dpi=300, bbox_inches='tight')
plt.show()

print("‚úÖ Visualization saved")

# üéâ Training Complete!

## Summary

**Optimizations applied:**
- ‚úÖ Optuna hyperparameter search
- ‚úÖ Full validation set (not 5% sample)
- ‚úÖ Reduced warmup epochs
- ‚úÖ ReduceLROnPlateau + Cosine annealing
- ‚úÖ Better early stopping patience

**Results saved to:** `Google Drive/AbAg_Training_02/training_output_OPTIMIZED/`

---

**To improve further:**
- Increase `N_TRIALS` to 20-30 for more thorough optimization
- Try different loss functions (Huber, custom weighted loss)
- Ensemble multiple models

---

**Happy modeling! üß¨üöÄ**