In [None]:
from IPython.core.interactiveshell import InteractiveShell

InteractiveShell.ast_node_interactivity = "all"

# Multilayer-Perceptron and Images

## Goals

- Train an MLP and understand how it works on images

## Google Colab Check

In [None]:
import sys

# Detect Colab
IN_COLAB = "google.colab" in sys.modules
print(f"In Colab: {IN_COLAB}")

# Show prominent message if in Colab
if IN_COLAB:
    try:
        from IPython.display import Markdown, display

        display(
            Markdown(
                """
> 💾 **Optionally:**  
> Save this notebook to your **personal Google Drive** to persist any changes.
>
> *Go to `File ▸ Save a copy in Drive` before editing.*
            """
            )
        )
    except Exception:
        print(
            "\n💾 Optionally: Save the notebook to your personal Google Drive to persist changes.\n"
        )

We mount google drive to store data.

In [None]:
if IN_COLAB:
    from google.colab import drive

    drive.mount("/content/drive")

## Specify Data Path

**Modify the following paths if necessary.**

That is where your data will be stored.

In [None]:
from pathlib import Path

if IN_COLAB:
    DATA_PATH = Path("/content/drive/MyDrive/cas-dl-module-compvis-part1")
else:
    DATA_PATH = Path("../../data")
assert DATA_PATH.exists(), f"PATH: {DATA_PATH} does not exist."

## Install Lectures Package

Install `dl_cv_lectures` package with all necessary dependencies.

This package provides the environment of the exercises-repository, as well as helper- and utils modules: [Link](https://github.com/marco-willi/cas-dl-compvis-exercises-hs2025)

The following code installs the package from a local repository (if available), otherwise it installs it from the exercise repository.

In [None]:
import subprocess
import sys
from pathlib import Path

from rich.console import Console

console = Console()


def ensure_dl_cv_lectures():
    """Ensure dl_cv_lectures is installed (local or from GitHub)."""
    try:
        import dl_cv_lectures

        console.print(
            "[bold green]✅ dl_cv_lectures installed — all good![/bold green]"
        )
        return
    except ImportError:
        console.print("[bold yellow]⚠️ dl_cv_lectures not found.[/bold yellow]")
    repo_path = Path("/workspace/pyproject.toml")
    if repo_path.exists():
        console.print("[cyan]📦 Installing from local repository...[/cyan]")
        cmd = [sys.executable, "-m", "pip", "install", "-e", "/workspace"]
    else:
        console.print("[cyan]🌐 Installing from GitHub repository...[/cyan]")
        cmd = [
            sys.executable,
            "-m",
            "pip",
            "install",
            "git+https://github.com/marco-willi/cas-dl-compvis-exercises-hs2025",
        ]
    try:
        subprocess.run(cmd, check=True)
        console.print("[bold green]✅ Installation successful![/bold green]")
    except subprocess.CalledProcessError as e:
        console.print(f"[bold red]❌ Installation failed ({e}).[/bold red]")


ensure_dl_cv_lectures()

### Load Libraries

Load all libraries and packages used in this exercise.

In [None]:
import math

import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchinfo
import torchshow as ts
from torchvision import transforms
from torchvision.transforms.v2 import functional as TF
from torchvision.utils import make_grid

from dl_cv_lectures import visualize
from dl_cv_lectures.classification import train_one_epoch
from dl_cv_lectures.data import pattern

Define a default device for your computations.

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using: {device}")

## 1) Define Data, Loaders, and MLP

Quickly get data and train an MLP.

We use some pre-defined functionality from the `dl_cv_lectures` package. Feel free to substitute with your own functionality. 

## Data

We use a dataset that consists of inspection images of microscopic components. Some of them are faulty (label=1) and need to be identified and sorted out.

Lets load the data and take a look at it.

First we create a [torch.utils.data.Dataset](https://pytorch.org/docs/stable/data.html#torch.utils.data.Dataset).

In [None]:
ds_train = pattern.PatternDataset(
    num_samples=100, image_side_length=16, seed=123, max_errors=3, max_x_y_shift=0
)

Let's take a look at the first observation.

In [None]:
image, label = ds_train[0]
print(image.size)
image

We collect a few samples, visualize them and try to understand how label and images are related.

In [None]:
def get_images_and_labels_from_ds(
    ds: torch.utils.data.Dataset, num_images_to_fetch: int = 16
) -> list[torch.Tensor]:
    """Fetch first n images from a torch.utils.data.Dataset with (image, label) signature."""
    # for each image: convert it to (C x H x W) format and scale to 0-1
    images = [
        TF.to_image(ds[i][0]).to(torch.float32) / 255.0
        for i in range(0, num_images_to_fetch)
    ]
    labels = [ds[i][1] for i in range(0, num_images_to_fetch)]
    return images, labels

In [None]:
images, labels = get_images_and_labels_from_ds(ds_train, num_images_to_fetch=16)

fig, ax = visualize.plot_square_collage_with_captions(
    images, [f"Label: {label}" for label in labels], global_normalize=True
)
plt.tight_layout()

**Question**: When is a pattern label=0, when is it label=1?

<details>
<summary>Click to reveal answer</summary>

The pattern is labeled as **label=0** when it's **without errors** (a perfect grid / line pattern), and **label=1** when it contains **errors** (off-line pixels, up to `max_errors=3` deviations from the perfect pattern). The dataset simulates a quality inspection task where label=1 indicates a defective/faulty component.

</details>

## MLP Architecture

Next, we define an MLP.

In [None]:
class MLP(nn.Module):
    """A Multi-Layer Perceptron (MLP) model for classification."""

    def __init__(
        self,
        num_hidden: int,
        num_classes: int,
        input_size: tuple[int, int, int] = (1, 28, 28),
    ):
        """
        Args:
            num_hidden (int): Number of neurons in the hidden layer.
            num_classes (int): Number of output classes for classification.
            input_size tuple[int, int, int]: The dimensions of the input image.
        """
        super().__init__()
        # Flatten the input image into a 1D tensor
        self.flatten = nn.Flatten()

        # Hidden layer: fully connected layer from input_size to num_hidden neurons.
        self.hidden = nn.Linear(
            in_features=input_size[0] * input_size[1] * input_size[2],
            out_features=num_hidden,
            bias=True,
        )

        # Output layer: fully connected layer from num_hidden neurons to num_classes outputs.
        self.output = nn.Linear(in_features=num_hidden, out_features=num_classes)

    def forward(self, x):
        """
        Forward pass through the MLP model.

        Args:
            x (torch.Tensor): Input tensor (batch_size, channels, height, width).

        Returns:
            torch.Tensor: Output logits (before softmax).
        """
        # Flatten the input tensor into (batch_size, input_size[0] * input_size[1])
        x = self.flatten(x)

        # Apply the hidden layer (linear transformation)
        x = self.hidden(x)

        # Apply ReLU activation function to introduce non-linearity
        x = F.relu(x)

        # Apply the output layer (linear transformation) to get the logits
        x = self.output(x)

        # Return the output logits (not yet passed through softmax)
        return x

Lets initialize the model and inspect it using `torchinfo`

In [None]:
torch.manual_seed(123)
net = MLP(num_hidden=16, num_classes=2, input_size=(1, 16, 16))
print(net)
print(torchinfo.summary(net, input_size=(1, 1, 16, 16)))

Now we define loss function and optimizer.

In [None]:
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(net.parameters(), lr=1e-1, weight_decay=1e-3)

Next, we define the training dataset and dataloader.

In [None]:
image_transforms = transforms.Compose(
    [
        transforms.ToTensor(),
    ]
)


ds_train = pattern.PatternDataset(
    num_samples=100000,
    seed=123,
    max_errors=3,
    max_x_y_shift=0,
    transform=image_transforms,
)

dl_train = torch.utils.data.DataLoader(
    ds_train, batch_size=128, shuffle=True, num_workers=4
)

## Model Training

That's it! We are ready to train our model.

In [None]:
total_epochs = 3
for epoch in range(0, total_epochs):
    print(f"Starting Epoch: {epoch + 1} / {total_epochs}")
    train_one_epoch(dl_train, net, optimizer, loss_fn, device=device)

## 2) Analyse MLP properties

Now it is getting interesting! 

We want to inspect the weights of the MLP.

We can access the layers of our model by the attributes that we defined.

For example we can access the attribute `hidden` of our `MLP`  object.

In [None]:
weights = net.hidden.weight
weights.shape

The weight matrix $\mathbf{W}$ is multiplied with an image $\mathbf{x}$ to produce the layer activations $\mathbf{a}$ with: $\mathbf{a} = \mathbf{x} \mathbf{W}^T + \mathbf{b}$

We can see that the hidden layer has 16 neurons, each of which is connected to all input neurons.

Each row in $\mathbf{W}$ can be visualized as an image. We need to reshape it accordingly.

**Question**: What do you expect to see when visualizing the weights?

<details>
<summary>Click to reveal answer</summary>

You should expect to see **pattern-like structures** in the weight visualizations. Each neuron in the hidden layer learns to detect specific features or patterns in the input images. Since the task involves detecting errors in a grid pattern, the weights should resemble:
- Grid-like structures
- Edge detectors
- Pattern detectors for the expected/unexpected pixel positions

The weights essentially represent what each neuron is "looking for" in the input image. Neurons that activate strongly for specific patterns will have weights that visually resemble those patterns.

</details>

In [None]:
def visualize_weights(
    weights: torch.Tensor, figsize: tuple[int, int] = (12, 12), scale_each: bool = True
):
    num_neurons = weights.shape[0]
    dim = weights.shape[1]

    side_length = int(math.sqrt(dim))

    weights = weights.reshape(num_neurons, 1, side_length, side_length)

    nrow = int(math.sqrt(num_neurons))
    image_grid = TF.to_pil_image(
        make_grid(weights, nrow=nrow, normalize=True, scale_each=scale_each)
    )

    _, ax = plt.subplots(figsize=figsize)
    _ = ax.imshow(image_grid, cmap="Greys_r")
    _ = ax.axis("off")
    plt.show()


visualize_weights(weights, scale_each=False)

**Question**: How do you interpret the weights? How does it match with your expectations?

<details>
<summary>Click to reveal answer</summary>

The weights show learned **feature detectors** that each neuron uses to identify patterns in the input. You should observe:

- **Grid-like patterns**: Some neurons have learned to detect the regular grid structure
- **Local patterns**: Weights highlight specific pixel configurations that distinguish label=0 from label=1
- **Complementary features**: Different neurons capture different aspects of the pattern (some may detect presence of grid, others detect absence/errors)

This matches expectations because the MLP learns to decompose the classification task into detecting multiple local features. Each hidden neuron specializes in recognizing a specific pattern aspect, and the output layer combines these features to make the final classification decision.

The visualization reveals that **MLPs learn interpretable features** when trained on structured data like grid patterns.

</details>

**Question**: What happens if we reduce the hidden layer to just 2 neurons?

<details>
<summary>Click to reveal answer</summary>

With only **2 neurons in the hidden layer**, the model has severely limited capacity:

- **Reduced representational power**: Only 2 feature detectors to capture all relevant patterns
- **Simpler learned features**: Each neuron must capture more general/coarse features since there are fewer neurons
- **Potentially lower accuracy**: The model may struggle to distinguish all cases, especially edge cases
- **More visible/interpretable weights**: With just 2 neurons, it's easier to see what each one has learned

The weights will show the **most important/discriminative features** that the model found for separating the two classes. The model is forced to learn a very compact representation, essentially finding the 2 most critical patterns needed for classification.

This demonstrates the importance of **model capacity** - too few neurons leads to underfitting, while too many might overfit.

</details>

In [None]:
torch.manual_seed(123)
net = MLP(num_hidden=2, num_classes=2, input_size=(1, 16, 16))
dl_train = torch.utils.data.DataLoader(
    ds_train, batch_size=128, shuffle=True, num_workers=4
)
optimizer = torch.optim.SGD(net.parameters(), lr=1e-1, weight_decay=1e-3)
total_epochs = 5
for epoch in range(0, total_epochs):
    print(f"Starting Epoch: {epoch + 1} / {total_epochs}")
    train_one_epoch(dl_train, net, optimizer, loss_fn)

In [None]:
visualize_weights(net.hidden.weight, scale_each=False)

# 3) What if?

What if we make it a bit more difficult and let the pattern randomly shift spatially?

In [None]:
ds_train = pattern.PatternDataset(
    num_samples=100000, seed=123, max_errors=3, max_x_y_shift=1
)

In [None]:
images, labels = get_images_and_labels_from_ds(ds_train, num_images_to_fetch=16)

fig, ax = visualize.plot_square_collage_with_captions(
    images, [f"Label: {label}" for label in labels], global_normalize=True
)
plt.tight_layout()

**Question**: Is this a more difficult problem? How will the weights differ from the simpler case?

<details>
<summary>Click to reveal answer</summary>

**Yes, this is a significantly more difficult problem!**

**Why it's harder:**
- **Spatial variation**: Patterns can now appear at different positions (up to ±1 pixel shift)
- **Loss of spatial structure**: The MLP cannot exploit spatial relationships - it treats each pixel independently
- **More data needed**: The model needs to learn that the same pattern at different positions has the same meaning

**How weights will differ:**
- **Less crisp patterns**: Weights become more blurred/diffuse since they need to account for multiple positions
- **More redundancy**: Multiple neurons may learn similar but shifted versions of the same pattern
- **Harder to interpret**: The clear grid structures may become less visible as neurons try to be position-invariant

This highlights a fundamental **limitation of MLPs for image data**: they lack **translation invariance**. This is why CNNs with their sliding convolutional kernels are much better suited for image tasks - they naturally handle spatial shifts.

</details>

In [None]:
torch.manual_seed(123)
net = MLP(num_hidden=16, num_classes=2, input_size=(1, 16, 16))
ds_train = pattern.PatternDataset(
    num_samples=100000,
    seed=123,
    max_errors=3,
    max_x_y_shift=1,
    transform=image_transforms,
)
dl_train = torch.utils.data.DataLoader(
    ds_train, batch_size=128, shuffle=True, num_workers=4
)
optimizer = torch.optim.SGD(net.parameters(), lr=1e-1, weight_decay=1e-3)
total_epochs = 5
for epoch in range(0, total_epochs):
    print(f"Starting Epoch: {epoch + 1} / {total_epochs}")
    train_one_epoch(dl_train, net, optimizer, loss_fn)

In [None]:
visualize_weights(net.hidden.weight, scale_each=False)

This is already a bit more difficult to interpret!

## 4) Let's try somethig more crazy!

We extend the images and place each pattern in a random quarant.

In [None]:
from dl_cv_lectures.transform import RandomQuadrantPad

torch.manual_seed(123)
ds_train = pattern.PatternDataset(
    num_samples=100000,
    seed=123,
    max_errors=3,
    max_x_y_shift=0,
    transform=transforms.Compose(
        [
            transforms.ToTensor(),
            RandomQuadrantPad(),
            transforms.Lambda(lambda x: TF.to_pil_image(x)),
        ]
    ),
)

In [None]:
images, labels = get_images_and_labels_from_ds(ds_train, num_images_to_fetch=16)

fig, ax = visualize.plot_square_collage_with_captions(
    images, [f"Label: {label}" for label in labels], global_normalize=True
)
plt.tight_layout()

**Question**: What do you think happens here?

<details>
<summary>Click to reveal answer</summary>

The `RandomQuadrantPad()` transformation is **randomly placing the 16×16 pattern into one of the four quadrants** of a larger 32×32 image:

- **Input**: 16×16 pattern image
- **Output**: 32×32 image with the pattern in top-left, top-right, bottom-left, or bottom-right quadrant
- **Other quadrants**: Filled with zeros (black)

**Purpose:**
This creates **extreme spatial variation** to test the MLP's ability to handle position-dependent features. The pattern can now appear in 4 completely different locations, making the task much harder for an MLP.

**Impact on MLP:**
- The MLP input size increases from 256 (16×16) to 1024 (32×32) pixels
- Many more weights to learn
- The model must learn 4 different "versions" of the same pattern at different positions
- This further demonstrates why **MLPs are not well-suited for spatially-varying image data**

This sets up a comparison with CNNs, which handle such spatial variations much more naturally through their translation-invariant convolutions.

</details>

In [None]:
torch.manual_seed(123)
net = MLP(num_hidden=16, num_classes=2, input_size=(1, 16 * 2, 16 * 2))
ds_train = pattern.PatternDataset(
    num_samples=100000,
    seed=123,
    max_errors=3,
    max_x_y_shift=0,
    transform=transforms.Compose(
        [
            transforms.ToTensor(),
            RandomQuadrantPad(),
        ]
    ),
)
dl_train = torch.utils.data.DataLoader(
    ds_train, batch_size=128, shuffle=True, num_workers=4
)
optimizer = torch.optim.SGD(net.parameters(), lr=1e-1, weight_decay=1e-3)
total_epochs = 5
for epoch in range(0, total_epochs):
    print(f"Starting Epoch: {epoch + 1} / {total_epochs}")
    train_one_epoch(dl_train, net, optimizer, loss_fn)

In [None]:
visualize_weights(net.hidden.weight, scale_each=False)

## 5) What happens if we increase the positional uncertainty?

Let's randomly shift the pattern by up to 3 pixels.

In [None]:
torch.manual_seed(123)
ds_train = pattern.PatternDataset(
    num_samples=100000,
    seed=123,
    max_errors=3,
    max_x_y_shift=3,
    transform=transforms.Compose(
        [
            transforms.ToTensor(),
            RandomQuadrantPad(),
            transforms.Lambda(lambda x: TF.to_pil_image(x)),
        ]
    ),
)

images, labels = get_images_and_labels_from_ds(ds_train, num_images_to_fetch=16)

fig, ax = visualize.plot_square_collage_with_captions(
    images, [f"Label: {label}" for label in labels], global_normalize=True
)
plt.tight_layout()

We switch to the Adam optimizer which is often much faster and needs less care tuning learning rates.

In [None]:
torch.manual_seed(123)
net = MLP(num_hidden=16, num_classes=2, input_size=(1, 16 * 2, 16 * 2))
ds_train = pattern.PatternDataset(
    num_samples=100000,
    seed=123,
    max_errors=3,
    max_x_y_shift=3,
    transform=transforms.Compose(
        [
            transforms.ToTensor(),
            RandomQuadrantPad(),
        ]
    ),
)
dl_train = torch.utils.data.DataLoader(
    ds_train, batch_size=128, shuffle=True, num_workers=4
)
optimizer = torch.optim.Adam(net.parameters())
total_epochs = 5
for epoch in range(0, total_epochs):
    print(f"Starting Epoch: {epoch + 1} / {total_epochs}")
    train_one_epoch(dl_train, net, optimizer, loss_fn)

**Question**: What do you observe?

<details>
<summary>Click to reveal answer</summary>

With the increased positional uncertainty (up to ±3 pixel shift AND random quadrant placement), you should observe:

**Weight visualization insights:**
- **Very blurred/diffuse patterns**: Weights no longer show clear grid structures
- **Spread-out activations**: Each neuron tries to capture patterns across a wider spatial area
- **Loss of interpretability**: Hard to see what each neuron has learned
- **Potentially four repeated patterns**: Some neurons might learn similar features positioned for different quadrants

**Training characteristics:**
- **Slower convergence**: Model takes longer to learn due to increased complexity
- **Potentially lower accuracy**: Even with more neurons, the MLP struggles with this level of spatial variation
- **More training data needed**: The 100,000 samples help, but the problem is fundamentally challenging for MLPs

**Key insight:**
This demonstrates the **fundamental limitation of MLPs for images with spatial variation**. The fully-connected architecture cannot efficiently learn position-invariant features. This motivates the need for **Convolutional Neural Networks (CNNs)**, which are designed to handle spatial variations through:
- Local receptive fields
- Weight sharing
- Translation equivariance

</details>

In [None]:
visualize_weights(net.hidden.weight, scale_each=False)

## 6) A Peek Ahead: What happens if we choose a CNN?

Finally! Let's ditch the MLP.

The following CNN is deeper but has ~30 times fewer parameters.

In [None]:
import torch
import torch.nn as nn


class SmallCNN(nn.Module):
    def __init__(self, kernel_size: int, num_classes=10, num_filters=16):
        super().__init__()

        padding = kernel_size // 2

        # First convolutional layer: 1 input channel (grayscale), 16 output channels, 3x3 kernel
        self.conv1 = nn.Conv2d(
            in_channels=1,
            out_channels=num_filters,
            kernel_size=kernel_size,
            padding=padding,
        )

        # Second convolutional layer: 16 input channels, 32 output channels, 3x3 kernel
        self.conv2 = nn.Conv2d(
            in_channels=num_filters,
            out_channels=num_filters * 2,
            kernel_size=kernel_size,
            padding=padding,
        )

        # Third convolutional layer: 32 input channels, 64 output channels, 3x3 kernel
        self.conv3 = nn.Conv2d(
            in_channels=num_filters * 2,
            out_channels=num_filters * 4,
            kernel_size=kernel_size,
            padding=padding,
        )

        # Max pooling layer
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)

        # Global Average Pooling
        self.gap = nn.AdaptiveAvgPool2d((1, 1))  # Output size is 1x1 per feature map

        # Final output layer
        self.fc = nn.Linear(in_features=num_filters * 4, out_features=num_classes)

    def forward(self, x):
        # First conv layer with ReLU
        x = F.relu(self.conv1(x))
        # Second conv layer with ReLU and max pooling
        x = self.pool(F.relu(self.conv2(x)))
        # Third conv layer with ReLU and max pooling
        x = self.pool(F.relu(self.conv3(x)))

        # Global Average Pooling (GAP)
        x = self.gap(x)  # Shape will be (batch_size, 64, 1, 1)

        # Flatten the GAP output to feed into the fully connected layer
        x = torch.flatten(x, 1)  # Shape (batch_size, 64)

        # Final fully connected layer (acts as the output layer)
        x = self.fc(x)  # Shape (batch_size, num_classes)

        return x

In [None]:
net = SmallCNN(kernel_size=5, num_classes=2, num_filters=2)
print(torchinfo.summary(net, input_size=(1, 1, 16 * 2, 16 * 2)))

In [None]:
torch.manual_seed(123)
net = SmallCNN(kernel_size=5, num_classes=2, num_filters=2)
ds_train = pattern.PatternDataset(
    num_samples=100000,
    seed=123,
    max_errors=3,
    max_x_y_shift=3,
    transform=transforms.Compose(
        [
            transforms.ToTensor(),
            RandomQuadrantPad(),
        ]
    ),
)
dl_train = torch.utils.data.DataLoader(
    ds_train, batch_size=128, shuffle=True, num_workers=6
)
optimizer = torch.optim.Adam(net.parameters())
total_epochs = 3
for epoch in range(0, total_epochs):
    print(f"Starting Epoch: {epoch + 1} / {total_epochs}")
    train_one_epoch(dl_train, net, optimizer, loss_fn)

The CNN uses only two filters in the first layer, which we can visualize:

In [None]:
ts.show(net.conv1.weight)

Now let's visualize what these filters actually detect when applied to input images. We'll look at the activations (feature maps) produced by each filter.

In [None]:
# Get a batch of 16 sample images
ds_sample = pattern.PatternDataset(
    num_samples=16,
    seed=123,
    max_errors=3,
    max_x_y_shift=3,
    transform=transforms.Compose(
        [
            transforms.ToTensor(),
            RandomQuadrantPad(),
        ]
    ),
)

sample_images = torch.stack([ds_sample[i][0] for i in range(16)])
sample_labels = [ds_sample[i][1] for i in range(16)]

# Display the sample images
print(f"Sample images shape: {sample_images.shape}")
ts.show(sample_images)

### Visualize Activations from Filter 1

Apply the first convolutional layer and extract the activations from the first filter.

In [None]:
# Apply conv1 to get all feature maps
with torch.no_grad():
    conv1_output = net.conv1(sample_images)  # Shape: (16, 2, 32, 32)
    conv1_relu = F.relu(conv1_output)

# Extract activations from first filter (filter 0)
filter1_activations = conv1_relu[:, 0:1, :, :]  # Shape: (16, 1, 32, 32)

print(f"Filter 1 activations shape: {filter1_activations.shape}")
ts.show(filter1_activations)

### Visualize Activations from Filter 2

Now let's see what the second filter detects in the same images.

In [None]:
# Extract activations from second filter (filter 1)
filter2_activations = conv1_relu[:, 1:2, :, :]  # Shape: (16, 1, 32, 32)

print(f"Filter 2 activations shape: {filter2_activations.shape}")
ts.show(filter2_activations)

### Comparison: Original Images, Filter 1, and Filter 2

Let's create a side-by-side comparison to see what each filter detects in a few example images.

In [None]:
# Create a comparison for the first 4 images
num_to_show = 8

fig, axes = plt.subplots(num_to_show, 3, figsize=(12, num_to_show * 3))

for i in range(num_to_show):
    # Original image
    _ = axes[i, 0].imshow(sample_images[i, 0].cpu(), cmap="gray")
    _ = axes[i, 0].set_title(f"Original Image {i + 1}\nLabel: {sample_labels[i]}")
    _ = axes[i, 0].axis("off")

    # Filter 1 activation
    _ = axes[i, 1].imshow(filter1_activations[i, 0].cpu(), cmap="viridis")
    _ = axes[i, 1].set_title("Filter 1 Activation")
    _ = axes[i, 1].axis("off")

    # Filter 2 activation
    _ = axes[i, 2].imshow(filter2_activations[i, 0].cpu(), cmap="viridis")
    _ = axes[i, 2].set_title("Filter 2 Activation")
    _ = axes[i, 2].axis("off")

plt.tight_layout()
plt.show()

print("\nObservation:")
print(
    "- Bright regions (yellow) in the activation maps show where the filter strongly responds"
)
print("- Dark regions (purple) show where the filter doesn't respond much")
print("- Each filter learns to detect different patterns in the input images")