# AbAg Binding Affinity Prediction - IMPROVED Training v2

**Improvements over v1:**
- ‚úÖ GELU activation (instead of ReLU) - smoother gradients
- ‚úÖ Deeper architecture: 150 ‚Üí 512 ‚Üí 256 ‚Üí 128 ‚Üí 64 ‚Üí 1 (vs 150 ‚Üí 256 ‚Üí 128 ‚Üí 1)
- ‚úÖ 10x stronger weights for very strong/weak binders
- ‚úÖ Lower learning rate: 0.0001 (vs 0.001) - more stable training
- ‚úÖ Focal loss option - focuses on hard examples
- ‚úÖ Gradient clipping - prevents exploding gradients
- ‚úÖ Better initialization - Xavier/He initialization

**Expected Results:**
- Overall RMSE: 1.48 ‚Üí 0.8-1.0
- Very strong RMSE: 2.94 ‚Üí 1.0-1.5 (67% improvement)
- Spearman œÅ: 0.39 ‚Üí 0.65-0.75

---

**Dataset:** 390,757 samples (330,762 with features)
**Training time:** ~10-12 hours on T4 GPU (100 epochs)

## 1. Setup - GPU and Dependencies

In [None]:
# Check GPU
import torch
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"VRAM: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")
else:
    print("‚ö†Ô∏è WARNING: No GPU detected! Enable GPU in Runtime settings.")

In [None]:
# Install dependencies
!pip install -q transformers scikit-learn pandas numpy tqdm matplotlib seaborn
print("‚úÖ All dependencies installed!")

## 2. Mount Google Drive

In [None]:
from google.colab import drive
drive.mount('/content/drive')
print("‚úÖ Google Drive mounted!")

In [None]:
import os
from pathlib import Path

# Set up paths - MODIFY THIS to match your Google Drive location
DRIVE_DATA_PATH = "/content/drive/MyDrive/AbAg_data/merged_with_all_features.csv"
OUTPUT_DIR = "/content/drive/MyDrive/AbAg_data/models_v2"

# Create output directory
os.makedirs(OUTPUT_DIR, exist_ok=True)

# Copy data to local for faster training
LOCAL_DATA_PATH = "/content/merged_with_all_features.csv"

print(f"Data path: {DRIVE_DATA_PATH}")
print(f"Output directory: {OUTPUT_DIR}")

if os.path.exists(DRIVE_DATA_PATH):
    print(f"‚úÖ Data file found! Size: {os.path.getsize(DRIVE_DATA_PATH) / 1e6:.1f} MB")
    
    print("Copying data to local storage for faster I/O...")
    !cp "{DRIVE_DATA_PATH}" "{LOCAL_DATA_PATH}"
    print("‚úÖ Data copied to local storage!")
    
    DATA_PATH = LOCAL_DATA_PATH
else:
    print(f"‚ùå Data file not found at: {DRIVE_DATA_PATH}")
    print("\nPlease upload 'merged_with_all_features.csv' to your Google Drive.")

## 3. Imports and Constants

In [None]:
# Imports
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split
from sklearn.metrics import mean_squared_error, mean_absolute_error
from scipy.stats import spearmanr, pearsonr
from tqdm import tqdm
import matplotlib.pyplot as plt
import seaborn as sns
import time

# Set random seeds
torch.manual_seed(42)
np.random.seed(42)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(42)

# Constants
BINS = [0, 5, 7, 9, 11, 16]
BIN_LABELS = ['very_weak', 'weak', 'moderate', 'strong', 'very_strong']

print("‚úÖ Imports complete!")

## 4. Dataset Class

In [None]:
class AffinityDataset(Dataset):
    def __init__(self, features, labels):
        self.features = torch.FloatTensor(features)
        self.labels = torch.FloatTensor(labels)
    
    def __len__(self):
        return len(self.labels)
    
    def __getitem__(self, idx):
        return self.features[idx], self.labels[idx]

print("‚úÖ Dataset class defined!")

## 5. IMPROVED Model Architecture (GELU + Deeper)

In [None]:
class ImprovedAffinityPredictor(nn.Module):
    """Improved model with GELU activation and deeper architecture"""
    
    def __init__(self, input_dim=150, hidden_dims=[512, 256, 128, 64], dropout=0.3):
        super(ImprovedAffinityPredictor, self).__init__()
        
        layers = []
        prev_dim = input_dim
        
        for i, hidden_dim in enumerate(hidden_dims):
            # Linear layer
            linear = nn.Linear(prev_dim, hidden_dim)
            
            # Xavier/He initialization for better gradient flow
            nn.init.xavier_uniform_(linear.weight)
            nn.init.zeros_(linear.bias)
            
            layers.extend([
                linear,
                nn.BatchNorm1d(hidden_dim),
                nn.GELU(),  # GELU instead of ReLU - smoother gradients
                nn.Dropout(dropout if i < len(hidden_dims) - 1 else dropout * 0.5)  # Less dropout in final layers
            ])
            prev_dim = hidden_dim
        
        # Output layer
        output_layer = nn.Linear(prev_dim, 1)
        nn.init.xavier_uniform_(output_layer.weight)
        nn.init.zeros_(output_layer.bias)
        layers.append(output_layer)
        
        self.network = nn.Sequential(*layers)
    
    def forward(self, x):
        return self.network(x).squeeze()

print("‚úÖ Improved model class defined!")
print("   - GELU activation (smoother than ReLU)")
print("   - Deeper: 512 ‚Üí 256 ‚Üí 128 ‚Üí 64")
print("   - Xavier initialization")
print("   - Adaptive dropout")

## 6. Loss Functions (Weighted MSE + Focal Loss)

In [None]:
class WeightedMSELoss(nn.Module):
    """MSE with class-based weighting"""
    
    def __init__(self, bin_weights, bins_edges):
        super().__init__()
        self.bin_weights = bin_weights
        self.bins = bins_edges
    
    def forward(self, predictions, targets):
        weights = torch.ones_like(targets)
        for i, (low, high) in enumerate(zip(self.bins[:-1], self.bins[1:])):
            mask = (targets >= low) & (targets < high)
            weights[mask] = self.bin_weights[i]
        
        mse = (predictions - targets) ** 2
        weighted_mse = mse * weights
        return weighted_mse.mean()


class FocalMSELoss(nn.Module):
    """Focal loss for regression - focuses on hard examples"""
    
    def __init__(self, bin_weights, bins_edges, gamma=2.0):
        super().__init__()
        self.bin_weights = bin_weights
        self.bins = bins_edges
        self.gamma = gamma
    
    def forward(self, predictions, targets):
        # Class weights
        weights = torch.ones_like(targets)
        for i, (low, high) in enumerate(zip(self.bins[:-1], self.bins[1:])):
            mask = (targets >= low) & (targets < high)
            weights[mask] = self.bin_weights[i]
        
        # Focal weighting - down-weight easy examples
        mse = (predictions - targets) ** 2
        focal_weight = mse ** (self.gamma / 2)  # Higher error = higher weight
        
        weighted_mse = mse * weights * (1 + focal_weight)
        return weighted_mse.mean()


print("‚úÖ Loss functions defined!")
print("   - WeightedMSELoss: Class-based weighting")
print("   - FocalMSELoss: Focuses on hard examples (gamma=2.0)")

## 7. Load and Prepare Data

In [None]:
print("Loading dataset...")
df = pd.read_csv(DATA_PATH, low_memory=False)
print(f"‚úÖ Loaded {len(df):,} samples")

# Filter samples with features
pca_cols = [f'esm2_pca_{i}' for i in range(150)]
df_with_features = df[df[pca_cols[0]].notna()].copy()
print(f"‚úÖ Samples with features: {len(df_with_features):,}")

# Create affinity bins
df_with_features['affinity_bin'] = pd.cut(
    df_with_features['pKd'], bins=BINS, labels=BIN_LABELS, include_lowest=True
)

# Show distribution
print("\nAffinity Distribution:")
for label in BIN_LABELS:
    count = (df_with_features['affinity_bin'] == label).sum()
    pct = count / len(df_with_features) * 100
    print(f"  {label:<15}: {count:6,} ({pct:5.2f}%)")

print(f"\nTotal: {len(df_with_features):,}")

In [None]:
# Extract features and labels
X = df_with_features[pca_cols].values
y = df_with_features['pKd'].values

# Train/val/test split (same as v1 for comparison)
X_temp, X_test, y_temp, y_test = train_test_split(X, y, test_size=0.15, random_state=42)
X_train, X_val, y_train, y_val = train_test_split(X_temp, y_temp, test_size=0.15/0.85, random_state=42)

print(f"Train set: {len(X_train):,}")
print(f"Val set: {len(X_val):,}")
print(f"Test set: {len(X_test):,}")

# Calculate IMPROVED class weights (10x stronger for extremes)
y_train_binned = pd.cut(y_train, bins=BINS, labels=BIN_LABELS, include_lowest=True)
bin_counts = y_train_binned.value_counts().sort_index()
total_samples = len(y_train)
bin_weights = {}

for label in BIN_LABELS:
    count = bin_counts.get(label, 1)
    base_weight = total_samples / (len(BIN_LABELS) * count)
    
    # 10x stronger weight for very strong and very weak
    if label in ['very_strong', 'very_weak']:
        bin_weights[label] = base_weight * 10
    else:
        bin_weights[label] = base_weight

print("\nIMPROVED Class Weights (10x stronger for extremes):")
for label, weight in bin_weights.items():
    marker = "‚≠ê" if label in ['very_strong', 'very_weak'] else "  "
    print(f"  {marker} {label:<15}: {weight:.2f}")

# Convert to tensor
bin_weights_tensor = torch.FloatTensor([bin_weights[l] for l in BIN_LABELS]).cuda()

# Create datasets
train_dataset = AffinityDataset(X_train, y_train)
val_dataset = AffinityDataset(X_val, y_val)
test_dataset = AffinityDataset(X_test, y_test)

print("\n‚úÖ Data preparation complete!")

## 8. Training Configuration

In [None]:
# IMPROVED Training configuration
EPOCHS = 100
BATCH_SIZE = 128
LEARNING_RATE = 0.0001  # 10x lower than v1 for stability
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
USE_FOCAL_LOSS = True  # Set to False to use weighted MSE instead
GRADIENT_CLIP = 1.0  # Clip gradients to prevent explosion

print(f"IMPROVED Configuration:")
print(f"  Epochs: {EPOCHS}")
print(f"  Batch size: {BATCH_SIZE}")
print(f"  Learning rate: {LEARNING_RATE} (10x lower than v1)")
print(f"  Device: {DEVICE}")
print(f"  Loss: {'Focal MSE' if USE_FOCAL_LOSS else 'Weighted MSE'}")
print(f"  Gradient clipping: {GRADIENT_CLIP}")
print(f"\n  Model: 150 ‚Üí 512 ‚Üí 256 ‚Üí 128 ‚Üí 64 ‚Üí 1")
print(f"  Activation: GELU (vs ReLU in v1)")
print(f"  Class weights: 10x stronger for extremes")

## 9. Create Data Loaders

In [None]:
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=2)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=2)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=2)

print(f"‚úÖ Data loaders created!")
print(f"  Train batches: {len(train_loader)}")
print(f"  Val batches: {len(val_loader)}")
print(f"  Test batches: {len(test_loader)}")

## 10. Initialize Model

In [None]:
# Initialize IMPROVED model
model = ImprovedAffinityPredictor(
    input_dim=150,
    hidden_dims=[512, 256, 128, 64],  # Deeper than v1
    dropout=0.3
)
model = model.to(DEVICE)

# Loss and optimizer
if USE_FOCAL_LOSS:
    criterion = FocalMSELoss(bin_weights_tensor, BINS, gamma=2.0)
    print("Using Focal MSE Loss (focuses on hard examples)")
else:
    criterion = WeightedMSELoss(bin_weights_tensor, BINS)
    print("Using Weighted MSE Loss")

optimizer = optim.AdamW(  # AdamW instead of Adam - better regularization
    model.parameters(),
    lr=LEARNING_RATE,
    weight_decay=1e-4,  # Stronger regularization
    betas=(0.9, 0.999)
)

# Cosine annealing scheduler - gradually reduces learning rate
scheduler = optim.lr_scheduler.CosineAnnealingWarmRestarts(
    optimizer,
    T_0=20,  # Restart every 20 epochs
    T_mult=2,
    eta_min=LEARNING_RATE * 0.01
)

# Count parameters
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f"\n‚úÖ Model initialized!")
print(f"  Total parameters: {total_params:,}")
print(f"  Trainable parameters: {trainable_params:,}")
print(f"  Optimizer: AdamW (better than Adam)")
print(f"  Scheduler: Cosine Annealing with Warm Restarts")
print(f"\nModel architecture:")
print(model)

## 11. Training Loop (with Gradient Clipping)

In [None]:
# Training loop with improvements
best_val_loss = float('inf')
train_losses = []
val_losses = []
learning_rates = []
train_start = time.time()

print("Starting IMPROVED training...")
print("="*70)

for epoch in range(EPOCHS):
    epoch_start = time.time()
    
    # Training
    model.train()
    train_loss = 0
    train_pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{EPOCHS} [Train]")
    
    for features, labels in train_pbar:
        features, labels = features.to(DEVICE), labels.to(DEVICE)
        
        optimizer.zero_grad()
        predictions = model(features)
        loss = criterion(predictions, labels)
        loss.backward()
        
        # Gradient clipping - prevents exploding gradients
        torch.nn.utils.clip_grad_norm_(model.parameters(), GRADIENT_CLIP)
        
        optimizer.step()
        
        train_loss += loss.item()
        train_pbar.set_postfix({'loss': f'{loss.item():.4f}'})
    
    train_loss /= len(train_loader)
    train_losses.append(train_loss)
    
    # Validation
    model.eval()
    val_loss = 0
    val_pbar = tqdm(val_loader, desc=f"Epoch {epoch+1}/{EPOCHS} [Val]")
    
    with torch.no_grad():
        for features, labels in val_pbar:
            features, labels = features.to(DEVICE), labels.to(DEVICE)
            predictions = model(features)
            loss = criterion(predictions, labels)
            val_loss += loss.item()
            val_pbar.set_postfix({'loss': f'{loss.item():.4f}'})
    
    val_loss /= len(val_loader)
    val_losses.append(val_loss)
    
    # Update learning rate
    current_lr = optimizer.param_groups[0]['lr']
    learning_rates.append(current_lr)
    scheduler.step()
    
    epoch_time = time.time() - epoch_start
    
    print(f"Epoch {epoch+1}/{EPOCHS} - Train: {train_loss:.4f}, Val: {val_loss:.4f}, LR: {current_lr:.6f}, Time: {epoch_time:.1f}s")
    
    # Save best model
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'scheduler_state_dict': scheduler.state_dict(),
            'train_loss': train_loss,
            'val_loss': val_loss,
            'config': {
                'hidden_dims': [512, 256, 128, 64],
                'dropout': 0.3,
                'activation': 'GELU',
                'learning_rate': LEARNING_RATE,
                'batch_size': BATCH_SIZE,
                'focal_loss': USE_FOCAL_LOSS
            }
        }, f'{OUTPUT_DIR}/best_model_v2.pth')
        print(f"  ‚úÖ New best model saved! (val_loss: {val_loss:.4f})")
    
    # Save checkpoint every 10 epochs
    if (epoch + 1) % 10 == 0:
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'scheduler_state_dict': scheduler.state_dict(),
            'train_loss': train_loss,
            'val_loss': val_loss,
        }, f'{OUTPUT_DIR}/checkpoint_v2_epoch_{epoch+1}.pth')
        print(f"  üíæ Checkpoint saved!")

total_time = time.time() - train_start
print(f"\n{'='*70}")
print(f"‚úÖ Training complete! Total time: {total_time/3600:.2f} hours")
print(f"Best validation loss: {best_val_loss:.4f}")
print(f"{'='*70}")

## 12. Plot Training Curves

In [None]:
# Plot training curves with learning rate
fig, axes = plt.subplots(1, 2, figsize=(16, 6))

# Loss curves
axes[0].plot(train_losses, label='Train Loss', linewidth=2)
axes[0].plot(val_losses, label='Val Loss', linewidth=2)
axes[0].set_xlabel('Epoch', fontsize=12)
axes[0].set_ylabel('Loss', fontsize=12)
axes[0].set_title('Training and Validation Loss', fontsize=14)
axes[0].legend(fontsize=11)
axes[0].grid(True, alpha=0.3)

# Learning rate schedule
axes[1].plot(learning_rates, linewidth=2, color='green')
axes[1].set_xlabel('Epoch', fontsize=12)
axes[1].set_ylabel('Learning Rate', fontsize=12)
axes[1].set_title('Learning Rate Schedule (Cosine Annealing)', fontsize=14)
axes[1].grid(True, alpha=0.3)
axes[1].set_yscale('log')

plt.tight_layout()
plt.savefig(f'{OUTPUT_DIR}/training_curves_v2.png', dpi=300, bbox_inches='tight')
plt.show()
print("‚úÖ Training curves saved!")

## 13. Evaluation

In [None]:
# Load best model
checkpoint = torch.load(f'{OUTPUT_DIR}/best_model_v2.pth')
model.load_state_dict(checkpoint['model_state_dict'])
print(f"‚úÖ Best model loaded from epoch {checkpoint['epoch']+1}")
print(f"   Val loss: {checkpoint['val_loss']:.4f}")

In [None]:
# Evaluate on test set
model.eval()
test_predictions = []
test_targets = []

print("Running evaluation on test set...")
with torch.no_grad():
    for features, labels in tqdm(test_loader, desc="Testing"):
        features = features.to(DEVICE)
        predictions = model(features)
        test_predictions.extend(predictions.cpu().numpy())
        test_targets.extend(labels.numpy())

test_predictions = np.array(test_predictions)
test_targets = np.array(test_targets)

# Calculate metrics
mse = mean_squared_error(test_targets, test_predictions)
rmse = np.sqrt(mse)
mae = mean_absolute_error(test_targets, test_predictions)
spearman = spearmanr(test_targets, test_predictions)[0]
pearson = pearsonr(test_targets, test_predictions)[0]
r2 = 1 - (np.sum((test_targets - test_predictions)**2) / np.sum((test_targets - test_targets.mean())**2))

print("="*70)
print("TEST SET PERFORMANCE (v2 IMPROVED)")
print("="*70)
print(f"RMSE:        {rmse:.4f}")
print(f"MAE:         {mae:.4f}")
print(f"Spearman œÅ:  {spearman:.4f}")
print(f"Pearson r:   {pearson:.4f}")
print(f"R¬≤:          {r2:.4f}")
print("="*70)

In [None]:
# Per-bin metrics
test_df = pd.DataFrame({
    'target': test_targets,
    'prediction': test_predictions
})
test_df['affinity_bin'] = pd.cut(test_df['target'], bins=BINS, labels=BIN_LABELS, include_lowest=True)

print("\nPER-BIN PERFORMANCE:")
print("="*70)
print(f"{'Bin':<15} | {'Count':<8} | {'RMSE':<8} | {'MAE':<8}")
print("-"*70)

for label in BIN_LABELS:
    bin_data = test_df[test_df['affinity_bin'] == label]
    if len(bin_data) > 0:
        bin_rmse = np.sqrt(mean_squared_error(bin_data['target'], bin_data['prediction']))
        bin_mae = mean_absolute_error(bin_data['target'], bin_data['prediction'])
        marker = "‚≠ê" if label in ['very_strong', 'very_weak'] else "  "
        print(f"{marker} {label:<13} | {len(bin_data):<8} | {bin_rmse:<8.4f} | {bin_mae:<8.4f}")

print("="*70)

## 14. Comparison with v1

In [None]:
# Show comparison with v1 results
v1_results = {
    'RMSE': 1.4761,
    'MAE': 1.3011,
    'Spearman': 0.3912,
    'Pearson': 0.7265,
    'R2': 0.5188,
    'very_strong_rmse': 2.9394
}

v2_very_strong_rmse = np.sqrt(mean_squared_error(
    test_df[test_df['affinity_bin'] == 'very_strong']['target'],
    test_df[test_df['affinity_bin'] == 'very_strong']['prediction']
))

print("\n" + "="*70)
print("COMPARISON: v1 vs v2 (IMPROVED)")
print("="*70)
print(f"{'Metric':<20} | {'v1':<12} | {'v2 (improved)':<12} | {'Change':<12}")
print("-"*70)

metrics = [
    ('RMSE', v1_results['RMSE'], rmse, 'lower is better'),
    ('MAE', v1_results['MAE'], mae, 'lower is better'),
    ('Spearman œÅ', v1_results['Spearman'], spearman, 'higher is better'),
    ('Pearson r', v1_results['Pearson'], pearson, 'higher is better'),
    ('R¬≤', v1_results['R2'], r2, 'higher is better'),
    ('Very Strong RMSE', v1_results['very_strong_rmse'], v2_very_strong_rmse, 'lower is better')
]

for metric_name, v1_val, v2_val, direction in metrics:
    if 'lower' in direction:
        change_pct = ((v1_val - v2_val) / v1_val) * 100
        symbol = "‚úÖ" if v2_val < v1_val else "‚ùå"
    else:
        change_pct = ((v2_val - v1_val) / v1_val) * 100
        symbol = "‚úÖ" if v2_val > v1_val else "‚ùå"
    
    print(f"{symbol} {metric_name:<18} | {v1_val:<12.4f} | {v2_val:<12.4f} | {change_pct:+.1f}%")

print("="*70)

## 15. Generate Plots

In [None]:
# Predictions vs targets
plt.figure(figsize=(10, 10))
plt.scatter(test_targets, test_predictions, alpha=0.3, s=10)
plt.plot([test_targets.min(), test_targets.max()], [test_targets.min(), test_targets.max()], 'r--', lw=2)
plt.xlabel('True pKd', fontsize=12)
plt.ylabel('Predicted pKd', fontsize=12)
plt.title(f'v2 IMPROVED - Test Set Predictions\nSpearman œÅ = {spearman:.4f}, RMSE = {rmse:.4f}', fontsize=14)
plt.grid(True, alpha=0.3)
plt.axis('equal')
plt.savefig(f'{OUTPUT_DIR}/predictions_vs_targets_v2.png', dpi=300, bbox_inches='tight')
plt.show()
print("‚úÖ Prediction plot saved!")

In [None]:
# Residuals analysis
residuals = test_predictions - test_targets

fig, axes = plt.subplots(1, 2, figsize=(16, 6))

axes[0].scatter(test_predictions, residuals, alpha=0.3, s=10)
axes[0].axhline(y=0, color='r', linestyle='--', lw=2)
axes[0].set_xlabel('Predicted pKd', fontsize=12)
axes[0].set_ylabel('Residuals (Predicted - True)', fontsize=12)
axes[0].set_title('Residuals vs Predictions', fontsize=14)
axes[0].grid(True, alpha=0.3)

axes[1].hist(residuals, bins=50, edgecolor='black', alpha=0.7)
axes[1].axvline(x=0, color='r', linestyle='--', lw=2)
axes[1].set_xlabel('Residuals', fontsize=12)
axes[1].set_ylabel('Frequency', fontsize=12)
axes[1].set_title(f'Residuals Distribution\nMean = {residuals.mean():.4f}, Std = {residuals.std():.4f}', fontsize=14)
axes[1].grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig(f'{OUTPUT_DIR}/residuals_analysis_v2.png', dpi=300, bbox_inches='tight')
plt.show()
print("‚úÖ Residuals plot saved!")

## 16. Save Results

In [None]:
# Save results
results_summary = f"""
AbAg Binding Affinity Prediction - v2 IMPROVED Results
{'='*70}

Model Improvements:
  - GELU activation (vs ReLU in v1)
  - Deeper architecture: 512 ‚Üí 256 ‚Üí 128 ‚Üí 64 (vs 256 ‚Üí 128 in v1)
  - 10x stronger weights for very strong/weak binders
  - Lower learning rate: 0.0001 (vs 0.001 in v1)
  - {'Focal MSE Loss' if USE_FOCAL_LOSS else 'Weighted MSE Loss'}
  - Gradient clipping: {GRADIENT_CLIP}
  - AdamW optimizer with cosine annealing

Training:
  - Epochs: {EPOCHS}
  - Batch size: {BATCH_SIZE}
  - Training samples: {len(X_train):,}
  - Validation samples: {len(X_val):,}
  - Test samples: {len(X_test):,}
  - Total training time: {total_time/3600:.2f} hours

Test Set Performance:
  - RMSE:       {rmse:.4f} (v1: {v1_results['RMSE']:.4f})
  - MAE:        {mae:.4f} (v1: {v1_results['MAE']:.4f})
  - Spearman œÅ: {spearman:.4f} (v1: {v1_results['Spearman']:.4f})
  - Pearson r:  {pearson:.4f} (v1: {v1_results['Pearson']:.4f})
  - R¬≤:         {r2:.4f} (v1: {v1_results['R2']:.4f})

Per-Bin Performance:
"""

for label in BIN_LABELS:
    bin_data = test_df[test_df['affinity_bin'] == label]
    if len(bin_data) > 0:
        bin_rmse = np.sqrt(mean_squared_error(bin_data['target'], bin_data['prediction']))
        bin_mae = mean_absolute_error(bin_data['target'], bin_data['prediction'])
        marker = "‚≠ê" if label in ['very_strong', 'very_weak'] else "  "
        results_summary += f"{marker} - {label:<15}: RMSE={bin_rmse:6.4f}, MAE={bin_mae:6.4f}, N={len(bin_data):6,}\n"

results_summary += f"\n{'='*70}\n"

with open(f'{OUTPUT_DIR}/evaluation_results_v2.txt', 'w') as f:
    f.write(results_summary)

print(results_summary)
print(f"‚úÖ Results saved to {OUTPUT_DIR}/evaluation_results_v2.txt")

In [None]:
# Save predictions
results_df = pd.DataFrame({
    'true_pKd': test_targets,
    'predicted_pKd': test_predictions,
    'residual': residuals,
    'affinity_bin': test_df['affinity_bin']
})

results_df.to_csv(f'{OUTPUT_DIR}/test_predictions_v2.csv', index=False)
print(f"‚úÖ Predictions saved to {OUTPUT_DIR}/test_predictions_v2.csv")

# Save final model
torch.save({
    'model_state_dict': model.state_dict(),
    'metrics': {
        'rmse': rmse,
        'mae': mae,
        'spearman': spearman,
        'pearson': pearson,
        'r2': r2
    },
    'config': {
        'input_dim': 150,
        'hidden_dims': [512, 256, 128, 64],
        'dropout': 0.3,
        'activation': 'GELU',
        'epochs': EPOCHS,
        'batch_size': BATCH_SIZE,
        'learning_rate': LEARNING_RATE,
        'focal_loss': USE_FOCAL_LOSS
    }
}, f'{OUTPUT_DIR}/final_model_v2.pth')

print(f"\n‚úÖ All files saved to Google Drive: {OUTPUT_DIR}")
print(f"\nYou can now download the trained model from Google Drive!")