In [1]:
"""
Utility to inspect and visualize SHAP values saved in NPZ format.
"""

import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path

In [2]:
npz_path = "../reports/shap/resnet18/shap_values_first_batch.npz"
data = np.load(npz_path)

In [3]:
# Inspect the shapes
print("Array shapes:")
for key in data.files:
    print(f"  {key}: {data[key].shape}")
print()

Array shapes:
  values: (16, 224, 224, 3, 101)
  base_values: (16, 101)



In [4]:
# Load the arrays
values = data['values']
base_values = data['base_values']

print(f"Values dtype: {values.dtype}")
print(f"Base values dtype: {base_values.dtype}")
print()

Values dtype: float64
Base values dtype: float64



In [5]:
# Interpretation
print("=" * 60)
print("INTERPRETATION")
print("=" * 60)
print(f"Number of images analyzed: {values.shape[0]}")
print(f"Image dimensions: {values.shape[1]}x{values.shape[2]} (HxW)")
print(f"Number of color channels: {values.shape[3]}")
print(f"Number of output classes: {values.shape[4]}")
print()
print("The 'values' array shows the contribution of each pixel")
print("to each of the 101 food classes.")
print()
print("Positive values = pixel pushes prediction toward that class")
print("Negative values = pixel pushes prediction away from that class")
print()

INTERPRETATION
Number of images analyzed: 16
Image dimensions: 224x224 (HxW)
Number of color channels: 3
Number of output classes: 101

The 'values' array shows the contribution of each pixel
to each of the 101 food classes.

Positive values = pixel pushes prediction toward that class
Negative values = pixel pushes prediction away from that class



In [6]:
# Show base values
print(f"Base values shape: {base_values.shape}")
print(f"Base values are the model's 'default' predictions")
print(f"Example base value for first image: {base_values[0][:5]}... (first 5 classes)")
print()

# Analyze first image
print("=" * 60)
print("EXAMPLE: First Image Analysis")
print("=" * 60)
img_idx = 0
print(f"Image {img_idx}:")

Base values shape: (16, 101)
Base values are the model's 'default' predictions
Example base value for first image: [0.00081181 0.00127885 0.00172922 0.00089205 0.00166602]... (first 5 classes)

EXAMPLE: First Image Analysis
Image 0:


In [7]:
# Find which class had highest attribution
# Sum attributions across spatial dimensions and channels for each class
spatial_attribution = values[img_idx].sum(axis=(0, 1, 2))  # Sum over H, W, C
top_class = np.argmax(spatial_attribution)
print(f"  Most influenced class: {top_class}")
print(f"  Attribution value: {spatial_attribution[top_class]:.4f}")
print(f"  Top 5 classes by attribution: {np.argsort(spatial_attribution)[-5:][::-1]}")
print()

# Show value ranges
print(f"Attribution value range for class {top_class}:")
class_attrs = values[img_idx, :, :, :, top_class]
print(f"  Min: {class_attrs.min():.6f}")
print(f"  Max: {class_attrs.max():.6f}")
print(f"  Mean: {class_attrs.mean():.6f}")
print()

  Most influenced class: 6
  Attribution value: 0.8398
  Top 5 classes by attribution: [ 6 85 89 36 56]

Attribution value range for class 6:
  Min: -0.000002
  Max: 0.000031
  Mean: 0.000006



In [8]:
# Create visualization
print("Creating visualization...")
fig, axes = plt.subplots(2, 2, figsize=(12, 10))

# Show attributions for top predicted class (averaged across RGB channels)
top_attr_map = values[img_idx, :, :, :, top_class].mean(axis=2)
im1 = axes[0, 0].imshow(top_attr_map, cmap='RdBu_r', aspect='auto')
axes[0, 0].set_title(f'Attributions for Class {top_class}\n(Most Influential Class)')
plt.colorbar(im1, ax=axes[0, 0])

# Show absolute attributions (magnitude of influence)
abs_attr_map = np.abs(top_attr_map)
im2 = axes[0, 1].imshow(abs_attr_map, cmap='hot', aspect='auto')
axes[0, 1].set_title('Absolute Attribution Magnitude')
plt.colorbar(im2, ax=axes[0, 1])

# Show total positive attributions across all classes
total_positive = np.maximum(values[img_idx], 0).sum(axis=(2, 3))  # Sum over channels and classes
im3 = axes[1, 0].imshow(total_positive, cmap='Reds', aspect='auto')
axes[1, 0].set_title('Total Positive Attributions\n(All Classes)')
plt.colorbar(im3, ax=axes[1, 0])

# Show total negative attributions across all classes
total_negative = np.minimum(values[img_idx], 0).sum(axis=(2, 3))
im4 = axes[1, 1].imshow(total_negative, cmap='Blues_r', aspect='auto')
axes[1, 1].set_title('Total Negative Attributions\n(All Classes)')
plt.colorbar(im4, ax=axes[1, 1])

plt.tight_layout()
output_path = Path(npz_path).parent / "analysis_visualization.png"
plt.savefig(output_path, dpi=150, bbox_inches='tight')
print(f"Saved visualization to: {output_path}")
plt.close()

# Create a summary heatmap showing top classes
print("\nCreating class importance heatmap...")
fig, ax = plt.subplots(figsize=(10, 6))

# Compute total attribution for each class across all images
class_importance = np.zeros((values.shape[0], values.shape[4]))
for i in range(values.shape[0]):
    for c in range(values.shape[4]):
        class_importance[i, c] = values[i, :, :, :, c].sum()

# Show top 20 classes
top_classes = np.argsort(class_importance.sum(axis=0))[-20:][::-1]
heatmap_data = class_importance[:, top_classes].T

im = ax.imshow(heatmap_data, cmap='RdBu_r', aspect='auto')
ax.set_xlabel('Image Index')
ax.set_ylabel('Class Index')
ax.set_title('SHAP Attribution Heatmap\n(Top 20 Most Influential Classes)')
ax.set_yticks(range(len(top_classes)))
ax.set_yticklabels(top_classes)
plt.colorbar(im, ax=ax, label='Total Attribution')

plt.tight_layout()
heatmap_path = Path(npz_path).parent / "class_importance_heatmap.png"
plt.savefig(heatmap_path, dpi=150, bbox_inches='tight')
print(f"Saved heatmap to: {heatmap_path}")
plt.close()

print("\nDone! Check the generated PNG files for visualizations.")

Creating visualization...
Saved visualization to: ../reports/shap/resnet18/analysis_visualization.png

Creating class importance heatmap...
Saved heatmap to: ../reports/shap/resnet18/class_importance_heatmap.png

Done! Check the generated PNG files for visualizations.
