In [1]:
import torchvision
from torchvision.models.detection.backbone_utils import BackboneWithFPN, LastLevelMaxPool
from torchvision.models.feature_extraction import create_feature_extractor
from torchvision.models.detection.rpn import AnchorGenerator
from torchvision.models.detection import MaskRCNN



# 导入各种模型，不需要特别的大===>[resnet34, resnet50, vgg16, mobilenet_v3_large, efficientnet]
from torchvision.models import resnet34, resnet50, vgg16, mobilenet_v3_large, efficientnet_b0
model_map = {
    "resnet34": resnet34(pretrained=True),
    "resnet50": resnet50(pretrained=True),
    "vgg16": vgg16(pretrained=True).features,
    "mobilenet_v3_large": mobilenet_v3_large(pretrained=True).features,
    "efficientnet_b0": efficientnet_b0(pretrained=True).features,
}




In [2]:
def build_backbone_with_fpn(model_name):
    """
    构建backbone
    :param model_name: model name
    :return: backbone
    """
    # --- mobilenet_v3_large fpn backbone --- #
    model_files = []
    if model_name == "mobilenet_v3_large":
        backbone = model_map[model_name]
        # print(backbone)
        return_layers = {'3': "0",   # stride 8
                '8':'1',  # stride 16
                '12':'2',
                '16':'3'}  # stripe 32
        in_channel_list = [24, 80, 112, 960]

    # --- efficientnet_b0 fpn backbone --- #
    if model_name == "efficientnet_b0":
        backbone = model_map[model_name]
       # print(backbone)
        return_layers = {'2': "0",   # stride 8
                '4':'1',  # stride 16
                '6':'2',
                '8':'3'}  # stripe 32
        in_channel_list = [24, 80, 192, 1280]

    # --- resnet34 fpn backbone --- #
    if model_name == "resnet34":
        backbone = model_map[model_name]
        # print(backbone)
        return_layers = {
                'layer1': "0",   # stride 8',
                'layer2': "1",   # stride 16
                'layer3':'2',  # stride 32
                'layer4':'3'}  # stripe 64
        in_channel_list = [64, 128, 256, 512]

    # --- resnet50 fpn backbone --- #
    if model_name == "resnet50":
        backbone = model_map[model_name]
        # print(backbone)
        return_layers = {
                'layer1': "0",   # stride 8'
                'layer2': "1",   # stride 8
                'layer3':'2',  # stride 16
                'layer4':'3'}  # stripe 32
        in_channel_list = [256, 512, 1024, 2048]

    if model_name == "vgg16":
        backbone = model_map[model_name]
        # print(backbone)
        return_layers = {'4': "0",   # stride 8
                '9':'1',  # stride 16
                '16':'2',
                '30':'3'}  # stripe 32
        in_channel_list = [64, 128, 256, 512]
    
    return backbone, return_layers, in_channel_list

In [6]:
build_backbone_with_fpn(model_name='resnet50')



(ResNet(
   (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
   (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
   (relu): ReLU(inplace=True)
   (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
   (layer1): Sequential(
     (0): Bottleneck(
       (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
       (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
       (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
       (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
       (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
       (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
       (relu): ReLU(inplace=True)
       (downsample): Sequential(
         (0): Conv2d(64, 256, kernel_size=(1,

In [4]:
def create_model(num_classes, model_name):
  backbone, return_layers, in_channels_list = build_backbone_with_fpn(model_name)

  backbone_with_fpn = BackboneWithFPN(backbone,
                        return_layers=return_layers,
                        in_channels_list=in_channels_list,
                        out_channels=256,
                        extra_blocks=LastLevelMaxPool())

  anchor_sizes = ((32,), (64,), (128,), (256,), (512,))
  aspect_ratios = ((0.5, 1.0, 2.0),) * len(anchor_sizes)
  anchor_generator = AnchorGenerator(sizes=anchor_sizes,
                        aspect_ratios=aspect_ratios)

  roi_pooler = torchvision.ops.MultiScaleRoIAlign(featmap_names=['0', '1', '2','3'],  # 在哪些特征层上进行RoIAlign pooling
                             output_size=[7, 7],  # RoIAlign pooling输出特征矩阵尺寸
                            sampling_ratio=2)  # 采样率
  mask_roi_pooler = torchvision.ops.MultiScaleRoIAlign(featmap_names=['0', '1', '2','3'],
                                output_size = [7,7],       # 在哪些特征层上进行RoIAlign pooling
                               sampling_ratio=2)  # 采样率

  model = MaskRCNN(backbone=backbone_with_fpn,
                       num_classes=num_classes,
                       rpn_anchor_generator=anchor_generator,
                       box_roi_pool=roi_pooler,
                       mask_roi_pool=mask_roi_pooler)

  return model

In [5]:
# import  time
# start = time.time()
model = create_model(2,'mobilenet_v3_large')
# model.eval()
# x = [torch.rand(3, 300, 400), torch.rand(3, 500, 400)]
# predictions = model(x)
# end = time.time()
# print(end - start)  