# Antibody-Antigen Binding Prediction Training on Colab

**Training time**: ~3-4 days (vs 36 days on RTX 2060)

**Setup**: Follow instructions below to connect Google Drive

## Step 1: Mount Google Drive

In [None]:
from google.colab import drive
drive.mount('/content/drive')

import os
os.chdir('/content/drive/MyDrive/AbAg_Training')
print(f"Current directory: {os.getcwd()}")

## Step 2: Check GPU

In [None]:
!nvidia-smi

## Step 3: Install Dependencies

In [None]:
!pip install -q transformers==4.57.1 torch pandas scipy scikit-learn tqdm

## Step 4: Create Tokenization Cache (One-time, ~10 minutes)

In [None]:
%%writefile create_tokenization_cache.py
"""
SQLite-Based Tokenization Cache
"""

import sqlite3
import pandas as pd
import torch
from transformers import AutoTokenizer
from tqdm import tqdm
import hashlib
import numpy as np
import argparse
from pathlib import Path


def create_sequence_hash(sequence):
    """Create hash for sequence to use as key"""
    return hashlib.md5(sequence.encode()).hexdigest()


def create_tokenization_db(db_path):
    """Create SQLite database for tokenized sequences"""
    conn = sqlite3.connect(db_path)
    cursor = conn.cursor()
    
    cursor.execute('''
        CREATE TABLE IF NOT EXISTS sequences (
            seq_hash TEXT PRIMARY KEY,
            input_ids BLOB,
            attention_mask BLOB
        )
    ''')
    
    cursor.execute('CREATE INDEX IF NOT EXISTS idx_hash ON sequences(seq_hash)')
    conn.commit()
    return conn


def tokenize_and_cache(csv_path, db_path, max_length=512, batch_size=100):
    """Tokenize all sequences and cache to SQLite"""
    
    print("Loading data...")
    df = pd.read_csv(csv_path)
    print(f"Total samples: {len(df):,}")
    
    print("\nInitializing tokenizer...")
    tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t33_650M_UR50D")
    
    print("\nCreating cache database...")
    conn = create_tokenization_db(db_path)
    cursor = conn.cursor()
    
    # Get unique sequences
    unique_sequences = set()
    for seq in df['antibody_sequence']:
        unique_sequences.add(seq)
    for seq in df['antigen_sequence']:
        unique_sequences.add(seq)
    
    unique_sequences = list(unique_sequences)
    print(f"\nUnique sequences to tokenize: {len(unique_sequences):,}")
    
    # Tokenize in batches
    print("\nTokenizing and caching...")
    for i in tqdm(range(0, len(unique_sequences), batch_size)):
        batch = unique_sequences[i:i+batch_size]
        
        for seq in batch:
            seq_hash = create_sequence_hash(seq)
            
            # Check if already cached
            cursor.execute('SELECT 1 FROM sequences WHERE seq_hash = ?', (seq_hash,))
            if cursor.fetchone():
                continue
            
            # Tokenize
            tokens = tokenizer(
                seq,
                max_length=max_length,
                padding='max_length',
                truncation=True,
                return_tensors='pt'
            )
            
            # Convert to numpy for storage
            input_ids = tokens['input_ids'].squeeze(0).numpy()
            attention_mask = tokens['attention_mask'].squeeze(0).numpy()
            
            # Store as bytes
            input_ids_blob = input_ids.tobytes()
            attention_mask_blob = attention_mask.tobytes()
            
            cursor.execute('''
                INSERT INTO sequences (seq_hash, input_ids, attention_mask)
                VALUES (?, ?, ?)
            ''', (seq_hash, input_ids_blob, attention_mask_blob))
        
        if (i + batch_size) % 1000 == 0:
            conn.commit()
    
    conn.commit()
    conn.close()
    
    print(f"\n‚úÖ Tokenization cache created: {db_path}")
    print(f"   Size: {Path(db_path).stat().st_size / (1024*1024):.1f} MB")


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--data', type=str, required=True)
    parser.add_argument('--output', type=str, default='tokenization_cache.db')
    parser.add_argument('--max_length', type=int, default=512)
    args = parser.parse_args()
    
    tokenize_and_cache(args.data, args.output, args.max_length)

In [None]:
# Run cache creation (only need to do this once!)
!python create_tokenization_cache.py \
  --data agab_phase2_full.csv \
  --output tokenization_cache.db

## Step 5: Training Script

In [None]:
%%writefile train_colab.py
"""
Colab Training Script with SQLite Cache
"""

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from transformers import AutoModel
import pandas as pd
import numpy as np
from scipy import stats
from sklearn.model_selection import train_test_split
from sklearn.metrics import mean_squared_error, mean_absolute_error, r2_score
from tqdm import tqdm
import argparse
import json
import time
from pathlib import Path
import sqlite3
import hashlib


class FocalMSELoss(nn.Module):
    def __init__(self, gamma=2.0):
        super().__init__()
        self.gamma = gamma

    def forward(self, pred, target):
        mse = (pred - target) ** 2
        focal_weight = (1 + mse) ** self.gamma
        return (focal_weight * mse).mean()


class CachedAbAgDataset(Dataset):
    def __init__(self, df, cache_db_path, max_length=512):
        self.df = df.reset_index(drop=True)
        self.cache_db_path = cache_db_path
        self.max_length = max_length
        self.conn = None
        print(f"Dataset: {len(df)} samples (SQLite cache)")

    def _get_connection(self):
        if self.conn is None:
            self.conn = sqlite3.connect(self.cache_db_path, check_same_thread=False)
        return self.conn

    def _get_tokens(self, sequence):
        seq_hash = hashlib.md5(sequence.encode()).hexdigest()
        conn = self._get_connection()
        cursor = conn.cursor()
        cursor.execute('SELECT input_ids, attention_mask FROM sequences WHERE seq_hash = ?', (seq_hash,))
        result = cursor.fetchone()
        
        if result is None:
            raise ValueError(f"Sequence not found in cache: {seq_hash}")
        
        input_ids = np.frombuffer(result[0], dtype=np.int64)
        attention_mask = np.frombuffer(result[1], dtype=np.int64)
        
        return torch.from_numpy(input_ids.copy()), torch.from_numpy(attention_mask.copy())

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        ab_seq = self.df.iloc[idx]['antibody_sequence']
        ag_seq = self.df.iloc[idx]['antigen_sequence']
        
        ab_input_ids, ab_attention_mask = self._get_tokens(ab_seq)
        ag_input_ids, ag_attention_mask = self._get_tokens(ag_seq)
        
        return {
            'antibody_input_ids': ab_input_ids,
            'antibody_attention_mask': ab_attention_mask,
            'antigen_input_ids': ag_input_ids,
            'antigen_attention_mask': ag_attention_mask,
            'pKd': torch.tensor(self.df.iloc[idx]['pKd'], dtype=torch.float32)
        }


class AbAgModel(nn.Module):
    def __init__(self, model_name="facebook/esm2_t33_650M_UR50D", dropout=0.2):
        super().__init__()
        self.esm = AutoModel.from_pretrained(model_name)
        
        for param in self.esm.parameters():
            param.requires_grad = False
        
        hidden_size = self.esm.config.hidden_size
        self.regressor = nn.Sequential(
            nn.Linear(hidden_size * 2, 512),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(512, 128),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(128, 1)
        )

    def forward(self, ab_input_ids, ab_attention_mask, ag_input_ids, ag_attention_mask):
        ab_emb = self.esm(input_ids=ab_input_ids, attention_mask=ab_attention_mask).last_hidden_state[:, 0, :]
        ag_emb = self.esm(input_ids=ag_input_ids, attention_mask=ag_attention_mask).last_hidden_state[:, 0, :]
        combined = torch.cat([ab_emb, ag_emb], dim=1)
        return self.regressor(combined).squeeze(-1)


def train_epoch(model, loader, optimizer, criterion, device, scaler):
    model.train()
    total_loss = 0
    
    pbar = tqdm(loader, desc="Training")
    for batch in pbar:
        antibody_input_ids = batch['antibody_input_ids'].to(device)
        antibody_attention_mask = batch['antibody_attention_mask'].to(device)
        antigen_input_ids = batch['antigen_input_ids'].to(device)
        antigen_attention_mask = batch['antigen_attention_mask'].to(device)
        targets = batch['pKd'].to(device)
        
        optimizer.zero_grad()
        
        with torch.amp.autocast('cuda'):
            predictions = model(antibody_input_ids, antibody_attention_mask,
                              antigen_input_ids, antigen_attention_mask)
            loss = criterion(predictions, targets)
        
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
        
        total_loss += loss.item()
        pbar.set_postfix({'loss': f'{loss.item():.2e}'})
    
    return total_loss / len(loader)


def evaluate(model, loader, device):
    model.eval()
    predictions = []
    targets = []
    
    with torch.no_grad():
        for batch in tqdm(loader, desc="Evaluating"):
            antibody_input_ids = batch['antibody_input_ids'].to(device)
            antibody_attention_mask = batch['antibody_attention_mask'].to(device)
            antigen_input_ids = batch['antigen_input_ids'].to(device)
            antigen_attention_mask = batch['antigen_attention_mask'].to(device)
            batch_targets = batch['pKd'].to(device)
            
            with torch.amp.autocast('cuda'):
                batch_predictions = model(antibody_input_ids, antibody_attention_mask,
                                        antigen_input_ids, antigen_attention_mask)
            
            predictions.extend(batch_predictions.cpu().numpy())
            targets.extend(batch_targets.cpu().numpy())
    
    predictions = np.array(predictions)
    targets = np.array(targets)
    
    rmse = np.sqrt(mean_squared_error(targets, predictions))
    mae = mean_absolute_error(targets, predictions)
    r2 = r2_score(targets, predictions)
    spearman = stats.spearmanr(targets, predictions)[0]
    
    # Recall for strong binders
    strong_binders = targets >= 9.0
    predicted_strong = predictions >= 9.0
    recall = (strong_binders & predicted_strong).sum() / strong_binders.sum() if strong_binders.sum() > 0 else 0
    
    return {
        'rmse': rmse,
        'mae': mae,
        'r2': r2,
        'spearman': spearman,
        'recall_pkd9': recall * 100,
        'predictions': predictions,
        'targets': targets
    }


def main(args):
    torch.backends.cudnn.benchmark = True
    torch.backends.cuda.matmul.allow_tf32 = True
    torch.backends.cudnn.allow_tf32 = True
    
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Device: {device}")
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    
    df = pd.read_csv(args.data)
    print(f"\nLoaded {len(df):,} samples")
    
    train_df, temp_df = train_test_split(df, test_size=0.3, random_state=42)
    val_df, test_df = train_test_split(temp_df, test_size=0.5, random_state=42)
    
    print(f"Train: {len(train_df):,} | Val: {len(val_df):,} | Test: {len(test_df):,}")
    
    train_dataset = CachedAbAgDataset(train_df, args.cache_db)
    val_dataset = CachedAbAgDataset(val_df, args.cache_db)
    
    train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, 
                             num_workers=2, pin_memory=True)
    val_loader = DataLoader(val_dataset, batch_size=args.batch_size * 2, shuffle=False,
                           num_workers=2, pin_memory=True)
    
    model = AbAgModel(dropout=args.dropout).to(device)
    criterion = FocalMSELoss(gamma=args.focal_gamma)
    optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=args.epochs)
    scaler = torch.amp.GradScaler('cuda')
    
    start_epoch = 0
    best_spearman = -1
    output_dir = Path(args.output_dir)
    output_dir.mkdir(exist_ok=True)
    
    # Resume from checkpoint if provided
    if args.resume and Path(args.resume).exists():
        print(f"\nüîÑ Resuming from checkpoint: {args.resume}")
        checkpoint = torch.load(args.resume, map_location=device)
        model.load_state_dict(checkpoint['model_state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        if 'scheduler_state_dict' in checkpoint:
            scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
        start_epoch = checkpoint['epoch'] + 1
        best_spearman = checkpoint.get('best_val_spearman', -1)
        print(f"‚úì Resuming from epoch {start_epoch}")
        print(f"‚úì Best Spearman: {best_spearman:.4f}")
    
    print(f"\n{'='*70}")
    print(f"Starting training for {args.epochs} epochs (from epoch {start_epoch})")
    print(f"{'='*70}\n")
    
    for epoch in range(start_epoch, args.epochs):
        print(f"\nEpoch {epoch+1}/{args.epochs}")
        
        train_loss = train_epoch(model, train_loader, optimizer, criterion, device, scaler)
        val_metrics = evaluate(model, val_loader, device)
        scheduler.step()
        
        print(f"Train Loss: {train_loss:.4f}")
        print(f"Val RMSE: {val_metrics['rmse']:.4f} | MAE: {val_metrics['mae']:.4f}")
        print(f"Val Spearman: {val_metrics['spearman']:.4f} | R¬≤: {val_metrics['r2']:.4f}")
        print(f"Val Recall@pKd‚â•9: {val_metrics['recall_pkd9']:.2f}%")
        
        if val_metrics['spearman'] > best_spearman:
            best_spearman = val_metrics['spearman']
            torch.save(model.state_dict(), output_dir / 'best_model.pth')
            print("‚úì Saved best model")
        
        checkpoint = {
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'scheduler_state_dict': scheduler.state_dict(),
            'best_val_spearman': best_spearman,
            'val_metrics': val_metrics
        }
        torch.save(checkpoint, output_dir / 'checkpoint_latest.pth')
    
    print(f"\n{'='*70}")
    print(f"Training complete! Best Spearman: {best_spearman:.4f}")
    print(f"{'='*70}")


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--data', type=str, required=True)
    parser.add_argument('--cache_db', type=str, required=True)
    parser.add_argument('--output_dir', type=str, default='outputs_colab')
    parser.add_argument('--resume', type=str, default=None, help='Path to checkpoint to resume from')
    parser.add_argument('--epochs', type=int, default=50)
    parser.add_argument('--batch_size', type=int, default=32)
    parser.add_argument('--lr', type=float, default=1e-3)
    parser.add_argument('--weight_decay', type=float, default=0.01)
    parser.add_argument('--dropout', type=float, default=0.2)
    parser.add_argument('--focal_gamma', type=float, default=2.0)
    args = parser.parse_args()
    
    main(args)

## Step 6: Start Training üöÄ

**Option A: Start from scratch**
**Option B: Resume from your local checkpoint (epoch 5)**

In [None]:
# Option A: Start from scratch (epoch 1)
!python train_colab.py \
  --data agab_phase2_full.csv \
  --cache_db tokenization_cache.db \
  --epochs 50 \
  --batch_size 32 \
  --focal_gamma 2.0 \
  --output_dir outputs_colab

## Verify Files Before Training

In [None]:
# Check file sizes and integrity
import os
from pathlib import Path

files_to_check = {
    'agab_phase2_full.csv': 127 * 1024 * 1024,  # ~127 MB
    'checkpoint_latest.pth': 2500 * 1024 * 1024,  # ~2.5 GB
}

print("File verification:")
print("-" * 60)
for filename, expected_size in files_to_check.items():
    if Path(filename).exists():
        actual_size = Path(filename).stat().st_size
        size_mb = actual_size / (1024 * 1024)
        expected_mb = expected_size / (1024 * 1024)
        status = "‚úì" if actual_size > expected_size * 0.95 else "‚ö†Ô∏è SIZE MISMATCH"
        print(f"{status} {filename}: {size_mb:.1f} MB (expected: {expected_mb:.1f} MB)")
        
        # For checkpoint, try to load it
        if filename.endswith('.pth') and actual_size > 1024:
            try:
                import torch
                checkpoint = torch.load(filename, map_location='cpu', weights_only=False)
                print(f"   ‚úì Valid checkpoint - Epoch {checkpoint['epoch'] + 1}, Spearman: {checkpoint['best_val_spearman']:.4f}")
            except Exception as e:
                print(f"   ‚úó CORRUPTED: {str(e)[:100]}")
                print(f"   ‚Üí Recommendation: Delete and re-upload, OR use Option A (start fresh)")
    else:
        print(f"‚úó {filename}: NOT FOUND")
        if filename == 'agab_phase2_full.csv':
            print(f"   ‚Üí REQUIRED: Please upload this file to continue!")

print("-" * 60)
print("\nRecommendation:")
if not Path('agab_phase2_full.csv').exists():
    print("‚ùå Cannot proceed - missing data file!")
elif not Path('checkpoint_latest.pth').exists():
    print("‚úì Use Option A (start from scratch)")
else:
    # Check if checkpoint loads
    try:
        import torch
        torch.load('checkpoint_latest.pth', map_location='cpu', weights_only=False)
        print("‚úì Use Option B (resume from epoch 6)")
    except:
        print("‚ö†Ô∏è  Checkpoint corrupted - Use Option A (start from scratch)")

## Monitor Progress

In [None]:
# Check checkpoint
import torch
checkpoint = torch.load('outputs_colab/checkpoint_latest.pth', map_location='cpu')
print(f"Epoch: {checkpoint['epoch'] + 1}")
print(f"Best Spearman: {checkpoint['best_val_spearman']:.4f}")
print(f"Latest metrics: {checkpoint['val_metrics']}")

## Download Results

In [None]:
# Download best model
from google.colab import files
files.download('outputs_colab/best_model.pth')
files.download('outputs_colab/checkpoint_latest.pth')