# Resume Training or Evaluate Saved Model

**Use this notebook when:**
- Your Colab session timed out
- Training was interrupted
- You want to evaluate an already-trained model
- You want to continue training from a checkpoint

**This notebook will:**
1. Reconnect to Google Drive
2. Load your saved model/checkpoint
3. Continue training OR run evaluation
4. Generate all plots and metrics

## 1. Setup - Reconnect Everything

In [None]:
# Check GPU
import torch
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"VRAM: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")

In [None]:
# Install dependencies
!pip install -q transformers scikit-learn pandas numpy tqdm matplotlib seaborn
print("✅ Dependencies installed!")

In [None]:
# Mount Google Drive
from google.colab import drive
drive.mount('/content/drive')
print("✅ Google Drive mounted!")

In [None]:
# Import all libraries
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split
from sklearn.metrics import mean_squared_error, mean_absolute_error
from scipy.stats import spearmanr, pearsonr
from tqdm import tqdm
import matplotlib.pyplot as plt
import seaborn as sns
import os
import time

torch.manual_seed(42)
np.random.seed(42)

print("✅ All imports complete!")

In [None]:
# Set paths - UPDATE THESE to match your Drive location
DRIVE_DATA_PATH = "/content/drive/MyDrive/AbAg_data/merged_with_all_features.csv"
MODEL_DIR = "/content/drive/MyDrive/AbAg_data/models"

DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'

# Constants
BINS = [0, 5, 7, 9, 11, 16]
BIN_LABELS = ['very_weak', 'weak', 'moderate', 'strong', 'very_strong']

print(f"Model directory: {MODEL_DIR}")
print(f"Device: {DEVICE}")

# List available checkpoints
if os.path.exists(MODEL_DIR):
    print("\nAvailable files:")
    for f in sorted(os.listdir(MODEL_DIR)):
        if f.endswith('.pth'):
            size_mb = os.path.getsize(os.path.join(MODEL_DIR, f)) / 1e6
            print(f"  - {f} ({size_mb:.1f} MB)")
else:
    print(f"\n⚠️ Model directory not found! Check the path.")

## 2. Recreate Model Architecture

In [None]:
# Dataset class
class AffinityDataset(Dataset):
    def __init__(self, features, labels):
        self.features = torch.FloatTensor(features)
        self.labels = torch.FloatTensor(labels)
    
    def __len__(self):
        return len(self.labels)
    
    def __getitem__(self, idx):
        return self.features[idx], self.labels[idx]

print("✅ Dataset class defined!")

In [None]:
# Model architecture (MUST match training)
class AffinityPredictor(nn.Module):
    def __init__(self, input_dim=150, hidden_dims=[256, 128], dropout=0.3):
        super(AffinityPredictor, self).__init__()
        
        layers = []
        prev_dim = input_dim
        
        for hidden_dim in hidden_dims:
            layers.extend([
                nn.Linear(prev_dim, hidden_dim),
                nn.BatchNorm1d(hidden_dim),
                nn.ReLU(),
                nn.Dropout(dropout)
            ])
            prev_dim = hidden_dim
        
        layers.append(nn.Linear(prev_dim, 1))
        
        self.network = nn.Sequential(*layers)
    
    def forward(self, x):
        return self.network(x).squeeze()

print("✅ Model class defined!")

In [None]:
# Weighted MSE Loss
class WeightedMSELoss(nn.Module):
    def __init__(self, bin_weights, bins_edges):
        super().__init__()
        self.bin_weights = bin_weights
        self.bins = bins_edges
    
    def forward(self, predictions, targets):
        weights = torch.ones_like(targets)
        for i, (low, high) in enumerate(zip(self.bins[:-1], self.bins[1:])):
            mask = (targets >= low) & (targets < high)
            weights[mask] = self.bin_weights[i]
        
        mse = (predictions - targets) ** 2
        weighted_mse = mse * weights
        return weighted_mse.mean()

print("✅ Weighted MSE loss defined!")

## 3. Load Model

**Choose which model to load:**
- `best_model.pth` - Best validation loss during training
- `final_model.pth` - Final model with metadata
- `checkpoint_epoch_XX.pth` - Specific checkpoint

In [None]:
# Load model - CHANGE THIS if you want a different checkpoint
MODEL_PATH = f"{MODEL_DIR}/best_model.pth"

# Initialize model
model = AffinityPredictor(input_dim=150, hidden_dims=[256, 128], dropout=0.3)
model = model.to(DEVICE)

# Load weights
if os.path.exists(MODEL_PATH):
    checkpoint = torch.load(MODEL_PATH)
    model.load_state_dict(checkpoint['model_state_dict'])
    print(f"✅ Model loaded from: {MODEL_PATH}")
    
    if 'epoch' in checkpoint:
        print(f"   Epoch: {checkpoint['epoch']+1}")
    if 'val_loss' in checkpoint:
        print(f"   Val loss: {checkpoint['val_loss']:.4f}")
    if 'metrics' in checkpoint:
        print(f"   Metrics: {checkpoint['metrics']}")
else:
    print(f"❌ Model not found at: {MODEL_PATH}")
    print("   Please check the path or choose a different checkpoint.")

## 4. Load Test Data

In [None]:
# Copy data to local if needed
LOCAL_DATA_PATH = "/content/merged_with_all_features.csv"

if not os.path.exists(LOCAL_DATA_PATH) and os.path.exists(DRIVE_DATA_PATH):
    print("Copying data to local storage...")
    !cp "{DRIVE_DATA_PATH}" "{LOCAL_DATA_PATH}"
    print("✅ Data copied!")
    DATA_PATH = LOCAL_DATA_PATH
elif os.path.exists(LOCAL_DATA_PATH):
    print("✅ Using local data")
    DATA_PATH = LOCAL_DATA_PATH
else:
    print("✅ Using Drive data")
    DATA_PATH = DRIVE_DATA_PATH

In [None]:
# Load and prepare data (same split as training)
print("Loading dataset...")
df = pd.read_csv(DATA_PATH, low_memory=False)

pca_cols = [f'esm2_pca_{i}' for i in range(150)]
df_with_features = df[df[pca_cols[0]].notna()].copy()
print(f"✅ Loaded {len(df_with_features):,} samples with features")

# Extract features
X = df_with_features[pca_cols].values
y = df_with_features['pKd'].values

# Same split as training (must use same random_state!)
X_temp, X_test, y_temp, y_test = train_test_split(X, y, test_size=0.15, random_state=42)
X_train, X_val, y_train, y_val = train_test_split(X_temp, y_temp, test_size=0.15/0.85, random_state=42)

print(f"\nData splits:")
print(f"  Train: {len(X_train):,}")
print(f"  Val: {len(X_val):,}")
print(f"  Test: {len(X_test):,}")

## 5. Run Evaluation

In [None]:
# Create test dataset and loader
test_dataset = AffinityDataset(X_test, y_test)
test_loader = DataLoader(test_dataset, batch_size=128, shuffle=False, num_workers=2)

# Evaluate
model.eval()
test_predictions = []
test_targets = []

print("Running evaluation...")
with torch.no_grad():
    for features, labels in tqdm(test_loader, desc="Testing"):
        features = features.to(DEVICE)
        predictions = model(features)
        test_predictions.extend(predictions.cpu().numpy())
        test_targets.extend(labels.numpy())

test_predictions = np.array(test_predictions)
test_targets = np.array(test_targets)

print("✅ Evaluation complete!")

In [None]:
# Calculate overall metrics
mse = mean_squared_error(test_targets, test_predictions)
rmse = np.sqrt(mse)
mae = mean_absolute_error(test_targets, test_predictions)
spearman = spearmanr(test_targets, test_predictions)[0]
pearson = pearsonr(test_targets, test_predictions)[0]
r2 = 1 - (np.sum((test_targets - test_predictions)**2) / np.sum((test_targets - test_targets.mean())**2))

print("="*60)
print("TEST SET PERFORMANCE")
print("="*60)
print(f"RMSE:        {rmse:.4f}")
print(f"MAE:         {mae:.4f}")
print(f"Spearman ρ:  {spearman:.4f}")
print(f"Pearson r:   {pearson:.4f}")
print(f"R²:          {r2:.4f}")
print("="*60)

In [None]:
# Per-bin metrics
test_df = pd.DataFrame({
    'target': test_targets,
    'prediction': test_predictions
})
test_df['affinity_bin'] = pd.cut(test_df['target'], bins=BINS, labels=BIN_LABELS, include_lowest=True)

print("\nPER-BIN PERFORMANCE:")
print("="*60)
print(f"{'Bin':<15} | {'Count':<8} | {'RMSE':<8} | {'MAE':<8}")
print("-"*60)

for label in BIN_LABELS:
    bin_data = test_df[test_df['affinity_bin'] == label]
    if len(bin_data) > 0:
        bin_rmse = np.sqrt(mean_squared_error(bin_data['target'], bin_data['prediction']))
        bin_mae = mean_absolute_error(bin_data['target'], bin_data['prediction'])
        print(f"{label:<15} | {len(bin_data):<8} | {bin_rmse:<8.4f} | {bin_mae:<8.4f}")

print("="*60)

## 6. Generate Plots

In [None]:
# Predictions vs targets
plt.figure(figsize=(10, 10))
plt.scatter(test_targets, test_predictions, alpha=0.3, s=10)
plt.plot([test_targets.min(), test_targets.max()], [test_targets.min(), test_targets.max()], 'r--', lw=2)
plt.xlabel('True pKd', fontsize=12)
plt.ylabel('Predicted pKd', fontsize=12)
plt.title(f'Test Set Predictions\nSpearman ρ = {spearman:.4f}, RMSE = {rmse:.4f}', fontsize=14)
plt.grid(True, alpha=0.3)
plt.axis('equal')
plt.savefig(f'{MODEL_DIR}/predictions_vs_targets.png', dpi=300, bbox_inches='tight')
plt.show()
print("✅ Prediction plot saved!")

In [None]:
# Residuals analysis
residuals = test_predictions - test_targets

fig, axes = plt.subplots(1, 2, figsize=(16, 6))

# Residuals vs predictions
axes[0].scatter(test_predictions, residuals, alpha=0.3, s=10)
axes[0].axhline(y=0, color='r', linestyle='--', lw=2)
axes[0].set_xlabel('Predicted pKd', fontsize=12)
axes[0].set_ylabel('Residuals (Predicted - True)', fontsize=12)
axes[0].set_title('Residuals vs Predictions', fontsize=14)
axes[0].grid(True, alpha=0.3)

# Residuals distribution
axes[1].hist(residuals, bins=50, edgecolor='black', alpha=0.7)
axes[1].axvline(x=0, color='r', linestyle='--', lw=2)
axes[1].set_xlabel('Residuals', fontsize=12)
axes[1].set_ylabel('Frequency', fontsize=12)
axes[1].set_title(f'Residuals Distribution\nMean = {residuals.mean():.4f}, Std = {residuals.std():.4f}', fontsize=14)
axes[1].grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig(f'{MODEL_DIR}/residuals_analysis.png', dpi=300, bbox_inches='tight')
plt.show()
print("✅ Residuals plot saved!")

In [None]:
# Per-bin performance visualization
bin_metrics = []
for label in BIN_LABELS:
    bin_data = test_df[test_df['affinity_bin'] == label]
    if len(bin_data) > 0:
        bin_rmse = np.sqrt(mean_squared_error(bin_data['target'], bin_data['prediction']))
        bin_mae = mean_absolute_error(bin_data['target'], bin_data['prediction'])  
        bin_metrics.append({'bin': label, 'RMSE': bin_rmse, 'MAE': bin_mae, 'count': len(bin_data)})

bin_df = pd.DataFrame(bin_metrics)

fig, axes = plt.subplots(1, 2, figsize=(16, 6))

# RMSE by bin
axes[0].bar(bin_df['bin'], bin_df['RMSE'], alpha=0.7, edgecolor='black')
axes[0].set_xlabel('Affinity Bin', fontsize=12)
axes[0].set_ylabel('RMSE', fontsize=12)
axes[0].set_title('RMSE by Affinity Bin', fontsize=14)
axes[0].tick_params(axis='x', rotation=45)
axes[0].grid(True, alpha=0.3, axis='y')

# Sample counts by bin
axes[1].bar(bin_df['bin'], bin_df['count'], alpha=0.7, edgecolor='black', color='green')
axes[1].set_xlabel('Affinity Bin', fontsize=12)
axes[1].set_ylabel('Sample Count', fontsize=12)
axes[1].set_title('Test Set Distribution', fontsize=14)
axes[1].tick_params(axis='x', rotation=45)
axes[1].grid(True, alpha=0.3, axis='y')

plt.tight_layout()
plt.savefig(f'{MODEL_DIR}/per_bin_analysis.png', dpi=300, bbox_inches='tight')
plt.show()
print("✅ Per-bin analysis plot saved!")

## 7. Save Results

In [None]:
# Save detailed results
results_summary = f"""
AbAg Binding Affinity Prediction - Evaluation Results
{'='*70}

Model: {MODEL_PATH}
Test samples: {len(X_test):,}

Overall Performance:
  - RMSE:       {rmse:.4f}
  - MAE:        {mae:.4f}
  - Spearman ρ: {spearman:.4f}
  - Pearson r:  {pearson:.4f}
  - R²:         {r2:.4f}

Per-Bin Performance:
"""

for label in BIN_LABELS:
    bin_data = test_df[test_df['affinity_bin'] == label]
    if len(bin_data) > 0:
        bin_rmse = np.sqrt(mean_squared_error(bin_data['target'], bin_data['prediction']))
        bin_mae = mean_absolute_error(bin_data['target'], bin_data['prediction'])
        results_summary += f"  - {label:<15}: RMSE={bin_rmse:6.4f}, MAE={bin_mae:6.4f}, N={len(bin_data):6,}\n"

results_summary += f"\n{'='*70}\n"

# Save to file
with open(f'{MODEL_DIR}/evaluation_results.txt', 'w') as f:
    f.write(results_summary)

print(results_summary)
print(f"✅ Results saved to {MODEL_DIR}/evaluation_results.txt")

In [None]:
# Save predictions to CSV for further analysis
results_df = pd.DataFrame({
    'true_pKd': test_targets,
    'predicted_pKd': test_predictions,
    'residual': residuals,
    'affinity_bin': test_df['affinity_bin']
})

results_df.to_csv(f'{MODEL_DIR}/test_predictions.csv', index=False)
print(f"✅ Predictions saved to {MODEL_DIR}/test_predictions.csv")
print(f"\nAll files saved to Google Drive: {MODEL_DIR}")

## 8. Summary

**Files saved to Google Drive:**
- `evaluation_results.txt` - Detailed metrics
- `test_predictions.csv` - All predictions for analysis
- `predictions_vs_targets.png` - Scatter plot
- `residuals_analysis.png` - Residuals plots
- `per_bin_analysis.png` - Per-bin performance

**You can now:**
1. Download these files from Google Drive
2. Use the model locally (see next notebook/script)
3. Share results
4. Continue training from checkpoint if needed