In [1]:
import torch

In [120]:
def bbox_from_keypoints(keypoints):
    """Extract a bbox from 2D keypoints.
    
    Arguments:
    keypoints -- tensor of 2D joint information (B, J, D)
    """
    x_min = torch.min(keypoints[:, :, 0], dim=1).values.unsqueeze(1)
    y_min = torch.min(keypoints[:, :, 1], dim=1).values.unsqueeze(1)
    x_max = torch.max(keypoints[:, :, 0], dim=1).values.unsqueeze(1)
    y_max = torch.max(keypoints[:, :, 1], dim=1).values.unsqueeze(1)
    
    return torch.cat([x_min, y_min, x_max, y_max], dim=1)


In [133]:
keypoints = torch.FloatTensor([
    [[1, 1], [2, 2], [0, 3]],
    [[1, 1], [1, 2], [5, 5]]
]).requires_grad_()

bbox = bbox_from_keypoints(keypoints)

In [128]:
def bbox_losses(bbox_pred, bbox_real):
    """Implement GIoU according to https://giou.stanford.edu/GIoU.pdf.
    
    Arguments:
    bbox_pred -- tensor of predicted bboxes given by opposing corners (B, 4)
    bbox_real -- tensor of ground truth bboxes given by opposing corners (B, 4)
    """
    
    assert len(bbox_pred.shape) == 2 and \
           len(bbox_real.shape) == 2 and \
           bbox_pred.shape == bbox_real.shape

    x1_real, y1_real, x2_real, y2_real = torch.split(bbox_real, 1, dim=1)
    
    x1_pred = torch.min(bbox_pred[:, [0, 2]], dim=1).values.unsqueeze(1)
    y1_pred = torch.min(bbox_pred[:, [1, 3]], dim=1).values.unsqueeze(1)
    x2_pred = torch.max(bbox_pred[:, [0, 2]], dim=1).values.unsqueeze(1)
    y2_pred = torch.max(bbox_pred[:, [1, 3]], dim=1).values.unsqueeze(1)
    
    x1_crop = torch.min(x1_pred, x1_real)
    y1_crop = torch.min(y1_pred, y1_real)
    x2_crop = torch.max(x2_pred, x2_real)
    y2_crop = torch.max(y2_pred, y2_real)
        
    A_real = (x2_real - x1_real) * (y2_real - y1_real)
    A_pred = (x2_pred - x1_pred) * (y2_pred - y1_pred)
    A_crop = (x2_crop - x1_crop) * (y2_crop - y1_crop)
    
    x1_int = torch.max(x1_pred, x1_real)
    y1_int = torch.max(y1_pred, y1_real)
    x2_int = torch.min(x2_pred, x2_real)
    y2_int = torch.min(y2_pred, y2_real)

    mask = (x2_int > x1_int).float() * (y2_int > y1_int).float()
    
    intersection = (x2_int - x1_int) * (y2_int - y1_int) * mask
    union = A_real + A_pred - intersection
    
    iou = torch.mean(intersection / union)
    iou_general = torch.mean(iou - (A_crop - union) / A_crop)
    
    return 1 - iou, 1 - iou_general


In [131]:
bbox_pred = torch.stack([
    torch.FloatTensor([1, 1, 2, 2]),
    torch.FloatTensor([2, 2, 1, 1]),
    torch.FloatTensor([1, 1, 2, 2]),
    torch.FloatTensor([1, 1, 2, 2]),
    torch.FloatTensor([1, 1, 2, 2])
]).requires_grad_()

bbox_real = torch.stack([
    torch.Tensor([1, 1, 2, 2]),
    torch.Tensor([1, 1, 2, 2]),
    torch.Tensor([1.5, 1.5, 2.5, 2.5]),
    torch.Tensor([2, 2, 3, 3]),
    torch.Tensor([3, 3, 4, 4]),
])

loss, loss_general = bbox_losses(bbox_pred, bbox_real)