# Segment 3
## Motivation
In this notebook, we explore neuron-level interpretability in a classic convolutional neural network: InceptionV1. Our guiding questions are:

- What kinds of visual patterns activate individual neurons?

- Are these patterns reflected in real images from a dataset?

- Can the synthetic visualizations help us build intuition about what natural images will activate strongly?

To answer this, we will:

1. Select neurons from an intermediate layer (mixed_4a)

2. Generate activation-maximizing images for those neurons

3. Search a dataset for real images that strongly activate the same neurons

4. Visualize and analyze the resulting patterns

## Background Concepts
### What is a "neuron" in a CNN?

In convolutional networks, a neuron corresponds to one channel (feature map) in a convolutional layer. Each neuron responds to specific visual patterns such as edges or orientations (early layers), textures or shapes (middle layers) or object parts or concepts (later layers).

### Why `mixed_4a`?

InceptionV1â€™s mixed_4a layer sits roughly in the middle of the network, a sweet spot for studying emergent visual features. If it were too early, the detection would be mostly low-level edges, and if it were too late it would be highly abstract object concepts.

## Imports and Reproducibility
We start by fixing all random seeds to make sure the results are reproducible

In [3]:
import torch
import numpy as np
import random
import os
from pathlib import Path
from typing import Dict, List, Tuple

import matplotlib.pyplot as plt
from PIL import Image

from torchvision import transforms
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader, Subset

from lucent.optvis import render, param, transform, objectives
from lucent.modelzoo import inceptionv1

In [4]:
torch.manual_seed(42)
np.random.seed(42)
random.seed(42)

## The `Analyzer` Class
To keep the notebook clean and modular, we wrap the full pipeline inside an `Analyzer`class. We use it to load the model, extract the activations, optimize images, search the dataset and visualize the results.

In [6]:
class Analyzer:
    """
    A utility class for neuron-level interpretability analysis
    in InceptionV1 using activation maximization and dataset search.
    """

    def __init__(self, target_layer, device: str = 'cuda' if torch.cuda.is_available() else 'cpu'):
        self.device = device
        self.model = None
        self.target_layer = target_layer
        self.selected_neuron = None

        # Dictionary used to store activations captured by hooks
        self.activations = {}

        self.load_model()

        # Standard ImageNet preprocessing
        self.transform = transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])
        ])

    def load_model(self):
        """
        Loads a pretrained InceptionV1 model
        and sets it to evaluation mode.
        """
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.model = inceptionv1(pretrained=True).to(device).eval()

    def set_selected_neuron(self, neuron: int):
        """
        Sets the neuron (channel index) to be analyzed.
        """
        self.selected_neuron = neuron

    def get_layer_info(self) -> Dict:
        """
        Retrieves the module corresponding to the target layer name.
        """
        target_module = None
        for name, module in self.model.named_modules():
            if name == self.target_layer:
                target_module = module
                break

        if target_module is None:
            raise ValueError(f"Layer {self.target_layer} not found in model")

        return {
            'name': self.target_layer,
            'module': target_module
        }

    # As we've seen in previous segments, activation maximization is a technique used to generate input images that maximize the activation of specific neurons within a neural network. 
    # This is typically done through gradient-based optimization, where we start with a random image and iteratively adjust it to increase the activation of the target neuron. 
    # The resulting images can provide insights into what features or patterns the neuron is responsive to, helping us understand the internal workings of the model.

    def activation_maximization(self, neuron, steps: int = 512, lr: float = 0.05) -> torch.Tensor:
        """
        Generates an image that maximally activates a given neuron
        using gradient-based optimization.
        """
        self.set_selected_neuron(neuron)

        # Objective: maximize a specific channel in a specific layer
        obj = objectives.channel(self.target_layer, self.selected_neuron)

        # Small transformations prevent the optimization from producing high-frequency artifacts, and encourage more interpretable patterns.
        transform_list = [
            transform.pad(12),
            transform.jitter(8),
            transform.random_scale([0.9, 0.95, 1.05, 1.1]),
            transform.random_rotate([-5, -4, -3, -2, -1, 0, 1, 2, 3, 4, 5]),
            transform.jitter(2)
        ]

        # Using an FFT-based parameterization biases the optimization towards more natural-looking structures.
        def param_f(self, lr: float = 0.05):
            return param.image(224, fft=True, decorrelate=True, lr=lr)
        
        result = render.render_vis(
            self.model,
            obj,
            param_f=param_f,
            transforms=transform_list,
            thresholds=(steps,),
            show_inline=False
        )

        optimized_image = result[0]
        return optimized_image
    
    # To measure neuron responses to real images, we attach a forward hook to the target layer.
    def setup_activation_hook(self):
        """
        Registers a forward hook to capture activations
        from the target layer during inference.
        """
        def hook_fn(module, input, output):
            self.activations[self.target_layer] = output.detach()

        for name, module in self.model.named_modules():
            if name == self.target_layer:
                module.register_forward_hook(hook_fn)
                break

    def get_neuron_activation(self, image: torch.Tensor) -> float:
        """
        Computes the mean activation value of the selected neuron
        for a given image.
        """
        with torch.no_grad():
            if len(image.shape) == 3:
                image = image.unsqueeze(0)

            image = image.to(self.device)
            _ = self.model(image)

            layer_activation = self.activations[self.target_layer]
            neuron_activation = layer_activation[0, self.selected_neuron]

            # Average over spatial dimensions
            return neuron_activation.mean().item()
        
    # Instead of a single number, we can also inspect where the neuron fires spatially.
    def get_neuron_activation_map(self, image: torch.Tensor) -> np.ndarray:
        """
        Returns a normalized spatial activation map
        for the selected neuron.
        """
        with torch.no_grad():
            if len(image.shape) == 3:
                image = image.unsqueeze(0)

            image = image.to(self.device)
            _ = self.model(image)

            activation_map = self.activations[self.target_layer][0, self.selected_neuron]
            activation_map = activation_map.cpu().numpy()

            # Normalize for visualization
            if activation_map.max() > activation_map.min():
                activation_map = (activation_map - activation_map.min()) / (
                    activation_map.max() - activation_map.min()
                )

            return activation_map
        
    # We can now search the dataset to find real images that strongly activate the neuron.
    def find_highly_activating_images(
        self,
        num_samples: int = 1000,
        top_k: int = 10,
        split: str = "train",
        dataset_root: str = "data/imagenette2-320"
    ) -> List[Tuple[Image.Image, float]]:
        
        self.setup_activation_hook()
        split = "val" if split in ("validation", "valid", "test", "val") else "train"
        split_path = Path(dataset_root) / split

        transform = transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            transforms.Normalize(
                mean=[0.485, 0.456, 0.406],
                std=[0.229, 0.224, 0.225],
            ),
        ])

        dataset = ImageFolder(str(split_path), transform=transform)

        if len(dataset) > num_samples:
            indices = random.sample(range(len(dataset)), num_samples)
            dataset = Subset(dataset, indices)

        dataloader = DataLoader(dataset, batch_size=1, shuffle=False)

        image_activations = []

        for image, _ in dataloader:
            activation_value = self.get_neuron_activation(image)
            activation_map = self.get_neuron_activation_map(image)

            image_activations.append((image, activation_value, activation_map))

        image_activations.sort(key=lambda x: x[1], reverse=True)
        return image_activations[:top_k]
    
    # This step performs simple descriptive analysis (color statistics, constrast heuristics and pattern labels)
    def analyze_patterns(self, optimized_image: torch.Tensor,
                        top_images: List[Tuple[Image.Image, float, np.ndarray]]) -> Dict:

        analysis = {
            'neuron_id': self.selected_neuron,
            'layer': self.target_layer,
            'optimization_result': optimized_image,
            'top_activating_images': top_images,
            'patterns_detected': [],
            'color_analysis': {},
            'texture_analysis': {}
        }

        opt_img_pil = transforms.ToPILImage()(optimized_image.squeeze(0))

        opt_img_array = np.array(opt_img_pil)
        mean_color = np.mean(opt_img_array, axis=(0, 1))
        std_color = np.std(opt_img_array, axis=(0, 1))

        analysis['color_analysis']['optimized_image'] = {
            'mean_rgb': mean_color.tolist(),
            'std_rgb': std_color.tolist()
        }

        top_image_colors = []
        for img, activation, _ in top_images[:5]:
            img_array = np.array(img)
            img_mean = np.mean(img_array, axis=(0, 1))
            top_image_colors.append(img_mean.tolist())

        analysis['color_analysis']['top_images'] = top_image_colors

        patterns = []
        if np.std(opt_img_array) > 50:
            patterns.append("High contrast/texture")
        if np.mean(mean_color) > 150:
            patterns.append("Bright colors")
        elif np.mean(mean_color) < 100:
            patterns.append("Dark colors")

        analysis['patterns_detected'] = patterns

        return analysis

    # This step creates visualizations in a structured manner.
    def visualize_results(self, analysis: Dict, save_dir: str = 'results'):

        os.makedirs(save_dir, exist_ok=True)

        top_images = analysis['top_activating_images']
        n_images = min(len(top_images), 10)

        fig, axes = plt.subplots(n_images + 1, 2, figsize=(15, 4 * (n_images + 1)))
        fig.suptitle(f'Neuron {self.selected_neuron} Analysis - Layer {self.target_layer}',
                    fontsize=16, fontweight='bold')

        opt_img = analysis['optimization_result']
        if isinstance(opt_img, torch.Tensor):
            opt_img_display = opt_img.squeeze(0).permute(1, 2, 0).detach().cpu().numpy()
        else:
            opt_img_display = opt_img.squeeze() if opt_img.ndim == 4 else opt_img
            if opt_img_display.ndim == 3 and opt_img_display.shape[0] == 3:
                opt_img_display = np.transpose(opt_img_display, (1, 2, 0))

        opt_img_display = np.clip(opt_img_display, 0, 1)

        axes[0, 0].imshow(opt_img_display)
        axes[0, 0].set_title('Activation Maximization\nResult', fontweight='bold')
        axes[0, 0].axis('off')
        axes[0, 1].axis('off')

        for i, (img, activation, activation_map) in enumerate(top_images[:n_images]):
            row = i + 1

            axes[row, 0].imshow(img)
            axes[row, 0].set_title(f'Top {i+1} Image\nActivation: {activation:.3f}', fontsize=10)
            axes[row, 0].axis('off')

            im = axes[row, 1].imshow(activation_map, cmap='viridis', interpolation='nearest')
            axes[row, 1].set_title(f'Activation Map\nMax: {activation_map.max():.3f}', fontsize=10)
            axes[row, 1].axis('off')

        plt.tight_layout()
        plt.savefig(f'{save_dir}/neuron_{self.selected_neuron}_analysis_with_activations.png',
                   dpi=300, bbox_inches='tight')
        plt.show()

    # Finally, we tie everything together
    def run_complete_analysis(self,
                              neuron: int = 0,
                              num_samples: int = 1000,
                              top_k: int = 10,
                              optim_steps: int = 512) -> Dict:

        print("\nStep 1: Performing activation maximization...")
        optimized_image = self.activation_maximization(neuron, steps=optim_steps)

        print("\nStep 2: Finding highly activating dataset images...")
        top_images = self.find_highly_activating_images(
            num_samples=num_samples,
            top_k=top_k
        )

        print("\nStep 3: Analyzing patterns and features...")
        analysis = self.analyze_patterns(optimized_image, top_images)

        print("\nStep 4: Creating visualizations...")
        self.visualize_results(analysis, save_dir="activating0")

        return analysis


## Example with one neuron

In [None]:
analyzer = Analyzer("mixed4a")

results = analyzer.run_complete_analysis(
    neuron=1,
    num_samples=500,
    top_k=10,
    optim_steps=512
)

## What to Look For in the Results
As you explore different neurons, ask:

- Does the activation maximization and the real images share visual motifs?

- Is the neuron sensitive to color, texture, shape, or structure?

- Are activation maps localized or diffuse?

Here are the 10 images that activate the most for the first 10 neurons of the `mixed_4a` layer. Each neuron is analyzed independently, and the results are stored in a dictionary for inspection. 

Try to identify why the neuron activated highly for this image, and where the pattern shows up!

In [None]:
analyzer = Analyzer(target_layer="mixed_4a")

all_results = {}

for neuron_id in range(1, 11):
    print("=" * 60)
    print(f"Analyzing neuron {neuron_id} in layer mixed_4a")
    print("=" * 60)

    analysis = analyzer.run_complete_analysis(
        neuron=neuron_id,
        num_samples=500,
        top_k=10,
        optim_steps=512
    )

    all_results[neuron_id] = analysis