<div style="display: flex; justify-content: space-between; align-items: center;">
    <div style="text-align: left; flex: 4">
        <strong>Author:</strong> Amirhossein Heydari ‚Äî 
        üìß <a href="mailto:amirhosseinheydari78@gmail.com">amirhosseinheydari78@gmail.com</a> ‚Äî 
        üêô <a href="https://github.com/mr-pylin/pytorch-workshop" target="_blank" rel="noopener">github.com/mr-pylin</a>
    </div>
    <div style="text-align: right; flex: 1;">
        <a href="https://pytorch.org/" target="_blank" rel="noopener noreferrer">
            <img src="../assets/images/pytorch/logo/pytorch-logo-dark.svg" 
                 alt="PyTorch Logo"
                 style="max-height: 48px; width: auto; background-color: #ffffff; border-radius: 8px;">
        </a>
    </div>
</div>
<hr>


**Table of contents**<a id='toc0_'></a>    
- [Dependencies](#toc1_)    
- [Model Creation](#toc2_)    
  - [Create from Scratch](#toc2_1_)    
  - [Pre-defined](#toc2_2_)    
- [PyTorch Model Components](#toc3_)    
  - [Model Children](#toc3_1_)    
  - [Model Modules](#toc3_2_)    
  - [Model Parameters](#toc3_3_)    
- [Feature Extraction](#toc4_)    
  - [Method 1: Modify the Model (Forward Truncation)](#toc4_1_)    
    - [Top-level Feature Extraction](#toc4_1_1_)    
    - [Hierarchical Feature Extraction](#toc4_1_2_)    
  - [Method 2: Using `nn.Module` hooks](#toc4_2_)    
    - [Custom Function](#toc4_2_1_)    
    - [Custom Class](#toc4_2_2_)    
  - [Method 3: Using `torchvision.models.feature_extraction` API (Recommended)](#toc4_3_)    
    - [A Complete Example](#toc4_3_1_)    
      - [Load Dataset](#toc4_3_1_1_)    
      - [Load Pre-trained Model](#toc4_3_1_2_)    
      - [Extract Feature Maps](#toc4_3_1_3_)    
        - [layer1 Feature Maps](#toc4_3_1_3_1_)    
        - [layer4.1.conv2 Feature Maps](#toc4_3_1_3_2_)    

<!-- vscode-jupyter-toc-config
	numbering=false
	anchor=true
	flat=false
	minLevel=1
	maxLevel=6
	/vscode-jupyter-toc-config -->
<!-- THIS CELL WILL BE REPLACED ON TOC UPDATE. DO NOT WRITE YOUR TEXT IN THIS CELL -->

# <a id='toc1_'></a>[Dependencies](#toc0_)


In [None]:
import matplotlib.pyplot as plt
import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision.datasets import CIFAR10
from torchvision.models.feature_extraction import create_feature_extractor
from torchvision.models import resnet50, ResNet50_Weights
from torchvision.transforms import v2

In [None]:
# disable automatic figure display (plt.show() required)
# this ensures consistency with .py scripts and gives full control over when plots appear
plt.ioff()

In [None]:
# set a seed for deterministic results
seed = 42
torch.manual_seed(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

In [None]:
# check if cuda is available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# log
device

In [None]:
DATASET_DIR = "../datasets"

# <a id='toc2_'></a>[Model Creation](#toc0_)


## <a id='toc2_1_'></a>[Create from Scratch](#toc0_)


In [None]:
class CustomModel(nn.Module):

    def __init__(self):
        super().__init__()

        # first block
        self.block1 = nn.Sequential(
            nn.Conv2d(in_channels=3, out_channels=16, kernel_size=3, padding=1),
            nn.ReLU(),
        )

        # second block
        self.block2 = nn.Sequential(
            nn.Conv2d(16, 32, 3, padding=1),
            nn.ReLU(),
        )

        # classifier
        self.classifier = nn.Linear(32 * 8 * 8, 10)

    def forward(self, x):
        x = self.block1(x)
        x = self.block2(x)
        x = torch.flatten(x, start_dim=1)
        x = self.classifier(x)
        return x

In [None]:
model_1 = CustomModel()

In [None]:
model_1

## <a id='toc2_2_'></a>[Pre-defined](#toc0_)


In [None]:
model_2 = resnet50(weights=None)

In [None]:
model_2

# <a id='toc3_'></a>[PyTorch Model Components](#toc0_)


## <a id='toc3_1_'></a>[Model Children](#toc0_)

- Refers to the **immediate submodules** of a PyTorch nn.Module.
- Accessed via `nn.Module.children()` or `nn.Module.named_children()`.
- Only returns top-level layers, not nested layers.
- Useful when you want to iterate over the main blocks of a model without diving recursively.


In [None]:
for i, (name, child) in enumerate(model_1.named_children(), start=1):
    print(f"child: {i}")
    print(f"{name}: {child}")
    print("-" * 50)

## <a id='toc3_2_'></a>[Model Modules](#toc0_)

- Refers to all **submodules recursively**, including children, grandchildren, etc.
- Accessed via `nn.Module.modules()` or `nn.Module.named_modules()`.
- Includes the root module itself as the first element.
- Useful for global inspection of the full model hierarchy, e.g., for hooks or feature extraction.


In [None]:
for i, (name, module) in enumerate(model_1.named_modules(), start=1):
    print(f"module: {i}")
    if name:
        print(f"{name}: {module}")
    else:
        print(module)
    print("-" * 50)

## <a id='toc3_3_'></a>[Model Parameters](#toc0_)

- Refers to the **trainable tensors** in the model: weights, biases, and other parameters.
- Accessed via `nn.Module.parameters()` or `nn.Module.named_parameters()`.
- Returns either raw tensors or (name, tensor) pairs.
- Useful for optimization, freezing layers, or inspecting shapes of parameters.


In [None]:
for i, (name, param) in enumerate(model_1.named_parameters(), start=1):
    print(f"parameter: {i}")
    print(f"{name}: {param.shape}")
    print("-" * 50)

# <a id='toc4_'></a>[Feature Extraction](#toc0_)

- Each layer in a Neural Network, transforms the input into a new representation.
- The output of any intermediate layer is called a **feature representation** (or activation).

**Why Do We Extract Features?**:

- Transfer Learning
  - Extracted features from a pretrained model serve as a general-purpose representation of the input.
  - You can reuse these features and train a new classifier on top instead of retraining the entire network.
  - This reduces training time, data requirements, and overfitting.
- Representation Learning
  - Features are vector embeddings that encode semantic information about inputs.
  - By extracting them, you can evaluate how well the model separates classes or captures meaningful structure.
  - This is useful for clustering, similarity analysis, and embedding evaluation.
- Debugging Models
  - Inspecting extracted features helps detect internal issues such as dead neurons, feature collapse, or saturation.
  - These problems may not be visible from loss or accuracy alone but can prevent proper learning.
  - Feature analysis provides insight into the model‚Äôs internal behavior.
- Backdoor / Security Research
  - Backdoor triggers alter internal feature representations to force misclassification.
  - Extracting features allows you to detect abnormal representation shifts, clustering anomalies, and vulnerable layers.
  - This is essential for analyzing and defending against backdoor attacks.
- Knowledge Distillation
  - In knowledge distillation, the student model learns to reproduce the teacher‚Äôs internal feature representations.
  - Extracted features act as supervision signals that transfer learned knowledge.
  - This improves generalization, efficiency, and robustness of the student model.


## <a id='toc4_1_'></a>[Method 1: Modify the Model (Forward Truncation)](#toc0_)


### <a id='toc4_1_1_'></a>[Top-level Feature Extraction](#toc0_)


In [None]:
class TopLevelFeatureExtractor(nn.Module):
    def __init__(self, model: nn.Module, layers_to_extract: list[str]):
        super().__init__()
        self.model = model
        self.layers_to_extract = set(layers_to_extract)

    def forward(self, x):
        features = {}
        for name, child in self.model.named_children():
            if name == "classifier":
                break
            x = child(x)
            if name in self.layers_to_extract:
                features[name] = x
        return features

In [None]:
top_level_extractor = TopLevelFeatureExtractor(model_1, layers_to_extract=["block1", "block2"])
top_level_extractor

In [None]:
features = top_level_extractor(torch.randn(1, 3, 16, 16))

# log
for k, v in features.items():
    print(f"{k}: {v.shape}")

### <a id='toc4_1_2_'></a>[Hierarchical Feature Extraction](#toc0_)


In [None]:
class HierarchicalFeatureExtractor(nn.Module):
    def __init__(self, model: nn.Module, layers_to_extract: list[str]):
        super().__init__()
        self.model = model
        self.layers_to_extract = set(layers_to_extract)

    def forward(self, x):
        features = {}

        def _forward(module: nn.Module, input_x: torch.Tensor, prefix=""):
            out = input_x
            for name, child in module.named_children():
                full_name = f"{prefix}.{name}" if prefix else name

                # stop if we reach the classifier
                if full_name == "classifier":
                    return out

                # recursively process nested children first
                if list(child.children()):
                    out = _forward(child, out, prefix=full_name)
                else:
                    out = child(out)  # only apply leaf modules

                # save features if requested
                if full_name in self.layers_to_extract:
                    features[full_name] = out

            return out

        _forward(self.model, x)
        return features

In [None]:
Hierarchical_extractor = HierarchicalFeatureExtractor(model_1, layers_to_extract=["block1", "block2"])
Hierarchical_extractor

In [None]:
features = Hierarchical_extractor(torch.randn(1, 3, 16, 16))

# log
for k, v in features.items():
    print(f"{k}: {v.shape}")

## <a id='toc4_2_'></a>[Method 2: Using `nn.Module` hooks](#toc0_)


### <a id='toc4_2_1_'></a>[Custom Function](#toc0_)


In [None]:
# dictionary to store features
features_dict = {}

In [None]:
# hook function that saves features by layer name
def hook_fn(name):
    def fn(module, input, output):
        features_dict[name] = output

    return fn

In [None]:
# register hooks on multiple layers
hook1 = model_1.block1.register_forward_hook(hook_fn("block1"))
hook2 = model_1.block2.register_forward_hook(hook_fn("block2"))

In [None]:
_ = model_1(torch.randn(1, 3, 8, 8))

In [None]:
# remove hooks
hook1.remove()
hook2.remove()

In [None]:
# log
for k, v in features_dict.items():
    print(f"{k}: {v.shape}")

### <a id='toc4_2_2_'></a>[Custom Class](#toc0_)


In [None]:
class FeatureHook:
    def __init__(self, modules: dict[str, nn.Module]):
        self.modules = modules
        self.features = {}
        self.hooks = []
        self._register_hooks()

    def _register_hooks(self):
        for name, module in self.modules.items():
            hook = module.register_forward_hook(self._make_hook(name))
            self.hooks.append(hook)

    def _make_hook(self, name):
        """
        Create a hook function for a specific layer.

        Args:
            name (str): The key under which the feature output of this layer
                will be stored in `self.features`.

        Returns:
            function: A hook function compatible with PyTorch's `register_forward_hook`.
        """

        def hook_fn(module, input, output):
            self.features[name] = output

        return hook_fn

    def remove(self):
        """
        Remove all registered hooks to prevent memory leaks.
        After calling this method, `self.hooks` is cleared.
        """
        for hook in self.hooks:
            hook.remove()
        self.hooks = []

In [None]:
hook = FeatureHook(
    modules={
        "block1": model_1.block1,
        "block2": model_1.block2,
    },
)

In [None]:
_ = model_1(torch.randn(1, 3, 8, 8))

In [None]:
# log
for k, v in hook.features.items():
    print(f"{k}: {v.shape}")

In [None]:
hook.remove()

## <a id='toc4_3_'></a>[Method 3: Using `torchvision.models.feature_extraction` API (Recommended)](#toc0_)


In [None]:
return_nodes = {"layer4.1.bn3": "feature1", "avgpool": "feature2"}
feature_extractor = create_feature_extractor(model_2, return_nodes=return_nodes)

In [None]:
features = feature_extractor(torch.randn(1, 3, 224, 224))

# log
for k, v in features.items():
    print(f"{k}: {v.shape}")

### <a id='toc4_3_1_'></a>[A Complete Example](#toc0_)


#### <a id='toc4_3_1_1_'></a>[Load Dataset](#toc0_)


In [None]:
transform = v2.Compose(
    [
        v2.ToImage(),
        v2.Resize((224, 224)),
        v2.ToDtype(dtype=torch.float32, scale=True),
        v2.Normalize(mean=(0.5,), std=(0.5,)),
    ]
)

In [None]:
testset = CIFAR10(DATASET_DIR, train=False, transform=transform, download=False)
testloader = DataLoader(testset, batch_size=1, shuffle=False)

In [None]:
x, y = next(iter(testloader))

In [None]:
# plot
fig, axs = plt.subplots(nrows=1, ncols=2, figsize=(8, 4), layout="compressed")
fig.suptitle(f"Label: {y.item()} ({testset.classes[y]})")
axs[0].imshow(testset.data[0])
axs[0].axis("off")
axs[0].set_title("Before Transform")
axs[1].imshow(x.detach().cpu()[0].permute(1, 2, 0).clamp(0, 1))
axs[1].axis("off")
axs[1].set_title("After Transform")
plt.show()

#### <a id='toc4_3_1_2_'></a>[Load Pre-trained Model](#toc0_)


In [None]:
pre_trained_resnet50 = resnet50(weights=ResNet50_Weights.IMAGENET1K_V2)

In [None]:
pre_trained_resnet50

#### <a id='toc4_3_1_3_'></a>[Extract Feature Maps](#toc0_)


In [None]:
return_nodes = {
    "layer1": "feature_maps_1",
    "layer4.1.conv2": "feature_maps_2",
}

In [None]:
feature_extractor = create_feature_extractor(pre_trained_resnet50, return_nodes)
features = feature_extractor(x)

##### <a id='toc4_3_1_3_1_'></a>[layer1 Feature Maps](#toc0_)


In [None]:
# plot
nrows = ncols = int(features["feature_maps_1"].shape[1] ** 0.5)
fig, axs = plt.subplots(nrows=nrows, ncols=ncols, figsize=(ncols, nrows), layout="compressed")
fig.suptitle("pre_trained_resnet50.layer1(x)")
for row in range(nrows):
    for col in range(ncols):
        axs[row, col].imshow(features["feature_maps_1"][0, row * ncols + col].detach().cpu(), cmap="gray")
        axs[row, col].axis("off")
        axs[row, col].set(title=row * ncols + col)
plt.show()

##### <a id='toc4_3_1_3_2_'></a>[layer4.1.conv2 Feature Maps](#toc0_)


In [None]:
# plot
nrows = 16
ncols = 32
fig, axs = plt.subplots(nrows=nrows, ncols=ncols, figsize=(ncols, nrows), layout="compressed")
fig.suptitle("pre_trained_resnet50.layer1[1].conv2(x)")
for row in range(nrows):
    for col in range(ncols):
        axs[row, col].imshow(features["feature_maps_2"][0, row * ncols + col].detach().cpu(), cmap="gray")
        axs[row, col].axis("off")
        axs[row, col].set(title=row * ncols + col)
plt.show()