# IgFold + ESM-2 Hybrid Training for Antibody-Antigen Binding

**Training time**: ~4-5 days on Colab (vs 40+ days locally)

**Architecture**:
- **Antibody**: IgFold BERT embeddings (512-dim) - antibody-specific features
- **Antigen**: ESM-2 embeddings (1280-dim) - general protein features
- **Advantage**: Better CDR and paratope representation â†’ Higher Recall@pKdâ‰¥9

## Step 1: Mount Google Drive

In [None]:
from google.colab import drive
drive.mount('/content/drive')

import os
os.chdir('/content/drive/MyDrive/AbAg_Training')
print(f"Current directory: {os.getcwd()}")

## Step 2: Check GPU

In [None]:
!nvidia-smi

## Step 3: Install Dependencies

In [None]:
!pip install -q transformers==4.57.1 torch pandas scipy scikit-learn tqdm igfold

## Step 4: Create Model Definition

In [None]:
%%writefile model_igfold_hybrid.py
"""
Hybrid IgFold + ESM-2 Model for Antibody-Antigen Binding Prediction
"""

import torch
import torch.nn as nn
from transformers import AutoModel, AutoTokenizer
from igfold import IgFoldRunner


class IgFoldBERTOnly(nn.Module):
    """
    Hybrid model:
    - IgFold BERT embeddings for antibodies (512-dim)
    - ESM-2 embeddings for antigens (1280-dim)
    """

    def __init__(self, esm_model_name="facebook/esm2_t33_650M_UR50D", dropout=0.3):
        super().__init__()

        # IgFold for antibody
        self.igfold = IgFoldRunner()

        # ESM-2 for antigen
        self.esm = AutoModel.from_pretrained(esm_model_name)
        self.esm_tokenizer = AutoTokenizer.from_pretrained(esm_model_name)

        # Freeze ESM-2
        for param in self.esm.parameters():
            param.requires_grad = False

        ab_dim = 512  # IgFold BERT
        ag_dim = 1280  # ESM-2 t33
        combined_dim = ab_dim + ag_dim

        # Regressor
        self.regressor = nn.Sequential(
            nn.Linear(combined_dim, 1024),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.LayerNorm(1024),
            nn.Linear(1024, 512),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.LayerNorm(512),
            nn.Linear(512, 128),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(128, 1)
        )

    def get_antibody_embedding(self, antibody_seq):
        sequences = {"H": antibody_seq}
        with torch.no_grad():
            emb = self.igfold.embed(sequences=sequences)
        bert_emb = emb.bert_embs.mean(dim=1)  # Pool over length
        return bert_emb.squeeze(0)

    def get_antigen_embedding(self, antigen_seq, device):
        tokens = self.esm_tokenizer(
            antigen_seq,
            return_tensors="pt",
            padding=True,
            truncation=True,
            max_length=512
        ).to(device)

        with torch.no_grad():
            outputs = self.esm(**tokens)
            ag_emb = outputs.last_hidden_state[:, 0, :]

        return ag_emb.squeeze(0)

    def forward(self, antibody_seqs, antigen_seqs, device):
        # Get embeddings
        ab_embeddings = []
        for ab_seq in antibody_seqs:
            ab_emb = self.get_antibody_embedding(ab_seq)
            ab_embeddings.append(ab_emb)
        ab_embeddings = torch.stack(ab_embeddings).to(device)

        ag_embeddings = []
        for ag_seq in antigen_seqs:
            ag_emb = self.get_antigen_embedding(ag_seq, device)
            ag_embeddings.append(ag_emb)
        ag_embeddings = torch.stack(ag_embeddings).to(device)

        # Combine and predict
        combined = torch.cat([ab_embeddings, ag_embeddings], dim=1)
        predictions = self.regressor(combined).squeeze(-1)

        return predictions


class FocalMSELoss(nn.Module):
    def __init__(self, gamma=2.0):
        super().__init__()
        self.gamma = gamma

    def forward(self, pred, target):
        mse = (pred - target) ** 2
        focal_weight = (1 + mse) ** self.gamma
        return (focal_weight * mse).mean()

## Step 5: Create Training Script

In [None]:
%%writefile train_colab_igfold.py
"""
Training script for IgFold + ESM-2 Hybrid
"""

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import pandas as pd
import numpy as np
from scipy import stats
from sklearn.model_selection import train_test_split
from sklearn.metrics import mean_squared_error, mean_absolute_error, r2_score
from tqdm import tqdm
import argparse
from pathlib import Path

from model_igfold_hybrid import IgFoldBERTOnly, FocalMSELoss


class AbAgDataset(Dataset):
    def __init__(self, df):
        self.df = df.reset_index(drop=True)

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        return {
            'antibody_sequence': self.df.iloc[idx]['antibody_sequence'],
            'antigen_sequence': self.df.iloc[idx]['antigen_sequence'],
            'pKd': torch.tensor(self.df.iloc[idx]['pKd'], dtype=torch.float32)
        }


def collate_fn(batch):
    antibody_seqs = [item['antibody_sequence'] for item in batch]
    antigen_seqs = [item['antigen_sequence'] for item in batch]
    pKds = torch.stack([item['pKd'] for item in batch])
    return {'antibody_seqs': antibody_seqs, 'antigen_seqs': antigen_seqs, 'pKd': pKds}


def train_epoch(model, loader, optimizer, criterion, device, scaler):
    model.train()
    total_loss = 0

    pbar = tqdm(loader, desc="Training")
    for batch in pbar:
        antibody_seqs = batch['antibody_seqs']
        antigen_seqs = batch['antigen_seqs']
        targets = batch['pKd'].to(device)

        optimizer.zero_grad()

        with torch.amp.autocast('cuda'):
            predictions = model(antibody_seqs, antigen_seqs, device)
            loss = criterion(predictions, targets)

        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        total_loss += loss.item()
        pbar.set_postfix({'loss': f'{loss.item():.2e}'})

    return total_loss / len(loader)


def evaluate(model, loader, device):
    model.eval()
    predictions = []
    targets = []

    with torch.no_grad():
        for batch in tqdm(loader, desc="Evaluating"):
            antibody_seqs = batch['antibody_seqs']
            antigen_seqs = batch['antigen_seqs']
            batch_targets = batch['pKd'].to(device)

            with torch.amp.autocast('cuda'):
                batch_predictions = model(antibody_seqs, antigen_seqs, device)

            predictions.extend(batch_predictions.cpu().numpy())
            targets.extend(batch_targets.cpu().numpy())

    predictions = np.array(predictions)
    targets = np.array(targets)

    rmse = np.sqrt(mean_squared_error(targets, predictions))
    mae = mean_absolute_error(targets, predictions)
    r2 = r2_score(targets, predictions)
    spearman = stats.spearmanr(targets, predictions)[0]
    pearson = np.corrcoef(targets, predictions)[0, 1]

    strong_binders = targets >= 9.0
    predicted_strong = predictions >= 9.0
    recall_pkd9 = (strong_binders & predicted_strong).sum() / strong_binders.sum() if strong_binders.sum() > 0 else 0

    return {
        'rmse': rmse,
        'mae': mae,
        'r2': r2,
        'spearman': spearman,
        'pearson': pearson,
        'recall_pkd9': recall_pkd9 * 100
    }


def main(args):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Device: {device}")
    print(f"GPU: {torch.cuda.get_device_name(0)}")

    df = pd.read_csv(args.data)
    print(f"Loaded {len(df):,} samples")

    train_df, temp_df = train_test_split(df, test_size=0.3, random_state=42)
    val_df, test_df = train_test_split(temp_df, test_size=0.5, random_state=42)

    print(f"Train: {len(train_df):,} | Val: {len(val_df):,} | Test: {len(test_df):,}")

    train_dataset = AbAgDataset(train_df)
    val_dataset = AbAgDataset(val_df)

    train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True,
                             num_workers=2, collate_fn=collate_fn, pin_memory=True)
    val_loader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False,
                           num_workers=2, collate_fn=collate_fn, pin_memory=True)

    model = IgFoldBERTOnly(dropout=args.dropout).to(device)
    criterion = FocalMSELoss(gamma=args.focal_gamma)
    optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=args.epochs)
    scaler = torch.amp.GradScaler('cuda')

    start_epoch = 0
    best_spearman = -1
    output_dir = Path(args.output_dir)
    output_dir.mkdir(exist_ok=True)

    if args.resume and Path(args.resume).exists():
        checkpoint = torch.load(args.resume, map_location=device)
        model.load_state_dict(checkpoint['model_state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
        start_epoch = checkpoint['epoch'] + 1
        best_spearman = checkpoint.get('best_val_spearman', -1)
        print(f"Resuming from epoch {start_epoch}, Best Spearman: {best_spearman:.4f}")

    print(f"\nStarting training for {args.epochs} epochs...\n")

    for epoch in range(start_epoch, args.epochs):
        print(f"Epoch {epoch+1}/{args.epochs}")

        train_loss = train_epoch(model, train_loader, optimizer, criterion, device, scaler)
        val_metrics = evaluate(model, val_loader, device)
        scheduler.step()

        print(f"Train Loss: {train_loss:.4f}")
        print(f"Val RMSE: {val_metrics['rmse']:.4f} | Spearman: {val_metrics['spearman']:.4f}")
        print(f"Val Recall@pKdâ‰¥9: {val_metrics['recall_pkd9']:.2f}%")

        if val_metrics['spearman'] > best_spearman:
            best_spearman = val_metrics['spearman']
            torch.save(model.state_dict(), output_dir / 'best_model.pth')
            print("âœ“ Saved best model")

        checkpoint = {
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'scheduler_state_dict': scheduler.state_dict(),
            'best_val_spearman': best_spearman,
            'val_metrics': val_metrics
        }
        torch.save(checkpoint, output_dir / 'checkpoint_latest.pth')

    print(f"\nTraining complete! Best Spearman: {best_spearman:.4f}")


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--data', type=str, required=True)
    parser.add_argument('--output_dir', type=str, default='outputs_igfold')
    parser.add_argument('--resume', type=str, default=None)
    parser.add_argument('--epochs', type=int, default=50)
    parser.add_argument('--batch_size', type=int, default=8)
    parser.add_argument('--lr', type=float, default=1e-3)
    parser.add_argument('--weight_decay', type=float, default=0.01)
    parser.add_argument('--dropout', type=float, default=0.3)
    parser.add_argument('--focal_gamma', type=float, default=2.0)
    args = parser.parse_args()
    main(args)

## Step 6: Start Training ðŸš€

In [None]:
!python train_colab_igfold.py \
  --data agab_phase2_full.csv \
  --epochs 50 \
  --batch_size 8 \
  --focal_gamma 2.0 \
  --output_dir outputs_igfold

## Monitor Progress

In [None]:
# Check checkpoint
import torch
checkpoint = torch.load('outputs_igfold/checkpoint_latest.pth', map_location='cpu')
print(f"Epoch: {checkpoint['epoch'] + 1}")
print(f"Best Spearman: {checkpoint['best_val_spearman']:.4f}")
print(f"Recall@pKdâ‰¥9: {checkpoint['val_metrics']['recall_pkd9']:.2f}%")

## Download Results

In [None]:
from google.colab import files
files.download('outputs_igfold/best_model.pth')
files.download('outputs_igfold/checkpoint_latest.pth')