# 注册钩子函数（register_forward_hook）

pytorch 注册钩子函数（register_forward_hook）,实现对各个层（layer）输入输出 shape 的查看

参考链接：
+ https://www.bilibili.com/video/BV1WY411N7fR
+ https://github.com/chunhuizhang/bilibili_vlogs/tree/master/learn_torch/utils

In [5]:
import timm
import torch
from torch import nn


def print_shape(m, i, o):
    #m: module, i: input, o: output
    # print(m, i[0].shape, o.shape)
    print(m, i[0].shape, '=>', o.shape)


def get_children(model: nn.Module):
    # get children form model!（递归）
    children = list(model.children())
    flatt_children = []
    if children == []:
        # if model has no children; model is last child! :O
        return model
    else:
       # look for children from children... to the last child!
       for child in children:
            try:
                flatt_children.extend(get_children(child))
            except TypeError:
                flatt_children.append(get_children(child))
    return flatt_children


model_name = 'vgg11'
# model_name = 'resnet34'
model = timm.create_model(model_name, pretrained=True)

flatt_children = get_children(model)
for layer in flatt_children:
    layer.register_forward_hook(print_shape)

# for layer in model.children():
#     layer.register_forward_hook(print_shape)

# 4d: batch*channel*width*height
batch_input = torch.randn(4, 3, 300, 300)

model(batch_input)

model.safetensors:   0%|          | 0.00/531M [00:00<?, ?B/s]

Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) torch.Size([4, 3, 300, 300]) => torch.Size([4, 64, 300, 300])
ReLU(inplace=True) torch.Size([4, 64, 300, 300]) => torch.Size([4, 64, 300, 300])
MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False) torch.Size([4, 64, 300, 300]) => torch.Size([4, 64, 150, 150])
Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) torch.Size([4, 64, 150, 150]) => torch.Size([4, 128, 150, 150])
ReLU(inplace=True) torch.Size([4, 128, 150, 150]) => torch.Size([4, 128, 150, 150])
MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False) torch.Size([4, 128, 150, 150]) => torch.Size([4, 128, 75, 75])
Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) torch.Size([4, 128, 75, 75]) => torch.Size([4, 256, 75, 75])
ReLU(inplace=True) torch.Size([4, 256, 75, 75]) => torch.Size([4, 256, 75, 75])
Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) torch.Size([4, 256, 75,

tensor([[-1.5516, -0.2116, -0.4501,  ..., -0.8903, -0.0109,  2.6101],
        [-1.2771, -0.0294, -0.0046,  ..., -0.9105,  0.0357,  2.0267],
        [-1.2822, -0.1448, -0.2532,  ..., -0.9623,  0.1757,  2.3927],
        [-1.2680, -0.0981, -0.0421,  ..., -0.8792, -0.0463,  2.1411]],
       grad_fn=<AddmmBackward0>)