In [1]:
%matplotlib inline
import torch
import torchvision
from torch import nn
from torch.nn import functional as F
from d2l import torch as d2l

使用预训练的 `resnet18` 作为基准网络

In [2]:
pretrained_net = torchvision.models.resnet18(pretrained=True)

In [8]:
list(pretrained_net.children())[-3:]

[Sequential(
   (0): BasicBlock(
     (conv1): Conv2d(256, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
     (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
     (relu): ReLU(inplace=True)
     (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
     (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
     (downsample): Sequential(
       (0): Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2), bias=False)
       (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
     )
   )
   (1): BasicBlock(
     (conv1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
     (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
     (relu): ReLU(inplace=True)
     (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
     (bn2): Batc

将 resnet18 最后两层去掉

In [10]:
net = nn.Sequential(*list(pretrained_net.children())[:-2])

得到去掉卷积层后，320 x 480 的图片网络中输出是 10 x 15

In [13]:
X = torch.rand(size=(1, 3, 320, 480))
Y = net(X)
Y.shape

torch.Size([1, 512, 10, 15])

使用大小为 64 的卷积核，填充 16，步幅 32，最终上采样为原图大小 320 x 480，且先通过卷积层将通道数变为 Pascal VOC2012 数据集的类别数（21），这里填充是 16，是因为要让高宽成倍增大（32倍），就要让 $k(64) = 2p + s(32)$

In [18]:
conv = nn.Conv2d(512, 21, kernel_size=1)
tconv = nn.ConvTranspose2d(21, 21, kernel_size=64, padding=16, stride=32)
tconv(conv(Y)).shape

torch.Size([1, 21, 320, 480])

在网络中添加层

In [23]:
net.add_module('final_conv', conv)
net.add_module('transpose_conv', tconv)
list(net.children())[-3:]

[Sequential(
   (0): BasicBlock(
     (conv1): Conv2d(256, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
     (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
     (relu): ReLU(inplace=True)
     (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
     (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
     (downsample): Sequential(
       (0): Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2), bias=False)
       (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
     )
   )
   (1): BasicBlock(
     (conv1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
     (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
     (relu): ReLU(inplace=True)
     (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
     (bn2): Batc

采用双线性插值的上采样算法初始化转置卷积层，以下是双线性插值初始化的转置卷积层权重，输入的参数是通道数和卷积核大小

In [24]:
def bilinear_kernel(in_channels, out_channels, kernel_size):
    factor = (kernel_size + 1) // 2
    if kernel_size % 2 == 1:
        center = factor - 1
    else:
        center = factor - 0.5
    og = (torch.arange(kernel_size).reshape(-1, 1),
          torch.arange(kernel_size).reshape(1, -1))
    filt = (1 - torch.abs(og[0] - center) / factor) * \
           (1 - torch.abs(og[1] - center) / factor)
    weight = torch.zeros((in_channels, out_channels,
                          kernel_size, kernel_size))
    weight[range(in_channels), range(out_channels), :, :] = filt
    return weight

### 正式训练

读取数据集

In [28]:
batch_size, crop_size = 32, (320, 480)
train_iter, test_iter = None, None
# train_iter, test_iter = torchvision.datasets.VOCSegmentation
# train_iter, test_iter = d2l.load_data_voc(batch_size, crop_size)

In [29]:
def loss(inputs, targets):
    return F.cross_entropy(inputs, targets, reduction='none').mean(1).mean(1)

In [30]:
# num_epochs, lr, wd, devices = 5, 0.001, 1e-3, d2l.try_all_gpus()
# trainer = torch.optim.SGD(net.parameters(), lr=lr, weight_decay=wd)
# d2l.train_ch13(net, train_iter, test_iter, loss, trainer, num_epochs, devices)

In [31]:
def predict(img):
    X = test_iter.dataset.normalize_image(img).unsqueeze(0)
    pred = net(X.to(devices[0])).argmax(dim=1)
    return pred.reshape(pred.shape[1], pred.shape[2])

In [32]:
def label2image(pred):
    colormap = torch.tensor(d2l.VOC_COLORMAP, device=devices[0])
    X = pred.long()
    return colormap[X, :]

In [None]:
# voc_dir = d2l.download_extract('voc2012', 'VOCdevkit/VOC2012')
# test_images, test_labels = d2l.read_voc_images(voc_dir, False)
# n, imgs = 4, []
# for i in range(n):
#     crop_rect = (0, 0, 320, 480)
#     X = torchvision.transforms.functional.crop(test_images[i], *crop_rect)
#     pred = label2image(predict(X))
#     imgs += [X.permute(1,2,0), pred.cpu(),
#              torchvision.transforms.functional.crop(
#                  test_labels[i], *crop_rect).permute(1,2,0)]
# d2l.show_images(imgs[::3] + imgs[1::3] + imgs[2::3], 3, n, scale=2);