
# Synthetic X‑ray dataset of a hollow metal sphere

**What this cell is for.**  
Synthesise a simple greyscale dataset that looks like X‑ray cross‑sections of a *hollow* metal sphere.  
Each image contains a circular shell (intensity on the shell only), and a single **global density factor** ($y$) that scales the whole shell. This makes the problem deliberately **linear** so pixel‑wise linear models behave sensibly, while still being realistic enough for a CNN to learn.

---

## 1) Geometry and coordinate system

Let the image be $(100\times 100)$ pixels with centre $(c_x,c_y)=\big(\tfrac{H-1}{2},\tfrac{W-1}{2}\big)$.  
Generate a slightly **elliptical** radius to allow mild shape variation:

$$
r_e(x,y)\;=\;\sqrt{\left(\frac{x-c_x}{s_x}\right)^2 + \left(\frac{y-c_y}{s_y}\right)^2},
\qquad s_x,s_y \approx 1.
$$

Two radii define the **hollow** interior and the **outer** boundary:

- inner radius $R_0 \approx 20$ px (interior is empty),
- outer radius $R_1 \approx 40$ px (outside is background).

The shell mask is

$$
\mathbf{1}_{\text{shell}}(x,y)=
\begin{cases}
1, & R_0 < r_e(x,y) \le R_1,\\[2pt]
0, & \text{otherwise}.
\end{cases}
$$

---

## 2) Radial intensity profile inside the shell

Within the shell create a smooth radial fall‑off using a shape parameter $\gamma\!\in[0.9,1.1]$.

Define the **normalised thickness coordinate**
$$
t(x,y)\;=\;\frac{r_e(x,y)-R_0}{\,R_1 - R_0\,}\in[0,1].
$$

The unshaded base intensity is
$$
I_0(x,y)\;=\;\bigl(1-t(x,y)\bigr)^{\gamma}.
$$

Intuitively, $I_0$ is bright at the inner boundary and fades towards the outer boundary.

---

## 3) Low‑frequency illumination (shading)

Apply a gentle multiplicative illumination field to mimic sensor vignetting:

$$
S(x,y)\;=\;1+\alpha \, P_2(X,Y),
\qquad \alpha\in[0.05,0.12],
$$

where $(X,Y)$ are coordinates normalised by $R_1$ and
$$
P_2(X,Y)=c_0+c_1X+c_2Y+c_3XY+c_4X^2+c_5Y^2,
$$
with the polynomial re‑scaled to have zero mean and unit standard deviation so that $\alpha$ directly controls the magnitude.

---

## 4) Image formation and label

Sample a **global density factor** $y\in[0.8,1.2]$ and form the image (before noise) as

$$
I(x,y)\;=\;y\,I_0(x,y)\,S(x,y).
$$

Then add a small constant **offset** and **Gaussian pixel noise** to simulate the sensor:

$$
I_{\text{obs}}(x,y)\;=\;\mathrm{clip}\Big(I(x,y)+\text{offset}+\varepsilon(x,y),\,0,1\Big),
\qquad \varepsilon\sim\mathcal{N}(0,\sigma^2).
$$

The training **label** is $y$ with tiny measurement noise:
$$
y_{\text{obs}} = y + \eta,\qquad \eta\sim\mathcal{N}(0,\sigma_y^2).
$$

---

## 5) Why this cell matters

- The mapping from pixels to $y$ is **linear** at the population level, so OLS/Ridge provide a solid baseline.  
- Small geometric and illumination variations keep the task realistic and prevent trivial overfitting.  
- Later cells (CNN + SHAP) use these arrays:

```
images       : (N, 100, 100) float32 in [0,1]
y_true_all   : (N,)   noise‑free target values
y_obs_all    : (N,)   observed targets (y_true + tiny noise)
```

---


In [None]:
"""
Synthetic X‑ray dataset of a hollow metal sphere (100×100) — generation & visualisation
---------------------------------------------------------------------------------------

What this cell does:
  • Generates N=1000 grayscale 100×100 images that look like X‑ray cross‑sections of a
    *hollow* metal sphere:
      - inner radius R0 ≈ 20 px (hollow)
      - outer radius R1 ≈ 40 px (metal shell)
      - intensity decays smoothly from the inner boundary to the outer boundary
  • Adds small, realistic variations per image:
      radii jitter, slight centre shift, mild ellipticity,
      gentle low‑frequency “illumination” shading,
      sensor noise, tiny label noise
  • Defines a per‑image target y (“density factor”): it globally scales the shell intensity.
    This makes the pixel‑wise linear model predictive, but noise/variation keep it realistic.
  • Shows:
      - a grid of sample images with y_true and y_obs,
      - the label histogram (y_obs),
      - the mean image across the dataset.


Variables created for later cells:
  images          : (N, 100, 100) float32 in [0,1]
  y_true_all      : (N,) noise‑free target values
  y_obs_all       : (N,) observed targets (y_true + tiny noise)
  helper functions: print_metrics, plot_pred_vs_actual, show_heatmap
"""

import numpy as np
import matplotlib.pyplot as plt

# ----------------------------
# Reproducibility & settings
# ----------------------------
SEED = 2025
rng = np.random.default_rng(SEED)

IMG_SIZE = 100
N_IMAGES = 1000

# Geometry
R0_BASE = 20.0     # inner (hollow) radius
R1_BASE = 40.0     # outer radius
CENTER   = (IMG_SIZE - 1) / 2.0  # 49.5 (sub‑pixel centre for symmetry)

# Variation & noise (small but meaningful)
RADIUS_JITTER_PX      = 2.0
CENTER_JITTER_PX      = 1.0
ELLIPTICITY_RANGE     = (0.98, 1.02)
PROFILE_GAMMA_RANGE   = (0.90, 1.10)   # 1.0 is linear fall‑off
SHADING_ALPHA_RANGE   = (0.05, 0.12)   # multiplicative, low‑frequency illumination
PIXEL_NOISE_STD       = 0.02           # additive pixel noise
BACKGROUND_OFFSET_MAX = 0.01           # small sensor offset per image
Y_RANGE               = (0.8, 1.2)     # density scale factor
Y_LABEL_NOISE_STD     = 0.01           # tiny measurement noise on labels

# ----------------------------
# Helper: metrics & plotting (re‑used in later cells)
# ----------------------------
def rmse(y_true, y_pred):
    return float(np.sqrt(np.mean((y_true - y_pred) ** 2)))

def print_metrics(split_name, y_true, y_pred):
    from sklearn.metrics import r2_score, mean_absolute_error
    r2  = r2_score(y_true, y_pred)
    mae = mean_absolute_error(y_true, y_pred)
    rms = rmse(y_true, y_pred)
    print(f"[{split_name}] R² = {r2:.4f} | MAE = {mae:.4f} | RMSE = {rms:.4f}")
    return r2, mae, rms

def plot_pred_vs_actual(y_actual, y_pred, title="Predicted vs Actual"):
    plt.figure(figsize=(5, 4))
    plt.scatter(y_actual, y_pred, s=14, alpha=0.75)
    lo = min(float(np.min(y_actual)), float(np.min(y_pred)))
    hi = max(float(np.max(y_actual)), float(np.max(y_pred)))
    pad = 0.02 * (hi - lo) if hi > lo else 0.01
    plt.plot([lo - pad, hi + pad], [lo - pad, hi + pad], linestyle="--", linewidth=1.0)
    plt.xlabel("Actual y")
    plt.ylabel("Predicted y")
    plt.title(title)
    plt.grid(True, alpha=0.3)
    plt.tight_layout()
    plt.show()

def show_heatmap(arr2d, title, cmap=None, with_colorbar=True):
    plt.figure(figsize=(5, 4))
    plt.imshow(arr2d, cmap=cmap)
    if with_colorbar:
        plt.colorbar(fraction=0.046, pad=0.04)
    plt.title(title)
    plt.axis("off")
    plt.tight_layout()
    plt.show()

# ----------------------------
# Coordinates (constant fields)
# ----------------------------
y_coords, x_coords = np.indices((IMG_SIZE, IMG_SIZE))
X0 = x_coords - CENTER
Y0 = y_coords - CENTER
XN = X0 / R1_BASE
YN = Y0 / R1_BASE

def make_low_frequency_shading(alpha, rng):
    """
    Smooth multiplicative shading:
      shade(x,y) = 1 + alpha * P2(x,y),
    where P2 is a random quadratic in (XN, YN), normalised to zero‑mean and unit‑std.
    """
    c = rng.normal(0.0, 1.0, size=6)
    field = (c[0] + c[1]*XN + c[2]*YN + c[3]*XN*YN + c[4]*XN*XN + c[5]*YN*YN)
    field = field - field.mean()
    std = field.std()
    if std > 1e-8:
        field = field / std
    return 1.0 + alpha * field

def generate_one_image(rng):
    """
    Create one hollow‑sphere image with small geometric & illumination variations.
    Returns:
      img    : (H,W) float32 in [0,1]
      y_true : noise‑free density factor
      y_obs  : observed (y_true + tiny label noise)
    """
    # Jittered geometry
    R0 = R0_BASE + rng.uniform(-RADIUS_JITTER_PX, +RADIUS_JITTER_PX)
    R1 = R1_BASE + rng.uniform(-RADIUS_JITTER_PX, +RADIUS_JITTER_PX)
    if R1 <= R0 + 8.0:  # ensure sensible shell thickness
        R1 = R0 + 8.0

    cx = CENTER + rng.uniform(-CENTER_JITTER_PX, +CENTER_JITTER_PX)
    cy = CENTER + rng.uniform(-CENTER_JITTER_PX, +CENTER_JITTER_PX)

    sx = rng.uniform(*ELLIPTICITY_RANGE)  # mild ellipticity
    sy = rng.uniform(*ELLIPTICITY_RANGE)

    gamma = rng.uniform(*PROFILE_GAMMA_RANGE)  # fall‑off shape

    # Elliptical radius from (cx,cy)
    dx = (x_coords - cx) / sx
    dy = (y_coords - cy) / sy
    re = np.sqrt(dx*dx + dy*dy)

    # Ring profile: value 1 at inner boundary → 0 at outer boundary, shaped by gamma
    img = np.zeros((IMG_SIZE, IMG_SIZE), dtype=np.float64)
    shell = (re > R0) & (re <= R1)
    t = np.empty_like(re)
    t[shell] = (re[shell] - R0) / (R1 - R0)  # 0 at inner boundary, 1 at outer
    img[shell] = (1.0 - t[shell]) ** gamma

    # Low‑frequency multiplicative shading
    alpha = rng.uniform(*SHADING_ALPHA_RANGE)
    shading = make_low_frequency_shading(alpha, rng)
    img = img * shading

    # Apply density scaling (true label)
    y_true = float(rng.uniform(*Y_RANGE))
    img = y_true * img

    # Sensor offset + pixel noise
    offset = rng.uniform(0.0, BACKGROUND_OFFSET_MAX)
    img = img + offset
    img = img + rng.normal(0.0, PIXEL_NOISE_STD, size=img.shape)

    # Clip to [0,1]
    img = np.clip(img, 0.0, 1.0).astype(np.float32)

    # Observed label with tiny measurement noise
    y_obs = y_true + float(rng.normal(0.0, Y_LABEL_NOISE_STD))
    return img, y_true, y_obs

def generate_dataset(n_images, rng):
    images = np.empty((n_images, IMG_SIZE, IMG_SIZE), dtype=np.float32)
    y_true = np.empty(n_images, dtype=np.float32)
    y_obs  = np.empty(n_images, dtype=np.float32)
    for i in range(n_images):
        img, yt, yo = generate_one_image(rng)
        images[i] = img
        y_true[i] = yt
        y_obs[i]  = yo
    return images, y_true, y_obs

# ----------------------------
# Generate & summarise
# ----------------------------
images, y_true_all, y_obs_all = generate_dataset(N_IMAGES, rng)

print("Dataset summary:")
print(f"  images.shape = {images.shape} (float32, values in [0,1])")
print(f"  y_true: mean={y_true_all.mean():.4f}, std={y_true_all.std():.4f}, "
      f"min={y_true_all.min():.4f}, max={y_true_all.max():.4f}")
print(f"  y_obs :  mean={y_obs_all.mean():.4f},  std={y_obs_all.std():.4f},  "
      f"min={y_obs_all.min():.4f},  max={y_obs_all.max():.4f}")

# ----------------------------
# Visual checks
# ----------------------------
# Grid of sample images
def show_image_grid(images, y_true, y_obs, n_show=12, title="Sample simulated images"):
    n_show = int(min(n_show, images.shape[0]))
    idxs = rng.choice(images.shape[0], size=n_show, replace=False)
    cols = int(np.ceil(np.sqrt(n_show)))
    rows = int(np.ceil(n_show / cols))
    fig, axes = plt.subplots(rows, cols, figsize=(1.9*cols, 1.9*rows))
    axes = np.atleast_1d(axes).ravel()
    for ax, i in zip(axes, idxs):
        ax.imshow(images[i], cmap="gray", vmin=0.0, vmax=1.0)
        ax.set_title(f"y_true={y_true[i]:.3f}\ny_obs={y_obs[i]:.3f}", fontsize=8)
        ax.axis("off")
    for k in range(n_show, len(axes)):  # hide any spare panels
        axes[k].axis("off")
    plt.suptitle(title, fontsize=12)
    plt.tight_layout()
    plt.show()

show_image_grid(images, y_true_all, y_obs_all, n_show=12,
                title="Synthetic hollow‑sphere X‑ray images")

# Label histogram
plt.figure(figsize=(5.5, 4))
plt.hist(y_obs_all, bins=20, edgecolor="black", alpha=0.85)
plt.xlabel("Observed y")
plt.ylabel("Count")
plt.title("Label distribution (y_obs)")
plt.tight_layout()
plt.show()

# Mean image (across the full dataset)
mean_img = images.mean(axis=0)
show_heatmap(mean_img, title="Mean image (all samples)", cmap="gray", with_colorbar=True)

print("\nAll arrays are now in memory: images, y_true_all, y_obs_all.")



# Linear and Ridge regression

**Purpose of this cell**  
Fit two classical linear baselines on the synthetic images: (i) Ordinary Least Squares (OLS) on flattened pixels and (ii) Ridge regression with feature standardisation and cross‑validated regularisation. Report metrics on Train/Validation/Test splits and visualise both predictions and coefficient maps.

---

## 1) Data layout and splits

- Each image has shape $H\times W=100\times 100$ and is flattened into a vector of length $p=10{,}000$. Stacking all images yields a design matrix $X\in\mathbb{R}^{n\times p}$ and a target vector $y\in\mathbb{R}^{n}$.
- The cell creates fixed splits (seeded): **Train 64%**, **Validation 16%**, **Test 20%**.
- Note the scale: typically $p \gg n_{\text{train}}$ (e.g. $10{,}000$ features vs about $640$ training samples). This matters for OLS.

---

## 2) Ordinary Least Squares (OLS)

**Objective.** With coefficients $w\in\mathbb{R}^{p}$ and intercept $b\in\mathbb{R}$,
$$
\min_{w,b}\ \frac{1}{2}\,\lVert Xw + b\mathbf{1} - y\rVert_2^2.
$$

Let $\tilde X=[X\ \ \mathbf{1}]$ and $\tilde w=[w^\top\ \ b]^\top$. The SVD‑based solver returns the minimum‑norm solution
$$
\tilde w^\star=\tilde X^{+}y,
$$
where $\tilde X^{+}$ is the Moore–Penrose pseudoinverse.

**Why training $R^2$ can be $\approx 1$ when $p\gg n$.**  
When $\operatorname{rank}(\tilde X)=n_{\text{train}}$, the fitted predictions on the training set satisfy
$$
\hat y=\tilde X\tilde w^\star=\tilde X\tilde X^{+}y=Hy,
$$
with $H$ the idempotent **hat matrix** (a projector). In the over‑parameterised regime the column space of $\tilde X$ is large enough to interpolate $y$, producing near‑zero training error and hence $R^2_{\text{train}}\approx 1$. This indicates capacity, not necessarily genuine signal.

**Plots and maps.** The cell shows Predicted‑vs‑Actual scatter plots and reshapes the learned $w$ to $H\times W$ to display a coefficient heatmap.

---

## 3) Ridge regression (with standardisation and cross‑validation)

**Why standardise?** Ridge penalises coefficient magnitudes, so each feature is scaled to comparable units:
$$
Z_{ij}=\frac{X_{ij}-\mu_j}{\sigma_j}.
$$

**Objective and solution.**
$$
\min_{w}\ \frac{1}{2}\,\lVert Zw-y\rVert_2^2+\lambda\lVert w\rVert_2^2
\quad\Longrightarrow\quad
w_\lambda=(Z^\top Z+\lambda I)^{-1}Z^\top y.
$$
The intercept is fitted separately (equivalently, $Z$ and $y$ are centred).

**Choosing $\lambda$ (alpha).**  
A 5‑fold cross‑validation over a log‑spaced grid $\lambda\in[10^{-6},10^3]$ is run **on the training split only**. The best $\lambda^\star$ by CV‑MSE is then refit on **Train+Val** and evaluated once on Test.

**Mapping coefficients back to raw pixel space.**  
Let $w^{\text{scaled}}$ be the weights learned on $Z$. With scaler statistics $(\mu_j,\sigma_j)$,
$$
w^{\text{raw}}_j=\frac{w^{\text{scaled}}_j}{\sigma_j},
\qquad
b^{\text{raw}}=b^{\text{scaled}}-\sum_{j=1}^{p}\frac{w^{\text{scaled}}_j\,\mu_j}{\sigma_j}.
$$
The $H\times W$ reshaped $w^{\text{raw}}$ gives an interpretable pixel‑space heatmap.

---

## 4) Metrics reported

For predictions $\hat y$ against targets $y$:
- **Coefficient of determination**
$$
R^2=1-\frac{\sum_i (y_i-\hat y_i)^2}{\sum_i (y_i-\bar y)^2},\qquad \bar y=\frac{1}{n}\sum_i y_i.
$$
- **Mean Absolute Error (MAE)**: $\ \text{MAE}=\frac{1}{n}\sum_i |y_i-\hat y_i|$.
- **Root Mean Squared Error (RMSE)**: $\ \text{RMSE}=\sqrt{\frac{1}{n}\sum_i (y_i-\hat y_i)^2}$.

**Noise ceiling for $R^2$.** Using the noise‑free labels $y^{\text{true}}$,
$$
R^2_{\text{ceiling}} = 1 - \frac{\sum_i (y^{\text{true}}_i - y^{\text{obs}}_i)^2}{\sum_i (y^{\text{obs}}_i - \bar y^{\text{obs}})^2},
$$
which equals $1-\frac{\operatorname{Var}(\text{noise})}{\operatorname{Var}(y^{\text{obs}})}$ when the observation noise is homoscedastic.

---

## 5) Typical outcomes and interpretation

- **OLS (Train):** $R^2$ often near 1 due to interpolation when $p\gg n$. Validation/Test provide the honest picture.
- **Ridge:** usually lower Train $R^2$ but **better generalisation**; the $\ell_2$ penalty damps ill‑conditioned directions.
- **Coefficient maps:** high absolute weights concentrate on the ring, where the generative signal lives.
- **Negative $R^2$ on Val/Test:** means the model is worse than predicting the split mean; the tuned Ridge should avoid this unless the task is intentionally adversarial.

---

## 6) Practices encoded in the code

- Feature scaling and Ridge are in a single `Pipeline` to avoid **data leakage** inside CV.
- Hyper‑parameter selection uses **only the training split**; Test is never touched until the end.
- A fixed random seed makes the split reproducible.



In [None]:
"""
Linear & Ridge regression on the in‑memory dataset
--------------------------------------------------

What this cell does:

1) Builds Train / Validation / Test splits: 64% / 16% / 20%.
2) Fits an Ordinary Least Squares (OLS) linear regression on flattened pixels → y_obs.
   • Notes why OLS can achieve R²≈1 on the training set when p ≫ n (10,000 features vs ~640 samples).
3) Trains a tuned Ridge regression:
   • Pipeline(StandardScaler, Ridge)
   • Hyper‑parameter alpha selected by 5‑fold CV on the *training* split only.
   • Report Train & Val performance for the CV‑selected model.
   • Refit the final Ridge on Train+Val with the chosen alpha; report Test performance.
4) Prints metrics (R², MAE, RMSE) and draws Predicted vs Actual plots.
5) Shows coefficient heatmaps for OLS and final Ridge (mapped back to raw‑pixel space).

"""

import numpy as np
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split, KFold, GridSearchCV
from sklearn.linear_model import LinearRegression, Ridge
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import StandardScaler

# Sanity check: make sure Cell 1 was run
if "images" not in globals() or "y_obs_all" not in globals():
    raise RuntimeError("Cell 1 has not been run. Please run Cell 1 to create images/y_obs_all in memory.")

SEED = 2025
TEST_FRAC = 0.20
VAL_FRAC_WITHIN_TRAINVAL = 0.20  # 20% of the 80% → 16% overall
rng = np.random.default_rng(SEED)

N, H, W = images.shape
D = H * W

# Flatten images to (N, D)
X_all = images.reshape(N, -1).astype(np.float64)
y_all = y_obs_all.astype(np.float64)
y_true_all_local = y_true_all.astype(np.float64)  # for noise ceiling

# Split Train+Val vs Test
X_trainval, X_test, y_trainval, y_test, y_true_trainval, y_true_test = train_test_split(
    X_all, y_all, y_true_all_local, test_size=TEST_FRAC, random_state=SEED
)
# Split Train vs Val
X_train, X_val, y_train, y_val, y_true_train, y_true_val = train_test_split(
    X_trainval, y_trainval, y_true_trainval, test_size=VAL_FRAC_WITHIN_TRAINVAL, random_state=SEED
)

print("Split summary:")
print(f"  Train: {X_train.shape[0]} samples, {X_train.shape[1]} features")
print(f"  Val  : {X_val.shape[0]} samples")
print(f"  Test : {X_test.shape[0]} samples")
print(f"  Note: p (features) = {D:,},  n_train = {X_train.shape[0]:,}.  Here, p ≫ n.\n")

# ------------------------------------------------------------
# 1) OLS Linear Regression (may interpolate when p ≫ n)
# ------------------------------------------------------------
ols = LinearRegression()  # scikit‑learn uses an SVD‑based least‑squares solver
ols.fit(X_train, y_train)

# Predictions
yhat_tr_ols = ols.predict(X_train)
yhat_va_ols = ols.predict(X_val)
yhat_te_ols = ols.predict(X_test)

print("OLS performance (against observed y):")
print_metrics("Train (OLS)", y_train, yhat_tr_ols)
print_metrics("Val   (OLS)", y_val,   yhat_va_ols)
print_metrics("Test  (OLS)", y_test,  yhat_te_ols)

# Why can Train R² be ~1?
# With p (=10,000) >> n_train (~640), an OLS model has enough degrees of freedom to fit the training
# labels almost perfectly (interpolate), especially as y is largely a global scaling of the image.
# The SVD solution picks one of infinitely many solutions (minimum‑norm) that achieve near‑zero
# training error if the design matrix has rank n_train. This is normal in over‑parameterised linear models.

# Noise ceiling on Test (best achievable R² against *noisy* labels)
r2_ceiling_test = 1.0 - np.sum((y_true_test - y_test) ** 2) / np.sum((y_test - np.mean(y_test)) ** 2)
print(f"\n[Reference] Test‑set R² noise ceiling (label noise): {r2_ceiling_test:.4f}\n")

# Visual diagnostics — Predicted vs Actual
plot_pred_vs_actual(y_train, yhat_tr_ols, "Predicted vs Actual (Train, OLS)")
plot_pred_vs_actual(y_val,   yhat_va_ols, "Predicted vs Actual (Validation, OLS)")
plot_pred_vs_actual(y_test,  yhat_te_ols, "Predicted vs Actual (Test, OLS)")

# Coefficient heatmap (OLS)
coef_ols = ols.coef_.reshape(H, W)
show_heatmap(coef_ols, title="OLS: learned pixel weights", cmap=None, with_colorbar=True)

# ------------------------------------------------------------
# 2) Ridge Regression (tuned) — Pipeline(StandardScaler, Ridge)
# ------------------------------------------------------------
# Standardise features because Ridge penalises the magnitude of coefficients and should
# not be affected by arbitrary feature scales.
pipe = Pipeline([
    ("scaler", StandardScaler(with_mean=True, with_std=True)),
    ("ridge",  Ridge(random_state=SEED))
])

# Hyper‑parameter grid for alpha (L2 strength)
alphas = np.logspace(-6, 3, 25)  # 1e-6 … 1e3

cv = KFold(n_splits=5, shuffle=True, random_state=SEED)
param_grid = {"ridge__alpha": alphas}

grid = GridSearchCV(
    estimator=pipe,
    param_grid=param_grid,
    scoring="neg_mean_squared_error",
    cv=cv,
    n_jobs=-1,
    verbose=0,
    refit=True  # refit the best on the *training* split
)

print("Tuning Ridge (5‑fold CV on the training split)…")
grid.fit(X_train, y_train)

best_alpha = grid.best_params_["ridge__alpha"]
print(f"Best alpha (CV on Train): {best_alpha:.6g}")

# Evaluate the CV‑selected model on Train & Val
ridge_cv_model = grid.best_estimator_
yhat_tr_ridge = ridge_cv_model.predict(X_train)
yhat_va_ridge = ridge_cv_model.predict(X_val)

print("\nRidge (CV‑selected) performance:")
print_metrics("Train (Ridge CV)", y_train, yhat_tr_ridge)
print_metrics("Val   (Ridge CV)", y_val,   yhat_va_ridge)

# OPTIONAL: refit on Train+Val using the chosen alpha, then evaluate on Test
final_ridge = Pipeline([
    ("scaler", StandardScaler(with_mean=True, with_std=True)),
    ("ridge",  Ridge(alpha=best_alpha, random_state=SEED))
])
final_ridge.fit(np.vstack([X_train, X_val]), np.hstack([y_train, y_val]))

yhat_te_ridge = final_ridge.predict(X_test)
print_metrics("\nTest  (Ridge final)", y_test, yhat_te_ridge)

# Visual diagnostics — Predicted vs Actual for Ridge (final)
plot_pred_vs_actual(y_train, yhat_tr_ridge, "Predicted vs Actual (Train, Ridge CV model)")
plot_pred_vs_actual(y_val,   yhat_va_ridge, "Predicted vs Actual (Validation, Ridge CV model)")
plot_pred_vs_actual(y_test,  yhat_te_ridge, "Predicted vs Actual (Test, Ridge final model)")

# Coefficient heatmap (Ridge, mapped back to raw pixel space)
#  ŷ = intercept + Σ w_scaled_i * (x_i - mean_i)/std_i
#  ⇒ effective raw‑space weight: w_raw_i = w_scaled_i / std_i
#     effective raw intercept:  b_raw = intercept - Σ w_scaled_i * mean_i / std_i
scaler = final_ridge.named_steps["scaler"]
ridge  = final_ridge.named_steps["ridge"]
w_scaled = ridge.coef_.ravel()
w_raw = w_scaled / (scaler.scale_ + 1e-12)
coef_ridge_raw = w_raw.reshape(H, W)
show_heatmap(coef_ridge_raw, title="Ridge (final): pixel weights (raw‑space)", cmap=None, with_colorbar=True)

# Brief takeaways
print("\nTakeaways:")
print("  • With 10,000 features and ~640 training samples (p ≫ n), OLS can nearly interpolate the training labels,")
print("    often giving R²≈1 on Train. That’s expected and is a symptom of high capacity, not necessarily true signal.")
print("  • Ridge regularisation stabilises the solution, improves conditioning, and typically gives better generalisation.")
print("  • The learned weight maps concentrate on the ring, which is exactly where the signal lies.")
print("  • ‘Noise ceiling’ gives the best possible R² on Test when labels contain measurement noise;")
print("    if your model approaches it, you are close to optimal under the present noise level.")


# CNN regressor on hollow‑sphere images

## **Purpose of this cell**  

### Train a convolutional regressor that predicts the scalar target from each image. The network is designed to succeed on both the *easy* dataset (no star) and the *hard* dataset (with a tiny star on the shell), while remaining compatible with Deep SHAP. 
---

## 1) Inputs, splits, and preprocessing

**Inputs.** Each image is a single‑channel array of shape $H\times W=100\times100$ with values in $[0,1]$. The target is a scalar $y$ per image (observed label), with a noise‑free companion $y^{\text{true}}$ used only for diagnostics later.

**Splits.** A fixed seed produces: Train 64%, Validation 16%, Test 20% (same recipe as the linear cell).

**Background subtraction.** For each image, the median of the four $k\times k$ corner patches (default $k=10$) estimates a background offset $b$. The corrected image uses
$$
I'(x,y)=\max\{\,I(x,y)-b,\,0\,\}.
$$

**Global normalisation (train statistics).** Let $\mu=\operatorname{mean}(I'_{\text{train}})$ and $\sigma=\operatorname{std}(I'_{\text{train}})$. Normalise every split with
$$
\tilde I(x,y)=\operatorname{clip}\!\left(\frac{I'(x,y)-\mu}{\sigma},\, -8,\ 8\right).
$$

**CoordConv channels.** Three coordinate maps are concatenated to the image channel:
- $x_{\text{lin}}=(x-c_x)/C$,
- $y_{\text{lin}}=(y-c_y)/C$,
- $r=\sqrt{x_{\text{lin}}^2+y_{\text{lin}}^2}$ (rescaled to $[0,1]$),
with $C=(H-1)/2$. These channels provide translation‑/rotation‑aware geometry; the shell and the star both live on a ring in $r$.

**Data augmentation.** Random horizontal and vertical flips. Arrays are re‑made contiguous after flipping to avoid “negative stride” issues when converting to PyTorch tensors.

---

## 2) Target standardisation

The network predicts a **z‑scored** target on the training split:
$$
z=\frac{y-\mu_y}{\sigma_y},\qquad \hat z=f_\theta(\tilde I, x_{\text{lin}}, y_{\text{lin}}, r).
$$
Predictions are mapped back to label space with
$$
\hat y=\sigma_y\,\hat z+\mu_y.
$$
This stabilises optimisation and aligns scales across experiments.

---

## 3) Dataset and loaders

A `Dataset` returns triples $(X, z, y)$ where $X\in\mathbb{R}^{4\times H\times W}$ contains `[image, x, y, r]`.  
Loaders use a moderate batch size (128 on CUDA/MPS, 64 on CPU).

---

## 4) Network architecture (SHAP‑compatible)

The model exposes `.features`, `.gap`, and `.head` with **exactly two `Linear` layers** in `.head`. All activations are non‑in‑place (`ReLU(inplace=False)`), and the forward method returns a cloned tensor to keep Deep SHAP happy.

### 4.1 Fixed image stem (edge/contrast helpers)

From the image channel only, two fixed maps are computed:
- Sobel gradient magnitude $G=\sqrt{(I_x)^2+(I_y)^2}$,
- Local high‑pass $H=I-\text{mean3}(I)$,
then concatenated: `[image, x, y, r, G, H]` → **6 channels**. These maps are part of the graph (differentiable) but non‑trainable, so external inputs remain 4‑channel.

### 4.2 Ring priors and “starness” branch

From the radius map $r$, three Gaussian bands act as **priors** for the shell:
$$
\text{ring}(r;r_0,\sigma)=\exp\!\left(-\tfrac12\Big(\tfrac{r-r_0}{\sigma}\Big)^2\right).
$$
- Mid‑shell: $r_0\approx0.435,\ \sigma\approx0.060$  
- Inner edge: $r_0\approx0.283,\ \sigma\approx0.040$  
- Outer edge: $r_0\approx0.566,\ \sigma\approx0.040$

A lightweight starness branch runs two $3\times3$ convolutions with ReLU on the **gated image** $I \times \text{ring}_{\text{mid}}$ to produce a `star_map`. A broadcast global‑mean map $g(x,y)=\operatorname{mean}(I)$ supplies scene‑wide brightness context. Concatenation yields **11 channels**:
```
[image, x, y, r, G, H, ring_mid, edge_inner, edge_outer, star_map, gmean]
```

### 4.3 Convolutional trunk

A sequence of Conv → ReLU → small residual blocks, with two max‑pools:
- 100→50 and 50→25 spatial downsamples.
- No BatchNorm/GroupNorm (avoids instability on MPS, and keeps the graph simple for SHAP).

### 4.4 Global pooling that preserves tiny cues

`.gap` is a small custom module:
- Global **mean** pooling per channel,
- Concatenated with global **max** pooling per channel,
so a tensor with $C$ channels becomes $2C$. Formally, for channel $k$,
$$
p^{\text{avg}}_k=\frac{1}{HW}\sum_{i,j} A_{kij},\qquad
p^{\text{max}}_k=\max_{i,j} A_{kij},\qquad
p=\big[p^{\text{avg}},\,p^{\text{max}}\big].
$$
The max path ensures that a small, bright star can dominate the pooled descriptor even if its area is tiny.

### 4.5 Head (two `Linear` layers)

```
Flatten → Linear(2C,128) → ReLU → Dropout(0.10) → Linear(128,1)
```
This outputs $\hat z$ (a scalar). Returning `x.clone()` in `forward` avoids autograd view/in‑place quirks that sometimes trip Deep SHAP.

---

## 5) Loss, optimisation, and schedules

**Loss (MSE in z‑space).**
$$
\mathcal{L}(\theta)=\frac{1}{N}\sum_{i=1}^{N}\big(\hat z_i - z_i\big)^2.
$$

**Optimiser and schedule.** AdamW with weight decay $2\times10^{-4}$, initial learning rate $6\times10^{-4}$, cosine annealing to $10^{-5}$.

**Regularisation and stability.**
- Gradient clipping at unit norm.
- Early stopping on validation loss (patience 20 epochs).
- Non‑finite‑loss guard: if a batch yields non‑finite loss (rare on MPS), halve the learning rate and skip that step.

---

## 6) Metrics and plots

After restoring the best validation checkpoint, the cell reports for Train/Val/Test:
- $R^2=1-\dfrac{\sum(y-\hat y)^2}{\sum(y-\bar y)^2}$,
- MAE and RMSE in label space.  
The “noise ceiling” estimate on Test quantifies the best possible $R^2$ given the label noise.

Two diagnostic plot types are produced:
1) Training curves (loss vs epoch).  
2) Predicted‑vs‑Actual scatter, with the identity line for reference.

A small panel shows Original image → Background‑subtracted input → textual triplet $(y^{\text{true}}, y^{\text{obs}}, \hat y)$.

---

## 7) Why this CNN works on both datasets

- **Easy (no star):** The target is largely a global density scale. The global‑mean map and the average‑pool path allow the model to capture this quickly. Ring priors and edge‑aware features encourage a focus on the shell rather than the background.
- **Hard (with star):** The star is a compact, high‑contrast cue located on the shell. The starness branch creates a concentrated activation where a star‑like pattern aligns with the mid‑ring. The **max** path in ConcatPool ensures this activation can drive the prediction even with tiny area. CoordConv provides the geometry to separate inner/outer edges from mid‑ring content, improving localisation.

---

## 8) SHAP compatibility and later interpretability

- The class exposes `.features`, `.gap`, and `.head`, with exactly **two** `Linear` layers in the head, matching downstream SHAP assumptions.
- All activations are non‑in‑place. The forward pass returns a clone.  
- Deep SHAP is typically wrapped in a small compatibility class that replaces `nn.Flatten` with a functional flatten; additivity checks are disabled when needed.

This allows downstream analyses to: select a good baseline, draw overlays of SHAP attributions, compute deletion curves, and run component‑aware bounding boxes that verify whether the model locks onto the star.

---

## 9) Practical checks and troubleshooting

- If validation loss becomes `NaN` on MPS, reduce the learning rate and verify that inputs are finite after normalisation. The guard in the training loop already performs a back‑off.
- Confirm that background subtraction and normalisation reuse **training** statistics for all splits; otherwise leakage or distribution shift can degrade generalisation.
- When porting to a different geometry, retune the ring‑prior centres and widths; the general form
  $$\text{ring}(r;r_0,\sigma)=\exp\!\left(-\tfrac12\big((r-r_0)/\sigma\big)^2\right)$$
  remains the same.
- To emphasise microscopic cues further, bias the pooling by increasing the weight of the max branch (e.g., concatenate `[mean, max, max]`) and adjust the first Linear layer shape accordingly.

---

## 10) Take‑home message

The architecture deliberately separates three responsibilities:
1) **Global brightness** (handled by average pooling and the global‑mean map),
2) **Shell geometry** (handled by CoordConv radius and ring priors),
3) **Tiny local anomalies** such as a star (handled by the starness branch and the max‑pool path).

This separation makes the network robust, interpretable with SHAP, and effective across both easy and hard variants of the dataset.


In [None]:
# ============================================================
# Cell 3 — CNN that learns global density + shell edges + a tiny star if there
# (SHAP‑compatible, M1‑safe, robust training)
# ============================================================

import numpy as np
import matplotlib.pyplot as plt
import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split

# ----- Preconditions
if "images" not in globals() or "y_true_all" not in globals() or "y_obs_all" not in globals():
    raise RuntimeError("Please run Cell 1 (data generation) first.")
torch.set_grad_enabled(True)

# ----- Device (M1 → MPS if available)
if torch.cuda.is_available():
    DEVICE = "cuda"
elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
    DEVICE = "mps"
else:
    DEVICE = "cpu"
print(f"Using device: {DEVICE}")

# ----- Repro
SEED = 2025
np.random.seed(SEED); torch.manual_seed(SEED)
if DEVICE == "cuda": torch.cuda.manual_seed_all(SEED)

# ----- Split (same recipe as Cell 2)
TEST_FRAC = 0.20
VAL_FRAC_WITHIN_TRAINVAL = 0.20

N, H, W = images.shape
CENTER = (H - 1) / 2.0

X_all = images
y_all = y_obs_all.astype(np.float32)
y_true_all_local = y_true_all.astype(np.float32)

X_trainval, X_test, y_trainval, y_test, y_true_trainval, y_true_test = train_test_split(
    X_all, y_all, y_true_all_local, test_size=TEST_FRAC, random_state=SEED
)
X_train, X_val, y_train, y_val, y_true_train, y_true_val = train_test_split(
    X_trainval, y_trainval, y_true_trainval, test_size=VAL_FRAC_WITHIN_TRAINVAL, random_state=SEED
)

print("Split summary:")
print(f"  Train: {X_train.shape[0]}  | Val: {X_val.shape[0]}  | Test: {X_test.shape[0]}")

# ----- Preprocessing (corner‑median background → train‑stats normalisation)
ys, xs = np.indices((H, W))

def estimate_background_offset(img: np.ndarray, k: int = 10) -> float:
    patches = [img[:k, :k], img[:k, -k:], img[-k:, :k], img[-k:, -k:]]
    return float(np.median(np.concatenate([p.ravel() for p in patches])))

def subtract_background(X: np.ndarray) -> np.ndarray:
    Xp = np.empty_like(X)
    for i, img in enumerate(X):
        bg = estimate_background_offset(img, k=10)
        im = img - bg
        im[im < 0.0] = 0.0
        Xp[i] = im
    return Xp

X_train_s = subtract_background(X_train)
X_val_s   = subtract_background(X_val)
X_test_s  = subtract_background(X_test)

pix_mean = float(X_train_s.mean())
pix_std  = float(X_train_s.std() + 1e-8)

def norm_clamp(Xs):
    Xn = (Xs - pix_mean) / pix_std
    Xn = np.clip(Xn, -8.0, 8.0).astype(np.float32)
    return Xn

X_train_n = norm_clamp(X_train_s)
X_val_n   = norm_clamp(X_val_s)
X_test_n  = norm_clamp(X_test_s)

# CoordConv channels (exported for SHAP)
x_lin = (xs - CENTER) / CENTER
y_lin = (ys - CENTER) / CENTER
r_map = np.sqrt(x_lin**2 + y_lin**2)
r_map = r_map / (r_map.max() + 1e-12)

coord_tensor = torch.from_numpy(np.stack([x_lin, y_lin, r_map], axis=0).astype(np.float32))
coord_hflip  = torch.flip(coord_tensor, dims=[2])
coord_vflip  = torch.flip(coord_tensor, dims=[1])
coord_bflip  = torch.flip(coord_tensor, dims=[1, 2])

# ----- Target scaling (z‑score on TRAIN)
y_mean = float(np.mean(y_train))
y_std  = float(np.std(y_train) + 1e-8)
def to_z(y):   return ((y - y_mean) / y_std).astype(np.float32)
def from_z(z): return z * y_std + y_mean

z_train = to_z(y_train); z_val = to_z(y_val); z_test = to_z(y_test)

# ----- Augmentation (flips; always contiguous → no negative‑stride errors)
def random_flip_image_and_coords(im: np.ndarray):
    hflip = (np.random.rand() < 0.5)
    vflip = (np.random.rand() < 0.5)
    if hflip: im = im[:, ::-1]
    if vflip: im = im[::-1, :]
    im = np.ascontiguousarray(im, dtype=np.float32)
    if   hflip and  vflip: coord_use = coord_bflip
    elif hflip and not vflip: coord_use = coord_hflip
    elif not hflip and vflip: coord_use = coord_vflip
    else: coord_use = coord_tensor
    return im, coord_use

class ImageRegDataset(Dataset):
    def __init__(self, Xn, z, y, train: bool):
        self.Xn = Xn; self.z = z; self.y = y; self.train = train
    def __len__(self): return self.Xn.shape[0]
    def __getitem__(self, i):
        im = self.Xn[i]
        if self.train:
            im, coord_use = random_flip_image_and_coords(im)
        else:
            im = np.ascontiguousarray(im, dtype=np.float32)
            coord_use = coord_tensor
        x_img = torch.from_numpy(im).unsqueeze(0)              # (1,H,W)
        x     = torch.cat([x_img, coord_use], dim=0)           # (4,H,W)
        z     = torch.tensor(self.z[i:i+1], dtype=torch.float32)
        y     = torch.tensor(self.y[i:i+1], dtype=torch.float32)
        return x, z, y

BATCH_SIZE = 128 if DEVICE in ("cuda", "mps") else 64
PIN = (DEVICE == "cuda")
train_loader = DataLoader(ImageRegDataset(X_train_n, z_train, y_train, train=True),
                          batch_size=BATCH_SIZE, shuffle=True, pin_memory=PIN, num_workers=0)
val_loader   = DataLoader(ImageRegDataset(X_val_n,   z_val,   y_val,   train=False),
                          batch_size=BATCH_SIZE, shuffle=False, pin_memory=PIN, num_workers=0)
test_loader  = DataLoader(ImageRegDataset(X_test_n,  z_test,  y_test,  train=False),
                          batch_size=BATCH_SIZE, shuffle=False, pin_memory=PIN, num_workers=0)

# ----- Fixed helper maps (inside the net; external input stays 4‑ch)
class FixedImageStem(nn.Module):
    """Sobel grad‑mag + local contrast from image channel."""
    def __init__(self):
        super().__init__()
        sobel_x = torch.tensor([[1,0,-1],[2,0,-2],[1,0,-1]], dtype=torch.float32).view(1,1,3,3)
        sobel_y = sobel_x.transpose(2,3).contiguous()
        mean3   = torch.full((1,1,3,3), 1.0/9.0, dtype=torch.float32)
        self.conv_sx = nn.Conv2d(1,1,3,padding=1,bias=False)
        self.conv_sy = nn.Conv2d(1,1,3,padding=1,bias=False)
        self.blur3   = nn.Conv2d(1,1,3,padding=1,bias=False)
        with torch.no_grad():
            self.conv_sx.weight.copy_(sobel_x)
            self.conv_sy.weight.copy_(sobel_y)
            self.blur3.weight.copy_(mean3)
        for p in self.parameters(): p.requires_grad = False
    def forward(self, x4):
        img = x4[:, :1]
        gx = self.conv_sx(img); gy = self.conv_sy(img)
        gradmag   = torch.sqrt(torch.clamp(gx*gx + gy*gy, min=1e-12))
        localmean = self.blur3(img)
        highpass  = img - localmean
        return torch.cat([x4, gradmag, highpass], dim=1)  # 4 → 6 ch

# ----- Ring priors (from r channel) + starness branch
class RingPriorsAndStar(nn.Module):
    """
    Inside‑net priors:
      • ring_mid, edge_inner, edge_outer from r channel (Gaussian bands)
      • starness: convs on (img × ring_mid) to produce a focused 'star map'
    Returns: concatenation of [x6, priors(3), star_map(1), global_mean_map(1)] → 6 + 3 + 1 + 1 = 11 channels
    """
    def __init__(self):
        super().__init__()
        # starness sub-net (lightweight, high gain)
        self.s_c1 = nn.Conv2d(1, 16, 3, padding=1)
        self.s_a1 = nn.ReLU(inplace=False)
        self.s_c2 = nn.Conv2d(16, 1, 3, padding=1)
        nn.init.kaiming_normal_(self.s_c1.weight, nonlinearity="relu")
        nn.init.kaiming_normal_(self.s_c2.weight, nonlinearity="relu")
        nn.init.zeros_(self.s_c1.bias); nn.init.zeros_(self.s_c2.bias)

    def forward(self, x6):
        # x6 = [img, x, y, r, gradmag, highpass]
        img = x6[:, :1]
        r   = x6[:, 3:4]
        # ring priors (values tuned to your geometry: inner~0.283, mid~0.435, outer~0.566)
        ring_mid   = torch.exp(-0.5*((r - 0.435)/0.060)**2)
        edge_inner = torch.exp(-0.5*((r - 0.283)/0.040)**2)
        edge_outer = torch.exp(-0.5*((r - 0.566)/0.040)**2)
        # starness on img × ring_mid
        z = img * ring_mid
        z = self.s_a1(self.s_c1(z))
        star_map = torch.relu(self.s_c2(z)) * ring_mid  # keep it on the ring
        # global context map (mean intensity of img)
        gmean = img.mean(dim=(2,3), keepdim=True).expand_as(img)
        return torch.cat([x6, ring_mid, edge_inner, edge_outer, star_map, gmean], dim=1)  # 11 ch

# ----- A small, norm‑free residual block
class ResBlock(nn.Module):
    def __init__(self, c):
        super().__init__()
        self.pad1 = nn.ReflectionPad2d(1); self.c1 = nn.Conv2d(c, c, 3)
        self.act  = nn.ReLU(inplace=False)
        self.pad2 = nn.ReflectionPad2d(1); self.c2 = nn.Conv2d(c, c, 3)
        nn.init.kaiming_normal_(self.c1.weight, nonlinearity="relu")
        nn.init.kaiming_normal_(self.c2.weight, nonlinearity="relu")
        if self.c1.bias is not None: nn.init.zeros_(self.c1.bias)
        if self.c2.bias is not None: nn.init.zeros_(self.c2.bias)
    def forward(self, x):
        y = self.c1(self.pad1(x)); y = self.act(y); y = self.c2(self.pad2(y))
        return self.act(y + x)

class ConcatPool2d(nn.Module):
    def forward(self, x):
        return torch.cat([torch.mean(x, dim=(2,3), keepdim=True),
                          torch.amax(x, dim=(2,3), keepdim=True)], dim=1)

# ----- Features extractor
class FeatureExtractor(nn.Module):
    def __init__(self):
        super().__init__()
        self.stem      = FixedImageStem()          # 4 → 6
        self.ring_star = RingPriorsAndStar()       # 6 → 11
        self.pad0 = nn.ReflectionPad2d(1); self.c0 = nn.Conv2d(11, 64, 3); self.a0 = nn.ReLU(inplace=False)
        self.b1 = ResBlock(64); self.pool1 = nn.MaxPool2d(2)     # 100 → 50
        self.pad1 = nn.ReflectionPad2d(1); self.c1 = nn.Conv2d(64, 128, 3); self.a1 = nn.ReLU(inplace=False)
        self.b2 = ResBlock(128); self.pool2 = nn.MaxPool2d(2)    # 50 → 25
        self.pad2 = nn.ReflectionPad2d(1); self.c2 = nn.Conv2d(128, 160, 3); self.a2 = nn.ReLU(inplace=False)
        self.b3 = ResBlock(160)
        for m in [self.c0, self.c1, self.c2]:
            nn.init.kaiming_normal_(m.weight, nonlinearity="relu")
            if m.bias is not None: nn.init.zeros_(m.bias)
    def forward(self, x4):
        x6  = self.stem(x4)
        x11 = self.ring_star(x6)
        y = self.a0(self.c0(self.pad0(x11))); y = self.b1(y); y = self.pool1(y)
        y = self.a1(self.c1(self.pad1(y)));  y = self.b2(y); y = self.pool2(y)
        y = self.a2(self.c2(self.pad2(y)));  y = self.b3(y)
        return y

# ----- Regressor (SHAP‑compatible: .features / .gap / .head ; TWO Linear)
class CNNRegressor(nn.Module):
    def __init__(self, in_ch=4):
        super().__init__()
        self.features = FeatureExtractor()
        self.gap      = ConcatPool2d()             # preserves “max” path for a tiny star
        self.head     = nn.Sequential(
            nn.Flatten(),
            nn.Linear(2*160, 128), nn.ReLU(inplace=False),
            nn.Dropout(p=0.10),
            nn.Linear(128, 1),
        )
        for m in self.modules():
            if isinstance(m, nn.Linear):
                nn.init.kaiming_normal_(m.weight, nonlinearity="relu")
                if m.bias is not None: nn.init.zeros_(m.bias)
    def forward(self, x):
        x = self.features(x)
        x = self.gap(x)
        x = self.head(x)
        return x.clone()   # SHAP‑safe

model = CNNRegressor(in_ch=4).to(DEVICE)
print(model)
print(f"Parameters: {sum(p.numel() for p in model.parameters()):,}")

# ----- Training (robust but simple; M1‑safe)
LR = 6e-4
WEIGHT_DECAY = 2e-4
EPOCHS = 140
PATIENCE = 20
criterion = nn.MSELoss()
optimiser = torch.optim.AdamW(model.parameters(), lr=LR, weight_decay=WEIGHT_DECAY)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimiser, T_max=EPOCHS, eta_min=1e-5)

best_val = float("inf"); best_state = None
train_losses, val_losses = [], []; no_improve = 0

def _finite(x): return torch.isfinite(x).all().item()

for epoch in range(1, EPOCHS+1):
    # Train
    model.train()
    loss_sum = 0.0
    for xb, zb, _yb in train_loader:
        xb, zb = xb.to(DEVICE), zb.to(DEVICE)
        optimiser.zero_grad(set_to_none=True)
        pred_z = model(xb)
        loss = criterion(pred_z, zb)
        if not _finite(loss):
            # rare on MPS: back off LR and skip this step
            for g in optimiser.param_groups: g["lr"] = max(g["lr"]*0.5, 1e-5)
            continue
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        optimiser.step()
        loss_sum += float(loss.item()) * xb.size(0)
    train_loss = loss_sum / len(train_loader.dataset); train_losses.append(train_loss)

    # Validate
    model.eval()
    loss_sum = 0.0
    with torch.no_grad():
        for xb, zb, _yb in val_loader:
            xb, zb = xb.to(DEVICE), zb.to(DEVICE)
            pred_z = model(xb)
            loss_sum += float(criterion(pred_z, zb).item()) * xb.size(0)
    val_loss = loss_sum / len(val_loader.dataset); val_losses.append(val_loss)
    scheduler.step()

    print(f"Epoch {epoch:03d} | Train L={train_loss:.6f} | Val L={val_loss:.6f} | LR={scheduler.get_last_lr()[0]:.2e}")

    if val_loss < best_val - 1e-7:
        best_val = val_loss
        best_state = {k: v.detach().cpu().clone() for k, v in model.state_dict().items()}
        no_improve = 0
    else:
        no_improve += 1
        if no_improve >= PATIENCE:
            print(f"Early stopping (patience={PATIENCE}).")
            break

# Restore best
if best_state is not None:
    model.load_state_dict(best_state)

# ----- Metrics in y‑space
def metrics_from_loader(name, loader):
    model.eval()
    preds_y, trues_y = [], []
    with torch.no_grad():
        for xb, zb, yb in loader:
            xb = xb.to(DEVICE)
            zhat = model(xb).cpu().numpy().squeeze(-1)
            yhat = from_z(zhat)
            preds_y.append(yhat.astype(np.float32))
            trues_y.append(yb.numpy().squeeze(-1).astype(np.float32))
    y_pred = np.concatenate(preds_y); y_true = np.concatenate(trues_y)
    ss_res = float(((y_true - y_pred)**2).sum())
    ss_tot = float(((y_true - y_true.mean())**2).sum() + 1e-12)
    r2   = 1.0 - ss_res/ss_tot
    mae  = float(np.mean(np.abs(y_true - y_pred)))
    rmse = float(np.sqrt(np.mean((y_true - y_pred)**2)))
    print(f"[{name}] R² = {r2:.4f} | MAE = {mae:.4f} | RMSE = {rmse:.4f}")
    return y_true, y_pred, r2, mae, rmse

print("\nPerformance (against observed y):")
y_tr, yhat_tr, r2_tr, _, _ = metrics_from_loader("Train",      train_loader)
y_va, yhat_va, r2_va, _, _ = metrics_from_loader("Validation", val_loader)
y_te, yhat_te, r2_te, _, _ = metrics_from_loader("Test",       test_loader)

# Noise ceiling on Test (label noise only)
r2_ceiling_test = 1.0 - float(((y_true_test - y_te)**2).sum()) / float(((y_te - y_te.mean())**2).sum() + 1e-12)
print(f"\n[Reference] Test noise‑ceiling R² (label noise): {r2_ceiling_test:.4f}\n")

# ----- Plots
plt.figure(figsize=(6,4))
plt.plot(train_losses, label="Train loss"); plt.plot(val_losses, label="Val loss")
plt.xlabel("Epoch"); plt.ylabel("Loss"); plt.title("Training curves (CNN — ring prior + starness)")
plt.legend(); plt.grid(True, alpha=0.3); plt.tight_layout(); plt.show()

def plot_pred_vs_actual(y_actual, y_pred, title):
    plt.figure(figsize=(5,4))
    plt.scatter(y_actual, y_pred, s=14, alpha=0.75)
    lo = min(float(np.min(y_actual)), float(np.min(y_pred)))
    hi = max(float(np.max(y_actual)), float(np.max(y_pred)))
    pad = 0.02*(hi - lo) if hi > lo else 0.01
    plt.plot([lo - pad, hi + pad], [lo - pad, hi + pad], linestyle="--", linewidth=1.0)
    plt.xlabel("Actual y"); plt.ylabel("Predicted y"); plt.title(title)
    plt.grid(True, alpha=0.3); plt.tight_layout(); plt.show()

plot_pred_vs_actual(y_tr, yhat_tr, "Predicted vs Actual (Train, CNN)")
plot_pred_vs_actual(y_va, yhat_va, "Predicted vs Actual (Validation, CNN)")
plot_pred_vs_actual(y_te, yhat_te, "Predicted vs Actual (Test, CNN)")

# Small test panel (original → preprocessed → prediction)
n_show = min(6, len(X_test))
sel = np.random.choice(len(X_test), size=n_show, replace=False)
fig, axes = plt.subplots(3, n_show, figsize=(2.0*n_show, 5.8))
for j, k in enumerate(sel):
    axes[0, j].imshow(X_test[k],   cmap="gray", vmin=0, vmax=1); axes[0, j].set_title("Original");      axes[0, j].axis("off")
    axes[1, j].imshow(X_test_s[k], cmap="gray", vmin=0, vmax=1); axes[1, j].set_title("Bg‑subtracted"); axes[1, j].axis("off")
    axes[2, j].text(0.02, 0.80, f"y_true≈{y_true_test[k]:.3f}", transform=axes[2, j].transAxes)
    axes[2, j].text(0.02, 0.55, f"y_obs ={y_test[k]:.3f}",      transform=axes[2, j].transAxes)
    axes[2, j].text(0.02, 0.30, f"y_hat ={yhat_te[k]:.3f}",     transform=axes[2, j].transAxes)
    axes[2, j].axis("off")
axes[0,0].set_ylabel("Input"); axes[1,0].set_ylabel("Preproc"); axes[2,0].set_ylabel("Labels")
plt.suptitle("CNN predictions on held‑out test samples", fontsize=12)
plt.tight_layout(); plt.show()

print("Takeaways:")
print("  • Ring priors tell the model *where* to look; the starness branch makes a tiny bright cue dominate via max pooling.")
print("  • A global-mean map gives an easy path to the overall scale so capacity focuses on local features.")



# Deep SHAP on linear and CNN models

**Purpose of this cell**  
Explain and validate model attributions for three regressors trained on the hollow‑sphere images:

- **OLS**: Ordinary Least Squares on flattened pixels.  
- **Ridge**: Standardised pixels + L2‑regularised linear regression.  
- **CNN**: 4‑channel input (image + $x$, $y$, $r$) with a ring‑aware, star‑sensitive architecture.

The analysis produces SHAP/Deep SHAP maps, compares baselines, quantifies attribution fidelity (deletion/AOPC), checks where the explanations sit relative to the **ground‑truth shell**, and demonstrates whether the CNN **locks onto the star** globally. The notes below aim to make each step interpretable and reproducible.

---

## 1) Quick refresher: what SHAP values mean

A SHAP value $\phi_i(f,x)$ reports the contribution of feature $i$ to the prediction $f(x)$ **relative to a baseline** $x'$. In the classical Shapley formulation (from cooperative game theory), for a model with feature set $F$ and $M=|F|$ features,
$$
\phi_i(f,x) \;=\; \sum_{S \subseteq F\setminus\{i\}}\frac{|S|!\,\big(M-|S|-1\big)!}{M!}
\;\Big(f_{S\cup\{i\}}(x) - f_S(x)\Big),
$$
where $f_S$ is the model evaluated when only the subset $S$ is “present” and all other features are clamped to the baseline. Two properties guide interpretation:

- **Local accuracy (additivity).**  
  $$ f(x) = f(x') + \sum_{i=1}^{M}\phi_i(f,x). $$

- **Symmetry/efficiency.** Features that influence $f$ identically receive equal credit; total credit equals the prediction difference $f(x)-f(x')$.

In images, “features” are often **pixels** or **channels×pixels**. The baseline $x'$ is typically a *background distribution* (e.g., a set of training images) or a single reference (e.g., the mean image).

---

## 2) Why three different explainers are used

### 2.1 Linear models (OLS, Ridge) — exact, fast SHAP
For linear regression $f(x)=w^\top x + b$ with a baseline distribution $X'$ such that $\mathbb{E}[X']=\bar x$, the **interventional** SHAP values reduce to
$$
\phi_i(f,x) = w_i\,(x_i-\bar x_i), \quad \text{and}\quad f(x')=w^\top\bar x + b.
$$
Hence SHAP for OLS and Ridge is exact under the chosen baseline, and **LinearExplainer** is both fast and faithful.

**Intuition.** Hot pixels on the ring receive large positive (or negative) $\phi$ if the corresponding coefficient $w_i$ is large in magnitude; zero‑mean standardisation in Ridge changes the scale but not the location of these attributions.

### 2.2 CNN — Deep SHAP (DeepExplainer)
Deep SHAP for PyTorch approximates Shapley values via **DeepLIFT‑style** rules and path integrations from a baseline $x'$ to the input $x$. One useful mental model is *Integrated Gradients* (IG) with multiple baselines $\{x'^{(b)}\}$:
$$
\phi_i^{\text{IG}}(f,x;x') \approx (x_i-x'_i)\int_0^1 
\frac{\partial f\big(x'+t(x-x')\big)}{\partial x_i}\,dt,
\qquad
\phi_i^{\text{DeepSHAP}}(f,x) \approx \frac{1}{B}\sum_{b=1}^B \phi_i^{\text{IG}}(f,x;x'^{(b)}).
$$
DeepExplainer applies exact rules on piecewise‑linear parts and approximations on others. Two practical adjustments are made in the code to keep it robust:

1. **SHAP‑compatible wrapper** (`SHAPCompatCNN`). This removes `nn.Flatten` as a module, disables in‑place ReLUs, and copies the two linear layers explicitly, so the graph stays simple and traceable.
2. **`check_additivity=False`.** This disables strict additivity checks, which can fail with non‑supported ops (e.g., `sqrt` in the fixed Sobel stem) or small numerical drift. Additivity can still be inspected manually by summing attributions and comparing to $f(x)-f(x')$.

**Baseline choice matters.** A poor baseline can make any attribution method look inconsistent. This cell *evaluates several baselines* and selects one by fidelity metrics (see §6).

---

## 3) Preliminaries and splits

- **Device.** CUDA if available, otherwise Apple MPS, otherwise CPU.  
- **Warnings.** Harmless SHAP warnings about unrecognised modules (e.g., `ReflectionPad2d`, the custom `ConcatPool2d`) are suppressed to keep the output clean.  
- **Splits.** The same 64/16/20 Train/Val/Test split as earlier cells ensures metric comparability. A fixed `rng` seed stabilises sampling for background sets and the test subset used for explanations.

A small **test subset** (default 24 images) is sampled from the Test split to keep the SHAP runs snappy while remaining representative for visualisation and metrics.

---

## 4) Common geometry and regional metrics

Three radial regions are defined using the known sphere geometry:
- **Hollow**: $r \le 18$ px.  
- **Shell**: $18 < r \le 42$ px.  
- **Outside**: $r > 42$ px.

Given any attribution map $S\in\mathbb{R}^{H\times W}$, the analysis reports:

- **Region shares** (area‑biased): fraction of $\sum|S|$ that falls in each region.  
- **Region means** (area‑fair): mean $|S|$ per pixel in each region.  
- **Top‑$p$ concentration**: the fraction of total $|S|$ contained in the top $p\%$ pixels by $|S|$. High concentration indicates focus.

A **radial profile** $\bar s(r)$ is also shown by binning pixels into $r$‑bands and averaging $|S|$ within each band. Expect peaks near the inner/outer shell edges for models that key off shell structure.

**Good vs bad.**
- *Good* (with star): high share/mean in the **shell**, high **top‑1% concentration** (evidence of a tight cue), and a radial profile with sharp peaks near the shell radius.  
- *Bad*: substantial attribution outside the shell or a flat radial profile (model not localising the physics).

---

## 5) Linear explainers (OLS and Ridge)

### 5.1 Background sets
- **OLS.** Background = a random subset of flattened train images; this sets $\bar x$ in the linear SHAP formula.  
- **Ridge.** The model operates on **standardised** features $Z=(X-\mu)/\sigma$. SHAP is computed in $Z$‑space and reshaped back to image‐space. The map reflects $w_{\text{ridge}}$ scaled by the feature z‑scores.

### 5.2 Expected qualitative results
- OLS: crisp ring attributions; may over‑fit noise in Train if $p\gg n$.  
- Ridge: ring remains but is smoother; coefficients are shrunk toward zero, so diffuse patterns appear.

**Fidelity check.** *Deletion curves* (see §7) should show a noticeable drop in $R^2$ once the top 1–5% of pixels by $|$SHAP$|$ are removed or replaced by baseline values.

---

## 6) CNN attributions — baselines and selection by fidelity

### 6.1 Building CNN inputs for SHAP
Inputs follow the **training recipe** exactly: corner median background subtraction, train‑statistics normalisation, and appending the three CoordConv maps $(x,y,r)$. A mismatch here is the most common source of odd attributions.

### 6.2 Candidate baselines
Four baselines are considered:
1. **Train subset**: $m$ random training samples (recommended).  
2. **Mean image**: single image channel = train mean; $(x,y,r)$ channels kept as deterministic maps.  
3. **Median image**: robust to outliers.  
4. **Blurred mean**: low‑pass filtered mean; often improves stability.

### 6.3 How the “best” baseline is chosen
Three fidelity signals are computed for each candidate:

- **Shell IoU/Precision** at $k\in\{1\%,5\%\}$ of top $|$SHAP$|$ pixels:  
  $$
  \text{IoU}=\frac{|M_k\cap \text{Shell}|}{|M_k\cup \text{Shell}|},\qquad
  \text{Precision}=\frac{|M_k\cap \text{Shell}|}{|M_k|}.
  $$

- **Deletion AOPC** (area over the performance curve): remove the top‑$k$% $|$SHAP$|$ pixels in the **image** channel (keeping $(x,y,r)$ intact), recompute $R^2(k)$ and integrate the drop:
  $$
  \mathrm{AOPC}=\int_{k\in\mathcal{K}}\big(R^2(0)-R^2(k)\big)\,dk.
  $$

A simple normalised score combines IoU@1%, IoU@5% and AOPC, and the baseline with the largest score is chosen. This is a pragmatic proxy for *faithful, focused* explanations on this dataset.

**Good vs bad.**
- *Good baseline*: high IoU/Precision against the shell and large AOPC (removing top‑|SHAP| harms performance quickly).  
- *Bad baseline*: low IoU and small AOPC; attributions likely smear over background or contradict physics.

---

## 7) Deletion curves (OLS, Ridge, CNN)

**Purpose.** Test whether high‑magnitude SHAP pixels are **causally important** for prediction.

- For OLS/Ridge, the top‑$k$% features by $|$SHAP$|$ are replaced by baseline values (mean or zero in the standardised space).  
- For the CNN, only the **image channel** is masked at those pixels (the coordinate channels remain untouched).

The metric reported is the **$R^2$ against observed labels** on the explained subset as a function of $k$. The summarising scalar is **AOPC** (larger is better).

**Good vs bad.**
- *Good*: $R^2$ drops sharply by removing 1–5% of top pixels; large AOPC.  
- *Bad*: flat $R^2$ curve; indicates unfaithful or noisy attributions.

---

## 8) Bounding‑box analysis on top‑|SHAP| within the shell

**Goal.** Quantify how **compact** the attribution is within the shell, and whether it behaves like a local cue (e.g., the star).

Procedure per image:
1. Select the **top $p\%$** ($p=1$% by default) pixels by $|$SHAP$|$ **within the shell**.  
2. Extract the **largest weighted connected component** (8‑connectivity; weights = $|$SHAP$|$).  
3. Compute metrics for the tight axis‑aligned bounding box around that component:

   - **Area fraction of shell**:  
     $$\text{AreaFrac}=\frac{|\,\text{Box}\cap\text{Shell}\,|}{|\,\text{Shell}\,|}.$$
   - **|SHAP| share captured by box**:  
     $$\text{Share}=\frac{\sum_{(i,j)\in\text{Box}\cap\text{Shell}}|S_{ij}|}{\sum_{(i,j)\in\text{Shell}}|S_{ij}|}.$$
   - **Angular coverage** (degrees) of selected pixels using circular geometry.  
   - **Radial width** normalised by shell thickness.

**Good vs bad.**
- *Good (star dataset)*: **small** AreaFrac, **large** Share (most attribution captured by a small box), **low** angular coverage, and **narrow** radial width.  
- *Bad*: large boxes capturing little attribution, or angular coverage approaching a full ring (360°).

---

## 9) Global star‑capture analysis

This section verifies that the model focuses on the star **globally**, not only via per‑image boxes.

1. **Star proxy centre.** For each image, compute a ring‑weighted, high‑tail map and pick its maximum as a proxy star centre $(y_s,x_s)$.  
2. **Star disc mask.** $D=\{(i,j): (i-y_s)^2+(j-x_s)^2 \le r^2\}$ with $r=4$ px.  
3. **SHAP share within disc** (normalised by shell attribution):
   $$
   \text{Share}_{\text{disc}}=\frac{\sum_{(i,j)\in D\cap\text{Shell}}|S_{ij}|}{\sum_{(i,j)\in \text{Shell}}|S_{ij}|}.
   $$
4. **Enrichment** relative to the area fraction of the disc:  
   $$
   \text{Enrichment}=\frac{\text{Share}_{\text{disc}}}{|D\cap\text{Shell}|/|\text{Shell}|}.
   $$
5. **Angular alignment** of the largest component vs the star angle $\theta_s$: compute a circular mean angle of the component weighted by $|S|$ and report $|\theta_{\text{comp}}-\theta_s|$ in degrees.

**Good vs bad.**
- *Good (CNN on star dataset)*: high **Share**, **Enrichment $\gg 1$**, and **small** angular error.  
- *Bad*: Share near area fraction (no enrichment) or large angular errors, similar to OLS/Ridge behaviour.

---

## 10) Baseline sensitivity for the CNN

Deep SHAP is baseline‑conditional. The analysis compares attributions with a **train‑subset** background vs a **mean** background and reports the **Spearman rank correlation** between $|$SHAP$|$ maps per image.

**Good vs bad.**
- *Good*: moderate to high rank correlation (e.g., median $\gtrsim 0.6$), indicating stability across sensible baselines.  
- *Bad*: near‑zero or highly variable correlations, suggesting the baseline selection dominates the explanation.

---

## 11) Polar superpixels / regional analysis

Pixels are binned into $(r,\theta)$ sectors (e.g., 10 radial bands × 16 angular bins). For each sector, the mean $|$SHAP$|$ is recorded. Two descriptors are shown:

- **Peak ring location** (which radial band holds the largest mean attribution).  
- **Anisotropy coefficient of variation** within that band:
  $$
  \text{CV}=\frac{\operatorname{std}_\theta(\text{mean}|S|)}{\operatorname{mean}_\theta(\text{mean}|S|)}.
  $$

**Good vs bad.**
- *Good (star dataset)*: peak on the shell bands and **higher** CV (attribution concentrated into arcs aligned with the star).  
- *Bad*: flat CV near zero (isotropic ring), typical of purely global linear cues.

---

## 12) Augmentation‑invariance sanity check (horizontal flips)

Inputs are flipped horizontally, explained, then unflipped back in attribution space. The **Spearman rank correlation** between original and flipped‑then‑unflipped $|$SHAP$|$ maps is reported per image.

**Good vs bad.**
- *Good*: high correlation (the explanation respects the flip symmetry).  
- *Bad*: unstable attributions under symmetry transformations.

---

## 13) Integrated Gradients (IG) as triangulation

To cross‑check Deep SHAP, **Integrated Gradients** is computed from a simple baseline $\tilde x$ (image channel zeroed; coordinate channels intact):
$$
\mathrm{IG}_i(x; \tilde x)=(x_i-\tilde x_i)\int_0^1 \frac{\partial f\big(\tilde x + t(x-\tilde x)\big)}{\partial x_i}\,dt.
$$
A Spearman correlation between $|\mathrm{IG}|$ and $|$SHAP$|$ per image is reported.

**Good vs bad.**
- *Good*: positive, often substantial correlation, indicating agreement between two attribution methods with different assumptions.  
- *Bad*: no correlation; revisit baseline choice and preprocessing consistency.

---

## 14) Interpreting typical outputs on this dataset

- **Linear models (OLS/Ridge).** Expect ring‑shaped attributions. Deletion curves drop, but not catastrophically at 1–5%. Bounding boxes are large; star‑disc enrichment near 1.  
- **CNN (hard/star dataset).** Expect small bounding boxes, large |SHAP| share captured, low angular coverage; high star‑disc enrichment and small angular error. Deletion AOPC usually largest among the three.  
- **CNN (easy/no‑star dataset).** Attributions move to the shell edges and global scale; star‑specific metrics are not applicable or show no enrichment.

Numerically, on a healthy run it is common to observe:  
- IoU@5% vs shell: CNN $>$ Ridge $\gtrsim$ OLS.  
- AOPC (CNN) $>$ AOPC (Ridge) $>$ AOPC (OLS).  
- Star‑disc enrichment: CNN $\gg 1$, OLS/Ridge $\approx 1$.

---

## 15) Practical pitfalls and remedies

- **Baseline mismatch** between training and SHAP inputs → nonsensical maps. Always rebuild inputs with the *same* background subtraction and normalisation.  
- **Non‑supported ops** (e.g., `sqrt` from Sobel magnitude) can violate strict additivity checks. Using `check_additivity=False` avoids hard failures; manual additivity checks remain possible.  
- **Floating‑point quirks on MPS/CUDA** can surface as non‑finite losses; the training code already defends by reducing the learning rate and skipping the offending step.  
- **Over‑aggressive masking** in deletion curves can change the data distribution drastically; restricting masking to the **image channel** (and retaining $(x,y,r)$) keeps the test closer to “remove visual evidence”.

---

## 16) A short glossary

- **Baseline (background).** The reference input(s) $x'$ relative to which contributions are measured.  
- **Deep SHAP.** DeepExplainer’s adaptation of Shapley ideas to deep networks, mixing DeepLIFT rules and path integrations.  
- **AOPC.** “Area Over the (performance) Curve” after progressively deleting top‑|SHAP| pixels; larger indicates more faithful attributions.  
- **IoU / Precision vs shell.** Overlap metrics between top‑|SHAP| pixels and the known shell region.  
- **Largest weighted component.** The connected subset within top‑|SHAP| that has the greatest total $|$SHAP$|$; used to avoid fragmented boxes.  
- **Enrichment.** Share of $|$SHAP$|$ inside the star disc divided by the star disc’s area fraction of the shell. Values $\gg1$ indicate star focus.

---

**Bottom line.** On this synthetic problem, linear models explain the **global density** via ring‑like weights; the CNN explains **both** global scale and **local star cues**. The attribution tests here (baseline selection, deletion/AOPC, IoU/Precision, bounding boxes, global star capture, flips, and IG triangulation) provide a comprehensive, quantitative argument that the CNN’s predictions are driven by the intended visual evidence.


In [None]:
"""
Deep SHAP — End‑to‑end, commented, and *interpretable* analysis on our three models
-----------------------------------------------------------------------------------

B. Models explained:
  • OLS (pixels → y), Ridge (StandardScaler + Ridge → y), and CNN (4‑ch input: image + x,y,r; predicts z then inverted to y).
  • CNN is made SHAP‑safe by:
      – a SHAPCompatCNN wrapper that removes nn.Flatten and any in‑place ReLUs,
      – using check_additivity=False (via safe_shap_values) to tolerate unsupported ops (e.g., sqrt in fixed stems).

Changes in this version:
  • Insertion curves removed (kept Deletion/AOPC).
  • FIX: bounding‑box analysis is now robust (no negative kth; handles degenerate masks; always 2‑D indexing).
  • NEW: before boxing we take the **largest weighted connected component** of the top‑|SHAP| pixels within the shell.
"""

# -----------------
# Imports and checks
# -----------------
import copy
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.patches as patches
import torch
import torch.nn as nn

from sklearn.pipeline import Pipeline
from sklearn.model_selection import train_test_split
from sklearn.metrics import r2_score


# ----- Device (M1 → MPS if available)
if torch.cuda.is_available():
    DEVICE = "cuda"
elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
    DEVICE = "mps"
else:
    DEVICE = "cpu"
print(f"Using device: {DEVICE}")

# Silence SHAP's "unrecognized nn.Module" warnings for harmless modules (ConcatPool2d / ReflectionPad2d)
import warnings
warnings.filterwarnings(
    "ignore",
    message=r"unrecognized nn\.Module: .*",
    category=UserWarning,
    module="shap.explainers._deep.deep_pytorch"
)


try:
    import shap
except Exception as e:
    raise ImportError("Please install SHAP first:  pip install shap") from e

# Optional Spearman correlation (baseline sensitivity); fall back to Pearson if unavailable
try:
    from scipy.stats import spearmanr
    HAVE_SPEARMAN = True
except Exception:
    HAVE_SPEARMAN = False

# Sanity: ensure earlier cells provided these
required = ["images", "y_true_all", "y_obs_all", "ols", "final_ridge", "model"]
for v in required:
    if v not in globals():
        raise RuntimeError(f"Missing '{v}'. Please run the data, OLS/Ridge, and CNN training cells first.")

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
model.to(DEVICE).eval()
torch.set_grad_enabled(True)  # in case a previous cell disabled grad

N, H, W = images.shape
CENTER = (H - 1) / 2.0
print(f"[Setup] images: {images.shape} | device: {DEVICE}")

# -----------------------------
# Split indices (SAME recipe as earlier; seed=2025)
# -----------------------------
SEED = 2025
TEST_FRAC = 0.20
VAL_FRAC_WITHIN_TRAINVAL = 0.20  # 20% of 80% → 16% overall

rng = np.random.default_rng(SEED)
idx_all = np.arange(N)

# First split: TrainVal vs Test
idx_trainval, idx_test = train_test_split(idx_all, test_size=TEST_FRAC, random_state=SEED)
# Second split: Train vs Val (within TrainVal)
idx_train, idx_val = train_test_split(idx_trainval, test_size=VAL_FRAC_WITHIN_TRAINVAL, random_state=SEED)

print(f"[Split] Train={len(idx_train)}, Val={len(idx_val)}, Test={len(idx_test)}")

# -----------------------------
# Utility helpers
# -----------------------------
def area_trapezoid(y, x):
    y = np.asarray(y, dtype=float)
    x = np.asarray(x, dtype=float)
    if hasattr(np, "trapezoid"):
        return float(np.trapezoid(y, x))
    return float(np.sum(0.5 * (y[1:] + y[:-1]) * (x[1:] - x[:-1])))

def heatmap2d(arr2d, title, cmap="magma", with_cbar=True):
    plt.figure(figsize=(5, 4))
    plt.imshow(arr2d, cmap=cmap)
    if with_cbar:
        plt.colorbar(fraction=0.046, pad=0.04)
    plt.title(title); plt.axis("off"); plt.tight_layout(); plt.show()

def grid_overlays(imgs, shap_maps, title_prefix, ncols=4):
    n = len(imgs)
    ncols = min(ncols, n); nrows = int(np.ceil(n / ncols))
    vmax = np.percentile(np.abs(np.concatenate([s.ravel() for s in shap_maps])), 99) if len(shap_maps) else 1.0
    fig, axes = plt.subplots(nrows, ncols, figsize=(3.2 * ncols, 3.2 * nrows))
    axes = np.atleast_1d(axes).ravel()
    for i in range(n):
        axes[i].imshow(imgs[i], cmap="gray", vmin=0, vmax=1)
        axes[i].imshow(shap_maps[i], cmap="RdBu_r", alpha=0.7, vmin=-vmax, vmax=+vmax)
        axes[i].set_title(f"{title_prefix} — sample {i}", fontsize=9)
        axes[i].axis("off")
    for k in range(n, nrows * ncols): axes[k].axis("off")
    plt.tight_layout(); plt.show()

# -----------------------------
# Radial & angular fields + “region share” metrics
# -----------------------------
ys, xs = np.indices((H, W))
R = np.sqrt((ys - CENTER) ** 2 + (xs - CENTER) ** 2)
theta = (np.arctan2(ys - CENTER, xs - CENTER) + 2*np.pi) % (2*np.pi)  # used by polar & bbox

INNER_R = 18.0
OUTER_R = 42.0
mask_inner = (R <= INNER_R)  # hollow
mask_shell = (R > INNER_R) & (R <= OUTER_R)  # metal shell
mask_outer = (R > OUTER_R)  # outside

def region_shares(abs_shap_map):
    total = abs_shap_map.sum() + 1e-12
    return (
        float((abs_shap_map[mask_inner]).sum() / total),
        float((abs_shap_map[mask_shell]).sum() / total),
        float((abs_shap_map[mask_outer]).sum() / total),
    )

def region_means(abs_shap_map):
    mi = float(abs_shap_map[mask_inner].mean() if mask_inner.sum() else 0.0)
    ms = float(abs_shap_map[mask_shell].mean() if mask_shell.sum() else 0.0)
    mo = float(abs_shap_map[mask_outer].mean() if mask_outer.sum() else 0.0)
    return mi, ms, mo

def top_p_concentration(abs_shap_map, p=0.05):
    A = np.abs(abs_shap_map).ravel()
    k = max(1, int(p * A.size))
    kth = A.size - k  # non‑negative kth
    if kth < 0: kth = 0
    thr = np.partition(A, kth)[kth]
    return float((A[A >= thr]).sum() / (A.sum() + 1e-12))

def radial_profile(abs_shap_map, nbins=60):
    r = R.ravel(); v = abs_shap_map.ravel()
    bins = np.linspace(0, R.max() + 1e-6, nbins + 1)
    prof = np.zeros(nbins, dtype=float)
    for i in range(nbins):
        m = (r >= bins[i]) & (r < bins[i + 1])
        prof[i] = v[m].mean() if np.any(m) else 0.0
    centres = 0.5 * (bins[:-1] + bins[1:])
    return centres, prof

def summarise_region_and_profile(stack, model_name):
    shares = np.array([region_shares(np.abs(s)) for s in stack])
    inner_s, shell_s, outer_s = shares.mean(axis=0)

    means = np.array([region_means(np.abs(s)) for s in stack])
    inner_m, shell_m, outer_m = means.mean(axis=0)

    conc_1 = np.mean([top_p_concentration(np.abs(s), p=0.01) for s in stack])
    conc_5 = np.mean([top_p_concentration(np.abs(s), p=0.05) for s in stack])

    print(f"\n{model_name} — mean |SHAP| region *share* (area‑biased):")
    print(f"  Hollow (r ≤ {INNER_R:.0f}) : {inner_s:6.2%}")
    print(f"  Shell  ({INNER_R:.0f}<r≤{OUTER_R:.0f}): {shell_s:6.2%}")
    print(f"  Outside(r > {OUTER_R:.0f}) : {outer_s:6.2%}")

    print(f"{model_name} — mean |SHAP| *per pixel* (area‑fair):")
    print(f"  Hollow:  {inner_m:.6f}  |  Shell: {shell_m:.6f}  |  Outside: {outer_m:.6f}")
    print(f"{model_name} — concentration: top‑1% captures {conc_1:6.2%}, top‑5% captures {conc_5:6.2%}")

    profs, radii = [], None
    for s in stack:
        rr, p = radial_profile(np.abs(s), nbins=70)
        profs.append(p); radii = rr
    profs = np.stack(profs, axis=0)
    plt.figure(figsize=(6, 4))
    plt.plot(radii, profs.mean(axis=0))
    plt.xlabel("Radius (px)"); plt.ylabel("Mean |SHAP|")
    plt.title(f"Radial profile of |SHAP| — {model_name}\n(look for peaks around the shell boundaries)")
    plt.grid(True, alpha=0.3); plt.tight_layout(); plt.show()

def topk_mask(abs_map, k_frac):
    A = np.abs(abs_map).ravel()
    k = max(1, int(k_frac * A.size))
    kth = A.size - k
    if kth < 0: kth = 0
    thr = np.partition(A, kth)[kth]
    return (np.abs(abs_map) >= thr)

def iou_and_precision_vs_shell(stack, k_fracs=(0.01, 0.02, 0.05, 0.10)):
    results = {}; shell = mask_shell
    for kf in k_fracs:
        ious, precs = [], []
        for s in stack:
            m = topk_mask(s, kf).reshape(H, W)
            inter = np.logical_and(m, shell).sum()
            union = np.logical_or(m, shell).sum()
            iou = inter / (union + 1e-12)
            prec = inter / (m.sum() + 1e-12)
            ious.append(iou); precs.append(prec)
        results[kf] = (float(np.mean(ious)), float(np.mean(precs)))
    return results

def print_iou_precision(name, results):
    print(f"\n{name} — Ground‑truth alignment vs shell (IoU@k / Precision@k):")
    for kf, (iou, prec) in results.items():
        print(f"  k={int(kf * 100)}%  IoU={iou:.3f}  |  Precision={prec:.3f}")

# -----------------------------
# A common test subset to *explain*
# -----------------------------
N_EXPLAIN = min(24, len(idx_test))
test_sel_idx = rng.choice(idx_test, size=N_EXPLAIN, replace=False)
test_imgs = images[test_sel_idx]
y_obs_test_sel = y_obs_all[test_sel_idx]

# ==========================================================
# 1) OLS — LinearExplainer on flattened pixels
# ==========================================================
print("\n[OLS] Expect a ring; most |SHAP| in the shell.")
X_flat = images.reshape(N, -1).astype(np.float64)

bg_ols_idx = rng.choice(idx_train, size=min(200, len(idx_train)), replace=False)
X_bg_ols = X_flat[bg_ols_idx]
X_ols_test = X_flat[test_sel_idx]

expl_ols = shap.LinearExplainer(ols, X_bg_ols)  # identity link
shap_ols_flat = expl_ols.shap_values(X_ols_test)
if isinstance(shap_ols_flat, list):  # defensive
    shap_ols_flat = shap_ols_flat[0]
shap_ols_imgs = shap_ols_flat.reshape(-1, H, W)

grid_overlays(test_imgs[:8], shap_ols_imgs[:8], "OLS")
heatmap2d(np.mean(np.abs(shap_ols_imgs), axis=0), "Global mean |SHAP| — OLS")
summarise_region_and_profile(shap_ols_imgs, "OLS")
print_iou_precision("OLS", iou_and_precision_vs_shell(shap_ols_imgs))

# ==========================================================
# 2) Ridge — LinearExplainer on STANDARDISED features
# ==========================================================
print("\n[Ridge] Expect more diffused credit spread; still ring‑biased.")
if not isinstance(final_ridge, Pipeline):
    raise RuntimeError("Expected 'final_ridge' to be a Pipeline(StandardScaler, Ridge).")
scaler = final_ridge.named_steps["scaler"]
ridge  = final_ridge.named_steps["ridge"]

Z_bg   = scaler.transform(X_flat[bg_ols_idx])
Z_test = scaler.transform(X_flat[test_sel_idx])

expl_ridge = shap.LinearExplainer(ridge, Z_bg)
shap_ridge_scaled = expl_ridge.shap_values(Z_test)
if isinstance(shap_ridge_scaled, list):
    shap_ridge_scaled = shap_ridge_scaled[0]
shap_ridge_imgs = shap_ridge_scaled.reshape(-1, H, W)

grid_overlays(test_imgs[:8], shap_ridge_imgs[:8], "Ridge")
heatmap2d(np.mean(np.abs(shap_ridge_imgs), axis=0), "Global mean |SHAP| — Ridge")
summarise_region_and_profile(shap_ridge_imgs, "Ridge")
print_iou_precision("Ridge", iou_and_precision_vs_shell(shap_ridge_imgs))

# ==========================================================
# 3) CNN — DeepExplainer on 4‑channel inputs (image + CoordConv x,y,r)
#      SHAP‑compatible wrapper (no nn.Flatten; no in‑place ops); additivity off
# ==========================================================
print("\n[CNN] Strong edges on inner/outer shell expected; baseline matters — we select it by fidelity.")

# --- preprocessing used for CNN inputs (match the training recipe) ---
def estimate_background_offset(img: np.ndarray, k: int = 10) -> float:
    patches = [img[:k, :k], img[:k, -k:], img[-k:, :k], img[-k:, -k:]]
    return float(np.median(np.concatenate([p.ravel() for p in patches])))

def subtract_background(X: np.ndarray) -> np.ndarray:
    Xp = np.empty_like(X)
    for i, img in enumerate(X):
        bg = estimate_background_offset(img, k=10)
        im = img - bg
        im[im < 0.0] = 0.0
        Xp[i] = im
    return Xp

pm_ps_found = False
if "pix_mean" in globals() and "pix_std" in globals():
    pm = float(pix_mean); ps = float(pix_std); pm_ps_found = True
elif "pixel_mean" in globals() and "pixel_std" in globals():
    pm = float(pixel_mean); ps = float(pixel_std); pm_ps_found = True
if not pm_ps_found:
    X_train_bgsub = subtract_background(images[idx_train])
    pm = float(X_train_bgsub.mean()); ps = float(X_train_bgsub.std() + 1e-8)

# CoordConv maps
x_lin = (xs - CENTER) / CENTER
y_lin = (ys - CENTER) / CENTER
r_map = np.sqrt(x_lin ** 2 + y_lin ** 2)
r_map = r_map / (r_map.max() + 1e-12)

def build_cnn_input(img_batch: np.ndarray) -> np.ndarray:
    Xs = subtract_background(img_batch)
    Xn = (Xs - pm) / ps
    n = Xn.shape[0]
    X4 = np.zeros((n, 4, H, W), dtype=np.float32)
    X4[:, 0] = Xn.astype(np.float32)
    X4[:, 1] = x_lin.astype(np.float32)
    X4[:, 2] = y_lin.astype(np.float32)
    X4[:, 3] = r_map.astype(np.float32)
    return X4

X_cnn_test_4 = build_cnn_input(images[test_sel_idx])

# --- SHAP‑compatible wrapper that removes nn.Flatten and copies Linear weights ---
class SHAPCompatCNN(nn.Module):
    def __init__(self, base: nn.Module):
        super().__init__()
        m = copy.deepcopy(base).to(DEVICE).eval()
        for module in m.modules():
            if isinstance(module, nn.ReLU):
                module.inplace = False
        self.features = m.features
        self.gap = m.gap
        # Extract Linear layers from head (expects exactly two)
        lin_layers = [mod for mod in m.head if isinstance(mod, nn.Linear)]
        if len(lin_layers) != 2:
            raise RuntimeError("Expected two Linear layers in model.head.")
        self.fc1 = nn.Linear(lin_layers[0].in_features, lin_layers[0].out_features, bias=True)
        self.act = nn.ReLU(inplace=False)
        drops = [mod for mod in m.head if isinstance(mod, nn.Dropout)]
        self.drop = nn.Dropout(p=drops[0].p if drops else 0.0)
        self.fc2 = nn.Linear(lin_layers[1].in_features, lin_layers[1].out_features, bias=True)
        with torch.no_grad():
            self.fc1.weight.copy_(lin_layers[0].weight); self.fc1.bias.copy_(lin_layers[0].bias)
            self.fc2.weight.copy_(lin_layers[1].weight); self.fc2.bias.copy_(lin_layers[1].bias)

    def forward(self, x):
        x = self.features(x)
        x = self.gap(x)
        x = torch.flatten(x, 1)  # functional flatten (no nn.Flatten module)
        x = self.fc1(x)
        x = self.act(x)
        x = self.drop(x)
        y = self.fc2(x)
        return y.clone()

shap_model = SHAPCompatCNN(model).to(DEVICE).eval()

# --- candidate baselines for CNN SHAP ---
def build_baseline(kind="train_subset", m=64):
    """
    kind ∈ {"train_subset", "mean", "median", "blurred_mean"}.
    Returns a tensor on DEVICE suitable for DeepExplainer background.
    """
    if kind == "train_subset":
        idx = rng.choice(idx_train, size=min(m, len(idx_train)), replace=False)
        X_bg = build_cnn_input(images[idx])
        return torch.from_numpy(X_bg).to(DEVICE)
    if kind in {"mean", "median", "blurred_mean"}:
        X_tr = build_cnn_input(images[idx_train])
        if kind == "mean":
            img_ch = X_tr[:, 0].mean(axis=0, keepdims=True).astype(np.float32)
        elif kind == "median":
            img_ch = np.median(X_tr[:, 0], axis=0, keepdims=True).astype(np.float32)
        else:  # blurred_mean
            mean_img = X_tr[:, 0].mean(axis=0)
            k = 5; pad = k // 2
            tmp = np.pad(mean_img, pad, mode="reflect"); out = np.zeros_like(mean_img)
            for y in range(H):
                for x in range(W):
                    out[y, x] = tmp[y:y + k, x:x + k].mean()
            img_ch = out[None, :, :].astype(np.float32)
        Xb = np.zeros((1, 4, H, W), dtype=np.float32)
        Xb[0, 0] = img_ch[0]; Xb[0, 1] = x_lin; Xb[0, 2] = y_lin; Xb[0, 3] = r_map
        return torch.from_numpy(Xb).to(DEVICE)
    raise ValueError("Unknown baseline kind.")

# --- DeepExplainer helper that disables additivity check (critical fix) ---
def safe_shap_values(explainer, x_tensor):
    try:
        return explainer.shap_values(x_tensor, check_additivity=False)
    except TypeError:
        try:
            import shap.explainers._deep.deep_utils as _du
            _du.TOLERANCE = 1e9
        except Exception:
            pass
        return explainer.shap_values(x_tensor)

# --- Deep SHAP for several baselines; pick the best by fidelity ---
BASELINES = ["train_subset", "mean", "median", "blurred_mean"]
baseline_results = {}
torch.set_grad_enabled(True)
X_test_t = torch.from_numpy(X_cnn_test_4).to(DEVICE).requires_grad_(True)

def cnn_predict_y_from_4(X4):
    """Predict y (not z) for a batch of 4‑channel inputs with the *trained* CNN clone used for SHAP."""
    y_train_obs = y_obs_all[idx_train]
    y_mean = float(np.mean(y_train_obs)); y_std = float(np.std(y_train_obs) + 1e-8)
    def from_z(z): return z * y_std + y_mean
    shap_model.eval()
    with torch.no_grad():
        z = shap_model(torch.from_numpy(X4).to(DEVICE)).cpu().numpy().squeeze(-1)
    return from_z(z)

fractions = [0.01, 0.02, 0.05, 0.10, 0.20]
y_obs_sel = y_obs_test_sel

for bkind in BASELINES:
    bg_t = build_baseline(bkind, m=64 if bkind == "train_subset" else 1)
    expl = shap.DeepExplainer(shap_model, bg_t)
    vals = safe_shap_values(expl, X_test_t)
    sh = vals[0] if isinstance(vals, list) else vals  # (n, 4, H, W)
    sh_img = sh[:, 0, :, :]
    baseline_results[bkind] = {"all": sh, "img": sh_img}

    # IoU/Precision@k
    io = iou_and_precision_vs_shell(sh_img, k_fracs=(0.01, 0.05))
    baseline_results[bkind]["iou_prec"] = io

    # Deletion AOPC
    n, _, h, w = X_cnn_test_4.shape; D = h * w
    order = np.argsort(-np.abs(sh_img.reshape(n, D)), axis=1)
    y0 = cnn_predict_y_from_4(X_cnn_test_4); r2_base = r2_score(y_obs_sel, y0)
    r2_list = []
    for frac in fractions:
        k = max(1, int(frac * D))
        X_mod = X_cnn_test_4.copy()
        for i in range(n):
            idxk = order[i, :k]; rr, cc = idxk // w, idxk % w
            X_mod[i, 0, rr, cc] = 0.0
        y_pred_mod = cnn_predict_y_from_4(X_mod)
        r2_list.append(r2_score(y_obs_sel, y_pred_mod))
    drops = (r2_base - np.array(r2_list))
    aopc = area_trapezoid(drops, np.array(fractions))
    baseline_results[bkind]["aopc"] = float(aopc)
    baseline_results[bkind]["r2_base"] = float(r2_base)

torch.set_grad_enabled(False)

# Pick best baseline by (IoU@5% + IoU@1%) + AOPC (simple normalised score)
def normalise(v):
    v = np.asarray(v, dtype=float)
    if np.ptp(v) < 1e-12:
        return np.ones_like(v)
    return (v - v.min()) / (v.max() - v.min())

scores = []; kinds = []
for k, d in baseline_results.items():
    io1 = d["iou_prec"][0.01][0]; io5 = d["iou_prec"][0.05][0]
    kinds.append(k); scores.append([io1, io5, d["aopc"]])
scores = np.array(scores)
score = normalise(scores[:, 0]) + normalise(scores[:, 1]) + normalise(scores[:, 2])
BEST_BASELINE = kinds[int(np.argmax(score))]
print(f"\n[CNN] Baseline selection by fidelity → chosen: **{BEST_BASELINE}**")
for k in BASELINES:
    io1 = baseline_results[k]["iou_prec"][0.01][0]
    io5 = baseline_results[k]["iou_prec"][0.05][0]
    print(f"  {k:13s}  IoU@1%={io1:.3f}  IoU@5%={io5:.3f}  AOPC={baseline_results[k]['aopc']:.4f}")

# Use the chosen baseline’s SHAP maps from now on
shap_cnn_allch = baseline_results[BEST_BASELINE]["all"]
shap_cnn_imgch = baseline_results[BEST_BASELINE]["img"]

# Light background bootstrapping (variability hints)
print("\n[CNN] Background bootstrapping (variability over train‑subset backgrounds)")
B = 3
boot_shares = []
for b in range(B):
    bg_t = build_baseline("train_subset", m=32)
    torch.set_grad_enabled(True)
    e = shap.DeepExplainer(shap_model, bg_t)
    vals = safe_shap_values(e, X_test_t)
    torch.set_grad_enabled(False)
    sh = vals[0] if isinstance(vals, list) else vals
    sh_img = sh[:, 0]
    shares = np.array([region_shares(np.abs(s)) for s in sh_img])
    boot_shares.append(shares.mean(axis=0))
boot_shares = np.stack(boot_shares, axis=0)
means = boot_shares.mean(axis=0); stds = boot_shares.std(axis=0)
print(f"  Region share mean±sd (Hollow/Shell/Outside): "
      f"{means[0]:.3f}±{stds[0]:.3f} / {means[1]:.3f}±{stds[1]:.3f} / {means[2]:.3f}±{stds[2]:.3f}")

# Local overlays & global maps for CNN (chosen baseline)
grid_overlays(test_imgs[:8], shap_cnn_imgch[:8], f"CNN [{BEST_BASELINE}]")
heatmap2d(np.mean(np.abs(shap_cnn_imgch), axis=0), f"Global mean |SHAP| — CNN [{BEST_BASELINE}]")
summarise_region_and_profile(shap_cnn_imgch, f"CNN [{BEST_BASELINE}]")
print_iou_precision(f"CNN [{BEST_BASELINE}]", iou_and_precision_vs_shell(shap_cnn_imgch))

# -----------------------------
# Deletion curves — measured as R² changes (OLS / Ridge / CNN)
# -----------------------------
print("\n[Deletion curves: what they mean]\n"
      "• Remove top‑|SHAP| pixels and re‑score R² vs true labels.\n"
      "  A *faithful* explanation targets genuinely important pixels ⇒ R² **drops fast**. We summarise by AOPC.\n")

def plot_deletion_r2(title, fractions, r2_base, r2_list):
    plt.figure(figsize=(5.8, 4))
    plt.plot([f * 100 for f in fractions], [r2_base] * len(fractions), linestyle="--", label="Baseline R²")
    plt.plot([f * 100 for f in fractions], r2_list, marker="o", label="R² after deletion")
    plt.xlabel("Top-|SHAP| pixels removed (%)")
    plt.ylabel("R² vs true labels (test subset)")
    plt.title(title); plt.grid(True, alpha=0.3); plt.legend(); plt.tight_layout(); plt.show()
    drops = (r2_base - np.array(r2_list))
    aopc = area_trapezoid(drops, np.array(fractions))
    print(f"{title} — AOPC (higher → better fidelity): {aopc:.4f}")
    return aopc

# --- OLS deletion
y0_ols_pred = ols.predict(X_ols_test)
r2_base_ols = r2_score(y_obs_test_sel, y0_ols_pred)
order_ols = np.argsort(-np.abs(shap_ols_flat), axis=1)
r2_list_ols_del = []
mu = X_bg_ols.mean(axis=0)
for frac in fractions:
    k = max(1, int(frac * X_ols_test.shape[1]))
    X_del = X_ols_test.copy()
    for i in range(X_del.shape[0]):
        X_del[i, order_ols[i, :k]] = mu[order_ols[i, :k]]
    r2_list_ols_del.append(r2_score(y_obs_test_sel, ols.predict(X_del)))
aopc_ols = plot_deletion_r2("Deletion curve — OLS (R² drop)", fractions, r2_base_ols, r2_list_ols_del)

# --- Ridge deletion
y0_ridge_pred = final_ridge.predict(X_flat[test_sel_idx])
r2_base_ridge = r2_score(y_obs_test_sel, y0_ridge_pred)
order_ridge = np.argsort(-np.abs(shap_ridge_scaled), axis=1)
r2_list_ridge_del = []
Z_test = scaler.transform(X_flat[test_sel_idx])
for frac in fractions:
    k = max(1, int(frac * Z_test.shape[1]))
    Z_del = Z_test.copy()
    for i in range(Z_del.shape[0]):
        Z_del[i, order_ridge[i, :k]] = 0.0  # 0 == scaled mean
    y_pred_del = (Z_del @ ridge.coef_.ravel() + ridge.intercept_)
    r2_list_ridge_del.append(r2_score(y_obs_test_sel, y_pred_del))
aopc_ridge = plot_deletion_r2("Deletion curve — Ridge (R² drop)", fractions, r2_base_ridge, r2_list_ridge_del)

# --- CNN deletion (image channel only; coords intact)
y_train_obs = y_obs_all[idx_train]
y_mean = float(np.mean(y_train_obs)); y_std = float(np.std(y_train_obs) + 1e-8)
def from_z(z): return z * y_std + y_mean

def cnn_predict_y(_model, X4):
    _model.eval()
    with torch.no_grad():
        z = _model(torch.from_numpy(X4).to(DEVICE)).cpu().numpy().squeeze(-1)
    return from_z(z)

y0_cnn_pred = cnn_predict_y(shap_model, X_cnn_test_4)
r2_base_cnn = r2_score(y_obs_test_sel, y0_cnn_pred)
n, _, h, w = X_cnn_test_4.shape; D = h * w
order_cnn = np.argsort(-np.abs(shap_cnn_imgch.reshape(n, D)), axis=1)

r2_list_cnn_del = []
for frac in fractions:
    k = max(1, int(frac * D))
    X_mod = X_cnn_test_4.copy()
    for i in range(n):
        idx = order_cnn[i, :k]; rr, cc = idx // w, idx % w
        X_mod[i, 0, rr, cc] = 0.0
    y_pred_mod = cnn_predict_y(shap_model, X_mod)
    r2_list_cnn_del.append(r2_score(y_obs_test_sel, y_pred_mod))
aopc_cnn = plot_deletion_r2(f"Deletion curve — CNN [{BEST_BASELINE}] (R² drop)", fractions, r2_base_cnn, r2_list_cnn_del)

# -----------------------------
# Bounding‑box analysis over top‑|SHAP| within the shell
#   • Robust against shape issues (always flattens indices safely).
#   • Takes the **largest weighted connected component** before boxing.
# -----------------------------
def circular_coverage(angles_rad):
    """Minimal angular span (radians) covering all given angles on the circle."""
    if angles_rad.size == 0:
        return 0.0
    a = np.sort(angles_rad)
    diffs = np.diff(np.concatenate([a, a[:1] + 2*np.pi]))
    max_gap = np.max(diffs)
    return float(2*np.pi - max_gap)

def _largest_component(mask2d, weights2d=None):
    """
    8‑connected components on a boolean 2‑D mask.
    Returns the mask of the component with largest total weight (or largest size if weights is None).
    """
    M = np.ascontiguousarray(mask2d.astype(bool))
    H_, W_ = M.shape
    visited = np.zeros_like(M, dtype=bool)
    best_weight = -1.0
    best_comp = None

    # Precompute neighbor offsets (8‑connectivity)
    nbrs = [(-1,-1),(-1,0),(-1,1),(0,-1),(0,1),(1,-1),(1,0),(1,1)]

    # Iterate over all True pixels
    ys_, xs_ = np.where(M)
    for y0, x0 in zip(ys_, xs_):
        if visited[y0, x0]:
            continue
        # BFS/stack
        stack = [(y0, x0)]
        visited[y0, x0] = True
        comp_pixels = [(y0, x0)]
        while stack:
            y, x = stack.pop()
            for dy, dx in nbrs:
                yy, xx = y + dy, x + dx
                if 0 <= yy < H_ and 0 <= xx < W_ and (not visited[yy, xx]) and M[yy, xx]:
                    visited[yy, xx] = True
                    stack.append((yy, xx))
                    comp_pixels.append((yy, xx))
        # weight for this component
        if weights2d is not None:
            w = float(np.sum([weights2d[y, x] for (y, x) in comp_pixels]))
        else:
            w = float(len(comp_pixels))
        if w > best_weight:
            best_weight = w
            best_comp = comp_pixels

    comp_mask = np.zeros_like(M, dtype=bool)
    if best_comp is not None:
        ys_c, xs_c = zip(*best_comp)
        comp_mask[ys_c, xs_c] = True
    return comp_mask

def compute_bbox_metrics(stack, model_name, imgs_for_overlay=None, draw_overlays=True, overlay_n=6,
                         p_top=0.01):
    """
    For each SHAP map in `stack`, within the shell:
      • select the TOP p_top fraction of |SHAP| pixels (non-negative kth),
      • take the **largest weighted connected component** of that mask,
      • compute the tight axis-aligned bounding box,
      • report area fraction of the shell covered by the box,
               fraction of total |SHAP| in shell captured by the box,
               angular coverage (deg) of selected pixels,
               radial width normalised by shell thickness.
    Returns a dict of per-sample arrays and prints robust summarised stats.
    """
    area_fracs = []
    shap_shares = []
    ang_degs    = []
    rad_widths  = []

    shell_area = float(mask_shell.sum())
    shell_thick = (OUTER_R - INNER_R)

    overlays_done = 0
    if imgs_for_overlay is None:
        imgs_for_overlay = []

    for i, S in enumerate(stack):
        # ensure 2‑D map
        S2 = np.asarray(S)
        if S2.ndim != 2:
            # if something off slipped in, reduce to the last two dims
            S2 = np.asarray(S2)[-H:, -W:]
            S2 = S2.reshape(H, W)
        Sabs = np.abs(S2)

        vals = Sabs[mask_shell].ravel()

        if vals.size == 0 or not np.any(vals > 0):
            # skip gracefully
            continue

        # TOP‑p% threshold (non‑negative kth)
        k = max(1, int(np.ceil(p_top * vals.size)))
        kth = vals.size - k
        if kth < 0: kth = 0
        thr = np.partition(vals, kth)[kth]

        # binary mask of top‑p% within shell
        M = np.zeros((H, W), dtype=bool)
        M_shell = (Sabs >= thr) & mask_shell

        # Focus on **largest weighted component** (weights = |SHAP|)
        M_sel = _largest_component(M_shell, weights2d=Sabs)

        # fallback: if component is empty (pathological), take the single max pixel in shell
        if not np.any(M_sel):
            shell_idx = np.column_stack(np.where(mask_shell))
            # argmax over shell pixels
            argmax_flat = np.argmax(Sabs[mask_shell])
            yx = shell_idx[argmax_flat]
            M_sel = np.zeros((H, W), dtype=bool)
            M_sel[yx[0], yx[1]] = True

        # coords of selected pixels (robust way, works for any shape)
        coords = np.column_stack(np.where(M_sel))
        y_min, y_max = int(coords[:, 0].min()), int(coords[:, 0].max())
        x_min, x_max = int(coords[:, 1].min()), int(coords[:, 1].max())

        # area fraction of shell covered by the BOX ∩ shell
        box_mask = np.zeros((H, W), dtype=bool)
        box_mask[y_min:y_max+1, x_min:x_max+1] = True
        box_shell = np.logical_and(box_mask, mask_shell)
        area_frac = float(box_shell.sum() / (shell_area + 1e-12))
        area_fracs.append(area_frac)

        # |SHAP| share captured by the BOX within the shell
        shap_shell_total = float(Sabs[mask_shell].sum() + 1e-12)
        shap_in_box = float(Sabs[box_shell].sum())
        shap_share = shap_in_box / shap_shell_total
        shap_shares.append(shap_share)

        # angular coverage (deg) of SELECTED pixels
        ang = theta[M_sel]
        ang_span = circular_coverage(ang.ravel())
        ang_deg = float(ang_span * 180.0 / np.pi)
        ang_degs.append(ang_deg)

        # radial width normalised by shell thickness (using selected pixels)
        r_sel = R[M_sel]
        rad_w = float((r_sel.max() - r_sel.min()) / (shell_thick + 1e-12))
        rad_widths.append(rad_w)

        # overlays (draw box + highlight selected component)
        if draw_overlays and (overlays_done < overlay_n) and (i < len(imgs_for_overlay)):
            fig, ax = plt.subplots(1, 1, figsize=(3.2, 3.2))
            ax.imshow(imgs_for_overlay[i], cmap="gray", vmin=0, vmax=1)
            rect = patches.Rectangle((x_min, y_min), x_max - x_min + 1, y_max - y_min + 1,
                                     linewidth=1.8, edgecolor='lime', facecolor='none')
            ax.add_patch(rect)
            # faint fill for the selected component
            M_alpha = np.zeros((H, W), dtype=float)
            M_alpha[M_sel] = 0.5
            ax.imshow(M_alpha, cmap="Reds", alpha=0.25, vmin=0, vmax=1)
            ax.set_title(f"{model_name} — bbox on top-{int(p_top*100)}% |SHAP| (largest component)", fontsize=9)
            ax.axis("off")
            plt.tight_layout(); plt.show()
            overlays_done += 1

    # Summaries
    area_fracs = np.array(area_fracs, dtype=float)
    shap_shares = np.array(shap_shares, dtype=float)
    ang_degs    = np.array(ang_degs, dtype=float)
    rad_widths  = np.array(rad_widths, dtype=float)

    def summ(name, arr, unit=""):
        if arr.size == 0:
            print(f"  {name}: n=0")
            return (np.nan, np.nan, np.nan)
        print(f"  {name}: median={np.nanmedian(arr):.3f}{unit}  "
              f"IQR=({np.nanpercentile(arr,25):.3f}{unit}, {np.nanpercentile(arr,75):.3f}{unit})")
        return (float(np.nanmedian(arr)),
                float(np.nanpercentile(arr,25)),
                float(np.nanpercentile(arr,75)))

    print(f"\n[{model_name}] Bounding‑box metrics on top‑|SHAP| within the shell (largest component; p_top={p_top*100:.1f}%):")
    s1 = summ("Area fraction of shell (BOX∩shell / shell)", area_fracs)
    s2 = summ("|SHAP| share captured (BOX∩shell / shell)",   shap_shares)
    s3 = summ("Angular coverage (deg) of selected pixels",     ang_degs, unit="°")
    s4 = summ("Radial width / shell thickness",                rad_widths)

    return {
        "area_frac": area_fracs, "shap_share": shap_shares,
        "angle_deg": ang_degs, "rad_width": rad_widths,
        "summary": {"area_frac": s1, "shap_share": s2, "angle_deg": s3, "rad_width": s4}
    }

# --- Run bounding‑box metrics for each model (top 1% by default)
P_TOP = 0.01
bbox_stats_ols   = compute_bbox_metrics(shap_ols_imgs,   "OLS",   imgs_for_overlay=test_imgs[:8], draw_overlays=False, p_top=P_TOP)
bbox_stats_ridge = compute_bbox_metrics(shap_ridge_imgs, "Ridge", imgs_for_overlay=test_imgs[:8], draw_overlays=False, p_top=P_TOP)
bbox_stats_cnn   = compute_bbox_metrics(shap_cnn_imgch,  f"CNN [{BEST_BASELINE}]",
                                        imgs_for_overlay=test_imgs[:8], draw_overlays=True, p_top=P_TOP)

print("\n[Reading the bounding‑box summaries]")
print("• A small **Area fraction** with a large **|SHAP| share** and **low angular coverage** means the model relies on a tight, local cue.\n"
      "• For the *hard/star* dataset, the CNN should have the **smallest boxes** and **tightest angles** among the three, "
      "while OLS/Ridge remain diffuse.\n"
      "• For the *simple* dataset (no star), the CNN’s boxes will typically hug thin arcs on the inner/outer edges; "
      "angles will be modest, not full‑ring, and radial width < 0.5.\n")

# ================================
# Global star-capture analysis (robust 2-D handling + enrichment)
# ================================
import numpy as np
import matplotlib.pyplot as plt

print("\n[Global star-capture analysis]")

# --- Helpers to guarantee 2-D (H,W) arrays and clean values
def _ensure_hw(x):
    a = np.asarray(x)
    a = np.squeeze(a)
    if a.ndim != 2:
        # Force-shape to (H, W) if a weird singleton remains
        a = a.reshape(H, W)
    # Clean NaN/Inf for safety
    a = np.nan_to_num(a, nan=0.0, posinf=0.0, neginf=0.0)
    return a

# Reuse ring geometry already defined: R, theta, mask_shell, INNER_R, OUTER_R
R_mid   = 0.5 * (INNER_R + OUTER_R)
R_sigma = 0.18 * (OUTER_R - INNER_R)
RING_W  = np.exp(-0.5 * ((R - R_mid) / (R_sigma + 1e-8))**2) * mask_shell
shell_area = float(mask_shell.sum())
shell_thick = (OUTER_R - INNER_R)

def _box_blur3(a):
    a = _ensure_hw(a)
    tmp = np.pad(a, 1, mode="reflect")
    out = np.zeros_like(a, dtype=float)
    for y in range(H):
        for x in range(W):
            out[y, x] = tmp[y:y+3, x:x+3].mean()
    return out

def find_star_proxy_center(img):
    # Ring-weighted, high-tail emphasis
    z = _ensure_hw(img) - float(np.median(img))
    z[z < 0.0] = 0.0
    m = z * RING_W
    m = _box_blur3(m)
    yi, xi = np.unravel_index(np.argmax(m), m.shape)
    return int(yi), int(xi)

def disc_mask(yc, xc, radius=4):
    yy, xx = np.ogrid[:H, :W]
    return ((yy - yc)**2 + (xx - xc)**2) <= (radius*radius)

def circular_coverage(angles_rad):
    if angles_rad.size == 0:
        return 0.0
    a = np.sort(angles_rad)
    diffs = np.diff(np.concatenate([a, a[:1] + 2*np.pi]))
    max_gap = np.max(diffs)
    return float(2*np.pi - max_gap)

def circ_mean_angle(angles, weights=None):
    if angles.size == 0:
        return np.nan
    if weights is None:
        weights = np.ones_like(angles, dtype=float)
    c = np.sum(weights * np.cos(angles))
    s = np.sum(weights * np.sin(angles))
    return float(np.arctan2(s, c))

def circ_abs_diff_deg(a, b):
    d = np.abs((a - b + np.pi) % (2*np.pi) - np.pi)
    return float(d * 180.0 / np.pi)

# ---- robust 8-connected component on 2-D boolean masks
def _largest_component(mask2d, weights2d=None):
    M = _ensure_hw(mask2d).astype(bool)
    Wts = _ensure_hw(weights2d) if weights2d is not None else None
    H_, W_ = M.shape
    visited = np.zeros_like(M, dtype=bool)
    best_weight = -1.0
    best_comp = None
    nbrs = [(-1,-1),(-1,0),(-1,1),(0,-1),(0,1),(1,-1),(1,0),(1,1)]
    ys_, xs_ = np.where(M)
    for y0, x0 in zip(ys_, xs_):
        if visited[y0, x0]:
            continue
        stack = [(y0, x0)]
        visited[y0, x0] = True
        comp_pixels = [(y0, x0)]
        while stack:
            y, x = stack.pop()
            for dy, dx in nbrs:
                yy, xx = y + dy, x + dx
                if 0 <= yy < H_ and 0 <= xx < W_ and (not visited[yy, xx]) and M[yy, xx]:
                    visited[yy, xx] = True
                    stack.append((yy, xx))
                    comp_pixels.append((yy, xx))
        if Wts is not None:
            w = float(np.sum([Wts[y, x] for (y, x) in comp_pixels]))
        else:
            w = float(len(comp_pixels))
        if w > best_weight:
            best_weight = w
            best_comp = comp_pixels
    comp_mask = np.zeros_like(M, dtype=bool)
    if best_comp:
        ys_c, xs_c = zip(*best_comp)
        comp_mask[ys_c, xs_c] = True
    return comp_mask

def largest_component_mask(abs_map, top_frac=0.01):
    S = _ensure_hw(abs_map)
    vals = S[mask_shell].ravel()
    if vals.size == 0:
        return np.zeros((H, W), dtype=bool)
    vals = np.nan_to_num(vals, nan=0.0, posinf=0.0, neginf=0.0)
    k = max(1, int(np.ceil(top_frac * vals.size)))
    kth = max(0, vals.size - k)           # non-negative kth
    thr = np.partition(vals, kth)[kth]
    M_shell = (S >= thr) & mask_shell
    return _largest_component(M_shell, weights2d=S)

def global_star_capture(shap_stack, model_name, imgs_subset, disc_radius=4):
    shares, enrichments, ang_errs = [], [], []
    for i in range(len(shap_stack)):
        img = _ensure_hw(imgs_subset[i])
        S   = np.abs(_ensure_hw(shap_stack[i]))

        # (1) star disc
        ys_, xs_ = find_star_proxy_center(img)
        D = disc_mask(ys_, xs_, radius=disc_radius)
        disc_shell = np.logical_and(D, mask_shell)

        # (2) |SHAP| share inside disc (normalised by |SHAP| on shell)
        denom = float(S[mask_shell].sum() + 1e-12)
        numer = float(S[disc_shell].sum())
        share = numer / denom
        shares.append(share)

        # (3) enrichment vs area fraction of disc in the shell
        area_frac_disc = float(disc_shell.sum() / (shell_area + 1e-12))
        enrichments.append(share / (area_frac_disc + 1e-12))

        # (4) angular alignment vs largest |SHAP| component
        M_sel = largest_component_mask(S, top_frac=0.01)
        if np.any(M_sel):
            th_star = float(theta[ys_, xs_])
            ww = S[M_sel]
            th_comp = circ_mean_angle(theta[M_sel], weights=ww / (ww.sum() + 1e-12))
            ang_errs.append(circ_abs_diff_deg(th_star, th_comp))
        else:
            ang_errs.append(np.nan)

    shares = np.array(shares, dtype=float)
    enrichments = np.array(enrichments, dtype=float)
    ang_errs = np.array(ang_errs, dtype=float)

    def summarise(name, arr, unit=""):
        nfin = int(np.sum(np.isfinite(arr)))
        print(f"  {model_name} — {name}: "
              f"median={np.nanmedian(arr):.3f}{unit} | "
              f"IQR=({np.nanpercentile(arr,25):.3f}{unit}, {np.nanpercentile(arr,75):.3f}{unit}) | n={nfin}")

    print(f"\n[{model_name}] Star-disc capture (radius={disc_radius}px), enrichment, and angular alignment")
    summarise("SHAP share inside star disc", shares)
    summarise("enrichment (share / area)",   enrichments)
    summarise("abs angular error vs star (deg)", ang_errs, unit="°")
    return shares, enrichments, ang_errs

# Run on the same subset used for SHAP (test_imgs / shap_*_imgs)
shares_ols,   enrich_ols,   ang_ols   = global_star_capture(shap_ols_imgs,   "OLS",   test_imgs, disc_radius=4)
shares_ridge, enrich_ridge, ang_ridge = global_star_capture(shap_ridge_imgs, "Ridge", test_imgs, disc_radius=4)
shares_cnn,   enrich_cnn,   ang_cnn   = global_star_capture(shap_cnn_imgch,  f"CNN [{BEST_BASELINE}]", test_imgs, disc_radius=4)

# Optional: global comparison plots
plt.figure(figsize=(6,4))
plt.boxplot([shares_ols, shares_ridge, shares_cnn], labels=["OLS","Ridge","CNN"], showmeans=True)
plt.ylabel("Star-disc |SHAP| share (fraction of |SHAP| in shell)")
plt.title("Global star capture across explained test subset")
plt.grid(True, alpha=0.3); plt.tight_layout(); plt.show()

plt.figure(figsize=(6,4))
plt.boxplot([enrich_ols, enrich_ridge, enrich_cnn], labels=["OLS","Ridge","CNN"], showmeans=True)
plt.ylabel("Enrichment (|SHAP| share / area fraction of disc)")
plt.title("Global enrichment of |SHAP| at the star location")
plt.grid(True, alpha=0.3); plt.tight_layout(); plt.show()


# -----------------------------
# CNN baseline sensitivity (train subset vs mean baseline)
# -----------------------------
torch.set_grad_enabled(True)
bg_train_t = build_baseline("train_subset", m=64)
bg_mean_t  = build_baseline("mean", m=1)
e_train = shap.DeepExplainer(shap_model, bg_train_t)
e_mean  = shap.DeepExplainer(shap_model, bg_mean_t)
vals_a = safe_shap_values(e_train, X_test_t)
vals_b = safe_shap_values(e_mean, X_test_t)
torch.set_grad_enabled(False)
A = (vals_a[0] if isinstance(vals_a, list) else vals_a)[:, 0]
B = (vals_b[0] if isinstance(vals_b, list) else vals_b)[:, 0]
if HAVE_SPEARMAN:
    corrs = []
    for i in range(A.shape[0]):
        a = np.abs(A[i]).ravel(); b = np.abs(B[i]).ravel()
        if (np.std(a) < 1e-12) or (np.std(b) < 1e-12):
            corrs.append(np.nan)
        else:
            corrs.append(spearmanr(a, b, nan_policy="omit").correlation)
    corrs = np.array(corrs)
    print("\n[Baseline sensitivity — CNN]")
    print(f"  Spearman rank correlation of |SHAP| (train‑baseline vs mean‑baseline): "
          f"median={np.nanmedian(corrs):.3f}, IQR=({np.nanpercentile(corrs, 25):.3f}, {np.nanpercentile(corrs, 75):.3f})")
else:
    print("\n[Baseline sensitivity — CNN]\n  SciPy not available; install SciPy to report Spearman correlation.")

# ==========================================================
# Polar superpixels / regional analysis
# ==========================================================
print("\n[Polar superpixels / regional analysis]")
NB_RADIAL = 10; NB_THETA = 16
rad_edges  = np.linspace(0, R.max() + 1e-6, NB_RADIAL + 1)
theta_edges= np.linspace(0, 2 * np.pi, NB_THETA + 1)
rad_bin = np.digitize(R, rad_edges) - 1
theta_bin = np.digitize(theta, theta_edges) - 1
rad_bin[rad_bin == NB_RADIAL] = NB_RADIAL - 1
theta_bin[theta_bin == NB_THETA] = NB_THETA - 1

def polar_bin_aggregate(abs_shap_map):
    out = np.zeros((NB_RADIAL, NB_THETA), dtype=np.float64)
    counts = np.zeros_like(out)
    for r in range(NB_RADIAL):
        for t in range(NB_THETA):
            m = (rad_bin == r) & (theta_bin == t)
            if np.any(m):
                out[r, t] = abs_shap_map[m].mean()
                counts[r, t] = m.sum()
    return out, counts

def polar_summary(stack, model_name):
    mats = []
    for s in stack:
        M, _ = polar_bin_aggregate(np.abs(s))
        mats.append(M)
    Mmean = np.mean(mats, axis=0)
    ring_strength = Mmean.mean(axis=1)
    r_idx = int(np.argmax(ring_strength))
    r_lo, r_hi = rad_edges[r_idx], rad_edges[r_idx + 1]
    row = Mmean[r_idx, :]
    cv = float(np.std(row) / (np.mean(row) + 1e-12))

    plt.figure(figsize=(6, 3.8))
    plt.imshow(Mmean, aspect="auto", origin="lower", cmap="magma")
    plt.colorbar(fraction=0.046, pad=0.04)
    plt.yticks(np.arange(NB_RADIAL), [f"{rad_edges[i]:.0f}-{rad_edges[i + 1]:.0f}" for i in range(NB_RADIAL)])
    plt.xticks(np.arange(0, NB_THETA, 4), [f"{int(360 * t / NB_THETA)}°" for t in range(0, NB_THETA, 4)])
    plt.xlabel("Angle (θ)"); plt.ylabel("Radius band (px)")
    plt.title(f"{model_name} — polar mean |SHAP|\n(peak ring ~ {r_lo:.0f}–{r_hi:.0f}px; anisotropy CV={cv:.3f})")
    plt.tight_layout(); plt.show()

    print(f"{model_name} — peak ring: ~{r_lo:.0f}–{r_hi:.0f} px, anisotropy (CV across angle) = {cv:.3f}")
    return (r_idx, (r_lo, r_hi), cv, Mmean)

r_ols, band_ols, cv_ols, M_ols   = polar_summary(shap_ols_imgs,   "OLS")
r_rid, band_rid, cv_rid, M_rid   = polar_summary(shap_ridge_imgs, "Ridge")
r_cnn, band_cnn, cv_cnn, M_cnn   = polar_summary(shap_cnn_imgch,  f"CNN [{BEST_BASELINE}]")

print("\n[Interpretation — polar summaries]\n"
      "• Peak ring should sit on the shell radius; CNN often shows sharper peaks on inner/outer edges.\n"
      "• Anisotropy (CV) near zero ⇒ isotropic ring; higher CV ⇒ focus on arcs (useful for local defects).\n")

# -----------------------------
# CNN regional ablation + effect sizes (optional sanity)
# -----------------------------
print("\n[CNN regional ablation + effect sizes]")
K_REGIONS = 12
flat_scores = M_cnn.ravel()
top_idx = np.argsort(-flat_scores)[:K_REGIONS]
top_regions = [(i // NB_THETA, i % NB_THETA) for i in top_idx]

mask_top = np.zeros((H, W), dtype=bool)
for (rr, tt) in top_regions:
    mask_top |= ((rad_bin == rr) & (theta_bin == tt))

X_cnn_test_4_ablate = X_cnn_test_4.copy()
X_cnn_test_4_ablate[:, 0, :, :] = np.where(mask_top[None, :, :], 0.0, X_cnn_test_4_ablate[:, 0, :, :])

y_pred_cnn_base = y0_cnn_pred
y_pred_cnn_abl  = cnn_predict_y(shap_model, X_cnn_test_4_ablate)

r2_cnn_base = r2_score(y_obs_test_sel, y_pred_cnn_base)
r2_cnn_abl  = r2_score(y_obs_test_sel, y_pred_cnn_abl)
print(f"  CNN R² (baseline on this subset): {r2_cnn_base:.4f}")
print(f"  CNN R² after ablating top‑{K_REGIONS} polar regions: {r2_cnn_abl:.4f}")
print(f"  R² drop: {r2_cnn_base - r2_cnn_abl:.4f} (larger ⇒ those regions matter)\n")

# -----------------------------
# Augmentation‑invariance sanity check (flips)
# -----------------------------
print("\n[Augmentation‑invariance sanity check — CNN SHAP under flips]")
def flip_h(arr):  # horizontal flip
    return arr[..., :, ::-1]

Xt_flip_np = flip_h(X_cnn_test_4.copy()).copy()
Xt_flip = torch.from_numpy(Xt_flip_np).to(DEVICE).requires_grad_(True)
torch.set_grad_enabled(True)
e_best = shap.DeepExplainer(shap_model, build_baseline(BEST_BASELINE, m=(64 if BEST_BASELINE == "train_subset" else 1)))
vals_flip = safe_shap_values(e_best, Xt_flip)
torch.set_grad_enabled(False)
sh_flip = vals_flip[0] if isinstance(vals_flip, list) else vals_flip
sh_flip_img_unflipped = flip_h(sh_flip[:, 0]).copy()

if HAVE_SPEARMAN:
    corrs = []
    for i in range(N_EXPLAIN):
        a = np.abs(shap_cnn_imgch[i]).ravel()
        b = np.abs(sh_flip_img_unflipped[i]).ravel()
        if np.std(a) < 1e-12 or np.std(b) < 1e-12:
            corrs.append(np.nan)
        else:
            corrs.append(spearmanr(a, b, nan_policy="omit").correlation)
    corrs = np.array(corrs)
    print(f"  Flip‑invariance — Spearman(|SHAP|, original vs unflipped‑flip): "
          f"median={np.nanmedian(corrs):.3f}, IQR=({np.nanpercentile(corrs, 25):.3f}, {np.nanpercentile(corrs, 75):.3f})")
else:
    print("  (Install SciPy to compute Spearman rank correlations.)")

# -----------------------------
# Optional: Integrated Gradients (IG) for triangulation (no external libs)
# -----------------------------
print("\n[Integrated Gradients (IG) — optional triangulation]")
def integrated_gradients(model_torch, X_np, baseline_np, steps=32):
    was_enabled = torch.is_grad_enabled()
    try:
        torch.set_grad_enabled(True)
        model_torch.eval()
        X = torch.from_numpy(np.ascontiguousarray(X_np)).to(DEVICE)
        B = torch.from_numpy(np.ascontiguousarray(baseline_np)).to(DEVICE)
        attr = torch.zeros_like(X)
        for s in range(1, steps + 1):
            t = s / steps
            Xi = B + t * (X - B)
            Xi.requires_grad_(True)
            out = model_torch(Xi)
            grads = torch.autograd.grad(out.sum(), Xi, retain_graph=False, create_graph=False)[0]
            attr += grads
        attr = (X - B) * attr / steps
        return attr.detach().cpu().numpy()
    finally:
        torch.set_grad_enabled(was_enabled)

X_base_ig = X_cnn_test_4.copy(); X_base_ig[:, 0] = 0.0
attr_ig = integrated_gradients(shap_model, X_cnn_test_4, X_base_ig, steps=32)[:, 0]

if HAVE_SPEARMAN:
    igcorrs = []
    for i in range(N_EXPLAIN):
        a = np.abs(attr_ig[i]).ravel()
        b = np.abs(shap_cnn_imgch[i]).ravel()
        if np.std(a) < 1e-12 or np.std(b) < 1e-12:
            igcorrs.append(np.nan)
        else:
            igcorrs.append(spearmanr(a, b, nan_policy="omit").correlation)
    igcorrs = np.array(igcorrs)
    print(f"  IG vs Deep SHAP — Spearman(|attr|): median={np.nanmedian(igcorrs):.3f} "
          f"IQR=({np.nanpercentile(igcorrs, 25):.3f}, {np.nanpercentile(igcorrs, 75):.3f})")
else:
    print("  (Install SciPy to compute Spearman rank correlations.)")

# -----------------------------
# Final, conclusions
# -----------------------------
print("\n=== Conclusions ===")
print("• OLS: clean ring; Ridge: more diffused ring; CNN: crisp double band on inner/outer edges "
      "with local arcs where defects (or the star) dominate.")
print("• Deletion curves (AOPC) quantify faithfulness; higher AOPC → better.")
print("• Bounding‑box metrics (largest component) — CNN concentrates |SHAP| into **smaller**, more **angularly tight** regions, "
      "capturing a **larger share** of |SHAP| with less area, especially on the hard dataset.")
print("• Baseline matters; this cell auto‑selects one by IoU@k + AOPC. IG triangulation is included as a sanity check.")


# Interpreting the SHAP results on the **simple** hollow‑sphere dataset (no star)

This note interprets the SHAP outputs produced for three models trained on the *simple* dataset:

- **OLS**: a plain linear regressor, $\,\hat y = Xw + b$.
- **Ridge**: a regularised linear regressor, $\,\hat y = Xw + b$ with an $L_2$ penalty on $w$.
- **CNN**: a small convolutional network that receives a 4‑channel input (image + $x, y, r$). It predicts a $z$‑scored target and is mapped back to $y$ at evaluation time.

The dataset contains greyscale images of a hollow metal shell. The label $y$ mainly scales the **whole shell** (global density factor), with small variations from illumination and defects. There is *no* bright local star in this dataset; the only genuine signal is concentrated on the **ring** (inner/outer edges and shell interior).

The SHAP framework decomposes a prediction as
$$
f(x) - \mathbb{E}[f(X)] \;=\; \sum_{j=1}^D \phi_j(x),
$$
where $\phi_j(x)$ is the contribution (Shapley value) of feature $j$ for input $x$. For linear models with a mean baseline, this simplifies to
$$
\phi_j(x) = w_j\,(x_j - \mathbb{E}[X_j]).
$$

Below, each figure or metric is explained, followed by what “good” and “bad” patterns look like in this dataset.

---

## 1) Overlay panels: per‑image SHAP on top of the input

**What the panels show.** Each small image overlays a SHAP heatmap on the greyscale input; blue/orange indicate negative/positive contributions to $\hat y$. A strong, thin band on the **shell** is expected because $y$ scales shell intensity.

**Observed patterns.**

- **OLS:** the overlays show tight bands hugging the inner and outer radii. This is the textbook pattern for a linear model when the label is a global scale of the ring: pixels on the ring move $y$ the most.
- **Ridge:** the overlays look more diffuse. Attribution spills into the hollow centre and the outside background. This indicates the model is partly using nuisance correlations (illumination, offset) rather than focusing solely on the ring.
- **CNN:** the overlays show crisp arcs on the ring, often emphasising the inner and outer edges. This indicates that the network has learned to anchor on geometric structure (edges) rather than global brightness alone.

**How to read colours.** If the inner edge is dark and the outer edge is bright, OLS often assigns opposite signs on the two edges because increasing one pixel while holding the mean baseline changes the predicted scale with opposite effect at different sides of the gradient. The *absolute* value matters most when assessing “where” the model looks.

**Pitfalls.** Visual overlays are suggestive but not quantitative; when features are collinear (adjacent pixels on a ring), contributions are spread and can alternate sign. Use the absolute‑value summaries below to avoid being misled by colour.

---

## 2) Global mean $|\text{SHAP}|$ maps

**What the map shows.** Averaging $|\phi|$ across the explained test subset answers: *“On average, where in the image does the model place credit?”*

**Good vs bad.**

- A **good** map is an annulus: most energy lies on the shell; the hollow interior and the outside are dark.
- A **bad** map shows energy in the background or the hollow centre.

**Observed.**
- **OLS:** a clean annulus.
- **Ridge:** substantial energy in the hollow and outside, implying spurious reliance on global nuisance variation.
- **CNN:** a clean annulus, often with a slightly stronger inner rim, consistent with edge‑based features inside the network.

---

## 3) Radial profiles of $|\text{SHAP}|$

**Definition.** Pixels are binned by radius $r$ from the centre; the plot shows the mean $|\phi|$ per radius bin:
$$
\text{profile}(r) \;=\; \mathbb{E}[\,|\phi(x,y)| \mid \text{radius}(x,y)=r\,].
$$

**Expected shape.** Two peaks at the inner and outer shell boundaries (around $\sim20$ px and $\sim40$ px) because $y$ depends on shell intensity.

**Observed.**
- **OLS:** a sharp peak near the inner boundary and a shoulder towards the outer boundary — a very clear “ring” signature.
- **Ridge:** no clear double‑peak; the curve is noisy and relatively flat, reflecting diffuse credit.
- **CNN:** a strong inner peak with a secondary outer peak — the desired pattern for an edge‑aware model.

**Pitfalls.** If the shell mask used for binning is too wide or mis‑centred, peaks blur. Always check geometry constants match the dataset.

---

## 4) Region‑based metrics (area‑biased vs area‑fair)

Two complementary summaries are reported.

- **Region share (area‑biased):** fraction of total $|\phi|$ mass inside each region (hollow / shell / outside).
- **Mean per pixel (area‑fair):** average $|\phi|$ per pixel *within* each region, comparable despite different areas.

**Numbers (provided).**

### OLS
- **Share:** Hollow **3.33%**, Shell **91.49%**, Outside **5.18%**.
- **Per pixel:** Hollow **0.000068**, Shell **0.000424**, Outside **0.000024**.
- **Interpretation:** Almost all attribution sits on the shell (good). Per‑pixel intensity in the shell is ≈ **18×** higher than outside ($0.000424/0.000024$), which is a strong localisation score.

### Ridge
- **Share:** Hollow **11.26%**, Shell **36.84%**, Outside **51.90%**.
- **Per pixel:** Hollow **0.000091**, Shell **0.000068**, Outside **0.000096**.
- **Interpretation:** More than half the mass is **outside**. Per‑pixel intensity is actually larger outside than on the shell (≈ **1.4×**), signalling reliance on background. This is a red flag for interpretability.

### CNN (baseline = `train_subset`)
- **Share:** Hollow **4.86%**, Shell **93.63%**, Outside **1.50%**.
- **Per pixel:** Hollow **0.000065**, Shell **0.000292**, Outside **0.000005**.
- **Interpretation:** Strong focus on the shell (excellent). Per‑pixel intensity in the shell is ≈ **58×** larger than outside, the crispest localisation of the three.

**Pitfalls.** “Share” is influenced by region area: the shell covers many pixels, so shares are naturally large. The per‑pixel metric controls for this and is often the better fairness check.

---

## 5) Concentration of importance (sparsity of $|\text{SHAP}|$)

**Definition.** The **concentration** at top‑$p\%$ is
$$
\text{Conc}_{p\%} \;=\;
\frac{\sum_{j \in \text{top }p\%\text{ by }|\phi|} |\phi_j|}
{\sum_{j} |\phi_j|}.
$$

**Numbers (provided).**
- **OLS:** top‑1% → **29.33%**, top‑5% → **57.68%** (highly concentrated; a small fraction of pixels carries most of the credit).
- **Ridge:** top‑1% → **7.79%**, top‑5% → **23.32%** (diffuse and noisy credit).
- **CNN:** top‑1% → **20.13%**, top‑5% → **51.98%** (concentrated but spread along thin arcs).

**Interpretation.** Higher concentration indicates sharper, more local attributions. Extremely high values can also appear if the model uses a single pixel artefact; always corroborate with the region metrics and overlays.

---

## 6) Alignment with ground truth shell (IoU and Precision at top‑$k$)

**Definition.** For the top‑$k\%$ $|\phi|$ mask $M$ and the shell mask $S$,
$$
\text{IoU} \;=\; \frac{|M \cap S|}{|M \cup S|}, \qquad
\text{Precision} \;=\; \frac{|M \cap S|}{|M|}.
$$

**Numbers (provided).**

- **OLS:** Precision at 1–10% ≈ **0.98–0.99**; IoU rises from **0.021 → 0.217**.
- **Ridge:** Precision **0.10 → 0.30**, IoU **0.002 → 0.058**.
- **CNN:** Precision **0.985–0.989**, IoU **0.022 → 0.217**.

**Interpretation.**
- Precision near 1.0 means the **top‐ranked** pixels almost all fall on the shell (good). The small IoU is expected because the top‑$k$ mask covers a tiny fraction of the shell area; union is large, intersection is small. Use precision as the main signal here.
- Ridge’s low precision confirms the diffuse overlays: many top‑ranked pixels land **off** the ring.

**Pitfalls.** IoU is not scale‑free: for very sparse masks, a model can be perfectly aligned yet score a small IoU. Report both precision and IoU to avoid misinterpretation.

---

## 7) Deletion curves and AOPC (faithfulness check)

**Procedure.** Iteratively zero out the top‑$p\%$ $|\phi|$ pixels in the **image channel** (keeping coordinate channels intact for the CNN) and re‑score on the same explained subset. The plot shows $R^2$ after deletion. The **AOPC** score is the area under the $R^2$ drop curve:
$$
\text{AOPC} \;=\; \int_0^{p_{\max}} \big(R^2_{\text{baseline}} - R^2_{\text{after deletion at }p}\big)\, dp.
$$
Higher AOPC indicates that removing SHAP‑ranked pixels impairs the model the most ⇒ **more faithful** attribution.

**Numbers (provided).**
- **CNN:** AOPC ≈ **0.1586** for the chosen baseline (`train_subset`).
- **OLS:** (from plot) AOPC ≈ **0.13** — noticeable drop, consistent with focused ring attributions.
- **Ridge:** AOPC ≈ **0.036** — relatively small; the model is less harmed by removing its top‑ranked pixels, consistent with diffuse/noisy SHAP.

**Baseline sensitivity.** For the CNN, alternative backgrounds give very similar fidelity: mean baseline AOPC ≈ **0.1623**, median ≈ **0.1582**, blurred mean ≈ **0.1550**. This closeness is a healthy sign; wildly different AOPCs across baselines suggest unstable explanations.

**Pitfalls.**
- Deletion creates **distribution shift** (images with blacked‑out pixels are not from the data distribution). This is a proxy test, not a proof.
- If the model uses global averages or coordinate channels, deleting image pixels may not fully disrupt its computation; interpret in that context.

---

## 8) Baseline bootstrapping (robustness across CNN backgrounds)

Using several random train‑subset backgrounds for Deep SHAP gives **stable** region shares for the CNN:
- Hollow/Shell/Outside ≈ **0.050±0.003 / 0.935±0.003 / 0.016±0.000**.

Small standard deviations indicate low sensitivity to which training images define the background distribution. This reduces the risk that results are an artefact of a single background choice.

---

## 9) Putting it together — model‑by‑model

**OLS (linear, no regularisation)**
- Visual and quantitative evidence agree: attributions concentrate on the **shell** (share ≈ 91.5%; precision ≈ 0.98).
- High concentration of $|\phi|$ (top‑5% captures ≈ 57.7%) and a clear two‑peak radial profile.
- Good deletion behaviour (AOPC ≈ 0.13).  
**Take‑home:** OLS is a good explainer on this dataset because the label is essentially a linear global scale of the ring.

**Ridge (linear + $L_2$)**
- Attribution mass drifts off the ring (outside share ≈ 51.9%), precision low (≤ 0.30).
- Radial profile lacks sharp peaks; concentration is low.  
**Take‑home:** The regulariser and standardisation damp useful contrasts and spread credit to nuisance signals, so explanations are less reliable here.

**CNN (edge‑aware, with coordinates)**
- Strong ring localisation (shell share ≈ 93.6%, precision ≈ 0.99) and crisp radial peaks.
- High concentration and the strongest AOPC among the three. Baseline bootstrapping shows stability.  
**Take‑home:** The CNN discovers physically meaningful cues (edges, shell arcs). Despite extra capacity, attribution remains local and faithful on this dataset.

---

## 10) Common traps and how to avoid them

- **Confusing sign with importance.** Use $|\phi|$ to assess *where* the model looks; the sign depends on the baseline and on how intensities map to $y$.
- **Over‑interpreting IoU.** Small IoU at very small $k$ is normal for thin structures; focus on precision and per‑pixel means.
- **Ignoring baselines.** Deep SHAP depends on a background distribution. Check several baselines and report stability (as done above).
- **Assuming deletion is causal.** Deletion scores are helpful but not causal proofs; combine with multiple diagnostics.
- **Geometry mismatch.** If the inner/outer radii used to define the shell masks are off, every radial/IoU metric degrades. Confirm geometry constants before comparing runs.

---

## 11) What “good” looks like on this dataset

- Overlay panels: thin arcs/bands exactly on the shell.
- Global mean $|\text{SHAP}|$: clean annulus.
- Radial profile: two distinct peaks near the known inner/outer radii.
- Region metrics: shell share $\gtrsim 90\%$, shell per‑pixel $|\phi|$ $\gg$ outside.
- Precision@k: $\approx 0.98$–$0.99$ for $k \in \{1,2,5,10\}\%$.
- Concentration: top‑1% $\gtrsim 20\%$, top‑5% $\gtrsim 50\%$.
- Deletion (AOPC): large and smooth $R^2$ drop.

Anything markedly different (e.g., energy in the background, flat radial profile, low precision) indicates reliance on nuisance signals or a mismatch between model and data.



# SHAP interpretations — **simple (no‑star) dataset** - part 2

This note interprets the SHAP outputs for OLS, Ridge, and the CNN when trained and explained on the **simple hollow‑sphere** dataset (no star). The aim is to read each visual and metric in context and to separate *global* behaviour from *local* behaviour.

---

## 1) Reading the figures at a glance

- **Overlays (grey image + red/blue SHAP):** red ≈ positive contribution to $\hat y$, blue ≈ negative; the magnitude is the opacity.
- **Global mean $|{\rm SHAP}|$:** average absolute attribution per pixel over the explained subset; reveals *where* the model tends to look.
- **Radial profile:** mean $|{\rm SHAP}|$ as a function of radius; inner and outer shell edges should appear as peaks.
- **Region shares vs shell:** split attributions across *hollow* (inside), *shell*, and *outside*.
- **IoU@k / Precision@k:** alignment of the model’s **top‑$k$%** $|{\rm SHAP}|$ pixels with the ground‑truth **shell mask**.
- **Deletion curve (AOPC):** remove the top‑$|{\rm SHAP}|$ pixels and re‑score $R^2$; faithful explanations produce faster drops.

---

## 2) Definitions (used throughout)

Let $S\in\mathbb{R}^{H\times W}$ be a SHAP map (image channel), and let $\Omega\subset\{1,\dots,H\}\times\{1,\dots,W\}$ be a region (hollow, shell, outside).

**Region share (area‑biased):**
$$
\mathrm{share}_\Omega(S)
=\frac{\sum_{(i,j)\in\Omega}|S_{ij}|}{\sum_{(i,j)}|S_{ij}|}.
$$

**Region mean (area‑fair):**
$$
\mathrm{mean}_\Omega(S)=\frac{1}{|\Omega|}\sum_{(i,j)\in\Omega}|S_{ij}|.
$$

**Top‑$p$ concentration:**
$$
\mathrm{conc}_p(S)=\frac{\sum_{m\in\mathcal{T}_p}|S_m|}{\sum_m |S_m|},
$$
where $\mathcal{T}_p$ contains the top $p$ fraction of pixels by $|S|$.

**Alignment vs shell (IoU / Precision at $k\%$):**
$$
\mathrm{IoU}_k=\frac{|M_k\cap \text{Shell}|}{|M_k\cup \text{Shell}|},\qquad
\mathrm{Prec}_k=\frac{|M_k\cap \text{Shell}|}{|M_k|},
$$
where $M_k$ is the binary mask of the top‑$k\%$ $|S|$.

**Deletion AOPC** (area under performance drop curve):
$$
\mathrm{AOPC}
=\int_0^{f_{\max}}\big(R^2_{\text{base}}-R^2(f)\big)\,\mathrm{d}f
\;\;\approx\;\;\sum_\ell \tfrac12\,(d_{\ell}+d_{\ell+1})\,(f_{\ell+1}-f_\ell),
$$
with $d_\ell=R^2_{\text{base}}-R^2(f_\ell)$. Larger is better (faster degradation when removing truly important pixels).

**Bounding‑box selection (top‑$p$% in shell):**
$$
M
=\operatorname{LargestComponent}\Big(\big\{|S|\ge t_p(S\,|\,\text{shell})\big\}\cap\text{Shell}\Big).
$$

**Star‑disc share & enrichment** (used later as a “local‑cue” probe, even without a real star):
$$
\text{share}=\frac{\sum_{D\cap\text{Shell}}|S|}{\sum_{\text{Shell}}|S|},\qquad
\text{enrich}=\frac{\text{share}}{\frac{|D\cap\text{Shell}|}{|\text{Shell}|}}.
$$

---

## 3) OLS — interpretation

**What is expected:** a clear ring since the label is a global scale of ring intensity. OLS is linear in pixels and will put weight mainly on the shell boundaries.

**Numbers**
- **Region share:** Hollow **3.33%**, Shell **91.49%**, Outside **5.18%**.
- **Per‑pixel mean:** Hollow **6.8e‑5**, Shell **4.24e‑4**, Outside **2.4e‑5**.
- **Concentration:** top‑1% **29.33%**, top‑5% **57.68%**.
- **IoU / Precision:** at 1%: IoU **0.021**, Precision **0.967**; at 10%: IoU **0.217**, Precision **0.984**.
- **Deletion AOPC:** **0.1317** (good for a linear model).

**Reading the plots**
- **Overlays / global mean:** a tight band on the ring; small bleed outside due to slight mis‑registration and noise.
- **Radial profile:** a sharp peak near the inner edge and a secondary shoulder approaching the outer edge — edges carry the strongest linear signal.

**Local vs global**
- **Local:** the top‑1% mask forms thin arcs on the ring.
- **Global:** the average map is annular, consistent with a pixelwise linear explanation of a radial shell.

**Pitfalls**
- High **precision** but **low IoU** at small $k$ is normal: a thin ring occupies a small fraction of the image; even a perfect ring mask has low IoU at tiny $k$.

---

## 4) Ridge — interpretation

**What is expected:** similar ring signal but more **diffuse** because the $\ell_2$ penalty shrinks and spreads coefficients.

**Numbers**
- **Region share:** Hollow **11.26%**, Shell **36.84%**, Outside **51.90%**.
- **Per‑pixel mean:** Hollow **9.1e‑5**, Shell **6.8e‑5**, Outside **9.6e‑5**.
- **Concentration:** top‑1% **7.79%**, top‑5% **23.32%**.
- **IoU / Precision:** at 1%: IoU **0.002**, Precision **0.102**; at 10%: IoU **0.058**, Precision **0.302**.
- **Deletion AOPC:** **0.0356** (weak fidelity).

**Reading the plots**
- **Overlays:** scattered red/blue speckle across the frame; still some ring structure but attribution “escapes” into the background.
- **Radial profile:** flat with only faint peaks → credit is not concentrated on the physical edges.

**Interpretation**
- Ridge stabilises training but the explanation becomes **low‑contrast**: the model assigns small weights broadly, so no small set of pixels dominates performance. This is consistent with the low AOPC.

---

## 5) CNN — interpretation (baseline selected by fidelity)

**What is expected:** $|{\rm SHAP}|$ concentrates along **both** inner and outer edges; attributions appear as short **arcs** because small misalignments and augmentation place edge evidence locally.

**Numbers**
- **Baseline chosen:** `train_subset` (near‑ties with mean/median).
- **Bootstrapping:** region share mean±sd = Hollow **0.050±0.003**, Shell **0.935±0.003**, Outside **0.016±0.000**.
- **Region share:** Hollow **4.86%**, Shell **93.63%**, Outside **1.50%**.
- **Per‑pixel mean:** Hollow **6.5e‑5**, Shell **2.92e‑4**, Outside **5e‑6**.
- **Concentration:** top‑1% **20.13%**, top‑5% **51.98%**.
- **IoU / Precision:** at 1%: IoU **0.022**, Precision **0.985**; at 10%: IoU **0.217**, Precision **0.984**.
- **Deletion AOPC:** **0.1586** (best of the three).

**Reading the plots**
- **Overlays / global mean:** crisp double band at inner/outer edges; local arcs rather than full rings — the network exploits *local* edge evidence plus global context (provided by the CoordConv channels).  
- **Radial profile:** two clear peaks at the expected radii, sharper than OLS/Ridge.

---

## 6) Bounding‑box metrics (top‑1% $|{\rm SHAP}|$ *within the shell*)

The box is fitted to the **largest connected component** of the top‑1% $|{\rm SHAP}|$ in the shell. Reported metrics:

- **Area fraction:** $|\text{BOX}\cap\text{Shell}|/|\text{Shell}|$.
- **$|{\rm SHAP}|$ share captured:** $\sum_{\text{BOX}\cap\text{Shell}}|S| / \sum_{\text{Shell}}|S|$.
- **Angular coverage:** minimal angle span covering the selected pixels (degrees).
- **Radial width / shell thickness.**

**Medians (IQR):**
- **OLS:** area **0.002** (0.002–0.003), share **0.025** (0.020–0.033), angle **9.8°** (6.9°–14.1°), radial **0.057** (0.042–0.066).
- **Ridge:** area **0.003** (0.002–0.004), share **0.013** (0.007–0.016), angle **13.9°** (7.3°–17.1°), radial **0.041** (0.023–0.053).
- **CNN:** area **0.005** (0.003–0.007), share **0.036** (0.022–0.053), angle **11.2°** (8.6°–16.6°), radial **0.093** (0.073–0.127).

**Interpretation:** CNN captures **more** of the shell’s total evidence with a slightly larger box (because it often spans both edges), while Ridge captures **less** despite comparable area.

---

## 7) “Global star‑capture” on a dataset **without** a star

The procedure still finds a **proxy**: a small disc on the shell at the **brightest ring‑weighted blob**.  
Metrics remain valid but should be read as *“does the model favour any local arc on the ring?”*

- **$|{\rm SHAP}|$ in the disc (share):** OLS **0.005**, Ridge **0.008**, CNN **0.012** (medians).
- **Enrichment (share / area):** OLS **0.458**, Ridge **0.697**, CNN **1.132**.
- **Angular error** (disc centre vs largest component): broad IQRs and large medians (e.g., OLS **51°**, CNN **76°**) → no consistent local hotspot, as expected when there is **no true star**.

**Good sign:** enrichment close to **1** with large angular variability — indicates no persistent spurious cue.  
**Bad sign:** consistently **high** enrichment with **small** angular error — would suggest an artefact acting like a fixed defect.

---

## 8) Polar superpixels (regional summary)

Binning the image into $(r,\theta)$ cells gives an angular profile at the **peak ring**.

- **OLS:** peak ring ≈ inner edge; **CV ≈ 0.142** (mild anisotropy).
- **Ridge:** peak ring faint; **CV ≈ 0.105** (nearly isotropic but weak signal).
- **CNN:** peak ring between inner/outer edges; **CV ≈ 0.342** (deliberate focus on *arcs*).

Here, $\mathrm{CV}=\mathrm{std}(m_\theta)/\mathrm{mean}(m_\theta)$ for the angular means $m_\theta$.  
Small CV ⇒ isotropic ring; larger CV ⇒ the model locks onto *segments* (useful if local defects matter).

---

## 9) Sanity checks

- **Flip‑invariance (CNN):** Spearman correlation of $|{\rm SHAP}|$ before/after a horizontal flip (then unflipped) — median **0.718**.  
  Strong but not perfect: augmentation and local pooling create small asymmetries.
- **Integrated Gradients vs Deep SHAP:** Spearman between $|{\rm IG}|$ and $|{\rm SHAP}|$ — median **0.485**.  
  Different estimators agree moderately; disagreement tends to appear off‑ring.

---

## 10) Takeaways for the **simple** dataset

- **OLS** provides a tidy, physically plausible annulus; it explains a sizeable portion of performance with a compact set (AOPC **0.1317**).
- **Ridge** spreads credit broadly and is less faithful (AOPC **0.0356**).
- **CNN** concentrates on **edge arcs** and is the most faithful by deletion (AOPC **0.1586**).  
  Local boxes capture more $|{\rm SHAP}|$ with modest angular spans — exactly what an edge‑seeking CNN should do on this dataset.


# Synthetic X‑ray dataset — “Death‑Star” contrast (energy‑neutral star)

**Purpose.** Construct a dataset where the **star** is an unmistakable **local** driver of the target while a **small global anchor** still exists. The star is injected so that, on average **within the shell**, global brightness **does not increase** (energy‑neutral construction). Linear models that rely on per‑image mean or broad ring structure therefore struggle, whereas a CNN can localise the star at **any** angle on the ring and gain a strong predictive advantage.

---

## 1) Image formation and label (high‑level)

Let the base hollow‑sphere density be a function on pixels $(i,j)$:
- inner radius $R_0 \approx 20$, outer radius $R_1 \approx 40$;
- the shell profile fades from the inner to the outer boundary with a shape exponent $\gamma\!\in\![0.95,1.05]$;
- gentle low‑frequency **shading** $S(i,j)$ is applied **multiplicatively in the density domain**.

A **smooth sensor response** converts density $x$ to brightness
$$
g(x) \;=\; 1 - e^{-x/\tau}, \qquad \tau>0,
$$
so images stay in $[0,1)$ without hard clipping and with star highlights preserved.

The observed label is
$$
y \;\approx\; c_0\,y_{\text{density}}
\;+\; \beta_{\text{edge}}\,E
\;-\; \beta_{\text{crack}}\,C
\;-\; \beta_{\text{void}}\,V
\;-\; \beta_{\text{aniso}}\,A
\;-\; \beta_{\text{int}}\,(\!C\!\cdot\!A\!)
\;+\; \gamma_{\star}\,y_{\star}
\;+\; \text{label noise},
$$
where
- $y_{\text{density}}$ is a **small** global scaling ($\approx\!1\pm2\%$) to keep ordinary least squares (OLS) and Ridge meaningful but not dominant;
- $E,C,V,A$ encode shell edges, cracks, voids and anisotropy from the **pre‑sensor** density field (all at **small** weight);
- $y_{\star}$ is a **post‑sensor**, area‑normalised **star brightness** (defined below) that dominates label variance.

---

## 2) Energy‑neutral star (why linear models struggle)

A five‑arm star is placed **on the shell** at a random angle every time. In the **density domain**, let
$$
\text{add}(i,j) \;=\; \text{STAR\_DENSITY\_ADD} \times \text{star\_map}(i,j),
$$
and let $\mathbb{1}_{\text{shell}}(i,j)$ be the shell mask. Define the **shell mean** of the addition
$$
\mu_{\text{add}} \;=\; \frac{1}{|\text{shell}|}\sum_{(i,j)\in\text{shell}} \text{add}(i,j).
$$
Apply an energy‑neutral **compensation**
$$
\text{comp} \;=\; \text{STAR\_COMPENSATE}\times \mu_{\text{add}},
$$
and form two compensated density fields
$$
\begin{aligned}
\text{dens0c}(i,j)   &= \max\big(\text{dens0}(i,j) - \text{comp}\,\mathbb{1}_{\text{shell}}(i,j),\,0\big),\\[2pt]
\text{dens\_star}(i,j) &= \max\big(\text{dens0}(i,j) + \text{add}(i,j) - \text{comp}\,\mathbb{1}_{\text{shell}}(i,j),\,0\big).
\end{aligned}
$$

Interpretation:
- The **mean density on the shell** is preserved (up to shading and the subsequent non‑linear sensor), so the **per‑image mean brightness** carries **little** information about the star.
- Because the star appears at a **random angle**, any single **fixed‑pixel** linear weight map cannot consistently align with it across images. A CNN, however, can learn **local** star detectors that match anywhere on the ring.

Finally, brightness “without star” and “with star” are
$$
I_0 = g\big(\text{dens0c}\times S\big), \qquad I_{\star} = g\big(\text{dens\_star}\times S\big).
$$

---

## 3) Star term in the label (what the model is rewarded to see)

Define the **post‑sensor excess brightness** due to the star
$$
\Delta I \;=\; I_{\star} - I_0 \;\;\ge 0,
$$
and its **area‑normalised** mean over the shell
$$
y_{\star} \;=\; \frac{1}{|\text{shell}|}\sum_{(i,j)\in\text{shell}} \Delta I(i,j).
$$

The noise‑free label is
$$
\boxed{\;
y_{\text{true}}
= c_0\,y_{\text{density}}
+ \beta_{\text{edge}}\,E
- \beta_{\text{crack}}\,C
- \beta_{\text{void}}\,V
- \beta_{\text{aniso}}\,A
- \beta_{\text{int}}\,(C\!\cdot\!A)
+ \gamma_{\star}\,y_{\star}
\;}
$$
and the observed label is $y_{\text{obs}}=y_{\text{true}}+\varepsilon$, with small Gaussian noise. This construction ensures that the label **rewards exactly what the sensor shows**; a CNN that highlights the star in the image will be rewarded by $y_{\star}$.

---

## 4) Why the chosen sensor helps (smooth, no hard clipping)

The response $g(x)=1-e^{-x/\tau}$ is **monotone** and **concave**, producing bright highlights that **do not flatten** at 1.0. This avoids:
- gradient dead‑zones (typical of hard clipping);
- spuriously high counts at pixel value 1.0;
- “cheating” via global exposure changes.

**Diagnostic expectation:** “Pixel mass near 1.0” $\approx 0\%$.

---

## 5) What each diagnostic now means

Let $m$ be the per‑image mean brightness and let $s$ be the stored star‑brightness proxy ($y_{\star}$).

- **Corr$(y_{\text{obs}}, m)$** — should be **modest** (target range 0.25–0.45).  
  A larger value suggests the star is unintentionally changing the global mean (reduce $\text{Y\_DENSITY\_RANGE}$ or increase $\text{STAR\_COMPENSATE}$).

- **Corr$(y_{\text{obs}}, s)$** — should be **high** (aim $\ge 0.65$).  
  Confirms that the label is actually driven by the **post‑sensor** star signal.

- **Incremental $R^2$** from adding the star proxy to a mean‑only model:  
  Fit $y\approx b_0+b_1 m$ (model 1) and $y\approx b_0+b_1 m + b_2 s$ (model 2). Report
  $$
  \Delta R^2 \;=\; R^2_{\text{model 2}} - R^2_{\text{model 1}} \;\;\text{(aim }\ge 0.30\text{)}.
  $$
  This quantifies how much **extra** variance the star explains beyond global brightness.

- **Mean image** — shows the shell but **no star imprint** (the star angle is random and cancels in the average).

- **Label histogram** — typically unimodal with wider spread than the simple/no‑star dataset due to the star term.

---

## 6) Why OLS/Ridge underperform and CNN excels

- **OLS/Ridge:** features are **fixed pixel locations**. A star at arbitrary angles cannot be represented by a **single** static weight pattern; any average weight map dilutes the star across angles. Energy‑neutral construction further removes the “mean brightness” shortcut.
- **CNN:** learns **local** filters and, via pooling and ring priors, is **equivariant** to the star’s angular placement. The star branch and concat pooling create a path where small, intense features survive global averaging and influence the output.

**Practical expectation:** Ridge $R^2$ on Test stays modest (often $\ll 0.5$), whereas the CNN achieves a substantially higher $R^2$ by focusing on the star.

---

## 7) Visual guide: what “good” vs “bad” looks like

- **Good (as designed):**
  - Pixel mass near 1.0 $\approx 0\%$.
  - Corr$(y_{\text{obs}}, m)$ in 0.25–0.45.
  - Corr$(y_{\text{obs}}, s)\ge 0.65$.
  - $\Delta R^2 \ge 0.30$.
  - Mean image shows **no star**.
  - Sample images: a crisp, bright star on the shell; background and ring remain realistic.

- **Bad (needs tuning):**
  - Pixel mass near 1.0 $\gg 0\%$ → decrease brightness or increase $\tau$ (sensor less aggressive).
  - Corr$(y_{\text{obs}}, m)$ too high → increase $\text{STAR\_COMPENSATE}$ or narrow $\text{Y\_DENSITY\_RANGE}$; reduce $\text{SHADING\_ALPHA\_RANGE}$ if global illumination dominates.
  - Corr$(y_{\text{obs}}, s)$ too low → increase $\text{STAR\_DENSITY\_ADD}$ or $\gamma_{\star}$; consider slightly sharper arms/core.
  - Mean image displays a faint star ghost → check that the star angle is uniformly random; ensure compensation is on the **shell only**.

---

## 8) Implementation notes tied to the code

1) **Base shell.**  
   The ring image $B(i,j)$ is built from the elliptical radius $r$ using a normalised depth $t=(r-R_0)/(R_1-R_0)$ and a profile $B=(1-t)^{\gamma}$ on the shell, zero elsewhere.

2) **Defects.**  
   Anisotropy, cracks and voids adjust the **density** before the sensor. Their weights in the label are **small** so the star dominates without being the only signal.

3) **Star map.**  
   A 5‑arm pattern with tight Gaussian core/arms is placed at $(r_{\star},\theta_{\star})\approx$ mid‑shell $\pm$ jitter. Multiplying by $\text{STAR\_DENSITY\_ADD}$ yields a strong local addition in the density domain.

4) **Energy‑neutral compensation.**  
   Subtract a constant $\text{comp}$ **only inside the shell** so the shell‑mean density remains unchanged. This keeps the **global mean** anchor small.

5) **Shading in density domain.**  
   Multiplicative shading before $g(\cdot)$ preserves highlight structure and avoids post‑sensor renormalisation artefacts.

6) **Post‑sensor star term.**  
   $y_{\star}$ is computed from $I_{\star}-I_0$ **after** the sensor. The label therefore encourages the network to match what the image actually shows.

---

## 9) Parameter intuition (how each affects difficulty)

- **STAR\_DENSITY\_ADD**: increases local brightness; too small → star is subtle; too large → risk of sensor saturation (watch the pixel‑mass diagnostic).
- **STAR\_COMPENSATE** ($0\ldots 1$): increases energy neutrality; higher values reduce correlation with global mean.
- **TAU\_SENSOR**: controls concavity; larger $\tau$ makes $g$ more linear (flatter highlights), smaller $\tau$ increases pop (brighter star) but approaches saturation.
- **Y\_DENSITY\_RANGE**: narrows/widens the global anchor; keep **tight** to force reliance on the star.
- **SHADING\_ALPHA\_RANGE**: too high re‑introduces a global confound via illumination; keep gentle.

---

## 10) Compact formula summary

- Sensor: $g(x)=1-e^{-x/\tau}$.
- Compensation: $\text{comp}=\alpha\,\frac{1}{|\text{shell}|}\sum_{(i,j)\in\text{shell}}\text{add}(i,j)$.
- Compensated densities:
  $$
  \text{dens0c}=\max(\text{dens0}-\text{comp}\,\mathbb{1}_{\text{shell}},0),\quad
  \text{dens\_star}=\max(\text{dens0}+\text{add}-\text{comp}\,\mathbb{1}_{\text{shell}},0).
  $$
- Brightness: $I_0=g(\text{dens0c}\cdot S)$, $I_{\star}=g(\text{dens\_star}\cdot S)$.
- Star term: $y_{\star}=\frac{1}{|\text{shell}|}\sum_{\text{shell}}\big(I_{\star}-I_0\big)$.
- Label: as boxed in §3.
- Diagnostics: $\text{Corr}(y,m)$, $\text{Corr}(y,s)$, $\Delta R^2$.

---

## 11) What the plots should convey

- **Sample grid:** the star is obvious and always on the shell; angle varies.
- **Histogram:** broader spread than the simple/no‑star set.
- **Mean image:** no star imprint, crisp ring.
- **Printed diagnostics:** low pixel mass at 1, moderate Corr$(y,m)$, high Corr$(y,s)$, healthy $\Delta R^2$.

**Bottom line.** The dataset is deliberately hostile to global, pixel‑fixed linear models and favourable to **local, geometry‑aware** learners. Success looks like **poor Ridge** but **strong CNN**, especially once explanations (SHAP/IG) are examined in later cells.


In [None]:
"""
Synthetic X‑ray dataset of a hollow metal sphere (100×100) — generation & visualisation
---------------------------------------------------------------------------------------

Goal (finalised for "Death‑Star" contrast)
------------------------------------------
Make the **star** an unmistakable *local* driver of `y_true` that a CNN detects easily,
while keeping a **small global linear anchor** (y_density) so OLS/Ridge behave sensibly
. Inject the star in a way that **does not inflate global brightness**
(energy‑neutral within the shell), so pixel‑wise linear models cannot exploit the mean,
whereas a CNN can localise the star anywhere on the ring.

Outputs:
  images      : (N, 100, 100) float32 in [0,1)
  y_true_all  : (N,) noise‑free target
  y_obs_all   : (N,) observed target (y_true + small label noise)

"""

import numpy as np
import matplotlib.pyplot as plt

# ----------------------------
# Reproducibility & settings
# ----------------------------
SEED = 2025
rng = np.random.default_rng(SEED)

IMG_SIZE = 100
N_IMAGES = 1000

# Geometry (nominal)
R0_BASE = 20.0  # inner (hollow) radius
R1_BASE = 40.0  # outer radius
CENTER = (IMG_SIZE - 1) / 2.0

# Illumination & noise (applied in density domain → smooth post‑sensor highlights)
SHADING_ALPHA_RANGE = (0.015, 0.035)  # gentle illumination; low global confound
PIXEL_NOISE_STD = 0.008  # tiny, after sensor response
BACKGROUND_OFFSET_MAX = 0.004

# Sensor response: smooth, monotone, no hard clipping
TAU_SENSOR = 1.7  # smaller → brighter; tuned so star “pops” but stays < 1.0

# Target ranges / noise
Y_DENSITY_RANGE = (0.98, 1.02)  # tight global anchor (keeps linear models "ok" but not great)
Y_LABEL_NOISE_STD = 0.007
Y_TRUE_CLIP = (0.80, 1.80)

# Variability in shape (mild)
RADIUS_JITTER_PX = 2.0
CENTER_JITTER_PX = 1.0
ELLIPTICITY_RANGE = (0.985, 1.015)
PROFILE_GAMMA_RANGE = (0.95, 1.05)

# Star parameters — **very strong local cue**
STAR_ARMS = 5
STAR_CORE_SIGMA_PX = (0.8, 1.2)  # tight core
STAR_ARM_SIGMA_PX = (1.6, 2.2)  # thin arms
STAR_SHARPNESS = 9.0  # angular sharpness
STAR_RADIUS_JITTER = 1.0  # around mid‑shell
STAR_AMP_RANGE = (1.0, 1.8)  # per‑image amplitude variability
STAR_DENSITY_ADD = 2.2  # **additive in density** (pre‑sensor); strong local addition
STAR_COMPENSATE = 1.00  # 0=none, 1=fully subtract star’s mean density on the shell

# Nuisance weights (smaller than star so it stands out)
EDGE_WT = 0.03
CRACK_WT = 0.04
VOID_WT = 0.035
ANISO_WT = 0.03
INTERACT_WT = 0.015
STAR_Y_GAIN = 48.0  # label gain on *post‑sensor* star brightness (strong)


# ----------------------------
# Helper: metrics & plotting
# ----------------------------
def rmse(y_true, y_pred):
    return float(np.sqrt(np.mean((y_true - y_pred) ** 2)))


def print_metrics(split_name, y_true, y_pred):
    from sklearn.metrics import r2_score, mean_absolute_error
    r2 = r2_score(y_true, y_pred)
    mae = mean_absolute_error(y_true, y_pred)
    rms = rmse(y_true, y_pred)
    print(f"[{split_name}] R² = {r2:.4f} | MAE = {mae:.4f} | RMSE = {rms:.4f}")
    return r2, mae, rms


def plot_pred_vs_actual(y_actual, y_pred, title="Predicted vs Actual"):
    plt.figure(figsize=(5, 4))
    plt.scatter(y_actual, y_pred, s=14, alpha=0.75)
    lo = min(float(np.min(y_actual)), float(np.min(y_pred)))
    hi = max(float(np.max(y_actual)), float(np.max(y_pred)))
    pad = 0.02 * (hi - lo) if hi > lo else 0.01
    plt.plot([lo - pad, hi + pad], [lo - pad, hi + pad], linestyle="--", linewidth=1.0)
    plt.xlabel("Actual y");
    plt.ylabel("Predicted y");
    plt.title(title)
    plt.grid(True, alpha=0.3);
    plt.tight_layout();
    plt.show()


def show_heatmap(arr2d, title, cmap=None, with_colorbar=True):
    plt.figure(figsize=(5, 4))
    plt.imshow(arr2d, cmap=cmap)
    if with_colorbar:
        plt.colorbar(fraction=0.046, pad=0.04)
    plt.title(title);
    plt.axis("off");
    plt.tight_layout();
    plt.show()


# ----------------------------
# Coordinates (constant fields)
# ----------------------------
y_coords, x_coords = np.indices((IMG_SIZE, IMG_SIZE))
X0 = x_coords - CENTER
Y0 = y_coords - CENTER


# ----------------------------
# Smooth illumination (mean=1), applied in *density* domain
# ----------------------------
def make_low_frequency_shading(alpha, rng):
    XN = X0 / R1_BASE
    YN = Y0 / R1_BASE
    c = rng.normal(0.0, 1.0, size=6)
    field = (c[0] + c[1] * XN + c[2] * YN + c[3] * XN * YN + c[4] * XN * XN + c[5] * YN * YN)
    field = field - field.mean()
    std = field.std() + 1e-12
    field = field / std
    return 1.0 + alpha * field  # multiplicative factor around 1


# ----------------------------
# Geometry helpers
# ----------------------------
def radial_distance(x, y, cx, cy, sx=1.0, sy=1.0):
    dx = (x - cx) / sx
    dy = (y - cy) / sy
    return np.sqrt(dx * dx + dy * dy)


def angle_field(x, y, cx, cy):
    return (np.arctan2(y - cy, x - cx) + 2 * np.pi) % (2 * np.pi)


# ----------------------------
# Defects (moderate)
# ----------------------------
def make_anisotropy_map(re, theta, R0, R1, amp, k_freq, phase):
    r_mid = 0.5 * (R0 + R1)
    sigma_r = 0.30 * (R1 - R0)
    radial_weight = np.exp(-0.5 * ((re - r_mid) / (sigma_r + 1e-8)) ** 2)
    return amp * np.sin(k_freq * theta + phase) * radial_weight


def add_radial_cracks(dx, dy, re, R0, R1, rng):
    C = np.zeros_like(re, dtype=np.float64)
    n_cr = int(np.clip(rng.poisson(lam=0.8), 0, 3))
    for _ in range(n_cr):
        theta0 = rng.uniform(0, 2 * np.pi)
        depth = rng.uniform(0.12, 0.22)
        sigma = rng.uniform(0.8, 1.4)
        s, c = np.sin(theta0), np.cos(theta0)
        dist_perp = np.abs(-s * dx + c * dy)
        band = (re > (R0 + 0.8)) & (re < (R1 - 0.8))
        C += depth * np.exp(-0.5 * (dist_perp / (sigma + 1e-8)) ** 2) * band
    return np.clip(C, 0.0, 0.60)


def add_voids(dx, dy, re, R0, R1, rng):
    V = np.zeros_like(re, dtype=np.float64)
    n_v = int(np.clip(rng.poisson(lam=0.6), 0, 3))
    for _ in range(n_v):
        theta_c = rng.uniform(0, 2 * np.pi)
        r_c = rng.uniform(R0 + 3.0, R1 - 3.0)
        cx = r_c * np.cos(theta_c);
        cy = r_c * np.sin(theta_c)
        depth = rng.uniform(0.18, 0.28)
        rad = rng.uniform(1.4, 3.0)
        dist2 = (dx - cx) ** 2 + (dy - cy) ** 2
        V += depth * np.exp(-0.5 * dist2 / ((rad + 1e-8) ** 2))
    return np.clip(V, 0.0, 0.65)


# ----------------------------
# Star generator (5‑arm) on the shell
# ----------------------------
def make_star_on_shell(dx0, dy0, re, R0, R1, rng):
    r_mid = 0.5 * (R0 + R1)
    theta_star = rng.uniform(0, 2 * np.pi)
    r_star = r_mid + rng.uniform(-STAR_RADIUS_JITTER, +STAR_RADIUS_JITTER)

    sx = r_star * np.cos(theta_star)
    sy = r_star * np.sin(theta_star)

    U = dx0 - sx
    V = dy0 - sy
    rho = np.sqrt(U * U + V * V)
    phi = (np.arctan2(V, U) + 2 * np.pi) % (2 * np.pi)

    core_sig = rng.uniform(*STAR_CORE_SIGMA_PX)
    arm_sig = rng.uniform(*STAR_ARM_SIGMA_PX)
    amp = rng.uniform(*STAR_AMP_RANGE)

    core = np.exp(-0.5 * (rho / (core_sig + 1e-8)) ** 2)
    arms = np.exp(-0.5 * (rho / (arm_sig + 1e-8)) ** 2) * np.clip(0.5 * (1.0 + np.cos(STAR_ARMS * phi)), 0.0,
                                                                  1.0) ** STAR_SHARPNESS
    star = np.clip(core + 0.9 * arms, 0.0, 1.0)

    band = (re > (R0 + 1.0)) & (re < (R1 - 1.0))
    return amp * star * band  # amplitude‑scaled star map in [0, amp]


# ----------------------------
# Sensor response — smooth, no hard clipping
# ----------------------------
def sensor_response(density_field, tau=TAU_SENSOR):
    # g(x) = 1 - exp(-x/tau)  ∈ [0,1) strictly; monotone; concave.
    return 1.0 - np.exp(-np.maximum(density_field, 0.0) / (tau + 1e-12))


# ----------------------------
# Generate one image + targets
# ----------------------------
def generate_one_image(rng):
    # --- Geometry & base fields
    R0 = R0_BASE + rng.uniform(-RADIUS_JITTER_PX, +RADIUS_JITTER_PX)
    R1 = R1_BASE + rng.uniform(-RADIUS_JITTER_PX, +RADIUS_JITTER_PX)
    if R1 <= R0 + 8.0:
        R1 = R0 + 8.0

    cx = CENTER + rng.uniform(-CENTER_JITTER_PX, +CENTER_JITTER_PX)
    cy = CENTER + rng.uniform(-CENTER_JITTER_PX, +CENTER_JITTER_PX)
    sx = rng.uniform(*ELLIPTICITY_RANGE)
    sy = rng.uniform(*ELLIPTICITY_RANGE)
    gamma = rng.uniform(*PROFILE_GAMMA_RANGE)

    re = radial_distance(x_coords, y_coords, cx, cy, sx=sx, sy=sy)

    base = np.zeros((IMG_SIZE, IMG_SIZE), dtype=np.float64)
    shell = (re > R0) & (re <= R1)
    t = np.empty_like(re)
    t[shell] = (re[shell] - R0) / (R1 - R0)
    base[shell] = (1.0 - t[shell]) ** gamma

    dx0 = (x_coords - cx)
    dy0 = (y_coords - cy)
    theta = angle_field(x_coords, y_coords, cx, cy)

    # --- Defects (moderate) in the base density
    A_amp = rng.uniform(-0.05, 0.05)
    A_kfreq = int(rng.integers(1, 4))
    A_phase = rng.uniform(0.0, 2 * np.pi)
    aniso = make_anisotropy_map(re, theta, R0, R1, A_amp, A_kfreq, A_phase)

    C_map = add_radial_cracks(dx0, dy0, re, R0, R1, rng)
    V_map = add_voids(dx0, dy0, re, R0, R1, rng)

    # --- Linear anchor (global density)
    y_density = float(rng.uniform(*Y_DENSITY_RANGE))

    # Density **before** star (structural)
    dens0 = y_density * base * (1.0 + aniso)
    dens0 = dens0 * (1.0 - C_map) * (1.0 - V_map)
    dens0 = np.maximum(dens0, 0.0)

    # --- Star (always present), **additive in density domain**
    star_map = make_star_on_shell(dx0, dy0, re, R0, R1, rng)  # amplitude‑scaled shape
    add = STAR_DENSITY_ADD * star_map  # strong local addition

    # Energy‑neutral adjustment inside the shell: subtract mean(add) on the shell
    shell_mask = shell.astype(np.float64)
    shell_area = float(shell_mask.sum())
    mu_add_shell = float((add * shell_mask).sum() / (shell_area + 1e-12))
    comp = STAR_COMPENSATE * mu_add_shell

    dens0c = np.maximum(dens0 - comp * shell_mask, 0.0)
    dens_star = np.maximum(dens0 + add - comp * shell_mask, 0.0)

    # --- Illumination shading in **density domain**
    alpha = rng.uniform(*SHADING_ALPHA_RANGE)
    shading = make_low_frequency_shading(alpha, rng)  # mean ~1
    dens0_s = dens0c * shading
    dens_star_s = dens_star * shading

    # --- Sensor response (smooth; no hard clipping)
    img_base = sensor_response(dens0_s, tau=TAU_SENSOR)  # “without star” brightness
    img_star = sensor_response(dens_star_s, tau=TAU_SENSOR)  # “with star” brightness

    # Final observed image (with star), tiny offset + noise
    offset = rng.uniform(0.0, BACKGROUND_OFFSET_MAX)
    img = img_star + offset + rng.normal(0.0, PIXEL_NOISE_STD, size=img_star.shape)
    img = np.clip(img, 0.0, 1.0 - 1e-6).astype(np.float32)  # virtually no mass at 1.0

    # --- Label from what the sensor “sees”
    # Star strength (positive): **post‑sensor** excess brightness within the shell
    star_excess = img_star - img_base
    star_brightness = float(star_excess[shell.astype(bool)].sum() / (shell_area + 1e-8))  # area‑normalised

    # Edge strength / defects (computed from *pre‑sensor* structural field; small weights)
    gx, gy = np.gradient(dens0)  # density edges reflect inner/outer shell sharpness
    grad_mag = np.sqrt(gx * gx + gy * gy)
    edge_band = (np.abs(re - R0) <= 1.5) | (np.abs(re - R1) <= 1.5)
    edge_strength_raw = float(grad_mag[edge_band].mean() if np.any(edge_band) else 0.0)
    EDGE_SCALE = 0.06
    edge_term = np.clip(edge_strength_raw / (EDGE_SCALE + 1e-8), 0.0, 1.5)

    crack_raw = float((C_map * base)[shell.astype(bool)].sum() / (shell_area + 1e-8))
    void_raw = float((V_map * base)[shell.astype(bool)].sum() / (shell_area + 1e-8))
    CRACK_SCALE = 24.0
    VOID_SCALE = 20.0
    crack_term = np.clip(CRACK_SCALE * crack_raw, 0.0, 1.5)
    void_term = np.clip(VOID_SCALE * void_raw, 0.0, 1.5)
    aniso_term = np.clip(abs(A_amp) / 0.10, 0.0, 1.5)
    interact_term = crack_term * aniso_term

    # Compose y_true
    y_true = (1.00 * y_density
              + EDGE_WT * edge_term
              - CRACK_WT * crack_term
              - VOID_WT * void_term
              - ANISO_WT * aniso_term
              - INTERACT_WT * interact_term
              + STAR_Y_GAIN * star_brightness)

    y_true = float(np.clip(y_true, *Y_TRUE_CLIP))
    y_obs = y_true + float(rng.normal(0.0, Y_LABEL_NOISE_STD))
    return img, y_true, y_obs, star_brightness


# ----------------------------
# Generate & summarise
# ----------------------------
def generate_dataset(n_images, rng):
    images = np.empty((n_images, IMG_SIZE, IMG_SIZE), dtype=np.float32)
    y_true = np.empty(n_images, dtype=np.float32)
    y_obs = np.empty(n_images, dtype=np.float32)
    star_b = np.empty(n_images, dtype=np.float32)  # diagnostic: post‑sensor star brightness (our proxy)
    for i in range(n_images):
        img, yt, yo, sb = generate_one_image(rng)
        images[i] = img;
        y_true[i] = yt;
        y_obs[i] = yo;
        star_b[i] = sb
    return images, y_true, y_obs, star_b


images, y_true_all, y_obs_all, _star_brightness_all = generate_dataset(N_IMAGES, rng)

print("Dataset summary:")
print(f"  images.shape = {images.shape} (float32, values in [0,1))")
print(f"  y_true: mean={y_true_all.mean():.4f}, std={y_true_all.std():.4f}, "
      f"min={y_true_all.min():.4f}, max={y_true_all.max():.4f}")
print(f"  y_obs :  mean={y_obs_all.mean():.4f},  std={y_obs_all.std():.4f},  "
      f"min={y_obs_all.min():.4f},  max={y_obs_all.max():.4f}")

# ----------------------------
# Diagnostics (fast, no downstream dependencies)
# ----------------------------
# 1) Pixel mass near 1.0 (smooth sensor → should be ~0)
clip_rate = float(np.mean(images >= (1.0 - 1e-6))) * 100.0

# 2) Correlations with global mean and star proxy
per_img_mean = images.reshape(len(images), -1).mean(axis=1)


def _corr(a, b):
    a = a - a.mean();
    b = b - b.mean()
    denom = (np.sqrt((a * a).sum()) * np.sqrt((b * b).sum()) + 1e-12)
    return float((a * b).sum() / denom)


corr_mean = _corr(y_obs_all, per_img_mean)  # global anchor
corr_star = _corr(y_obs_all, _star_brightness_all)  # star drives label

# 3) Incremental R² when adding star proxy to mean‑only linear model
X1 = np.c_[np.ones_like(per_img_mean), per_img_mean]
X2 = np.c_[np.ones_like(per_img_mean), per_img_mean, _star_brightness_all]
b1 = np.linalg.lstsq(X1, y_obs_all, rcond=None)[0]
b2 = np.linalg.lstsq(X2, y_obs_all, rcond=None)[0]
ss = lambda y, yhat: float(((y - yhat) ** 2).sum())
ss_tot = float(((y_obs_all - y_obs_all.mean()) ** 2).sum() + 1e-12)
r2_1 = 1.0 - ss(y_obs_all, X1 @ b1) / ss_tot
r2_2 = 1.0 - ss(y_obs_all, X2 @ b2) / ss_tot
delta_r2_star = r2_2 - r2_1

print(f"[Diag] Pixel mass near 1.0 = {clip_rate:.2f}%  (smooth sensor → expect ~0.00%)")
print(f"[Diag] Corr(y_obs, per‑image mean brightness) = {corr_mean:.3f}  (global anchor; aim 0.25–0.45)")
print(f"[Diag] Corr(y_obs, star‑brightness proxy)     = {corr_star:.3f}  (star should dominate; aim ≥ 0.65)")
print(f"[Diag] ΔR² from adding star proxy to mean‑only model = {delta_r2_star:.3f}  (aim ≥ 0.30)")


# ----------------------------
# Visual checks
# ----------------------------
def show_image_grid(images, y_true, y_obs, n_show=12, title="Sample simulated images"):
    n_show = int(min(n_show, images.shape[0]))
    idxs = rng.choice(images.shape[0], size=n_show, replace=False)
    cols = int(np.ceil(np.sqrt(n_show)))
    rows = int(np.ceil(n_show / cols))
    fig, axes = plt.subplots(rows, cols, figsize=(1.9 * cols, 1.9 * rows))
    axes = np.atleast_1d(axes).ravel()
    for ax, i in zip(axes, idxs):
        ax.imshow(images[i], cmap="gray", vmin=0.0, vmax=1.0)
        ax.set_title(f"y_true={y_true[i]:.3f}\ny_obs={y_obs[i]:.3f}", fontsize=8)
        ax.axis("off")
    for k in range(n_show, len(axes)):
        axes[k].axis("off")
    plt.suptitle(title, fontsize=12)
    plt.tight_layout()
    plt.show()


show_image_grid(images, y_true_all, y_obs_all, n_show=12,
                title="Synthetic hollow‑sphere X‑ray images (energy‑neutral star; smooth sensor)")

# Label histogram
plt.figure(figsize=(5.5, 4))
plt.hist(y_obs_all, bins=22, edgecolor="black", alpha=0.85)
plt.xlabel("Observed y");
plt.ylabel("Count");
plt.title("Label distribution (y_obs)")
plt.tight_layout();
plt.show()

# Mean image (across the full dataset)
mean_img = images.mean(axis=0)
show_heatmap(mean_img, title="Mean image (all samples)", cmap="gray", with_colorbar=True)

print("\nAll arrays are now in memory: images, y_true_all, y_obs_all.")


In [None]:
"""
Copy of cell 2...

Linear & Ridge regression on the in‑memory dataset
--------------------------------------------------

What this cell does:

1) Builds Train / Validation / Test splits: 64% / 16% / 20%.
2) Fits an Ordinary Least Squares (OLS) linear regression on flattened pixels → y_obs.
   • Notes why OLS can achieve R²≈1 on the training set when p ≫ n (10,000 features vs ~640 samples).
3) Trains a tuned Ridge regression:
   • Pipeline(StandardScaler, Ridge)
   • Hyper‑parameter alpha selected by 5‑fold CV on the *training* split only.
   • Report Train & Val performance for the CV‑selected model.
   • Refit the final Ridge on Train+Val with the chosen alpha; report Test performance.
4) Prints metrics (R², MAE, RMSE) and draws Predicted vs Actual plots.
5) Shows coefficient heatmaps for OLS and final Ridge (mapped back to raw‑pixel space).

"""

import numpy as np
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split, KFold, GridSearchCV
from sklearn.linear_model import LinearRegression, Ridge
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import StandardScaler

# Sanity check: make sure Cell 1 was run
if "images" not in globals() or "y_obs_all" not in globals():
    raise RuntimeError("Cell 1 has not been run. Please run Cell 1 to create images/y_obs_all in memory.")

SEED = 2025
TEST_FRAC = 0.20
VAL_FRAC_WITHIN_TRAINVAL = 0.20  # 20% of the 80% → 16% overall
rng = np.random.default_rng(SEED)

N, H, W = images.shape
D = H * W

# Flatten images to (N, D)
X_all = images.reshape(N, -1).astype(np.float64)
y_all = y_obs_all.astype(np.float64)
y_true_all_local = y_true_all.astype(np.float64)  # for noise ceiling

# Split Train+Val vs Test
X_trainval, X_test, y_trainval, y_test, y_true_trainval, y_true_test = train_test_split(
    X_all, y_all, y_true_all_local, test_size=TEST_FRAC, random_state=SEED
)
# Split Train vs Val
X_train, X_val, y_train, y_val, y_true_train, y_true_val = train_test_split(
    X_trainval, y_trainval, y_true_trainval, test_size=VAL_FRAC_WITHIN_TRAINVAL, random_state=SEED
)

print("Split summary:")
print(f"  Train: {X_train.shape[0]} samples, {X_train.shape[1]} features")
print(f"  Val  : {X_val.shape[0]} samples")
print(f"  Test : {X_test.shape[0]} samples")
print(f"  Note: p (features) = {D:,},  n_train = {X_train.shape[0]:,}.  Here, p ≫ n.\n")

# ------------------------------------------------------------
# 1) OLS Linear Regression (may interpolate when p ≫ n)
# ------------------------------------------------------------
ols = LinearRegression()  # scikit‑learn uses an SVD‑based least‑squares solver
ols.fit(X_train, y_train)

# Predictions
yhat_tr_ols = ols.predict(X_train)
yhat_va_ols = ols.predict(X_val)
yhat_te_ols = ols.predict(X_test)

print("OLS performance (against observed y):")
print_metrics("Train (OLS)", y_train, yhat_tr_ols)
print_metrics("Val   (OLS)", y_val,   yhat_va_ols)
print_metrics("Test  (OLS)", y_test,  yhat_te_ols)

# Why can Train R² be ~1?
# With p (=10,000) >> n_train (~640), an OLS model has enough degrees of freedom to fit the training
# labels almost perfectly (interpolate), especially as y is largely a global scaling of the image.
# The SVD solution picks one of infinitely many solutions (minimum‑norm) that achieve near‑zero
# training error if the design matrix has rank n_train. This is normal in over‑parameterised linear models.

# Noise ceiling on Test (best achievable R² against *noisy* labels)
r2_ceiling_test = 1.0 - np.sum((y_true_test - y_test) ** 2) / np.sum((y_test - np.mean(y_test)) ** 2)
print(f"\n[Reference] Test‑set R² noise ceiling (label noise): {r2_ceiling_test:.4f}\n")

# Visual diagnostics — Predicted vs Actual
plot_pred_vs_actual(y_train, yhat_tr_ols, "Predicted vs Actual (Train, OLS)")
plot_pred_vs_actual(y_val,   yhat_va_ols, "Predicted vs Actual (Validation, OLS)")
plot_pred_vs_actual(y_test,  yhat_te_ols, "Predicted vs Actual (Test, OLS)")

# Coefficient heatmap (OLS)
coef_ols = ols.coef_.reshape(H, W)
show_heatmap(coef_ols, title="OLS: learned pixel weights", cmap=None, with_colorbar=True)

# ------------------------------------------------------------
# 2) Ridge Regression (tuned) — Pipeline(StandardScaler, Ridge)
# ------------------------------------------------------------
# Standardise features because Ridge penalises the magnitude of coefficients and should
# not be affected by arbitrary feature scales.
pipe = Pipeline([
    ("scaler", StandardScaler(with_mean=True, with_std=True)),
    ("ridge",  Ridge(random_state=SEED))
])

# Hyper‑parameter grid for alpha (L2 strength)
alphas = np.logspace(-6, 3, 25)  # 1e-6 … 1e3

cv = KFold(n_splits=5, shuffle=True, random_state=SEED)
param_grid = {"ridge__alpha": alphas}

grid = GridSearchCV(
    estimator=pipe,
    param_grid=param_grid,
    scoring="neg_mean_squared_error",
    cv=cv,
    n_jobs=-1,
    verbose=0,
    refit=True  # refit the best on the *training* split
)

print("Tuning Ridge (5‑fold CV on the training split)…")
grid.fit(X_train, y_train)

best_alpha = grid.best_params_["ridge__alpha"]
print(f"Best alpha (CV on Train): {best_alpha:.6g}")

# Evaluate the CV‑selected model on Train & Val
ridge_cv_model = grid.best_estimator_
yhat_tr_ridge = ridge_cv_model.predict(X_train)
yhat_va_ridge = ridge_cv_model.predict(X_val)

print("\nRidge (CV‑selected) performance:")
print_metrics("Train (Ridge CV)", y_train, yhat_tr_ridge)
print_metrics("Val   (Ridge CV)", y_val,   yhat_va_ridge)

# OPTIONAL: refit on Train+Val using the chosen alpha, then evaluate on Test
final_ridge = Pipeline([
    ("scaler", StandardScaler(with_mean=True, with_std=True)),
    ("ridge",  Ridge(alpha=best_alpha, random_state=SEED))
])
final_ridge.fit(np.vstack([X_train, X_val]), np.hstack([y_train, y_val]))

yhat_te_ridge = final_ridge.predict(X_test)
print_metrics("\nTest  (Ridge final)", y_test, yhat_te_ridge)

# Visual diagnostics — Predicted vs Actual for Ridge (final)
plot_pred_vs_actual(y_train, yhat_tr_ridge, "Predicted vs Actual (Train, Ridge CV model)")
plot_pred_vs_actual(y_val,   yhat_va_ridge, "Predicted vs Actual (Validation, Ridge CV model)")
plot_pred_vs_actual(y_test,  yhat_te_ridge, "Predicted vs Actual (Test, Ridge final model)")

# Coefficient heatmap (Ridge, mapped back to raw pixel space)
#  ŷ = intercept + Σ w_scaled_i * (x_i - mean_i)/std_i
#  ⇒ effective raw‑space weight: w_raw_i = w_scaled_i / std_i
#     effective raw intercept:  b_raw = intercept - Σ w_scaled_i * mean_i / std_i
scaler = final_ridge.named_steps["scaler"]
ridge  = final_ridge.named_steps["ridge"]
w_scaled = ridge.coef_.ravel()
w_raw = w_scaled / (scaler.scale_ + 1e-12)
coef_ridge_raw = w_raw.reshape(H, W)
show_heatmap(coef_ridge_raw, title="Ridge (final): pixel weights (raw‑space)", cmap=None, with_colorbar=True)

# Brief takeaways
print("\nTakeaways:")
print("  • With 10,000 features and ~640 training samples (p ≫ n), OLS can nearly interpolate the training labels,")
print("    often giving R²≈1 on Train. That’s expected and is a symptom of high capacity, not necessarily true signal.")
print("  • Ridge regularisation stabilises the solution, improves conditioning, and typically gives better generalisation.")
print("  • The learned weight maps concentrate on the ring, which is exactly where the signal lies.")
print("  • ‘Noise ceiling’ gives the best possible R² on Test when labels contain measurement noise;")
print("    if your model approaches it, you are close to optimal under the present noise level.")


In [None]:
# ============================================================
# Cell 3 — CNN that learns global density + shell edges + a tiny star
# (SHAP‑compatible, M1‑safe, robust training)
# ============================================================

import numpy as np
import matplotlib.pyplot as plt
import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split

# ----- Preconditions
if "images" not in globals() or "y_true_all" not in globals() or "y_obs_all" not in globals():
    raise RuntimeError("Please run Cell 1 (data generation) first.")
torch.set_grad_enabled(True)

# ----- Device (M1 → MPS if available)
if torch.cuda.is_available():
    DEVICE = "cuda"
elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
    DEVICE = "mps"
else:
    DEVICE = "cpu"
print(f"Using device: {DEVICE}")

# ----- Repro
SEED = 2025
np.random.seed(SEED); torch.manual_seed(SEED)
if DEVICE == "cuda": torch.cuda.manual_seed_all(SEED)

# ----- Split (same recipe as Cell 2)
TEST_FRAC = 0.20
VAL_FRAC_WITHIN_TRAINVAL = 0.20

N, H, W = images.shape
CENTER = (H - 1) / 2.0

X_all = images
y_all = y_obs_all.astype(np.float32)
y_true_all_local = y_true_all.astype(np.float32)

X_trainval, X_test, y_trainval, y_test, y_true_trainval, y_true_test = train_test_split(
    X_all, y_all, y_true_all_local, test_size=TEST_FRAC, random_state=SEED
)
X_train, X_val, y_train, y_val, y_true_train, y_true_val = train_test_split(
    X_trainval, y_trainval, y_true_trainval, test_size=VAL_FRAC_WITHIN_TRAINVAL, random_state=SEED
)

print("Split summary:")
print(f"  Train: {X_train.shape[0]}  | Val: {X_val.shape[0]}  | Test: {X_test.shape[0]}")

# ----- Preprocessing (corner‑median background → train‑stats normalisation)
ys, xs = np.indices((H, W))

def estimate_background_offset(img: np.ndarray, k: int = 10) -> float:
    patches = [img[:k, :k], img[:k, -k:], img[-k:, :k], img[-k:, -k:]]
    return float(np.median(np.concatenate([p.ravel() for p in patches])))

def subtract_background(X: np.ndarray) -> np.ndarray:
    Xp = np.empty_like(X)
    for i, img in enumerate(X):
        bg = estimate_background_offset(img, k=10)
        im = img - bg
        im[im < 0.0] = 0.0
        Xp[i] = im
    return Xp

X_train_s = subtract_background(X_train)
X_val_s   = subtract_background(X_val)
X_test_s  = subtract_background(X_test)

pix_mean = float(X_train_s.mean())
pix_std  = float(X_train_s.std() + 1e-8)

def norm_clamp(Xs):
    Xn = (Xs - pix_mean) / pix_std
    Xn = np.clip(Xn, -8.0, 8.0).astype(np.float32)
    return Xn

X_train_n = norm_clamp(X_train_s)
X_val_n   = norm_clamp(X_val_s)
X_test_n  = norm_clamp(X_test_s)

# CoordConv channels (exported for SHAP)
x_lin = (xs - CENTER) / CENTER
y_lin = (ys - CENTER) / CENTER
r_map = np.sqrt(x_lin**2 + y_lin**2)
r_map = r_map / (r_map.max() + 1e-12)

coord_tensor = torch.from_numpy(np.stack([x_lin, y_lin, r_map], axis=0).astype(np.float32))
coord_hflip  = torch.flip(coord_tensor, dims=[2])
coord_vflip  = torch.flip(coord_tensor, dims=[1])
coord_bflip  = torch.flip(coord_tensor, dims=[1, 2])

# ----- Target scaling (z‑score on TRAIN)
y_mean = float(np.mean(y_train))
y_std  = float(np.std(y_train) + 1e-8)
def to_z(y):   return ((y - y_mean) / y_std).astype(np.float32)
def from_z(z): return z * y_std + y_mean

z_train = to_z(y_train); z_val = to_z(y_val); z_test = to_z(y_test)

# ----- Augmentation (flips; always contiguous → no negative‑stride errors)
def random_flip_image_and_coords(im: np.ndarray):
    hflip = (np.random.rand() < 0.5)
    vflip = (np.random.rand() < 0.5)
    if hflip: im = im[:, ::-1]
    if vflip: im = im[::-1, :]
    im = np.ascontiguousarray(im, dtype=np.float32)
    if   hflip and  vflip: coord_use = coord_bflip
    elif hflip and not vflip: coord_use = coord_hflip
    elif not hflip and vflip: coord_use = coord_vflip
    else: coord_use = coord_tensor
    return im, coord_use

class ImageRegDataset(Dataset):
    def __init__(self, Xn, z, y, train: bool):
        self.Xn = Xn; self.z = z; self.y = y; self.train = train
    def __len__(self): return self.Xn.shape[0]
    def __getitem__(self, i):
        im = self.Xn[i]
        if self.train:
            im, coord_use = random_flip_image_and_coords(im)
        else:
            im = np.ascontiguousarray(im, dtype=np.float32)
            coord_use = coord_tensor
        x_img = torch.from_numpy(im).unsqueeze(0)              # (1,H,W)
        x     = torch.cat([x_img, coord_use], dim=0)           # (4,H,W)
        z     = torch.tensor(self.z[i:i+1], dtype=torch.float32)
        y     = torch.tensor(self.y[i:i+1], dtype=torch.float32)
        return x, z, y

BATCH_SIZE = 128 if DEVICE in ("cuda", "mps") else 64
PIN = (DEVICE == "cuda")
train_loader = DataLoader(ImageRegDataset(X_train_n, z_train, y_train, train=True),
                          batch_size=BATCH_SIZE, shuffle=True, pin_memory=PIN, num_workers=0)
val_loader   = DataLoader(ImageRegDataset(X_val_n,   z_val,   y_val,   train=False),
                          batch_size=BATCH_SIZE, shuffle=False, pin_memory=PIN, num_workers=0)
test_loader  = DataLoader(ImageRegDataset(X_test_n,  z_test,  y_test,  train=False),
                          batch_size=BATCH_SIZE, shuffle=False, pin_memory=PIN, num_workers=0)

# ----- Fixed helper maps (inside the net; external input stays 4‑ch)
class FixedImageStem(nn.Module):
    """Sobel grad‑mag + local contrast from image channel."""
    def __init__(self):
        super().__init__()
        sobel_x = torch.tensor([[1,0,-1],[2,0,-2],[1,0,-1]], dtype=torch.float32).view(1,1,3,3)
        sobel_y = sobel_x.transpose(2,3).contiguous()
        mean3   = torch.full((1,1,3,3), 1.0/9.0, dtype=torch.float32)
        self.conv_sx = nn.Conv2d(1,1,3,padding=1,bias=False)
        self.conv_sy = nn.Conv2d(1,1,3,padding=1,bias=False)
        self.blur3   = nn.Conv2d(1,1,3,padding=1,bias=False)
        with torch.no_grad():
            self.conv_sx.weight.copy_(sobel_x)
            self.conv_sy.weight.copy_(sobel_y)
            self.blur3.weight.copy_(mean3)
        for p in self.parameters(): p.requires_grad = False
    def forward(self, x4):
        img = x4[:, :1]
        gx = self.conv_sx(img); gy = self.conv_sy(img)
        gradmag   = torch.sqrt(torch.clamp(gx*gx + gy*gy, min=1e-12))
        localmean = self.blur3(img)
        highpass  = img - localmean
        return torch.cat([x4, gradmag, highpass], dim=1)  # 4 → 6 ch

# ----- Ring priors (from r channel) + starness branch
class RingPriorsAndStar(nn.Module):
    """
    Inside‑net priors:
      • ring_mid, edge_inner, edge_outer from r channel (Gaussian bands)
      • starness: convs on (img × ring_mid) to produce a focused 'star map'
    Returns: concatenation of [x6, priors(3), star_map(1), global_mean_map(1)] → 6 + 3 + 1 + 1 = 11 channels
    """
    def __init__(self):
        super().__init__()
        # starness sub-net (lightweight, high gain)
        self.s_c1 = nn.Conv2d(1, 16, 3, padding=1)
        self.s_a1 = nn.ReLU(inplace=False)
        self.s_c2 = nn.Conv2d(16, 1, 3, padding=1)
        nn.init.kaiming_normal_(self.s_c1.weight, nonlinearity="relu")
        nn.init.kaiming_normal_(self.s_c2.weight, nonlinearity="relu")
        nn.init.zeros_(self.s_c1.bias); nn.init.zeros_(self.s_c2.bias)

    def forward(self, x6):
        # x6 = [img, x, y, r, gradmag, highpass]
        img = x6[:, :1]
        r   = x6[:, 3:4]
        # ring priors (values tuned to your geometry: inner~0.283, mid~0.435, outer~0.566)
        ring_mid   = torch.exp(-0.5*((r - 0.435)/0.060)**2)
        edge_inner = torch.exp(-0.5*((r - 0.283)/0.040)**2)
        edge_outer = torch.exp(-0.5*((r - 0.566)/0.040)**2)
        # starness on img × ring_mid
        z = img * ring_mid
        z = self.s_a1(self.s_c1(z))
        star_map = torch.relu(self.s_c2(z)) * ring_mid  # keep it on the ring
        # global context map (mean intensity of img)
        gmean = img.mean(dim=(2,3), keepdim=True).expand_as(img)
        return torch.cat([x6, ring_mid, edge_inner, edge_outer, star_map, gmean], dim=1)  # 11 ch

# ----- A small, norm‑free residual block
class ResBlock(nn.Module):
    def __init__(self, c):
        super().__init__()
        self.pad1 = nn.ReflectionPad2d(1); self.c1 = nn.Conv2d(c, c, 3)
        self.act  = nn.ReLU(inplace=False)
        self.pad2 = nn.ReflectionPad2d(1); self.c2 = nn.Conv2d(c, c, 3)
        nn.init.kaiming_normal_(self.c1.weight, nonlinearity="relu")
        nn.init.kaiming_normal_(self.c2.weight, nonlinearity="relu")
        if self.c1.bias is not None: nn.init.zeros_(self.c1.bias)
        if self.c2.bias is not None: nn.init.zeros_(self.c2.bias)
    def forward(self, x):
        y = self.c1(self.pad1(x)); y = self.act(y); y = self.c2(self.pad2(y))
        return self.act(y + x)

class ConcatPool2d(nn.Module):
    def forward(self, x):
        return torch.cat([torch.mean(x, dim=(2,3), keepdim=True),
                          torch.amax(x, dim=(2,3), keepdim=True)], dim=1)

# ----- Features extractor
class FeatureExtractor(nn.Module):
    def __init__(self):
        super().__init__()
        self.stem      = FixedImageStem()          # 4 → 6
        self.ring_star = RingPriorsAndStar()       # 6 → 11
        self.pad0 = nn.ReflectionPad2d(1); self.c0 = nn.Conv2d(11, 64, 3); self.a0 = nn.ReLU(inplace=False)
        self.b1 = ResBlock(64); self.pool1 = nn.MaxPool2d(2)     # 100 → 50
        self.pad1 = nn.ReflectionPad2d(1); self.c1 = nn.Conv2d(64, 128, 3); self.a1 = nn.ReLU(inplace=False)
        self.b2 = ResBlock(128); self.pool2 = nn.MaxPool2d(2)    # 50 → 25
        self.pad2 = nn.ReflectionPad2d(1); self.c2 = nn.Conv2d(128, 160, 3); self.a2 = nn.ReLU(inplace=False)
        self.b3 = ResBlock(160)
        for m in [self.c0, self.c1, self.c2]:
            nn.init.kaiming_normal_(m.weight, nonlinearity="relu")
            if m.bias is not None: nn.init.zeros_(m.bias)
    def forward(self, x4):
        x6  = self.stem(x4)
        x11 = self.ring_star(x6)
        y = self.a0(self.c0(self.pad0(x11))); y = self.b1(y); y = self.pool1(y)
        y = self.a1(self.c1(self.pad1(y)));  y = self.b2(y); y = self.pool2(y)
        y = self.a2(self.c2(self.pad2(y)));  y = self.b3(y)
        return y

# ----- Regressor (SHAP‑compatible: .features / .gap / .head ; TWO Linear)
class CNNRegressor(nn.Module):
    def __init__(self, in_ch=4):
        super().__init__()
        self.features = FeatureExtractor()
        self.gap      = ConcatPool2d()             # preserves “max” path for a tiny star
        self.head     = nn.Sequential(
            nn.Flatten(),
            nn.Linear(2*160, 128), nn.ReLU(inplace=False),
            nn.Dropout(p=0.10),
            nn.Linear(128, 1),
        )
        for m in self.modules():
            if isinstance(m, nn.Linear):
                nn.init.kaiming_normal_(m.weight, nonlinearity="relu")
                if m.bias is not None: nn.init.zeros_(m.bias)
    def forward(self, x):
        x = self.features(x)
        x = self.gap(x)
        x = self.head(x)
        return x.clone()   # SHAP‑safe

model = CNNRegressor(in_ch=4).to(DEVICE)
print(model)
print(f"Parameters: {sum(p.numel() for p in model.parameters()):,}")

# ----- Training (robust but simple; M1‑safe)
LR = 6e-4
WEIGHT_DECAY = 2e-4
EPOCHS = 140
PATIENCE = 20
criterion = nn.MSELoss()
optimiser = torch.optim.AdamW(model.parameters(), lr=LR, weight_decay=WEIGHT_DECAY)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimiser, T_max=EPOCHS, eta_min=1e-5)

best_val = float("inf"); best_state = None
train_losses, val_losses = [], []; no_improve = 0

def _finite(x): return torch.isfinite(x).all().item()

for epoch in range(1, EPOCHS+1):
    # Train
    model.train()
    loss_sum = 0.0
    for xb, zb, _yb in train_loader:
        xb, zb = xb.to(DEVICE), zb.to(DEVICE)
        optimiser.zero_grad(set_to_none=True)
        pred_z = model(xb)
        loss = criterion(pred_z, zb)
        if not _finite(loss):
            # rare on MPS: back off LR and skip this step
            for g in optimiser.param_groups: g["lr"] = max(g["lr"]*0.5, 1e-5)
            continue
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        optimiser.step()
        loss_sum += float(loss.item()) * xb.size(0)
    train_loss = loss_sum / len(train_loader.dataset); train_losses.append(train_loss)

    # Validate
    model.eval()
    loss_sum = 0.0
    with torch.no_grad():
        for xb, zb, _yb in val_loader:
            xb, zb = xb.to(DEVICE), zb.to(DEVICE)
            pred_z = model(xb)
            loss_sum += float(criterion(pred_z, zb).item()) * xb.size(0)
    val_loss = loss_sum / len(val_loader.dataset); val_losses.append(val_loss)
    scheduler.step()

    print(f"Epoch {epoch:03d} | Train L={train_loss:.6f} | Val L={val_loss:.6f} | LR={scheduler.get_last_lr()[0]:.2e}")

    if val_loss < best_val - 1e-7:
        best_val = val_loss
        best_state = {k: v.detach().cpu().clone() for k, v in model.state_dict().items()}
        no_improve = 0
    else:
        no_improve += 1
        if no_improve >= PATIENCE:
            print(f"Early stopping (patience={PATIENCE}).")
            break

# Restore best
if best_state is not None:
    model.load_state_dict(best_state)

# ----- Metrics in y‑space
def metrics_from_loader(name, loader):
    model.eval()
    preds_y, trues_y = [], []
    with torch.no_grad():
        for xb, zb, yb in loader:
            xb = xb.to(DEVICE)
            zhat = model(xb).cpu().numpy().squeeze(-1)
            yhat = from_z(zhat)
            preds_y.append(yhat.astype(np.float32))
            trues_y.append(yb.numpy().squeeze(-1).astype(np.float32))
    y_pred = np.concatenate(preds_y); y_true = np.concatenate(trues_y)
    ss_res = float(((y_true - y_pred)**2).sum())
    ss_tot = float(((y_true - y_true.mean())**2).sum() + 1e-12)
    r2   = 1.0 - ss_res/ss_tot
    mae  = float(np.mean(np.abs(y_true - y_pred)))
    rmse = float(np.sqrt(np.mean((y_true - y_pred)**2)))
    print(f"[{name}] R² = {r2:.4f} | MAE = {mae:.4f} | RMSE = {rmse:.4f}")
    return y_true, y_pred, r2, mae, rmse

print("\nPerformance (against observed y):")
y_tr, yhat_tr, r2_tr, _, _ = metrics_from_loader("Train",      train_loader)
y_va, yhat_va, r2_va, _, _ = metrics_from_loader("Validation", val_loader)
y_te, yhat_te, r2_te, _, _ = metrics_from_loader("Test",       test_loader)

# Noise ceiling on Test (label noise only)
r2_ceiling_test = 1.0 - float(((y_true_test - y_te)**2).sum()) / float(((y_te - y_te.mean())**2).sum() + 1e-12)
print(f"\n[Reference] Test noise‑ceiling R² (label noise): {r2_ceiling_test:.4f}\n")

# ----- Plots
plt.figure(figsize=(6,4))
plt.plot(train_losses, label="Train loss"); plt.plot(val_losses, label="Val loss")
plt.xlabel("Epoch"); plt.ylabel("Loss"); plt.title("Training curves (CNN — ring prior + starness)")
plt.legend(); plt.grid(True, alpha=0.3); plt.tight_layout(); plt.show()

def plot_pred_vs_actual(y_actual, y_pred, title):
    plt.figure(figsize=(5,4))
    plt.scatter(y_actual, y_pred, s=14, alpha=0.75)
    lo = min(float(np.min(y_actual)), float(np.min(y_pred)))
    hi = max(float(np.max(y_actual)), float(np.max(y_pred)))
    pad = 0.02*(hi - lo) if hi > lo else 0.01
    plt.plot([lo - pad, hi + pad], [lo - pad, hi + pad], linestyle="--", linewidth=1.0)
    plt.xlabel("Actual y"); plt.ylabel("Predicted y"); plt.title(title)
    plt.grid(True, alpha=0.3); plt.tight_layout(); plt.show()

plot_pred_vs_actual(y_tr, yhat_tr, "Predicted vs Actual (Train, CNN)")
plot_pred_vs_actual(y_va, yhat_va, "Predicted vs Actual (Validation, CNN)")
plot_pred_vs_actual(y_te, yhat_te, "Predicted vs Actual (Test, CNN)")

# Small test panel (original → preprocessed → prediction)
n_show = min(6, len(X_test))
sel = np.random.choice(len(X_test), size=n_show, replace=False)
fig, axes = plt.subplots(3, n_show, figsize=(2.0*n_show, 5.8))
for j, k in enumerate(sel):
    axes[0, j].imshow(X_test[k],   cmap="gray", vmin=0, vmax=1); axes[0, j].set_title("Original");      axes[0, j].axis("off")
    axes[1, j].imshow(X_test_s[k], cmap="gray", vmin=0, vmax=1); axes[1, j].set_title("Bg‑subtracted"); axes[1, j].axis("off")
    axes[2, j].text(0.02, 0.80, f"y_true≈{y_true_test[k]:.3f}", transform=axes[2, j].transAxes)
    axes[2, j].text(0.02, 0.55, f"y_obs ={y_test[k]:.3f}",      transform=axes[2, j].transAxes)
    axes[2, j].text(0.02, 0.30, f"y_hat ={yhat_te[k]:.3f}",     transform=axes[2, j].transAxes)
    axes[2, j].axis("off")
axes[0,0].set_ylabel("Input"); axes[1,0].set_ylabel("Preproc"); axes[2,0].set_ylabel("Labels")
plt.suptitle("CNN predictions on held‑out test samples", fontsize=12)
plt.tight_layout(); plt.show()

print("Takeaways:")
print("  • Ring priors tell the model *where* to look; the starness branch makes a tiny bright cue dominate via max pooling.")
print("  • A global-mean map gives an easy path to the overall scale so capacity focuses on local features.")
print("  • Architecture and training remain SHAP-safe.")


In [None]:
"""
Deep SHAP — End‑to‑end, commented, and *interpretable* analysis on our three models
-----------------------------------------------------------------------------------

B. Models explained:
  • OLS (pixels → y), Ridge (StandardScaler + Ridge → y), and CNN (4‑ch input: image + x,y,r; predicts z then inverted to y).
  • CNN is made SHAP‑safe by:
      – a SHAPCompatCNN wrapper that removes nn.Flatten and any in‑place ReLUs,
      – using check_additivity=False (via safe_shap_values) to tolerate unsupported ops (e.g., sqrt in fixed stems).

Changes in this version:
  • Insertion curves removed (kept Deletion/AOPC).
  • FIX: bounding‑box analysis is now robust (no negative kth; handles degenerate masks; always 2‑D indexing).
  • NEW: before boxing we take the **largest weighted connected component** of the top‑|SHAP| pixels within the shell.
"""

# -----------------
# Imports and checks
# -----------------
import copy
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.patches as patches
import torch
import torch.nn as nn

from sklearn.pipeline import Pipeline
from sklearn.model_selection import train_test_split
from sklearn.metrics import r2_score


# ----- Device (M1 → MPS if available)
if torch.cuda.is_available():
    DEVICE = "cuda"
elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
    DEVICE = "mps"
else:
    DEVICE = "cpu"
print(f"Using device: {DEVICE}")

# Silence SHAP's "unrecognized nn.Module" warnings for harmless modules (ConcatPool2d / ReflectionPad2d)
import warnings
warnings.filterwarnings(
    "ignore",
    message=r"unrecognized nn\.Module: .*",
    category=UserWarning,
    module="shap.explainers._deep.deep_pytorch"
)


try:
    import shap
except Exception as e:
    raise ImportError("Please install SHAP first:  pip install shap") from e

# Optional Spearman correlation (baseline sensitivity); fall back to Pearson if unavailable
try:
    from scipy.stats import spearmanr
    HAVE_SPEARMAN = True
except Exception:
    HAVE_SPEARMAN = False

# Sanity: ensure earlier cells provided these
required = ["images", "y_true_all", "y_obs_all", "ols", "final_ridge", "model"]
for v in required:
    if v not in globals():
        raise RuntimeError(f"Missing '{v}'. Please run the data, OLS/Ridge, and CNN training cells first.")

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
model.to(DEVICE).eval()
torch.set_grad_enabled(True)  # in case a previous cell disabled grad

N, H, W = images.shape
CENTER = (H - 1) / 2.0
print(f"[Setup] images: {images.shape} | device: {DEVICE}")

# -----------------------------
# Split indices (SAME recipe as earlier; seed=2025)
# -----------------------------
SEED = 2025
TEST_FRAC = 0.20
VAL_FRAC_WITHIN_TRAINVAL = 0.20  # 20% of 80% → 16% overall

rng = np.random.default_rng(SEED)
idx_all = np.arange(N)

# First split: TrainVal vs Test
idx_trainval, idx_test = train_test_split(idx_all, test_size=TEST_FRAC, random_state=SEED)
# Second split: Train vs Val (within TrainVal)
idx_train, idx_val = train_test_split(idx_trainval, test_size=VAL_FRAC_WITHIN_TRAINVAL, random_state=SEED)

print(f"[Split] Train={len(idx_train)}, Val={len(idx_val)}, Test={len(idx_test)}")

# -----------------------------
# Utility helpers
# -----------------------------
def area_trapezoid(y, x):
    y = np.asarray(y, dtype=float)
    x = np.asarray(x, dtype=float)
    if hasattr(np, "trapezoid"):
        return float(np.trapezoid(y, x))
    return float(np.sum(0.5 * (y[1:] + y[:-1]) * (x[1:] - x[:-1])))

def heatmap2d(arr2d, title, cmap="magma", with_cbar=True):
    plt.figure(figsize=(5, 4))
    plt.imshow(arr2d, cmap=cmap)
    if with_cbar:
        plt.colorbar(fraction=0.046, pad=0.04)
    plt.title(title); plt.axis("off"); plt.tight_layout(); plt.show()

def grid_overlays(imgs, shap_maps, title_prefix, ncols=4):
    n = len(imgs)
    ncols = min(ncols, n); nrows = int(np.ceil(n / ncols))
    vmax = np.percentile(np.abs(np.concatenate([s.ravel() for s in shap_maps])), 99) if len(shap_maps) else 1.0
    fig, axes = plt.subplots(nrows, ncols, figsize=(3.2 * ncols, 3.2 * nrows))
    axes = np.atleast_1d(axes).ravel()
    for i in range(n):
        axes[i].imshow(imgs[i], cmap="gray", vmin=0, vmax=1)
        axes[i].imshow(shap_maps[i], cmap="RdBu_r", alpha=0.7, vmin=-vmax, vmax=+vmax)
        axes[i].set_title(f"{title_prefix} — sample {i}", fontsize=9)
        axes[i].axis("off")
    for k in range(n, nrows * ncols): axes[k].axis("off")
    plt.tight_layout(); plt.show()

# -----------------------------
# Radial & angular fields + “region share” metrics
# -----------------------------
ys, xs = np.indices((H, W))
R = np.sqrt((ys - CENTER) ** 2 + (xs - CENTER) ** 2)
theta = (np.arctan2(ys - CENTER, xs - CENTER) + 2*np.pi) % (2*np.pi)  # used by polar & bbox

INNER_R = 18.0
OUTER_R = 42.0
mask_inner = (R <= INNER_R)  # hollow
mask_shell = (R > INNER_R) & (R <= OUTER_R)  # metal shell
mask_outer = (R > OUTER_R)  # outside

def region_shares(abs_shap_map):
    total = abs_shap_map.sum() + 1e-12
    return (
        float((abs_shap_map[mask_inner]).sum() / total),
        float((abs_shap_map[mask_shell]).sum() / total),
        float((abs_shap_map[mask_outer]).sum() / total),
    )

def region_means(abs_shap_map):
    mi = float(abs_shap_map[mask_inner].mean() if mask_inner.sum() else 0.0)
    ms = float(abs_shap_map[mask_shell].mean() if mask_shell.sum() else 0.0)
    mo = float(abs_shap_map[mask_outer].mean() if mask_outer.sum() else 0.0)
    return mi, ms, mo

def top_p_concentration(abs_shap_map, p=0.05):
    A = np.abs(abs_shap_map).ravel()
    k = max(1, int(p * A.size))
    kth = A.size - k  # non‑negative kth
    if kth < 0: kth = 0
    thr = np.partition(A, kth)[kth]
    return float((A[A >= thr]).sum() / (A.sum() + 1e-12))

def radial_profile(abs_shap_map, nbins=60):
    r = R.ravel(); v = abs_shap_map.ravel()
    bins = np.linspace(0, R.max() + 1e-6, nbins + 1)
    prof = np.zeros(nbins, dtype=float)
    for i in range(nbins):
        m = (r >= bins[i]) & (r < bins[i + 1])
        prof[i] = v[m].mean() if np.any(m) else 0.0
    centres = 0.5 * (bins[:-1] + bins[1:])
    return centres, prof

def summarise_region_and_profile(stack, model_name):
    shares = np.array([region_shares(np.abs(s)) for s in stack])
    inner_s, shell_s, outer_s = shares.mean(axis=0)

    means = np.array([region_means(np.abs(s)) for s in stack])
    inner_m, shell_m, outer_m = means.mean(axis=0)

    conc_1 = np.mean([top_p_concentration(np.abs(s), p=0.01) for s in stack])
    conc_5 = np.mean([top_p_concentration(np.abs(s), p=0.05) for s in stack])

    print(f"\n{model_name} — mean |SHAP| region *share* (area‑biased):")
    print(f"  Hollow (r ≤ {INNER_R:.0f}) : {inner_s:6.2%}")
    print(f"  Shell  ({INNER_R:.0f}<r≤{OUTER_R:.0f}): {shell_s:6.2%}")
    print(f"  Outside(r > {OUTER_R:.0f}) : {outer_s:6.2%}")

    print(f"{model_name} — mean |SHAP| *per pixel* (area‑fair):")
    print(f"  Hollow:  {inner_m:.6f}  |  Shell: {shell_m:.6f}  |  Outside: {outer_m:.6f}")
    print(f"{model_name} — concentration: top‑1% captures {conc_1:6.2%}, top‑5% captures {conc_5:6.2%}")

    profs, radii = [], None
    for s in stack:
        rr, p = radial_profile(np.abs(s), nbins=70)
        profs.append(p); radii = rr
    profs = np.stack(profs, axis=0)
    plt.figure(figsize=(6, 4))
    plt.plot(radii, profs.mean(axis=0))
    plt.xlabel("Radius (px)"); plt.ylabel("Mean |SHAP|")
    plt.title(f"Radial profile of |SHAP| — {model_name}\n(look for peaks around the shell boundaries)")
    plt.grid(True, alpha=0.3); plt.tight_layout(); plt.show()

def topk_mask(abs_map, k_frac):
    A = np.abs(abs_map).ravel()
    k = max(1, int(k_frac * A.size))
    kth = A.size - k
    if kth < 0: kth = 0
    thr = np.partition(A, kth)[kth]
    return (np.abs(abs_map) >= thr)

def iou_and_precision_vs_shell(stack, k_fracs=(0.01, 0.02, 0.05, 0.10)):
    results = {}; shell = mask_shell
    for kf in k_fracs:
        ious, precs = [], []
        for s in stack:
            m = topk_mask(s, kf).reshape(H, W)
            inter = np.logical_and(m, shell).sum()
            union = np.logical_or(m, shell).sum()
            iou = inter / (union + 1e-12)
            prec = inter / (m.sum() + 1e-12)
            ious.append(iou); precs.append(prec)
        results[kf] = (float(np.mean(ious)), float(np.mean(precs)))
    return results

def print_iou_precision(name, results):
    print(f"\n{name} — Ground‑truth alignment vs shell (IoU@k / Precision@k):")
    for kf, (iou, prec) in results.items():
        print(f"  k={int(kf * 100)}%  IoU={iou:.3f}  |  Precision={prec:.3f}")

# -----------------------------
# A common test subset to *explain*
# -----------------------------
N_EXPLAIN = min(24, len(idx_test))
test_sel_idx = rng.choice(idx_test, size=N_EXPLAIN, replace=False)
test_imgs = images[test_sel_idx]
y_obs_test_sel = y_obs_all[test_sel_idx]

# ==========================================================
# 1) OLS — LinearExplainer on flattened pixels
# ==========================================================
print("\n[OLS] Expect a ring; most |SHAP| in the shell.")
X_flat = images.reshape(N, -1).astype(np.float64)

bg_ols_idx = rng.choice(idx_train, size=min(200, len(idx_train)), replace=False)
X_bg_ols = X_flat[bg_ols_idx]
X_ols_test = X_flat[test_sel_idx]

expl_ols = shap.LinearExplainer(ols, X_bg_ols)  # identity link
shap_ols_flat = expl_ols.shap_values(X_ols_test)
if isinstance(shap_ols_flat, list):  # defensive
    shap_ols_flat = shap_ols_flat[0]
shap_ols_imgs = shap_ols_flat.reshape(-1, H, W)

grid_overlays(test_imgs[:8], shap_ols_imgs[:8], "OLS")
heatmap2d(np.mean(np.abs(shap_ols_imgs), axis=0), "Global mean |SHAP| — OLS")
summarise_region_and_profile(shap_ols_imgs, "OLS")
print_iou_precision("OLS", iou_and_precision_vs_shell(shap_ols_imgs))

# ==========================================================
# 2) Ridge — LinearExplainer on STANDARDISED features
# ==========================================================
print("\n[Ridge] Expect more diffused credit spread; still ring‑biased.")
if not isinstance(final_ridge, Pipeline):
    raise RuntimeError("Expected 'final_ridge' to be a Pipeline(StandardScaler, Ridge).")
scaler = final_ridge.named_steps["scaler"]
ridge  = final_ridge.named_steps["ridge"]

Z_bg   = scaler.transform(X_flat[bg_ols_idx])
Z_test = scaler.transform(X_flat[test_sel_idx])

expl_ridge = shap.LinearExplainer(ridge, Z_bg)
shap_ridge_scaled = expl_ridge.shap_values(Z_test)
if isinstance(shap_ridge_scaled, list):
    shap_ridge_scaled = shap_ridge_scaled[0]
shap_ridge_imgs = shap_ridge_scaled.reshape(-1, H, W)

grid_overlays(test_imgs[:8], shap_ridge_imgs[:8], "Ridge")
heatmap2d(np.mean(np.abs(shap_ridge_imgs), axis=0), "Global mean |SHAP| — Ridge")
summarise_region_and_profile(shap_ridge_imgs, "Ridge")
print_iou_precision("Ridge", iou_and_precision_vs_shell(shap_ridge_imgs))

# ==========================================================
# 3) CNN — DeepExplainer on 4‑channel inputs (image + CoordConv x,y,r)
#      SHAP‑compatible wrapper (no nn.Flatten; no in‑place ops); additivity off
# ==========================================================
print("\n[CNN] Strong edges on inner/outer shell expected; baseline matters — we select it by fidelity.")

# --- preprocessing used for CNN inputs (match the training recipe) ---
def estimate_background_offset(img: np.ndarray, k: int = 10) -> float:
    patches = [img[:k, :k], img[:k, -k:], img[-k:, :k], img[-k:, -k:]]
    return float(np.median(np.concatenate([p.ravel() for p in patches])))

def subtract_background(X: np.ndarray) -> np.ndarray:
    Xp = np.empty_like(X)
    for i, img in enumerate(X):
        bg = estimate_background_offset(img, k=10)
        im = img - bg
        im[im < 0.0] = 0.0
        Xp[i] = im
    return Xp

pm_ps_found = False
if "pix_mean" in globals() and "pix_std" in globals():
    pm = float(pix_mean); ps = float(pix_std); pm_ps_found = True
elif "pixel_mean" in globals() and "pixel_std" in globals():
    pm = float(pixel_mean); ps = float(pixel_std); pm_ps_found = True
if not pm_ps_found:
    X_train_bgsub = subtract_background(images[idx_train])
    pm = float(X_train_bgsub.mean()); ps = float(X_train_bgsub.std() + 1e-8)

# CoordConv maps
x_lin = (xs - CENTER) / CENTER
y_lin = (ys - CENTER) / CENTER
r_map = np.sqrt(x_lin ** 2 + y_lin ** 2)
r_map = r_map / (r_map.max() + 1e-12)

def build_cnn_input(img_batch: np.ndarray) -> np.ndarray:
    Xs = subtract_background(img_batch)
    Xn = (Xs - pm) / ps
    n = Xn.shape[0]
    X4 = np.zeros((n, 4, H, W), dtype=np.float32)
    X4[:, 0] = Xn.astype(np.float32)
    X4[:, 1] = x_lin.astype(np.float32)
    X4[:, 2] = y_lin.astype(np.float32)
    X4[:, 3] = r_map.astype(np.float32)
    return X4

X_cnn_test_4 = build_cnn_input(images[test_sel_idx])

# --- SHAP‑compatible wrapper that removes nn.Flatten and copies Linear weights ---
class SHAPCompatCNN(nn.Module):
    def __init__(self, base: nn.Module):
        super().__init__()
        m = copy.deepcopy(base).to(DEVICE).eval()
        for module in m.modules():
            if isinstance(module, nn.ReLU):
                module.inplace = False
        self.features = m.features
        self.gap = m.gap
        # Extract Linear layers from head (expects exactly two)
        lin_layers = [mod for mod in m.head if isinstance(mod, nn.Linear)]
        if len(lin_layers) != 2:
            raise RuntimeError("Expected two Linear layers in model.head.")
        self.fc1 = nn.Linear(lin_layers[0].in_features, lin_layers[0].out_features, bias=True)
        self.act = nn.ReLU(inplace=False)
        drops = [mod for mod in m.head if isinstance(mod, nn.Dropout)]
        self.drop = nn.Dropout(p=drops[0].p if drops else 0.0)
        self.fc2 = nn.Linear(lin_layers[1].in_features, lin_layers[1].out_features, bias=True)
        with torch.no_grad():
            self.fc1.weight.copy_(lin_layers[0].weight); self.fc1.bias.copy_(lin_layers[0].bias)
            self.fc2.weight.copy_(lin_layers[1].weight); self.fc2.bias.copy_(lin_layers[1].bias)

    def forward(self, x):
        x = self.features(x)
        x = self.gap(x)
        x = torch.flatten(x, 1)  # functional flatten (no nn.Flatten module)
        x = self.fc1(x)
        x = self.act(x)
        x = self.drop(x)
        y = self.fc2(x)
        return y.clone()

shap_model = SHAPCompatCNN(model).to(DEVICE).eval()

# --- candidate baselines for CNN SHAP ---
def build_baseline(kind="train_subset", m=64):
    """
    kind ∈ {"train_subset", "mean", "median", "blurred_mean"}.
    Returns a tensor on DEVICE suitable for DeepExplainer background.
    """
    if kind == "train_subset":
        idx = rng.choice(idx_train, size=min(m, len(idx_train)), replace=False)
        X_bg = build_cnn_input(images[idx])
        return torch.from_numpy(X_bg).to(DEVICE)
    if kind in {"mean", "median", "blurred_mean"}:
        X_tr = build_cnn_input(images[idx_train])
        if kind == "mean":
            img_ch = X_tr[:, 0].mean(axis=0, keepdims=True).astype(np.float32)
        elif kind == "median":
            img_ch = np.median(X_tr[:, 0], axis=0, keepdims=True).astype(np.float32)
        else:  # blurred_mean
            mean_img = X_tr[:, 0].mean(axis=0)
            k = 5; pad = k // 2
            tmp = np.pad(mean_img, pad, mode="reflect"); out = np.zeros_like(mean_img)
            for y in range(H):
                for x in range(W):
                    out[y, x] = tmp[y:y + k, x:x + k].mean()
            img_ch = out[None, :, :].astype(np.float32)
        Xb = np.zeros((1, 4, H, W), dtype=np.float32)
        Xb[0, 0] = img_ch[0]; Xb[0, 1] = x_lin; Xb[0, 2] = y_lin; Xb[0, 3] = r_map
        return torch.from_numpy(Xb).to(DEVICE)
    raise ValueError("Unknown baseline kind.")

# --- DeepExplainer helper that disables additivity check (critical fix) ---
def safe_shap_values(explainer, x_tensor):
    try:
        return explainer.shap_values(x_tensor, check_additivity=False)
    except TypeError:
        try:
            import shap.explainers._deep.deep_utils as _du
            _du.TOLERANCE = 1e9
        except Exception:
            pass
        return explainer.shap_values(x_tensor)

# --- Deep SHAP for several baselines; pick the best by fidelity ---
BASELINES = ["train_subset", "mean", "median", "blurred_mean"]
baseline_results = {}
torch.set_grad_enabled(True)
X_test_t = torch.from_numpy(X_cnn_test_4).to(DEVICE).requires_grad_(True)

def cnn_predict_y_from_4(X4):
    """Predict y (not z) for a batch of 4‑channel inputs with the *trained* CNN clone used for SHAP."""
    y_train_obs = y_obs_all[idx_train]
    y_mean = float(np.mean(y_train_obs)); y_std = float(np.std(y_train_obs) + 1e-8)
    def from_z(z): return z * y_std + y_mean
    shap_model.eval()
    with torch.no_grad():
        z = shap_model(torch.from_numpy(X4).to(DEVICE)).cpu().numpy().squeeze(-1)
    return from_z(z)

fractions = [0.01, 0.02, 0.05, 0.10, 0.20]
y_obs_sel = y_obs_test_sel

for bkind in BASELINES:
    bg_t = build_baseline(bkind, m=64 if bkind == "train_subset" else 1)
    expl = shap.DeepExplainer(shap_model, bg_t)
    vals = safe_shap_values(expl, X_test_t)
    sh = vals[0] if isinstance(vals, list) else vals  # (n, 4, H, W)
    sh_img = sh[:, 0, :, :]
    baseline_results[bkind] = {"all": sh, "img": sh_img}

    # IoU/Precision@k
    io = iou_and_precision_vs_shell(sh_img, k_fracs=(0.01, 0.05))
    baseline_results[bkind]["iou_prec"] = io

    # Deletion AOPC
    n, _, h, w = X_cnn_test_4.shape; D = h * w
    order = np.argsort(-np.abs(sh_img.reshape(n, D)), axis=1)
    y0 = cnn_predict_y_from_4(X_cnn_test_4); r2_base = r2_score(y_obs_sel, y0)
    r2_list = []
    for frac in fractions:
        k = max(1, int(frac * D))
        X_mod = X_cnn_test_4.copy()
        for i in range(n):
            idxk = order[i, :k]; rr, cc = idxk // w, idxk % w
            X_mod[i, 0, rr, cc] = 0.0
        y_pred_mod = cnn_predict_y_from_4(X_mod)
        r2_list.append(r2_score(y_obs_sel, y_pred_mod))
    drops = (r2_base - np.array(r2_list))
    aopc = area_trapezoid(drops, np.array(fractions))
    baseline_results[bkind]["aopc"] = float(aopc)
    baseline_results[bkind]["r2_base"] = float(r2_base)

torch.set_grad_enabled(False)

# Pick best baseline by (IoU@5% + IoU@1%) + AOPC (simple normalised score)
def normalise(v):
    v = np.asarray(v, dtype=float)
    if np.ptp(v) < 1e-12:
        return np.ones_like(v)
    return (v - v.min()) / (v.max() - v.min())

scores = []; kinds = []
for k, d in baseline_results.items():
    io1 = d["iou_prec"][0.01][0]; io5 = d["iou_prec"][0.05][0]
    kinds.append(k); scores.append([io1, io5, d["aopc"]])
scores = np.array(scores)
score = normalise(scores[:, 0]) + normalise(scores[:, 1]) + normalise(scores[:, 2])
BEST_BASELINE = kinds[int(np.argmax(score))]
print(f"\n[CNN] Baseline selection by fidelity → chosen: **{BEST_BASELINE}**")
for k in BASELINES:
    io1 = baseline_results[k]["iou_prec"][0.01][0]
    io5 = baseline_results[k]["iou_prec"][0.05][0]
    print(f"  {k:13s}  IoU@1%={io1:.3f}  IoU@5%={io5:.3f}  AOPC={baseline_results[k]['aopc']:.4f}")

# Use the chosen baseline’s SHAP maps from now on
shap_cnn_allch = baseline_results[BEST_BASELINE]["all"]
shap_cnn_imgch = baseline_results[BEST_BASELINE]["img"]

# Light background bootstrapping (variability hints)
print("\n[CNN] Background bootstrapping (variability over train‑subset backgrounds)")
B = 3
boot_shares = []
for b in range(B):
    bg_t = build_baseline("train_subset", m=32)
    torch.set_grad_enabled(True)
    e = shap.DeepExplainer(shap_model, bg_t)
    vals = safe_shap_values(e, X_test_t)
    torch.set_grad_enabled(False)
    sh = vals[0] if isinstance(vals, list) else vals
    sh_img = sh[:, 0]
    shares = np.array([region_shares(np.abs(s)) for s in sh_img])
    boot_shares.append(shares.mean(axis=0))
boot_shares = np.stack(boot_shares, axis=0)
means = boot_shares.mean(axis=0); stds = boot_shares.std(axis=0)
print(f"  Region share mean±sd (Hollow/Shell/Outside): "
      f"{means[0]:.3f}±{stds[0]:.3f} / {means[1]:.3f}±{stds[1]:.3f} / {means[2]:.3f}±{stds[2]:.3f}")

# Local overlays & global maps for CNN (chosen baseline)
grid_overlays(test_imgs[:8], shap_cnn_imgch[:8], f"CNN [{BEST_BASELINE}]")
heatmap2d(np.mean(np.abs(shap_cnn_imgch), axis=0), f"Global mean |SHAP| — CNN [{BEST_BASELINE}]")
summarise_region_and_profile(shap_cnn_imgch, f"CNN [{BEST_BASELINE}]")
print_iou_precision(f"CNN [{BEST_BASELINE}]", iou_and_precision_vs_shell(shap_cnn_imgch))

# -----------------------------
# Deletion curves — measured as R² changes (OLS / Ridge / CNN)
# -----------------------------
print("\n[Deletion curves: what they mean]\n"
      "• Remove top‑|SHAP| pixels and re‑score R² vs true labels.\n"
      "  A *faithful* explanation targets genuinely important pixels ⇒ R² **drops fast**. We summarise by AOPC.\n")

def plot_deletion_r2(title, fractions, r2_base, r2_list):
    plt.figure(figsize=(5.8, 4))
    plt.plot([f * 100 for f in fractions], [r2_base] * len(fractions), linestyle="--", label="Baseline R²")
    plt.plot([f * 100 for f in fractions], r2_list, marker="o", label="R² after deletion")
    plt.xlabel("Top-|SHAP| pixels removed (%)")
    plt.ylabel("R² vs true labels (test subset)")
    plt.title(title); plt.grid(True, alpha=0.3); plt.legend(); plt.tight_layout(); plt.show()
    drops = (r2_base - np.array(r2_list))
    aopc = area_trapezoid(drops, np.array(fractions))
    print(f"{title} — AOPC (higher → better fidelity): {aopc:.4f}")
    return aopc

# --- OLS deletion
y0_ols_pred = ols.predict(X_ols_test)
r2_base_ols = r2_score(y_obs_test_sel, y0_ols_pred)
order_ols = np.argsort(-np.abs(shap_ols_flat), axis=1)
r2_list_ols_del = []
mu = X_bg_ols.mean(axis=0)
for frac in fractions:
    k = max(1, int(frac * X_ols_test.shape[1]))
    X_del = X_ols_test.copy()
    for i in range(X_del.shape[0]):
        X_del[i, order_ols[i, :k]] = mu[order_ols[i, :k]]
    r2_list_ols_del.append(r2_score(y_obs_test_sel, ols.predict(X_del)))
aopc_ols = plot_deletion_r2("Deletion curve — OLS (R² drop)", fractions, r2_base_ols, r2_list_ols_del)

# --- Ridge deletion
y0_ridge_pred = final_ridge.predict(X_flat[test_sel_idx])
r2_base_ridge = r2_score(y_obs_test_sel, y0_ridge_pred)
order_ridge = np.argsort(-np.abs(shap_ridge_scaled), axis=1)
r2_list_ridge_del = []
Z_test = scaler.transform(X_flat[test_sel_idx])
for frac in fractions:
    k = max(1, int(frac * Z_test.shape[1]))
    Z_del = Z_test.copy()
    for i in range(Z_del.shape[0]):
        Z_del[i, order_ridge[i, :k]] = 0.0  # 0 == scaled mean
    y_pred_del = (Z_del @ ridge.coef_.ravel() + ridge.intercept_)
    r2_list_ridge_del.append(r2_score(y_obs_test_sel, y_pred_del))
aopc_ridge = plot_deletion_r2("Deletion curve — Ridge (R² drop)", fractions, r2_base_ridge, r2_list_ridge_del)

# --- CNN deletion (image channel only; coords intact)
y_train_obs = y_obs_all[idx_train]
y_mean = float(np.mean(y_train_obs)); y_std = float(np.std(y_train_obs) + 1e-8)
def from_z(z): return z * y_std + y_mean

def cnn_predict_y(_model, X4):
    _model.eval()
    with torch.no_grad():
        z = _model(torch.from_numpy(X4).to(DEVICE)).cpu().numpy().squeeze(-1)
    return from_z(z)

y0_cnn_pred = cnn_predict_y(shap_model, X_cnn_test_4)
r2_base_cnn = r2_score(y_obs_test_sel, y0_cnn_pred)
n, _, h, w = X_cnn_test_4.shape; D = h * w
order_cnn = np.argsort(-np.abs(shap_cnn_imgch.reshape(n, D)), axis=1)

r2_list_cnn_del = []
for frac in fractions:
    k = max(1, int(frac * D))
    X_mod = X_cnn_test_4.copy()
    for i in range(n):
        idx = order_cnn[i, :k]; rr, cc = idx // w, idx % w
        X_mod[i, 0, rr, cc] = 0.0
    y_pred_mod = cnn_predict_y(shap_model, X_mod)
    r2_list_cnn_del.append(r2_score(y_obs_test_sel, y_pred_mod))
aopc_cnn = plot_deletion_r2(f"Deletion curve — CNN [{BEST_BASELINE}] (R² drop)", fractions, r2_base_cnn, r2_list_cnn_del)

# -----------------------------
# Bounding‑box analysis over top‑|SHAP| within the shell
#   • Robust against shape issues (always flattens indices safely).
#   • Takes the **largest weighted connected component** before boxing.
# -----------------------------
def circular_coverage(angles_rad):
    """Minimal angular span (radians) covering all given angles on the circle."""
    if angles_rad.size == 0:
        return 0.0
    a = np.sort(angles_rad)
    diffs = np.diff(np.concatenate([a, a[:1] + 2*np.pi]))
    max_gap = np.max(diffs)
    return float(2*np.pi - max_gap)

def _largest_component(mask2d, weights2d=None):
    """
    8‑connected components on a boolean 2‑D mask.
    Returns the mask of the component with largest total weight (or largest size if weights is None).
    """
    M = np.ascontiguousarray(mask2d.astype(bool))
    H_, W_ = M.shape
    visited = np.zeros_like(M, dtype=bool)
    best_weight = -1.0
    best_comp = None

    # Precompute neighbor offsets (8‑connectivity)
    nbrs = [(-1,-1),(-1,0),(-1,1),(0,-1),(0,1),(1,-1),(1,0),(1,1)]

    # Iterate over all True pixels
    ys_, xs_ = np.where(M)
    for y0, x0 in zip(ys_, xs_):
        if visited[y0, x0]:
            continue
        # BFS/stack
        stack = [(y0, x0)]
        visited[y0, x0] = True
        comp_pixels = [(y0, x0)]
        while stack:
            y, x = stack.pop()
            for dy, dx in nbrs:
                yy, xx = y + dy, x + dx
                if 0 <= yy < H_ and 0 <= xx < W_ and (not visited[yy, xx]) and M[yy, xx]:
                    visited[yy, xx] = True
                    stack.append((yy, xx))
                    comp_pixels.append((yy, xx))
        # weight for this component
        if weights2d is not None:
            w = float(np.sum([weights2d[y, x] for (y, x) in comp_pixels]))
        else:
            w = float(len(comp_pixels))
        if w > best_weight:
            best_weight = w
            best_comp = comp_pixels

    comp_mask = np.zeros_like(M, dtype=bool)
    if best_comp is not None:
        ys_c, xs_c = zip(*best_comp)
        comp_mask[ys_c, xs_c] = True
    return comp_mask

def compute_bbox_metrics(stack, model_name, imgs_for_overlay=None, draw_overlays=True, overlay_n=6,
                         p_top=0.01):
    """
    For each SHAP map in `stack`, within the shell:
      • select the TOP p_top fraction of |SHAP| pixels (non-negative kth),
      • take the **largest weighted connected component** of that mask,
      • compute the tight axis-aligned bounding box,
      • report area fraction of the shell covered by the box,
               fraction of total |SHAP| in shell captured by the box,
               angular coverage (deg) of selected pixels,
               radial width normalised by shell thickness.
    Returns a dict of per-sample arrays and prints robust summarised stats.
    """
    area_fracs = []
    shap_shares = []
    ang_degs    = []
    rad_widths  = []

    shell_area = float(mask_shell.sum())
    shell_thick = (OUTER_R - INNER_R)

    overlays_done = 0
    if imgs_for_overlay is None:
        imgs_for_overlay = []

    for i, S in enumerate(stack):
        # ensure 2‑D map
        S2 = np.asarray(S)
        if S2.ndim != 2:
            # if something off slipped in, reduce to the last two dims
            S2 = np.asarray(S2)[-H:, -W:]
            S2 = S2.reshape(H, W)
        Sabs = np.abs(S2)

        vals = Sabs[mask_shell].ravel()

        if vals.size == 0 or not np.any(vals > 0):
            # skip gracefully
            continue

        # TOP‑p% threshold (non‑negative kth)
        k = max(1, int(np.ceil(p_top * vals.size)))
        kth = vals.size - k
        if kth < 0: kth = 0
        thr = np.partition(vals, kth)[kth]

        # binary mask of top‑p% within shell
        M = np.zeros((H, W), dtype=bool)
        M_shell = (Sabs >= thr) & mask_shell

        # Focus on **largest weighted component** (weights = |SHAP|)
        M_sel = _largest_component(M_shell, weights2d=Sabs)

        # fallback: if component is empty (pathological), take the single max pixel in shell
        if not np.any(M_sel):
            shell_idx = np.column_stack(np.where(mask_shell))
            # argmax over shell pixels
            argmax_flat = np.argmax(Sabs[mask_shell])
            yx = shell_idx[argmax_flat]
            M_sel = np.zeros((H, W), dtype=bool)
            M_sel[yx[0], yx[1]] = True

        # coords of selected pixels (robust way, works for any shape)
        coords = np.column_stack(np.where(M_sel))
        y_min, y_max = int(coords[:, 0].min()), int(coords[:, 0].max())
        x_min, x_max = int(coords[:, 1].min()), int(coords[:, 1].max())

        # area fraction of shell covered by the BOX ∩ shell
        box_mask = np.zeros((H, W), dtype=bool)
        box_mask[y_min:y_max+1, x_min:x_max+1] = True
        box_shell = np.logical_and(box_mask, mask_shell)
        area_frac = float(box_shell.sum() / (shell_area + 1e-12))
        area_fracs.append(area_frac)

        # |SHAP| share captured by the BOX within the shell
        shap_shell_total = float(Sabs[mask_shell].sum() + 1e-12)
        shap_in_box = float(Sabs[box_shell].sum())
        shap_share = shap_in_box / shap_shell_total
        shap_shares.append(shap_share)

        # angular coverage (deg) of SELECTED pixels
        ang = theta[M_sel]
        ang_span = circular_coverage(ang.ravel())
        ang_deg = float(ang_span * 180.0 / np.pi)
        ang_degs.append(ang_deg)

        # radial width normalised by shell thickness (using selected pixels)
        r_sel = R[M_sel]
        rad_w = float((r_sel.max() - r_sel.min()) / (shell_thick + 1e-12))
        rad_widths.append(rad_w)

        # overlays (draw box + highlight selected component)
        if draw_overlays and (overlays_done < overlay_n) and (i < len(imgs_for_overlay)):
            fig, ax = plt.subplots(1, 1, figsize=(3.2, 3.2))
            ax.imshow(imgs_for_overlay[i], cmap="gray", vmin=0, vmax=1)
            rect = patches.Rectangle((x_min, y_min), x_max - x_min + 1, y_max - y_min + 1,
                                     linewidth=1.8, edgecolor='lime', facecolor='none')
            ax.add_patch(rect)
            # faint fill for the selected component
            M_alpha = np.zeros((H, W), dtype=float)
            M_alpha[M_sel] = 0.5
            ax.imshow(M_alpha, cmap="Reds", alpha=0.25, vmin=0, vmax=1)
            ax.set_title(f"{model_name} — bbox on top-{int(p_top*100)}% |SHAP| (largest component)", fontsize=9)
            ax.axis("off")
            plt.tight_layout(); plt.show()
            overlays_done += 1

    # Summaries
    area_fracs = np.array(area_fracs, dtype=float)
    shap_shares = np.array(shap_shares, dtype=float)
    ang_degs    = np.array(ang_degs, dtype=float)
    rad_widths  = np.array(rad_widths, dtype=float)

    def summ(name, arr, unit=""):
        if arr.size == 0:
            print(f"  {name}: n=0")
            return (np.nan, np.nan, np.nan)
        print(f"  {name}: median={np.nanmedian(arr):.3f}{unit}  "
              f"IQR=({np.nanpercentile(arr,25):.3f}{unit}, {np.nanpercentile(arr,75):.3f}{unit})")
        return (float(np.nanmedian(arr)),
                float(np.nanpercentile(arr,25)),
                float(np.nanpercentile(arr,75)))

    print(f"\n[{model_name}] Bounding‑box metrics on top‑|SHAP| within the shell (largest component; p_top={p_top*100:.1f}%):")
    s1 = summ("Area fraction of shell (BOX∩shell / shell)", area_fracs)
    s2 = summ("|SHAP| share captured (BOX∩shell / shell)",   shap_shares)
    s3 = summ("Angular coverage (deg) of selected pixels",     ang_degs, unit="°")
    s4 = summ("Radial width / shell thickness",                rad_widths)

    return {
        "area_frac": area_fracs, "shap_share": shap_shares,
        "angle_deg": ang_degs, "rad_width": rad_widths,
        "summary": {"area_frac": s1, "shap_share": s2, "angle_deg": s3, "rad_width": s4}
    }

# --- Run bounding‑box metrics for each model (top 1% by default)
P_TOP = 0.01
bbox_stats_ols   = compute_bbox_metrics(shap_ols_imgs,   "OLS",   imgs_for_overlay=test_imgs[:8], draw_overlays=False, p_top=P_TOP)
bbox_stats_ridge = compute_bbox_metrics(shap_ridge_imgs, "Ridge", imgs_for_overlay=test_imgs[:8], draw_overlays=False, p_top=P_TOP)
bbox_stats_cnn   = compute_bbox_metrics(shap_cnn_imgch,  f"CNN [{BEST_BASELINE}]",
                                        imgs_for_overlay=test_imgs[:8], draw_overlays=True, p_top=P_TOP)

print("\n[Reading the bounding‑box summaries]")
print("• A small **Area fraction** with a large **|SHAP| share** and **low angular coverage** means the model relies on a tight, local cue.\n"
      "• For the *hard/star* dataset, the CNN should have the **smallest boxes** and **tightest angles** among the three, "
      "while OLS/Ridge remain diffuse.\n"
      "• For the *simple* dataset (no star), the CNN’s boxes will typically hug thin arcs on the inner/outer edges; "
      "angles will be modest, not full‑ring, and radial width < 0.5.\n")

# ================================
# Global star-capture analysis (robust 2-D handling + enrichment)
# ================================
import numpy as np
import matplotlib.pyplot as plt

print("\n[Global star-capture analysis]")

# --- Helpers to guarantee 2-D (H,W) arrays and clean values
def _ensure_hw(x):
    a = np.asarray(x)
    a = np.squeeze(a)
    if a.ndim != 2:
        # Force-shape to (H, W) if a weird singleton remains
        a = a.reshape(H, W)
    # Clean NaN/Inf for safety
    a = np.nan_to_num(a, nan=0.0, posinf=0.0, neginf=0.0)
    return a

# Reuse ring geometry already defined: R, theta, mask_shell, INNER_R, OUTER_R
R_mid   = 0.5 * (INNER_R + OUTER_R)
R_sigma = 0.18 * (OUTER_R - INNER_R)
RING_W  = np.exp(-0.5 * ((R - R_mid) / (R_sigma + 1e-8))**2) * mask_shell
shell_area = float(mask_shell.sum())
shell_thick = (OUTER_R - INNER_R)

def _box_blur3(a):
    a = _ensure_hw(a)
    tmp = np.pad(a, 1, mode="reflect")
    out = np.zeros_like(a, dtype=float)
    for y in range(H):
        for x in range(W):
            out[y, x] = tmp[y:y+3, x:x+3].mean()
    return out

def find_star_proxy_center(img):
    # Ring-weighted, high-tail emphasis
    z = _ensure_hw(img) - float(np.median(img))
    z[z < 0.0] = 0.0
    m = z * RING_W
    m = _box_blur3(m)
    yi, xi = np.unravel_index(np.argmax(m), m.shape)
    return int(yi), int(xi)

def disc_mask(yc, xc, radius=4):
    yy, xx = np.ogrid[:H, :W]
    return ((yy - yc)**2 + (xx - xc)**2) <= (radius*radius)

def circular_coverage(angles_rad):
    if angles_rad.size == 0:
        return 0.0
    a = np.sort(angles_rad)
    diffs = np.diff(np.concatenate([a, a[:1] + 2*np.pi]))
    max_gap = np.max(diffs)
    return float(2*np.pi - max_gap)

def circ_mean_angle(angles, weights=None):
    if angles.size == 0:
        return np.nan
    if weights is None:
        weights = np.ones_like(angles, dtype=float)
    c = np.sum(weights * np.cos(angles))
    s = np.sum(weights * np.sin(angles))
    return float(np.arctan2(s, c))

def circ_abs_diff_deg(a, b):
    d = np.abs((a - b + np.pi) % (2*np.pi) - np.pi)
    return float(d * 180.0 / np.pi)

# ---- robust 8-connected component on 2-D boolean masks
def _largest_component(mask2d, weights2d=None):
    M = _ensure_hw(mask2d).astype(bool)
    Wts = _ensure_hw(weights2d) if weights2d is not None else None
    H_, W_ = M.shape
    visited = np.zeros_like(M, dtype=bool)
    best_weight = -1.0
    best_comp = None
    nbrs = [(-1,-1),(-1,0),(-1,1),(0,-1),(0,1),(1,-1),(1,0),(1,1)]
    ys_, xs_ = np.where(M)
    for y0, x0 in zip(ys_, xs_):
        if visited[y0, x0]:
            continue
        stack = [(y0, x0)]
        visited[y0, x0] = True
        comp_pixels = [(y0, x0)]
        while stack:
            y, x = stack.pop()
            for dy, dx in nbrs:
                yy, xx = y + dy, x + dx
                if 0 <= yy < H_ and 0 <= xx < W_ and (not visited[yy, xx]) and M[yy, xx]:
                    visited[yy, xx] = True
                    stack.append((yy, xx))
                    comp_pixels.append((yy, xx))
        if Wts is not None:
            w = float(np.sum([Wts[y, x] for (y, x) in comp_pixels]))
        else:
            w = float(len(comp_pixels))
        if w > best_weight:
            best_weight = w
            best_comp = comp_pixels
    comp_mask = np.zeros_like(M, dtype=bool)
    if best_comp:
        ys_c, xs_c = zip(*best_comp)
        comp_mask[ys_c, xs_c] = True
    return comp_mask

def largest_component_mask(abs_map, top_frac=0.01):
    S = _ensure_hw(abs_map)
    vals = S[mask_shell].ravel()
    if vals.size == 0:
        return np.zeros((H, W), dtype=bool)
    vals = np.nan_to_num(vals, nan=0.0, posinf=0.0, neginf=0.0)
    k = max(1, int(np.ceil(top_frac * vals.size)))
    kth = max(0, vals.size - k)           # non-negative kth
    thr = np.partition(vals, kth)[kth]
    M_shell = (S >= thr) & mask_shell
    return _largest_component(M_shell, weights2d=S)

def global_star_capture(shap_stack, model_name, imgs_subset, disc_radius=4):
    shares, enrichments, ang_errs = [], [], []
    for i in range(len(shap_stack)):
        img = _ensure_hw(imgs_subset[i])
        S   = np.abs(_ensure_hw(shap_stack[i]))

        # (1) star disc
        ys_, xs_ = find_star_proxy_center(img)
        D = disc_mask(ys_, xs_, radius=disc_radius)
        disc_shell = np.logical_and(D, mask_shell)

        # (2) |SHAP| share inside disc (normalised by |SHAP| on shell)
        denom = float(S[mask_shell].sum() + 1e-12)
        numer = float(S[disc_shell].sum())
        share = numer / denom
        shares.append(share)

        # (3) enrichment vs area fraction of disc in the shell
        area_frac_disc = float(disc_shell.sum() / (shell_area + 1e-12))
        enrichments.append(share / (area_frac_disc + 1e-12))

        # (4) angular alignment vs largest |SHAP| component
        M_sel = largest_component_mask(S, top_frac=0.01)
        if np.any(M_sel):
            th_star = float(theta[ys_, xs_])
            ww = S[M_sel]
            th_comp = circ_mean_angle(theta[M_sel], weights=ww / (ww.sum() + 1e-12))
            ang_errs.append(circ_abs_diff_deg(th_star, th_comp))
        else:
            ang_errs.append(np.nan)

    shares = np.array(shares, dtype=float)
    enrichments = np.array(enrichments, dtype=float)
    ang_errs = np.array(ang_errs, dtype=float)

    def summarise(name, arr, unit=""):
        nfin = int(np.sum(np.isfinite(arr)))
        print(f"  {model_name} — {name}: "
              f"median={np.nanmedian(arr):.3f}{unit} | "
              f"IQR=({np.nanpercentile(arr,25):.3f}{unit}, {np.nanpercentile(arr,75):.3f}{unit}) | n={nfin}")

    print(f"\n[{model_name}] Star-disc capture (radius={disc_radius}px), enrichment, and angular alignment")
    summarise("SHAP share inside star disc", shares)
    summarise("enrichment (share / area)",   enrichments)
    summarise("abs angular error vs star (deg)", ang_errs, unit="°")
    return shares, enrichments, ang_errs

# Run on the same subset used for SHAP (test_imgs / shap_*_imgs)
shares_ols,   enrich_ols,   ang_ols   = global_star_capture(shap_ols_imgs,   "OLS",   test_imgs, disc_radius=4)
shares_ridge, enrich_ridge, ang_ridge = global_star_capture(shap_ridge_imgs, "Ridge", test_imgs, disc_radius=4)
shares_cnn,   enrich_cnn,   ang_cnn   = global_star_capture(shap_cnn_imgch,  f"CNN [{BEST_BASELINE}]", test_imgs, disc_radius=4)

# Optional: global comparison plots
plt.figure(figsize=(6,4))
plt.boxplot([shares_ols, shares_ridge, shares_cnn], labels=["OLS","Ridge","CNN"], showmeans=True)
plt.ylabel("Star-disc |SHAP| share (fraction of |SHAP| in shell)")
plt.title("Global star capture across explained test subset")
plt.grid(True, alpha=0.3); plt.tight_layout(); plt.show()

plt.figure(figsize=(6,4))
plt.boxplot([enrich_ols, enrich_ridge, enrich_cnn], labels=["OLS","Ridge","CNN"], showmeans=True)
plt.ylabel("Enrichment (|SHAP| share / area fraction of disc)")
plt.title("Global enrichment of |SHAP| at the star location")
plt.grid(True, alpha=0.3); plt.tight_layout(); plt.show()


# -----------------------------
# CNN baseline sensitivity (train subset vs mean baseline)
# -----------------------------
torch.set_grad_enabled(True)
bg_train_t = build_baseline("train_subset", m=64)
bg_mean_t  = build_baseline("mean", m=1)
e_train = shap.DeepExplainer(shap_model, bg_train_t)
e_mean  = shap.DeepExplainer(shap_model, bg_mean_t)
vals_a = safe_shap_values(e_train, X_test_t)
vals_b = safe_shap_values(e_mean, X_test_t)
torch.set_grad_enabled(False)
A = (vals_a[0] if isinstance(vals_a, list) else vals_a)[:, 0]
B = (vals_b[0] if isinstance(vals_b, list) else vals_b)[:, 0]
if HAVE_SPEARMAN:
    corrs = []
    for i in range(A.shape[0]):
        a = np.abs(A[i]).ravel(); b = np.abs(B[i]).ravel()
        if (np.std(a) < 1e-12) or (np.std(b) < 1e-12):
            corrs.append(np.nan)
        else:
            corrs.append(spearmanr(a, b, nan_policy="omit").correlation)
    corrs = np.array(corrs)
    print("\n[Baseline sensitivity — CNN]")
    print(f"  Spearman rank correlation of |SHAP| (train‑baseline vs mean‑baseline): "
          f"median={np.nanmedian(corrs):.3f}, IQR=({np.nanpercentile(corrs, 25):.3f}, {np.nanpercentile(corrs, 75):.3f})")
else:
    print("\n[Baseline sensitivity — CNN]\n  SciPy not available; install SciPy to report Spearman correlation.")

# ==========================================================
# Polar superpixels / regional analysis
# ==========================================================
print("\n[Polar superpixels / regional analysis]")
NB_RADIAL = 10; NB_THETA = 16
rad_edges  = np.linspace(0, R.max() + 1e-6, NB_RADIAL + 1)
theta_edges= np.linspace(0, 2 * np.pi, NB_THETA + 1)
rad_bin = np.digitize(R, rad_edges) - 1
theta_bin = np.digitize(theta, theta_edges) - 1
rad_bin[rad_bin == NB_RADIAL] = NB_RADIAL - 1
theta_bin[theta_bin == NB_THETA] = NB_THETA - 1

def polar_bin_aggregate(abs_shap_map):
    out = np.zeros((NB_RADIAL, NB_THETA), dtype=np.float64)
    counts = np.zeros_like(out)
    for r in range(NB_RADIAL):
        for t in range(NB_THETA):
            m = (rad_bin == r) & (theta_bin == t)
            if np.any(m):
                out[r, t] = abs_shap_map[m].mean()
                counts[r, t] = m.sum()
    return out, counts

def polar_summary(stack, model_name):
    mats = []
    for s in stack:
        M, _ = polar_bin_aggregate(np.abs(s))
        mats.append(M)
    Mmean = np.mean(mats, axis=0)
    ring_strength = Mmean.mean(axis=1)
    r_idx = int(np.argmax(ring_strength))
    r_lo, r_hi = rad_edges[r_idx], rad_edges[r_idx + 1]
    row = Mmean[r_idx, :]
    cv = float(np.std(row) / (np.mean(row) + 1e-12))

    plt.figure(figsize=(6, 3.8))
    plt.imshow(Mmean, aspect="auto", origin="lower", cmap="magma")
    plt.colorbar(fraction=0.046, pad=0.04)
    plt.yticks(np.arange(NB_RADIAL), [f"{rad_edges[i]:.0f}-{rad_edges[i + 1]:.0f}" for i in range(NB_RADIAL)])
    plt.xticks(np.arange(0, NB_THETA, 4), [f"{int(360 * t / NB_THETA)}°" for t in range(0, NB_THETA, 4)])
    plt.xlabel("Angle (θ)"); plt.ylabel("Radius band (px)")
    plt.title(f"{model_name} — polar mean |SHAP|\n(peak ring ~ {r_lo:.0f}–{r_hi:.0f}px; anisotropy CV={cv:.3f})")
    plt.tight_layout(); plt.show()

    print(f"{model_name} — peak ring: ~{r_lo:.0f}–{r_hi:.0f} px, anisotropy (CV across angle) = {cv:.3f}")
    return (r_idx, (r_lo, r_hi), cv, Mmean)

r_ols, band_ols, cv_ols, M_ols   = polar_summary(shap_ols_imgs,   "OLS")
r_rid, band_rid, cv_rid, M_rid   = polar_summary(shap_ridge_imgs, "Ridge")
r_cnn, band_cnn, cv_cnn, M_cnn   = polar_summary(shap_cnn_imgch,  f"CNN [{BEST_BASELINE}]")

print("\n[Interpretation — polar summaries]\n"
      "• Peak ring should sit on the shell radius; CNN often shows sharper peaks on inner/outer edges.\n"
      "• Anisotropy (CV) near zero ⇒ isotropic ring; higher CV ⇒ focus on arcs (useful for local defects).\n")

# -----------------------------
# CNN regional ablation + effect sizes (optional sanity)
# -----------------------------
print("\n[CNN regional ablation + effect sizes]")
K_REGIONS = 12
flat_scores = M_cnn.ravel()
top_idx = np.argsort(-flat_scores)[:K_REGIONS]
top_regions = [(i // NB_THETA, i % NB_THETA) for i in top_idx]

mask_top = np.zeros((H, W), dtype=bool)
for (rr, tt) in top_regions:
    mask_top |= ((rad_bin == rr) & (theta_bin == tt))

X_cnn_test_4_ablate = X_cnn_test_4.copy()
X_cnn_test_4_ablate[:, 0, :, :] = np.where(mask_top[None, :, :], 0.0, X_cnn_test_4_ablate[:, 0, :, :])

y_pred_cnn_base = y0_cnn_pred
y_pred_cnn_abl  = cnn_predict_y(shap_model, X_cnn_test_4_ablate)

r2_cnn_base = r2_score(y_obs_test_sel, y_pred_cnn_base)
r2_cnn_abl  = r2_score(y_obs_test_sel, y_pred_cnn_abl)
print(f"  CNN R² (baseline on this subset): {r2_cnn_base:.4f}")
print(f"  CNN R² after ablating top‑{K_REGIONS} polar regions: {r2_cnn_abl:.4f}")
print(f"  R² drop: {r2_cnn_base - r2_cnn_abl:.4f} (larger ⇒ those regions matter)\n")

# -----------------------------
# Augmentation‑invariance sanity check (flips)
# -----------------------------
print("\n[Augmentation‑invariance sanity check — CNN SHAP under flips]")
def flip_h(arr):  # horizontal flip
    return arr[..., :, ::-1]

Xt_flip_np = flip_h(X_cnn_test_4.copy()).copy()
Xt_flip = torch.from_numpy(Xt_flip_np).to(DEVICE).requires_grad_(True)
torch.set_grad_enabled(True)
e_best = shap.DeepExplainer(shap_model, build_baseline(BEST_BASELINE, m=(64 if BEST_BASELINE == "train_subset" else 1)))
vals_flip = safe_shap_values(e_best, Xt_flip)
torch.set_grad_enabled(False)
sh_flip = vals_flip[0] if isinstance(vals_flip, list) else vals_flip
sh_flip_img_unflipped = flip_h(sh_flip[:, 0]).copy()

if HAVE_SPEARMAN:
    corrs = []
    for i in range(N_EXPLAIN):
        a = np.abs(shap_cnn_imgch[i]).ravel()
        b = np.abs(sh_flip_img_unflipped[i]).ravel()
        if np.std(a) < 1e-12 or np.std(b) < 1e-12:
            corrs.append(np.nan)
        else:
            corrs.append(spearmanr(a, b, nan_policy="omit").correlation)
    corrs = np.array(corrs)
    print(f"  Flip‑invariance — Spearman(|SHAP|, original vs unflipped‑flip): "
          f"median={np.nanmedian(corrs):.3f}, IQR=({np.nanpercentile(corrs, 25):.3f}, {np.nanpercentile(corrs, 75):.3f})")
else:
    print("  (Install SciPy to compute Spearman rank correlations.)")

# -----------------------------
# Optional: Integrated Gradients (IG) for triangulation (no external libs)
# -----------------------------
print("\n[Integrated Gradients (IG) — optional triangulation]")
def integrated_gradients(model_torch, X_np, baseline_np, steps=32):
    was_enabled = torch.is_grad_enabled()
    try:
        torch.set_grad_enabled(True)
        model_torch.eval()
        X = torch.from_numpy(np.ascontiguousarray(X_np)).to(DEVICE)
        B = torch.from_numpy(np.ascontiguousarray(baseline_np)).to(DEVICE)
        attr = torch.zeros_like(X)
        for s in range(1, steps + 1):
            t = s / steps
            Xi = B + t * (X - B)
            Xi.requires_grad_(True)
            out = model_torch(Xi)
            grads = torch.autograd.grad(out.sum(), Xi, retain_graph=False, create_graph=False)[0]
            attr += grads
        attr = (X - B) * attr / steps
        return attr.detach().cpu().numpy()
    finally:
        torch.set_grad_enabled(was_enabled)

X_base_ig = X_cnn_test_4.copy(); X_base_ig[:, 0] = 0.0
attr_ig = integrated_gradients(shap_model, X_cnn_test_4, X_base_ig, steps=32)[:, 0]

if HAVE_SPEARMAN:
    igcorrs = []
    for i in range(N_EXPLAIN):
        a = np.abs(attr_ig[i]).ravel()
        b = np.abs(shap_cnn_imgch[i]).ravel()
        if np.std(a) < 1e-12 or np.std(b) < 1e-12:
            igcorrs.append(np.nan)
        else:
            igcorrs.append(spearmanr(a, b, nan_policy="omit").correlation)
    igcorrs = np.array(igcorrs)
    print(f"  IG vs Deep SHAP — Spearman(|attr|): median={np.nanmedian(igcorrs):.3f} "
          f"IQR=({np.nanpercentile(igcorrs, 25):.3f}, {np.nanpercentile(igcorrs, 75):.3f})")
else:
    print("  (Install SciPy to compute Spearman rank correlations.)")

# -----------------------------
# Final, conclusions
# -----------------------------
print("\n=== Conclusions ===")
print("• OLS: clean ring; Ridge: more diffused ring; CNN: crisp double band on inner/outer edges "
      "with local arcs where defects (or the star) dominate.")
print("• Deletion curves (AOPC) quantify faithfulness; higher AOPC → better.")
print("• Bounding‑box metrics (largest component) — CNN concentrates |SHAP| into **smaller**, more **angularly tight** regions, "
      "capturing a **larger share** of |SHAP| with less area, especially on the hard dataset.")
print("• Baseline matters; this cell auto‑selects one by IoU@k + AOPC. IG triangulation is included as a sanity check.")


# SHAP interpretations — **hard (star) dataset**

This note interprets SHAP outputs for OLS, Ridge, and the CNN on the **“Death‑Star”** dataset where a small, energy‑neutral star on the shell is the dominant *local* driver of the label. Linear baselines remain informative via a mild global anchor; the CNN can localise the star.

---

## 1) What changes vs the simple dataset?

- The star is **local** and can appear **anywhere on the shell**, so a faithful model must be **anisotropic**: short angular arcs should light up.
- Energy neutrality prevents a linear model from exploiting the **global mean**; attribution that leaves the shell or spreads uniformly is a red flag.

---

## 2) Core definitions (as used in the simple note)

- **Region share / mean / concentration / IoU / Precision / AOPC:** identical to the definitions given earlier:
  - $\mathrm{share}_\Omega(S)$, $\mathrm{mean}_\Omega(S)$, $\mathrm{conc}_p(S)$,
  - $\mathrm{IoU}_k$, $\mathrm{Prec}_k$,
  - $\displaystyle \mathrm{AOPC}=\int_0^{f_{\max}}\big(R^2_0-R^2(f)\big)\,\mathrm{d}f$.
- **Star‑disc share & enrichment:**
  $$
  \text{share}=\frac{\sum_{D\cap\text{Shell}}|S|}{\sum_{\text{Shell}}|S|},\qquad
  \text{enrich}=\frac{\text{share}}{\frac{|D\cap\text{Shell}|}{|\text{Shell}|}}.
  $$
  Here $D$ is a radius‑4 disc centred at a **star proxy** (ring‑weighted bright maximum).

---

## 3) OLS — interpretation on the star dataset

**Numbers**
- **Region share:** Hollow **2.37%**, Shell **94.34%**, Outside **3.29%**.
- **Per‑pixel mean:** Hollow **2.4e‑5**, Shell **2.11e‑4**, Outside **7e‑6**.
- **Concentration:** top‑1% **32.46%**, top‑5% **64.60%**.
- **IoU / Precision:** at 1%: IoU **0.022**, Precision **0.982**; at 10%: IoU **0.217**, Precision **0.987**.
- **Deletion AOPC:** **0.0545** (modest).

**Reading the plots**
- **Overlays / global mean:** a very clear ring, much like the simple dataset — OLS stays edge‑driven.  
- **Radial profile:** sharp inner‑edge peak with outer‑edge shoulder.

**Interpretation**
- OLS still behaves plausibly (attribution on the shell), but the deletion curve shows that these pixels explain **less** of the prediction when the star is the true driver. The linear model lacks a mechanism to focus tightly on the small star.

---

## 4) Ridge — interpretation on the star dataset

**Numbers**
- **Region share:** Hollow **12.43%**, Shell **28.52%**, Outside **59.05%**.
- **Per‑pixel mean:** Hollow **5.6e‑5**, Shell **2.9e‑5**, Outside **6.1e‑5**.
- **Concentration:** top‑1% **9.62%**, top‑5% **27.02%**.
- **IoU / Precision:** at 1%: IoU **0.003**, Precision **0.130**; at 10%: IoU **0.026**, Precision **0.142**.
- **Deletion AOPC:** **−0.0115** (negative).

**Reading the plots**
- **Overlays:** diffuse speckle into the background; weak ring signal.
- **Radial profile:** mostly flat; no strong peaks.

**Interpretation**
- The negative AOPC means removing the “important” pixels **does not** hurt performance — sometimes it even helps. This is a hallmark of **unfaithful** attributions: the model’s predictions are not controlled by the regions that the explanation highlights.

---

## 5) CNN — interpretation on the star dataset (baseline by fidelity)

**Baseline selection:** `train_subset` is chosen by a combined score (IoU@k + AOPC). Other baselines are close; the choice is consistent across runs.

**Numbers**
- **Bootstrapping (train‑subset backgrounds):** Hollow **0.014±0.001**, Shell **0.982±0.002**, Outside **0.004±0.000**.
- **Region share:** Hollow **1.47%**, Shell **98.15%**, Outside **0.37%**.
- **Per‑pixel mean:** Hollow **5.1e‑5**, Shell **8.25e‑4**, Outside **3e‑6**.
- **Concentration:** top‑1% **37.58%**, top‑5% **69.52%**.
- **IoU / Precision:** at 1%: IoU **0.022**, Precision **0.996**; at 10%: IoU **0.219**, Precision **0.994**.
- **Deletion AOPC:** **0.2745** (large, and much higher than OLS/Ridge).

**Reading the plots**
- **Overlays / global mean:** short, high‑contrast **arcs** on the ring, frequently where the star sits.  
- **Radial profile:** two pronounced peaks at the inner and outer edges; little mass away from the shell.

**Interpretation**
- High concentration and very large AOPC together indicate that the CNN’s predictions are driven by a **compact set** of shell pixels — consistent with a **local star**.  
- Precision ≈ 1.0 across $k$ confirms that the top attributions rarely leave the shell.

---

## 6) Star‑oriented diagnostics (qualitative on the provided plots)

- **Star‑disc share / enrichment boxplots:** the CNN has the **highest** median share and enrichment at the star location; OLS is lower; Ridge is lowest.  
  A good outcome on this dataset is **CNN ≫ OLS ≫ Ridge**.  
- **Bboxes (largest component, top‑1%):** the CNN’s boxes are **small** and slanted along the ring with high $|{\rm SHAP}|$ density inside; OLS boxes are thin arcs; Ridge boxes often sit on noise.

**What “good” looks like**
- High enrichment at the star disc, small boxes, and large AOPC.  
**What “bad” looks like**
- Enrichment $\approx 1$ or $<1$ for the CNN, boxes drifting off the shell, or negative AOPC.

---

## 7) Local vs global reading on the star dataset

- **Local:** the CNN’s top‑1% pixels form a compact arc aligned with the star (small angular span).  
- **Global:** the mean map remains annular — the model knows *where* interesting things live (the shell) but it only **needs** a small part at test time.

---

## 8) Practical pitfalls

- **Baseline sensitivity (Deep SHAP):** different baselines shift magnitudes. Use the fidelity‑based picker and report the choice.
- **Area bias:** “region share” favours large regions; always pair it with the **per‑pixel mean**.
- **IoU at tiny $k$:** will be small for thin structures even when explanations are correct.
- **Negative AOPC:** treat it as a red flag for explanation faithfulness.

---

## 9) Takeaways for the **star** dataset

- **OLS** remains edge‑based but cannot prioritise the small star strongly (AOPC **0.0545**).  
- **Ridge** is unfaithful here (AOPC **−0.0115**).  
- **CNN** shows the intended behaviour: attributions concentrate on **short shell arcs**, deletion hurts substantially (AOPC **0.2745**), and enrichment at the star location is highest among the three.


# SHAP interpretations — harder “Death‑Star” dataset (part 2)

This note interprets the **local** and **global** SHAP evidence for the harder dataset in which a bright 5‑arm *star* is injected on the metallic shell. It focuses on what each diagnostic is measuring, what the numbers imply for **OLS**, **Ridge**, and the **CNN**, and how to read the results safely.



---

## 1) Quick orientation — local vs global attribution

- **Local** views answer *“where, in this specific image, did the model pick up signal?”*  
  Examples here: **bounding‑box** metrics over top‑$1\%$ $|{\rm SHAP}|$ and the **star‑disc** capture at a tiny disc centred on the detected star.

- **Global** views answer *“across many images, what patterns does the model rely on?”*  
  Examples here: region **shares** (hollow / shell / outside), **deletion curves** (AOPC), **polar superpixels** (radius–angle maps), and **baseline sensitivity**.

The two perspectives should agree: if a model relies on a strictly local driver (the star), local metrics will be tight, and global metrics will report heavy concentration near the star’s radius and angle.

---

## 2) Bounding‑box metrics (top‑$1\%$ $|{\rm SHAP}|$ inside the shell)

**What is computed?**  
For each image:

1. Select the top $p=1\%$ of $|{\rm SHAP}|$ **within the shell**.  
2. Keep the **largest 8‑connected component** (by total weight).  
3. Draw the tight axis‑aligned **bounding box** around that component.  
4. Report:
   - **Area fraction** of the shell covered:  
     $$
     \text{AreaFrac} \;=\; \frac{|{\rm Box}\cap \text{Shell}|}{|\text{Shell}|}\,.
     $$
   - **Share of $|{\rm SHAP}|$** captured by that box (within the shell):  
     $$
     \text{Share} \;=\; \frac{\sum_{(i,j)\in{\rm Box}\cap\text{Shell}} |s_{ij}|}{\sum_{(i,j)\in\text{Shell}} |s_{ij}|}\,.
     $$
   - **Angular coverage** of selected pixels (degrees).  
   - **Radial width** normalised by shell thickness.

**Why it matters.**  
Small **AreaFrac** yet large **Share** means a **local cue** explains a lot of the model’s behaviour. Angular coverage tells whether that cue is a short arc or a long sector; radial width indicates whether the emphasis is on a thin edge vs a broader band.

**Your medians (hard dataset).**

- **OLS**: Area $0.003$, Share $0.043$, Angle $6.7^\circ$, Radial width $0.090$.  
  **Enrichment of the box** $\approx 0.043 / 0.003 \approx 14\times$. Tight arcs near shell edges dominate the linear model’s credit.
- **Ridge**: Area $0.009$, Share $0.069$, Angle $8.9^\circ$, Radial width $0.226$.  
  Enrichment $\approx 7.7\times$. Credit is more **diffuse**; radial width is larger, consistent with the model spreading mass beyond crisp edges (and even outside the shell; see §4).
- **CNN**: Area $0.011$, Share $0.268$, Angle $11.7^\circ$, Radial width $0.239$.  
  Enrichment $\approx 24\times$. The box captures the star **plus** adjacent inner/outer edges, hence wider radial span but still a small angular sector.

**How to read this.**  
The CNN concentrates a large portion of $|{\rm SHAP}|$ into **very small boxes**. OLS also forms tight boxes but captures far less $|{\rm SHAP}|$. Ridge is the most diffuse.

---

## 3) Global star‑capture and enrichment (local cue measured **globally**)

**What is measured?**

1. **Star centre** $(y_\*,x_\*)$ is estimated directly from the *image* via a ring‑weighted, smoothed peak finder.  
2. A small **disc** of radius $r=4$ px is placed at $(y_\*,x_\*)$:  
   $$
   D \;=\; \big\{(i,j) : (i-y_\*)^2 + (j-x_\*)^2 \le r^2\big\}.
   $$
3. **Share inside the disc** (relative to the shell):  
   $$
   \text{Share}_D \;=\; \frac{\sum_{(i,j)\in D\cap\text{Shell}} |s_{ij}|}{\sum_{(i,j)\in\text{Shell}} |s_{ij}|}\,.
   $$
4. **Enrichment** compares this share to the **area fraction** of the disc in the shell:  
   $$
   \text{Enrich} \;=\; \frac{\text{Share}_D}{\,|D\cap\text{Shell}|/|\text{Shell}|\,}\,.
   $$
   For this geometry, the disc occupies only about $1\%$ of the shell, so an enrichment of $25$ means “about $25\times$ more $|{\rm SHAP}|$ than random area would explain.”
5. **Angular alignment** compares the star’s angle $\theta_\*$ with the **circular mean angle** of the largest $|{\rm SHAP}|$ component:  
   $$
   \bar\theta \;=\; \operatorname{atan2}\!\Big(\sum_k w_k\sin\theta_k,\; \sum_k w_k\cos\theta_k\Big),\qquad
   \Delta\theta \;=\; \big|((\theta_\*-\bar\theta+\pi)\bmod 2\pi) - \pi\big|\,.
   $$

**Why it matters.**  
This couples **where the true local driver sits** with **where attribution mass sits**. It turns a local property (a tiny star) into a global, comparable score.

**Medians (hard dataset).**

- **OLS**: Share $0.056$, Enrich $5.19$, Angle error $\tilde{\Delta\theta}=1.01^\circ$.  
  Linear pixels do respond to the star location, but far less than the CNN.
- **Ridge**: Share $0.070$, Enrich $6.45$, Angle error $0.53^\circ$.  
  Alignment is precise; mass is limited because attribution is spread outside the shell too.
- **CNN**: Share $0.273$, Enrich $25.19$, Angle error $0.48^\circ$.  
  A quarter of all shell‑attribution is packed into a ~1% disc at the star. This is the expected signature of a **local causal driver**.

**Important control — when there is *no star*.**  
On the simple dataset the disc is just a random place on the ring. The expected share then is near its **area fraction** (≈1%), and enrichment hovers around $1$. Large angle errors are normal because there is no canonical angle. Observing near‑unity enrichment and large angular errors in that case confirms the metric is not spuriously locking onto unrelated structure.

---

## 4) Region shares and polar superpixels (global structure)

**Region shares** summarise where $|{\rm SHAP}|$ lives on average. Denote $s_{ij}$ the SHAP map and regions $\mathcal{H}$ (hollow), $\mathcal{S}$ (shell), $\mathcal{O}$ (outside). The mean **per‑pixel** strength in a region is

$$
\mu_{\mathcal{R}} \;=\; \frac{1}{|\mathcal{R}|}\sum_{(i,j)\in\mathcal{R}} |s_{ij}| \,.
$$

- **OLS**: shell‑centred pattern, as desired for a pixel‑wise linear model of a ring.  
- **Ridge**: *outside* dominates (e.g. Outside share $59.05\%$; per‑pixel mean outside > shell). This indicates sensitivity to background / normalisation rather than shell physics.  
- **CNN**: $98.15\%$ of area‑biased share in the shell; per‑pixel mean in the shell is an order of magnitude above elsewhere.

**Polar superpixels** bin $|{\rm SHAP}|$ by radius and angle. The **CV across angle** in the peak radial band captures **anisotropy**: 

$$
\mathrm{CV} \;=\; \frac{\sigma_\theta}{\mu_\theta}\,.
$$

- **Peak radius**: OLS at $14$–$21$ px (inner edge); CNN at $28$–$35$ px (towards outer edge); Ridge peaks outside the shell ( $49$–$56$ px ), echoing the region‑share warning.
- **Anisotropy**: CNN CV $0.413$ (focused arcs), OLS CV $0.146$, Ridge CV $0.091$ (nearly isotropic, but largely off‑shell). Focused arcs are the expected footprint of a **local** defect (the star).

---

## 5) Deletion curves and AOPC (faithfulness check)

Pixels are ranked by $|{\rm SHAP}|$ on the image channel; the top fraction is zeroed; the model is re‑scored against test labels. Define $R^2(f)$ the score after deleting a fraction $f$; the **Area Over the Perturbation Curve** is

$$
\text{AOPC} \;=\; \int_0^{f_{\max}}\big(R^2(0)-R^2(f)\big)\,df\,,
$$

approximated by the trapezoid rule in the plots. Larger AOPC $\Rightarrow$ faster *drop* in $R^2$ $\Rightarrow$ higher **fidelity** of the explanations.

- **OLS**: AOPC $0.0545$ — some faithfulness; deletion hurts performance moderately.
- **Ridge**: AOPC $-0.0115$ — deleting “important” pixels slightly **improves** $R^2$. This is a red flag: attribution seems mis‑placed (consistent with outside‑shell emphasis).
- **CNN**: AOPC $0.2745$ — large drop; attributions are highly **causal** for the prediction.

**Pitfall.** Deletion is baseline‑dependent. If the replacement value inadvertently regularises the input (e.g. removes background confounds for Ridge), AOPC can be pessimistic or even negative, without implying the SHAP algorithm is wrong.

---

## 6) Baseline sensitivity, flips, and IG triangulation

- **Deep SHAP baseline:** median Spearman rank correlation $0.759$ between train‑subset and mean baselines indicates **robust** ordering of pixel importance, but not perfect invariance.  
  Baseline choice should be guided by downstream **fidelity** (IoU@k and AOPC), not just convenience.
- **Flip‑invariance:** median Spearman $0.787$ (original vs horizontally flipped‑&‑unflipped explanations) shows the CNN’s attributions move **with** the image, as expected from a spatial model.
- **IG vs Deep SHAP:** median Spearman $0.551$ on $|{\rm attr}|$ suggests both methods agree on the main loci of evidence; differences are expected because IG integrates gradients along a path from a baseline, while Deep SHAP uses a background distribution.

---

## 7) Putting local and global evidence together

- **Local (star‑disc, boxes).** The CNN allocates roughly **a quarter** of its shell attribution to a **1% disc** centred on the star (enrichment $\approx 25\times$) and packs **$26$–$33\%$** of shell attribution into tiny boxes. This is the canonical footprint of a **strong, local causal driver**. OLS notices the star but gives it much less weight; Ridge is distracted by background structure.
- **Global (deletion, region shares, polar maps).** The CNN’s large AOPC and high anisotropy in the correct radial band confirm the local picture globally. Ridge’s negative AOPC and off‑shell peak expose a global mismatch with the task.

---

## 8) Practical reading guide (what “good” looks like)

- **Good local + good global:** high disc enrichment ($\gg 1$), tiny angular error ($<1^\circ$), small boxes with high $|{\rm SHAP}|$ share, large AOPC, peak radius on the shell, high anisotropy when a local cue should matter. → **CNN here.**
- **Looks ring‑like but shallow:** moderate disc enrichment ($\approx 5$–$7$), small angular error, modest box share, modest AOPC. → **OLS here.**
- **Diffuse or off‑shell:** low precision vs shell, negative/flat AOPC, peak radius outside shell. → **Ridge here.**

---

## 9) Common pitfalls and how to avoid them

1. **Confusing area‑biased with per‑pixel means.** The “region share” can be dominated by area. Always compare with per‑pixel means to spot off‑shell leakage.  
2. **Threshold sensitivity.** Changing the top‑$p\%$ affects boxes and IoU. Stability across a few $p$ values is a good sanity check.  
3. **Deletion baseline artefacts.** Zeroing pixels can regularise inputs and artificially help a model (Ridge). Report AOPC alongside a qualitative check of deletion images.  
4. **Assuming causality from a single view.** Align local and global evidence before claiming a driver is causal.  
5. **Ignoring the background choice for Deep SHAP.** Baselines shift absolute magnitudes. Prefer baselines that maximise **fidelity** on held‑out images.

---

## 10) Condensed numerical summary (medians)

- **Bounding boxes (Area / Share / Angle / Radial width):**  
  OLS $0.003 / 0.043 / 6.7^\circ / 0.090$ → $\sim14\times$ enrichment.  
  Ridge $0.009 / 0.069 / 8.9^\circ / 0.226$ → $\sim7.7\times$.  
  CNN $0.011 / 0.268 / 11.7^\circ / 0.239$ → $\sim24\times$.

- **Star‑disc (Share / Enrich / Angle error):**  
  OLS $0.056 / 5.19 / 1.01^\circ$; Ridge $0.070 / 6.45 / 0.53^\circ$; CNN $0.273 / 25.19 / 0.48^\circ$.

- **Deletion (AOPC):** OLS $0.0545$; Ridge $-0.0115$; CNN $0.2745$.

- **Polar peaks (radius band; anisotropy CV):**  
  OLS $14$–$21$ px; $0.146$. Ridge $49$–$56$ px; $0.091$. CNN $28$–$35$ px; $0.413$.

---

### Bottom line

The CNN’s explanations show a **coherent story** across all lenses: a **local** driver on the shell dominates the prediction. Linear models track ring structure but either **under‑weight** the star (OLS) or **misallocate** credit to the background (Ridge). The combination of **high disc enrichment**, **small angular error**, **large AOPC**, and **on‑shell polar peaks** is precisely the pattern expected when a model has learned to localise and use the star.
