In [13]:
import torch
import torch.nn.functional as F
import sys
import homework.utils
import homework.models
import matplotlib.pyplot as plt

In [3]:
def point_in_box(pred, lbl):
    px, py = pred[:, None, 0], pred[:, None, 1]
    x0, y0, x1, y1 = lbl[None, :, 0], lbl[None, :, 1], lbl[None, :, 2], lbl[None, :, 3]
    return (x0 <= px) & (px < x1) & (y0 <= py) & (py < y1)


def point_close(pred, lbl, d=5):
    px, py = pred[:, None, 0], pred[:, None, 1]
    x0, y0, x1, y1 = lbl[None, :, 0], lbl[None, :, 1], lbl[None, :, 2], lbl[None, :, 3]
    return ((x0 + x1 - 1) / 2 - px) ** 2 + ((y0 + y1 - 1) / 2 - py) ** 2 < d ** 2


def box_iou(pred, lbl, t=0.5):
    px, py, pw2, ph2 = pred[:, None, 0], pred[:, None, 1], pred[:, None, 2], pred[:, None, 3]
    px0, px1, py0, py1 = px - pw2, px + pw2, py - ph2, py + ph2
    x0, y0, x1, y1 = lbl[None, :, 0], lbl[None, :, 1], lbl[None, :, 2], lbl[None, :, 3]
    iou = (abs(torch.min(px1, x1) - torch.max(px0, x0)) * abs(torch.min(py1, y1) - torch.max(py0, y0))) / \
          (abs(torch.max(px1, x1) - torch.min(px0, x0)) * abs(torch.max(py1, y1) - torch.min(py0, y0)))
    return iou > t

class PR:
    def __init__(self, min_size=20, is_close=point_in_box):
        self.min_size = min_size
        self.total_det = 0
        self.det = []
        self.is_close = is_close

    def add(self, d, lbl):
        lbl = torch.as_tensor(lbl.astype(float), dtype=torch.float32).view(-1, 4)
        d = torch.as_tensor(d, dtype=torch.float32).view(-1, 5)
        all_pair_is_close = self.is_close(d[:, 1:], lbl)

        # Get the box size and filter out small objects
        sz = abs(lbl[:, 2]-lbl[:, 0]) * abs(lbl[:, 3]-lbl[:, 1])

        # If we have detections find all true positives and count of the rest as false positives
        if len(d):
            detection_used = torch.zeros(len(d))
            # For all large objects
            for i in range(len(lbl)):
                if sz[i] >= self.min_size:
                    # Find a true positive
                    s, j = (d[:, 0] - 1e10 * detection_used - 1e10 * ~all_pair_is_close[:, i]).max(dim=0)
                    if not detection_used[j] and all_pair_is_close[j, i]:
                        detection_used[j] = 1
                        self.det.append((float(s), 1))

            # Mark any detection with a close small ground truth as used (no not count false positives)
            detection_used += all_pair_is_close[:, sz < self.min_size].any(dim=1)

            # All other detections are false positives
            for s in d[detection_used == 0, 0]:
                self.det.append((float(s), 0))

        # Total number of detections, used to count false negatives
        self.total_det += int(torch.sum(sz >= self.min_size))

In [18]:
model = homework.models.Detector()
for img, *gts in homework.utils.DetectionSuperTuxDataset('dense_data/valid', min_size=0):
    heatmap = model(img)
    print(heatmap)
    sys.exit()
    

tensor([[[[ 0.3024,  0.3277,  0.3415,  ...,  0.3171,  0.2630,  0.3041],
          [ 0.4483,  0.4956,  0.3580,  ...,  0.4213,  0.3147,  0.3081],
          [ 0.2651,  0.2549,  0.2448,  ...,  0.3106,  0.2403,  0.3329],
          ...,
          [ 0.3603,  0.3369,  0.3720,  ...,  0.3245,  0.3774,  0.3903],
          [ 0.2775,  0.2994,  0.2631,  ...,  0.2761,  0.2927,  0.3003],
          [ 0.3602,  0.2548,  0.3335,  ...,  0.3204,  0.3519,  0.2813]],

         [[ 0.2891,  0.3541,  0.3203,  ...,  0.2336,  0.3087,  0.2219],
          [ 0.0313,  0.1466,  0.1820,  ...,  0.0532,  0.1770,  0.0779],
          [ 0.2501,  0.3186,  0.2811,  ...,  0.1521,  0.2635,  0.1625],
          ...,
          [ 0.1860,  0.0892,  0.1405,  ...,  0.1431,  0.2034,  0.1012],
          [ 0.2818,  0.2379,  0.2735,  ...,  0.2489,  0.2568,  0.2364],
          [ 0.2121,  0.1728,  0.2072,  ...,  0.2149,  0.2187,  0.1914]],

         [[-0.2969, -0.3344, -0.2911,  ..., -0.1928, -0.1618, -0.2137],
          [-0.2889, -0.3523, -

SystemExit: 