- 1. 残差块

ResNet沿用了VGG设计, 使用 3 * 3 的卷积核, 填充为 1, 步长为 1, 输出通道数等于输入通道数, 激活函数为 ReLU

In [3]:
from torch import nn
from torch.nn import BatchNorm2d
from torch.nn import functional as F

#@save
class ResidualBlock(nn.Module):
    def __init__(self, input_channels, num_channels, use_1x1conv=False, strides=1):
        super().__init__()
        # 定义残差包含的组件, 具体组件之间怎么流转的在 forward 中进行操作
        # 残差组件, 包含两个卷积层 、 两个批量规范化层、一个激活层, 卷积层之后跟着规范化层, 如果需要使用1*1卷积层来调整通道数目, 需要在定义一个卷积层

        self.conv1 = nn.Conv2d(in_channels=input_channels, out_channels=num_channels, kernel_size=3, padding=1, stride=strides)
        self.conv2 = nn.Conv2d(in_channels=num_channels, out_channels=num_channels, kernel_size=3, padding=1) # 第二个卷积层不调整宽高

        self.bn1 = BatchNorm2d(num_features=num_channels)
        self.bn2 = BatchNorm2d(num_features=num_channels)

        if use_1x1conv:
            self.conv3 = nn.Conv2d(in_channels=input_channels, out_channels=num_channels, kernel_size=1, stride=strides)
        else:
            self.conv3 = None

    def forward(self, X):
        Y = F.relu(self.bn1(self.conv1(X)))
        Y = self.bn2(self.conv2(Y))
        if self.conv3 is not None:
            X = self.conv3(X)
        Y += X

        return F.relu(Y)

import torch

blk = ResidualBlock(3, 3)
X = torch.randn(4, 3, 6, 6)
print(blk(X).shape)

torch.Size([4, 3, 6, 6])


In [4]:
blk = ResidualBlock(3, 6, use_1x1conv=True, strides=2)
X = torch.randn(4, 3, 6, 6)
print(blk(X).shape)

torch.Size([4, 6, 3, 3])
