In [4]:
'''
    Summarize SSD module from ssd.py
'''

from torchinfo import summary
import torch
from torch.nn import Conv2d, Sequential, ModuleList, BatchNorm2d
from torch import nn
from models.mb3_ssd.vision.nn.mobilenet_v2 import InvertedResidual
from models.mb3_ssd.vision.nn.mobilenet_v3 import MobileNetV3 as MobileNetV3Base
from models.mb3_ssd.vision.ssd.ssd import SSD, GraphPath
from models.mb3_ssd.vision.utils.box_utils import SSDSpec, SSDBoxSizes, generate_ssd_priors
import numpy as np

class MobileNetV3Config(object):
    def __init__(self):
        self.image_size = 300
        self.image_mean = np.array([127, 127, 127])  # RGB layout
        self.image_std = 128.0
        self.iou_threshold = 0.45
        self.center_variance = 0.1
        self.size_variance = 0.2

        self.specs = [
            SSDSpec(19, 16, SSDBoxSizes(60, 105), [2, 3]),
            SSDSpec(10, 32, SSDBoxSizes(105, 150), [2, 3]),
            SSDSpec(5, 64, SSDBoxSizes(150, 195), [2, 3]),
            SSDSpec(3, 100, SSDBoxSizes(195, 240), [2, 3]),
            SSDSpec(2, 150, SSDBoxSizes(240, 285), [2, 3]),
            SSDSpec(1, 300, SSDBoxSizes(285, 330), [2, 3])
        ]


        self.priors = generate_ssd_priors(self.specs, self.image_size)
        
config = MobileNetV3Config()

# testing parameter
# base_net = MobileNetV3(size='small', width_mult=1.0, classifier=False).features
base_net = MobileNetV3Base().features
width_mult = 1.0
num_classes = 3
is_test = False

def SeperableConv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0, onnx_compatible=False):
    """Replace Conv2d with a depthwise Conv2d and Pointwise Conv2d.
    """
    ReLU = nn.ReLU if onnx_compatible else nn.ReLU6
    return Sequential(
        Conv2d(in_channels=in_channels, out_channels=in_channels, kernel_size=kernel_size,
               groups=in_channels, stride=stride, padding=padding),
        BatchNorm2d(in_channels),
        ReLU(),
        Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=1),
    )

source_layer_indexes = [GraphPath(11, 'conv'),20,]

extras = ModuleList([
    InvertedResidual(1280, 512, stride=2, expand_ratio=0.2),
    InvertedResidual(512, 256, stride=2, expand_ratio=0.25),
    InvertedResidual(256, 256, stride=2, expand_ratio=0.5),
    InvertedResidual(256, 64, stride=2, expand_ratio=0.25)
])

regression_headers = ModuleList([
    SeperableConv2d(in_channels=round(288 * width_mult), out_channels=6 * 4,
                    kernel_size=3, padding=1, onnx_compatible=False),
    SeperableConv2d(in_channels=1280, out_channels=6 * 4, kernel_size=3, padding=1, onnx_compatible=False),
    SeperableConv2d(in_channels=512, out_channels=6 * 4, kernel_size=3, padding=1, onnx_compatible=False),
    SeperableConv2d(in_channels=256, out_channels=6 * 4, kernel_size=3, padding=1, onnx_compatible=False),
    SeperableConv2d(in_channels=256, out_channels=6 * 4, kernel_size=3, padding=1, onnx_compatible=False),
    Conv2d(in_channels=64, out_channels=6 * 4, kernel_size=1),
])

classification_headers = ModuleList([
    SeperableConv2d(in_channels=round(288 * width_mult), out_channels=6 * num_classes, kernel_size=3, padding=1),
    SeperableConv2d(in_channels=1280, out_channels=6 * num_classes, kernel_size=3, padding=1),
    SeperableConv2d(in_channels=512, out_channels=6 * num_classes, kernel_size=3, padding=1),
    SeperableConv2d(in_channels=256, out_channels=6 * num_classes, kernel_size=3, padding=1),
    SeperableConv2d(in_channels=256, out_channels=6 * num_classes, kernel_size=3, padding=1),
    Conv2d(in_channels=64, out_channels=6 * num_classes, kernel_size=1),
])

ssd_model = SSD(num_classes, base_net, source_layer_indexes,
           extras, classification_headers, regression_headers, is_test=is_test, config=config)

ssd_model.eval()
# if is_test:
#     confidences, boxes = ssd_model(torch.zeros(1, 3, 224, 224))
#     print('confidences:', confidences.shape)
#     print('boxes:', boxes.shape)
# else:
#     confidences, locations = ssd_model(torch.zeros(1, 3, 224, 224))
#     print('confidences:', confidences.shape)
#     print('locations:', locations.shape)

summary(base_net, (1, 3, 224, 224))

Layer (type:depth-idx)                   Output Shape              Param #
Sequential                               --                        --
├─Conv2d: 1-1                            [1, 16, 112, 112]         448
├─BatchNorm2d: 1-2                       [1, 16, 112, 112]         32
├─h_swish: 1-3                           [1, 16, 112, 112]         --
├─MobileBlock: 1-4                       [1, 16, 56, 56]           --
│    └─Sequential: 2-1                   [1, 16, 112, 112]         --
│    │    └─Conv2d: 3-1                  [1, 16, 112, 112]         256
│    │    └─BatchNorm2d: 3-2             [1, 16, 112, 112]         32
│    │    └─ReLU: 3-3                    [1, 16, 112, 112]         --
│    └─Sequential: 2-2                   [1, 16, 56, 56]           --
│    │    └─Conv2d: 3-4                  [1, 16, 56, 56]           160
│    │    └─BatchNorm2d: 3-5             [1, 16, 56, 56]           32
│    └─SqueezeBlock: 2-3                 [1, 16, 56, 56]           --
│    │    └─

In [5]:
summary(ssd_model, (1, 3, 224, 224))

Layer (type:depth-idx)                        Output Shape              Param #
SSD                                           --                        --
├─ModuleList: 1-1                             --                        --
├─ModuleList: 1-2                             --                        --
├─ModuleList: 1-3                             --                        --
├─ModuleList: 1-4                             --                        --
├─Sequential: 1                               --                        --
│    └─Conv2d: 2-1                            [1, 16, 112, 112]         448
│    └─BatchNorm2d: 2-2                       [1, 16, 112, 112]         32
│    └─h_swish: 2-3                           [1, 16, 112, 112]         --
│    └─MobileBlock: 2-4                       [1, 16, 56, 56]           --
│    │    └─Sequential: 3-1                   [1, 16, 112, 112]         288
│    │    └─Sequential: 3-2                   [1, 16, 56, 56]           192
│    │    └─Squee