In [1]:
timm_path = "../input/timm-pytorch-image-models/pytorch-image-models-master"
import sys
sys.path.append(timm_path)
import timm
import torch

In [2]:
import torch.nn as nn
import torchvision
from torchvision.models.detection.rpn import AnchorGenerator
from torchvision.models.detection import FasterRCNN
from collections import OrderedDict
from torchvision.models.detection.backbone_utils import BackboneWithFPN
from typing import Callable, Dict, Optional, List, Union
from torchvision.ops.feature_pyramid_network import FeaturePyramidNetwork, LastLevelMaxPool, ExtraFPNBlock
from torch import nn, Tensor
from torchvision.models._utils import IntermediateLayerGetter

In [3]:
class My_BackboneWithFPN(nn.Module):
    def __init__(
        self,
        backbone: nn.Module,
        return_layers: Dict[str, str],
        in_channels_list: List[int],
        out_channels: int,
        extra_blocks: Optional[ExtraFPNBlock] = None,
    ) -> None:
        super().__init__()

        if extra_blocks is None:
            extra_blocks = LastLevelMaxPool()
        
        self.backbone = backbone
        self.back = self.backbone.blocks
        

        self.body = IntermediateLayerGetter(self.back, return_layers=return_layers)
        self.fpn = FeaturePyramidNetwork(
            in_channels_list=in_channels_list,
            out_channels=out_channels,
            extra_blocks=extra_blocks,
        )
        self.out_channels = out_channels
        
        

    def forward(self, x: Tensor) -> Dict[str, Tensor]:
        x = self.backbone.conv_stem(x)
        x = self.backbone.bn1(x)
        x = self.backbone.act1(x) 
        x = self.body(x)
        x = self.fpn(x)
        return x

In [4]:
mm = timm.create_model("efficientnet_b3", pretrained=False,num_classes=0, global_pool='')

In [5]:
def custom_fpn(backbone,pretrained):
    
    in_channels_list = [32, 48, 136, 384]
    backbone = timm.create_model(backbone, pretrained=pretrained,num_classes=0, global_pool='')
    out_channels = 384
    return_layers = {'1':'0', '2': '1', '4': '2', '6': '3'}
    
    return My_BackboneWithFPN(backbone, return_layers, in_channels_list, out_channels)

In [6]:
backbone = custom_fpn("efficientnet_b3" , False)
# get some dummy image
x = torch.rand(1,3,224,224)
# compute the output
output = backbone(x)
print([(k, v.shape) for k, v in output.items()])

In [7]:
from torchvision.models.detection.rpn import AnchorGenerator
from torchvision.models.detection.backbone_utils import BackboneWithFPN
def get_model():
    # load a model; pre-trained on COCO
    #model = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=True)
    
    #backbone
    back = custom_fpn("efficientnet_b3" , False)
    
    #RPN
    anchor_sizes = ((32,), (64,), (128,), (256,), (512,))
    aspect_ratios = ((0.5, 1.0,),) * len(anchor_sizes)
    anchor_generator = AnchorGenerator(anchor_sizes, aspect_ratios)
    
    #ROI
    roi_pooler = torchvision.ops.MultiScaleRoIAlign(featmap_names=['0','1','2','3','pool'],
                output_size=7,
                sampling_ratio=2)
    
    model = FasterRCNN(back,num_classes=2,rpn_anchor_generator=anchor_generator,box_roi_pool=roi_pooler)

    return model

model = get_model()

In [8]:
images, boxes = torch.rand(2, 3, 600, 600), torch.rand(2, 11, 4)
boxes[:, :, 2:4] = boxes[:, :, 0:2] + boxes[:, :, 2:4]
labels = torch.randint(0, 2, (4, 11))
images = list(image for image in images)
targets = []
for i in range(len(images)):
    d = {}
    d['boxes'] = boxes[i]
    d['labels'] = labels[i]
    targets.append(d)

In [9]:
#model = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=False)

output = model(images, targets)

In [10]:
output