In [1]:
import torch
from torch import nn
from torch.nn import functional as F
from d2l import torch as d2l

In [2]:
class Residual(nn.Module):
    def __init__(self, input_channels, num_channels, strides=1):
        super().__init__()
        self.conv1 = nn.Conv2d(input_channels, num_channels,
                               kernel_size=3, padding=1, stride=strides)
        self.conv2 = nn.Conv2d(num_channels, num_channels,
                               kernel_size=3, padding=1)
        
        if strides > 1 or input_channels != num_channels: 
            self.conv3 = nn.Conv2d(input_channels, num_channels,
                                   kernel_size=1, stride=strides)
        else:
            self.conv3 = None
        self.bn1 = nn.BatchNorm2d(num_channels)
        self.bn2 = nn.BatchNorm2d(num_channels)

    def forward(self, X):
        Y = F.relu(self.bn1(self.conv1(X)))
        Y = self.bn2(self.conv2(Y))
        if self.conv3:
            X = self.conv3(X)
        Y += X
        return F.relu(Y)

In [3]:
b1 = nn.Sequential(nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3), nn.BatchNorm2d(64), nn.ReLU(),
                   nn.MaxPool2d(kernel_size=3, stride=2, padding=1))

In [4]:
def resnet_block(input_channels, num_channels, num_residuals,
                 first_block=False):
    blk = []
    for i in range(num_residuals):
        if i == 0 and not first_block:
            blk.append(Residual(input_channels, num_channels, strides=2))
        else:
            blk.append(Residual(num_channels, num_channels))
    return blk

In [5]:
b2 = nn.Sequential(*resnet_block(64, 64, 2, first_block=True))
b3 = nn.Sequential(*resnet_block(64, 128, 2))
b4 = nn.Sequential(*resnet_block(128, 256, 2))
b5 = nn.Sequential(*resnet_block(256, 512, 2))

In [6]:
ResNet = nn.Sequential(b1, b2, b3, b4, b5,
                    nn.AdaptiveAvgPool2d((1,1)),
                    nn.Flatten(), nn.Linear(512, 1000))

In [8]:
X = torch.rand(size=(1, 3, 224, 224))
for layer in ResNet:
    X = layer(X)
    print(layer.__class__.__name__,'output shape:\t', X.shape)

Sequential output shape:	 torch.Size([1, 64, 56, 56])
Sequential output shape:	 torch.Size([1, 64, 56, 56])
Sequential output shape:	 torch.Size([1, 128, 28, 28])
Sequential output shape:	 torch.Size([1, 256, 14, 14])
Sequential output shape:	 torch.Size([1, 512, 7, 7])
AdaptiveAvgPool2d output shape:	 torch.Size([1, 512, 1, 1])
Flatten output shape:	 torch.Size([1, 512])
Linear output shape:	 torch.Size([1, 1000])


In [7]:
rec = []
def ParamNum(l):
    num = 0
    if isinstance(l, nn.Sequential):
        for l2 in l:
            num += ParamNum(l2)
    else:
        if isinstance(l, Residual):
            l_1 = l.conv1
            l_2 = l.conv2
            num = l_1.in_channels * l_1.out_channels * l_1.kernel_size[0] * l_1.kernel_size[1] + l_2.in_channels * l_2.out_channels * l_2.kernel_size[0] * l_2.kernel_size[1] 
            expr = f'{l_1.in_channels}*{l_1.out_channels}*{l_1.kernel_size[0]}*{l_1.kernel_size[1]}+{l_2.in_channels}*{l_2.out_channels}*{l_2.kernel_size[0]}*{l_2.kernel_size[1]}'
        elif isinstance(l, nn.Conv2d):
            num = l.in_channels * l.out_channels * l.kernel_size[0] * l.kernel_size[1]
            expr = f'{l.in_channels}*{l.out_channels}*{l.kernel_size[0]}*{l.kernel_size[1]}'
        elif isinstance(l, nn.Linear):
            num = l.in_features * l.out_features
            expr = f'{l.in_features}*{l.out_features}'
        if num > 0:
            rec.append([l.__class__.__name__, expr, num])
    return num

num_sum = ParamNum(ResNet)
print(f'{num_sum} = {round(num_sum / 1000 / 1000, 1)} M')
rec

11506880 = 11.5 M


[['Conv2d', '3*64*7*7', 9408],
 ['Residual', '64*64*3*3+64*64*3*3', 73728],
 ['Residual', '64*64*3*3+64*64*3*3', 73728],
 ['Residual', '64*128*3*3+128*128*3*3', 221184],
 ['Residual', '128*128*3*3+128*128*3*3', 294912],
 ['Residual', '128*256*3*3+256*256*3*3', 884736],
 ['Residual', '256*256*3*3+256*256*3*3', 1179648],
 ['Residual', '256*512*3*3+512*512*3*3', 3538944],
 ['Residual', '512*512*3*3+512*512*3*3', 4718592],
 ['Linear', '512*1000', 512000]]