# Phase 4: Model Training and Optimization

Phase 4 involves setting up the training loop, defining loss, selecting an optimizer, and implementing techniques like early stopping and learning rate 
scheduling.

Here is the code implementation from scratch, integrating the components from Phase 3.

##### Prerequisites

We'll assume you have the DTIModel, DTIDataset, and custom_collate functions defined from Phase 3 and the train_data, val_data DataFrames from Phase 1.

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torch.optim.lr_scheduler import ReduceLROnPlateau
from sklearn.metrics import roc_auc_score, average_precision_score
import numpy as np
import os
import time

# --- ASSUMED IMPORTS FROM PREVIOUS PHASES ---
# from my_model_scripts import DTIModel, custom_collate, DTIDataset 
# from my_data_scripts import train_data, val_data 
# We'll define dummy values for demonstration

# --- DUMMY SETUP (REPLACE WITH REAL DATA/MODEL) ---
# Define dummy dimensions based on Phase 2 & 3
DRUG_IN_FEAT = 71 # Example number of atom features from Phase 2
TARGET_IN_FEAT = 21 # Amino acid alphabet size ('ACDEFGHIKLMNPQRSTVWXY')
EMBEDDING_DIM = 128
HIDDEN_DIM = 64
GNN_LAYERS = 3
CNN_KERNEL_SIZE = 8
MAX_LEN = 1200
BATCH_SIZE = 32
LEARNING_RATE = 1e-4
NUM_EPOCHS = 50

# Set device
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

### 1. Training Setup and Initialization

In [None]:
def initialize_training_components(drug_in_feat, target_in_feat, embedding_dim, hidden_dim, gnn_layers, cnn_kernel_size, lr):
    """Initializes the model, optimizer, loss function, and device."""
    
    # 1. Initialize Model
    model = DTIModel(
        drug_in_feat=drug_in_feat,
        target_in_feat=target_in_feat,
        hidden_dim=hidden_dim,
        gnn_layers=gnn_layers,
        cnn_kernel_size=cnn_kernel_size,
        embedding_dim=embedding_dim
    ).to(DEVICE)
    
    # 2. Define Loss Function
    # For binary classification, use Binary Cross-Entropy with Logits (BCEWithLogitsLoss) 
    # if the sigmoid is NOT in the model's forward pass, or standard BCELoss if it is.
    # Since we used sigmoid in DTIModel, we use BCELoss.
    criterion = nn.BCELoss() 
    
    # 3. Define Optimizer
    optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=1e-5) # Added L2 regularization
    
    # 4. Learning Rate Scheduler
    # Reduces LR if validation loss plateaus
    scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=5, verbose=True)
    
    return model, criterion, optimizer, scheduler

# --- Dummy Data Loader Setup (Replace with actual data) ---
# train_dataset = DTIDataset(train_data, max_len=MAX_LEN) 
# val_dataset = DTIDataset(val_data, max_len=MAX_LEN) 

# train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, collate_fn=custom_collate)
# val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, collate_fn=custom_collate)

### 2. Helper Functions for Evaluation and Early Stopping

In [None]:
def evaluate_model(model, loader, criterion):
    """Evaluates the model on a given dataset loader."""
    model.eval()
    total_loss = 0
    all_labels = []
    all_predictions = []
    
    with torch.no_grad():
        for drug_batch, target_batch, labels in loader:
            drug_batch = drug_batch.to(DEVICE)
            target_batch = target_batch.to(DEVICE)
            labels = labels.to(DEVICE).float().unsqueeze(1)
            
            predictions = model(drug_batch, target_batch)
            loss = criterion(predictions, labels)
            total_loss += loss.item() * len(labels)
            
            all_labels.extend(labels.cpu().numpy())
            all_predictions.extend(predictions.cpu().numpy())

    avg_loss = total_loss / len(loader.dataset)
    
    # Calculate key metrics (AUROC and AUPRC)
    try:
        auroc = roc_auc_score(all_labels, all_predictions)
        auprc = average_precision_score(all_labels, all_predictions)
    except ValueError:
        # Happens if only one class is present in the batch, handle gracefully
        auroc, auprc = 0.5, 0.5
        
    model.train()
    return avg_loss, auroc, auprc

class EarlyStopping:
    """Stops training if validation loss doesn't improve after a given patience."""
    def __init__(self, patience=10, min_delta=0, path='best_checkpoint.pt'):
        self.patience = patience
        self.min_delta = min_delta
        self.counter = 0
        self.best_loss = np.inf
        self.early_stop = False
        self.path = path

    def __call__(self, val_loss, model):
        if val_loss < self.best_loss - self.min_delta:
            self.best_loss = val_loss
            self.counter = 0
            # Save the best model state
            torch.save(model.state_dict(), self.path)
        elif val_loss > self.best_loss + self.min_delta:
            self.counter += 1
            if self.counter >= self.patience:
                self.early_stop = True

### 3. Main Training Loop

In [None]:
def train_model(model, train_loader, val_loader, criterion, optimizer, scheduler, num_epochs, checkpoint_path='dti_model_best.pt'):
    """Main function to run the training process."""
    
    print(f"\n4. Starting Training on device: {DEVICE}")
    print(f"   Model parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad):,}")
    
    # Initialize Early Stopping
    early_stopper = EarlyStopping(patience=10, min_delta=0.001, path=checkpoint_path)
    
    for epoch in range(1, num_epochs + 1):
        model.train()
        t0 = time.time()
        train_loss = 0
        
        # --- Training Step ---
        for drug_batch, target_batch, labels in train_loader:
            drug_batch = drug_batch.to(DEVICE)
            target_batch = target_batch.to(DEVICE)
            labels = labels.to(DEVICE).float().unsqueeze(1)
            
            optimizer.zero_grad()
            predictions = model(drug_batch, target_batch)
            loss = criterion(predictions, labels)
            loss.backward()
            optimizer.step()
            
            train_loss += loss.item() * len(labels)
            
        avg_train_loss = train_loss / len(train_loader.dataset)
        
        # --- Validation Step ---
        val_loss, val_auroc, val_auprc = evaluate_model(model, val_loader, criterion)
        
        # --- Optimization Steps ---
        scheduler.step(val_loss) # Update learning rate scheduler
        early_stopper(val_loss, model) # Check for early stopping and save best model
        
        t1 = time.time()
        
        print(f"Epoch {epoch:02d} | Time: {t1-t0:.2f}s | Train Loss: {avg_train_loss:.4f} | Val Loss: {val_loss:.4f} | Val AUROC: {val_auroc:.4f} | Val AUPRC: {val_auprc:.4f}")
        
        if early_stopper.early_stop:
            print(f"ðŸ›‘ Early stopping triggered. Loading best model from epoch {epoch - early_stopper.patience}.")
            # Load the best weights before exiting
            model.load_state_dict(torch.load(checkpoint_path))
            break
            
    print("Training finished.")
    return model

# --- FINAL EXECUTION (REQUIRES DUMMY DATA/MODEL TO BE REPLACED) ---
# model, criterion, optimizer, scheduler = initialize_training_components(
#     DRUG_IN_FEAT, TARGET_IN_FEAT, EMBEDDING_DIM, HIDDEN_DIM, GNN_LAYERS, CNN_KERNEL_SIZE, LEARNING_RATE
# )

# # Pass the model to the training function
# final_best_model = train_model(
#     model, train_loader, val_loader, criterion, optimizer, scheduler, NUM_EPOCHS
# )