In [9]:
import numpy as np
import cv2
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
from collections import OrderedDict
from util import *

##### Multi box IOU calcuation

In [10]:
# mode - x1y1x2y2, cxcywh
def bbox_iou(b1, b2, mode="x1y1x2y2"):
    if mode == "x1y1x2y2":
        b1_x1, b1_y1, b1_x2, b1_y2 = b1[...,0], b1[...,1], b1[...,2], b1[...,3]
        b2_x1, b2_y1, b2_x2, b2_y2 = b2[...,0], b2[...,1], b2[...,2], b2[...,3]  
    elif mode == "cxcywh":
        b1_x1, b1_x2 = b1[..., 0] - b1[..., 2] / 2, b1[..., 0] + b1[..., 2] / 2
        b1_y1, b1_y2 = b1[..., 1] - b1[..., 3] / 2, b1[..., 1] + b1[..., 3] / 2
        b2_x1, b2_x2 = b2[..., 0] - b2[..., 2] / 2, b2[..., 0] + b2[..., 2] / 2
        b2_y1, b2_y2 = b2[..., 1] - b2[..., 3] / 2, b2[..., 1] + b2[..., 3] / 2
    
    num_b1 = b1.shape[0]
    num_b2 = b2.shape[0]
    
    inter_x1 = torch.max(b1_x1.unsqueeze(1).repeat(1, num_b2), b2_x1)
    inter_y1 = torch.max(b1_y1.unsqueeze(1).repeat(1, num_b2), b2_y1)
    inter_x2 = torch.min(b1_x2.unsqueeze(1).repeat(1, num_b2), b2_x2)
    inter_y2 = torch.min(b1_y2.unsqueeze(1).repeat(1, num_b2), b2_y2)
            
    inter_area = torch.clamp(inter_x2 - inter_x1, min=0) * torch.clamp(inter_y2 - inter_y1, min=0)
    b1_area = (b1_x2 - b1_x1) * (b1_y2 - b1_y1)
    b2_area = (b2_x2 - b2_x1) * (b2_y2 - b2_y1)
    union_area = b1_area.unsqueeze(1).repeat(1, num_b2) + b2_area.unsqueeze(0).repeat(num_b1, 1) - inter_area
    
    iou = inter_area / union_area
    return iou

#### Yolo Loss layer

In [11]:
class YoloLayer(nn.Module):
    def __init__(self, anchors, img_dim, numClass):
        super().__init__()
        self.anchors = anchors
        self.img_dim = img_dim
                
        self.numClass = numClass
        self.bbox_attrib = 5 + numClass
        
        self.lambda_xy = 1
        self.lambda_wh = 1
        self.lambda_conf = 1 #1.0
        self.lambda_cls = 1 #1.0
        
        self.obj_scale = 1 #5
        self.noobj_scale = 1 #1
        
        self.ignore_thres = 0.5
        
        self.mseloss = nn.MSELoss(size_average=False) # https://pytorch.org/docs/stable/nn.html#torch.nn.MSELoss
        self.bceloss = nn.BCELoss(size_average=False)
        # https://pytorch.org/docs/stable/nn.html#bceloss
        self.bceloss_average = nn.BCELoss(size_average=True) # Binary Cross Entropy between the target and the output
 
    def forward(self, x, img_dim, target=None):
        #x : bs x nA*(5 + num_classes) * h * w
        nB = x.shape[0]
        nA = len(self.anchors)
        nH, nW = x.shape[2], x.shape[3]
        stride = img_dim[1] / nH
        anchors = torch.FloatTensor(self.anchors) / stride
        
        #Reshape predictions from [B x [A * (5 + numClass)] x H x W] to [B x A x H x W x (5 + numClass)]
        preds = x.view(nB, nA, self.bbox_attrib, nH, nW).permute(0, 1, 3, 4, 2).contiguous()
        
        # tx, ty, tw, wh
        preds_xy = preds[..., :2]
        preds_wh = preds[..., 2:4]
        preds_conf = preds[..., 4].sigmoid()
        preds_cls = preds[..., 5:].sigmoid()
        
        # Calculate cx, cy, anchor mesh
        mesh_x = torch.arange(nW).repeat(nH,1).unsqueeze(2)
        mesh_y = torch.arange(nH).repeat(nW,1).t().unsqueeze(2)
        mesh_xy = torch.cat((mesh_x,mesh_y), 2)
        mesh_anchors = anchors.view(1, nA, 1, 1, 2).repeat(1, 1, nH, nW, 1)
        
        # pred_boxes holds bx,by,bw,bh
        pred_boxes = torch.FloatTensor(preds[..., :4].shape)
        pred_boxes[..., :2] = preds_xy.detach().cpu().sigmoid() + mesh_xy.float() # sig(tx) + cx
        pred_boxes[..., 2:4] = preds_wh.detach().cpu().exp() * mesh_anchors  # exp(tw) * anchor
        
        if target is not None:
            obj_mask, noobj_mask, tconf, tcls, tx, ty, tw, th, nCorrect, nGT = self.build_target_tensor(
                                                                    pred_boxes, target.detach().cpu(),
                                                                    anchors, (nH, nW), self.numClass,
                                                                    self.ignore_thres)
            
            #recall = float(nCorrect / nGT) if nGT else 1
            #assert(nGT == TP + FN)

            # masks for loss calculations
            #obj_mask, noobj_mask = obj_mask.cuda(), noobj_mask.cuda()
            cls_mask = (obj_mask == 1)
            #tconf, tcls = tconf.cuda(), tcls.cuda()
            #tx, ty, tw, th = tx.cuda(), ty.cuda(), tw.cuda(), th.cuda()

            loss_x = self.lambda_xy * self.mseloss(preds_xy[..., 0] * obj_mask, tx * obj_mask) / nB
            loss_y = self.lambda_xy * self.mseloss(preds_xy[..., 1] * obj_mask, ty * obj_mask) / nB
            loss_w = self.lambda_wh * self.mseloss(preds_wh[..., 0] * obj_mask, tw * obj_mask) / nB
            loss_h = self.lambda_wh * self.mseloss(preds_wh[..., 1] * obj_mask, th * obj_mask) / nB

            loss_conf = self.lambda_conf * \
                        ( self.obj_scale * self.bceloss(preds_conf * obj_mask, obj_mask) + \
                          self.noobj_scale * self.bceloss(preds_conf * noobj_mask, noobj_mask * 0) ) / nB
            loss_cls = self.lambda_cls * self.bceloss(preds_cls[cls_mask], tcls[cls_mask]) / nB
            loss =  loss_x + loss_y + loss_w + loss_h + loss_conf + loss_cls 
                
            return loss, loss.item(), loss_x.item(), loss_y.item(), loss_w.item(), loss_h.item(), \
                   loss_conf.item(), loss_cls.item(), \
                   nCorrect, nGT
           
        # Return predictions if not training 
        out = torch.cat((pred_boxes * stride, 
                         preds_conf.unsqueeze(4),
                         preds_cls ), 4)
        
        # Reshape predictions from [B x A x H x W x (5 + numClass)] to [B x [A x H x W] x (5 + numClass)]
        # such that predictions at different strides could be concatenated on the same dimension
        out = out.permute(0, 2, 3, 1, 4).contiguous().view(nB, nA*nH*nW, self.bbox_attrib)
        return out

    def build_target_tensor(self, pred_boxes, target, anchors, inp_dim, numClass, ignore_thres):
        nB = target.shape[0] # batch
        nA = len(anchors) # 3
        nH, nW = inp_dim[0], inp_dim[1]
        nCorrect = 0
        nGT = 0
        target = target.float()

        obj_mask = torch.zeros(nB, nA, nH, nW, requires_grad=False)
        noobj_mask = torch.ones(nB, nA, nH, nW, requires_grad=False)
        tconf= torch.zeros(nB, nA, nH, nW, requires_grad=False)
        tcls= torch.zeros(nB, nA, nH, nW, numClass, requires_grad=False)
        tx = torch.zeros(nB, nA, nH, nW, requires_grad=False)
        ty = torch.zeros(nB, nA, nH, nW, requires_grad=False)
        tw = torch.zeros(nB, nA, nH, nW, requires_grad=False)
        th = torch.zeros(nB, nA, nH, nW, requires_grad=False)

        for b in range(nB): # batches
            for t in range(target.shape[1]): # targets
                if target[b, t].sum() == 0:
                    break;
                nGT += 1

                gx = target[b, t, 1] * nW
                gy = target[b, t, 2] * nH
                gw = target[b, t, 3] * nW
                gh = target[b, t, 4] * nH
                gi = int(gx)
                gj = int(gy)

                # pred_boxes - [A x H x W x 4]  
                # Do not train for objectness(noobj) if anchor iou > threshold.
                tmp_gt_boxes = torch.FloatTensor([gx, gy, gw, gh]).unsqueeze(0)
                tmp_pred_boxes = pred_boxes[b].view(-1, 4)
                tmp_ious, _ = torch.max(bbox_iou(tmp_pred_boxes, tmp_gt_boxes, mode="cxcywh"), 1)
                ignore_idx = (tmp_ious > ignore_thres).view(nA, nH, nW) # get rid of box that iou > thres with GT.
                noobj_mask[b][ignore_idx] = 0

                
                #find best fit anchor for each ground truth box
                tmp_gt_boxes = torch.FloatTensor([[0, 0, gw, gh]])
                tmp_anchor_boxes = torch.cat((torch.zeros(nA, 2), anchors), 1)
                tmp_ious = bbox_iou(tmp_anchor_boxes, tmp_gt_boxes, mode="cxcywh")
                best_anchor = torch.argmax(tmp_ious, 0).item()
                
                #find iou for best fit anchor prediction box against the ground truth box
                tmp_gt_box = torch.FloatTensor([gx, gy, gw, gh]).unsqueeze(0)
                tmp_pred_box = pred_boxes[b, best_anchor, gj, gi].view(-1, 4)
                tmp_iou = bbox_iou(tmp_gt_box, tmp_pred_box, mode="cxcywh")

                if tmp_iou > 0.5:
                    nCorrect += 1

                obj_mask[b, best_anchor, gj, gi] = 1
                #noobj_mask[b, best_anchor, gj, gi] = 0
                tconf[b, best_anchor, gj, gi] = 1
                tcls[b, best_anchor, gj, gi, int(target[b, t, 0])] = 1
                sig_x = gx - gi
                sig_y = gy - gj
                tx[b, best_anchor, gj, gi] = torch.log(sig_x/(1-sig_x) + 1e-16)
                ty[b, best_anchor, gj, gi] = torch.log(sig_y/(1-sig_y) + 1e-16)
                tw[b, best_anchor, gj, gi] = torch.log(gw / anchors[best_anchor, 0] + 1e-16)
                th[b, best_anchor, gj, gi] = torch.log(gh / anchors[best_anchor, 1] + 1e-16)

        return obj_mask, noobj_mask, tconf, tcls, tx, ty, tw, th, nCorrect, nGT

#### Modify YoloNet (from yolo_detect.ipynb)

In [12]:
#import darknet
#from darknet import Darknet, PreDetectionConvGroup, UpsampleGroup, WeightManager

class YoloNet(nn.Module):
    def __init__(self, img_dim=None, anchors = [10,13,  16,30,  33,23,  30,61,  62,45,  59,119,  116,90,  156,198,  373,326], numClass=80):
        super().__init__()
        nin = 32
        self.numClass = numClass
        self.img_dim = img_dim
        self.stat_keys = ['loss', 'loss_x', 'loss_y', 'loss_w', 'loss_h', 'loss_conf', 'loss_cls',
                          'nCorrect', 'nGT', 'recall']
        
        anchors = [(anchors[i], anchors[i+1]) for i in range(0,len(anchors),2)]
        anchors = [anchors[i:i+3] for i in range(0, len(anchors), 3)][::-1]
                
        self.feature = Darknet([1,2,8,8,4])
        self.feature.addCachedOut(61)
        self.feature.addCachedOut(36)
        
        self.pre_det1 = PreDetectionConvGroup(1024, 512, numClass=self.numClass)
        self.yolo1 = YoloLayer(anchors[0], img_dim, self.numClass)
        self.pre_det1.addCachedOut(-3) #Fetch output from 4th layer backward including yolo layer
        
        self.up1 = UpsampleGroup(512)
        self.pre_det2 = PreDetectionConvGroup(768, 256, numClass=self.numClass)
        self.yolo2 = YoloLayer(anchors[1], img_dim, self.numClass)
        self.pre_det2.addCachedOut(-3)
        
        self.up2 = UpsampleGroup(256)
        self.pre_det3 = PreDetectionConvGroup(384, 128, numClass=self.numClass)
        self.yolo3 = YoloLayer(anchors[2], img_dim, self.numClass)
        
   
    def forward(self, x, target=None):
        img_dim = (x.shape[3], x.shape[2])
        #Extract features
        out = self.feature(x)
                
        #Detection layer 1
        out = self.pre_det1(out)
        det1 = self.yolo1(out, img_dim, target)
        
        #Upsample 1
        r_head1 = self.pre_det1.getCachedOut(-3)
        r_tail1 = self.feature.getCachedOut(61)
        out = self.up1(r_head1,r_tail1)
                
        #Detection layer 2
        out = self.pre_det2(out)
        det2 = self.yolo2(out, img_dim, target)
        
        #Upsample 2
        r_head2 = self.pre_det2.getCachedOut(-3)
        r_tail2 = self.feature.getCachedOut(36)
        out = self.up2(r_head2,r_tail2)
                
        #Detection layer 3
        out = self.pre_det3(out)
        det3 = self.yolo3(out, img_dim, target)
        
        if target is not None:
            loss, *out = [sum(det) for det in zip(det1, det2, det3)]
            self.stats = dict(zip(self.stat_keys, out))
            self.stats['recall'] = self.stats['nCorrect'] / self.stats['nGT'] if self.stats['nGT'] else 0
            return loss
        else:
            return det1, det2, det3
    
    # Format : pytorch / darknet
    def saveWeight(self, weights_path, format='pytorch'):
        if format == 'pytorch':
            torch.save(self.state_dict(), weights_path)
        elif format == 'darknet':
            raise NotImplementedError
    
    def loadWeight(self, weights_path, format='pytorch'):
        if format == 'pytorch':
            weights = torch.load(weights_path, map_location=lambda storage, loc: storage)
            self.load_state_dict(weights)
        elif format == 'darknet':
            wm = WeightManager(self)
            wm.loadWeight(weights_path)

In [13]:
inp = torch.rand([1,3,416,416])
YNet = YoloNet()
det1, det2, det3 = YNet(inp)

In [14]:
det1.shape,det2.shape,det3.shape

(torch.Size([1, 507, 85]),
 torch.Size([1, 2028, 85]),
 torch.Size([1, 8112, 85]))

In [22]:
label = torch.rand([1, 2, 5]) # Batch * obj * 5(1+4)
inp = torch.rand([1,3,416,416])
YNet = YoloNet()
loss = YNet(inp, label)

In [23]:
loss

tensor(8092.1797, grad_fn=<AddBackward0>)