Model view

In [4]:
from torch import nn
import torchvision
import torch
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor

In [5]:
pretrain_model = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=True)

in_features = pretrain_model.roi_heads.box_predictor.cls_score.in_features
num_classes = 11 # 10 class + background
pretrain_model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)


---
# Custom code

## Models

Object Detection, Instance Segmentation

Faster R-CNN ResNet-50 FPN // FPN은 Feature Pyramid Network를 말한다!

RetinaNet ResNet-50 FPN

Mask R-CNN ResNet-50 FPN

In [None]:
faster_rcnn50 = models.detection.fasterrcnn_resnet50_fpn(pretrained=True)
mask_rcnn50 = models.detection.maskrcnn_resnet50_fpn(pretrained=True)

---

## RPN 커스텀
torchvision의 Detection 모델은 rpn 부분만 커스텀

AnchorGenerator 인스턴스를 생성해 rpn 커스텀이 가능

In [14]:
from torchvision.models.detection.rpn import AnchorGenerator

# 5x3 anchor patterns
custom_rpn = AnchorGenerator(sizes=((32, 64, 128, 512)), aspect_ratios=((0.5, 1.0, 2.0))) 

In [9]:
model = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=True)

---
## ROI pooler 커스텀

torchvision은 ops 모듈을 통해 컴퓨터 비전에서 자주 사용하는 연산들을 제공

nms, roi_pool, roi_align, MultiScaleRoIAlign

pytorch Tutorial의 torchvision object detection에선 MultiScaleRoIAlign를 roi pooler로 이용

In [15]:

roi_pooler = torchvision.ops.MultiScaleRoIAlign(featmap_names=[0], output_size=7, sampling_ratio=2)

---

참고자료
<a href="https://pytorch.org/vision/stable/_modules/torchvision/models/detection/faster_rcnn.html">link<a>


In [18]:
num_classes = 11
in_features = model.roi_heads.box_predictor.cls_score.in_features

model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)

backbone = torchvision.models.mobilenet_v2(pretrained=True).features
backbone.out_channels = 1280

anchor_generator = AnchorGenerator(sizes=((32, 64, 128, 256, 512)), aspect_ratios=((0.5, 1.0, 2.0)))

roi_pooler = torchvision.ops.MultiScaleRoIAlign(featmap_names=[0], output_size=7, sampling_ratio=2)

model = torchvision.models.detection.FasterRCNN(
  backbone, 
  num_classes=2, 
  rpn_anchor_generator=anchor_generator, 
  box_roi_pool=roi_pooler
)