## ResNet

### 1 残差块
普通的网络结构（左）与加入残差连接的网络结构（右）:
![avatar](../resource/residual-block.svg)
在右图所示的残差块中，虚线框内要学习的是残差映射$f(x) - x$，当理想映射接近恒等映射时（即$f(x) = x$），虚线框内上方的加权运算的权重和偏差参数会被学习为0。此时的残差映射可以捕捉恒等映射的细微波动。

In [1]:
import time
import torch
from torch import nn, optim
import torch.nn.functional as F
import my_utils

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [19]:
class Residual(nn.Module):
    # ResNet沿用了VGG全3×3卷积层的设计。残差块里首先有2个有相同输出通道数的3×3卷积层
    # 每个卷积层后接一个批量归一化层。然后我们将输入跳过这两个卷积运算后直接加在最后的ReLU激活函数前
    def __init__(self, in_channels, out_channels, use_1x1conv=False, stride=1):
        super(Residual, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, stride=stride)
        self.bn1 = nn.BatchNorm2d(out_channels)
        
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1)
        self.bn2 = nn.BatchNorm2d(out_channels)
        
        if use_1x1conv:
            # 想要改变通道数
            self.conv3 = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride)
        else:
            self.conv3 = None
    
    def forward(self, X):
        Y = F.relu(self.bn1(self.conv1(X)))
        Y = self.bn2(self.conv2(Y))
        
        if self.conv3:
            X = self.conv3(X)
        return F.relu(Y + X)

In [20]:
blk = Residual(3, 3)
X = torch.rand((4, 3, 6, 6))
blk(X).shape 

torch.Size([4, 3, 6, 6])

In [21]:
blk = Residual(3, 6, use_1x1conv=True, stride=2)
X = torch.rand((4, 3, 6, 6))
blk(X).shape

torch.Size([4, 6, 3, 3])

### 2 ResNet模型
ResNet第一层与GooLeNet第一层一样，在输出通道数为64、步幅为2的7×7卷积层后接步幅为2的3×3的最大池化层。不同之处在于ResNet在卷积层后增加的批量归一化层。

In [15]:
net = nn.Sequential(
    nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3),
    nn.BatchNorm2d(64),
    nn.ReLU(),
    nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
)

GoogLeNet在后面接了4个由Inception块组成的模块。ResNet则使用4个由残差块组成的模块，每个模块使用若干个同样输出通道数的残差块，第一个模块的通道数同输入通道数一致。每个模块在第一个残差块里将上一个模块的通道数翻倍，并将高和宽减半。

In [24]:
def resnet_block(in_channels, out_channels, num_residuals, first_block=False):
    if first_block:
        assert in_channels == out_channels
    blk = []
    for i in range(num_residuals):
        if i == 0 and not first_block:
            blk.append(Residual(in_channels, out_channels, use_1x1conv=True, stride=2))
        else:
            blk.append(Residual(out_channels, out_channels))
    return nn.Sequential(*blk)

In [25]:
net.add_module("resnet_block1", resnet_block(64, 64, 2, first_block=True))
net.add_module("resnet_block2", resnet_block(64, 128, 2))
net.add_module("resnet_block3", resnet_block(128, 256, 2))
net.add_module("resnet_block4", resnet_block(256, 512, 2))

In [27]:
net.add_module('global_avg_pool', my_utils.GlobalAvgPool2d())
net.add_module('fc', nn.Sequential(my_utils.FlattenLayer(), nn.Linear(512, 10)))

In [28]:
print(net)
X = torch.rand((1, 1, 224, 224))
for name, layer in net.named_children():
    X = layer(X)
    print(name, ' output shape:\t', X.shape)

Sequential(
  (0): Conv2d(1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3))
  (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (2): ReLU()
  (3): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (resnet_block1): Sequential(
    (0): Residual(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): Residual(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, a

训练：

In [30]:
batch_size = 256
train_iter, test_iter = my_utils.load_data_fashion_mnist(batch_size, resize=96)

lr, num_epochs = 0.001, 5
optimizer = torch.optim.Adam(net.parameters(), lr=lr)
my_utils.train_cnn(net, train_iter, test_iter, batch_size, optimizer, device, num_epochs)