In [1]:
import torchvision
from torchvision.models.detection import FasterRCNN
from torchvision.models.detection.rpn import AnchorGenerator

In [2]:
# load a pre-trained model for classification and return only the features
backbone = torchvision.models.mobilenet_v2(pretrained=True).features

Downloading: "https://download.pytorch.org/models/mobilenet_v2-b0353104.pth" to C:\Users\ironc/.cache\torch\checkpoints\mobilenet_v2-b0353104.pth
100%|█████████████████████████████████████████████████████████████████████████████| 13.6M/13.6M [00:10<00:00, 1.32MB/s]


In [4]:
type(backbone)

torch.nn.modules.container.Sequential

In [8]:
# last layer of the backbone...
backbone[-1]

ConvBNReLU(
  (0): Conv2d(320, 1280, kernel_size=(1, 1), stride=(1, 1), bias=False)
  (1): BatchNorm2d(1280, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (2): ReLU6(inplace=True)
)

In [9]:
# FasterRCNN needs to know the number of output channels in a backbone.
# For mobilenet_v2, it's 1280. so we need to add it here
backbone.out_channels = 1280

In [10]:
# let's make the RPN generate 5 x 3 anchors per spatial location, with 5 different sizes and 3 different aspect ratios.
# We have a Tuple[Tuple[int]] because each feature map could potentially have different sizes and aspect ratios
anchor_generator = AnchorGenerator(sizes=(32, 64, 128, 256, 512), aspect_ratios=(0.5,  1.0, 2.0))

In [11]:
# let's define what are the feature maps that we will use to perform the region of interest cropping,
# as well as the size of the crop after rescaling.
# if your backbone returns a Tensor, featmap_names is expected to be [0].
# More generally, the backbone should return an OrderedDict[Tensor],
# and in featmap_names you can choose which feature maps to use.
roi_pooler = torchvision.ops.MultiScaleRoIAlign(featmap_names=[0],
                                                output_size=7,
                                                sampling_ratio=2)

In [12]:
# put the pieces together inside a FasterRCNN model
model = FasterRCNN(backbone=backbone,
                   num_classes=2,
                   rpn_anchor_generator=anchor_generator,
                   box_roi_pool=roi_pooler)

In [13]:
type(model)

torchvision.models.detection.faster_rcnn.FasterRCNN