# 单发多框检测 SSD

In [1]:
import torch
import torchvision
from torch import nn
from torch.nn import functional as F
import d2l.torch as d2l

### 类别预测层

使用一个保持输入的高和宽的卷积层, 使输入和输出在特征图宽和高上的坐标上对应

目标类别数量有 num_classes + 1 个类别, 包括了背景

设特征图的高和宽为 h 和 w , 如果以每一个单元为中心生成 num_anchors 个 锚框

则 $i(num\_classes + 1) + j$ 的通道表示第 $i$ 个锚框有关类别为 $j$ 的预测

In [2]:
def cls_predictor(num_inputs, num_anchors, num_classes):
    return nn.Conv2d(num_inputs, num_anchors * (num_classes + 1), kernel_size=3, padding=1)

### 边界框预测层

和类别预测层不同的是每个锚框预测 $4$ 个偏移量, 而不是 $q+1$ 个类别

In [3]:
def bbox_predictor(num_inputs, num_anchors):
    return nn.conv2d(num_inputs, num_anchors * 4, kernel_size=3, padding=1)

来看看长什么样

In [4]:
def forward(x, block):
    return block(x)
x1 = forward(torch.zeros((2, 8, 20, 20)), cls_predictor(8, 5, 10))
x2 = forward(torch.zeros((2, 16, 10, 10)), cls_predictor(16, 3, 10))
x1.shape, x2.shape

(torch.Size([2, 55, 20, 20]), torch.Size([2, 33, 10, 10]))

连接多尺度的预测

permute 再 flatten 的时候, 让一个像素的预测排在一起了

cat 可以在 dim 维度上把数据连接起来

In [5]:
def flatten_pred(pred):
    return torch.flatten(pred.permute(0, 2, 3, 1), start_dim=1)
def concat_preds(preds):
    return torch.cat([flatten_pred(p) for p in preds], dim=1)

合成一个, 然后方便之后预测, 拉成了一个 tensor

In [6]:
print(concat_preds([x1]).shape , concat_preds([x2]).shape)
print(concat_preds([x1, x2]).shape)

torch.Size([2, 22000]) torch.Size([2, 3300])
torch.Size([2, 25300])


高宽减半块

In [7]:
def down_sample_blk(in_channels, out_channels):
    blk = []
    for _ in range(2):
        blk.append(nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1))
        blk.append(nn.BatchNorm2d(out_channels))
        blk.append(nn.ReLU())
        in_channels = out_channels
    blk.append(nn.MaxPool2d(2))
    return nn.Sequential(*blk)

In [8]:
forward(torch.zeros((2, 3, 20, 20)), down_sample_blk(3, 7)).shape

torch.Size([2, 7, 10, 10])

基本网络块

用于抽取输入图像中的特征, 把网络的高宽减半 3 次

In [9]:
def base_net():
    blk = []
    num_filters = [3, 16, 32, 64]
    for i in range(len(num_filters)-1):
        blk.append(down_sample_blk(num_filters[i], num_filters[i+1]))
    return nn.Sequential(*blk)

In [10]:
forward(torch.zeros((2, 3, 256, 256)), base_net()).shape

torch.Size([2, 64, 32, 32])

0 : 基本网络块, 将高宽减半三次, 通道从 3 增加到 64 ( 3 层 down_sample_blk )

1 : 高宽减半块, 1 层 down_sample_blk

2 , 3 : 保持通道数不变的 down_sample_blk

4 : output_size 变成 1 * 1 的平均池化层

In [11]:
def get_blk(i):
    if i == 0: # 第一个是基本网络块 
        blk = base_net()
    elif i == 1: # 高宽减半块 
        blk = down_sample_blk(64, 128)
    elif i == 4: # 将高度和宽度都降到 1
        blk = nn.AdaptiveAvgPool2d((1,1)) # output_size = 1 * 1 
    else: # 高宽减半块 
        blk = down_sample_blk(128, 128)
    return blk

In [12]:
def blk_forward(X, blk, size, ratio, cls_predictor, bbox_predictor):
    Y = blk(X)
    anchors = d2l.multibox_prior(Y, sizes=size, ratios=ratio) # 生成一些锚框
    cls_preds = cls_predictor(Y)
    bbox_preds = bbox_predictor(Y)
    return (Y, anchors, cls_preds, bbox_preds)

设置生成 anchor 的参数

0.2 和 1.05 之间的数平均分成五个部分 0.2 , 0.37 , 0.54 , 0.71 , 0.88 , 1.05

后面那个值为 $\sqrt{a_i\times a_{i+1}}$ 生成的

In [13]:
sizes = [[0.2, 0.272], [0.37, 0.447], [0.54, 0.619], [0.71, 0.79], [0.88, 0.961]] # 长和宽的缩放比 size 越来越大 , 从看小的局部到更加 global 的物品
ratios = [[1, 2, 0.5]] * 5
num_anchors = len(sizes[0]) + len(ratios[0]) - 1

`setattr` 可以将第一个变量的第二个名字的值设置为第三个变量等价 `self.blk_{i} = get_blk(i)`

`getattr` 就是刚好相反, 是获得值

In [14]:
class TinySSD(nn.Module):
    def __init__(self, num_classes, **kwargs): # 类似一些字典的参数
        super(TinySSD, self).__init__(**kwargs)
        self.num_classes = num_classes
        idx_to_in_channels = [64, 128, 128, 128, 128]
        for i in range(5):
            setattr(self, f'blk_{i}', get_blk(i))
            setattr(self, f'cls_{i}', cls_predictor(idx_to_in_channels[i], num_anchors, num_classes))
            setattr(self, f'bbox_{i}', bbox_predictor(idx_to_in_channels[i], num_anchors))
    
    def forward(self, X):
        anchors, cls_preds, bbox_preds = [None]*5, [None]*5, [None]*5
        for i in range(5):
            X, anchors[i], cls_preds[i], bbox_preds[i] = blk_forward(
                X, getattr(self, f'blk_{i}'), sizes[i], ratios[i], getattr(self, f'cls_{i}'), getattr(self, f'bbox_{i}')
            )
            anchors = torch.cat(anchors, dim=1)
            cls_preds = concat_preds(cls_preds).reshape(cls_preds.shape[0], -1, self.num_classes + 1)
            bbox_preds = concat_preds(bbox_preds)
            return anchors, cls_preds, bbox_preds

In [None]:
net = TinySSD(num_classes=1)
X = torch.zeros((32, 3, 256, 256))