# ECAI hands‑on: Exploring Concept-based Explainable AI for Computer Vision

**Authors:** [Mahdi Pourghasem](https://www.linkedin.com/in/mahdi-pourghasem), [Gesina Schwalbe](https://gesina.github.io), [Jae Hee Lee](https://jaeheelee.gitlab.io)

**Goal:** 30–45 min hands‑on Colab that introduces feature attribution, feature visualization, unsupervised concept extraction, and a supervised concept-based explainability (TCAV‑style) experiment. All examples build on the PyTorch/torchvision deep learning framework.

Exercises are provided as textual descriptions in in the code respectivel marked with `=== EXERCISE`.

---

## Overview

1. **Setup & quick model demo**  
   *(You might need to restart the session after first call to the `!pip` cell!)*
2. **Feature attribution (Grad‑CAM)**

   * Run a classical feature importance method and inspect the heatmaps
3. **Feature visualization (most‑activating images for a filter/neuron)**

   * Collect and show images that maximally activate a chosen filter
4. **Unsupervised concept extraction (PCA / NMF / k‑means on activations)**

   * Extract linear components / clusters as prototypes and show most activating images
5. **Supervised concept extraction (simple TCAV style using synthetic concept)**

   * Create a synthetic concept dataset, train a linear concept classifier in activation space
   
6. **Evaluating Concept-to-output Attribution (TCAV)**

   * Compute the TCAV score for different concept activation vectors.

6. **Wrap up**
---



In [None]:
# Run this first cell in Colab
!pip install -q torch torchvision matplotlib scikit-learn tqdm numpy pandas scikit-image

> Note: !! You might need to restart the session after the installation. !!


---

## 0) Imports & utilities

> Note: We use `torchvision.models.resnet18(pretrained=True)` and `CIFAR10` (resized) as a small dataset that downloads automatically in Colab. These can be exchanged for other models, but make sure to adapt the setup accordingly (transformation, chosen target layer).

In [None]:
# Cell: imports & utilities
import torch, torchvision, cv2, numpy as np, matplotlib.pyplot as plt
import torch.nn.functional as F
from torchvision import transforms
from torchvision.transforms.functional import to_pil_image
from torch.utils.data import DataLoader, Subset
from sklearn.decomposition import PCA, NMF
from sklearn.cluster import KMeans
from sklearn.linear_model import LogisticRegression
from sklearn.preprocessing import normalize
from tqdm.auto import tqdm
import random, os
from copy import deepcopy
import PIL.Image

# Just for typing
from PIL.Image import Image
from typing import Union

# Helper variables
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Helper functions
def imshow(img: Union[torch.Tensor, list[torch.Tensor]], title=None) -> np.ndarray:
    """
    Plot image (or list of images) in a row with optional titles.

    Plots an image provided as torch tensor with values in [0,1].
    """
    if torch.is_tensor(img) and len(img.size()) > 3: img = list(img)
    imgs: list[torch.Tensor] = img if type(img) == list else [img]
    titles: list = title if type(title) == list else [title]*len(imgs)
    assert len(imgs) == len(titles)

    fig = plt.figure(figsize=(4*len(imgs),4))
    axes = fig.subplots(1,len(imgs), squeeze=False)
    for ax, img, title in zip(axes[0], imgs, titles):
      if torch.is_tensor(img):
        img = img.cpu().numpy().transpose(1,2,0)
        img = np.clip(img, 0, 1)
      ax.imshow(img)
      if title: ax.set_title(title)
      ax.axis('off')
    return fig


def overlay_heatmap(img: np.ndarray, heatmap: np.ndarray, alpha=0.5,
                    colormap=cv2.COLORMAP_JET) -> np.ndarray:
    """
    Overlay a heatmap over an image.

    The image should be a tensor, PIL image, or the numpy array of a PIL image.
    The heatmap should be a numpy array with values in [0,1].
    """
    img_np = np.array(to_pil_image(img) if torch.is_tensor(img) else img)
    heatmap_np = np.array(to_pil_image(heatmap, mode='F')) if torch.is_tensor(heatmap) else heatmap
    heatmap_np = cv2.resize(heatmap_np, (img_np.shape[0], img_np.shape[1]))
    heatmap_np = np.uint8(255 * heatmap_np)
    heatmap_np = cv2.applyColorMap(heatmap_np, colormap)
    overlay = cv2.addWeighted(heatmap_np, alpha, img_np, 1-alpha, 0)
    return overlay

def mask_heatmap(img: np.ndarray, heatmap: np.ndarray, alpha=1.) -> np.ndarray:
    """
    Mask an image according to a heatmap.
    """
    #return overlay_heatmap(img, heatmap, alpha=alpha, colormap=cv2.COLORMAP_BONE)
    img_np = np.array(to_pil_image(img) if torch.is_tensor(img) else img)
    heatmap_np = np.array(to_pil_image(heatmap, mode='F')) if torch.is_tensor(heatmap) else heatmap
    heatmap_np = cv2.resize(heatmap_np, (img_np.shape[0], img_np.shape[1]))
    masked_np = np.uint8(np.expand_dims(heatmap_np, axis=-1) * img_np)
    overlay = cv2.addWeighted(masked_np, alpha, img_np, 1-alpha, 0)
    return overlay


def undo_normalize_transform(orig_mean: tuple, orig_std: tuple):
    """
    Create a transformation that undos a normalization.
    (Needed for visualization of images from the transformed dataset.)
    """
    return transforms.Compose([
        transforms.Normalize(mean = [ 0., 0., 0. ], std = [1/std_i for std_i in orig_std]),
        transforms.Normalize(mean = [ -mean_i for mean_i in orig_mean], std = [ 1., 1., 1. ]),
        ])

def numpy_sigmoid(z): return 1.0/(1.0 + np.exp(-z))

*(If you got an error in the previous step after the first-time run of the `!pip` command, make sure to restart the notebook!)*

## 1) Dataset & model setup

In [None]:
# Cell: dataset and model

# Choose and load the models
from torchvision.models import resnet18, ResNet18_Weights
# Alternatives (check out https://docs.pytorch.org/vision/stable/models.html#classification):
#alexnet, AlexNet_Weights
#VGG16_BN_Weights

# Load pretrained weights and initialize model for inference
weights = ResNet18_Weights.DEFAULT  # ResNet18 predicting 1000 classes of ImageNet
model = resnet18(weights=weights).to(device).eval()


# Chose the right transformation for the model ...
transform = weights.transforms()
unnormalize = undo_normalize_transform(transform.mean, transform.std) # for plotting
# ... and initialize the training and test data as well as the data loader
trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
testset  = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
classes = trainset.classes

# (small subsets for speed in the tutorial)
train_subset = Subset(trainset, list(range(2000)))
test_subset = Subset(testset, list(range(1000)))

train_loader = DataLoader(train_subset, batch_size=32, shuffle=True)
test_loader  = DataLoader(test_subset, batch_size=32, shuffle=False)


print(f'Device: {device}')
print(f'Model loaded; data ready with '
      f'{len(train_subset)} train and {len(test_subset)} test samples.')

**Exercise:** run a forward pass on one batch and print the top‑3 predicted ImageNet label indices (we won't map them to names in this short tutorial).


In [None]:
# # <<<==== EXERCISE: Uncomment to explore
# # Load and show an example image
# img, label = train_subset[0]
# imshow(unnormalize(img), title=f"{classes[label]} ({label=})");

# # Conduct inference on the image
# out = model(img.unsqueeze(0).to(device))
# print(f"Output dimensions: {out.size()}")

# # Retrieve top 3 values
# topk_vals, topk_indices = torch.topk(out, k=3, dim=-1)
# display(topk_vals)
# display(topk_indices)

**Exercise:** Let's first have a closer look at the module structure of our model by printing it. What is the name of the last convolutional layer?

In [None]:
# # <<<==== EXERCISE: Uncomment to explore
# print(model)

# # Names of submodules are hierarchical, separated by a dot
# # (but mind that their order does not have to be the processing order)
# print([name for name, module in model.named_modules()])

## 2) Feature attribution — Grad‑CAM

We'll compute Grad‑CAM for a single convolutional layer (the last conv layer in ResNet18) and visualize heatmaps.

*What to see here?* Grad-CAM is a feature attribution method. Feature attribution highlights which parts of the input were important for the output. In that these methods may mix up whatever features are availabe. C-XAI further drills down which different features the model has learned. Furthermore, we will need the gradient backpropagation techniques employed here later in a similar fashion for the calculation of the TCAV score.

*How does it work?* Grad-CAM (1) collects the activation maps (convolutional intermediate outputs) of filters in a layer, (2) weights them according to their gradient with respect to the desired output, and (3) returns the weighted sum as heatmap.
We here do a minimal implementation of the Grad-CAM algorithm.

Let's implement a simple version of Grad-CAM.

In [None]:
# Cell: Grad-CAM helpers
last_conv_layer_name = 'layer4.1.conv2'  # ADAPT for different model


# Activations and gradients get stored in these global variables.
activations = None
gradients = None

# To store the activations (without gradient), we use the pytorch hook mechanism.
# (1) define the hook functions:
def save_activation(module, input, output):
    global activations
    activations = output.detach()

def save_gradient(module, grad_input, grad_output):
    global gradients
    gradients = grad_output[0].detach()

# (2) register the hook on last convolutional layer
target_module = dict(model.named_modules())[last_conv_layer_name]

handle_a = target_module.register_forward_hook(save_activation)
handle_g = target_module.register_full_backward_hook(save_gradient)

def grad_cam(input_tensor, model=model, target_class=None, device=device):
    """
    Simple Grad-CAM implementation for classifiers.

    Assumes activations and gradients are stored via hooks into the variables.
    input_tensor must be of size [Batch, Channel, Height, Width].
    """
    global activations, gradients
    assert len(input_tensor.size()) == 4, f"Invalid input tensor shape {input_tensor.size()}"

    model.train().zero_grad()  # start with zero gradients to not accumulate
    # forward pass (collect activations)
    out = model.to(device)(input_tensor.to(device))
    if target_class is None:
        # choose top predicted class
        target_class = out.argmax(dim=1).item()
    score = out[0, target_class]
    # backward pass (collect gradients)
    score.backward()

    # Average the weights to obtain one weight per activation map channel
    # dimensions: activations: [Batch, Channel, Height, Width], gradients: [B, C, H, W]
    weights = gradients.mean(dim=(2,3), keepdim=True)  # [B, C, 1, 1]
    # Weighted sum of activation channels (along channel dimension)
    cam = (weights * activations)       # [B, C, H, W]
    cam = cam.sum(dim=1)                # [B,    H, W]
    # Clip to positive values
    cam = F.relu(cam)

    # Normalize to values in [0,1]
    cam = cam - cam.min()
    cam = cam / (cam.max() + 1e-8)

    return cam.cpu().numpy()

In [None]:
# Cell: example visualization for one image
img, label = test_subset[0]  # <<====== EXERCISE: test different images
cam = grad_cam(img.unsqueeze(0), model=model).squeeze(0)

overlay = overlay_heatmap(to_pil_image(unnormalize(img)), cam)
imshow([unnormalize(img), cam, overlay],
       [f'input (class={classes[label]})', 'Grad-CAM heatmap', 'Grad-CAM overlay']);


**Exercise (5 min):**

* Try with several images/classes and observe where Grad‑CAM highlights.
* Ask: what does this tell you about the model's focus? What is missing from such explanations?


## 3) Feature visualization — Most‑activating images for a neuron/filter

We pick a single channel in the last convolutional layer activation map and find the images from the test subset that maximize the spatial mean of that channel.

*For details see the [NetDissect paper](https://openaccess.thecvf.com/content_cvpr_2017/papers/Bau_Network_Dissection_Quantifying_CVPR_2017_paper.pdf)*.

In [None]:
# Cell: compute activations for test dataset
def collect_activations(loader, layer_module, device=device):
    act_list = []
    imgs_list = []
    with torch.no_grad():
        for x, _ in tqdm(loader, desc="Activations collected for batches:"):
            x = x.to(device)
            _ = model(x)
            a = activations.clone().cpu()  # dimensions: [B, C, H, W]
            act_list.append(a)
            imgs_list.append(x.cpu())
    act_all = torch.cat(act_list, dim=0)
    imgs_all = torch.cat(imgs_list, dim=0)
    return act_all, imgs_all

acts_all, imgs_all = collect_activations(test_loader, target_module)

In [None]:
# Cell: determine and show top k images per filter

# Choose channel
channel = 0  ## <<<======== EXERCISE: Try different channels
k = 6        ## <<<======== EXERCISE: Try different k


# Use the channel's mean activation as score per image
scores = acts_all[:, channel].mean(dim=(1,2))        # dim: [B]

# Helper function for top-k selection and plotting
def show_topk(imgs: torch.Tensor, concept_scores: Union[np.ndarray, torch.Tensor],
              k=k, suptitle=None):
    """
    Select and show the top k scoring images from imgs.

    imgs should be a Collection of image tensors,
    the concept scores a flat tensor or numpy array of per-image scores.
    """
    # Get indices of k top-scoring images
    k = min(k, len(concept_scores))
    topk_indices = torch.topk(torch.tensor(concept_scores), k=k).indices

    # Plot
    topk_imgs   = [unnormalize(imgs[idx]) for idx in topk_indices]
    topk_titles = [f'score={concept_scores[idx]:.2f}, idx={int(idx)}'
                   for idx in topk_indices]
    fig = imshow(topk_imgs, topk_titles)
    if suptitle:
      fig.suptitle(suptitle)

show_topk(imgs_all, scores, k=k,
          suptitle=f'Most activating images for channel {channel}')


**Exercise:** Pick different channels. Do the top images show a coherent concept? Which channels are 'human interpretable'?

---

## 4) Unsupervised concept extraction — PCA / NMF / k‑means

We compute activation vectors for many probing images (for efficiency reasons spatially pooled). Then we can try to extract prototypical vectors via Principal Component Analysis (PCA) / Non-negative Matrix Factorization (NMF) / k-means clustering. To compare and explore, the images that most strongly project on each component / cluster centroid are visualized.

> Note: Due to the spatial pooling, our prototypes have the shape $[C]$ of an activation map pixel. This comes in handy, because we can then either do the prototype comparison against again pooled activations (=per image scores for each concept), or compare activation map pixels (=per activation map pixel scores, resulting in a heatmap as in Grad-CAM). The latter can also be visualized by using the heatmap as a mask.

*For details see the [ICE paper](https://ojs.aaai.org/index.php/AAAI/article/view/17389).*


In [None]:
# Cell: unsupervised concept mining with polysemantic neurons

num_concepts = 6  # <<<== EXERCISE: Use different number of concepts
k = 5

# We here reduce an activation channel to a single value,
# such that each image is represented by a vector of shape [1, C].
pooled = acts_all.mean(dim=(2,3))  # shape [B, C]
##                      ^^^== EXERCISE: Remove the reduction to adopt to concept segmentation
pooled_np = pooled.numpy()


# PCA: each principal component is a concept
pca = PCA(n_components=num_concepts)  # obtain principal components (PC)
pcs = pca.fit_transform(pooled_np) # project to PCs, dim: [B, num_concepts]
# for each PC, find images with largest absolute projection
for i in range(num_concepts):
    concept_scores = pcs[:, i]

    show_topk(imgs_all, concept_scores, suptitle=f'Principal Component {i} (top {k} samples)')


# # <<<== EXERCISE: Try out alternative techniques for concept prototype mining
# # NMF: find non-negative matrices U (projection), W (dictionary) such that
# #      g(X)=pooled_np ~= U*W
# # each dictionary entry in W is a concept vector.
# nmf = NMF(n_components=num_concepts, random_state=0, max_iter=1000)
# # !! the fitting requires only non-negative entries in X !!
# S = nmf.fit_transform(pooled_np * (pooled_np>0)) # project to concept scores, dim: [B, num_concepts]
# P = nmf.components_            # dictionary matrix, dim: [num_concepts, C]
# for i in range(num_concepts):
#     concept_scores = S[:, i]
#     show_topk(imgs_all, concept_scores, suptitle=f'Dictionary entry {i} (top {k} samples)')


# # <<<== EXERCISE: Try out alternative techniques for concept prototype mining
# # k-means: each cluster is a concept
# km = KMeans(n_clusters=num_concepts, random_state=0).fit(pooled_np)
# concept_labels = km.labels_
# for c in range(num_concepts):
#     # subselect images classified as the concept
#     idxs = np.where(concept_labels==c)[0]
#     # find images closest to centroid
#     centroid = km.cluster_centers_[c]
#     dists = np.linalg.norm(pooled_np[idxs] - centroid, axis=1)

#     show_topk(imgs_all[idxs], dists, k=k, suptitle=f'kmeans cluster {c} ({k} closest images)')

This can also be extended to segmentation (optional here).

In [None]:
# ## OPTIONAL: EXTEND TO UNSUPERVISED CONCEPT SEGMENTATION ON CNNs

# num_concepts = 6
# k = 5

# # >>> Comment in to re-activate pooling
# #pooled = acts_all.mean(dim=(2,3))  # shape [B, C]
# #pooled_np = pooled.numpy()
# acts_np = acts_all[:100].numpy()   # shape [100, C, H, W] (subselect for efficiency)
# B, C, H, W = acts_np.shape

# # Helper function for top-k selection and plotting
# def segment_topk(imgs: torch.Tensor, scores: Union[np.ndarray, torch.Tensor],
#                  k=k, suptitle=None):
#     """
#     Select and segment the top k scoring images from imgs.

#     imgs should be a Collection of image tensors,
#     the concept scores a tensor or numpy array of dim [B, H, W].
#     """
#     k = min(k, len(concept_scores))
#     heatmaps = torch.tensor(scores)  # dim [B, H, W] or [B]

#     # Get per-image scores
#     do_segment = heatmaps.dim() == 3
#     scores = heatmaps.mean(dim=(1,2)) if do_segment else heatmaps

#     # Get indices of k top-scoring images
#     topk_indices = torch.topk(scores, k=k).indices

#     # Plot
#     topk_imgs   = [unnormalize(imgs[idx]) for idx in topk_indices]
#     topk_titles = [f'score={scores[idx]:.2f}, idx={int(idx)}'
#                    for idx in topk_indices]
#     # Do per-image masking
#     if do_segment:
#       topk_imgs = [mask_heatmap(topk_img, (heatmaps[idx]>0.5)*1.0, alpha=1)
#                    for idx, topk_img in zip(topk_indices, topk_imgs)]

#     fig = imshow(topk_imgs, topk_titles)

#     if suptitle:
#       fig.suptitle(suptitle)

# # PCA: each principal component is a concept
# pca = PCA(n_components=num_concepts)  # obtain principal components (PC)

# # >>> Comment in to reactivate pooling
# #pca.fit(pooled_np)  # fitting could be done on the pooled version for efficiency
# pca.fit(acts_np.reshape((-1, C)))

# pcs = pca.transform(acts_np.reshape((-1, C)))  # project [B*H*W, C] -> [B*H*W, num_concepts]
# # for each PC, find images with largest absolute projection
# for i in range(num_concepts):
#     concept_scores = pcs[:, i]
#     concept_scores = pcs[:, i].reshape((B, H, W))
#     # normalize
#     concept_scores = numpy_sigmoid(concept_scores)
#     segment_topk(imgs_all, concept_scores, suptitle=f'Principal Component {i} (top {k} positive projections)')


# # k-means: each cluster is a concept
# km = KMeans(n_clusters=num_concepts, random_state=0)

# # >>> Comment in to reactivate pooling
# #km.fit(pooled_np)  # fitting could be done on the pooled version for efficiency
# km.fit(acts_np.reshape((-1, C)))
# concept_labels = km.labels_
# for c in range(num_concepts):
#     concept_acts = acts_np.reshape((-1, C))

#     # find images closest to centroid
#     centroid = km.cluster_centers_[c]
#     dists = np.linalg.norm(concept_acts - centroid, axis=1)

#     dists = dists.reshape((-1, H, W))
#     # normalize
#     dists = (dists - dists.min())/(dists.max() - dists.min() + 0.000001)
#     segment_topk(imgs_all, dists, k=k, suptitle=f'kmeans cluster {c} ({k} closest images)')


**Exercise (3 min):** compare PCA vs k‑means prototypes. Which yields more semantically coherent concepts? Discuss strengths/weaknesses.

---

## 5) Supervised concept extraction — simplified CAV

To keep the tutorial self‑contained we create a *synthetic concept* by inserting a small colored patch into images. That way participants can quickly label concept vs random and run a TCAV‑style pipeline.

**Steps implemented below:**

1. Create concept images by copying some test images and stamping a small red square in the corner.
2. Compute activations for concept and random sets.
3. Train a linear classifier on activations to get a concept direction (CAV).

*For details see the [TCAV paper](http://proceedings.mlr.press/v80/kim18d.html)*.

In [None]:
# Cell: create synthetic concept (red patch) and showcase positive and negative samples
num_pos_images = 200

def add_red_patch(img_tensor, patch_size=30, pos=(0,0)):
    x = img_tensor.clone()
    x[:, pos[1]:pos[1]+patch_size, pos[0]:pos[0]+patch_size] = torch.tensor([1.0,-1.0,-1.0]).view(3,1,1)
    return x

# Build concept data:
# (+) positive samples (take first 200 images and add patch)
concept_idx = list(range(num_pos_images))
concept_imgs = imgs_all[concept_idx].clone()
for i in range(len(concept_imgs)):
    concept_imgs[i] = add_red_patch(concept_imgs[i], patch_size=30, pos=(10,10))

# (-) negative samples (200 random other images)
random_idx = list(range(200,400))
random_imgs = imgs_all[random_idx].clone()

imshow([unnormalize(img) for img in concept_imgs[:3]], 'concept positive');
imshow([unnormalize(img) for img in random_imgs[:3]], 'concept negative');

In [None]:
# Cell: collect activations for synthetic concept dataset

# function to get pooled activations for a set of images
@torch.no_grad()
def pooled_activations_for_images(img_tensor_batch, batch_size=32):
    pooled_all = []
    model.eval()
    for i in tqdm(range(0, len(img_tensor_batch), batch_size), "Batch"):
        x = img_tensor_batch[i : (i+1)].to(device)
        _ = model(x)
        a = activations.clone().cpu().mean(dim=(2,3))
        pooled_all.append(a)
    return torch.cat(pooled_all, dim=0).numpy()

# !! to reduce computational cost !! => used pooled activations
concept_acts = pooled_activations_for_images(concept_imgs)
random_acts = pooled_activations_for_images(random_imgs)

**Exercise :**

* Try different patch colors/positions and sizes, and see how TCAV score changes.
* (Optional) Try using a real concept. E.g., pick sample images with label $1$ (='automobile' for CIFAR10) as positive class and random others as negative. What practical challenges arise?

-------------

We can now calculate train the CAV.

The CAV is the weight vector of a linear classifier that separates between the activations of positive and of negative concept samples. It represents the latent vector pointing into the direction of our concept.

> Note: To reduce computational costs, we
> 1. use pooled activations (CAV of dimensionality $[Channels]$) instead of the full activation map (spatially-aware CAV of dimensionality $[Channels, Height, Width]$);
> 2. and use a logistic regressor instead of a linear support vector machine (try out the difference).

In [None]:
# Cell: train linear concept classifier (concept model)

X = np.vstack([concept_acts, random_acts])
y = np.array([1]*len(concept_acts) + [0]*len(random_acts))
concept_model = LogisticRegression(max_iter=1000).fit(X, y)

# The model weights are the concept activation vector (CAV)
cav = concept_model.coef_.reshape(-1)  # direction vector

## 6) Evaluating Concept-to-output Attribution — TCAV

Now that we have collected several types of concept vectors, we can use these to inspect the processing of the model.
One simple insight is whether the concept's presence has a tendency to positively or negatively contribute to an output class. This can be measured by the TCAV score, which completes the pipeline started above from the [TCAV paper](http://proceedings.mlr.press/v80/kim18d.html):

4. Compute directional derivative of model output for a target class w.r.t. activations, projected onto the CAV, then compute TCAV score as fraction of images with positive directional derivative.


*More details:* The TCAV score of a CAV requires to calculate the directional derivative of our (pooled) activation map along the CAV. The meaning of this derivative is: If it is positive, the CAV positively contributed to the output; and vice versa. Mind that this could be realized also using other feature attribution methods: One has to obtain an estimate of each activation channel's importance to the output, and can then project this to the CAV subspace using dot product.
Counting the number of images from our test data with positive concept-to-output attribution yields the TCAV score of the CAV:
$$ TCAV = \frac{\left|\{x \in X_{\text{test}}, \frac{\partial f_{\text{act}\to\text{out}}}{\partial \text{act}} \cdot \text{CAV}>0 \}\right|}{|X_{\text{test}}|} $$
Note that this evaluation can be done irrespective of where the CAV originates from. We can, thus, also compare our supervised CAV(s) with the unsupervised ones.

*For details see the [TCAV paper](http://proceedings.mlr.press/v80/kim18d.html)*.

In [None]:
# Cell: train CAV and calculate TCAV score

# Set the CAV to be measured; dimension (due to pooling): [C]
#cav_tensor: torch.Tensor = torch.tensor(cav)  # supervised CAV
cav_tensor = torch.zeros(cav.shape); cav_tensor[channel] = 1. # unit vector (single filter)
#cav_tensor = pca.components_[-1]  # principal component (numpy array)
#cav_tensor = km.cluster_centers_[0]  # cluster centroid (numpy array)
# ^^^==== EXERCISE: Try out different CAVs (supervised and unsupervised)

cav_tensor = torch.tensor(cav_tensor).to(device)

# Possibly fix the class to calculate the TCAV score for
# (if None, will always calculate the contribution to the predicted class)
target_class_idx = None #trainset.class_to_idx['dog']
#                  ^^^==== EXERCISE: Try out different target classes


# Select test data
num_imgs = 80
target_imgs = imgs_all[400:400 + num_imgs].to(device)


# To collect activations, we again use the pytorch hook mechanism (cf. Grad-CAM).
# This time, we retain the gradient.
# (1) define the hook functions:
activations = None
def save_activation_retain_grad(module, input, output):
    global activations
    activations = output
    output.retain_grad()

# (2) register the hook
target_module = dict(model.named_modules())[last_conv_layer_name]
handle_a = target_module.register_forward_hook(save_activation_retain_grad)


# Count samples where the CAV positively contributes to the target class
positive_count = 0
model.eval()
for x in tqdm(target_imgs, desc=f"Images processed for TCAV calculation "
                                f"(target class {'\''+classes[target_class_idx]+'\'' if target_class_idx is not None else 'best'}')"):

    # forward pass
    out = model(x.unsqueeze(0))
    # choose top predicted class as target
    tclass = out.argmax(dim=1).item() if target_class_idx == None else target_class_idx
    score = out[0, tclass]

    # backward pass: compute gradient of score w.r.t activation
    model.zero_grad()
    score.backward(retain_graph=False)
    act_grad = activations.grad  # [1, C, H, W]

    # !! to reduce computational cost !! => used pooled activations
    pooled_grad = act_grad.mean(dim=(2,3)).squeeze(0)  # [C]
    # directional derivative approx = pooled_grad dot cav
    dd = (pooled_grad * cav_tensor).sum().item()

    # add to the TCAV score
    if dd > 0:
        positive_count += 1
handle_a.remove()

tcav_score = positive_count / num_imgs
print('TCAV score (fraction of images with positive directional derivative):', tcav_score)


---

## 7) Wrap up

Some thoughts for further discussion:

* Compare Grad‑CAM explanations to concept explanations (visualized prototypes and TCAV scores). What are the advantages and limitations of each?
* Compare unsupervised and supervised results regarding the concept-to-output sensitivity. What are differences in interpretability? How would you quantify this?
* How would you improve concept datasets for a real analysis (human labeling, multiple concept examples, negatives, etc.)?

---

