In [None]:
import os
import torch
import shap
import numpy as np
import matplotlib.pyplot as plt

# -----------------------------
# Config
# -----------------------------
patient_npz = "/content/chb_preprocessed/chb15_test.npz"
OUT_DIR = "/content/shap_outputs"
os.makedirs(OUT_DIR, exist_ok=True)

# -----------------------------
# 1. Disable cuDNN globally (temporary)
# -----------------------------
orig_cudnn_state = torch.backends.cudnn.enabled
torch.backends.cudnn.enabled = False
print("‚öôÔ∏è cuDNN disabled temporarily for SHAP backward pass.")

# -----------------------------
# 2. Load data
# -----------------------------
data = np.load(patient_npz, allow_pickle=True)
X = data["specs"]  # (samples, channels, freqs, times)
y = data["labels"]
ch_names = data.get("ch_names", [f"Ch{i}" for i in range(X.shape[1])])
n_freq_bins = X.shape[2]

# Frequency axis: adjust if your spectrogram uses a different max freq
freqs = np.linspace(0, 128, n_freq_bins)  # example: 0..128 Hz

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# -----------------------------
# 3. Split per-class and convert to torch
# -----------------------------
X_preictal_np = X[y == 1]
X_interictal_np = X[y == 0]

def to_torch(x_np):
    return torch.tensor(x_np, dtype=torch.float32, device=device)

X_preictal = to_torch(X_preictal_np)
X_interictal = to_torch(X_interictal_np)

print(f"Loaded: Preictal {len(X_preictal_np)} windows, Interictal {len(X_interictal_np)} windows")

# -----------------------------
# 4. Prepare model (must exist in scope)
# -----------------------------
combined_model.to(device)
combined_model.eval()

# -----------------------------
# 5. SHAP computation strategy (contrastive)
# -----------------------------
def compute_shap(background_tensor, test_tensor, desc, max_bg=64, max_test=128):
    """Compute SHAP values using GradientExplainer."""
    bg = background_tensor[:max_bg]
    test = test_tensor[:max_test]

    print(f"\nüß† Explaining {desc}: background={len(bg)}, test={len(test)}")
    explainer = shap.GradientExplainer(model=combined_model, data=bg)
    shap_values = explainer.shap_values(test)
    if not isinstance(shap_values, list):
        shap_values = [shap_values]
    sv = shap_values[0]
    sv = np.nan_to_num(sv)
    print(f"  -> SHAP shape: {sv.shape}")
    return sv, bg, test

# Explain preictal using interictal as background
sv_pre_vs_inter, bg_inter, test_pre = compute_shap(X_interictal, X_preictal, "Preictal vs Interictal (contrast)")

# Explain interictal using preictal as background
sv_inter_vs_pre, bg_pre, test_inter = compute_shap(X_preictal, X_interictal, "Interictal vs Preictal (contrast)")

# -----------------------------
# 6. Importance aggregation helpers
# -----------------------------
def agg_channel_freq(sv):
    """Aggregate absolute SHAP values into per-channel and per-frequency summaries."""
    a = np.abs(sv)
    if a.ndim == 5:
        a = a.mean(axis=-1)
    ch_imp = a.mean(axis=(0, 2, 3))        # (channels,)
    freq_imp = a.mean(axis=(0, 1, 3))      # (freqs,)
    ch_freq = a.mean(axis=(0, 3))          # (channels, freqs)
    return ch_imp, freq_imp, ch_freq

ch_pre, freq_pre, chfreq_pre = agg_channel_freq(sv_pre_vs_inter)
ch_inter, freq_inter, chfreq_inter = agg_channel_freq(sv_inter_vs_pre)

# -----------------------------
# 7. Normalize (proportional normalization)
# -----------------------------
def prop_normalize(vec):
    s = vec.sum()
    return vec / (s + 1e-12)

ch_pre_prop = prop_normalize(ch_pre)
ch_inter_prop = prop_normalize(ch_inter)

freq_pre_prop = prop_normalize(freq_pre)
freq_inter_prop = prop_normalize(freq_inter)

chfreq_pre_prop = chfreq_pre
chfreq_inter_prop = chfreq_inter

# -----------------------------
# 8. EEG band aggregation
# -----------------------------
bands = {
    "Delta (0.5-4 Hz)": (0.5, 4),
    "Theta (4-8 Hz)": (4, 8),
    "Alpha (8-13 Hz)": (8, 13),
    "Beta (13-30 Hz)": (13, 30),
    "Gamma (30-100 Hz)": (30, 100)
}

def band_aggregation(freqs, freq_prop):
    band_imp = {}
    for bname, (fmin, fmax) in bands.items():
        mask = (freqs >= fmin) & (freqs < fmax)
        if mask.sum() > 0:
            band_imp[bname] = freq_prop[mask].sum()
        else:
            band_imp[bname] = 0.0
    return band_imp

band_pre = band_aggregation(freqs, freq_pre_prop)
band_inter = band_aggregation(freqs, freq_inter_prop)

# -----------------------------
# 9. ŒîSHAP (preictal - interictal)
# -----------------------------
ch_delta = ch_pre_prop - ch_inter_prop
band_delta = {k: band_pre[k] - band_inter[k] for k in band_pre.keys()}

# -----------------------------
# 10. Replace numeric labels with real EEG channel names
# -----------------------------
channel_map = {
    0: "Fp1-F7", 1: "F7-T7", 2: "T7-P7", 3: "P7-O1",
    4: "Fp1-F3", 5: "F3-C3", 6: "C3-P3", 7: "P3-O1",
    8: "Fp2-F4", 9: "F4-C4", 10: "C4-P4", 11: "P4-O2",
    12: "Fp2-F8", 13: "F8-T8", 14: "T8-P8", 15: "P8-O2",
    16: "Fz-Cz", 17: "Cz-Pz", 18: "P7-T7", 19: "O1-O2",
    20: "T8-P8", 21: "Fpz-Cz"
}

if all(name.startswith("Ch") for name in ch_names):
    ch_names = [channel_map.get(i, f"Ch{i}") for i in range(len(ch_names))]

# -----------------------------
# 11. Plotting + save outputs
# -----------------------------
def save_fig(fig, name):
    path = os.path.join(OUT_DIR, name)
    fig.savefig(path, bbox_inches="tight")
    print(f"Saved {path}")

# Channel bar plots
plt.figure(figsize=(12,5))
plt.bar(np.arange(len(ch_names)), ch_pre_prop, color='crimson', alpha=0.9, label='Preictal (prop)')
plt.xticks(np.arange(len(ch_names)), ch_names, rotation=45, ha='right')
plt.ylabel("Proportional SHAP attribution")
plt.title("Preictal ‚Äî Channel importance (vs interictal baseline)")
plt.legend()
plt.tight_layout()
save_fig(plt.gcf(), "preictal_channel_prop.png")
plt.show()

plt.figure(figsize=(12,5))
plt.bar(np.arange(len(ch_names)), ch_inter_prop, color='royalblue', alpha=0.9, label='Interictal (prop)')
plt.xticks(np.arange(len(ch_names)), ch_names, rotation=45, ha='right')
plt.ylabel("Proportional SHAP attribution")
plt.title("Interictal ‚Äî Channel importance (vs preictal baseline)")
plt.legend()
plt.tight_layout()
save_fig(plt.gcf(), "interictal_channel_prop.png")
plt.show()

# ŒîSHAP per-channel (sorted)
order = np.argsort(-np.abs(ch_delta))
plt.figure(figsize=(12,5))
plt.bar(np.arange(len(ch_names)), ch_delta[order], color='purple', alpha=0.8)
plt.xticks(np.arange(len(ch_names)), np.array(ch_names)[order], rotation=45, ha='right')
plt.ylabel("ŒîSHAP (pre - inter) [proportion]")
plt.title("Channel ŒîSHAP (Preictal ‚àí Interictal) ‚Äî sorted by |Œî|")
plt.tight_layout()
save_fig(plt.gcf(), "channel_delta_shap.png")
plt.show()

# Band-level plots
plt.figure(figsize=(8,4))
plt.bar(list(band_pre.keys()), list(band_pre.values()), color='crimson', alpha=0.9)
plt.title("Preictal ‚Äî Band proportional SHAP (vs interictal baseline)")
plt.xticks(rotation=30, ha='right')
plt.tight_layout()
save_fig(plt.gcf(), "preictal_band_prop.png")
plt.show()

plt.figure(figsize=(8,4))
plt.bar(list(band_inter.keys()), list(band_inter.values()), color='royalblue', alpha=0.9)
plt.title("Interictal ‚Äî Band proportional SHAP (vs preictal baseline)")
plt.xticks(rotation=30, ha='right')
plt.tight_layout()
save_fig(plt.gcf(), "interictal_band_prop.png")
plt.show()

plt.figure(figsize=(8,4))
plt.bar(list(band_delta.keys()), list(band_delta.values()), color='purple', alpha=0.85)
plt.title("Œî Band SHAP (Preictal ‚àí Interictal)")
plt.xticks(rotation=30, ha='right')
plt.tight_layout()
save_fig(plt.gcf(), "band_delta_shap.png")
plt.show()

# Channel-frequency heatmaps
plt.figure(figsize=(10,6))
plt.imshow(chfreq_pre_prop.T, aspect='auto', origin='lower', cmap='Reds')
plt.colorbar(label="Mean |SHAP| (pre vs inter)")
plt.xlabel("Channel")
plt.ylabel("Frequency bin")
plt.title("Preictal ‚Äî Channel √ó Frequency mean |SHAP| (contrast)")
plt.xticks(np.arange(len(ch_names)), ch_names, rotation=45, ha='right')
plt.tight_layout()
save_fig(plt.gcf(), "ch_freq_pre_heatmap.png")
plt.show()

plt.figure(figsize=(10,6))
plt.imshow(chfreq_inter_prop.T, aspect='auto', origin='lower', cmap='Blues')
plt.colorbar(label="Mean |SHAP| (inter vs pre)")
plt.xlabel("Channel")
plt.ylabel("Frequency bin")
plt.title("Interictal ‚Äî Channel √ó Frequency mean |SHAP| (contrast)")
plt.xticks(np.arange(len(ch_names)), ch_names, rotation=45, ha='right')
plt.tight_layout()
save_fig(plt.gcf(), "ch_freq_inter_heatmap.png")
plt.show()

# -----------------------------
# 12. Diagnostics
# -----------------------------
def print_diag(sv, label):
    print(f"{label} SHAP: samples={sv.shape[0]}, channels={sv.shape[1]}, freqs={sv.shape[2]}, times={sv.shape[3]}")
    print(f"  mean |SHAP| per-channel (first 8): {np.round(np.abs(sv).mean(axis=(0,2,3))[:8],4)}")
    print(f"  std |SHAP| per-channel (first 8): {np.round(np.abs(sv).std(axis=(0,2,3))[:8],4)}")

print("\nDiagnostics:")
print_diag(sv_pre_vs_inter, "Pre vs Inter")
print_diag(sv_inter_vs_pre, "Inter vs Pre")

# -----------------------------
# 13. Restore cuDNN
# -----------------------------
torch.backends.cudnn.enabled = orig_cudnn_state
print("üîÅ cuDNN re-enabled. Done.")


In [None]:
import os
import torch
import shap
import numpy as np
import matplotlib.pyplot as plt

# -----------------------------
# Config
# -----------------------------
patient_npz = "/content/chb_preprocessed/chb15_test.npz"
OUT_DIR = "/content/shap_outputs"
os.makedirs(OUT_DIR, exist_ok=True)

# -----------------------------
# 1. Disable cuDNN globally (temporary)
# -----------------------------
orig_cudnn_state = torch.backends.cudnn.enabled
torch.backends.cudnn.enabled = False
print("‚öôÔ∏è cuDNN disabled temporarily for SHAP backward pass.")

# -----------------------------
# 2. Load data
# -----------------------------
data = np.load(patient_npz, allow_pickle=True)
X = data["specs"]  # (samples, channels, freqs, times)
y = data["labels"]
n_freq_bins = X.shape[2]

# Build frequency axis (0‚Äì128 Hz)
freqs = np.linspace(0, 128, n_freq_bins)

# Detect if frequency axis is inverted
mean_power_per_freq = X.mean(axis=(0, 1, 3))
if mean_power_per_freq[0] < mean_power_per_freq[-1]:
    print("‚ö†Ô∏è Detected inverted frequency axis ‚Äî flipping it!")
    X = np.flip(X, axis=2)
    mean_power_per_freq = mean_power_per_freq[::-1]

# -----------------------------
# 3. Split per-class and convert to torch
# -----------------------------
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

X_preictal_np = X[y == 1]
X_interictal_np = X[y == 0]

def to_torch(x_np):
    return torch.tensor(x_np, dtype=torch.float32, device=device)

X_preictal = to_torch(X_preictal_np)
X_interictal = to_torch(X_interictal_np)

print(f"Loaded: Preictal {len(X_preictal_np)} windows, Interictal {len(X_interictal_np)} windows")

# -----------------------------
# 4. Prepare model
# -----------------------------
combined_model.to(device)
combined_model.eval()

# -----------------------------
# 5. SHAP computation
# -----------------------------
def compute_shap(background_tensor, test_tensor, desc, max_bg=64, max_test=128):
    bg = background_tensor[:max_bg]
    test = test_tensor[:max_test]

    print(f"\nüß† Explaining {desc}: background={len(bg)}, test={len(test)}")
    explainer = shap.GradientExplainer(model=combined_model, data=bg)
    shap_values = explainer.shap_values(test)
    if not isinstance(shap_values, list):
        shap_values = [shap_values]
    sv = shap_values[0]
    sv = np.nan_to_num(sv)
    print(f"  -> SHAP shape: {sv.shape}")
    return sv

sv_pre_vs_inter = compute_shap(X_interictal, X_preictal, "Preictal vs Interictal")
sv_inter_vs_pre = compute_shap(X_preictal, X_interictal, "Interictal vs Preictal")

# -----------------------------
# 6. Frequency-level aggregation
# -----------------------------
def agg_freq(sv):
    a = np.abs(sv)
    if a.ndim == 5:
        a = a.mean(axis=-1)
    freq_imp = a.mean(axis=(0, 1, 3))  # (freqs,)
    return freq_imp

freq_pre = agg_freq(sv_pre_vs_inter)
freq_inter = agg_freq(sv_inter_vs_pre)

# -----------------------------
# 7. Normalize to proportional scale
# -----------------------------
def prop_normalize(vec):
    s = vec.sum()
    return vec / (s + 1e-12)

freq_pre_prop = prop_normalize(freq_pre)
freq_inter_prop = prop_normalize(freq_inter)

# -----------------------------
# 8. Plot proportional frequency SHAP
# -----------------------------
plt.figure(figsize=(10,5))
plt.plot(freqs, freq_pre_prop, color='crimson', lw=2, label='Preictal (prop)')
plt.plot(freqs, freq_inter_prop, color='royalblue', lw=2, label='Interictal (prop)')
plt.fill_between(freqs, freq_pre_prop, alpha=0.2, color='crimson')
plt.fill_between(freqs, freq_inter_prop, alpha=0.2, color='royalblue')
plt.xlabel("Frequency (Hz)")
plt.ylabel("Proportional SHAP attribution")
plt.title("Frequency-wise Proportional SHAP ‚Äî Preictal vs Interictal")
plt.legend()
plt.grid(alpha=0.3)
plt.tight_layout()

out_path = os.path.join(OUT_DIR, "freq_shap_prop.png")
plt.savefig(out_path, bbox_inches="tight")
plt.show()

print(f"‚úÖ Saved frequency SHAP plot at: {out_path}")

# -----------------------------
# 9. Restore cuDNN
# -----------------------------
torch.backends.cudnn.enabled = orig_cudnn_state
print("üîÅ cuDNN re-enabled. Done.")
