1. Region Proposal network (RPN)
2. RPN loss functions
3. Region of Interest Pooling (ROI)
4. ROI loss functions

In [2]:
import torch
import torchvision
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import itertools

In [3]:
image = torch.zeros((1,3,800,800)).float()

# [y1, x1, y2, x2]
bbox = torch.FloatTensor([[20, 30, 400, 500], [300, 400, 500, 600]])
labels = torch.LongTensor([6, 8])

# 1x1 in feature map -> 16x16 in image
sub_sample = 16

In [20]:
model = torchvision.models.vgg16(pretrained=True)
fe = list(model.features)
req_features = fe[:30]

faster_rcnn_feature = nn.Sequential(*req_features)

sample_output = faster_rcnn_feature(image)
print(sample_output.shape)

torch.Size([1, 512, 50, 50])


> Resion Proposal Network

In [17]:
mid_channels = 512
in_channels = 512
n_anchor = 9        # Number of anchors at each location in the feature map

conv1 = nn.Conv2d(in_channels, mid_channels, 3, 1, 1)
reg_layer = nn.Conv2d(mid_channels, n_anchor*4, 1, 1, 0)
cls_layer = nn.Conv2d(mid_channels, n_anchor*2, 1, 1, 0)

In [19]:
# conv1 sliding layer
conv1.weight.data.normal_(0, 0.01)
conv1.bias.data.zero_()

# reg_layer
reg_layer.weight.data.normal_(0, 0.01)
reg_layer.bias.data.zero_()

# cls_layer
cls_layer.weight.data.normal_(0, 0.01)
cls_layer.bias.data.zero_()

tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.])

In [29]:
x= conv1(sample_output)
pred_anchor_locs = reg_layer(x)
pred_cls_scores = cls_layer(x)

print(pred_anchor_locs.shape, pred_cls_scores.shape)

torch.Size([1, 36, 50, 50]) torch.Size([1, 18, 50, 50])


In [30]:
pred_anchor_locs = pred_anchor_locs.permute(0, 2, 3, 1).contiguous().view(1, -1, 4)
print(pred_anchor_locs.shape)

pred_cls_scores = pred_cls_scores.permute(0, 2, 3 ,1).contiguous()
print(pred_cls_scores.shape)

objectness_scores = pred_cls_scores.view(1, 50, 50, 9, 2)[:, :, :, :, 1].contiguous().view(1, -1)
print(objectness_scores.shape)

pred_cls_scores = pred_cls_scores.view(1, -1, 2)
print(pred_cls_scores.shape)

torch.Size([1, 22500, 4])
torch.Size([1, 50, 50, 18])
torch.Size([1, 22500])
torch.Size([1, 22500, 2])


## Anchor boxes

In [5]:
ratios = [0.5, 1, 2]
anchor_sizes = [8, 16, 32]
anchor_number = len(ratios) * len(anchor_sizes)
print(f'Total anchor #: {anchor_number}')

Total anchor #: 9


In [6]:
ctr_y = sub_sample / 2
ctr_x = sub_sample / 2

anchors_template = np.zeros((9,4))

for i, ratio in enumerate(ratios):
    for j, size in enumerate(anchor_sizes):
        h = size * np.sqrt(ratio) * sub_sample
        w = size / np.sqrt(ratio) * sub_sample
        
        y1 = -h/2
        x1 = -w/2
        y2 =  h/2
        x2 =  w/2
        anchor = [y1, x1, y2, x2]
        anchors_template[i*len(ratios) + j] = anchor

In [7]:
feature_map_size = 800 // 16
ctr_x_all = np.arange(8, (feature_map_size + 1) * 16 - 8, 16)
ctr_y_all = np.arange(8, (feature_map_size + 1) * 16 - 8, 16)

print(f'ctr_x_all:\n{ctr_x_all}')

ctr = np.zeros((feature_map_size, feature_map_size, 2), dtype = np.float32)

for y in range(feature_map_size):
    for x in range(feature_map_size):
        ctr[y, x] = np.array([ctr_y_all[y], ctr_x_all[x]])
        
        
print(f'ctr_all:\n {ctr}')

ctr_x_all:
[  8  24  40  56  72  88 104 120 136 152 168 184 200 216 232 248 264 280
 296 312 328 344 360 376 392 408 424 440 456 472 488 504 520 536 552 568
 584 600 616 632 648 664 680 696 712 728 744 760 776 792]
ctr_all:
 [[[  8.   8.]
  [  8.  24.]
  [  8.  40.]
  ...
  [  8. 760.]
  [  8. 776.]
  [  8. 792.]]

 [[ 24.   8.]
  [ 24.  24.]
  [ 24.  40.]
  ...
  [ 24. 760.]
  [ 24. 776.]
  [ 24. 792.]]

 [[ 40.   8.]
  [ 40.  24.]
  [ 40.  40.]
  ...
  [ 40. 760.]
  [ 40. 776.]
  [ 40. 792.]]

 ...

 [[760.   8.]
  [760.  24.]
  [760.  40.]
  ...
  [760. 760.]
  [760. 776.]
  [760. 792.]]

 [[776.   8.]
  [776.  24.]
  [776.  40.]
  ...
  [776. 760.]
  [776. 776.]
  [776. 792.]]

 [[792.   8.]
  [792.  24.]
  [792.  40.]
  ...
  [792. 760.]
  [792. 776.]
  [792. 792.]]]


In [8]:
# anchors -> (H/16, W/16, 9, 4)   # stride가 16인가?
anchors = np.zeros((feature_map_size, feature_map_size, 9, 4))

for y in range(feature_map_size):
    for x in range(feature_map_size):
        anchors[y, x] = (ctr[y, x] + anchors_template.reshape(-1, 2, 2)).reshape(-1, 4)

In [16]:
anchors = anchors.reshape(-1, 4)
index_inside = np.where((anchors[:, 0] >= 0) & 
                        (anchors[:, 1] >= 0) &
                        (anchors[:, 2] <= 800) &
                        (anchors[:, 3] <= 800))[0]
print(index_inside.shape)

(8940,)
