# mAP Implementation with details
- We will implement Mean Average Precision (mAP) with pretrained YOLOv3 model (which publicly available on pytorch model)

In [None]:
import torch

output = torch.zeros(2, 30, 7, 7)
label = torch.zeros(2, 25, 7, 7)

In [None]:
from torch import Tensor
from torch.nn import Module

class MeanAvgPrecision(Module):
    def __init__(self):
        super().__init__()
        
    def forward(self, output, label):
        classes = 20
        bboxes = (output.shape[1] - classes) // 5
        assert((output.shape[1] - classes) % 5 == 0)
        
        index_table = self.generate_cell_index_table(output)
        
        for c in range(20):
            class_mask = torch.eq(torch.argmax(output[:, -classes:, :, :], 1, keepdim=True), c)
            
            # [2, 10, 49], [2, 10, 0], ... [B, 5 * 2, I]
            class_bboxes = torch.masked_select(output[:, :bboxes * 5, :, :], class_mask).view(output.shape[0], 10, -1)
            # [2, 2, 49], [2, 2, 0], ... [B, 2, I]
            class_bbox_indicies = torch.masked_select(index_table, class_mask).view(output.shape[0], 2, -1)
            
            # concat two box matrix
            ##! TODO: idx_y, idx_x OR idx_x, idx_y ?
            # [B, 7, I], 7 -> [x, y, w, h, c, idx_y, idx_x]
            class_bboxes = class_bboxes.view(class_bboxes.shape[0], 5, -1)
            class_bbox_indicies = torch.repeat_interleave(class_bbox_indicies, 2, 1).view(class_bbox_indicies.shape[0], 2, -1)
            class_bboxes_with_index = torch.cat([class_bboxes, class_bbox_indicies], dim=1)
            
            confidence_indicies = torch.argsort(class_bboxes_with_index[:, 4:5], dim=0, descending=True).repeat(1, 7, 1)
            
            # [B, 7, 98] sorted by confidence in descending order
            sorted_class_bboxes_with_index = torch.gather(class_bboxes_with_index, 1, confidence_indicies)
            
            for b in range(sorted_class_bboxes_with_index.shape[2]):
                current_bbox = class_bboxes_with_index[:, :, b]
                
#             print("Class %d entries: " % c, class_bboxes.shape)
#             print("Class %d entry indicies must be [B, 2, I] and same B, I as above: \n\t" % c, class_bbox_indicies.shape)
            
        return 0
    
    @staticmethod
    def generate_cell_index_table(output):
        index_map_x = torch.arange(0, 7, device=output.device).repeat(7)
        index_map_y = torch.repeat_interleave(torch.arange(0, 7, device=output.device), 7)
        index_map = torch.unsqueeze(torch.stack([index_map_y, index_map_x], dim=0).view(2, 7, 7), 0)
        return index_map
        # index_map -> [1, 2, 7, 7]
    
    @staticmethod
    def get_iou_between(xywh1: Tensor, index1: Tensor, xywh2: Tensor, index2: Tensor) -> Tensor:
        # xywh1, xywh2 -> [B, 4, I] (I refers to items to compare)
        # index1, index2 -> [B, 2, I] (same as above, only has y, x of their I rows)
        pass
    
    @staticmethod
    def get_iou_xywh(input_xywh: Tensor, label_xywh: Tensor) -> Tensor:
        index_map_x = torch.arange(0, 7, device=input_xywh.device).repeat(7)
        index_map_y = torch.repeat_interleave(torch.arange(0, 7, device=input_xywh.device), 7)
        index_map = torch.unsqueeze(torch.stack([index_map_y, index_map_x], dim=0).view(2, 7, 7), 0)
        input_xy_global = (input_xywh[:, :2, :, :] + index_map) / 7
        input_width_half, input_height_half = (input_xywh[:, 2, :, :] / 2), (input_xywh[:, 3, :, :] / 2)
        input_xmin = input_xy_global[:, 0, :, :] - input_width_half  # x_center - width / 2
        input_xmax = input_xy_global[:, 0, :, :] + input_width_half
        input_ymin = input_xy_global[:, 1, :, :] - input_height_half
        input_ymax = input_xy_global[:, 1, :, :] + input_height_half

        label_xy_global = (label_xywh[:, :2, :, :] + index_map) / 7
        label_width_half, label_height_half = (label_xywh[:, 2, :, :] / 2), (label_xywh[:, 3, :, :] / 2)
        label_xmin = label_xy_global[:, 0, :, :] - label_width_half  # x_center - width / 2
        label_xmax = label_xy_global[:, 0, :, :] + label_width_half
        label_ymin = label_xy_global[:, 1, :, :] - label_height_half
        label_ymax = label_xy_global[:, 1, :, :] + label_height_half

        input_volume = input_xywh[:, 2, :, :] * input_xywh[:, 3, :, :]
        label_volume = label_xywh[:, 2, :, :] * label_xywh[:, 3, :, :]
        intersect_width = torch.minimum(input_xmax, label_xmax) - torch.maximum(input_xmin, label_xmin)
        intersect_height = torch.minimum(input_ymax, label_ymax) - torch.maximum(input_ymin, label_ymin)
        intersect_volume = intersect_width * intersect_height
        union_volume = input_volume + label_volume - intersect_volume

        return intersect_volume / union_volume
        
mean_avg_precision = MeanAvgPrecision()
print(mean_avg_precision(output, label))

In [None]:
torch.repeat_interleave(torch.Tensor([[1], [2], [3]]), 4, 1)