In [237]:
import torch
from torch.utils.data import Dataset
import os
from os.path import join as osp
from PIL import Image
from torchvision.transforms import transforms

img_transforms = transforms.Compose([transforms.Resize((384,1248)), transforms.ToTensor()])

label_map = {
                'Car':                 0,
                'Van':                 1,
                'Truck':               2,
                'Pedestrian':          3,
                'Person_sitting':      4,
                'Cyclist':             5,
                'Tram':                6,
                'Misc':                7,
                'DontCare':            8
            }

class KittiDetection2D(Dataset):

    def __init__(self, root, transforms=None):
        super(KittiDetection2D, self).__init__()
        self.image_dir = osp(root, "image_2")
        self.label_dir = osp(root, "label_2")
        self.image_list = os.listdir(osp(root, "image_2"))
        self.label_list = os.listdir(osp(root, "label_2"))
        self.transforms = transforms
        self.h = 384
        self.w = 1248
        self.S = (11,24)

    def __len__(self):
        return len(self.image_list)

    def __getitem__(self, index):
        image = Image.open(osp(self.image_dir, self.image_list[index]))
        label_pth = osp(self.label_dir, self.label_list[index])
        target = torch.zeros(11, 24, 14)
        with open(label_pth, 'r') as f:
            for i in f.readlines():
                obj_class, x1, y1, x2, y2 = self._parse_label(i)
                cell, x, y, w, h = self._convert_label(x1, y1, x2, y2)
                target[cell[0], cell[1]] = self._create_vector(obj_class, x, y, w, h)
            
        if self.transforms:
            image = self.transforms(image)

        return image, target

    def _create_vector(self, obj, x, y, h, w):
        obj_vector = torch.zeros(9)
        obj_vector[obj] = 1
        return torch.cat((obj_vector, torch.tensor([1,x,y,h,w])))

    def _parse_label(self, label):
        label = label.split()
        obj_class = label_map[label[0]]
        x1, y1, x2, y2 = int(float(label[4])), int(float(label[5])), int(float(label[6])), int(float(label[7]))
        return obj_class, x1, y1, x2, y2

    def _convert_label(self, x1, y1, x2, y2):
        x = int((x1+x2)/2)/self.w
        y = int((y1+y2)/2)/self.h
        h = int((y2-y1))/self.h
        w = int((x2-x1))/self.w
        cell = int(y*self.S[0]), int(x*self.S[1])
        x = x*self.S[1]-cell[1]
        y = y*self.S[0]-cell[0]
        h = self.S[0]*h
        w = self.S[1]*w
        return cell,x,y,w,h

ds = KittiDetection2D(r"E:\Deep Learning Projects\datasets\kitti_object_detection\Kitti\raw\training", img_transforms)
x, target = ds[0]
x.unsqueeze_(0)
target.unsqueeze_(0)
from model import YOLOv1
from utils import read_yaml, iou

model_cfg = read_yaml('model.yaml')
model = YOLOv1(model_cfg)
pred = model(x)



In [227]:
x = torch.where(target[0,..., 9]==1)
x

(tensor([6]), tensor([14]))

In [240]:
from torch.nn import MSELoss
import torch.nn as nn
loss = MSELoss()
tar_b = target[..., 10:14]
exists_box = target[...,9:10]
print('existbox', exists_box[0,6,14])
pred_b1 = pred[...,10:14]
pred_b2 = pred[..., 15:19]
print(pred_b1[0,6,14], tar_b[0,6,14])
iou_b1 = iou(pred_b1, tar_b).unsqueeze(-1)
iou_b2 = iou(pred_b2, tar_b).unsqueeze(-1)
_, best_box = torch.max(torch.cat((iou_b1,iou_b2), dim=-1), dim=-1)
best_box.unsqueeze_(-1)
pred_b = best_box*pred[..., 15:19] + (1-best_box)*pred[..., 10:14]
print(pred_b1[0,6,14], pred_b[0,6,14])
pred_coord = torch.sqrt(torch.abs(pred_b1[...,2:4]))
target_coord = torch.sqrt(tar_b[...,2:4])
coord_loss = 5*loss(torch.flatten(pred_coord), torch.flatten(target_coord))
print(coord_loss)


existbox tensor([1.])
tensor([-0.0395,  0.0335,  0.0334, -0.0004], grad_fn=<SelectBackward0>) tensor([0.6346, 0.4453, 1.8846, 4.6979])
tensor([-0.0395,  0.0335,  0.0334, -0.0004], grad_fn=<SelectBackward0>) tensor([-0.0395,  0.0335,  0.0334, -0.0004], grad_fn=<SelectBackward0>)
tensor(0.1647, grad_fn=<MulBackward0>)


In [243]:
yololoss = YoloLoss()
yololoss(pred, target)


tensor([[[[0., 0., 0., 0.],
          [0., 0., 0., 0.],
          [0., 0., 0., 0.],
          ...,
          [0., 0., 0., 0.],
          [0., 0., 0., 0.],
          [0., 0., 0., 0.]],

         [[0., 0., 0., 0.],
          [0., 0., 0., 0.],
          [0., 0., 0., 0.],
          ...,
          [0., 0., 0., 0.],
          [0., 0., 0., 0.],
          [0., 0., 0., 0.]],

         [[0., 0., 0., 0.],
          [0., 0., 0., 0.],
          [0., 0., 0., 0.],
          ...,
          [0., 0., 0., 0.],
          [0., 0., 0., 0.],
          [0., 0., 0., 0.]],

         ...,

         [[0., 0., 0., 0.],
          [0., 0., 0., 0.],
          [0., 0., 0., 0.],
          ...,
          [0., 0., 0., 0.],
          [0., 0., 0., 0.],
          [0., 0., 0., 0.]],

         [[0., 0., 0., 0.],
          [0., 0., 0., 0.],
          [0., 0., 0., 0.],
          ...,
          [0., 0., 0., 0.],
          [0., 0., 0., 0.],
          [0., 0., 0., 0.]],

         [[0., 0., 0., 0.],
          [0., 0., 0., 0.],
    

tensor(0.0709, grad_fn=<AddBackward0>)

In [242]:
import torch
from torch import nn
from utils import iou

class YoloLoss(nn.Module):
    def __init__(self, S=(11,24), B=2, C=9, coord=5, noobj=0.5) -> None:
        super(YoloLoss, self).__init__()
        self.mse_loss = nn.MSELoss()
        self.S = S
        self.B = B
        self.C = C
        self.coord = coord
        self.noobj = noobj

    def forward(self, pred, target):

        class_probs = target[..., :self.C]                              # [N, S[0], s[1], C]
        exist_box_identity = target[..., self.C:self.C+1]
        target_box = exist_box_identity * target[..., self.C+1:]                       # [N, S[0], s[1], 4]
        print(target_box.unique())

        iou_b1 = iou(pred[..., 10:14], target[...,10:14]).unsqueeze(-1) # [N, S[0], S[1], 1]
        iou_b2 = iou(pred[...,15:19], target[..., 10:14]).unsqueeze(-1) # [N, S[0], S[1], 1]
        max_iou, best_box = torch.max(torch.cat((iou_b1,iou_b2), dim=-1), dim=-1)            # best_box [N, S[0], S[1]]
        best_box = best_box.unsqueeze(-1)
        pred_best_box = exist_box_identity*(best_box*pred[...,15:19] + (1-best_box)*pred[..., 10:14]) # [N, S[0], s[1], 4]
        
        ## coord loss
        pred_coord = pred_best_box[..., 0:2]
        target_coord = target_box[..., 0:2]
        coord_loss = self.coord*self.mse_loss(torch.flatten(pred_coord, 0, -2), torch.flatten(target_coord, 0, -2))

        ## box loss
        
        pred_h_w = torch.sign(pred_best_box[...,2:4])*torch.sqrt(torch.abs(pred_best_box[..., 2:4])) # [N, S[0], S[1], 2]
        target_h_w = torch.sqrt(target_box[..., 2:4]) #[N, S[0], S[1], 2]
        box_loss = self.coord*self.mse_loss(torch.flatten(pred_h_w, 0, -2), torch.flatten(target_h_w, 0, -2))

        ## object loss
        pred_obj = exist_box_identity*(best_box*pred[...,14:15] + (1-best_box)*pred[..., 9:10])

        object_loss = self.mse_loss(torch.flatten(pred_obj, 0, -2), torch.flatten(exist_box_identity, 0, -2))

        ## no object loss

        noobject_loss = self.mse_loss(torch.flatten((1-exist_box_identity)*pred[...,9:10], 0, -2), torch.flatten((1-exist_box_identity)*target[...,9:10], 0, -2)) + self.mse_loss(torch.flatten((1-exist_box_identity)*pred[...,9:10], 0, -2), torch.flatten((1-exist_box_identity)*target[...,9:10], 0, -2))      ###### CHECK
        noobject_loss = self.noobj*object_loss

        ## class loss
        pred_class = exist_box_identity*pred[..., :9]
        target_class = exist_box_identity*target[..., :9]

        class_loss = self.mse_loss(torch.flatten(pred_class, 0, -2), torch.flatten(target_class, 0, -2))
        print('cord loss', coord_loss)
        # print('box loss', box_loss)
        # print('obj loss', object_loss)
        # print('noobj loss', noobject_loss)
        # print('class loss', class_loss)

        return box_loss + coord_loss + object_loss + noobject_loss + class_loss
        





In [132]:
loss = YoloLoss()
loss(torch.rand(1,11,24,19), torch.rand(1,11,24,14))

tensor(0.7199)

In [103]:
iou_b1 = iou(torch.rand(1,11,24,4), torch.rand(1,11,24,4)).unsqueeze(-1)
iou_b2 = iou(torch.rand(1,11,24,4), torch.rand(1,11,24,4)).unsqueeze(-1)
max_iou, best_box = torch.max(torch.cat((iou_b1,iou_b2), dim=-1), dim=-1)
best_box.unsqueeze_(-1)
pred_box = best_box * torch.rand(1,11,24,4)

In [102]:
torch.rand(1,11,24,4)[..., :5].shape

torch.Size([1, 11, 24, 4])

In [105]:
pred_box.shape

torch.Size([1, 11, 24, 4])

In [109]:
torch.flatten(torch.rand(1,11,24,4), 0,-2).shape

torch.Size([264, 4])