In [1]:
!pip install -q torch_geometric pyg_lib torch_scatter torch_sparse -f https://data.pyg.org/whl/torch-$(python -c 'import torch; print(torch.__version__ )').html

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m63.1/63.1 kB[0m [31m2.9 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m4.7/4.7 MB[0m [31m58.5 MB/s[0m eta [36m0:00:00[0m00:01[0m00:01[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m10.8/10.8 MB[0m [31m113.5 MB/s[0m eta [36m0:00:00[0m00:01[0m00:01[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m5.0/5.0 MB[0m [31m94.4 MB/s[0m eta [36m0:00:00[0m:00:01[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.1/1.1 MB[0m [31m28.9 MB/s[0m eta [36m0:00:00[0m00:01[0m
[?25h

In [2]:
!pip install -q rdkit

[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m36.2/36.2 MB[0m [31m53.1 MB/s[0m eta [36m0:00:00[0m:00:01[0m00:01[0m
[?25h

In [3]:
!pip install -q transformers

[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m564.3/564.3 kB[0m [31m14.3 MB/s[0m eta [36m0:00:00[0m00:01[0m
[?25h[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
datasets 4.1.1 requires pyarrow>=21.0.0, but you have pyarrow 19.0.1 which is incompatible.
gradio 5.38.1 requires pydantic<2.12,>=2.0, but you have pydantic 2.12.0a1 which is incompatible.[0m[31m
[0m

In [4]:
!pip install -q pandas scikit-learn tqdm

In [7]:
"""
OPTIMIZED FOR MULTI-GPU TRAINING with DistributedDataParallel
- Proper DDP setup for PyTorch Geometric
- Distributed data loading with DistributedSampler
- Gradient synchronization across GPUs
- Persistent preprocessed data (saved to disk)
- Memory-efficient training with mixed precision
"""

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.distributed as dist
import torch.multiprocessing as mp
from torch.utils.data import Dataset, DataLoader
from torch.utils.data.distributed import DistributedSampler
from torch.nn.parallel import DistributedDataParallel as DDP
from torch_geometric.nn import GCNConv, global_mean_pool
from torch_geometric.data import Data, Batch
import pandas as pd
import numpy as np
from rdkit import Chem
from sklearn.model_selection import train_test_split
from sklearn.metrics import r2_score, mean_squared_error, mean_absolute_error
from transformers import EsmModel, AutoTokenizer
import warnings
from tqdm import tqdm
import os
import pickle
from pathlib import Path

warnings.filterwarnings('ignore')

# ============================================================================
# CONFIGURATION - OPTIMIZED FOR MULTI-GPU DDP
# ============================================================================

class Config:
    """Centralized configuration for distributed training"""
    # Model architecture
    LIGAND_EMBED_DIM = 192
    PROTEIN_EMBED_DIM = 192
    GNN_HIDDEN_DIM = 96
    GNN_NUM_LAYERS = 2
    ATTENTION_HEADS = 6
    ATTENTION_DROPOUT = 0.15
    MLP_DIMS = [576, 384, 192, 64, 1]
    MLP_DROPOUT = 0.2
    
    # ESM-2 settings
    ESM_MODEL = "facebook/esm2_t12_35M_UR50D"
    ESM_MAX_LENGTH = 800
    FREEZE_ESM = True
    
    # Training - DDP optimized
    BATCH_SIZE = 8           # Per GPU batch size
    GRADIENT_ACCUMULATION = 2  # Accumulation steps per GPU
    LEARNING_RATE = 3e-5
    WEIGHT_DECAY = 1e-5
    NUM_EPOCHS = 50
    PATIENCE = 15
    GRAD_CLIP_NORM = 1.0
    
    # Data
    TEST_SIZE = 0.15
    VAL_SIZE = 0.15
    RANDOM_SEED = 42
    
    # Persistence paths
    DATA_CACHE_DIR = Path("/kaggle/working/dta_cache")
    PROCESSED_DATA_FILE = "processed_datasets.pkl"
    
    # DDP settings
    BACKEND = 'nccl'  # Use 'gloo' for CPU or Windows
    FIND_UNUSED_PARAMETERS = False  # Set to True if needed
    
    # DataLoader settings
    NUM_WORKERS = 4  # Per GPU
    PIN_MEMORY = True
    PERSISTENT_WORKERS = False  # Set to False for DDP compatibility

# ============================================================================
# DDP UTILITY FUNCTIONS
# ============================================================================

def setup_ddp(rank, world_size):
    """Initialize the distributed environment"""
    os.environ['MASTER_ADDR'] = 'localhost'
    os.environ['MASTER_PORT'] = '12355'
    
    # Initialize process group
    dist.init_process_group(
        backend=Config.BACKEND,
        init_method='env://',
        world_size=world_size,
        rank=rank
    )
    
    # Set device for this process
    torch.cuda.set_device(rank)
    
    # Set random seeds for reproducibility
    torch.manual_seed(Config.RANDOM_SEED + rank)
    np.random.seed(Config.RANDOM_SEED + rank)


def cleanup_ddp():
    """Clean up distributed environment"""
    dist.destroy_process_group()


def is_main_process():
    """Check if this is the main process (rank 0)"""
    return not dist.is_initialized() or dist.get_rank() == 0


def get_rank():
    """Get current process rank"""
    if not dist.is_initialized():
        return 0
    return dist.get_rank()


def get_world_size():
    """Get total number of processes"""
    if not dist.is_initialized():
        return 1
    return dist.get_world_size()


def reduce_metric(value, world_size):
    """Average a metric across all processes"""
    if not dist.is_initialized():
        return value
    
    tensor = torch.tensor(value, device=f'cuda:{get_rank()}')
    dist.all_reduce(tensor, op=dist.ReduceOp.SUM)
    return tensor.item() / world_size


def all_gather_predictions(predictions, labels):
    """Gather predictions and labels from all processes"""
    if not dist.is_initialized():
        return predictions, labels
    
    world_size = get_world_size()
    rank = get_rank()
    
    # Convert to tensors
    pred_tensor = torch.tensor(predictions, device=f'cuda:{rank}')
    label_tensor = torch.tensor(labels, device=f'cuda:{rank}')
    
    # Gather sizes from all processes
    local_size = torch.tensor([len(predictions)], device=f'cuda:{rank}')
    size_list = [torch.zeros(1, dtype=torch.long, device=f'cuda:{rank}') 
                 for _ in range(world_size)]
    dist.all_gather(size_list, local_size)
    
    max_size = max([s.item() for s in size_list])
    
    # Pad tensors to max size
    padded_pred = torch.zeros(max_size, device=f'cuda:{rank}')
    padded_label = torch.zeros(max_size, device=f'cuda:{rank}')
    padded_pred[:len(predictions)] = pred_tensor
    padded_label[:len(labels)] = label_tensor
    
    # Gather from all processes
    pred_list = [torch.zeros(max_size, device=f'cuda:{rank}') 
                 for _ in range(world_size)]
    label_list = [torch.zeros(max_size, device=f'cuda:{rank}') 
                  for _ in range(world_size)]
    
    dist.all_gather(pred_list, padded_pred)
    dist.all_gather(label_list, padded_label)
    
    # Concatenate and trim padding
    all_preds = []
    all_labels = []
    for i, size in enumerate(size_list):
        all_preds.extend(pred_list[i][:size.item()].cpu().numpy())
        all_labels.extend(label_list[i][:size.item()].cpu().numpy())
    
    return np.array(all_preds), np.array(all_labels)

# ============================================================================
# DATA PROCESSING (Same as original with minor fixes)
# ============================================================================

def smiles_to_graph(smiles):
    """Convert SMILES to PyTorch Geometric graph"""
    mol = Chem.MolFromSmiles(smiles)
    if mol is None:
        return None
    
    atom_features = []
    for atom in mol.GetAtoms():
        features = [
            atom.GetAtomicNum(),
            atom.GetDegree(),
            atom.GetFormalCharge(),
            atom.GetHybridization().real,
            int(atom.GetIsAromatic())
        ]
        atom_features.append(features)
    
    x = torch.tensor(atom_features, dtype=torch.float)
    
    edge_indices = []
    edge_features = []
    for bond in mol.GetBonds():
        i = bond.GetBeginAtomIdx()
        j = bond.GetEndAtomIdx()
        bond_type = bond.GetBondTypeAsDouble()
        edge_indices.extend([[i, j], [j, i]])
        edge_features.extend([bond_type, bond_type])
    
    if len(edge_indices) == 0:
        edge_index = torch.zeros((2, 0), dtype=torch.long)
        edge_attr = torch.zeros((0, 1), dtype=torch.float)
    else:
        edge_index = torch.tensor(edge_indices, dtype=torch.long).t().contiguous()
        edge_attr = torch.tensor(edge_features, dtype=torch.float).unsqueeze(1)
    
    return Data(x=x, edge_index=edge_index, edge_attr=edge_attr)


class DTADataset(Dataset):
    """Drug-Target Affinity Dataset with disk caching"""
    
    def __init__(self, df, tokenizer, max_length=800, normalize_pki=True, 
                 pki_mean=None, pki_std=None, cache_data=None):
        self.df = df.reset_index(drop=True)
        self.tokenizer = tokenizer
        self.max_length = max_length
        self.normalize_pki = normalize_pki
        
        if normalize_pki:
            if pki_mean is None:
                self.pki_mean = df['pKi'].mean()
                self.pki_std = df['pKi'].std()
            else:
                self.pki_mean = pki_mean
                self.pki_std = pki_std
        
        if cache_data is not None:
            if is_main_process():
                print("✓ Loading from cache...")
            self.graphs = cache_data['graphs']
            self.protein_tokens = cache_data['protein_tokens']
            self.df = cache_data['df']
        else:
            self._process_data()
    
    def _process_data(self):
        """Process SMILES and proteins"""
        valid_indices = []
        self.graphs = []
        
        if is_main_process():
            print("Processing SMILES to molecular graphs...")
        
        for idx, smiles in enumerate(tqdm(self.df['Ligand SMILES'], 
                                         desc="SMILES", 
                                         disable=not is_main_process())):
            graph = smiles_to_graph(smiles)
            if graph is not None:
                self.graphs.append(graph)
                valid_indices.append(idx)
        
        self.df = self.df.iloc[valid_indices].reset_index(drop=True)
        
        if is_main_process():
            print(f"✓ Valid molecules: {len(self.df)}")
            print("Pre-tokenizing protein sequences...")
        
        self.protein_tokens = []
        for idx in tqdm(range(len(self.df)), 
                       desc="Proteins", 
                       disable=not is_main_process()):
            row = self.df.iloc[idx]
            sequence = self._get_protein_sequence(row)
            
            tokens = self.tokenizer(
                sequence,
                max_length=self.max_length,
                padding='max_length',
                truncation=True,
                return_tensors='pt'
            )
            
            self.protein_tokens.append({
                'input_ids': tokens['input_ids'].squeeze(0),
                'attention_mask': tokens['attention_mask'].squeeze(0)
            })
        
        if is_main_process():
            print(f"✓ Pre-processing complete!\n")
    
    def _get_protein_sequence(self, row):
        """Find protein sequence from various possible column names"""
        possible_names = [
            'BindingDB Target Chain Sequence',
            'BindingDB Target Chain Sequence 1',
            'Protein Sequence',
            'protein_sequence',
            'sequence'
        ]
        
        for name in possible_names:
            if name in row.index:
                return row[name]
        
        seq_cols = [col for col in row.index if 'sequence' in col.lower()]
        if seq_cols:
            return row[seq_cols[0]]
        
        raise KeyError(f"Cannot find protein sequence column. Available: {list(row.index)}")
    
    def get_cache_data(self):
        """Return data for caching"""
        return {
            'graphs': self.graphs,
            'protein_tokens': self.protein_tokens,
            'df': self.df
        }
    
    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        
        graph = self.graphs[idx]
        protein_input_ids = self.protein_tokens[idx]['input_ids']
        protein_attention_mask = self.protein_tokens[idx]['attention_mask']
        
        pki = row['pKi']
        if self.normalize_pki:
            pki = (pki - self.pki_mean) / self.pki_std
        pki = torch.tensor(pki, dtype=torch.float)
        
        return {
            'graph': graph,
            'protein_input_ids': protein_input_ids,
            'protein_attention_mask': protein_attention_mask,
            'pki': pki
        }


def collate_fn(batch):
    """Custom collate for batching PyTorch Geometric data"""
    graphs = [item['graph'] for item in batch]
    graph_batch = Batch.from_data_list(graphs)
    
    protein_input_ids = torch.stack([item['protein_input_ids'] for item in batch])
    protein_attention_mask = torch.stack([item['protein_attention_mask'] for item in batch])
    pki = torch.stack([item['pki'] for item in batch])
    
    return {
        'graph_batch': graph_batch,
        'protein_input_ids': protein_input_ids,
        'protein_attention_mask': protein_attention_mask,
        'pki': pki
    }

# ============================================================================
# MODEL ARCHITECTURE (Same as original)
# ============================================================================

class LigandGNN(nn.Module):
    """Graph Neural Network for ligand encoding"""
    
    def __init__(self, input_dim=5, hidden_dim=96, output_dim=192, num_layers=2):
        super().__init__()
        
        self.input_projection = nn.Linear(input_dim, hidden_dim)
        
        self.convs = nn.ModuleList()
        self.norms = nn.ModuleList()
        
        for _ in range(num_layers):
            self.convs.append(GCNConv(hidden_dim, hidden_dim))
            self.norms.append(nn.LayerNorm(hidden_dim))
        
        self.output_projection = nn.Sequential(
            nn.Linear(hidden_dim, output_dim),
            nn.LayerNorm(output_dim)
        )
        self.dropout = nn.Dropout(0.1)
    
    def forward(self, data):
        x, edge_index, batch = data.x, data.edge_index, data.batch
        
        x = self.input_projection(x)
        x = F.relu(x)
        
        for conv, norm in zip(self.convs, self.norms):
            x_residual = x
            x = conv(x, edge_index)
            x = norm(x)
            x = F.relu(x)
            x = self.dropout(x)
            x = x + x_residual
        
        x = global_mean_pool(x, batch)
        x = self.output_projection(x)
        
        return x


class ProteinEncoder(nn.Module):
    """ESM-2 based protein encoder"""
    
    def __init__(self, model_name, output_dim=192, freeze=True):
        super().__init__()
        
        self.esm = EsmModel.from_pretrained(model_name)
        
        if freeze:
            for param in self.esm.parameters():
                param.requires_grad = False
        
        esm_hidden_size = self.esm.config.hidden_size
        
        self.projection = nn.Sequential(
            nn.Linear(esm_hidden_size, 384),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(384, output_dim),
            nn.LayerNorm(output_dim)
        )
    
    def forward(self, input_ids, attention_mask):
        outputs = self.esm(input_ids=input_ids, attention_mask=attention_mask)
        protein_embedding = outputs.last_hidden_state[:, 0, :]
        protein_embedding = self.projection(protein_embedding)
        
        return protein_embedding, outputs.last_hidden_state


class BiDirectionalCrossAttention(nn.Module):
    """Bidirectional cross-attention"""
    
    def __init__(self, embed_dim=192, num_heads=6, dropout=0.15):
        super().__init__()
        
        self.ligand_to_protein = nn.MultiheadAttention(
            embed_dim=embed_dim,
            num_heads=num_heads,
            dropout=dropout,
            batch_first=True
        )
        
        self.protein_to_ligand = nn.MultiheadAttention(
            embed_dim=embed_dim,
            num_heads=num_heads,
            dropout=dropout,
            batch_first=True
        )
        
        self.norm1 = nn.LayerNorm(embed_dim)
        self.norm2 = nn.LayerNorm(embed_dim)
    
    def forward(self, ligand_embed, protein_embed):
        ligand_query = ligand_embed.unsqueeze(1)
        protein_query = protein_embed.unsqueeze(1)
        
        lig_attn, lig_weights = self.ligand_to_protein(
            query=ligand_query,
            key=protein_query,
            value=protein_query
        )
        lig_attn = self.norm1(lig_attn.squeeze(1) + ligand_embed)
        
        prot_attn, prot_weights = self.protein_to_ligand(
            query=protein_query,
            key=ligand_query,
            value=ligand_query
        )
        prot_attn = self.norm2(prot_attn.squeeze(1) + protein_embed)
        
        return lig_attn, prot_attn, (lig_weights, prot_weights)


class DTAModel(nn.Module):
    """Complete Drug-Target Affinity prediction model"""
    
    def __init__(self, config):
        super().__init__()
        
        self.ligand_encoder = LigandGNN(
            input_dim=5,
            hidden_dim=config.GNN_HIDDEN_DIM,
            output_dim=config.LIGAND_EMBED_DIM,
            num_layers=config.GNN_NUM_LAYERS
        )
        
        self.protein_encoder = ProteinEncoder(
            model_name=config.ESM_MODEL,
            output_dim=config.PROTEIN_EMBED_DIM,
            freeze=config.FREEZE_ESM
        )
        
        self.cross_attention = BiDirectionalCrossAttention(
            embed_dim=config.LIGAND_EMBED_DIM,
            num_heads=config.ATTENTION_HEADS,
            dropout=config.ATTENTION_DROPOUT
        )
        
        mlp_layers = []
        for i in range(len(config.MLP_DIMS) - 1):
            mlp_layers.extend([
                nn.Linear(config.MLP_DIMS[i], config.MLP_DIMS[i+1]),
                nn.ReLU() if i < len(config.MLP_DIMS) - 2 else nn.Identity(),
                nn.Dropout(config.MLP_DROPOUT) if i < len(config.MLP_DIMS) - 2 else nn.Identity()
            ])
        
        self.mlp = nn.Sequential(*mlp_layers)
    
    def forward(self, graph_batch, protein_input_ids, protein_attention_mask):
        ligand_embed = self.ligand_encoder(graph_batch)
        protein_embed, _ = self.protein_encoder(protein_input_ids, protein_attention_mask)
        lig_attn, prot_attn, attn_weights = self.cross_attention(ligand_embed, protein_embed)
        combined = torch.cat([lig_attn, prot_attn, protein_embed], dim=1)
        pki_pred = self.mlp(combined).squeeze(-1)
        
        return pki_pred, attn_weights

# ============================================================================
# DATA PREPROCESSING (Same as original)
# ============================================================================

def preprocess_and_cache_data(csv_path, config=Config):
    """
    STEP 1: Run this to preprocess and cache data
    Only needs to be run ONCE
    """
    config.DATA_CACHE_DIR.mkdir(parents=True, exist_ok=True)
    cache_file = config.DATA_CACHE_DIR / config.PROCESSED_DATA_FILE
    
    if cache_file.exists():
        print("=" * 80)
        print("CACHED DATA FOUND - Loading from disk...")
        print("=" * 80)
        with open(cache_file, 'rb') as f:
            cached_data = pickle.load(f)
        print("✓ Loaded cached data successfully!")
        return cached_data
    
    print("=" * 80)
    print("PREPROCESSING DATA - This will be saved to disk")
    print("=" * 80)
    
    torch.manual_seed(config.RANDOM_SEED)
    np.random.seed(config.RANDOM_SEED)
    
    print("Loading CSV...")
    df = pd.read_csv(csv_path)
    print(f"Total samples: {len(df)}")
    
    original_len = len(df)
    df = df[~df['pKi'].isna() & ~np.isinf(df['pKi'])]
    if len(df) < original_len:
        print(f"⚠ Removed {original_len - len(df)} invalid pKi rows")
    
    train_val_df, test_df = train_test_split(
        df, test_size=config.TEST_SIZE, random_state=config.RANDOM_SEED
    )
    train_df, val_df = train_test_split(
        train_val_df, test_size=config.VAL_SIZE / (1 - config.TEST_SIZE),
        random_state=config.RANDOM_SEED
    )
    
    print(f"Split: Train={len(train_df)} | Val={len(val_df)} | Test={len(test_df)}\n")
    
    print("Loading ESM-2 tokenizer...")
    tokenizer = AutoTokenizer.from_pretrained(config.ESM_MODEL, trust_remote_code=True)
    
    print("\nProcessing TRAIN dataset...")
    train_dataset = DTADataset(train_df, tokenizer, config.ESM_MAX_LENGTH)
    
    print("\nProcessing VAL dataset...")
    val_dataset = DTADataset(
        val_df, tokenizer, config.ESM_MAX_LENGTH,
        pki_mean=train_dataset.pki_mean,
        pki_std=train_dataset.pki_std
    )
    
    print("\nProcessing TEST dataset...")
    test_dataset = DTADataset(
        test_df, tokenizer, config.ESM_MAX_LENGTH,
        pki_mean=train_dataset.pki_mean,
        pki_std=train_dataset.pki_std
    )
    
    print("\n" + "=" * 80)
    print("SAVING PROCESSED DATA TO DISK...")
    cached_data = {
        'train_cache': train_dataset.get_cache_data(),
        'val_cache': val_dataset.get_cache_data(),
        'test_cache': test_dataset.get_cache_data(),
        'pki_mean': train_dataset.pki_mean,
        'pki_std': train_dataset.pki_std,
        'tokenizer_name': config.ESM_MODEL
    }
    
    with open(cache_file, 'wb') as f:
        pickle.dump(cached_data, f)
    
    print(f"✓ Data saved to: {cache_file}")
    print("=" * 80)
    
    return cached_data

# ============================================================================
# DDP TRAINER
# ============================================================================

class DDPTrainer:
    """Distributed trainer using DistributedDataParallel"""
    
    def __init__(self, model, train_loader, val_loader, config, rank, 
                 pki_mean=0, pki_std=1):
        self.config = config
        self.rank = rank
        self.pki_mean = pki_mean
        self.pki_std = pki_std
        
        # Move model to GPU and wrap with DDP
        self.model = model.to(rank)
        self.model = DDP(
            self.model, 
            device_ids=[rank],
            find_unused_parameters=config.FIND_UNUSED_PARAMETERS
        )
        
        self.train_loader = train_loader
        self.val_loader = val_loader
        
        self.optimizer = torch.optim.AdamW(
            self.model.parameters(),
            lr=config.LEARNING_RATE,
            weight_decay=config.WEIGHT_DECAY
        )
        
        self.scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
            self.optimizer,
            mode='min',
            factor=0.5,
            patience=5,
            verbose=(rank == 0)
        )
        
        self.scaler = torch.cuda.amp.GradScaler()
        self.criterion = nn.MSELoss()
        self.best_val_loss = float('inf')
        self.patience_counter = 0
    
    def train_epoch(self):
        self.model.train()
        total_loss = 0
        num_batches = 0
        self.optimizer.zero_grad()
        
        pbar = tqdm(self.train_loader, desc='Training', disable=(self.rank != 0))
        
        for batch_idx, batch in enumerate(pbar):
            graph_batch = batch['graph_batch'].to(self.rank)
            protein_input_ids = batch['protein_input_ids'].to(self.rank)
            protein_attention_mask = batch['protein_attention_mask'].to(self.rank)
            pki_true = batch['pki'].to(self.rank)
            
            # Mixed precision forward pass
            with torch.cuda.amp.autocast():
                pki_pred, _ = self.model(graph_batch, protein_input_ids, protein_attention_mask)
                loss = self.criterion(pki_pred, pki_true)
                loss = loss / self.config.GRADIENT_ACCUMULATION
            
            # Backward pass
            self.scaler.scale(loss).backward()
            
            # Gradient accumulation
            if (batch_idx + 1) % self.config.GRADIENT_ACCUMULATION == 0 or \
               (batch_idx + 1) == len(self.train_loader):
                self.scaler.unscale_(self.optimizer)
                torch.nn.utils.clip_grad_norm_(
                    self.model.parameters(), 
                    self.config.GRAD_CLIP_NORM
                )
                self.scaler.step(self.optimizer)
                self.scaler.update()
                self.optimizer.zero_grad()
            
            total_loss += loss.item() * self.config.GRADIENT_ACCUMULATION
            num_batches += 1
            
            if self.rank == 0:
                pbar.set_postfix({'loss': f'{loss.item() * self.config.GRADIENT_ACCUMULATION:.4f}'})
        
        # Average loss across all GPUs
        avg_loss = total_loss / num_batches
        avg_loss = reduce_metric(avg_loss, get_world_size())
        
        return avg_loss
    
    def evaluate(self, loader):
        self.model.eval()
        total_loss = 0
        all_preds = []
        all_trues = []
        
        with torch.no_grad():
            pbar = tqdm(loader, desc='Evaluating', disable=(self.rank != 0))
            for batch in pbar:
                graph_batch = batch['graph_batch'].to(self.rank)
                protein_input_ids = batch['protein_input_ids'].to(self.rank)
                protein_attention_mask = batch['protein_attention_mask'].to(self.rank)
                pki_true = batch['pki'].to(self.rank)
                
                with torch.cuda.amp.autocast():
                    pki_pred, _ = self.model(graph_batch, protein_input_ids, protein_attention_mask)
                
                if torch.isnan(pki_pred).any() or torch.isinf(pki_pred).any():
                    pki_pred = torch.nan_to_num(pki_pred, nan=0.0, posinf=10.0, neginf=0.0)
                
                loss = self.criterion(pki_pred, pki_true)
                
                total_loss += loss.item()
                all_preds.extend(pki_pred.cpu().numpy())
                all_trues.extend(pki_true.cpu().numpy())
        
        # Gather predictions from all GPUs
        all_preds, all_trues = all_gather_predictions(all_preds, all_trues)
        
        # Denormalize
        all_preds = np.array(all_preds) * self.pki_std + self.pki_mean
        all_trues = np.array(all_trues) * self.pki_std + self.pki_mean
        
        all_preds = np.nan_to_num(all_preds, nan=0.0, posinf=100.0, neginf=0.0)
        all_trues = np.nan_to_num(all_trues, nan=0.0, posinf=100.0, neginf=0.0)
        
        avg_loss = reduce_metric(total_loss / len(loader), get_world_size())
        
        if self.rank == 0:
            r2 = r2_score(all_trues, all_preds)
            rmse = np.sqrt(mean_squared_error(all_trues, all_preds))
            mae = mean_absolute_error(all_trues, all_preds)
        else:
            r2 = rmse = mae = 0.0
        
        return avg_loss, r2, rmse, mae
    
    def train(self):
        if self.rank == 0:
            print(f"\nDistributed Training on {get_world_size()} GPUs")
            print(f"Per-GPU batch size: {self.config.BATCH_SIZE}")
            print(f"Effective batch size: {self.config.BATCH_SIZE * get_world_size() * self.config.GRADIENT_ACCUMULATION}")
            print("=" * 80)
        
        for epoch in range(self.config.NUM_EPOCHS):
            # Set epoch for DistributedSampler
            self.train_loader.sampler.set_epoch(epoch)
            
            if self.rank == 0:
                print(f"\nEpoch {epoch + 1}/{self.config.NUM_EPOCHS}")
            
            train_loss = self.train_epoch()
            val_loss, val_r2, val_rmse, val_mae = self.evaluate(self.val_loader)
            
            if self.rank == 0:
                self.scheduler.step(val_loss)
                
                print(f"Train Loss: {train_loss:.4f}")
                print(f"Val Loss: {val_loss:.4f} | R²: {val_r2:.4f} | RMSE: {val_rmse:.4f} | MAE: {val_mae:.4f}")
                
                if val_loss < self.best_val_loss:
                    self.best_val_loss = val_loss
                    self.patience_counter = 0
                    
                    # Save model (only rank 0)
                    checkpoint = {
                        'epoch': epoch + 1,
                        'model_state_dict': self.model.module.state_dict(),  # Use .module for DDP
                        'optimizer_state_dict': self.optimizer.state_dict(),
                        'val_loss': val_loss,
                        'val_r2': val_r2,
                        'pki_mean': self.pki_mean,
                        'pki_std': self.pki_std
                    }
                    torch.save(checkpoint, 'best_model_checkpoint.pt')
                    torch.save(self.model.module.state_dict(), 'best_model_weights.pt')
                    print("✓ Model saved!")
                else:
                    self.patience_counter += 1
                    if self.patience_counter >= self.config.PATIENCE:
                        print(f"\nEarly stopping at epoch {epoch + 1}")
                        break
            
            # Synchronize early stopping across all processes
            if dist.is_initialized():
                should_stop = torch.tensor([self.patience_counter >= self.config.PATIENCE], 
                                          device=self.rank, dtype=torch.bool)
                dist.broadcast(should_stop, src=0)
                if should_stop.item():
                    break
            
            # Memory cleanup
            torch.cuda.empty_cache()
        
        # Wait for all processes
        if dist.is_initialized():
            dist.barrier()
        
        # Load best model on rank 0
        if self.rank == 0:
            self.model.module.load_state_dict(torch.load('best_model_weights.pt'))
            print("\n" + "=" * 80)
            print("Training completed!")


# ============================================================================
# MAIN DDP TRAINING FUNCTION
# ============================================================================

def train_worker(rank, world_size, cached_data, config):
    """
    Worker function for each GPU process
    This is called by torch.multiprocessing.spawn
    """
    # Setup DDP
    setup_ddp(rank, world_size)
    
    if rank == 0:
        print("=" * 80)
        print(f"GPU INFORMATION")
        print("=" * 80)
        for i in range(world_size):
            print(f"GPU {i}: {torch.cuda.get_device_name(i)}")
            print(f"  Memory: {torch.cuda.get_device_properties(i).total_memory / 1e9:.2f} GB")
        print("=" * 80 + "\n")
    
    # Load tokenizer
    tokenizer = AutoTokenizer.from_pretrained(
        cached_data['tokenizer_name'], 
        trust_remote_code=True
    )
    
    # Create datasets from cache
    train_dataset = DTADataset(
        pd.DataFrame(), tokenizer, config.ESM_MAX_LENGTH,
        normalize_pki=True,
        pki_mean=cached_data['pki_mean'],
        pki_std=cached_data['pki_std'],
        cache_data=cached_data['train_cache']
    )
    
    val_dataset = DTADataset(
        pd.DataFrame(), tokenizer, config.ESM_MAX_LENGTH,
        normalize_pki=True,
        pki_mean=cached_data['pki_mean'],
        pki_std=cached_data['pki_std'],
        cache_data=cached_data['val_cache']
    )
    
    test_dataset = DTADataset(
        pd.DataFrame(), tokenizer, config.ESM_MAX_LENGTH,
        normalize_pki=True,
        pki_mean=cached_data['pki_mean'],
        pki_std=cached_data['pki_std'],
        cache_data=cached_data['test_cache']
    )
    
    if rank == 0:
        print(f"Train: {len(train_dataset)} | Val: {len(val_dataset)} | Test: {len(test_dataset)}\n")
    
    # Create DistributedSamplers
    train_sampler = DistributedSampler(
        train_dataset,
        num_replicas=world_size,
        rank=rank,
        shuffle=True,
        seed=config.RANDOM_SEED
    )
    
    val_sampler = DistributedSampler(
        val_dataset,
        num_replicas=world_size,
        rank=rank,
        shuffle=False
    )
    
    test_sampler = DistributedSampler(
        test_dataset,
        num_replicas=world_size,
        rank=rank,
        shuffle=False
    )
    
    # Create dataloaders with DistributedSampler
    train_loader = DataLoader(
        train_dataset,
        batch_size=config.BATCH_SIZE,
        sampler=train_sampler,  # Use sampler instead of shuffle
        collate_fn=collate_fn,
        num_workers=config.NUM_WORKERS,
        pin_memory=config.PIN_MEMORY,
        persistent_workers=config.PERSISTENT_WORKERS
    )
    
    val_loader = DataLoader(
        val_dataset,
        batch_size=config.BATCH_SIZE,
        sampler=val_sampler,
        collate_fn=collate_fn,
        num_workers=config.NUM_WORKERS,
        pin_memory=config.PIN_MEMORY,
        persistent_workers=config.PERSISTENT_WORKERS
    )
    
    test_loader = DataLoader(
        test_dataset,
        batch_size=config.BATCH_SIZE,
        sampler=test_sampler,
        collate_fn=collate_fn,
        num_workers=config.NUM_WORKERS,
        pin_memory=config.PIN_MEMORY,
        persistent_workers=config.PERSISTENT_WORKERS
    )
    
    # Initialize model
    if rank == 0:
        print("Initializing model...")
    
    model = DTAModel(config)
    
    # Better initialization
    def init_weights(m):
        if isinstance(m, nn.Linear):
            torch.nn.init.xavier_uniform_(m.weight, gain=0.5)
            if m.bias is not None:
                torch.nn.init.zeros_(m.bias)
    
    model.apply(init_weights)
    
    if rank == 0:
        total_params = sum(p.numel() for p in model.parameters())
        trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
        print(f"Total parameters: {total_params:,}")
        print(f"Trainable parameters: {trainable_params:,}\n")
    
    # Train
    trainer = DDPTrainer(
        model, train_loader, val_loader, config, rank,
        pki_mean=cached_data['pki_mean'],
        pki_std=cached_data['pki_std']
    )
    trainer.train()
    
    # Final test evaluation (only on rank 0)
    if rank == 0:
        print("\n" + "=" * 80)
        print("FINAL TEST SET EVALUATION")
        print("=" * 80)
    
    test_loss, test_r2, test_rmse, test_mae = trainer.evaluate(test_loader)
    
    if rank == 0:
        print(f"Test Loss: {test_loss:.4f}")
        print(f"Test R²: {test_r2:.4f}")
        print(f"Test RMSE: {test_rmse:.4f}")
        print(f"Test MAE: {test_mae:.4f}")
        print("=" * 80)
    
    # Cleanup
    cleanup_ddp()


def train_from_cache_ddp(cache_path=None, config=Config):
    """
    STEP 2: Run this to train using DDP with cached data
    """
    if cache_path is None:
        # Make sure this path is correct for your Kaggle environment
        # For example, if your cached data is in /kaggle/input/your-dataset/
        cache_path = Path('/kaggle/input/cache-data/processed_datasets.pkl')
    else:
        cache_path = Path(cache_path)
    
    if not cache_path.exists():
        raise FileNotFoundError(
            f"No cached data found at {cache_path}! Run preprocess_and_cache_data() first."
        )
    
    print("=" * 80)
    print("LOADING CACHED DATA FOR DDP TRAINING")
    print("=" * 80)
    
    with open(cache_path, 'rb') as f:
        cached_data = pickle.load(f)
    
    print("✓ Cached data loaded!")
    
    if not torch.cuda.is_available():
        raise RuntimeError("CUDA is not available. DDP requires GPUs.")
    
    world_size = torch.cuda.device_count()
    
    if world_size < 2:
        print("\n⚠ WARNING: Only 1 GPU detected. DDP is not needed.")
        # ... (rest of the warning)
    
    print(f"\nDetected {world_size} GPUs. Starting distributed training...")
    print("=" * 80 + "\n")

    # ========================================================================
    # CRITICAL FIX: SET THE START METHOD HERE
    # ========================================================================
    # This MUST be placed right before mp.spawn in the main script execution.
    try:
        mp.set_start_method('fork', force=True)
        print("✓ Set multiprocessing start method to 'fork'.")
    except RuntimeError:
        # This might happen in some interactive environments if the context is already set.
        # It's usually safe to ignore.
        print("Multiprocessing context already set. Continuing...")
    # ========================================================================

    # Spawn processes for each GPU
    mp.spawn(
        train_worker,
        args=(world_size, cached_data, config),
        nprocs=world_size,
        join=True
    )
    
    print("\n" + "=" * 80)
    print("DDP TRAINING COMPLETED!")
    print("=" * 80)



# ============================================================================
# CONVENIENCE FUNCTIONS
# ============================================================================

def train_single_gpu_from_cache(cache_path=None, config=Config):
    """
    Alternative: Single GPU training (more stable for debugging)
    
    Use this if you encounter DDP issues or want to debug
    """
    if cache_path is None:
        cache_path = Path('/kaggle/input/cache-data/processed_datasets.pkl')
    else:
        cache_path = Path(cache_path)
    
    if not cache_path.exists():
        raise FileNotFoundError(
            f"No cached data found at {cache_path}! Run preprocess_and_cache_data() first."
        )
    
    print("=" * 80)
    print("LOADING CACHED DATA FOR SINGLE-GPU TRAINING")
    print("=" * 80)
    
    with open(cache_path, 'rb') as f:
        cached_data = pickle.load(f)
    
    print("✓ Cached data loaded!")
    
    device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
    print(f"\nUsing device: {device}")
    
    if torch.cuda.is_available():
        print(f"GPU: {torch.cuda.get_device_name(0)}")
        print(f"Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")
    
    # Load tokenizer
    tokenizer = AutoTokenizer.from_pretrained(
        cached_data['tokenizer_name'], 
        trust_remote_code=True
    )
    
    # Create datasets
    train_dataset = DTADataset(
        pd.DataFrame(), tokenizer, config.ESM_MAX_LENGTH,
        normalize_pki=True,
        pki_mean=cached_data['pki_mean'],
        pki_std=cached_data['pki_std'],
        cache_data=cached_data['train_cache']
    )
    
    val_dataset = DTADataset(
        pd.DataFrame(), tokenizer, config.ESM_MAX_LENGTH,
        normalize_pki=True,
        pki_mean=cached_data['pki_mean'],
        pki_std=cached_data['pki_std'],
        cache_data=cached_data['val_cache']
    )
    
    test_dataset = DTADataset(
        pd.DataFrame(), tokenizer, config.ESM_MAX_LENGTH,
        normalize_pki=True,
        pki_mean=cached_data['pki_mean'],
        pki_std=cached_data['pki_std'],
        cache_data=cached_data['test_cache']
    )
    
    print(f"\nTrain: {len(train_dataset)} | Val: {len(val_dataset)} | Test: {len(test_dataset)}\n")
    
    # Create dataloaders (no DistributedSampler)
    train_loader = DataLoader(
        train_dataset,
        batch_size=config.BATCH_SIZE,
        shuffle=True,
        collate_fn=collate_fn,
        num_workers=config.NUM_WORKERS,
        pin_memory=config.PIN_MEMORY
    )
    
    val_loader = DataLoader(
        val_dataset,
        batch_size=config.BATCH_SIZE,
        shuffle=False,
        collate_fn=collate_fn,
        num_workers=config.NUM_WORKERS,
        pin_memory=config.PIN_MEMORY
    )
    
    test_loader = DataLoader(
        test_dataset,
        batch_size=config.BATCH_SIZE,
        shuffle=False,
        collate_fn=collate_fn,
        num_workers=config.NUM_WORKERS,
        pin_memory=config.PIN_MEMORY
    )
    
    # Initialize model
    print("Initializing model...")
    model = DTAModel(config).to(device)
    
    def init_weights(m):
        if isinstance(m, nn.Linear):
            torch.nn.init.xavier_uniform_(m.weight, gain=0.5)
            if m.bias is not None:
                torch.nn.init.zeros_(m.bias)
    
    model.apply(init_weights)
    
    total_params = sum(p.numel() for p in model.parameters())
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print(f"Total parameters: {total_params:,}")
    print(f"Trainable parameters: {trainable_params:,}\n")
    
    # Use a simplified trainer for single GPU
    from types import SimpleNamespace
    single_config = SimpleNamespace(**vars(config))
    single_config.DEVICE = device
    
    # Note: You would need to implement a SingleGPUTrainer or adapt the original Trainer class
    print("For single-GPU training, use the original Trainer class from the source code.")
    print("This function serves as a template.")


# ============================================================================
# USAGE EXAMPLES
# ============================================================================

"""
USAGE:

# Step 1: Preprocess data (run once)
cached_data = preprocess_and_cache_data('/path/to/your/data.csv')

# Step 2: Train with DDP (multiple GPUs)
train_from_cache_ddp()

# Alternative: Train with single GPU (for debugging)
train_single_gpu_from_cache()

# To resume training or change config:
config = Config()
config.LEARNING_RATE = 1e-5
config.NUM_EPOCHS = 100
train_from_cache_ddp(config=config)
"""

"\nUSAGE:\n\n# Step 1: Preprocess data (run once)\ncached_data = preprocess_and_cache_data('/path/to/your/data.csv')\n\n# Step 2: Train with DDP (multiple GPUs)\ntrain_from_cache_ddp()\n\n# Alternative: Train with single GPU (for debugging)\ntrain_single_gpu_from_cache()\n\n# To resume training or change config:\nconfig = Config()\nconfig.LEARNING_RATE = 1e-5\nconfig.NUM_EPOCHS = 100\ntrain_from_cache_ddp(config=config)\n"

In [8]:
train_from_cache_ddp()

LOADING CACHED DATA FOR DDP TRAINING


MemoryError: 

In [None]:
import gc

# Run garbage collector
gc.collect()


In [None]:
# This loads from cache and trains
# No reprocessing needed, even after crashes!
model, trainer = train_from_cache()

# Model checkpoints saved:
# - best_model_weights.pt
# - best_model_checkpoint.pt