# Phase 7: Explainability & Interpretability

## 1. Objective
To understand **what the models have learned**. We use **Grad-CAM (Gradient-weighted Class Activation Mapping)** to visualize which temporal regions of the PPG signal contribute most to the Heart Rate prediction.

## 2. Setup
We compare two models trained on **PPG-DaLiA**:
1.  **Supervised CNN** (Baseline)
2.  **SSL Specialized CNN** (Phase 4.2)

We expect the SSL model to focus more robustly on physiological features (systolic peaks) rather than noise.

In [None]:
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
import numpy as np
from pathlib import Path
from torch.utils.data import DataLoader

# Import project modules
from models import CNNBaseline
from train import PPGDataset
from xai_utils import GradCAM1D, plot_saliency

device = torch.device("cpu") # Visualization is fast enough on CPU

In [None]:
# Load Models
def load_model(path):
    model = CNNBaseline().to(device)
    model.load_state_dict(torch.load(path, map_location=device))
    model.eval()
    return model

model_sup = load_model("best_model_cnn_ppg_dalia.pth")
model_ssl = load_model("best_model_ssl_specialized_ppg_dalia.pth")

# Target Layer for Grad-CAM (Last Conv Block)
target_layer_sup = model_sup.conv3[0]
target_layer_ssl = model_ssl.conv3[0]

print("Models Loaded and Hooks Ready.")

## 3. Visualization
We select a sample from the test set and overlay the Grad-CAM heatmap.

In [None]:
# Load Test Data
dataset = PPGDataset("preprocessed_data/ppg_dalia/test")
loader = DataLoader(dataset, batch_size=1, shuffle=True)

# Get a random sample
x, y_true = next(iter(loader))
x = x.to(device)

# --- Explain Supervised Model ---
cam_sup = GradCAM1D(model_sup, target_layer_sup)
y_pred_sup = model_sup(x)
# Backward to generate gradients (w.r.t prediction)
model_sup.zero_grad()
y_pred_sup.backward()
heatmap_sup = cam_sup(x)

# --- Explain SSL Model ---
cam_ssl = GradCAM1D(model_ssl, target_layer_ssl)
y_pred_ssl = model_ssl(x)
model_ssl.zero_grad()
y_pred_ssl.backward()
heatmap_ssl = cam_ssl(x)

# --- Plot Comparison ---
signal = x.detach().cpu().numpy().flatten()
fig, ax = plt.subplots(2, 1, figsize=(12, 6), sharex=True)

plot_saliency(signal, heatmap_sup, title=f"Supervised CNN (Pred: {y_pred_sup.item():.1f}, True: {y_true.item():.1f})", ax=ax[0])
plot_saliency(signal, heatmap_ssl, title=f"SSL Specialized (Pred: {y_pred_ssl.item():.1f})", ax=ax[1])

plt.tight_layout()
plt.show()

## 4. Interpretation
**Red regions** indicate high importance for the model's decision.

*   **Physiological Alignment**: Look effectively at the systolic peaks (sharp upstrokes).
*   **Artifact Ignorance**: Check if the model ignores the flat/noisy breakdown sections.

If the SSL model shows cleaner focus on the peaks compared to the Supervised model (which might focus on random noise), this confirms **better representation learning**.