# 07 · SHAP Explainability & Attention Overlays (SpectraMind V50)

**Purpose.** Quantify and visualize feature influence on predicted transmission spectra using SHAP, and (optionally) fuse with saved attention weights from the spectral GNN branch. Artifacts are saved to `outputs/notebooks/07_shap_explainability/`.

**What this notebook does**
1. Loads a trained V50 model checkpoint or scripted model from the latest (or chosen) Hydra run directory.
2. Loads a small validation/sample set and runs forward predictions.
3. Computes SHAP values (Deep/Gradient/Kernel fallback) and exports:
   - Beeswarm summary
   - Per‑wavelength bar/heatmap
   - Spectrum‑overlay plots (|SHAP| vs. wavelength on top of \u03bc/\u03c3 predictions)
4. If available, fuses SHAP saliency with saved spectral attention weights to highlight consensus regions.
5. Captures environment (package versions, git/DVC/Hydra config snapshot) for reproducibility.

**Pre‑requisites**
- You have at least one completed training/prediction run under `outputs/` (Hydra timestamped folder), e.g. `outputs/2025-08-18/11-52-30/` with `checkpoints/` or a scripted model file and `config.yaml`.
- Optional: saved artifacts such as `wavelengths.npy`, `attn_weights.npy`, and a small validation tensor/Parquet for quick analysis.

**Reproducibility note**
We capture versions, seeds, and resolve the active Hydra run folder. All figures/reports are saved deterministically based on `RANDOM_SEED` unless SHAP estimator randomness is inherent.


In [None]:
import os, sys, json, glob, time, math, random, platform, subprocess, pathlib
from pathlib import Path
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

# Optional installs (no-ops if already present)
try:
    import shap  # noqa: F401
except Exception:
    %pip -q install shap
    import shap

try:
    import torch
    from torch import nn
except Exception:
    %pip -q install torch --extra-index-url https://download.pytorch.org/whl/cpu
    import torch
    from torch import nn

RANDOM_SEED = 1234
np.random.seed(RANDOM_SEED)
random.seed(RANDOM_SEED)
torch.manual_seed(RANDOM_SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(RANDOM_SEED)

# Paths
ROOT = Path.cwd()
OUT_NOTEBOOK_DIR = ROOT / "outputs" / "notebooks" / "07_shap_explainability"
OUT_NOTEBOOK_DIR.mkdir(parents=True, exist_ok=True)

print("Python:", platform.python_version())
print("Torch:", torch.__version__)
import shap as _shap; print("SHAP:", _shap.__version__)
print("CWD:", ROOT)


In [None]:
# --- Find latest Hydra run (or pick one) ---
def find_latest_hydra_run(outputs_root=ROOT/"outputs"):
    candidates = []
    for d in sorted(outputs_root.rglob(".hydra")):
        run_dir = d.parent
        # Prefer runs that contain checkpoints or a scripted model
        score = int((run_dir / "checkpoints").exists()) + int(any(run_dir.glob("*.pt"))) + int(any(run_dir.glob("*.pth")))
        mtime = run_dir.stat().st_mtime
        candidates.append((score, mtime, run_dir))
    if not candidates:
        return None
    candidates.sort(reverse=True)
    return candidates[0][2]

RUN_DIR = find_latest_hydra_run()
print("Detected RUN_DIR:", RUN_DIR)
if RUN_DIR is None:
    print("⚠️ No Hydra run found under outputs/. You can set RUN_DIR manually below.")

# Manual override example:
# RUN_DIR = ROOT / "outputs" / "2025-08-18" / "11-52-30"

# Load run config if present
CFG_PATH = None
if RUN_DIR is not None:
    for p in [RUN_DIR/"config.yaml", RUN_DIR/".hydra"/"config.yaml"]:
        if p.exists():
            CFG_PATH = p; break

if CFG_PATH:
    print("Using config:", CFG_PATH)
    print("\n--- Hydra Config (head) ---")
    print("\n".join(Path(CFG_PATH).read_text().splitlines()[:50]))
else:
    print("⚠️ Could not locate config.yaml (will proceed with defaults).")


In [None]:
# --- Load model (scripted .pt/.pth or Lightning/PyTorch checkpoint) ---
def _try_scripted_model(run_dir):
    if run_dir is None:
        return None
    for ext in ("*.pt", "*.pth"):
        cands = list(run_dir.glob(ext))
        if cands:
            try:
                m = torch.jit.load(str(cands[0]), map_location="cpu")
                m.eval()
                print("Loaded scripted model:", cands[0].name)
                return m
            except Exception as e:
                print("Scripted load failed:", e)
    return None

def _try_state_dict(run_dir):
    # Example: load via model class in repository
    if run_dir is None:
        return None
    ck = None
    for p in list((run_dir/"checkpoints").glob("*.ckpt")) + list((run_dir/"checkpoints").glob("*.pth")):
        ck = p; break
    if ck is None:
        return None
    try:
        # Attempt to import model class
        sys.path.insert(0, str(ROOT))
        try:
            from src.model.v50_model import V50Model  # adjust to your repo
            model = V50Model()
            sd = torch.load(str(ck), map_location="cpu")
            # Lightning or raw dict
            state_dict = sd.get("state_dict", sd)
            # Strip prefixes if needed
            clean_sd = {k.replace("model.", ""): v for k, v in state_dict.items()}
            missing, unexpected = model.load_state_dict(clean_sd, strict=False)
            print("Loaded state_dict from", ck.name, "missing:", len(missing), "unexpected:", len(unexpected))
            model.eval()
            return model
        except Exception as ie:
            print("Repo model import failed:", ie)
            return None
    except Exception as e:
        print("State_dict load failed:", e)
        return None

model = _try_scripted_model(RUN_DIR)
if model is None:
    model = _try_state_dict(RUN_DIR)

if model is None:
    # Fallback: define a tiny surrogate so the notebook remains runnable for demo
    class TinySurrogate(nn.Module):
        def __init__(self, in_dim=512, out_dim=283):
            super().__init__()
            self.net = nn.Sequential(
                nn.Linear(in_dim, 256), nn.ReLU(),
                nn.Linear(256, 256), nn.ReLU(),
                nn.Linear(256, out_dim)
            )
        def forward(self, x):
            return self.net(x)
    model = TinySurrogate()
    print("⚠️ Using TinySurrogate demo model (no trained weights found).")

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device).eval();


In [None]:
# --- Load a small sample batch ---
def load_sample_inputs(run_dir):
    # Try common locations in this repo structure, adjust as needed
    candidates = [
        run_dir/"cache"/"val_inputs.pt" if run_dir else None,
        run_dir/"pred_inputs.pt" if run_dir else None,
        ROOT/"data"/"processed"/"val_inputs.pt",
        ROOT/"data"/"processed"/"val.parquet"
    ]
    for c in candidates:
        if c is None or not c.exists():
            continue
        if c.suffix == ".pt":
            t = torch.load(str(c), map_location="cpu")
            if isinstance(t, torch.Tensor):
                return t[:32]  # small subset
        elif c.suffix == ".parquet":
            df = pd.read_parquet(c)
            x = torch.tensor(df.values, dtype=torch.float32)
            return x[:32]
    # Fallback synthetic demo
    x = torch.randn(32, 512)
    return x

X = load_sample_inputs(RUN_DIR)
X = X.to(device)
print("Sample batch:", tuple(X.shape))

# Wavelengths (attempt to load, else synthesize 283 channels)
def load_wavelengths(run_dir):
    for p in [ROOT/"data"/"meta"/"wavelengths.npy", ROOT/"data"/"meta"/"wavelengths.csv", run_dir/"wavelengths.npy" if run_dir else None]:
        if p is not None and p.exists():
            if p.suffix == ".npy":
                return np.load(p)
            else:
                return pd.read_csv(p).values.squeeze()
    return np.linspace(0.8, 5.0, 283)  # µm synthetic grid

WAVELENGTHS = load_wavelengths(RUN_DIR)
print("Wavelength samples:", WAVELENGTHS[:5], "...", WAVELENGTHS[-5:])


In [None]:
# --- Forward pass (µ/σ if your head is probabilistic; here we assume deterministic head) ---
with torch.no_grad():
    Y_pred = model(X)  # [B, 283] expected
Y_pred = Y_pred.detach().cpu().numpy()
print("Pred shape:", Y_pred.shape)

# Save quick preview
fig, ax = plt.subplots(1, 1, figsize=(9, 4))
ax.plot(WAVELENGTHS, Y_pred[0], lw=1.5)
ax.set_xlabel("Wavelength [µm]")
ax.set_ylabel("Predicted transmission")
ax.set_title("Prediction preview (sample 0)")
fig.tight_layout()
fig.savefig(OUT_NOTEBOOK_DIR/"preview_prediction.png", dpi=150)
plt.close(fig)
print("Saved:", OUT_NOTEBOOK_DIR/"preview_prediction.png")


## Compute SHAP values
We attempt DeepExplainer (for typical PyTorch nets). If that fails, we try GradientExplainer; if that also fails, we fall back to KernelExplainer on a subsample (slower, but robust).

In [None]:
def compute_shap(model, X, max_bg=64, method_order=("deep", "gradient", "kernel")):
    model.eval()
    X_cpu = X.detach().cpu()
    # background for SHAP
    bg = X_cpu[: min(max_bg, len(X_cpu))]
    explainer, shap_values = None, None
    for m in method_order:
        try:
            if m == "deep":
                e = shap.DeepExplainer(model, bg.to(model.device))
                sv = e.shap_values(X)
                return e, sv
            if m == "gradient":
                e = shap.GradientExplainer(model, bg.to(model.device))
                sv = e.shap_values(X)
                return e, sv
            if m == "kernel":
                f = lambda inp: model(torch.tensor(inp, dtype=X.dtype, device=model.device)).detach().cpu().numpy()
                e = shap.KernelExplainer(f, shap.kmeans(bg.numpy(), min(20, len(bg))))
                sv = e.shap_values(X_cpu.numpy(), nsamples=100)
                return e, sv
        except Exception as exc:
            print(f"{m} explainer failed:", exc)
            continue
    raise RuntimeError("All SHAP methods failed.")

explainer, shap_vals = compute_shap(model, X)
print(type(explainer))
if isinstance(shap_vals, list):  # some explainers return list per output; take first/mean
    try:
        shap_array = np.stack(shap_vals, axis=0)  # [O, B, in_dim] or similar
        shap_vals_mean = shap_array.mean(axis=0)
    except Exception:
        shap_vals_mean = np.array(shap_vals[0])
else:
    shap_vals_mean = np.array(shap_vals)

print("SHAP shape:", shap_vals_mean.shape)


In [None]:
# --- SHAP summary (beeswarm) ---
try:
    shap.summary_plot(shap_vals_mean, X.detach().cpu().numpy(), show=False, max_display=25)
    plt.tight_layout()
    plt.savefig(OUT_NOTEBOOK_DIR/"shap_summary_beeswarm.png", dpi=150, bbox_inches='tight')
    plt.close()
    print("Saved:", OUT_NOTEBOOK_DIR/"shap_summary_beeswarm.png")
except Exception as e:
    print("Beeswarm failed:", e)

# --- Mean |SHAP| per input feature (bar) ---
mean_abs = np.mean(np.abs(shap_vals_mean), axis=0)
fig, ax = plt.subplots(1,1, figsize=(9,3))
ax.plot(mean_abs, lw=1)
ax.set_title("Mean |SHAP| per input feature")
ax.set_xlabel("Input feature index")
ax.set_ylabel("Mean |SHAP|")
fig.tight_layout()
fig.savefig(OUT_NOTEBOOK_DIR/"shap_mean_abs_per_feature.png", dpi=150)
plt.close(fig)
print("Saved:", OUT_NOTEBOOK_DIR/"shap_mean_abs_per_feature.png")


## SHAP over wavelength (overlay)
If your model input already aligns per-wavelength (e.g., features map directly to 283 spectral channels), we can overlay |SHAP| on the predicted spectrum. If inputs are not 1:1 with wavelengths, set or derive a mapping.

In [None]:
# --- Attempt to interpret last 283 features as wavelength-local features ---
def to_wavelength_shap(mean_abs, expected=283):
    if len(mean_abs) >= expected:
        return mean_abs[-expected:]
    # If shorter, interpolate to 283
    x_old = np.linspace(0, 1, len(mean_abs))
    x_new = np.linspace(0, 1, expected)
    return np.interp(x_new, x_old, mean_abs)

wl_shap = to_wavelength_shap(mean_abs, expected=283)

fig, ax1 = plt.subplots(1,1, figsize=(10,4))
ax1.plot(WAVELENGTHS, Y_pred[0], color='C0', lw=1.5, label='Predicted spectrum (sample 0)')
ax2 = ax1.twinx()
ax2.plot(WAVELENGTHS, wl_shap, color='C3', lw=1.0, alpha=0.9, label='|SHAP| (approx per-λ)')
ax1.set_xlabel("Wavelength [µm]")
ax1.set_ylabel("Transmission")
ax2.set_ylabel("|SHAP|")
ax1.set_title("Spectrum & |SHAP| overlay")
fig.tight_layout()
fig.savefig(OUT_NOTEBOOK_DIR/"overlay_spectrum_shap.png", dpi=150)
plt.close(fig)
print("Saved:", OUT_NOTEBOOK_DIR/"overlay_spectrum_shap.png")


## SHAP × Attention fusion (optional)
If your spectral GNN saved attention weights per wavelength (e.g., `attn_weights.npy` of shape `[heads, 283]` or `[283]`), we can blend them with |SHAP| to highlight consensus regions.

In [None]:
def load_attention(run_dir):
    if run_dir is None:
        return None
    for p in [run_dir/"attn_weights.npy", run_dir/"artifacts"/"attn_weights.npy"]:
        if p.exists():
            attn = np.load(p)
            if attn.ndim == 2:
                attn = attn.mean(axis=0)
            return attn
    return None

attn = load_attention(RUN_DIR)
if attn is not None and len(attn) != 283:
    # interpolate to 283
    x_old = np.linspace(0, 1, len(attn))
    x_new = np.linspace(0, 1, 283)
    attn = np.interp(x_new, x_old, attn)

if attn is not None:
    attn_norm = (attn - attn.min()) / (attn.max() - attn.min() + 1e-12)
    shap_norm = wl_shap / (wl_shap.max() + 1e-12)
    fusion = 0.5*attn_norm + 0.5*shap_norm

    fig, ax = plt.subplots(1,1, figsize=(10,3.5))
    ax.plot(WAVELENGTHS, shap_norm, lw=1.0, label='|SHAP| (norm)')
    ax.plot(WAVELENGTHS, attn_norm, lw=1.0, label='Attention (norm)')
    ax.plot(WAVELENGTHS, fusion, lw=2.0, label='Fusion', color='C2')
    ax.set_title('SHAP × Attention consensus')
    ax.set_xlabel('Wavelength [µm]')
    ax.legend()
    fig.tight_layout()
    fig.savefig(OUT_NOTEBOOK_DIR/"shap_attention_fusion.png", dpi=150)
    plt.close(fig)
    print("Saved:", OUT_NOTEBOOK_DIR/"shap_attention_fusion.png")
else:
    print("No attention weights found (skipping fusion plot).")


## HTML report export (optional)
Exports a lightweight HTML with inline PNGs and (if possible) a SHAP force plot snippet.

In [None]:
report_path = OUT_NOTEBOOK_DIR/"shap_report.html"
pngs = [
    "preview_prediction.png",
    "shap_summary_beeswarm.png",
    "shap_mean_abs_per_feature.png",
    "overlay_spectrum_shap.png",
    "shap_attention_fusion.png"
]
html = [
    "<html><head><meta charset='utf-8'><title>SHAP Report</title></head><body>",
    f"<h1>SHAP Report · {time.strftime('%Y-%m-%d %H:%M:%S')}</h1>",
]
if CFG_PATH:
    html.append("<h2>Hydra Config (head)</h2><pre>")
    head = "\n".join(Path(CFG_PATH).read_text().splitlines()[:80])
    html.append(head)
    html.append("</pre>")

for png in pngs:
    p = OUT_NOTEBOOK_DIR/png
    if p.exists():
        html.append(f"<h3>{png}</h3><img src='{png}' style='max-width:100%;'>")

html.append("</body></html>")
Path(report_path).write_text("\n".join(html), encoding="utf-8")
print("Saved:", report_path)


## (Optional) Generate predictions via CLI before explaining
If you haven’t produced predictions/artifacts yet, and your environment exposes the CLI, you can trigger a fast run below (adjust flags to your config).

In [None]:
USE_CLI = False  # set True to enable
if USE_CLI:
    cmd = [
        "spectramind", "predict",
        "predict.fast_dev_run=True",
        "loader.limit=64"
    ]
    print("Running:", " ".join(cmd))
    try:
        subprocess.run(cmd, check=True)
    except Exception as e:
        print("CLI failed:", e)


## Environment capture & finishing up

In [None]:
def pip_freeze():
    try:
        out = subprocess.check_output([sys.executable, "-m", "pip", "freeze"], text=True)
        Path(OUT_NOTEBOOK_DIR/"pip_freeze.txt").write_text(out)
        return True
    except Exception:
        return False

ok = pip_freeze()
print("Saved pip_freeze:", ok)
meta = {
    "timestamp": time.strftime('%Y-%m-%d %H:%M:%S'),
    "python": platform.python_version(),
    "torch": torch.__version__,
    "shap": _shap.__version__,
    "device": str(device),
    "run_dir": str(RUN_DIR) if RUN_DIR else None,
    "config": str(CFG_PATH) if CFG_PATH else None,
}
Path(OUT_NOTEBOOK_DIR/"meta.json").write_text(json.dumps(meta, indent=2))
print(json.dumps(meta, indent=2))
