# 注册钩子函数（register_forward_hook）

pytorch 注册钩子函数（register_forward_hook）,实现对各个层（layer）输入输出 shape 的查看

参考链接：
+ https://www.bilibili.com/video/BV1WY411N7fR
+ https://github.com/chunhuizhang/bilibili_vlogs/tree/master/learn_torch/utils

In [1]:
import timm
import torch
from torch import nn


def print_shape(m, i, o):
    #m: module, i: input, o: output
    # print(m, i[0].shape, o.shape)
    print(m, i[0].shape, '=>', o.shape)


def get_children(model: nn.Module):
    # get children form model!（递归）
    children = list(model.children())
    flatt_children = []
    if children == []:
        # if model has no children; model is last child! :O
        return model
    else:
       # look for children from children... to the last child!
       for child in children:
            try:
                flatt_children.extend(get_children(child))
            except TypeError:
                flatt_children.append(get_children(child))
    return flatt_children


model_name = 'vgg11'
# model_name = 'resnet34'
model = timm.create_model(model_name, pretrained=True)

flatt_children = get_children(model)
for layer in flatt_children:
    layer.register_forward_hook(print_shape)

# for layer in model.children():
#     layer.register_forward_hook(print_shape)

# 4d: batch*channel*width*height
batch_input = torch.randn(4, 3, 300, 300)

model(batch_input)

Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) torch.Size([4, 3, 300, 300]) => torch.Size([4, 64, 300, 300])
ReLU(inplace=True) torch.Size([4, 64, 300, 300]) => torch.Size([4, 64, 300, 300])
MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False) torch.Size([4, 64, 300, 300]) => torch.Size([4, 64, 150, 150])
Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) torch.Size([4, 64, 150, 150]) => torch.Size([4, 128, 150, 150])
ReLU(inplace=True) torch.Size([4, 128, 150, 150]) => torch.Size([4, 128, 150, 150])
MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False) torch.Size([4, 128, 150, 150]) => torch.Size([4, 128, 75, 75])


Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) torch.Size([4, 128, 75, 75]) => torch.Size([4, 256, 75, 75])
ReLU(inplace=True) torch.Size([4, 256, 75, 75]) => torch.Size([4, 256, 75, 75])
Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) torch.Size([4, 256, 75, 75]) => torch.Size([4, 256, 75, 75])
ReLU(inplace=True) torch.Size([4, 256, 75, 75]) => torch.Size([4, 256, 75, 75])
MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False) torch.Size([4, 256, 75, 75]) => torch.Size([4, 256, 37, 37])
Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) torch.Size([4, 256, 37, 37]) => torch.Size([4, 512, 37, 37])
ReLU(inplace=True) torch.Size([4, 512, 37, 37]) => torch.Size([4, 512, 37, 37])


Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) torch.Size([4, 512, 37, 37]) => torch.Size([4, 512, 37, 37])
ReLU(inplace=True) torch.Size([4, 512, 37, 37]) => torch.Size([4, 512, 37, 37])
MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False) torch.Size([4, 512, 37, 37]) => torch.Size([4, 512, 18, 18])
Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) torch.Size([4, 512, 18, 18]) => torch.Size([4, 512, 18, 18])
ReLU(inplace=True) torch.Size([4, 512, 18, 18]) => torch.Size([4, 512, 18, 18])
Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) torch.Size([4, 512, 18, 18]) => torch.Size([4, 512, 18, 18])
ReLU(inplace=True) torch.Size([4, 512, 18, 18]) => torch.Size([4, 512, 18, 18])
MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False) torch.Size([4, 512, 18, 18]) => torch.Size([4, 512, 9, 9])
Conv2d(512, 4096, kernel_size=(7, 7), stride=(1, 1)) torch.Size([4, 512, 9, 9]) => torch.Size([4, 4096,

Conv2d(4096, 4096, kernel_size=(1, 1), stride=(1, 1)) torch.Size([4, 4096, 3, 3]) => torch.Size([4, 4096, 3, 3])
ReLU(inplace=True) torch.Size([4, 4096, 3, 3]) => torch.Size([4, 4096, 3, 3])
AdaptiveAvgPool2d(output_size=1) torch.Size([4, 4096, 3, 3]) => torch.Size([4, 4096, 1, 1])
Flatten(start_dim=1, end_dim=-1) torch.Size([4, 4096, 1, 1]) => torch.Size([4, 4096])
Dropout(p=0.0, inplace=False) torch.Size([4, 4096]) => torch.Size([4, 4096])
Linear(in_features=4096, out_features=1000, bias=True) torch.Size([4, 4096]) => torch.Size([4, 1000])
Identity() torch.Size([4, 1000]) => torch.Size([4, 1000])


tensor([[-1.2878, -0.1586, -0.1201,  ..., -1.0280, -0.0363,  2.2444],
        [-1.2650, -0.2347, -0.1561,  ..., -0.9542,  0.0032,  2.3665],
        [-1.2061, -0.2263, -0.1208,  ..., -1.1237,  0.1227,  2.3543],
        [-1.3775, -0.2838, -0.5017,  ..., -1.0857,  0.1728,  2.5088]],
       grad_fn=<AddmmBackward0>)