In [1]:
import torch
import numpy as np
from collections import Counter
import matplotlib.pyplot as plt
import config
import matplotlib.patches as patches
from torch.utils.data import DataLoader

In [2]:
def iou_width_height(boxes1, boxes2):
    """
    Parameters:
        boxes1 (tensor): width and height of the first bounding boxes
        boxes2 (tensor): width and height of the second bounding boxes
    Returns:
        tensor: Intersection over union of the corresponding boxes
    """
    intersection = torch.min(boxes1[..., 0], boxes2[..., 0]) * torch.min(
        boxes1[..., 1], boxes2[..., 1]
    )
    union = (
        boxes1[..., 0] * boxes1[..., 1] + boxes2[..., 0] * boxes2[..., 1] - intersection
    )
    return intersection / union


def intersection_over_union(boxes_preds, boxes_labels, box_format="midpoint"):
    """
    Video explanation of this function:
    https://youtu.be/XXYG5ZWtjj0

    This function calculates intersection over union (iou) given pred boxes
    and target boxes.

    Parameters:
        boxes_preds (tensor): Predictions of Bounding Boxes (BATCH_SIZE, 4)
        boxes_labels (tensor): Correct labels of Bounding Boxes (BATCH_SIZE, 4)
        box_format (str): midpoint/corners, if boxes (x,y,w,h) or (x1,y1,x2,y2)

    Returns:
        tensor: Intersection over union for all examples
    """

    if box_format == "midpoint":
        box1_x1 = boxes_preds[..., 0:1] - boxes_preds[..., 2:3] / 2
        box1_y1 = boxes_preds[..., 1:2] - boxes_preds[..., 3:4] / 2
        box1_x2 = boxes_preds[..., 0:1] + boxes_preds[..., 2:3] / 2
        box1_y2 = boxes_preds[..., 1:2] + boxes_preds[..., 3:4] / 2
        box2_x1 = boxes_labels[..., 0:1] - boxes_labels[..., 2:3] / 2
        box2_y1 = boxes_labels[..., 1:2] - boxes_labels[..., 3:4] / 2
        box2_x2 = boxes_labels[..., 0:1] + boxes_labels[..., 2:3] / 2
        box2_y2 = boxes_labels[..., 1:2] + boxes_labels[..., 3:4] / 2

    if box_format == "corners":
        box1_x1 = boxes_preds[..., 0:1]
        box1_y1 = boxes_preds[..., 1:2]
        box1_x2 = boxes_preds[..., 2:3]
        box1_y2 = boxes_preds[..., 3:4]
        box2_x1 = boxes_labels[..., 0:1]
        box2_y1 = boxes_labels[..., 1:2]
        box2_x2 = boxes_labels[..., 2:3]
        box2_y2 = boxes_labels[..., 3:4]

    x1 = torch.max(box1_x1, box2_x1)
    y1 = torch.max(box1_y1, box2_y1)
    x2 = torch.min(box1_x2, box2_x2)
    y2 = torch.min(box1_y2, box2_y2)

    intersection = (x2 - x1).clamp(0) * (y2 - y1).clamp(0)
    box1_area = abs((box1_x2 - box1_x1) * (box1_y2 - box1_y1))
    box2_area = abs((box2_x2 - box2_x1) * (box2_y2 - box2_y1))

    return intersection / (box1_area + box2_area - intersection + 1e-6)


def non_max_suppression(bboxes, iou_threshold, threshold, box_format="corners"):
    """
    Video explanation of this function:
    https://youtu.be/YDkjWEN8jNA

    Does Non Max Suppression given bboxes

    Parameters:
        bboxes (list): list of lists containing all bboxes with each bboxes
        specified as [class_pred, prob_score, x1, y1, x2, y2]
        iou_threshold (float): threshold where predicted bboxes is correct
        threshold (float): threshold to remove predicted bboxes (independent of IoU)
        box_format (str): "midpoint" or "corners" used to specify bboxes

    Returns:
        list: bboxes after performing NMS given a specific IoU threshold
    """

    assert type(bboxes) == list

    bboxes = [box for box in bboxes if box[1] > threshold]
    bboxes = sorted(bboxes, key=lambda x: x[1], reverse=True)
    bboxes_after_nms = []

    while bboxes:
        chosen_box = bboxes.pop(0)

        bboxes = [
            box
            for box in bboxes
            if box[0] != chosen_box[0]
            or intersection_over_union(
                torch.tensor(chosen_box[2:]),
                torch.tensor(box[2:]),
                box_format=box_format,
            )
            < iou_threshold
        ]

        bboxes_after_nms.append(chosen_box)

    return bboxes_after_nms

In [3]:
def mean_average_precision(pred_boxes, true_boxes, iou_threshold=0.5, box_format='corners', num_classes=20):
    average_precisions = []
    epsilon = 1e-6

    for c in range(num_classes):
        detections = []
        ground_truths = []

        for pred_box in pred_boxes:
            if pred_box[1] == c:
                detections.append(pred_box)

        for true_box in true_boxes:
            if true_box[1] == c:
                ground_truths.append(true_box)

        amount_boxes = Counter([gt[0] for gt in ground_truths])

        for key, val in amount_boxes.items():
            amount_boxes[key] = torch.zeros(val)

        detections.sort(key=lambda x: x[2], reverse=True)
        TP = torch.zeros(len(detections))
        FP = torch.zeros(len(detections))
        total_true_boxes = len(ground_truths)

        if total_true_boxes == 0:
            continue

        for detection_idx, detection in enumerate(detections):
            # consider each image at a time only
            ground_truth_of_current_img = [true_box for true_box in ground_truths if true_box[0] == detection[0]]

            best_iou = 0
            num_gts = len(ground_truth_of_current_img)

            for idx, gt in ground_truth_of_current_img:
                iou = intersection_over_union(torch.tensor(detection[3:]), torch.tensor(gt[3:]), box_format)

                if iou > best_iou:
                    best_iou = iou
                    best_gt_idx = idx

            if best_iou > iou_threshold:
                if amount_boxes[detection[0]][best_gt_idx] == 0:
                    amount_boxes[detection[0]][best_gt_idx] = 1
                    TP[detection_idx] = 1
                
                else:
                    FP[detection_idx] = 1

            else:
                FP[detection_idx] = 1

        TP_cumsum = TP.cumsum(dim=0)
        FP_cumsum = FP.cumsum(dim=0)
        recalls = TP_cumsum / (total_true_boxes + epsilon)
        precisions = TP_cumsum / (TP_cumsum + FP_cumsum + epsilon)
        precisions = torch.cat([torch.tensor([1]), precisions])
        recalls = torch.cat([torch.tensor([0]), recalls])

        average_precisions.append(torch.trapz(precisions, recalls))

    return sum(average_precisions) / len(average_precisions)

In [4]:
def plot_image(image, boxes):
    cmap = plt.get_cmap("tab20b")
    class_labels = config.PASCAL_CLASSES
    colors = [cmap(i) for i in np.linspace(0, 1, len(class_labels))]
    im = np.array(image)
    height, width, _ = im.shape

    fig, ax = plt.subplots(1)
    ax.imshow(im)

    for box in boxes:
        assert len(box) == 6 # class, confidence, x, y, w, h
        class_pred = box[0]
        box = box[2:]
        upper_left_x = box[0] - box[2] / 2
        upper_left_y = box[1] - box[3] / 2
        rect = patches.Rectangle((upper_left_x * width, upper_left_y * height),
                                 box[2] * width,
                                 box[3] * height,
                                 linewidth=2,
                                 edgecolor=colors[int(class_pred)],
                                 facecolor='none')
        ax.add_patch(rect)
        plt.text(
            upper_left_x * width, 
            upper_left_y * height,
            s=class_labels[int(class_pred)],
            color="white",
            verticalalignment="top",
            bbox={"color": colors[int(class_pred)], "pad": 0}
        )
    
    plt.show()

In [6]:
def cells_to_boxes(predictions, anchors, S, is_pred=True):
    # N x 3 x S x S x num_classes + 5
    batch_size = predictions.shape[0]
    num_anchors = len(anchors)
    box_predictions = predictions[..., 1:5]
    
    if is_pred:
        anchors = anchors.reshape(1, len(anchors), 1, 1, 2)
        box_predictions[..., 0:2] = torch.sigmoid(box_predictions[..., 0:2])
        box_predictions[..., 2:4] = torch.exp(box_predictions[..., 2:4]) * anchors
        scores = torch.sigmoid(predictions[..., 0])
        best_class = torch.argmax(predictions[..., 5:], dim=-1).unsqueeze(-1)
    
    else:
        scores = predictions[..., 0]
        best_class = predictions[..., 5]

    cell_indices = torch.arange(S).repeat(predictions.shape[0], 3, S, 1)
    x = 1 / S * (box_predictions[..., 0] + cell_indices)
    y = 1 / S * (box_predictions[..., 1] + cell_indices.permute(0, 1, 3, 2, 4))
    w_h = 1 / S * box_predictions[..., 2:4]
    converted_boxes = torch.cat([best_class, scores, x, y, w_h], dim=-1).reshape(batch_size, num_anchors * S * S, 6)

    return converted_boxes.tolist()

In [7]:
def get_evaluation_boxes(loader, model, iou_threshold, anchors, threshold, box_format='midpoint', device='cuda'):
    model.eval()
    train_idx = 0
    all_pred_boxes = []
    all_true_boxes = []

    for batch_idx, (x, labels) in enumerate(loader):
        x = x.to(device)

        with torch.no_grad():
            predictions = model(x)

        batch_size = x.shape[0]
        boxes = [[] for _ in range(batch_size)]
        for i in range(3):
            S = predictions[i].shape[2]
            anchor = torch.tensor([*anchors[i]]).to(device) * S
            boxes_scale_i = cells_to_boxes(predictions[i], anchor, S, True)
            for idx, box in enumerate(boxes_scale_i):
                boxes[idx] += box
        
        true_boxes = cells_to_boxes(labels[2], anchor, S, False)

        for idx in range(batch_size):
            nms_bboxes = non_max_suppression(boxes[idx], iou_threshold, threshold, box_format)

            for nms_box in nms_bboxes:
                all_pred_boxes.append([train_idx] + nms_box)
            
            for box in true_boxes[idx]:
                if box[1] > threshold:
                    all_true_boxes.append([train_idx] + box)

            train_idx += 1

    model.train()

    return all_pred_boxes, all_true_boxes

In [8]:
def check_class_accuracy(model, loader, threshold):
    model.eval()
    total_class_preds, correct_class = 0, 0
    total_no_obj, correct_no_obj = 0, 0
    total_obj, correct_obj = 0, 0

    for idx, (x, y) in enumerate(loader):
        x = x.to(config.DEVICE)

        with torch.no_grad():
            predictions = model(x)

        for i in range(3):
            y[i] = y[i].to(config.DEVICE)
            obj = y[i][..., 0] == 1
            noobj = y[i][..., 0] == 0

            correct_class += torch.sum(torch.argmax(predictions[i][..., 5:][obj], dim=-1) == y[i][..., 5][obj])
            total_class_preds += torch.sum(obj)

            obj_preds = torch.sigmoid(predictions[i][..., 0]) > threshold
            correct_obj += torch.sum(obj_preds[obj] == y[i][..., 0][obj])
            total_obj += torch.sum(obj)
            correct_no_obj += torch.sum(obj_preds[noobj] == y[i][..., 0][noobj])
            total_no_obj += torch.sum(noobj)

    print(f"Class accuracy is: {(correct_class/(total_class_preds+1e-16))*100:2f}%")
    print(f"No obj accuracy is: {(correct_no_obj/(total_no_obj+1e-16))*100:2f}%")
    print(f"Obj accuracy is: {(correct_obj/(total_obj+1e-16))*100:2f}%")
    model.train()

In [9]:
def get_mean_std(loader):
    channels_sum, channels_sqr_sum, num_batches = 0, 0, 0

    for data, _ in (loader):
        channels_sum += torch.mean(data, dim=[0, 2, 3])
        channels_sqr_sum += torch.mean(data ** 2, dim=[0, 2, 3])
        num_batches += 1

    mean = channels_sum / num_batches
    std = (channels_sqr_sum / num_batches - mean ** 2) ** 0.5

    return mean, std

In [10]:
def get_loaders(train_csv, test_csv):
    from dataset import YOLODataset

    IMAGE_SIZE = config.IMAGE_SIZE
    train_dataset = YOLODataset(train_csv, config.IMG_DIR, config.LABEL_DIR, config.ANCHORS, IMAGE_SIZE, transform=config.train_transforms)
    test_dataset = YOLODataset(test_csv, config.IMG_DIR, config.LABEL_DIR, config.ANCHORS, IMAGE_SIZE, transform=config.test_transforms)
    train_loader = DataLoader(train_dataset, config.BATCH_SIZE, shuffle=True)
    test_loader = DataLoader(test_dataset, config.BATCH_SIZE, shuffle=False)

In [11]:
def plot_couple_examples(model, loader, thresh, iou_thresh, anchors):
    model.eval()
    x, y = next(iter(loader))
    x = x.to(config.DEVICE)

    with torch.no_grad():
        out = model(x)
        boxes = [[] for _ in range(x.shape[0])]

        for i in range(3):
            N, A, S, _, _ = out[i].shape
            anchor = anchors[i]
            boxes_scale_i = cells_to_boxes(out[i], anchor, S, True)
            for idx, box in enumerate(boxes_scale_i):
                boxes[idx] += box
        
        model.train()

    for i in range(N):
        nms_boxes = non_max_suppression(boxes[i], iou_threshold=iou_thresh, threshold=thresh, box_format="midpoint")
        plot_image(x[i].permute(1, 2, 0).detach().cpu(), nms_boxes)