# Model inspection in pytorch

> This notebooks introduces some functions to explore the internals of DNN models implemented in [pytorch](https://pytorch.org/).

## 📓 Table of Contents

- [Load model](#Load)
- [Inspect modules](#Inspect-the-modules-of-the-model)
    - [modules()](#.modules())
    - [named_modules()](#.named_modules())
    - [{module-name}](#.{module-name})
    - [children()](#.children())
- [Inspect parameters](#Inspect-the-parameters-of-the-model)
    - [parameters()](#.parameters())
    - [named_parameters()](#named_parameters())
    - [{parameter_name}](#{parameter_name})
    - [state_dict()](#.state-dict())
- [Feature extraction](#Feature-extraction)
    - [hooks](#Hooks)

## Load

In [1]:
%%javascript
IPython.OutputArea.auto_scroll_threshold = 11

<IPython.core.display.Javascript object>

In [2]:
# import ipywidgets as widgets
# import matplotlib.pyplot as plt
# import numpy as np
# import random
# import seaborn as sns
# import skimage.io as io
# import torch


We will first load an example DNN to inspect. Let's import the pretrained [resnet50](https://pytorch.org/vision/stable/models.html#torchvision.models.resnet50) available in pytorch:

In [3]:
from torchvision.models import resnet50

resnet_model = resnet50(pretrained=True)

Let's set the model to be in evaluation mode [explain this better]:

In [4]:
resnet_model.eval();

## Inspect the modules of the model

### `.modules()`

The function [`.modules()`](https://pytorch.org/docs/stable/generated/torch.nn.Module.html?highlight=modules#torch.nn.Module.modules) is a [generator](https://realpython.com/introduction-to-python-generators/) function that returns an iterator containing the modules of a DNN.

In [5]:
resnet_model.modules()

<generator object Module.modules at 0x1095bd6d8>

The output sequence follows a tree-like structure. The first element of the sequence will contain the entire model:

In [6]:
resnet_modules = list(resnet_model.modules())
print(resnet_modules[0])

ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): Bottleneck(
      (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (downsample): Sequential(
        (0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 

The second element will only contain the first module of the model:

In [7]:
print(resnet_modules[1])

Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)


If the module contains other submodules, all of them will be stored first:

In [8]:
print(resnet_modules[5])

Sequential(
  (0): Bottleneck(
    (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
    (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
    (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU(inplace=True)
    (downsample): Sequential(
      (0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
  )
  (1): Bottleneck(
    (conv1): Conv2d(256, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
    (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (conv2): Conv2d(64, 64, ke

And then recursively accessed:

In [9]:
print(resnet_modules[6])

Bottleneck(
  (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
  (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (downsample): Sequential(
    (0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
    (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  )
)


In [10]:
print(resnet_modules[7])

Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)


In [11]:
print(resnet_modules[8])

BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)


### `.named_modules()`

We can also use the function `named_modules()` to get the module names along with the module itself:

In [12]:
resnet_named_modules = list(resnet_model.named_modules())
print(resnet_named_modules[1])

('conv1', Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False))


Each element of the output contains a tuple, where the first element is the name and the second the module:

In [13]:
print(resnet_named_modules[1][0])

conv1


In [14]:
print(resnet_named_modules[14][0])

layer1.0.downsample


### `.{module-name}`

Now that we know the name assigned to every module, we can also access the modules by calling their names:

In [15]:
resnet_model.conv1

Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)

Be aware that if the module contains a sequential element, the submodules can only first be accessed by indexing. For example, let's inspect `layer1` of the model:

In [16]:
resnet_model.layer1

Sequential(
  (0): Bottleneck(
    (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
    (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
    (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU(inplace=True)
    (downsample): Sequential(
      (0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
  )
  (1): Bottleneck(
    (conv1): Conv2d(256, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
    (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (conv2): Conv2d(64, 64, ke

If I wanted to access the first submodule of the first bottleneck layer in `layer1`, running `resnet_model.layer1.0.conv1` would fail. Uncomment the cell below to inspect the error it would raise:

In [17]:
#resnet_model.layer1.0.conv1

We need to index the sequential part instead:

In [18]:
resnet_model.layer1[0].conv1

Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)

### `.children()`

The function `.children()` can be used to access the first level of children modules (or also called submodules)

In [19]:
resnet_children = list(resnet_model.children())
print(resnet_children[0])

Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)


In [20]:
print(resnet_children[5])

Sequential(
  (0): Bottleneck(
    (conv1): Conv2d(256, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
    (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
    (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (conv3): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
    (bn3): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU(inplace=True)
    (downsample): Sequential(
      (0): Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2), bias=False)
      (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
  )
  (1): Bottleneck(
    (conv1): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
    (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (conv2): Conv2d(

If a children module contains other children modules, these won't be accessed. Let's see what happens when we access the next element:

In [21]:
print(resnet_children[6])

Sequential(
  (0): Bottleneck(
    (conv1): Conv2d(512, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
    (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
    (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)
    (bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU(inplace=True)
    (downsample): Sequential(
      (0): Conv2d(512, 1024, kernel_size=(1, 1), stride=(2, 2), bias=False)
      (1): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
  )
  (1): Bottleneck(
    (conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
    (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (conv2): Co

But we can always access these submodules by calling children on the children module:

In [22]:
second_level_children = list(resnet_children[6].children())
print(second_level_children[0])

Bottleneck(
  (conv1): Conv2d(512, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
  (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
  (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)
  (bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (downsample): Sequential(
    (0): Conv2d(512, 1024, kernel_size=(1, 1), stride=(2, 2), bias=False)
    (1): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  )
)


In [23]:
third_level_children = list(second_level_children[0].children())
print(third_level_children[0])

Conv2d(512, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)


We can also use [`named_children()`]() to both return the modules and their names. Let's print the names of the first level children modules:

In [24]:
# use list comprehension to only store the name of the children modules
resnet_children_names = [name for name, _ in resnet_model.named_children()]
print(resnet_children_names)

['conv1', 'bn1', 'relu', 'maxpool', 'layer1', 'layer2', 'layer3', 'layer4', 'avgpool', 'fc']


## Inspect the parameters of the model

### `.parameters()`

In pytorch, the learnable parameters of the model can be accessed by calling the generator function `.parameters()`.

In [25]:
resnet_parameters = list(resnet_model.parameters())
print(f"Number of parameters: {len(resnet_parameters)}")

Number of parameters: 161


Let's inspect the contents of the first parameter:

In [26]:
print(f"type: {type(resnet_parameters[0])}")
print(f"shape: {resnet_parameters[0].shape}")
print(f"content: {resnet_parameters[0]}")

type: <class 'torch.nn.parameter.Parameter'>
shape: torch.Size([64, 3, 7, 7])
content: Parameter containing:
tensor([[[[ 1.3335e-02,  1.4664e-02, -1.5351e-02,  ..., -4.0896e-02,
           -4.3034e-02, -7.0755e-02],
          [ 4.1205e-03,  5.8477e-03,  1.4948e-02,  ...,  2.2060e-03,
           -2.0912e-02, -3.8517e-02],
          [ 2.2331e-02,  2.3595e-02,  1.6120e-02,  ...,  1.0281e-01,
            6.2641e-02,  5.1977e-02],
          ...,
          [-9.0349e-04,  2.7767e-02, -1.0105e-02,  ..., -1.2722e-01,
           -7.6604e-02,  7.8453e-03],
          [ 3.5894e-03,  4.8006e-02,  6.2051e-02,  ...,  2.4267e-02,
           -3.3662e-02, -1.5709e-02],
          [-8.0029e-02, -3.2238e-02, -1.7808e-02,  ...,  3.5359e-02,
            2.2439e-02,  1.7077e-03]],

         [[-1.8452e-02,  1.1415e-02,  2.3850e-02,  ...,  5.3736e-02,
            4.4022e-02, -9.4675e-03],
          [-7.7273e-03,  1.8890e-02,  6.7981e-02,  ...,  1.5956e-01,
            1.4606e-01,  1.1999e-01],
          [-4.6013

Let's inspect the shape of the first ten parameter:

In [27]:
parameters_size = [parameter.shape for parameter in resnet_parameters]
parameters_size[:10]

[torch.Size([64, 3, 7, 7]),
 torch.Size([64]),
 torch.Size([64]),
 torch.Size([64, 64, 1, 1]),
 torch.Size([64]),
 torch.Size([64]),
 torch.Size([64, 64, 3, 3]),
 torch.Size([64]),
 torch.Size([64]),
 torch.Size([256, 64, 1, 1])]

### `named_parameters()`

To more easily explore what each parameter represents, we can use `named_parameters()` which outputs both the name of the parameter and its contents. Let's print the name of the first ten parameters and their shape:

In [28]:
resnet_named_parameters = [f"{name}: {parameter.shape}" for name, parameter in resnet_model.named_parameters()]
resnet_named_parameters[:10]

['conv1.weight: torch.Size([64, 3, 7, 7])',
 'bn1.weight: torch.Size([64])',
 'bn1.bias: torch.Size([64])',
 'layer1.0.conv1.weight: torch.Size([64, 64, 1, 1])',
 'layer1.0.bn1.weight: torch.Size([64])',
 'layer1.0.bn1.bias: torch.Size([64])',
 'layer1.0.conv2.weight: torch.Size([64, 64, 3, 3])',
 'layer1.0.bn2.weight: torch.Size([64])',
 'layer1.0.bn2.bias: torch.Size([64])',
 'layer1.0.conv3.weight: torch.Size([256, 64, 1, 1])']

### `{parameter_name}`

If we know the parameter names, we can also access them directly like:

In [29]:
resnet_model.conv1.weight

Parameter containing:
tensor([[[[ 1.3335e-02,  1.4664e-02, -1.5351e-02,  ..., -4.0896e-02,
           -4.3034e-02, -7.0755e-02],
          [ 4.1205e-03,  5.8477e-03,  1.4948e-02,  ...,  2.2060e-03,
           -2.0912e-02, -3.8517e-02],
          [ 2.2331e-02,  2.3595e-02,  1.6120e-02,  ...,  1.0281e-01,
            6.2641e-02,  5.1977e-02],
          ...,
          [-9.0349e-04,  2.7767e-02, -1.0105e-02,  ..., -1.2722e-01,
           -7.6604e-02,  7.8453e-03],
          [ 3.5894e-03,  4.8006e-02,  6.2051e-02,  ...,  2.4267e-02,
           -3.3662e-02, -1.5709e-02],
          [-8.0029e-02, -3.2238e-02, -1.7808e-02,  ...,  3.5359e-02,
            2.2439e-02,  1.7077e-03]],

         [[-1.8452e-02,  1.1415e-02,  2.3850e-02,  ...,  5.3736e-02,
            4.4022e-02, -9.4675e-03],
          [-7.7273e-03,  1.8890e-02,  6.7981e-02,  ...,  1.5956e-01,
            1.4606e-01,  1.1999e-01],
          [-4.6013e-02, -7.6075e-02, -8.9648e-02,  ...,  1.2108e-01,
            1.6705e-01,  1.7619e-01]

The output will be of class `torch.nn.parameter.Parameter`

In [30]:
type(resnet_model.conv1.weight)

torch.nn.parameter.Parameter

If we only wanted the tensor storing the values of the parameter, we can obtain it by accessing the `.data` attribute of the object:

In [31]:
parameter_data = resnet_model.conv1.weight.data
print(f"type: {type(parameter_data)}")
print(f"shape: {parameter_data.shape}")

type: <class 'torch.Tensor'>
shape: torch.Size([64, 3, 7, 7])


We can also find out if the computation of the gradient is required for that tensor:

In [32]:
resnet_model.conv1.weight.requires_grad

True

### `.state dict()`

But PyTorch provides an easier way to access this data, with the function `state_dict()` which stores these variables conveniently in an ordered dictionary:

In [33]:
resnet_state_dict = resnet_model.state_dict()

print(f"type: {type(resnet_state_dict)}")
print(list(resnet_state_dict.keys())[:10])

type: <class 'collections.OrderedDict'>
['conv1.weight', 'bn1.weight', 'bn1.bias', 'bn1.running_mean', 'bn1.running_var', 'bn1.num_batches_tracked', 'layer1.0.conv1.weight', 'layer1.0.bn1.weight', 'layer1.0.bn1.bias', 'layer1.0.bn1.running_mean']


We can access the values of the parameters then by accessing each corresponding key:

In [34]:
resnet_state_dict["conv1.weight"]

tensor([[[[ 1.3335e-02,  1.4664e-02, -1.5351e-02,  ..., -4.0896e-02,
           -4.3034e-02, -7.0755e-02],
          [ 4.1205e-03,  5.8477e-03,  1.4948e-02,  ...,  2.2060e-03,
           -2.0912e-02, -3.8517e-02],
          [ 2.2331e-02,  2.3595e-02,  1.6120e-02,  ...,  1.0281e-01,
            6.2641e-02,  5.1977e-02],
          ...,
          [-9.0349e-04,  2.7767e-02, -1.0105e-02,  ..., -1.2722e-01,
           -7.6604e-02,  7.8453e-03],
          [ 3.5894e-03,  4.8006e-02,  6.2051e-02,  ...,  2.4267e-02,
           -3.3662e-02, -1.5709e-02],
          [-8.0029e-02, -3.2238e-02, -1.7808e-02,  ...,  3.5359e-02,
            2.2439e-02,  1.7077e-03]],

         [[-1.8452e-02,  1.1415e-02,  2.3850e-02,  ...,  5.3736e-02,
            4.4022e-02, -9.4675e-03],
          [-7.7273e-03,  1.8890e-02,  6.7981e-02,  ...,  1.5956e-01,
            1.4606e-01,  1.1999e-01],
          [-4.6013e-02, -7.6075e-02, -8.9648e-02,  ...,  1.2108e-01,
            1.6705e-01,  1.7619e-01],
          ...,
     

## Feature extraction

There are many ways one could extract the representations of a DNN in reponse to an input in PyTorch. [This blogpost](https://pytorch.org/blog/FX-feature-extraction-torchvision/) gives a nice overview of the pros and cons of each one. In this notebook we will only explore those methods that don't involve modyifing the source code of the model to obtain these internal representations.

{explain this better}

### Hooks

One way to extract how the input is transformed across the network is to use the function `register_forward_hook`. 

We first need to define a function where given a model module that has some input and output, we can register them in a dictionary 

{explain better}

In [35]:
from collections import OrderedDict

repr_input = OrderedDict()
repr_output = OrderedDict()

def get_representation(name):
    def hook(model, input, output):
        repr_input[name] = input[0].detach()
        repr_output[name] = output.detach()
    return hook

We then:

In [36]:
for name, layer in resnet_model.named_modules():
    layer.register_forward_hook(get_representation(name))

In [37]:
from torchvision import transforms

normalize = transforms.Normalize(
    mean=[0.485, 0.456, 0.406],
    std=[0.229, 0.224, 0.225]
)

preprocessing = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    normalize,
])

In [38]:
from PIL import Image

img = Image.open("./images/bird_0.jpg")

In [39]:
prepro_img = preprocessing(img)
prepro_img

tensor([[[0.5707, 0.6049, 0.6049,  ..., 0.6563, 0.6392, 0.6392],
         [0.5193, 0.5536, 0.5536,  ..., 0.6392, 0.6049, 0.6049],
         [0.4679, 0.4851, 0.5022,  ..., 0.6049, 0.5707, 0.5707],
         ...,
         [0.5707, 0.5193, 0.5536,  ..., 0.8789, 0.8618, 0.8789],
         [0.6563, 0.5878, 0.5878,  ..., 0.9132, 0.8961, 0.8961],
         [0.7419, 0.6734, 0.6734,  ..., 0.9303, 0.9474, 0.9303]],

        [[0.6954, 0.7304, 0.7304,  ..., 0.8004, 0.7829, 0.7829],
         [0.6429, 0.6779, 0.6779,  ..., 0.7829, 0.7479, 0.7479],
         [0.6078, 0.6254, 0.6429,  ..., 0.7479, 0.7129, 0.7129],
         ...,
         [0.7304, 0.6779, 0.6954,  ..., 0.9230, 0.9055, 0.9230],
         [0.8179, 0.7479, 0.7479,  ..., 0.9405, 0.9230, 0.9230],
         [0.9055, 0.8354, 0.8179,  ..., 0.9580, 0.9755, 0.9580]],

        [[0.8274, 0.8622, 0.8622,  ..., 0.8797, 0.8622, 0.8622],
         [0.7751, 0.8099, 0.8099,  ..., 0.8622, 0.8274, 0.8274],
         [0.7402, 0.7402, 0.7751,  ..., 0.8099, 0.7751, 0.

In [None]:
#resnet_model(prepro_img)

In [None]:
# for layer in [name for name, _ in resnet_model.named_modules()][1:20]:
#     print(layer)
#     print(f"Input shape: {activation_input[layer].shape}")
#     print(f"Output shape: {activation_output[layer].shape}")

We can make sure that the outputs of a module are the inputs of the next module, for example:

In [None]:
#torch.equal(activation_input["layer1.0.conv1"], activation_input["layer1.0.downsample"])

The input to the layer `downsample` is the same as the input of the whole block.

In [None]:
#torch.equal(activation_output["conv1"], activation_input["bn1"])

There are also pre-forward hooks and backward hooks. Read here: https://medium.com/the-owl/using-forward-hooks-to-extract-intermediate-layer-outputs-from-a-pre-trained-model-in-pytorch-1ec17af78712

Read here to understand how to add forward hooks for non sequential modules only: https://blog.paperspace.com/pytorch-hooks-gradient-clipping-debugging/

## References / Resources

- [What is a state dict in pytorch?](https://pytorch.org/tutorials/recipes/recipes/what_is_state_dict.html)
- [How can I load my best model as a feature extractor/evaluator?](https://discuss.pytorch.org/t/how-can-l-load-my-best-model-as-a-feature-extractor-evaluator/17254/47)