In [None]:
#IOU

def iou(pred_bboxes, target_bboxes):
    """
    Calculates rowise bounding box IoUs between two bounding box tensors
    
    Parameters:
        predicted_bboxes: Tensor containing a batch of predicted bounding boxes
            type:tensor 
            shape:[N,4]
            format:[x1,y1,x2,y2]
        target_bboxes: Tensor containing a batch of ground truth bounding boxes
            type:tensor 
            shape:[N,4]
            format:[x1,y1,x2,y2]
    Result:
        IoUs: Batch of IoUs 
        type: tensor
        shape: 1D tensor of size N
    
    """
    box1_x1 = predicted_bboxes[:, 0]
    box1_y1 = predicted_bboxes[:, 1]
    box1_x2 = predicted_bboxes[:, 2]
    box1_y2 = predicted_bboxes[:, 3] 
    box2_x1 = target_bboxes[:, 0]
    box2_y1 = target_bboxes[:, 1]
    box2_x2 = target_bboxes[:, 2]
    box2_y2 = target_bboxes[:, 3]

    x1 = torch.max(box1_x1, box2_x1)
    y1 = torch.max(box1_y1, box2_y1)
    x2 = torch.min(box1_x2, box2_x2)
    y2 = torch.min(box1_y2, box2_y2)

    intersection = (x2 - x1).clamp(0) * (y2 - y1).clamp(0)
    box1_area = abs((box1_x2 - box1_x1) * (box1_y2 - box1_y1))
    box2_area = abs((box2_x2 - box2_x1) * (box2_y2 - box2_y1))
    return intersection / (box1_area + box2_area - intersection + 1e-16)



In [None]:
def nms(bboxes, confidence_scores, confidence_threshold, iou_threshold):
    """
    Performs Non Max Suppression
    
    Parameters:
        bboxes: Tensor containing a batch of bounding boxes
            type:tensor 
            shape:[N,4]
            format:[x1,y1,x2,y2]
        confidence_scores: Confidence scores for each bounding boxes
            type:tensor
            shape: 1D tensor of size N
        confidence_threshold: Minimum Confidence threshold required to keep the bounding box
            type:int
        iou_threshold: Overlapping Bounding boxes above this IoU threshold will be discarded.
            type:int
            
    Result:
        
    """
    bboxes = bboxes[confidence_scores.argsort(descending=True)]
    bboxes = bboxes[confidence_scores>confidence_threshold]
    chosen_bboxes = []
    while len(bboxes)>=2:
        chosen_bboxes.append(bboxes[0])
        other_bboxes = bboxes[1:]
    