In [1]:
import os
import os.path as osp
import sys
sys.path.insert(0, osp.dirname(osp.abspath('.')))

In [2]:
import torch
from torch.utils.data import DataLoader

from data.dataset import YOLODataset, PASCAL_CLASSES
from data.transform import get_yolo_transform
from loss.yolo import bbox_iou

from utils.convert import cells_to_bboxes
from utils.cleanup import non_max_suppression as nms
from utils.display import plot_image

%load_ext autoreload
%autoreload 2

In [3]:
CSV_PATH = "../download/PASCAL_VOC/test.csv"
IMG_DIR = "../download/PASCAL_VOC/images/"
LABEL_DIR = "../download/PASCAL_VOC/labels/"

In [4]:
IMG_SIZE = 416
SCALES = [13, 26, 52]
ANCHORS = [
    [(0.28, 0.22), (0.38, 0.48), (0.9, 0.78)],
    [(0.07, 0.15), (0.15, 0.11), (0.14, 0.29)],
    [(0.02, 0.03), (0.04, 0.07), (0.08, 0.06)],
] # (3, 3, 2)

In [5]:
transform = get_yolo_transform(img_size=IMG_SIZE, mode='test')
dataset = YOLODataset(csv_file=CSV_PATH,
                      img_dir=IMG_DIR,
                      label_dir=LABEL_DIR,
                      anchors=ANCHORS,
                      transform=transform)
dataloader = DataLoader(dataset,
                        batch_size=8,
                        shuffle=False)

scaled_anchors = torch.tensor(ANCHORS)*torch.tensor(SCALES).unsqueeze(1).unsqueeze(1).repeat(1, 3, 2)

In [6]:
imgs, targets = next(iter(dataloader))

pred = targets[0]
target = targets[0]
target_scale = SCALES[0]
target_anchor = scaled_anchors[0]

In [7]:
bboxes = cells_to_bboxes(target,
                        scale=target_scale,
                        anchors=target_anchor,
                        is_preds=False).reshape(-1, 6)

In [8]:
obj_mask = bboxes[..., 1] == 1.
filtered = bboxes[..., 2:][obj_mask]

In [14]:
bbox_iou(filtered, filtered, mode='giou')

tensor([[1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.]])