# Step 4: XAI Analysis and Interpretation

This notebook implements **Sections 3.3 and 3.4** of the research proposal. Having trained Model A (Transfer Learning) and Model B (Fine-Tuning), we will now:

1.  Load the best-performing checkpoints for both models.
2.  Run inference on the entire validation set to identify specific examples of **True Positives, False Positives, and False Negatives**.
3.  Set up the implementation for our three XAI methods: Grad-CAM, SHAP, and Integrated Gradients.
4.  Generate and visualize side-by-side comparisons of the explanations for our curated examples.

In [None]:
import sys
import os
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from tqdm.auto import tqdm

import torch
from torchvision import transforms
from PIL import Image

# Add our custom source code to the path
sys.path.append("../src")
# Import our data setup and model class
from data_setup import create_dataloaders
from model import DRClassifier 

# --- Configuration ---
RAW_DATA_DIR = '../data/raw/aptos2019-blindness-detection/'
PROCESSED_DATA_DIR = '../data/processed/'
IMAGE_DIR = os.path.join(RAW_DATA_DIR, 'train_images/')
MODEL_A_PATH = '../models/model_a_best.ckpt'
MODEL_B_PATH = '../models/model_b_best.ckpt'

In [None]:
# --- Load the Trained Models ---

# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Load Model A (Transfer Learning)
model_a = DRClassifier.load_from_checkpoint(MODEL_A_PATH).to(device)
model_a.eval() # Set to evaluation mode

# Load Model B (Fine-Tuning)
model_b = DRClassifier.load_from_checkpoint(MODEL_B_PATH).to(device)
model_b.eval() # Set to evaluation mode

print("Models loaded successfully.")

### 4.1: Generating and Caching Model Predictions

To create an efficient workflow, we will run inference on the entire validation set once for each model and then **cache the results** in a CSV file. On subsequent runs, if these files exist, we'll load them directly instead of re-running the time-consuming prediction loop.

This allows us to quickly experiment with the XAI analysis without waiting for inference every time.

In [None]:
# --- Define paths for cached predictions ---
results_a_path = os.path.join(PROCESSED_DATA_DIR, 'results_model_a.csv')
results_b_path = os.path.join(PROCESSED_DATA_DIR, 'results_model_b.csv')

# --- Check if cached results exist ---
if os.path.exists(results_a_path) and os.path.exists(results_b_path):
    print("Loading cached predictions...")
    results_a_df = pd.read_csv(results_a_path)
    results_b_df = pd.read_csv(results_b_path)

else:
    print("Cached predictions not found. Generating and caching new predictions...")
    
    import time

    # MODIFIED FUNCTION SIGNATURE: Added 'dataframe' parameter
    def get_predictions(model, loader, device, dataframe):
        """Run inference on the dataloader and return results with debug prints."""
        model.eval()
        all_preds = []
        all_labels = []
        
        # We zip the dataloader with the dataframe. Because batch_size=1 and shuffle=False,
        # they will match up perfectly, row by row.
        # enumerate gives us a counter (batch_idx) for our debug prints.
        data_iterator = tqdm(zip(loader, dataframe.iterrows()), total=len(dataframe), desc="Getting Predictions")

        with torch.no_grad():
            for batch_idx, ((images, labels), (_, row)) in enumerate(data_iterator):
                
                # --- NEW DEBUGGING BLOCK ---
                # Print an update every 50 images
                if batch_idx % 50 == 0:
                    image_id = row['id_code']
                    # This print will appear above the tqdm progress bar
                    print(f"  [Debug] Processing image {batch_idx}/{len(dataframe)}: {image_id}.png")
                # --- END NEW DEBUGGING BLOCK ---

                images = images.to(device)
                logits = model(images)
                preds = torch.argmax(logits, dim=1)

                all_preds.extend(preds.cpu().numpy())
                all_labels.extend(labels.cpu().numpy())
                
                time.sleep(0.05)

        return np.array(all_preds), np.array(all_labels)

    # Recreate the validation dataloader
    val_df = pd.read_csv(os.path.join(PROCESSED_DATA_DIR, 'val_split.csv'))
    _, val_loader_inference = create_dataloaders(
        train_df=val_df, 
        val_df=val_df, 
        image_dir=IMAGE_DIR, 
        batch_size=1, 
        num_workers=0 # This remains 0, which is critical
    )

    # MODIFIED FUNCTION CALLS: Pass the val_df dataframe
    print("\n--- Generating predictions for Model A ---")
    preds_a, labels_a = get_predictions(model_a, val_loader_inference, device, val_df)
    
    print("\n--- Generating predictions for Model B ---")
    preds_b, labels_b = get_predictions(model_b, val_loader_inference, device, val_df)

    # Create results dataframes
    results_a_df = val_df.copy()
    results_a_df['predicted_label'] = preds_a
    results_a_df['true_label'] = labels_a

    results_b_df = val_df.copy()
    results_b_df['predicted_label'] = preds_b
    results_b_df['true_label'] = labels_b

    # Cache the results
    results_a_df.to_csv(results_a_path, index=False)
    results_b_df.to_csv(results_b_path, index=False)
    print("Predictions cached successfully.")

# --- Identify cases for analysis (This part is unchanged) ---
fp_a = results_a_df[(results_a_df['true_label'] == 0) & (results_a_df['predicted_label'] > 0)]
fn_a = results_a_df[(results_a_df['true_label'] > 0) & (results_a_df['predicted_label'] == 0)]
tp_a = results_a_df[(results_a_df['true_label'] > 0) & (results_a_df['predicted_label'] > 0) & (results_a_df['true_label'] == results_a_df['predicted_label'])]

print("\n--- Analysis Cases for Model A ---")
print(f"Found {len(tp_a)} True Positives (diseased)")
print(f"Found {len(fp_a)} False Positives")
print(f"Found {len(fn_a)} False Negatives")

print("\nExample of a False Negative from Model A:")
print(fn_a.head(1))

In [None]:
# --- Identify cases for analysis from the (now loaded) dataframes ---
fp_a = results_a_df[(results_a_df['true_label'] == 0) & (results_a_df['predicted_label'] > 0)]
fn_a = results_a_df[(results_a_df['true_label'] > 0) & (results_a_df['predicted_label'] == 0)]
tp_a = results_a_df[(results_a_df['true_label'] > 0) & (results_a_df['predicted_label'] > 0) & (results_a_df['true_label'] == results_a_df['predicted_label'])]

print("\n--- Analysis Cases for Model A ---")
print(f"Found {len(tp_a)} True Positives (diseased)")
print(f"Found {len(fp_a)} False Positives")
print(f"Found {len(fn_a)} False Negatives")

print("\nExample of a False Negative from Model A:")
print(fn_a.head(1))

In [None]:
print("Hello")