In [21]:
import torch
from torch import nn
import torchvision
from typing import *
from types import *

from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
from torchvision.models.detection import FasterRCNN
from torchvision.models.detection.rpn import AnchorGenerator
from torchvision.models.detection.backbone_utils import resnet_fpn_backbone

## utils

In [29]:
def param_groups(model: nn.Module) -> List[nn.Parameter]:
    body = model.body

    layers = []
    layers += [nn.Sequential(body.conv1, body.bn1)]
    layers += [getattr(body, l) for l in list(body) if l.startswith("layer")]
    layers += [model.fpn]

    _param_groups = [list(layer.parameters()) for layer in layers]
    check_all_model_params_in_groups2(model, _param_groups)

    return _param_groups


def patch_param_groups(model: nn.Module) -> None:
    model.param_groups = MethodType(param_groups, model)

In [14]:
def _resnet_features(model: nn.Module, out_channels: int):
    # remove last layer (fully-connected)
    modules = list(model.named_children())[:-1]
    features = nn.Sequential(OrderedDict(modules))

    features.out_channels = out_channels
    features.param_groups = MethodType(resnet_param_groups, features)

    return features


def resnet_param_groups(model: nn.Module) -> List[nn.Parameter]:
    layers = []
    layers += [nn.Sequential(model.conv1, model.bn1)]
    layers += [l for name, l in model.named_children() if name.startswith("layer")]

    param_groups = [list(layer.parameters()) for layer in layers]
    check_all_model_params_in_groups2(model, param_groups)

    return param_groups


def resnet18(pretrained: bool = True):
    model = torchvision.models.resnet18(pretrained=pretrained)
    return _resnet_features(model, out_channels=512)


def resnet34(pretrained: bool = True):
    model = torchvision.models.resnet34(pretrained=pretrained)
    return _resnet_features(model, out_channels=512)


def resnet50(pretrained: bool = True):
    model = torchvision.models.resnet50(pretrained=pretrained)
    return _resnet_features(model, out_channels=2048)


def resnet101(pretrained: bool = True):
    model = torchvision.models.resnet101(pretrained=pretrained)
    return _resnet_features(model, out_channels=2048)


def resnet152(pretrained: bool = True):
    model = torchvision.models.resnet152(pretrained=pretrained)
    return _resnet_features(model, out_channels=2048)


def resnext101_32x8d(pretrained: bool = True):
    model = torchvision.models.resnext101_32x8d(pretrained=pretrained)
    return _resnet_features(model, out_channels=2048)

In [24]:
m = FasterRCNN(resnet18(False), num_classes=10)
m = FasterRCNN(resnet34(False), num_classes = 12)
m = FasterRCNN(resnet50(False), num_classes = 1)
m = FasterRCNN(resnet101(False), num_classes=5)
m = FasterRCNN(resnet152(False), num_classes=16)
m = FasterRCNN(resnext101_32x8d(False), num_classes=14)

In [20]:
def _vgg_features(model: nn.Module):
    features = model.features
    features.out_channels = 512
    features.param_groups = MethodType(vgg_param_groups, features)

    return features


def vgg_param_groups(model: nn.Module) -> List[List[nn.Parameter]]:
    layers = []
    layers += [model[:4]]
    # splits layers into 3 equally sized chunks
    for group in np.array_split(model[4:], 3):
        layers += [nn.Sequential(*group)]

    param_groups = [list(layer.parameters()) for layer in layers]
    check_all_model_params_in_groups2(model, param_groups)

    return param_groups


def vgg11(pretrained: bool = True):
    return _vgg_features(model=torchvision.models.vgg11(pretrained=pretrained))


def vgg13(pretrained: bool = True):
    return _vgg_features(model=torchvision.models.vgg13(pretrained=pretrained))


def vgg16(pretrained: bool = True):
    return _vgg_features(model=torchvision.models.vgg16(pretrained=pretrained))


def vgg19(pretrained: bool = True):
    return _vgg_features(model=torchvision.models.vgg19(pretrained=pretrained))

In [25]:
m = FasterRCNN(vgg11(False), num_classes=14)
m = FasterRCNN(vgg13(False), num_classes=14)
m = FasterRCNN(vgg16(False), num_classes=14)
m = FasterRCNN(vgg19(False), num_classes=14)

In [22]:
def _resnet_fpn(name: str, pretrained: bool = True, **kwargs):
    model = resnet_fpn_backbone(backbone_name=name, pretrained=pretrained, **kwargs)
    patch_param_groups(model)
    return model

def resnet18_fpn(pretrained: bool = True, **kwargs):
    return _resnet_fpn("resnet18", pretrained=pretrained, **kwargs)


def resnet34_fpn(pretrained: bool = True, **kwargs):
    return _resnet_fpn("resnet34", pretrained=pretrained, **kwargs)


def resnet50_fpn(pretrained: bool = True, **kwargs):
    return _resnet_fpn("resnet50", pretrained=pretrained, **kwargs)


def resnet101_fpn(pretrained: bool = True, **kwargs):
    return _resnet_fpn("resnet101", pretrained=pretrained, **kwargs)


def resnet152_fpn(pretrained: bool = True, **kwargs):
    return _resnet_fpn("resnet152", pretrained=pretrained, **kwargs)


def resnext50_32x4d_fpn(pretrained: bool = True, **kwargs):
    return _resnet_fpn("resnext50_32x4d", pretrained=pretrained, **kwargs)


def resnext101_32x8d_fpn(pretrained: bool = True, **kwargs):
    return _resnet_fpn("resnext101_32x8d", pretrained=pretrained, **kwargs)


def wide_resnet50_2_fpn(pretrained: bool = False, **kwargs):
    return _resnet_fpn("wide_resnet50_2", pretrained=pretrained, **kwargs)


def wide_resnet101_2_fpn(pretrained: bool = False, **kwargs):
    return _resnet_fpn("wide_resnet101_2", pretrained=pretrained, **kwargs)

In [30]:
m = FasterRCNN(resnet18_fpn(False), num_classes=14)
m = FasterRCNN(resnet34_fpn(False), num_classes=14)
m = FasterRCNN(resnet50_fpn(False), num_classes=14)
m = FasterRCNN(resnet101_fpn(False), num_classes=14)
m = FasterRCNN(resnet152_fpn(False), num_classes=14)
m = FasterRCNN(resnext50_32x4d_fpn(False), num_classes=14)
m = FasterRCNN(resnext101_32x8d_fpn(False), num_classes=14)
m = FasterRCNN(wide_resnet50_2_fpn(False), num_classes=14)
m = FasterRCNN(wide_resnet101_2_fpn(False), num_classes=14)