5.12.1 稠密块

In [4]:
import time
import torch
from torch import nn, optim
import torch.nn.functional as F

import sys
sys.path.append('..')
import d2lzh_pytorch as d2l
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

def conv_block(in_channels, out_channels):
    blk = nn.Sequential(
        nn.BatchNorm2d(in_channels),
        nn.ReLU(),
        nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)
    )

    return blk

In [5]:
class DenseBlock(nn.Module):
    def __init__(self, num_convs, in_channels, out_channels) -> None:
        super(DenseBlock, self).__init__()

        net = []
        for i in range(num_convs):
            in_c = in_channels + i * out_channels
            net.append(conv_block(in_c, out_channels))
        self.net = nn.ModuleList(net)
        self.out_channels = in_channels + num_convs * out_channels

    def forward(self, x):
        for blk in self.net:
            y = blk(x)
            x = torch.cat((x, y), dim=1)    # 在通道维上将输入和输出连结, 通道维数改变
        return x

In [6]:
blk = DenseBlock(2, 3, 10)
x = torch.rand(4, 3, 8, 8)
y = blk(x)

5.12.2 过渡层
    -每个稠密块都会带来通道维数的增加，使用过多会提高模型复杂度；
    -过渡层用来控制模型的复杂度。

In [7]:
def transition_block(in_channels, out_channels):
    blk = nn.Sequential(
        nn.BatchNorm2d(in_channels),
        nn.ReLU(),
        nn.Conv2d(in_channels, out_channels, kernel_size=1),
        nn.AvgPool2d(kernel_size=2, stride=2)
    )
    return blk

blk = transition_block(23, 10)
blk(y).shape

torch.Size([4, 10, 4, 4])

5.12.3 DenseNet模型

In [8]:
net = nn.Sequential(
    nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3),
    nn.BatchNorm2d(64),
    nn.ReLU(),
    nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
)

In [9]:
num_channels, growth_rate = 64, 32
num_convs_in_dense_blocks = [4, 4, 4, 4]
for i, num_convs in enumerate(num_convs_in_dense_blocks):
    DB = DenseBlock(num_convs, num_channels, growth_rate)
    net.add_module("Dense_Block_%d" % i, DB)
    num_channels = DB.out_channels

    #在稠密块之间加入通道数减半的过渡层
    if i != len(num_convs_in_dense_blocks) - 1:
        net.add_module('transition_block_%d' % i, transition_block(num_channels, num_channels // 2))
        num_channels = num_channels // 2

In [10]:
net.add_module('BN', nn.BatchNorm2d(num_channels))
net.add_module('relu', nn.ReLU())
net.add_module('global_avg_pool', d2l.GlobalAvgPool2d())    #输出: (Batch, num_channels, 1, 1)
net.add_module('fc', nn.Sequential(
    d2l.FlattenLayer(),
    nn.Linear(num_channels, 10)
))

In [11]:
X = torch.rand((1, 1, 96, 96))
for name, layer in net.named_children():
    X = layer(X)
    print(name, 'output shape:\t', X.shape)

0 output shape:	 torch.Size([1, 64, 48, 48])
1 output shape:	 torch.Size([1, 64, 48, 48])
2 output shape:	 torch.Size([1, 64, 48, 48])
3 output shape:	 torch.Size([1, 64, 24, 24])
Dense_Block_0 output shape:	 torch.Size([1, 192, 24, 24])
transition_block_0 output shape:	 torch.Size([1, 96, 12, 12])
Dense_Block_1 output shape:	 torch.Size([1, 224, 12, 12])
transition_block_1 output shape:	 torch.Size([1, 112, 6, 6])
Dense_Block_2 output shape:	 torch.Size([1, 240, 6, 6])
transition_block_2 output shape:	 torch.Size([1, 120, 3, 3])
Dense_Block_3 output shape:	 torch.Size([1, 248, 3, 3])
BN output shape:	 torch.Size([1, 248, 3, 3])
relu output shape:	 torch.Size([1, 248, 3, 3])
global_avg_pool output shape:	 torch.Size([1, 248, 1, 1])
fc output shape:	 torch.Size([1, 10])
