# Implementation of Matching Algorithm.

In [1]:
from importlib.util import find_spec
if find_spec("model") is None:
    import sys
    sys.path.append('..')

In [5]:
from typing import List
import torch
from torch import Tensor
import torch.nn as nn
import torch.nn.functional as F

In [6]:
from layers.wrappers import nonzero_tuple

## Matching Strategy.
paper: [Faster R-CNN: Towards Real-Time Object Detection with Region Proposal Networks](https://arxiv.org/abs/1506.01497)

In [11]:
class Matcher:
    """
    This class assigns to each prediction "element" (e.g., a box) a
    ground-truth element. Each predicted element with have exactly zero
    or one matches. Each ground-truth element may be matched to zero or
    more predicted elements.
    
    The match is determined by the MxN `match_quality_matrix`, that
    characterizes how well each (ground-truth, prediction) pair match.
    
    i.e. In the case of boxes we can use the IOU between pairs. 
    
    The matcher returns:
        1. A vector of length N containing the index of the ground-truth
           element m in [0, M) that matches to prediction n in [0, N).
           
        2. A vector of length N containing the labels for each prediction. 
    """
    
    def __init__(
        self, thresholds: List[float], labels: List[int], allow_low_quality_matches: bool = False
    ):
        """
        Args:
            thresholds: a list of thresholds used to stratify predictions
                into levels.
            labels: a list of values to lable predictions belonging at each level.
                A label can be one of {-1, 0, 1} signifying {ignore, negative class, positive class},
                respectively.
            allow_low_quality_matches: if True, produce additional matches for predictions
                with maximum match quality lower than high_threshold.
                See set_low_quality_matches_ for more details.
        """
        thresholds = thresholds[:]
        assert thresholds[0] > 0
        thresholds.insert(0, -float("inf"))
        thresholds.append(float("inf"))
        
        # Currently torchscript does not support all pos generator.
        assert all([low <= high] for (low, high) in zip(thresholds[:-1], thresholds[1:]))
        assert all([l in [-1, 0, 1] for l in labels])
        assert len(labels) == len(thresholds) - 1
        
        self.thresholds = thresholds
        self.labels = labels
        self.allow_low_quality_matches = allow_low_quality_matches
        
    def __call__(self, match_quality_matrix: Tensor):
        """
        Args:
            match_quality_matrix (Tensor[float]): an MxN tensor containing the pairwise quality between M 
                ground-truth elements and N predicted elements. All elements must be >= 0
                (due to the use of `torch.nonzero` in `set_low_quality_matches_` methods.)
        Returns:
            matches (Tensor[int64]): a vector of length N, where matches[i] is a matched 
                ground-truth index in [0, M).
            match_labels (Tensor[int8]): a vector of length N where match_labels[i] indicates
                whether a prediction is a true or false positive or ignored.
        """
        assert match_quality_matrix.dim() == 2
        if match_quality_matrix.numel() == 0:
            default_matches = match_quality_matrix.new_full(
                (match_quality_matrix.size(1),), 0, dtype=torch.int64
            )
            
            # No gt boxes exits. So set labels to `self.labels[0]` which is usally background.
            # To ignore instead make labels=[-1, 0, -1, 1].
            default_match_labels = match_quality_matrix.new_full(
                (match_quality_matrix.size(1), ), self.labels[0], dtype=torch.int8
            )
            
            return default_matches, default_match_labels
        
        assert torch.all(match_quality_matrix >= 0)
        matched_vals, matches = match_quality_matrix.max(dim=0)
        match_labels = matches.new_full(matches.size(), 1, dtype=torch.int8)
        
        for (l, low, high) in zip(self.labels, self.thresholds[:-1], self.thresholds[1:]):
            low_high = (matched_vals >= low) & (matched_vals < high)
            match_labels[low_high] = l
            
        if self.allow_low_quality_matches:
            self.set_low_quality_matches_(match_labels, match_quality_matrix)
            
        return matches, match_labels
    
    def set_low_quality_matches_(self, match_labels, match_quality_matrix):
        """
        Produce additional matches for predictions that have only low_quality matches.
        
        Specifically, for each ground-truth label element find the set of predictions that have
        maximum overlap with it and set them to match ground-truth if unmatched previously.
        
        This function implements the RPN assignment case (i) in Sec. 3.1.2 of
        :paper:`Faster R-CNN`.
        """
        highest_quality_foreach_gt, _ = match_quality_matrix.max(dim=1)
        _, pred_inds_with_highest_quality = nonzero_tuple(
            match_quality_matrix == highest_quality_foreach_gt[:, None]
        )
        
        match_labels[pred_inds_with_highest_quality] = 1


In [12]:
matcher = Matcher([0.4, 0.5], [-1, 0, 1])

## Test Matcher

In [2]:
import torch
import random

In [3]:
from model.backbone.retina_meta import RetinaNetFPN50, RetinaNetHead
from model.backbone.resnet import ResNet50
from model.anchor_generator import AnchorBoxGenerator
from model.matcher import Matcher
from utils.box_utils import pairwise_iou, cat_boxes
from utils.shape_utils import permute_to_N_HWA_K

In [4]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [5]:
device

device(type='cuda')

In [20]:
data = torch.randn((16, 3, 512, 512)).to(device)
target_boxes = [torch.randn((random.randint(1, 7), 4)).to(device) for i in range(16)]

In [7]:
len(target_boxes), target_boxes[1].shape

(16, torch.Size([5, 4]))

In [22]:
num_classes = 20
num_anchors = 9
total_anchors = (
    64 * 64 * num_anchors
    + 32 * 32 * num_anchors
    + 16 * 16 * num_anchors
    + 8 * 8 * num_anchors
    + 4 * 4 * num_anchors
)

In [9]:
backbone = ResNet50().to(device)
fpn_backbone = RetinaNetFPN50().to(device)
head = RetinaNetHead(num_classes).to(device)
anchor_gen = AnchorBoxGenerator(
    sizes=[32., 64., 128., 256., 512.],
    aspect_ratios=[0.5, 1., 2.],
    scales=[1., 2 ** (1 / 3), 2 ** (2 / 3)],
    strides=[2, 2, 2, 2, 2]
).to(device)

In [10]:
_, C3, C4, C5 = backbone(data)
P3, P4, P5, P6, P7 = fpn_backbone(C3, C4, C5)
pred_logits, pred_bboxes = head(P3, P4, P5, P6, P7)
all_anchors = anchor_gen([P3, P4, P5, P6, P7])

In [11]:
all_anchors[0].shape

torch.Size([36864, 4])

In [12]:
reshaped_logits = [
      permute_to_N_HWA_K(pred_logits[k], num_classes) for k in pred_logits
  ]

reshaped_bboxes = [permute_to_N_HWA_K(pred_bboxes[k], 4) for k in pred_bboxes]


In [15]:
anchor_matcher = Matcher([0.4, 0.5], [-1, 0, 1])

In [73]:
thresholds = [0.4, 0.5]
labels = [-1, 0, 1]

In [75]:
for (l, low, high) in zip(labels, thresholds[:-1], thresholds[1:]):
    print(f"l = {l}; low = {low}; high = {high}")

l = -1; low = 0.4; high = 0.5


In [76]:
assert torch.all(match_qual >= 0)

In [77]:
matched_vals, matches = match_qual.max(dim=0)

In [81]:
matches

tensor([0, 0, 0,  ..., 0, 0, 0], device='cuda:0')

In [90]:
target_boxes[0]

tensor([[-0.6480, -0.6852,  1.2240,  0.4158],
        [ 0.2791, -0.7394, -1.7727, -1.5212],
        [ 0.6535,  0.7886,  0.1682,  1.2263],
        [ 0.3984,  1.8456, -1.0409,  2.3854]], device='cuda:0')

In [105]:
anchors[5433].view(1, 4)

tensor([[19.0812,  1.0406, 90.9188, 36.9594]], device='cuda:0')

In [107]:
target_boxes[0][2].view(1, 4)

tensor([[0.6535, 0.7886, 0.1682, 1.2263]], device='cuda:0')

In [103]:
pairwise_iou(anchors[5433].view(1, 4), target_boxes[0][3].view(1, 4))

tensor([[0.]], device='cuda:0')

In [86]:
torch.where(matches == 0)

(tensor([    0,     1,     2,  ..., 49101, 49102, 49103], device='cuda:0'),)

In [None]:
match_labels = matches.new_full(matches.size(), 1, dtype=torch.int8)

In [18]:
anchors = cat_boxes(all_anchors).to(device)

In [21]:
match_qual = pairwise_iou(target_boxes[0], anchors)

In [24]:
assert match_qual.shape == (target_boxes[0].size(0), total_anchors)

In [27]:
anchor_matcher(match_qual)

(tensor([0, 0, 0,  ..., 0, 0, 0], device='cuda:0'),
 tensor([-1, -1, -1,  ..., -1, -1, -1], device='cuda:0', dtype=torch.int8))

In [28]:
matched_idxs, anchor_labels = anchor_matcher(match_qual)

In [31]:
assert matched_idxs.shape == (total_anchors, )
assert anchor_labels.shape == (total_anchors, )

In [45]:
torch.where(anchor_labels == -1)

(tensor([    0,     1,     2,  ..., 49101, 49102, 49103], device='cuda:0'),)

In [67]:
(torch.where(matched_idxs == 3)[0].shape)[0] + (torch.where(matched_idxs == 0)[0].shape)[0]

49104

In [62]:
torch.where(anchor_labels == -1)[0].shape

torch.Size([49104])

In [46]:
torch.where(anchor_labels == 0)

(tensor([], device='cuda:0', dtype=torch.int64),)

In [47]:
torch.where(anchor_labels == 1)

(tensor([], device='cuda:0', dtype=torch.int64),)

In [52]:
target_boxes[0][matched_idxs].shape

torch.Size([49104, 4])

In [50]:
total_anchors

49104

In [68]:
m = torch.randn((5, 10))

In [72]:
max_vals, max_idxs = m.max(dim=0)
max_vals.shape, max_idxs.shape

(torch.Size([10]), torch.Size([10]))