In [1]:
%load_ext autoreload
%autoreload 2

In [203]:
import torch
import torch.nn as nn
import torchvision
from torchvision.transforms import ToTensor, Resize, Compose
import matplotlib.pyplot as plt
import matplotlib.patches as patches
import sys
sys.path.append('../')
import dataset
from torchvision.ops import box_iou

In [3]:
pascal_voc_train = torchvision.datasets.VOCDetection(
    root="../data",
    year="2007",
    image_set="train",
    download=False
)

In [4]:
voc_train = dataset.PascalVOC(pascal_voc=pascal_voc_train)

TRANSFORMING PASCAL VOC


In [204]:
def iou():
    pass

In [701]:
class YOLOv1Loss(nn.Module):
    """
    YOLOv1 Loss
    """
    def __init__(self, S=7, B=2, C=20, lambda_coord=5, lambda_noobj=0.5):
        """
        S: dimension of the S x S grid
        B: number of bounding boxes predicted by network
        C: number of classes
        lambda_coord: penalty for coord loss
        lambda_noobj: penalty for confidence loss when no object is present in target
        """
        super().__init__()
        
        self.S = S
        self.B = B
        self.C = C
        self.lambda_coord = lambda_coord
        self.lambda_noobj = lambda_noobj
        
    def _xywh_to_x1y1x2y2(self, boxes: torch.Tensor) -> torch.Tensor:
        """
        Converts YOLO bounding box format to (x1, y1, x2, y2)
        
        pred: (N, S^2, X, 4)
        
        returns (N, S^2, X, 4)
        """
        x = boxes[..., 0] # (N, S^2, X)
        y = boxes[..., 1]
        w = boxes[..., 2]
        h = boxes[..., 3]
        
        # print(x, x.shape)
        # print(y, y.shape)
        # print(w, w.shape)
        # print(h, h.shape)
        
        x1 = x - w / 2
        y1 = y - h / 2
        x2 = x + w / 2
        y2 = y + h / 2
        
        # print(x1)
        # print(y1)
        # print(x2)
        # print(y2)
        
        x1y1x2y2 = torch.stack((x1, y1, x2, y2), dim=3) # (N, S^2, X) -> (N, S^2, X, 4)
        
        # print(x1y1x2y2, x1y1x2y2.shape)
        
        return x1y1x2y2
        
    def _iou(self, pred, target) -> torch.Tensor:
        """
        Calculates the IOU between B prediction boxes and target boxe
        
        pred: (N, S^2, B, 4)
        target: (N, S^2, 1, 4)
        
        returns (N, S^2, B)
        """
        
        # print("CALCULATING IOUS")
        # sanity check to make sure tensors are of the right shape
        assert len(pred) == len(target)
        N = len(pred)
        assert pred.shape == torch.Size((N, self.S**2, self.B, 4))
        assert target.shape == torch.Size((N, self.S**2, 1, 4))
        
        # print("PASSED SANITY CHECKS")
        
#         print("CONVERTING TO x1y1...")
        pred = self._xywh_to_x1y1x2y2(pred)
        target = self._xywh_to_x1y1x2y2(target)
        
        # print("IOUS")
        # vectorized box ioi        
        def _box_iou(pred, target):
            print(pred[0].shape, target[0].shape)
            pred = pred[0] # (B, 4)
            target = target[0] # (1, 4)
            
            iou = box_iou(pred, target) # (B, 1)
            iou = iou.unsqueeze(0) # (1, B, 1)
            
            # print(iou)
            # print(iou.shape)
            return iou
        
        v_box_iou = torch.vmap(_box_iou, in_dims=1, out_dims=1) # (N, S^2, B, 1)
        
        ious = v_box_iou(pred, target)
        
        # print(ious)
        # print(ious.shape)
        
        ious = ious.squeeze(3) # (N, S^2, B, 1) -> (N, S^2, B)
        
        # print(ious, ious.shape)
        
        return ious
        
        
        
    
    def forward(self, pred, target):
        """
        pred: (N x S x S x (5 * B + C))
        target: (N x S x S x (5 + C))
        """
        
        print("YOLO LOSS")
        
        # check pred and target are in the correct shape
        assert len(pred) == len(target)
        N = len(pred)
        
        # get parameters of YOLO loss
        S = self.S
        B = self.B
        C = self.C
        lambda_coord = self.lambda_coord
        lambda_noobj = self.lambda_noobj
        
        assert pred.shape == torch.Size((N, S, S, 5 * B + C))
        assert target.shape == torch.Size((N, S, S, 5 + C))
        
        # flatten S x S grid into S^2
        pred = pred.view(N, S**2, -1) # (N, S, S, 5 * B + C) -> (N, S^2, 5 * B + C)
        target = target.view(N, S**2, -1) # (N, S, S, 5 + C) -> (N, S^2, 5 + C)
        
        print("flattening S x S to S^2")
        print(pred.shape, target.shape)
        
        # seperate tensor into box + classification
        print("seperating tensor into box + classification")
        pred_bndboxes = pred[:, :, 0:5 * B] # (N, S^2, 5 * B + C) -> (N, S^2, 5 * B)
        target_bndbox = target[:, :, 0:5] # (N, S^2, 5 + C) -> (N, S^2, 5)
        
        print("getting confidence")
        pred_confidences = pred_bndboxes[..., 0:-1:5] # (N, S^2, 5 * B) -> (N, S^2, B)
        target_confidence = target_bndbox[..., 0] # (N, S^2, 5) -> (N, S^2)
        
        print(pred_confidences, pred_confidences.shape)
        print(target_confidence, target_confidence.shape)
        
        print("getting bounding box")
        print(pred_bndboxes, pred_bndboxes.shape)
        
        box_indices = torch.arange(0, 5 * B) % 5 != 0 # mask for every 2nd, 3rd, 4th, and 5th element
        pred_boxes = pred_bndboxes[..., box_indices].view(N, S**2, -1, 4) # (N, S^2, 5 * B) -> (N, S^2, 4 * B) -> (N, S^2, B, 4)
        target_box = target_bndbox[:, :, 1: 5].unsqueeze(2) # (N, S^2, 4)
        
        print("PRED BOXES")
        print(pred_boxes, pred_boxes.shape)
        
        print("TARGET BOX")
        print(target_box, target_box.shape)
        
        
        print("getting classification")
        pred_classification = pred[:, :, 5 * B: 5 * B + C] # (N, S^2, 5 * B + C) -> (N, S^2, C)
        target_classification = target[:, :, 5: 5 + C] # (N, S^2, 5 + C) -> (N, S^2, C)
        
        # print(pred_classification, pred_classification.shape)
        # print(target_classification, target_classification.shape)
        
        # calculate IOU between predicted boxes and target box
        ious = self._iou(pred_boxes, target_box)
        
        # get the max iou -> the box with the highest iou is the one responsible for
        # predicting the bounding box in that given cell
        
        print("IOUS")
        print(ious, ious.shape)
        
        print("MAX IOUS")
        # (N, S^2, 1) ??
        max_ious, indices = torch.max(ious, dim=2)
        
        print(max_ious, indices)
        
        

In [702]:
# S = 1
# B = 2
# C = 2

S = 2
B = 2
C = 2

In [703]:
pred = torch.zeros((S, S, 5 * B + C))
pred[0, 0] = torch.tensor([1, 0.5, 0.5, 0.5, 0.5, 1, 0.5, 0.5, 1/7, 1/7, 0, 0.95])
pred = pred.unsqueeze(0)
pred.shape

torch.Size([1, 2, 2, 12])

In [704]:
target = torch.zeros((S, S, 5 + C))
target[0, 0] = torch.tensor([1, 0.5, 0.5, 1/7, 1/7, 0, 1])
target = target.unsqueeze(0)
target.shape

torch.Size([1, 2, 2, 7])

In [705]:
yolo_loss = YOLOv1Loss(
    S=S,
    B=B,
    C=C
)

yolo_loss

YOLOv1Loss()

In [706]:
yolo_loss(pred, target)

YOLO LOSS
flattening S x S to S^2
torch.Size([1, 4, 12]) torch.Size([1, 4, 7])
seperating tensor into box + classification
getting confidence
tensor([[[1., 1.],
         [0., 0.],
         [0., 0.],
         [0., 0.]]]) torch.Size([1, 4, 2])
tensor([[1., 0., 0., 0.]]) torch.Size([1, 4])
getting bounding box
tensor([[[1.0000, 0.5000, 0.5000, 0.5000, 0.5000, 1.0000, 0.5000, 0.5000,
          0.1429, 0.1429],
         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000]]]) torch.Size([1, 4, 10])
PRED BOXES
tensor([[[[0.5000, 0.5000, 0.5000, 0.5000],
          [0.5000, 0.5000, 0.1429, 0.1429]],

         [[0.0000, 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000, 0.0000]],

         [[0.0000, 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 

In [107]:
a = torch.zeros((1, 2, 2))

In [108]:
a.shape == torch.Size([1, 2, 2])

True

In [109]:
a.shape

torch.Size([1, 2, 2])