# Phase 5: Evaluation and Interpretation

This is the final, crucial step! Phase 5: Evaluation and Interpretation involves two main goals: rigorously measuring the model's performance on the unseen Test Set and extracting biological insights (interpretability).

Here is the code implementation from scratch, integrating the evaluation metrics and a basic interpretability method (analysis of embedding space).

##### Prerequisites
We'll assume you have the necessary classes and data splits:

In [None]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from sklearn.metrics import roc_auc_score, average_precision_score, confusion_matrix
from sklearn.manifold import TSNE
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from typing import List

# --- ASSUMED IMPORTS/SETUP FROM PREVIOUS PHASES ---
# from my_model_scripts import DTIModel, custom_collate, DTIDataset 
# from my_data_scripts import test_data # The DataFrame for the held-out test set

# Define constants from previous phases
DRUG_IN_FEAT = 71
TARGET_IN_FEAT = 21
EMBEDDING_DIM = 128
HIDDEN_DIM = 64
GNN_LAYERS = 3
CNN_KERNEL_SIZE = 8
MAX_LEN = 1200
BATCH_SIZE = 32
CHECKPOINT_PATH = 'dti_model_best.pt' # Path where Phase 4 saved the best weights
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

### 1. Model Evaluation on the Test Set (Proving Performance)
This function runs the final trained model on the test data and calculates the key binary classification metrics.

In [None]:
def evaluate_on_test_set(test_df: pd.DataFrame, checkpoint_path: str):
    """
    Loads the best model, performs prediction on the test set, and calculates metrics.
    
    Returns:
        tuple: (metrics_dict, all_labels, all_predictions, model)
    """
    print("1. Loading best model and setting up Test Evaluation...")
    
    # 1. Initialize and Load Model
    model = DTIModel(
        drug_in_feat=DRUG_IN_FEAT, target_in_feat=TARGET_IN_FEAT, hidden_dim=HIDDEN_DIM,
        gnn_layers=GNN_LAYERS, cnn_kernel_size=CNN_KERNEL_SIZE, embedding_dim=EMBEDDING_DIM
    ).to(DEVICE)
    
    # Check if checkpoint exists and load weights
    if not os.path.exists(checkpoint_path):
        print(f"Error: Checkpoint file not found at {checkpoint_path}. Using uninitialized model.")
    else:
        model.load_state_dict(torch.load(checkpoint_path, map_location=DEVICE))
    
    model.eval() # Set model to evaluation mode
    
    # 2. Setup DataLoader
    test_dataset = DTIDataset(test_df, max_len=MAX_LEN) 
    test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, collate_fn=custom_collate)

    all_labels, all_probabilities = [], []
    
    with torch.no_grad():
        for drug_batch, target_batch, labels in test_loader:
            drug_batch = drug_batch.to(DEVICE)
            target_batch = target_batch.to(DEVICE)
            
            # Predict probabilities
            predictions = model(drug_batch, target_batch).cpu().numpy().flatten()
            
            all_probabilities.extend(predictions)
            all_labels.extend(labels.cpu().numpy().flatten())

    # 3. Calculate Metrics
    all_labels = np.array(all_labels)
    all_probabilities = np.array(all_probabilities)
    
    # Calculate AUROC and AUPRC
    auroc = roc_auc_score(all_labels, all_probabilities)
    auprc = average_precision_score(all_labels, all_probabilities)
    
    # Convert probabilities to binary predictions using a threshold (e.g., 0.5)
    threshold = 0.5
    binary_predictions = (all_probabilities >= threshold).astype(int)
    
    # Calculate Confusion Matrix and derived metrics
    tn, fp, fn, tp = confusion_matrix(all_labels, binary_predictions).ravel()
    accuracy = (tp + tn) / len(all_labels)
    precision = tp / (tp + fp) if (tp + fp) > 0 else 0
    recall = tp / (tp + fn) if (tp + fn) > 0 else 0
    
    metrics = {
        'Test AUROC': auroc,
        'Test AUPRC': auprc,
        'Test Accuracy': accuracy,
        'Test Precision': precision,
        'Test Recall': recall,
        'TN': tn, 'FP': fp, 'FN': fn, 'TP': tp
    }
    
    print("\n--- Final Test Metrics ---")
    print(pd.Series(metrics).drop(['TN', 'FP', 'FN', 'TP']))
    
    return metrics, all_labels, all_probabilities, model

# --- DUMMY EXECUTION (Needs real test_data and checkpoint) ---
# test_metrics, test_labels, test_probabilities, final_model = evaluate_on_test_set(test_data, CHECKPOINT_PATH)

### 2. Global Interpretation (Embedding Visualization)
This function uses t-SNE to reduce the high-dimensional drug and protein feature vectors ($\mathbf{V}_D$ and $\mathbf{V}_P$) down to 2D for visualization, checking if the model clusters similar compounds/targets.

In [None]:
def visualize_embeddings(model: nn.Module, test_loader: DataLoader):
    """
    Extracts drug and target embeddings from the test set and visualizes them using t-SNE.
    """
    print("\n2. Extracting Embeddings for t-SNE Visualization...")
    model.eval()
    all_v_d, all_v_p, all_labels = [], [], []

    with torch.no_grad():
        for drug_batch, target_batch, labels in test_loader:
            drug_batch = drug_batch.to(DEVICE)
            target_batch = target_batch.to(DEVICE)
            
            # Extract V_D and V_P before concatenation/fusion
            v_d = model.drug_encoder(drug_batch).cpu().numpy()
            v_p = model.target_encoder(target_batch).cpu().numpy()
            
            all_v_d.append(v_d)
            all_v_p.append(v_p)
            all_labels.extend(labels.cpu().numpy().flatten())

    V_D_matrix = np.vstack(all_v_d)
    V_P_matrix = np.vstack(all_v_p)
    labels = np.array(all_labels)
    
    # --- Perform t-SNE on both embeddings ---
    tsne_v_d = TSNE(n_components=2, random_state=42, perplexity=30).fit_transform(V_D_matrix)
    tsne_v_p = TSNE(n_components=2, random_state=42, perplexity=30).fit_transform(V_P_matrix)
    
    # --- Visualization ---
    plt.figure(figsize=(14, 6))
    
    # 2.1 Drug Embeddings Plot
    plt.subplot(1, 2, 1)
    sns.scatterplot(x=tsne_v_d[:, 0], y=tsne_v_d[:, 1], hue=labels, palette='viridis', legend='full', alpha=0.7)
    plt.title('t-SNE of Drug Embeddings ($\mathbf{V}_D$)', fontsize=14)
    plt.xlabel('t-SNE Component 1')
    plt.ylabel('t-SNE Component 2')
    plt.legend(title='Binding Label')
    
    # 2.2 Target Embeddings Plot
    plt.subplot(1, 2, 2)
    sns.scatterplot(x=tsne_v_p[:, 0], y=tsne_v_p[:, 1], hue=labels, palette='viridis', legend='full', alpha=0.7)
    plt.title('t-SNE of Target Embeddings ($\mathbf{V}_P$)', fontsize=14)
    plt.xlabel('t-SNE Component 1')
    plt.ylabel('t-SNE Component 2')
    plt.legend(title='Binding Label')
    
    plt.tight_layout()
    plt.show()

# --- FINAL EXECUTION (REQUIRES DUMMY DATA/MODEL TO BE REPLACED) ---
# test_dataset = DTIDataset(test_data, max_len=MAX_LEN) 
# test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, collate_fn=custom_collate)
# visualize_embeddings(final_model, test_loader)

### 3. Local Interpretation (Attention/Saliency - Conceptual)
For this basic implementation, explicit attention mechanisms were not built into the Phase 3 model. However, here is the conceptual workflow for interpretation, which is vital for the biological context:
- Goal: Determine which atoms (Drug) and which amino acids (Target) are most important for the positive prediction of binding.
- Technique (Post-hoc): Use Gradient-based Saliency Maps.
    - Take a single True Positive DTI pair.
    - Calculate the gradient of the final prediction score with respect to the input features (the $\mathbf{X}$ node features for the drug and the OHE matrix for the protein).
    - The magnitude of this gradient at each feature position (atom or residue) indicates its importance to the prediction.
    - Visualize: Map the high-magnitude gradients back onto the 2D molecule structure (RDKit) and the 1D protein sequence to highlight the predicted pharmacophore and binding pocket residues.