---
title: "LeNet-Driven CNN Explainability through ExcitationBP Heat-Maps and Sparse Concepts using XNN + SRAE"
author: "Brian Cervantes Alvarez"
date: "2025-06-06"
date-format: full
format:
  revealjs:
    theme: simple
    slide-number: true
    scrollable: true
html-math-method: mathjax
lightbox: true
crossref:
  fig-title: Fig
  tbl-title: Tbl
  title-delim: "—"
eval: false
---


## Agenda {data-background-color="#D73F09" .smaller} 

1. *Motivation and risk*—Why care about explanations?
2. **Convolutional neural network (CNN) architecture** overview  
   * Layers and operations  
   * Training considerations
3. **XNN + SRAE**—concept extraction
4. **Causal localisation** (I-GOS++)
5. **Ethical considerations & Future research**
6. **Questions and discussion**


## 1. Motivation and Risk — Why Care About Explanations? {.smaller}

* **Safety-critical decisions**  
  * Autonomous vehicles: distinguish a **red traffic light** from a red billboard  
  * Medical diagnostics: pathologists must identify the specific cells that trigger a cancer flag  
* **Legal and ethical compliance**  
  * EU AI Act, FDA SaMD guidelines require interpretability  
* **Model debugging and improvement**  
  * Detect shortcut learning (e.g., background colour in bird identification)  
  * Reveal dataset bias and temporal drift  

Opaque decision pathways can introduce risk, undermine confidence, and complicate audits.


# 2. CNN Architecture Refresher {data-background-color="#D73F09"}

## Building Blocks of a CNN {.smaller}

:::: {.columns}
::: {.column width="40%"}
1. **Convolution** (+ ReLU)  
2. **Stride / Padding**  
3. **Pooling** (max & average)  
4. **Normalisation** (batch / layer)  
5. **Residual connections**  
6. **Flatten → dense head**  
7. **Regularisation** (dropout)
:::
::: {.column width="60%"}
![Figure 1. High-level CNN pipeline](images/article-hero-cnn.webp)
:::
::::

## Convolution Kernels {.smaller}

![Figure 2. Kernel illustration](images/kernel_filter.webp)

A **kernel** (filter) traverses the image, computing local dot-products to detect edges, textures, or colour blobs.  
Output spatial size: $\bigl[i - k\bigr] + 1$

## Stride & Padding {.smaller}

:::: {.columns}
::: {.column width="50%"}
![Figure 3. Stride illustration](images/stride.webp)

Stride $s$ controls down-sampling:  
$\bigl[i - k\bigr] / s + 1$
:::
::: {.column width="50%"}
![Figure 4. Padding illustration](images/padding.webp)

Zero-padding $p$ preserves border context:  
$\bigl[i - k + 2p\bigr] / s + 1$
:::
::::

## Pooling Layers {.smaller}

![Figure 5. Max vs average pooling](images/pooling.webp)

* **Max pooling:** retains dominant activation, emphasising edge detectors  
* **Average pooling:** provides a smoother summary, capturing global context  

Both improve translation invariance and reduce computation.

## Normalization in CNNs {.smaller}

After each convolution, the distribution of activations can shift (“internal covariate shift”), slowing learning. Normalization keeps activations on a stable scale so the network converges faster. For each channel and each mini-batch:

:::: {.columns}
::: {.column width="33.3%"}
Compute the batch mean  
$$
\mu = \frac{1}{m}\sum_{i=1}^m x_i
$$
:::
::: {.column width="33.3%"}
Compute the batch variance  
$$
\sigma^2 = \frac{1}{m}\sum_{i=1}^m (x_i - \mu)^2
$$
:::
::: {.column width="33.3%"}
Standardize and re‐scale  
$$
\hat{x}_i = \frac{x_i - \mu}{\sqrt{\sigma^2 + \epsilon}}, \quad
y_i = \gamma\,\hat{x}_i + \beta
$$
:::
::::

$\epsilon$ avoids division by zero & $\gamma,\beta$ are learnable scale and shift.

* **Faster convergence:** gradients stay well‐scaled  
* **Milder sensitivity** to initialization and learning rate  
* **Regularization effect:** slight noise from batch estimates helps generalization


## Residual Connections (ResNets) {.smaller}

Deep networks suffer from vanishing gradients; stacking more layers can *worsen* training accuracy. A shortcut (identity) path lets the block learn a *residual* function:

$$
\mathbf{y} = \mathcal{F}(\mathbf{x}) + \mathbf{x}
$$

* $\mathbf{x}$: input to the block  
* $\mathcal{F}(\mathbf{x})$: the “residual” (e.g., Conv → ReLU → Conv)  

* Enables gradients to flow directly through the identity branch  
* Lets layers fit small corrections ($\mathcal{F}$) rather than full mappings  
* Empirically allows networks of hundreds of layers to train  


## Flatten & Fully Connected Layers {.smaller}

![Figure 6. Flatten illustration](images/flatten.webp)

The three-dimensional tensor ($H\times W\times K$) is flattened and passed to dense layers, which integrate global context and output class probabilities via softmax.


## Regularization: Dropout {.smaller}

![Figure 7. Dropout mask](images/dropout.webp)

Randomly deactivates units during training to discourage co-adaptation and mitigate overfitting.


# 3. From CNN Architecture to XNN Explanation {data-background-color="#D73F09"}

## What is XNN? {.smaller}

- **XNN** = e**X**planation **N**eural **N**etwork  
- A compact network **attached** to a frozen, pre-trained CNN  
- **Encoder** $E_\theta$: compresses the CNN’s feature vector  
  $\displaystyle \mathbf{z}\in\mathbb{R}^D \;\longmapsto\; \mathbf{e}=E_\theta(\mathbf{z})\in\mathbb{R}^L$  
- **Projection** $v$: lifts $\mathbf{e}$ back to the original output space  
  $\displaystyle \hat{\mathbf{y}}=v^\top\mathbf{e}$  
- Trained with three losses to enforce:  
  1. **Faithfulness** (match original predictions)  
  2. **Sparsity** (keep concepts concise)  
  3. **Orthogonality** (make concepts distinct)  
- **Gives us:** sparse, stable, and interpretable “concept activations” you can visualize and audit  


## What is XNN? {.smaller}

:::: {.columns}
::: {.column width="50%"}
![Figure 8.](images/mainPaper1.png){width=80%}

- The explanation network takes $Z$, squeezes it down to just a few important “concept” numbers (the low-dimensional $E$), and uses those to reproduce the same prediction—so we can see which few concepts the CNN really cares about.


:::
::: {.column width="50%"}
![Figure 9.](images/mainPaper2.png){width=80%}

- First, the explanation network uses a small “encoder” to turn the long feature vector $Z$ into a short concept vector $E$. 
- Then a “decoder” tries to rebuild only the key parts of $Z$ (so we stay focused on a few features), a penalty makes sure each concept is different from the others, and finally a simple linear layer $v$ uses those concepts to exactly match the CNN’s original output.
:::
::::


## XNN + SRAE Loss Components {.smaller}

For a batch of $M$ inputs, we train the explanation network by minimizing the composite loss function:

$$
\boxed{
\mathcal{L} \;=\; L_{\mathrm{faith}} \;+\; L_{\mathrm{SR}} \;+\; L_{\mathrm{PT}}
}
$$

balancing  

1. **Faithfulness** $\;L_{\mathrm{faith}}$  
2. **Sparse-Reconstruction** $\;L_{\mathrm{SR}}$  
3. **Orthogonality (Pull-Away Term)** $\;L_{\mathrm{PT}}$  

## Understanding the XNN + SRAE Total Loss

$$
\mathcal{L} \;=\; L_{\mathrm{faith}} \;+\; L_{\mathrm{SR}} \;+\; L_{\mathrm{PT}}
$$

The total loss combines **three goals**: it makes the explanation network’s output match the original CNN’s predictions (faithfulness), forces it to use only a few key internal features (sparsity), and ensures each learned “concept” is different from the others (orthogonality). By minimizing this combined score, we get a small set of simple, distinct concepts that still predict as accurately as the original model.


## (1) Faithfulness Loss $L_{\mathrm{faith}}$ {.smaller}

$$
L_{\mathrm{faith}}
\;=\;
\underbrace{\frac{1}{M}\sum_{i=1}^{M}\sum_{c=1}^{C}
\Bigl(\bar y_{i,c} \;-\;\hat y_{i,c}\Bigr)^{2}}_{\substack{\text{average of how far apart}\\\text{the two sets of scores are}}}
$$

where

* $\bar y_{i,c}$ is the **backbone’s** output score (the CNN’s own number) for class $c$ on image $i$.
* $\hat y_{i,c} = v^\top E_\theta\bigl(Z^{(i)}\bigr)$ is the explanation network’s score for class $c$ (the small network’s prediction).

This loss measures, on average, how different the explanation network’s scores are from the CNN’s scores. By making this number small, we force the small network to copy the CNN’s predictions closely.



## (1) Faithfulness Loss $\;L_{\mathrm{faith}}$

By passing $Z$ through the five-dimensional encoder and then using $v$ to predict $\hat{y}$, we ensure the five concepts can recreate the CNN’s original output. In practice, $\hat{y}$ should match the softmax scores from LeNet-5’s final layer almost exactly.

> **Faithfulness** should converge toward 0 (perfect matching of the CNN’s logits).


## (2) Sparse-Reconstruction Loss $\;L_{\mathrm{SR}}$ {.smaller}


$$
L_{\mathrm{SR}}
\;=\;
\underbrace{\frac{\beta}{D_z}}_{\substack{\text{trade-off}\\\text{weight}}}
\underbrace{\sum_{k=1}^{D_z}
\log\!\Biggl(
1 \;+\;
q \;\underbrace{\frac{1}{M}\sum_{i=1}^{M}
\bigl(\,\tilde Z^{(i)}_{k} - Z^{(i)}_{k}\bigr)^{2}}_{\substack{\text{average squared}\\\text{reconstruction error}}}
\Biggr)}_{\substack{\text{encourages using only}\\\text{a few important features}}}
$$

where

* $Z^{(i)}_{k}$ is the $k$th coordinate of the backbone activation for image $i$.
* $\tilde Z^{(i)}_{k}$ is that same coordinate rebuilt by the decoder.
* $\beta$ controls how strongly we force sparsity (fewer features).
* $q$ makes the penalty jump quickly if reconstruction errors get big.
* $D_z$ is the total number of coordinates in the backbone’s activation.

By taking $\log\bigl(1 + q \times \text{MSE}\bigr)$, we give a small penalty for small errors and a larger penalty for big errors. This makes the explanation network focus on only a **few** key coordinates of $Z$, so most coordinates stay zero.



## (2) Sparse-Reconstruction Loss $\;L_{\mathrm{SR}}$ 

The decoder $\tilde{\theta}$ is trained to reconstruct only a small subset of the 128 features (with a log‐penalty), forcing each of the five concepts to focus on a few truly important patterns (e.g., “vertical bar” or “curved loop”).

> **Sparsity** should converge toward 0 (only a few features used, making explanations very concise).


## (3) Pull-Away (Orthogonality) Loss $L_{\mathrm{PT}}$ {.smaller}

$$
L_{\mathrm{PT}}
\;=\;
\underbrace{\frac{\eta}{L(L-1)}}_{\substack{\text{strength of}\\\text{orthogonality}}}
\underbrace{\sum_{\substack{l,l'=1\\l\neq l'}}^{L}
\Bigl(\tfrac{h_{l}^{\top}h_{l'}}{\|h_{l}\|\;\|h_{l'}\|}\Bigr)^{2}}_{\substack{\text{encourages each concept}\\\text{to be distinct (non-overlapping)}}}
$$

where

* $h_{\ell}\in\mathbb{R}^M$ is the vector of activations for concept $\ell$ across the batch.
* $L$ is the total number of concepts (explanation dimensionality).
* $\eta$ controls how strongly we discourage concepts that fire together.

By penalizing squared cosine similarities between every pair of concept activations, we force each concept to occupy its own “direction” in activation space, yielding a set of distinct, non-redundant explanations.



## (3) Pull-Away (Orthogonality) Loss $\;L_{\mathrm{PT}}$ 

A pull-away term makes sure each of the five concept activations is different from the others, so you end up with five distinct “reasons” that the CNN used to identify a digit.

> **Orthogonality** should converge toward 0 (each concept is completely distinct from the others).




## How XNN + SRAE Works {.smaller}

1. **Freeze the CNN**  
   - We keep all CNN weights fixed (no more training).  
   - The CNN now acts as a “feature extractor” that outputs a vector $\mathbf{z}\in\mathbb{R}^D$.

2. **Compress to Explanation Space**  
   - Pass $\mathbf{z}$ through a small encoder $E_\theta$ to get $\mathbf{e}\in\mathbb{R}^L$.  
   - Here, $L$ is small (e.g., 5), so $\mathbf{e}$ is easy to inspect.

3. **Reconstruct Key Features**  
   - From $\mathbf{e}$, a tiny decoder $\tilde\theta$ tries to rebuild **only** the important parts of $\mathbf{z}$.  
   - By penalizing reconstruction error with a logarithmic penalty, we force $\mathbf{e}$ to attend to just a **few coordinates** of $\mathbf{z}$.

4. **Match the CNN’s Output**  
   - Use a simple linear layer $v$ on $\mathbf{e}$ to recreate the CNN’s logits.  
   - Training adjusts $E_\theta$, $\tilde\theta$, and $v$ so that the final prediction $\hat{\mathbf{y}}_{\text{XNN}}$ is as close as possible to the original $\bar{\mathbf{y}}$.

5. **Enforce Orthogonality**  
   - While learning, we add a penalty that makes each of the $L$ concept dimensions in $\mathbf{e}$ point in a different direction (no overlap).  
   - This ensures each concept is unique—e.g., one concept might capture “vertical stroke,” another “curved loop,” etc.

6. **Result: Explainable Concepts**  
   - After training, each dimension of $\mathbf{e}$ is a **human-interpretable “concept.”**  
   - We can visualize, for example, where in the image the “vertical-stroke” concept fires, or how strongly the “loop” concept activates.  
   - Because $\mathbf{e}$ is low-dimensional and each dimension is sparse, it’s far easier to understand than the original $\mathbf{z}$.


# From Theory to Practice! {data-background-color="#D73F09"}

## Implementing XNN + SRAE for MNIST Dataset {.smaller}

> **LeNet-5 backbone**: Its 128-dimensional feature layer is compact enough for exhaustive hyper-parameter sweeps yet still achieves ≈99 % accuracy, allowing a fully connected surrogate to produce sharper, more faithful heat-maps than a convolution-only XNN.

**Testing concept stability**: We attach a five-dimensional SRAE head to LeNet-5, freeze the CNN, and train only the explanation network with our three-term loss, verifying that high-level concepts remain consistent under small input changes.


# Results: MNIST Case Study {data-background-color="#D73F09"}

## CNN Performance

The frozen backbone (LeNet-5) reaches a test accuracy of **99.01 %**, with validation loss plateauing early—indicating stable generalization.

## Hyper-Parameter Sweep {.smaller}

#### Table 1 — Hyper-parameter Sweep

| Run | β | q | η | $F_{\text{MSE}}$ | $F_{\text{MAE}}$ | Accuracy | $Sparse_{\text{Loss}}$ | $Ortho_{\text{Loss}}$ |
|:---:|:--:|:--:|:--:|:----------------:|:----------------:|:---------:|:--------------------:|:--------------------:|
| 1   | 1.0 | 1.0 | 0.1 | 9.2875   | 2.4538   | 0.9841   | 0.7721   | 0.0028   |
| 2   | 1.0 | 1.0 | 0.5 | 9.2945   | 2.4546   | 0.9842   | 0.7721   | 0.0031   |
| 3   | 1.0 | 1.0 | 1.0 | 9.2914   | 2.4544   | 0.9846   | 0.7722   | 0.0029   |
| 4   | 1.0 | 5.0 | 1.0 | 9.2936   | 2.4541   | 0.9842   | 1.7464   | 0.0028   |
| 5   | 1.0 | 0.1 | 1.0 | 9.2865   | 2.4497   | 0.9840   | 0.1253   | 0.0029   |
| **6** | **1.0** | **0.01** | **1.0** | **9.2945** | **2.4468** | **0.9836** | **0.0136** | **0.0027** |
| 7   | 1.5 | 0.1 | 1.0 | 9.2832   | 2.4494   | 0.9838   | 0.1253   | 0.0026   |

- **Run 6** (β = 1.0, q = 0.01, η = 1.0) strikes the best balance:
  - Sparse‐reconstruction loss drops to **0.0136** (two orders of magnitude lower)  
  - Faithfulness ($F_{\text{MSE}}$, $F_{\text{MAE}}$) and orthogonality remain virtually unchanged  

- Faithfulness errors vary little across settings, whereas sparsity depends strongly on $q$.  
- Orthogonality loss stays in the $10^{-3}$ range, indicating robust concept separation.

> **$F_{\text{MSE}}$** comes from summing $(\bar y - \hat y)^2$ over all classes and examples.
> 
> **$F_{\text{MAE}}$** would come from summing $\lvert \bar y - \hat y\rvert$ instead, with the same $1/M$ averaging.

## Concept Specialization {.smaller}

#### Table 2 — Concept Specialization for Run 6

| Digit | Top Feature | Mean Activation | Captured Pattern                       |
|:-----:|:------------:|:---------------:|:----------------------------------------|
| 0     | X3        | 11.757          | Closed circular stroke  
| 1     | X1        | −10.213         | Pure vertical line  
| 2     | X4        | 13.674          | Bottom curve with base  
| 3     | X3        | −11.276         | Open double curve  
| 4     | X1        | −13.920         | Vertical–horizontal junction  
| 5     | X3        | −13.256         | Top bar plus lower curve  
| 6     | X4        | −10.469         | Hook-like bottom curve  
| 7     | X4        | 8.930           | Horizontal top bar  
| 8     | X5        | 12.522          | Dual-loop structure  
| 9     | X2        | 12.821          | Curved top with tail  

- A single concept dominates each digit.  
- **Sign** of activation encodes polarity:  
  - X3 fires **positively** on closed loops (digit 0) but **negatively** on open curves (digits 3, 5).  
  - X1 focuses on vertical strokes; X4 and X5 capture increasingly complex curve–line combinations.

##  Concept Diversity {.smaller}

#### Table 3 — Diversity versus Specialization

| Feature | Primary Digit | Score   | Diversity ($\sigma/\mu$) | Interpretation                            |
|:-------:|:--------------:|:-------:|:------------------------:|:------------------------------------------|
| X1      | 4            | 13.920  | 0.504                    | Tightly tuned to vertical intersections  |
| X2      | 9            | 12.821  | 0.543                    | Targets curved-to-vertical transitions   |
| X3      | 5            | 13.256  | 0.794                    | Broadly active on multiple curved shapes |
| X4      | 2            | 13.674  | 0.426                    | Highly specific to bottom-curve motifs   |
| X5      | 8            | 12.522  | 0.658                    | Reserved for complex dual loops           |

- Diversity ratio $\sigma/\mu$ measures how evenly a concept activates across digits.  
- **Lower** values imply stronger focus on the primary digit (e.g., X4).  
- **Higher** ratio (X3) reflects versatility across multiple curved shapes.

## ExcitationBP Heat-maps {.smaller}

![Figure 10: Variants of “4” and their top-5 concept activations](images/plot_a_first_fives_4s.png){width=60%}

- Concept X3 (vertical strokes) and X5 (mid-bar) appear consistently across five “4” variants
- Concept X1 stays negative, confirming vertical lines are handled by X3.

---

## ExcitationBP Heat-maps {.smaller}

![Figure 11: Top-5 concepts for digits 2, 3, 6, 7, 8, 9](images/plot_b_digits_236789.png){width=60%}

- Concept X4 highlights bottom curves for 2 & 6
- Concept X2 fires on the hook of 9; X5 activates on dual loops of 8
- Concept X1 remains mostly dormant on curved digits.

---

## Original Study Heat-maps {.smaller}

![Figure 12: Fig 10 - Reproduced ExcitationBP from Chen et al. for digit 4 (positives and negatives)](images/studyFig10.png){width=60%}

These heat-maps are less visually crisp and exhibit more off-stroke activation. Our improvement comes from a fully connected surrogate and stronger regularization.

---

## ExcitationBP Heat-maps {.smaller}

![Figure 13: Digit 5: ExcitationBP for X1–X5](images/plot_digit_5_first5.png){width=60%}

Digit 5 shows negative X3 on the open top curve and positive X4 on bottom hook; X2 weakly attends to the upper hook, revealing potential confusion with 9.

---

## ExcitationBP Heat-maps {.smaller}

![Figure 14: Average c-MWP Heat-Maps per Digit (X1–X5)](images/average_cmwp_per_digit_wide.png){height=60%}

Column-wise separation confirms each concept is specialized and orthogonal:

- X1 → vertical edges (1, 4)  
- X2 → hook (9)  
- X3 → closed vs open curves (0 vs 3, 5)  
- X4 → bottom curves (2, 6, 7)  
- X5 → dual loops (8)

## Discussion {.smaller}

1. When we made the “q” parameter smaller, the explanation network used far fewer of the CNN’s internal features without losing accuracy in copying its outputs.
2. Because of this, each digit only activated one or two “concepts” strongly—making it clear which part of the image mattered most.
3. The concepts stayed very separate from each other (orthogonal), so no two concepts overlapped in what they detected.
4. One concept (X3) even flipped its response: it lit up positively for closed loops (like the “0”) and negatively for open curves (like the “5”), showing it truly understood shape differences.
5. The heat-maps are much sharper and cleaner than the original study’s, proving that a small, simple explanation network can still match the big CNN’s accuracy and give clear, easy-to-read visual explanations.


## Limitations

- Experiments are confined to MNIST; more complex datasets (e.g., CIFAR-10) require larger concept sets ($L\approx10$–20) and aggressive dimensionality reduction.  
- Faithfulness on convolution-only surrogates remains unexplored in our implementation.  
- The revived ExcitationBP code is outdated, lacking many quality-of-life features of modern attribution libraries (e.g., Grad-CAM++).


## Ethical Considerations {.smaller}

* **Transparency & Trust**
  By exposing a small set of human-readable “concepts,” stakeholders (clinicians, regulators, end users) can see exactly which features drove each decision. This reduces blind faith in black‐box models and builds confidence that the CNN isn’t relying on spurious cues.

* **Bias Detection & Mitigation**
  Because each concept highlights specific patterns (e.g., “vertical bar,” “curved loop”), it becomes easier to spot and correct cases where a concept might latch onto an unintended correlate—such as background textures or digit‐style biases—before deployment.

* **Safety in High-Stakes Settings**
  In domains like medical imaging or autonomous driving, knowing why a model flagged a sample (e.g., which “cell shapes” triggered a cancer warning) is crucial. The five concise concepts let experts verify that the network is focusing on medically relevant features rather than noise.

* **Accountability & Compliance**
  Regulations (EU AI Act, FDA guidelines) increasingly demand interpretability. A sparse, orthogonal explanation network satisfies these requirements by providing clear, auditable justifications for every prediction.

> By combining faithfulness (matching the CNN), sparsity (only a few key features), and orthogonality (distinct explanations), XNN + SRAE represents a meaningful step toward responsible AI—balancing accuracy with human‐centered transparency, bias control, and regulatory compliance.



## Further Study {.smaller}

* **Refactor visualization**  
  Modernize or replace ExcitationBP with a contemporary alternative (e.g., Grad-CAM++) to improve c-MWP averaging and injection into PyTorch hooks.  
* **Scale-up experiments**  
  Apply XNN + SRAE to Fashion-MNIST or CIFAR-10, increasing $L$ and benchmarking run-time versus faithfulness.  
* **Automate concept selection**  
  Incorporate sparsity-controlled pruning or Bayesian non-parametrics so that $L$ is chosen adaptively rather than manually.  
* **Deploy in real time**  
  Optimize the decoder and attribution pipeline for GPU inference so that concept heat-maps accompany live predictions in practical applications.



## References {.smaller}

1. Li, F., Qi, Z., Khorram, S., *et al.* (2021). *From heatmaps to structured explanations of image classifiers*. arXiv preprint arXiv:2109.06365.  
2. Zhang, Q., Nian Wu, Y., & Zhu, S. C. (2018). *Interpretable convolutional neural networks*. Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition.  
3. Doshi-Velez, F., & Kim, B. (2017). *Towards a Rigorous Science of Interpretable Machine Learning*. arXiv preprint arXiv:1702.08608.  
4. Adebayo, J., Gilmer, J., Muelly, M., Goodfellow, I., Hardt, M., & Kim, B. (2018). *Sanity Checks for Saliency Maps*. Advances in Neural Information Processing Systems 31.  
5. LeCun, Y., Bengio, Y., & Hinton, G. (2015). *Deep Learning*. *Nature*, 521(7553), 436-444.  
6. Chattopadhay, A., Sarkar, A., Howlader, P., & Routray, A. (2018). *Grad-CAM++: Generalized Gradient-based Visual Explanations for Deep Convolutional Networks*. Proceedings of the IEEE Winter Conference on Applications of Computer Vision (WACV), 839-847.  
7. Kokhlikyan, N., Miglani, V., Martin, C., Wang, E., Alsallakh, B., Reynolds, J., Melnikov, A., & Reblitz-Richardson, O. (2020). *Captum: A Unified and Generic Model Interpretability Library for PyTorch*. arXiv preprint arXiv:2009.07896.  
8. Greydanus, S. (2019). *excitationbp: Visualizing how deep networks make decisions* (Version 0.1) [Computer software]. GitHub. https://github.com/greydanus/excitationbp  
9. Allaire, J., & Tang, Y. (2024). tensorflow: R Interface to 'TensorFlow'. R package version 2.16.0. https://CRAN.R-project.org/package=tensorflow  
10. Allaire, J., & Chollet, F. (2024). keras: R Interface to 'Keras'. R package version 2.15.0. https://CRAN.R-project.org/package=keras  
11. Ushey, K., Allaire, J., & Tang, Y. (2024). reticulate: Interface to 'Python'. R package version 1.38.0. https://CRAN.R-project.org/package=reticulate  
12. Wickham, H. (2016). ggplot2: Elegant Graphics for Data Analysis. Springer-Verlag New York. ISBN 978-3-319-24277-4. https://ggplot2.tidyverse.org  
13. Wickham, H., François, R., Henry, L., Müller, K., & Vaughan, D. (2023). dplyr: A Grammar of Data Manipulation. R package version 1.1.4. https://CRAN.R-project.org/package=dplyr  
14. Auguie, B. (2017). gridExtra: Miscellaneous Functions for "Grid" Graphics. R package version 2.3. https://CRAN.R-project.org/package=gridExtra  


# Questions? {data-background-color="#D73F09" .smaller}



# Appendix


## Source code (R)

The complete R implementation for training and evaluating the CNN+XNN pipeline is available in the supplementary materials. Key components include:

- `create_base_cnn()`: LeNet-5 architecture definition
- `create_srae()`: SRAE explanation head construction  
- `srae_loss()`: Composite loss function implementation
- `train_srae()`: Custom training loop with three-loss optimization
- `analyze_digit_features()`: Concept activation analysis
- `evaluate_faithfulness()`: Faithfulness metric computation

```markdown
---
title: "XNN_SRAE_Implementation"
author: "Brian Cervantes Alvarez"
date: "06-02-2025"
format: html
---
```

### 1. Setup and Dependencies

```r
# Install required packages if not already installed
if (!require(tensorflow)) install.packages("tensorflow")
if (!require(keras)) install.packages("keras")
if (!require(reticulate)) install.packages("reticulate")
if (!require(ggplot2)) install.packages("ggplot2")
if (!require(dplyr)) install.packages("dplyr")
if (!require(gridExtra)) install.packages("gridExtra")

library(tensorflow)
library(keras)
library(reticulate)
library(ggplot2)
library(dplyr)
library(gridExtra)

# Install TensorFlow (run once)
# install_tensorflow()
```

### 2. Load and Prepare MNIST Data

```r
# Load MNIST dataset
mnist <- dataset_mnist()
x_train <- mnist$train$x
y_train <- mnist$train$y
x_test <- mnist$test$x
y_test <- mnist$test$y

# Normalize to [0,1]
x_train <- x_train / 255
x_test <- x_test / 255

# Reshape for CNN
x_train <- array_reshape(x_train, c(nrow(x_train), 28, 28, 1))
x_test <- array_reshape(x_test, c(nrow(x_test), 28, 28, 1))

# Convert labels to categorical
y_train_cat <- to_categorical(y_train, 10)
y_test_cat <- to_categorical(y_test, 10)

cat("Training data shape:", dim(x_train), "\n")
cat("Test data shape:", dim(x_test), "\n")
```

### 3. Create and Train Base CNN

```r
create_base_cnn <- function() {
  model <- keras_model_sequential() %>%
    layer_conv_2d(filters = 32, kernel_size = c(3, 3), activation = 'relu',
                  input_shape = c(28, 28, 1)) %>%
    layer_max_pooling_2d(pool_size = c(2, 2)) %>%
    layer_conv_2d(filters = 64, kernel_size = c(3, 3), activation = 'relu') %>%
    layer_max_pooling_2d(pool_size = c(2, 2)) %>%
    layer_flatten() %>%
    layer_dense(units = 128, activation = 'relu', name = 'feature_layer') %>%
    layer_dense(units = 10, activation = 'softmax', name = 'output_layer')
  
  return(model)
}

# Create and compile base CNN
base_cnn <- create_base_cnn()
base_cnn %>% compile(
  optimizer = 'adam',
  loss = 'categorical_crossentropy',
  metrics = c('accuracy')
)

# Train the base CNN
cat("Training base CNN...\n")
history <- base_cnn %>% fit(
  x_train, y_train_cat,
  epochs = 10,
  batch_size = 128,
  validation_data = list(x_test, y_test_cat),
  verbose = 1
)


base_accuracy <- base_cnn %>% evaluate(x_test, y_test_cat, verbose = 0)
cat("Base CNN accuracy:", base_accuracy["accuracy"], "\n")
```

### 4. Extract Intermediate Features (Z) and Predictions (ŷ)

```r
# Create feature extractor (Z layer)
z_extractor <- keras_model(
  inputs = base_cnn$input,
  outputs = get_layer(base_cnn, 'feature_layer')$output
)

# Create logits extractor (before softmax)
logits_model <- keras_model_sequential() %>%
  layer_conv_2d(filters = 32, kernel_size = c(3, 3), activation = 'relu',
                input_shape = c(28, 28, 1)) %>%
  layer_max_pooling_2d(pool_size = c(2, 2)) %>%
  layer_conv_2d(filters = 64, kernel_size = c(3, 3), activation = 'relu') %>%
  layer_max_pooling_2d(pool_size = c(2, 2)) %>%
  layer_flatten() %>%
  layer_dense(units = 128, activation = 'relu') %>%
  layer_dense(units = 10, activation = 'linear')  # No softmax for logits

# Copy weights from base_cnn to logits_model
for (i in 1:length(base_cnn$layers)) {
  if (length(get_weights(base_cnn$layers[[i]])) > 0) {
    if (i <= length(logits_model$layers)) {
      set_weights(logits_model$layers[[i]], get_weights(base_cnn$layers[[i]]))
    }
  }
}

# Extract features and logits
cat("Extracting features...\n")
z_train <- predict(z_extractor, x_train)
z_test <- predict(z_extractor, x_test)
y_hat_train <- predict(logits_model, x_train)
y_hat_test <- predict(logits_model, x_test)

cat("Z features shape:", dim(z_train), "\n")
cat("Y hat shape:", dim(y_hat_train), "\n")
```

### 5. Implement SRAE Model and Loss Function

```r
# Custom SRAE loss function
# This version fixes TensorFlow/Keras compatibility issues

# Simplified and fixed SRAE loss function
srae_loss <- function(z_true, y_hat_true, e, z_reconstructed, y_pred, 
                     beta = 1.0, q = 1.0, eta = 1.0) {
  
  # Term 1: Faithfulness loss
  faithfulness_loss <- tf$reduce_mean(tf$square(y_pred - y_hat_true))
  
  # Term 2: Sparse reconstruction loss with log penalty
  reconstruction_errors <- tf$square(z_reconstructed - z_true)
  reconstruction_errors_per_dim <- tf$reduce_mean(reconstruction_errors, axis = 0L)
  sparse_reconstruction_loss <- tf$reduce_mean(
    tf$math$log(1.0 + q * reconstruction_errors_per_dim)
  )
  
  # Term 3: Orthogonality loss (simplified version)
  e_t <- tf$transpose(e)  
  e_feat_norm <- tf$nn$l2_normalize(e_t, axis = 1L)  
  corr_feats <- tf$matmul(e_feat_norm, e_feat_norm, transpose_b = TRUE)
  
  num_xfeatures_int <- as.integer(num_xfeatures) 
  mask <- 1.0 - tf$eye(num_xfeatures_int)
  
  orthogonality_loss <- tf$reduce_mean(tf$square(corr_feats * mask))
  
  total_loss <- faithfulness_loss + beta * sparse_reconstruction_loss + eta * orthogonality_loss
  
  return(list(
    total = total_loss,
    faithfulness = faithfulness_loss,
    sparse_reconstruction = sparse_reconstruction_loss,
    orthogonality = orthogonality_loss
  ))
}

# ---- 5.5 Build the SRAE network -------------------------------------------
# z_train has shape (n_samples, 128) because the CNN’s 'feature_layer' has 128 units
input_dim      <- ncol(z_train)      # 128
num_xfeatures  <- 5                  # paper default (can be tuned)
num_classes    <- 10                 # logits for MNIST digits 0‑9

create_srae <- function(input_dim, num_xfeatures = 5, num_classes = 10) {
  
  z_input <- layer_input(shape = input_dim,  name = "z_input")
  
  # e  : low‑dimensional explanation (X‑features)
  e_output <- z_input %>% 
    layer_dense(units = num_xfeatures,
                activation = "linear",
                name = "explanation")
  
  # ŷ  : logit prediction derived **only** from the X‑features
  y_pred <- e_output %>% 
    layer_dense(units = num_classes,
                activation = "linear",
                name = "prediction")
  
  # ẑ  : reconstruction of the original latent vector
  z_recon <- e_output %>% 
    layer_dense(units = input_dim,
                activation = "linear",
                name = "z_reconstructed")
  
  keras_model(inputs  = z_input,
              outputs = list(e_output, z_recon, y_pred))
}

srae_model <- create_srae(input_dim, num_xfeatures, num_classes)

```

### 6. Custom Training Loop for SRAE

```r
# Custom training function for SRAE
train_srae <- function(model, z_train, y_hat_train, z_test, y_hat_test,
                      epochs = 100, batch_size = 256, 
                      beta = 1.0, q = 1.0, eta = 1.0) {
  
  # Use legacy optimizer as suggested by the warning
  optimizer <- tf$keras$optimizers$legacy$Adam(learning_rate = 0.001)
  
  # Convert data to TensorFlow tensors
  z_train_tensor <- tf$constant(z_train, dtype = tf$float32)
  y_hat_train_tensor <- tf$constant(y_hat_train, dtype = tf$float32)
  z_test_tensor <- tf$constant(z_test, dtype = tf$float32)
  y_hat_test_tensor <- tf$constant(y_hat_test, dtype = tf$float32)
  
  # Training history
  history <- list(
    epoch = c(),
    total_loss = c(),
    faithfulness_loss = c(),
    sparse_loss = c(),
    orthogonality_loss = c(),
    val_faithfulness = c()
  )
  
  n_samples <- nrow(z_train)
  n_batches <- ceiling(n_samples / batch_size)
  
  cat("Starting SRAE training...\n")
  
  # Training step function
  train_step <- function(z_batch, y_hat_batch) {
    with(tf$GradientTape() %as% tape, {
      # Forward pass
      outputs <- model(z_batch, training = TRUE)
      e <- outputs[[1]]  # explanation
      z_reconstructed <- outputs[[2]]  # reconstruction  
      y_pred <- outputs[[3]]  # prediction
      
      # Compute loss
      losses <- srae_loss(z_batch, y_hat_batch, 
                         e, z_reconstructed, y_pred,
                         beta = beta, q = q, eta = eta)
      total_loss <- losses$total
    })
    
    # Compute and apply gradients
    gradients <- tape$gradient(total_loss, model$trainable_variables)
    optimizer$apply_gradients(purrr::transpose(list(gradients, model$trainable_variables)))
    
    return(losses)
  }
  
  for (epoch in 1:epochs) {
    # Shuffle training data
    indices <- sample(1:n_samples)
    
    epoch_losses <- list(total = 0, faith = 0, sparse = 0, ortho = 0)
    
    for (batch in 1:n_batches) {
      start_idx <- (batch - 1) * batch_size + 1
      end_idx <- min(batch * batch_size, n_samples)
      batch_indices <- indices[start_idx:end_idx]
      
      # Get batch data
      z_batch <- tf$gather(z_train_tensor, as.integer(batch_indices - 1))
      y_hat_batch <- tf$gather(y_hat_train_tensor, as.integer(batch_indices - 1))
      
      # Training step
      batch_losses <- train_step(z_batch, y_hat_batch)
      
      # Accumulate losses
      epoch_losses$total <- epoch_losses$total + as.numeric(batch_losses$total)
      epoch_losses$faith <- epoch_losses$faith + as.numeric(batch_losses$faithfulness)
      epoch_losses$sparse <- epoch_losses$sparse + as.numeric(batch_losses$sparse_reconstruction)
      epoch_losses$ortho <- epoch_losses$ortho + as.numeric(batch_losses$orthogonality)
    }
    
    # Average losses over batches
    epoch_losses <- lapply(epoch_losses, function(x) x / n_batches)
    
    # Validation
    val_outputs <- model(z_test_tensor, training = FALSE)
    val_faith_loss <- tf$reduce_mean(tf$square(val_outputs[[3]] - y_hat_test_tensor))
    
    # Store history
    history$epoch <- c(history$epoch, epoch)
    history$total_loss <- c(history$total_loss, epoch_losses$total)
    history$faithfulness_loss <- c(history$faithfulness_loss, epoch_losses$faith)
    history$sparse_loss <- c(history$sparse_loss, epoch_losses$sparse)
    history$orthogonality_loss <- c(history$orthogonality_loss, epoch_losses$ortho)
    history$val_faithfulness <- c(history$val_faithfulness, as.numeric(val_faith_loss))
    
    # Print progress
    if (epoch %% 10 == 0) {
      cat(sprintf("Epoch %d/%d - Total Loss: %.4f, Faith: %.4f, Sparse: %.4f, Ortho: %.4f, Val Faith: %.4f\n",
                  epoch, epochs, epoch_losses$total, epoch_losses$faith, 
                  epoch_losses$sparse, epoch_losses$ortho, as.numeric(val_faith_loss)))
    }
  }
  
  return(history)
}

# Train SRAE with fixed version
training_history <- train_srae(
  srae_model, z_train, y_hat_train, z_test, y_hat_test,
  epochs = 50, batch_size = 256,
  beta = 1.0, q = 1.0, eta = 1
)
```

### 7. Analyze X-Features for Different Digits

```r
# Function to analyze X-features for specific digits
analyze_digit_features <- function(model, digit, x_data, y_data, z_data, n_samples = 10) {
  
  # Get indices for specific digit
  digit_indices <- which(y_data == digit)
  sample_indices <- sample(digit_indices, min(n_samples, length(digit_indices)))
  
  # Extract features for samples
  z_samples <- z_data[sample_indices, , drop = FALSE]
  
  # Get X-features
  outputs <- model(k_constant(z_samples))
  x_features <- as.array(outputs[[1]])  # [[1]] is the 'explanation' output
  
  # Compute mean and std for each X-feature
  feature_stats <- data.frame(
    digit = digit,
    x_feature = paste0("X", 1:5),
    mean_activation = colMeans(x_features),
    std_activation = apply(x_features, 2, sd),
    stringsAsFactors = FALSE
  )
  
  return(list(
    stats = feature_stats,
    individual_features = x_features,
    sample_indices = sample_indices
  ))
}

# Analyze all digits
cat("Analyzing X-features for all digits...\n")
all_digit_analysis <- list()

for (digit in 0:9) {
  analysis <- analyze_digit_features(srae_model, digit, x_test, y_test, z_test)
  all_digit_analysis[[as.character(digit)]] <- analysis
  
  cat(sprintf("\nDigit %d - Mean X-feature activations:\n", digit))
  print(round(analysis$stats$mean_activation, 4))
}

# Combine results for visualization
feature_summary <- do.call(rbind, lapply(all_digit_analysis, function(x) x$stats))
```


### 8. Visualize X-Feature Activation Patterns

```r
# Create X-feature activation heatmaps and bar plots
plot_xfeature_patterns <- function(feature_summary) {
  
  # Heatmap of X-features by digit
  p1 <- ggplot(feature_summary, aes(x = x_feature, y = factor(digit), fill = mean_activation)) +
    geom_tile(color = "white", size = 0.5) +
    scale_fill_gradient2(low = "blue", mid = "white", high = "red", 
                        midpoint = 0, name = "Activation") +
    labs(title = "X-Feature Activation Heatmap by Digit",
         x = "X-Feature", y = "Digit") +
    theme_minimal() +
    theme(
      axis.text.x = element_text(angle = 45, hjust = 1, size = 12),
      axis.text.y = element_text(size = 12),
      plot.title = element_text(size = 14, hjust = 0.5),
      legend.title = element_text(size = 12)
    )
  
  # Bar plot for each digit
  p2 <- ggplot(feature_summary, aes(x = x_feature, y = mean_activation, fill = x_feature)) +
    geom_bar(stat = "identity") +
    facet_wrap(~paste("Digit", digit), scales = "free_y", ncol = 5) +
    scale_fill_brewer(type = "qual", palette = "Set3") +
    labs(title = "X-Feature Activations by Digit",
         x = "X-Feature", y = "Mean Activation") +
    theme_minimal() +
    theme(
      axis.text.x = element_text(angle = 45, hjust = 1),
      legend.position = "none",
      strip.text = element_text(size = 10)
    )
  
  return(list(heatmap = p1, barplot = p2))
}

# Create and display the basic plots
plots <- plot_xfeature_patterns(feature_summary)

cat("Displaying X-Feature Heatmap...\n")
print(plots$heatmap)

cat("Displaying X-Feature Bar Plots...\n")
print(plots$barplot)
```

### 9. Enhanced Heatmap Analysis

```r
# Create enhanced heatmaps with clustering
library(reshape2)
library(viridis)
library(pheatmap)

# Create matrix for heatmap visualization
feature_matrix <- feature_summary %>%
  select(digit, x_feature, mean_activation) %>%
  reshape2::dcast(digit ~ x_feature, value.var = "mean_activation")

rownames(feature_matrix) <- paste("Digit", feature_matrix$digit)
feature_matrix$digit <- NULL

# Enhanced heatmap with clustering
if (require(pheatmap)) {
  cat("Creating enhanced heatmap with clustering...\n")
  pheatmap(
    as.matrix(feature_matrix),
    main = "X-Feature Activation Heatmap (Clustered)",
    color = viridis::viridis(100),
    cluster_rows = TRUE,
    cluster_cols = TRUE,
    display_numbers = TRUE,
    number_format = "%.3f",
    fontsize_number = 8,
    cellwidth = 40,
    cellheight = 40
  )
}

# Digit similarity heatmap based on X-features
digit_correlation <- cor(t(as.matrix(feature_matrix)))
pheatmap(
  digit_correlation,
  main = "Digit Similarity Based on X-Features",
  color = colorRampPalette(c("blue", "white", "red"))(100),
  display_numbers = TRUE,
  number_format = "%.2f",
  fontsize_number = 8,
  cellwidth = 30,
  cellheight = 30
)
```

### 10. Paper-Style X-Feature Visualization

```r
# EXCITATIONBP IMPLEMENTATION: Proper feature visualization (BETA)
library(tensorflow)

create_excitationbp_viz <- function(model, x_samples, y_samples, z_samples) {
  
  # Get X-features from SRAE model
  outputs <- model(k_constant(z_samples))
  x_features <- as.array(outputs[[1]])  # [[1]] is the 'explanation' output
  
  # Select digits to visualize
  digits_to_show <- c(2, 3, 6, 7, 8, 9)
  
  # Select one example per digit
  selected_examples <- list()
  for (digit in digits_to_show) {
    digit_indices <- which(y_samples == digit)
    if (length(digit_indices) > 0) {
      idx <- sample(digit_indices, 1)
      selected_examples[[as.character(digit)]] <- list(
        image = x_samples[idx, , , 1],
        x_features = x_features_all[idx, ],
        z_features = z_samples[idx, ],
        digit = digit,
        idx = idx
      )
    }
  }
  
  # ExcitationBP implementation for X-features
  excitation_bp <- function(z_input, x_feature_idx) {
    
    # Convert Z input to tensor
    z_tensor <- tf$constant(matrix(z_input, nrow = 1), dtype = tf$float32)
    
    # Forward pass through SRAE to get X-feature activation
    with(tf$GradientTape(persistent = TRUE) %as% tape, {
      tape$watch(z_tensor)
      
      # Get model outputs
      model_outputs <- model(z_tensor)
      x_features <- model_outputs[[1]]  # explanation features
      z_reconstructed <- model_outputs[[2]]  # reconstructed features
      
      # Target: activation of specific X-feature
      target_activation <- x_features[1, x_feature_idx]
    })
    
    # Compute gradients (standard backprop)
    gradients <- tape$gradient(target_activation, z_tensor)
    gradients_array <- as.array(gradients)
    
    # ExcitationBP: Keep only positive contributions
    # This shows what parts of Z positively contribute to the X-feature
    positive_gradients <- pmax(0, gradients_array)
    
    # Weight by the magnitude of Z features (input * gradient)
    excitation_scores <- z_input * positive_gradients
    
    return(list(
      gradients = gradients_array,
      positive_gradients = positive_gradients,
      excitation_scores = excitation_scores,
      activation = as.numeric(target_activation)
    ))
  }
  
  # Convert Z-space excitation back to image space visualization
  z_to_image_excitation <- function(z_excitation, original_image) {
    
    # Method 1: Create spatial mapping of Z features to image regions
    # Each Z feature corresponds to a spatial region in the image
    
    excitation_image <- matrix(0, nrow = 28, ncol = 28)
    n_z_features <- length(z_excitation)
    
    # Simple spatial mapping: divide image into regions for each Z feature
    if (n_z_features >= 64) {  # If we have enough Z features
      
      # Create 8x8 grid mapping Z features to image regions
      z_idx <- 1
      for (i in seq(1, 25, by = 4)) {
        for (j in seq(1, 25, by = 4)) {
          if (z_idx <= n_z_features) {
            # Map Z feature excitation to 4x4 image region
            excitation_image[i:(i+3), j:(j+3)] <- abs(z_excitation[z_idx])
            z_idx <- z_idx + 1
          }
        }
      }
      
    } else {
      # Fallback: distribute Z features across image
      for (z_idx in 1:min(n_z_features, 784)) {
        i <- ((z_idx - 1) %% 28) + 1
        j <- ((z_idx - 1) %/% 28) + 1
        if (i <= 28 && j <= 28) {
          excitation_image[i, j] <- abs(z_excitation[z_idx])
        }
      }
    }
    
    # Enhance excitation map with original image structure
    # Only show excitation where original image has content
    enhanced_excitation <- excitation_image * (original_image > 0.1)
    
    return(enhanced_excitation)
  }
  
  # Generate ExcitationBP visualizations
  cat("Generating ExcitationBP visualizations for X-features...\n")
  
  # Create visualization
  par(mfrow = c(length(digits_to_show), 6),
      mar = c(0.1, 0.1, 2, 0.1),
      bg = "black")
  
  for (digit in digits_to_show) {
    if (!is.null(selected_examples[[as.character(digit)]])) {
      
      img <- selected_examples[[as.character(digit)]]$image
      x_vals <- selected_examples[[as.character(digit)]]$x_features
      z_vals <- selected_examples[[as.character(digit)]]$z_features
      
      # Column 1: Original digit
      image(t(img[28:1, ]), 
            col = gray.colors(256, start = 0, end = 1),
            axes = FALSE,
            main = "Original Image",
            col.main = "white",
            cex.main = 0.8)
      
      # Columns 2-6: ExcitationBP for each X-feature
      for (x_idx in 1:5) {
        x_val <- x_vals[x_idx]
        
        # Compute ExcitationBP for this X-feature and this specific input
        tryCatch({
          excitation_result <- excitation_bp(z_vals, x_idx)
          
          # Convert Z-space excitation to image-space visualization
          excitation_image <- z_to_image_excitation(
            excitation_result$excitation_scores, 
            img
          )
          
          # Enhance visualization based on activation strength
          if (x_val >= 0) {
            # Positive activation: show excitation as bright regions
            final_pattern <- excitation_image * (abs(x_val) / max(abs(x_features_all[, x_idx])))
          } else {
            # Negative activation: show as inverted excitation
            final_pattern <- 1 - (excitation_image * (abs(x_val) / max(abs(x_features_all[, x_idx]))))
          }
          
          # Normalize for display
          if (max(final_pattern) > min(final_pattern)) {
            final_pattern <- (final_pattern - min(final_pattern)) / 
                            (max(final_pattern) - min(final_pattern))
          } else {
            final_pattern <- matrix(0.5, nrow = 28, ncol = 28)
          }
          
        }, error = function(e) {
          cat(paste("ExcitationBP error for X-feature", x_idx, ":", e$message, "\n"))
          # Fallback pattern
          final_pattern <- matrix(0.5, nrow = 28, ncol = 28)
        })
        
        # Display ExcitationBP result
        image(t(final_pattern[28:1, ]), 
              col = gray.colors(256, start = 0, end = 1),
              axes = FALSE,
              main = paste0("X", x_idx, ": ", sprintf("%.4f", x_val)),
              col.main = "white",
              cex.main = 0.7)
      }
    }
  }
  
  # Reset background
  par(bg = "white")
  
  return(list(examples = selected_examples))
}

# Simpler ExcitationBP using correlation method
create_correlation_excitationbp <- function(model, x_samples, y_samples, z_samples) {
  
  # This approach computes correlation between each X-feature and input pixels
  # across the entire dataset - a simplified form of ExcitationBP
  
  outputs <- model(k_constant(z_samples))
  x_features_all <- as.array(outputs[[1]])
  
  # Compute correlation maps for each X-feature
  cat("Computing X-feature correlation maps (simplified ExcitationBP)...\n")
  correlation_maps <- list()
  
  for (x_idx in 1:5) {
    correlation_map <- matrix(0, nrow = 28, ncol = 28)
    x_activations <- x_features_all[, x_idx]
    
    # Use subset for efficiency
    n_samples <- min(2000, nrow(x_samples))
    sample_indices <- sample(1:nrow(x_samples), n_samples)
    
    for (i in 1:28) {
      for (j in 1:28) {
        pixel_values <- x_samples[sample_indices, i, j, 1]
        feature_values <- x_activations[sample_indices]
        
        if (sd(pixel_values) > 0 && sd(feature_values) > 0) {
          correlation_map[i, j] <- cor(pixel_values, feature_values)
        }
      }
    }
    
    correlation_maps[[x_idx]] <- correlation_map
    cat(paste("Computed correlation map for X-feature", x_idx, "\n"))
  }
  
  # Now create input-specific visualizations
  digits_to_show <- c(2, 3, 6, 7, 8, 9)
  
  selected_examples <- list()
  for (digit in digits_to_show) {
    digit_indices <- which(y_samples == digit)
    if (length(digit_indices) > 0) {
      idx <- sample(digit_indices, 1)
      selected_examples[[as.character(digit)]] <- list(
        image = x_samples[idx, , , 1],
        x_features = x_features_all[idx, ],
        digit = digit
      )
    }
  }
  
  # Create visualization
  par(mfrow = c(length(digits_to_show), 6),
      mar = c(0.1, 0.1, 2, 0.1),
      bg = "black")
  
  for (digit in digits_to_show) {
    if (!is.null(selected_examples[[as.character(digit)]])) {
      
      img <- selected_examples[[as.character(digit)]]$image
      x_vals <- selected_examples[[as.character(digit)]]$x_features
      
      # Original image
      image(t(img[28:1, ]), 
            col = gray.colors(256, start = 0, end = 1),
            axes = FALSE,
            main = "Original Image",
            col.main = "white",
            cex.main = 0.8)
      
      # ExcitationBP-style visualizations
      for (x_idx in 1:5) {
        x_val <- x_vals[x_idx]
        
        # Get correlation map for this X-feature
        correlation_map <- correlation_maps[[x_idx]]
        
        # Weight by actual activation and input image
        input_weighted_excitation <- correlation_map * img * sign(x_val)
        
        # Scale by activation strength
        max_activation <- max(abs(x_features_all[, x_idx]))
        if (max_activation > 0) {
          activation_strength <- abs(x_val) / max_activation
          final_pattern <- input_weighted_excitation * activation_strength
        } else {
          final_pattern <- input_weighted_excitation * 0.1
        }
        
        # Normalize for display
        if (max(final_pattern) > min(final_pattern)) {
          final_pattern <- (final_pattern - min(final_pattern)) / 
                          (max(final_pattern) - min(final_pattern))
        } else {
          final_pattern <- matrix(0.5, nrow = 28, ncol = 28)
        }
        
        # Display
        image(t(final_pattern[28:1, ]), 
              col = gray.colors(256, start = 0, end = 1),
              axes = FALSE,
              main = paste0("X", x_idx, ": ", sprintf("%.4f", x_val)),
              col.main = "white",
              cex.main = 0.7)
      }
    }
  }
  
  par(bg = "white")
  return(list(examples = selected_examples, correlation_maps = correlation_maps))
}

# Run ExcitationBP visualization
cat("Creating ExcitationBP-based X-feature visualization...\n")

# Try the correlation-based method first (more stable)
x_viz <- x_test
y_viz <- y_test
z_viz <- z_test
paper_viz_results <- create_correlation_excitationbp(srae_model, x_viz, y_viz, z_viz)

cat("ExcitationBP visualization complete!\n")
cat("Now showing input-specific excitation patterns for each X-feature\n")
```

### 11. Analysis and Interpretation

```r
# Analyze the X-feature patterns for insights
analyze_xfeature_insights <- function(paper_viz_results, feature_summary) {
  
  cat("\n=== X-FEATURE ANALYSIS INSIGHTS ===\n")
  
  # Find dominant X-features for each digit
  cat("\nDominant X-features by digit:\n")
  for (digit in 0:9) {
    digit_features <- feature_summary[feature_summary$digit == digit, ]
    if (nrow(digit_features) > 0) {
      top_feature_idx <- which.max(abs(digit_features$mean_activation))
      top_feature <- digit_features$x_feature[top_feature_idx]
      top_value <- digit_features$mean_activation[top_feature_idx]
      
      cat(sprintf("Digit %d: %s (%.3f)\n", digit, top_feature, top_value))
    }
  }
  
  # Analyze X-feature specialization
  cat("\nX-feature specialization analysis:\n")
  for (x_feat in paste0("X", 1:5)) {
    feat_data <- feature_summary[feature_summary$x_feature == x_feat, ]
    max_digit <- feat_data$digit[which.max(abs(feat_data$mean_activation))]
    max_value <- max(abs(feat_data$mean_activation))
    
    cat(sprintf("%s: Most activated by Digit %d (%.3f)\n", x_feat, max_digit, max_value))
  }
  
  # Compute feature diversity (how specialized each X-feature is)
  cat("\nX-feature diversity (lower = more specialized):\n")
  for (x_feat in paste0("X", 1:5)) {
    feat_values <- feature_summary[feature_summary$x_feature == x_feat, "mean_activation"]
    diversity <- sd(abs(feat_values)) / mean(abs(feat_values))
    cat(sprintf("%s: %.3f\n", x_feat, diversity))
  }
  
  return(invisible())
}

# Run the analysis
analyze_xfeature_insights(paper_viz_results, feature_summary)

cat("\n=== VISUALIZATION SUMMARY ===\n")
cat("Generated visualizations:\n")
cat("1. X-Feature activation heatmaps (Section 8)\n")
cat("2. Enhanced clustered heatmaps (Section 9)\n") 
cat("3. Paper-style X-feature patterns (Section 10)\n")
cat("4. Analysis and insights (Section 11)\n")
```


### 12. Evaluate Model Performance

```r
# Evaluate faithfulness (how well SRAE predicts original CNN output)
evaluate_faithfulness <- function(model, z_test, y_hat_test) {
  
  outputs <- model(k_constant(z_test))
  predicted_logits <- as.array(outputs$prediction)
  
  # Mean squared error
  mse <- mean((predicted_logits - y_hat_test)^2)
  
  # Classification accuracy
  predicted_classes <- max.col(predicted_logits) - 1
  original_classes <- max.col(y_hat_test) - 1
  accuracy <- mean(predicted_classes == original_classes)
  
  cat("SRAE Evaluation:\n")
  cat("Faithfulness (MSE):", round(mse, 6), "\n")
  cat("Classification Accuracy:", round(accuracy, 4), "\n")
  
  return(list(mse = mse, accuracy = accuracy))
}

faithfulness_results <- evaluate_faithfulness(srae_model, z_test, y_hat_test)

# Print summary insights
cat("\n=== SUMMARY INSIGHTS ===\n")
cat("The X-features represent different visual patterns:\n")

# Find which X-features are most important for each digit
for (digit in 0:9) {
  features <- all_digit_analysis[[as.character(digit)]]$stats$mean_activation
  top_feature <- which.max(abs(features))
  cat(sprintf("Digit %d: Most activated by X%d (%.4f)\n", 
              digit, top_feature, features[top_feature]))
}
```



## Source code (Python) - ExcitationBP


### plot_multiple_cMWP.py

In [None]:
#!/usr/bin/env python3
# plot_multiple_cMWP.py

import os
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
from excitationbp import ExcitationBP

# ─────────────────────────────────────────────────────────────────────────────
# 1) Load the two SavedModels exported from R:
# ─────────────────────────────────────────────────────────────────────────────
cnn_path  = os.path.join("saved_models", "base_cnn_for_excitebp")
srae_path = os.path.join("saved_models", "srae_explainer")

print("Loading base CNN from:", cnn_path)
py_cnn = tf.keras.models.load_model(cnn_path)

print("Loading SRAE explainer from:", srae_path)
py_srae = tf.keras.models.load_model(srae_path, compile=False)

# ─────────────────────────────────────────────────────────────────────────────
# 2) Print layer names so we can verify how to stitch:
# ─────────────────────────────────────────────────────────────────────────────
print("\n--- CNN layers ---")
for layer in py_cnn.layers:
    print(f"{layer.name:30s}  {layer.output_shape}")

print("\n--- SRAE layers ---")
for layer in py_srae.layers:
    print(f"{layer.name:30s}  {layer.input_shape} → {layer.output_shape}")

# Sanity check:
assert py_cnn.get_layer("feature_layer").output_shape[-1] == 128
assert py_srae.get_layer("explanation").input_shape[-1] == 128

# ─────────────────────────────────────────────────────────────────────────────
# 3) “Stitch” the SRAE’s ‘explanation’ Dense onto the CNN:
# ─────────────────────────────────────────────────────────────────────────────
#
#    We copy just the weights of the original `explanation` layer
#    so that our fused model uses a true Dense, not a Lambda. This
#    ensures `ebp.excite(…, "explanation", i)` will find exactly
#    one layer named "explanation" to hook into.
#
orig_ex_layer = py_srae.get_layer("explanation")
new_ex_layer = tf.keras.layers.Dense(
    units      = orig_ex_layer.units,
    activation = orig_ex_layer.activation,
    use_bias   = orig_ex_layer.use_bias,
    name       = "explanation"
)
# Build & transfer weights:
new_ex_layer.build((None, 128))
new_ex_layer.set_weights(orig_ex_layer.get_weights())

# Build fused model: image → base_CNN → feature_layer(128) → new_explanation(5)
image_input    = py_cnn.input                                   # shape=(None,28,28,1)
feature_tensor = py_cnn.get_layer("feature_layer").output       # shape=(None,128)
explanation_tensor = new_ex_layer(feature_tensor)               # shape=(None,5)

full_model = tf.keras.Model(
    inputs  = image_input,
    outputs = explanation_tensor,
    name    = "cnn_srae_fused"
)
print("\nFull fused model summary:")
full_model.summary()

# ─────────────────────────────────────────────────────────────────────────────
# 4) Instantiate ExcitationBP on the fused model:
# ─────────────────────────────────────────────────────────────────────────────
ebp = ExcitationBP(full_model)

# ─────────────────────────────────────────────────────────────────────────────
# 5) Load MNIST and select examples:
# ─────────────────────────────────────────────────────────────────────────────
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()
x_test = x_test.astype("float32") / 255.0
x_test = np.expand_dims(x_test, axis=-1)  # shape = (N, 28, 28, 1)

# (a) first five examples of “4”
four_indices = np.where(y_test == 4)[0]
sel4 = four_indices[:5]

# (b) one example each of [2,3,6,7,8,9]
digits_b = [2, 3, 6, 7, 8, 9]
sel_b = []
for d in digits_b:
    idx_list = np.where(y_test == d)[0]
    sel_b.append(idx_list[0])

print("\nPlot (a): indices of the first five 4’s:", sel4.tolist())
print("Plot (b): indices of one example each of [2,3,6,7,8,9]:", sel_b)

# ─────────────────────────────────────────────────────────────────────────────
# 6) Helper to compute all c-MWP maps for a single index:
# ─────────────────────────────────────────────────────────────────────────────
def compute_cmwp_for_index(idx):
    """
    Given a single MNIST index `idx`, returns:
      - orig_img:  (28×28) array
      - x_feats:   length‐5 array of X‐feature activations
      - cmwp_maps: list of five (28×28) heatmaps (one per X‐feature)
    """
    # 1) Grab the (1,28,28,1) input, run it through the fused model
    img = x_test[idx : idx + 1, ...]                          # shape=(1,28,28,1)
    x_feats = full_model.predict(img, verbose=0).reshape(-1)  # shape=(5,)

    # 2) Compute c-MWP for each X‐feature i in [0..4]
    cmwp_maps = []
    for i in range(len(x_feats)):
        hm_tensor = ebp.excite(img, "explanation", i)
        hm = hm_tensor.numpy().squeeze()  # shape=(28,28)
        cmwp_maps.append(hm)
    return img.squeeze(), x_feats, cmwp_maps

# ─────────────────────────────────────────────────────────────────────────────
# 7) Plot (a): “first five 4’s” in a 5×6 grid:
# ─────────────────────────────────────────────────────────────────────────────
num_xfeat = 5
n_rows_a  = len(sel4)
n_cols    = num_xfeat + 1  # one column for the original, plus five heatmaps

plt.figure(figsize=(n_cols * 2, n_rows_a * 2))
plt.suptitle("Plot (a): c-MWP heatmaps for the first five 4’s", fontsize=16, y=0.92)

for row_i, idx in enumerate(sel4):
    orig_img, x_feats, maps = compute_cmwp_for_index(idx)

    # ─── Column 0: Original MNIST “4” ──────────────────────
    ax = plt.subplot(n_rows_a, n_cols, row_i * n_cols + 1)
    # Just show orig_img directly (no [::-1,:], no origin='lower'):
    plt.imshow(orig_img, cmap="gray", vmin=0, vmax=1)
    ax.set_xticks([]); ax.set_yticks([])
    ax.set_title(f"Original Digit", color="white", backgroundcolor="black", fontsize=10)

    # ─── Columns 1..5: c-MWP heatmaps for X₁…X₅ ─────────────
    for feat_i in range(num_xfeat):
        hm = maps[feat_i]
        title_str = f"X{feat_i+1}: {x_feats[feat_i]:.4f}"

        ax = plt.subplot(n_rows_a, n_cols, row_i * n_cols + (feat_i + 2))
        # Show hm as‐is (no [::-1,:], no origin='lower'):
        plt.imshow(hm, cmap="viridis", vmin=hm.min(), vmax=hm.max())
        ax.set_xticks([]); ax.set_yticks([])
        ax.set_title(title_str, color="white", backgroundcolor="black", fontsize=8)

plt.tight_layout(rect=[0, 0, 1, 0.90])
os.makedirs("figures", exist_ok=True)
out_a = os.path.join("figures", "plot_a_first_fives_4s.png")
plt.savefig(out_a, dpi=150, bbox_inches="tight")
print(f"\nSaved Plot (a) to {out_a}")
plt.close()

# ─────────────────────────────────────────────────────────────────────────────
# 8) Plot (b): digits [2,3,6,7,8,9] in a 6×6 grid:
# ─────────────────────────────────────────────────────────────────────────────
n_rows_b = len(sel_b)
plt.figure(figsize=(n_cols * 2, n_rows_b * 2))
plt.suptitle("Plot (b): c-MWP heatmaps for digits [2,3,6,7,8,9]", fontsize=16, y=0.92)

for row_i, idx in enumerate(sel_b):
    orig_img, x_feats, maps = compute_cmwp_for_index(idx)
    digit_label = int(y_test[idx])

    # ─── Column 0: Original MNIST digit ─────────────────────
    ax = plt.subplot(n_rows_b, n_cols, row_i * n_cols + 1)
    plt.imshow(orig_img, cmap="gray", vmin=0, vmax=1)  # no flips
    ax.set_xticks([]); ax.set_yticks([])
    ax.set_title(f"Idx={idx}, {digit_label}", color="white", backgroundcolor="black", fontsize=10)

    # ─── Columns 1..5: c-MWP heatmaps for X₁…X₅ ─────────────
    for feat_i in range(num_xfeat):
        hm = maps[feat_i]
        title_str = f"X{feat_i+1}: {x_feats[feat_i]:.4f}"

        ax = plt.subplot(n_rows_b, n_cols, row_i * n_cols + (feat_i + 2))
        plt.imshow(hm, cmap="viridis", vmin=hm.min(), vmax=hm.max())
        ax.set_xticks([]); ax.set_yticks([])
        ax.set_title(title_str, color="white", backgroundcolor="black", fontsize=8)

plt.tight_layout(rect=[0, 0, 1, 0.90])
out_b = os.path.join("figures", "plot_b_digits_236789.png")
plt.savefig(out_b, dpi=150, bbox_inches="tight")
print(f"\nSaved Plot (b) to {out_b}")
plt.close()

print("\nAll plots generated.\n")

### plot_each_digit_first5.py

In [None]:
#!/usr/bin/env python3
# plot_each_digit_first5.py

import os
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
from excitationbp import ExcitationBP

# ─────────────────────────────────────────────────────────────────────────────
# 1) Load the CNN‐up‐to‐Dense(128) and SRAE “explanation” models, then stitch them
# ─────────────────────────────────────────────────────────────────────────────
cnn_path  = os.path.join("saved_models", "base_cnn_for_excitebp")
srae_path = os.path.join("saved_models", "srae_explainer")

print("Loading base CNN from:", cnn_path)
py_cnn = tf.keras.models.load_model(cnn_path)

print("Loading SRAE explainer from:", srae_path)
py_srae = tf.keras.models.load_model(srae_path, compile=False)

# Verify layer names
print("\n--- CNN layers ---")
for layer in py_cnn.layers:
    print(f"{layer.name:30s}  {layer.output_shape}")
print("\n--- SRAE layers ---")
for layer in py_srae.layers:
    print(f"{layer.name:30s}  {layer.input_shape} → {layer.output_shape}")

# We expect py_cnn.get_layer("feature_layer").output_shape[-1] == 128
assert py_cnn.get_layer("feature_layer").output_shape[-1] == 128
# We expect py_srae.get_layer("explanation").input_shape[-1] == 128
assert py_srae.get_layer("explanation").input_shape[-1] == 128

# Build a brand‐new Dense( “explanation” ) layer so that
# it has a real layer name “explanation” (not a Lambda).
orig_expl = py_srae.get_layer("explanation")
new_expl  = tf.keras.layers.Dense(
    units      = orig_expl.units,
    activation = orig_expl.activation,
    use_bias   = orig_expl.use_bias,
    name       = "explanation"
)
new_expl.build((None, 128))
new_expl.set_weights(orig_expl.get_weights())

# Stitch it onto the CNN’s “feature_layer”
image_input       = py_cnn.input
feature_tensor    = py_cnn.get_layer("feature_layer").output
explanation_tensor = new_expl(feature_tensor)

full_model = tf.keras.Model(
    inputs  = image_input,
    outputs = explanation_tensor,
    name    = "cnn_srae_fused"
)
print("\nFull fused model summary:")
full_model.summary()

# Wrap with ExcitationBP
ebp = ExcitationBP(full_model)

# ─────────────────────────────────────────────────────────────────────────────
# 2) Load MNIST test set
# ─────────────────────────────────────────────────────────────────────────────
(_, _), (x_test, y_test) = tf.keras.datasets.mnist.load_data()
x_test = x_test.astype("float32") / 255.0
x_test = np.expand_dims(x_test, axis=-1)  # shape = (N,28,28,1)

# ─────────────────────────────────────────────────────────────────────────────
# 3) For each digit in [0,1,2,3,5,6,7,8,9], pick the first five indices:
# ─────────────────────────────────────────────────────────────────────────────
digits_to_plot = [0, 1, 2, 3, 5, 6, 7, 8, 9]
first5_dict = {}

for d in digits_to_plot:
    idxs = np.where(y_test == d)[0]
    if len(idxs) < 5:
        raise RuntimeError(f"Found fewer than 5 examples of digit {d} in test set!")
    first5_dict[d] = idxs[:5]

print("\nWill create one figure per digit (first 5 examples each).")
for d in digits_to_plot:
    print(f"  Digit {d}: indices {first5_dict[d].tolist()}")

# ─────────────────────────────────────────────────────────────────────────────
# 4) Helper: given a single test‐index, compute:
#      (a) orig_img   = 28×28 array
#      (b) x_feats    = length‐5 np.ndarray of explanation activations
#      (c) cmwp_maps  = list of five (28×28) c-MWP heatmaps
# ─────────────────────────────────────────────────────────────────────────────
def compute_cmwp_for_index(idx):
    # “inp” shape = (1,28,28,1)
    inp = x_test[idx : idx + 1, ...]
    # (1×5) → reshape to (5,) for X₁…X₅ activations
    x_feats = full_model.predict(inp, verbose=0).reshape(-1)
    cmwp_maps = []
    for i in range(len(x_feats)):
        hm_t  = ebp.excite(inp, "explanation", i)
        hm_np = hm_t.numpy().squeeze()  # → (28,28)
        cmwp_maps.append(hm_np)
    orig_img = inp.squeeze()  # (28,28)
    return orig_img, x_feats, cmwp_maps

# ─────────────────────────────────────────────────────────────────────────────
# 5) Loop over each digit, build a separate figure, and save it:
# ─────────────────────────────────────────────────────────────────────────────
os.makedirs("figures", exist_ok=True)

for d in digits_to_plot:
    five_indices = first5_dict[d]

    # Build a new figure: 5 rows × 6 columns
    plt.figure(figsize=(6 * 2.0, 5 * 2.0))
    plt.suptitle(
        f"c-MWP heatmaps for the first five '{d}'s",
        fontsize = 16,
        y        = 0.92
    )

    # For each of the five examples of digit d:
    for row_i, idx in enumerate(five_indices):
        orig_img, x_feats, cmwp_maps = compute_cmwp_for_index(idx)

        # ─── Column 0 (Original Digit) ─────────────────────────
        ax = plt.subplot(5, 6, row_i * 6 + 1)
        plt.imshow(orig_img, cmap="gray", vmin=0, vmax=1)
        ax.set_xticks([]); ax.set_yticks([])
        # Title = “Original Digit” (no index)
        ax.set_title("Original Digit",
                     color="white",
                     backgroundcolor="black",
                     fontsize=9)

        # ─── Columns 1…5 (c-MWP for X₁…X₅) ───────────────────────
        for feat_i in range(5):
            hm   = cmwp_maps[feat_i]
            val  = x_feats[feat_i]
            title_str = f"X{feat_i+1}: {val:.3f}"

            ax = plt.subplot(5, 6, row_i * 6 + (feat_i + 2))
            plt.imshow(hm, cmap="viridis", vmin=hm.min(), vmax=hm.max())
            ax.set_xticks([]); ax.set_yticks([])
            ax.set_title(title_str,
                         color="white",
                         backgroundcolor="black",
                         fontsize=7)

    plt.tight_layout(rect=[0, 0, 1, 0.90])
    save_name = f"plot_digit_{d}_first5.png"
    save_path = os.path.join("figures", save_name)
    plt.savefig(save_path, dpi=150, bbox_inches="tight")
    print(f"Saved figure for digit {d} → {save_path}")
    plt.close()

print("\nAll done. You should now have nine separate PNGs in ./figures/:\n"
      "  plot_digit_0_first5.png\n"
      "  plot_digit_1_first5.png\n"
      "  plot_digit_2_first5.png\n"
      "  plot_digit_3_first5.png\n"
      "  plot_digit_5_first5.png\n"
      "  plot_digit_6_first5.png\n"
      "  plot_digit_7_first5.png\n"
      "  plot_digit_8_first5.png\n"
      "  plot_digit_9_first5.png\n")

### average_cmwp_per_digit.py

In [None]:
#!/usr/bin/env python3
# average_cmwp_per_digit.py

import os
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
from excitationbp import ExcitationBP

# ─────────────────────────────────────────────────────────────────────────────
# 1) Load & stitch the CNN + SRAE so that “explanation” is a real Dense layer.
# ─────────────────────────────────────────────────────────────────────────────
cnn_path  = os.path.join("saved_models", "base_cnn_for_excitebp")
srae_path = os.path.join("saved_models", "srae_explainer")

print("Loading base CNN from:", cnn_path)
py_cnn = tf.keras.models.load_model(cnn_path)

print("Loading SRAE explainer from:", srae_path)
py_srae = tf.keras.models.load_model(srae_path, compile=False)

# If “explanation” was exported as a Lambda, replace it with a real Dense
orig_expl = py_srae.get_layer("explanation")
new_expl = tf.keras.layers.Dense(
    units      = orig_expl.units,
    activation = orig_expl.activation,
    use_bias   = orig_expl.use_bias,
    name       = "explanation"
)
# Build & copy weights
new_expl.build(input_shape=(None, orig_expl.input_shape[-1]))
new_expl.set_weights(orig_expl.get_weights())

# Re‐stitch onto the CNN’s feature_layer output
image_input        = py_cnn.input                              # (None,28,28,1)
feature_tensor     = py_cnn.get_layer("feature_layer").output  # (None,128)
explanation_tensor = new_expl(feature_tensor)                  # (None,5)

full_model = tf.keras.Model(
    inputs  = image_input,
    outputs = explanation_tensor,
    name    = "cnn_srae_fused"
)

print("\nFull fused model summary:")
full_model.summary()

# Wrap in ExcitationBP
ebp = ExcitationBP(full_model)

# ─────────────────────────────────────────────────────────────────────────────
# 2) Load MNIST test set and group indices by digit
# ─────────────────────────────────────────────────────────────────────────────
(_, _), (x_test, y_test) = tf.keras.datasets.mnist.load_data()

x_test = x_test.astype("float32") / 255.0
x_test = np.expand_dims(x_test, axis=-1)   # shape = (10000,28,28,1)
y_test = y_test.astype("int")

N_total = x_test.shape[0]
print(f"\nLoaded MNIST test set: {N_total} images.")

digit_indices = {d: np.where(y_test == d)[0] for d in range(10)}
for d in range(10):
    print(f"  → Digit {d}: {len(digit_indices[d])} examples")

# Prepare accumulators: avg_maps[digit][feature_index] = (28×28) float64
avg_maps = [
    [np.zeros((28, 28), dtype=np.float64) for _ in range(5)]
    for _ in range(10)
]
counts = [len(digit_indices[d]) for d in range(10)]

# ─────────────────────────────────────────────────────────────────────────────
# 3) Accumulate c-MWP heatmaps per (digit, feature), then divide by count
# ─────────────────────────────────────────────────────────────────────────────
print("\nAccumulating c-MWP maps, grouped by digit and feature…")
for d in range(10):
    idxs = digit_indices[d]
    c = len(idxs)
    print(f"\n  → Digit {d}: {c} images…")
    for (k, n) in enumerate(idxs):
        inp = x_test[n : n + 1, ...]  # shape = (1,28,28,1)
        # Compute c-MWP for each of the 5 X-features
        for i in range(5):
            hm = ebp.excite(inp, "explanation", i).numpy().squeeze()
            avg_maps[d][i] += hm

        if (k + 1) % 2000 == 0 or (k + 1) == c:
            print(f"      • processed {k+1}/{c} for digit {d}")

    # Convert sums into averages
    for i in range(5):
        avg_maps[d][i] /= float(c)

print("\nFinished building per-digit averages (10×5 maps).")

# ─────────────────────────────────────────────────────────────────────────────
# 4) Plot the 10×5 grid of average heatmaps (wider figure, title at top)
# ─────────────────────────────────────────────────────────────────────────────
os.makedirs("figures", exist_ok=True)

fig, axes = plt.subplots(
    nrows    = 10,
    ncols    = 5,
    figsize  = (16, 24),           # Increased width to accommodate the legend
    constrained_layout = True
)
fig.suptitle(
    "Average c-MWP Heatmaps per Digit (rows=0–9) & Feature (cols=X₁–X₅)",
    fontsize = 20,
    y        = 0.99                # Push the title to the very top
)

for d in range(10):
    # Compute row-specific vmin/vmax over X₁..X₅ for visibility
    row_min = min(np.min(avg_maps[d][i]) for i in range(5))
    row_max = max(np.max(avg_maps[d][i]) for i in range(5))

    for i in range(5):
        ax = axes[d, i]
        im = ax.imshow(
            avg_maps[d][i],
            cmap = "viridis",
            vmin = row_min,
            vmax = row_max
        )
        ax.set_xticks([])
        ax.set_yticks([])

        if d == 0:
            ax.set_title(f"X{i+1}", fontsize=14, color="white", backgroundcolor="black")
        if i == 0:
            ax.set_ylabel(
                f"Digit {d}",
                fontsize = 14,
                color    = "white",
                backgroundcolor = "black",
                rotation = 0,
                labelpad = 50
            )

# Add a single, tall colorbar on the right
cax = fig.add_axes([0.94, 0.05, 0.015, 0.9])  # [left, bottom, width, height]
cbar = plt.colorbar(im, cax=cax, orientation="vertical")
cbar.set_label("Avg c-MWP", fontsize=12)
cbar.ax.tick_params(labelsize=12)

out_path = os.path.join("figures", "average_cmwp_per_digit_wide.png")
plt.savefig(out_path, dpi=180, bbox_inches="tight")
plt.close(fig)

print(f"\nSaved the wide 10×5 grid → {out_path}\n")