# Import Module

In [None]:
%load_ext autoreload
%autoreload 2

import torch
import torch.nn as nn

from src.core.utils import describe

## build data loader

In [None]:
from src.datasets.loader.build_loader import build_dataloader
from mmcv.utils.config import Config


cfg = Config.fromfile('config/retinanet_x101_64x4d_fpn_1x.py')

train_cfg = cfg.train_cfg
test_cfg = cfg.test_cfg
dataset_cfg = cfg.data.val

loader = iter(build_dataloader(dataset_cfg))

## Build Module

### feature extractor :  backbone + neck

In [None]:
"""
backbone=dict(
        type='ResNeXt',
        depth=101,
        groups=64,
        base_width=4,
        num_stages=4,
        out_indices=(0, 1, 2, 3),
        frozen_stages=1,
        style='pytorch'),
neck=dict(
        type='FPN',
        in_channels=[256, 512, 1024, 2048],
        out_channels=256,
        start_level=1,
        add_extra_convs=True,
        num_outs=5)
"""
from src.models.builder import build_backbone, build_neck


class FeatureExtractor(nn.Module):
    def __init__(self, cfg):
        super(FeatureExtractor, self).__init__()
        backbone_cfg = cfg.model.backbone
        neck_cfg = cfg.model.neck
        
        self.backbone = build_backbone(backbone_cfg)
        self.neck = build_neck(neck_cfg)
    
    def forward(self, x):
        print(f"Raw Image shape : {describe(x)}")
        
        feature = self.backbone(x)
        print(f"After Resnet Passed: {describe(feature)}")
        
        multi_level_feature = self.neck(feature)
        print(f"After FPN Passed: {describe(multi_level_feature)}")
        
        return multi_level_feature

### Head : RetinaHead

In [None]:
"""
bbox_head=dict(
    type='RetinaHead',
    num_classes=81,  # background + 80 (RetinaNet)
    in_channels=256, # (RetinaNet)
    stacked_convs=4,  # number of class/box subnet's conv layers (RetinaNet)
    feat_channels=256,  # num_channels in subnet's conv feature (RetinaNet)
    octave_base_scale=4,  # anchor scale related factor (RetinaNet)
    scales_per_octave=3,  # anchor scale related factor (RetinaNet)
    anchor_ratios=[0.5, 1.0, 2.0],  # anchor scale related factor (RetinaNet)
    anchor_strides=[8, 16, 32, 64, 128],  # stride of anchor, normally stride of feature map. (RetinaNet)
    target_means=[.0, .0, .0, .0],  # regression target mean (RetinaNet)
    target_stds=[1.0, 1.0, 1.0, 1.0]))  # regression target std (RetinaNet)
"""
from src.models.builder import build_head


class Head(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.bbox_head = build_head(cfg.model.bbox_head)
    
    def forward(self, feature):
        cls_score, bbox_pred = self.bbox_head(feature)
        return cls_score, bbox_pred

## Load pretrained checkpoint

In [None]:
from mmcv.runner import load_checkpoint


feature_extractor = FeatureExtractor(cfg)
_ = load_checkpoint(feature_extractor, 'pretrained/retinanet_x101_64x4d_fpn_1x_pretrained.pth')

bbox_head = Head(cfg)
_ = load_checkpoint(bbox_head, 'pretrained/retinanet_x101_64x4d_fpn_1x_pretrained.pth')

### Load data

In [None]:
sample = next(loader)
print(sample.keys())

img = sample['img'].data[0]
img_metas = sample['img_meta'].data[0]
gt_bboxes = sample['gt_bboxes'].data[0]
gt_labels = sample['gt_labels'].data[0]

### feature extraction

In [None]:
feature = feature_extractor(img)
print(describe(feature))

### bbox prediction

In [None]:
cls_score, bbox_pred = bbox_head(feature)

print(describe(cls_score))
print(describe(bbox_pred))

## Calculate losses

### define anchor_generators

In [None]:
bbox_head.bbox_head.init_anchor_generator()

In [None]:
from src.visualization.visualize import draw_base_anchor


draw_base_anchor(bbox_head.bbox_head.anchor_generators[4], line_size=3)

### get anchors

In [None]:
from src.quiz.quiz1 import get_anchors

featmap_sizes = [featmap.size()[-2:] for featmap in cls_score]
anchor_list, valid_flag_list = get_anchors(bbox_head.bbox_head.anchor_generators, bbox_head.bbox_head.anchor_strides, featmap_sizes, img_metas)

In [None]:
describe(anchor_list)

In [None]:
describe(valid_flag_list)

### Make Target

In [None]:
num_imgs = len(img_metas)
assert len(anchor_list) == len(valid_flag_list) == num_imgs

num_level_anchors = [anchors.size(0) for anchors in anchor_list[0]]

for i in range(num_imgs):
    assert len(anchor_list[i]) == len(valid_flag_list[i])
    anchor_list[i] = torch.cat(anchor_list[i])
    valid_flag_list[i] = torch.cat(valid_flag_list[i])

In [None]:
anchor_target_variable = (anchor_list, valid_flag_list, gt_bboxes, gt_labels, img_metas, train_cfg, bbox_head.bbox_head.cls_out_channels)
%store anchor_target_variable

In [None]:
from src.core import multi_apply
from src.core.anchor import anchor_target_single


(all_labels, all_label_weights, all_bbox_targets, all_bbox_weights,
pos_inds_list, neg_inds_list) = multi_apply(
   anchor_target_single,
    anchor_list,
    valid_flag_list,
    gt_bboxes,
    gt_labels,
    img_metas,
    target_means=[.0, .0, .0, .0],
    target_stds=[1.0, 1.0, 1.0, 1.0],
    cfg=train_cfg,
    label_channels=bbox_head.bbox_head.cls_out_channels,
    sampling=False,
    unmap_outputs=True)

In [None]:
print(describe(all_labels))
print(describe(all_label_weights))
print(describe(all_bbox_targets))
print(describe(all_bbox_weights))
print(describe(pos_inds_list))
print(describe(neg_inds_list))

In [None]:
from src.quiz.quiz2 import images_to_levels

"""
이미지별로 구성한 target을 다시 level별로 구성되도록 형태를 바꿔줍니다.
"""
labels_list = images_to_levels(all_labels, num_level_anchors)
label_weights_list = images_to_levels(all_label_weights, num_level_anchors)
bbox_targets_list = images_to_levels(all_bbox_targets, num_level_anchors)
bbox_weights_list = images_to_levels(all_bbox_weights, num_level_anchors)

In [None]:
print(describe(labels_list))
print(describe(label_weights_list))
print(describe(bbox_targets_list))
print(describe(bbox_weights_list))

### get loss

In [None]:
"""
positive / negative sample의 개수를 각각 들고있습니다.
loss를 구하는 과정에서 normalize를 거치게 되며, 이는 곧 sample의 개수로 나누는 것을 의미합니다.

RetinaNet은 positive sample만을 사용하므로 사실상 num_total_neg는 무의미합니다.
"""
from src.quiz.quiz3 import loss_single

num_total_pos = sum([max(inds.numel(), 1) for inds in pos_inds_list])
num_total_neg = sum([max(inds.numel(), 1) for inds in neg_inds_list])

losses_cls, losses_reg = multi_apply(
    loss_single,
    cls_score,
    bbox_pred,
    labels_list,
    label_weights_list,
    bbox_targets_list,
    bbox_weights_list,
    num_total_samples=num_total_pos,
    cfg=cfg.train_cfg,
    cls_out_channels=bbox_head.bbox_head.cls_out_channels)

In [None]:
print(describe(losses_cls))
print(losses_cls)
print(describe(losses_reg))
print(losses_reg)

## get results

### get bboxes

In [None]:
from src.quiz.quiz8 import get_bboxes

bbox_list = get_bboxes(cls_score, bbox_pred, img_metas, test_cfg,
                       bbox_head.bbox_head.anchor_generators, bbox_head.bbox_head.anchor_strides, bbox_head.bbox_head.cls_out_channels)

### get result

In [None]:
from src.core import bbox2result

bbox_results = [
    bbox2result(det_bboxes, det_labels, bbox_head.bbox_head.num_classes)
    for det_bboxes, det_labels in bbox_list
]

### Draw bbox on Image

In [None]:
from src.visualization.show_result import show_result

show_result(sample, bbox_results[0], cfg.img_norm_cfg, 'coco')