In [1]:
import torch
import sys

from torch import Tensor

In [2]:
sys.path.append("../src")

In [3]:
from nms.cluster_nms import diou

In [4]:
state = torch.load("model_out.pth")

In [6]:
outs = state['box_outputs']
outs = outs[-1]
outs = outs.permute(0,2,3,1).reshape(outs.size(0), -1, 4)

In [7]:
outs.shape

torch.Size([5, 324, 4])

In [8]:
def diou(box_a: Tensor, box_b: Tensor, gamma=0.9, eps=1e-8) -> Tensor:
    assert box_a.ndim == 3
    assert box_b.ndim == 3
    assert box_a.size(0) == box_b.size(0)
    
    A, B = box_a.size(1), box_b.size(1)
    box_a = box_a.unsqueeze(2).expand(-1, -1, A, -1)
    box_b = box_b.unsqueeze(1).expand(-1, B, -1, -1)
    
    inter_yx0 = torch.max(box_a[..., :2], box_b[..., :2])
    inter_yx1 = torch.min(box_a[..., 2:4], box_b[..., 2:4])
    
    inter_hw = torch.clamp_min_(inter_yx1 - inter_yx0, 0)
    inter_area = torch.prod(inter_hw, dim=-1)
    del inter_hw, inter_yx0, inter_yx1
    
    hw_a = box_a[..., 2:4] - box_a[..., :2]
    hw_b = box_b[..., 2:4] - box_b[..., :2]
    
    area_a = torch.prod(hw_a, dim=-1)
    area_b = torch.prod(hw_b, dim=-1)
    
    union_area = area_a + area_b - inter_area
    iou = inter_area / (union_area + eps)
    del inter_area, union_area, area_a, area_b, hw_a, hw_b
    
    c_a = (box_a[..., :2] + box_a[..., 2:4]) / 2
    c_b = (box_b[..., :2] + box_b[..., 2:4]) / 2
    inter_diag = torch.pow(c_b - c_a, 2).sum(dim=-1)
    
    clos_yx0 = torch.min(box_a[..., :2], box_b[..., :2])
    clos_yx1 = torch.max(box_a[..., 2:4], box_b[..., 2:4])
    clos_hw = torch.clamp_min_(clos_yx1 - clos_yx0, 0)
    clos_diag = torch.pow(clos_hw, 2).sum(dim=-1)
    del clos_yx0, clos_yx1, clos_hw
    
    dist = inter_diag / (clos_diag + eps)
    return iou - dist ** gamma

In [25]:
N, D = outs.shape[:2]
C = 5

# [classes, detections]
cur_scores = torch.rand(C, D)
# [detections]
conf_scores, _ = torch.max(cur_scores, dim=0)

keep = conf_scores >= 0.5
scores = cur_scores[:, keep]
boxes = outs[0, keep, :]

boxes.shape, scores.shape

(torch.Size([310, 4]), torch.Size([5, 310]))

In [81]:
def cc_cluster_diounms(boxes, scores, iou_threshold=0.5, top_k=200):
    assert boxes.ndim == 2
    assert boxes.size(-1) == 4
    
    scores, classes = torch.max(scores, dim=0)
    # scores: [detections]
    _, idx = scores.sort(0, descending=True)
    idx = idx[:top_k]
    top_k_boxes = boxes[idx][None, ...]
    
    # [1, top_k, top_k] -> [top_k, top_k]
    iou = diou(top_k_boxes, top_k_boxes)[0].triu_(diagonal=1)
    best_iou = torch.zeros(top_k)
    
    for i in range(top_k):
        iou0 = iou
        best_iou, _ = torch.max(iou, dim=0)
        keep = (best_iou <= iou_threshold)[:, None].expand_as(iou)
        iou = torch.where(keep, iou, torch.zeros_like(iou))
        
        if (iou == iou0).all():
            print(i)
            break
    
    # filter out boxes, that are too close to each other
    idx = idx[best_iou <= iou_threshold]
    return boxes[idx], scores[idx], classes[idx]

In [82]:
pred_boxes, pred_scores, pred_classes = cc_cluster_diounms(boxes, scores, iou_threshold=0.5)

pred_boxes.shape, pred_scores.shape, pred_classes.shape

0


(torch.Size([200, 4]), torch.Size([200]), torch.Size([200]))