### codes reference https://github.com/ydixon/yolo_v3/blob/master/yolo_detect.ipynb
### a nice bolg for understanding. https://blog.csdn.net/leviopku/article/details/82660381

In [8]:
import numpy as np
from collections import OrderedDict
import torch
import torch.nn as nn
import torch.nn.functional as F

#### Basic network building blocks - conv_bn_relu, res_layer

In [9]:
# see https://blog.csdn.net/leviopku/article/details/82660381, resn
class conv_bn_relu(nn.Module):
    def __init__(self, nin, nout, kernel, stride=1, pad="SAME", padding=0, bn=True, activation="leakyRelu"):
        super().__init__()
        
        self.bn = bn
        self.activation = activation
        
        if pad == 'SAME':
            padding = (kernel-1)//2
            
        self.conv = nn.Conv2d(nin, nout, kernel, stride, padding, bias=not bn)
        if bn == True:
            self.bn = nn.BatchNorm2d(nout)
        if activation == "leakyRelu":
            self.relu = nn.LeakyReLU(negative_slope=0.1, inplace=True)
    
    def forward(self, x):
        return self.relu(self.bn(self.conv(x)))
    
class res_layer(nn.Module):
    def __init__(self, nin):
        super().__init__()
        
        self.conv1 = conv_bn_relu(nin, nin//2, kernel=1)  #64->32, 1
        self.conv2 = conv_bn_relu(nin//2, nin, kernel=3)  #32->64, 3 see figure of darknet
        
    def forward(self, x):
        return x + self.conv2(self.conv1(x)) # just '+', the dim will be the same, not concat!

#### Map2cfgDict - used to creating mapping that follows the cfg file from prjreddit's repository

In [10]:
def map2cfgDict(mlist):
    idx = 0 
    mdict = OrderedDict()
    for i,m in enumerate(mlist):
        if isinstance(m, res_layer):
            mdict[idx] = None
            mdict[idx+1] = None
            idx += 2
        mdict[idx] = i
        idx += 1
    
    return mdict        

#### UpsampleGroup

In [11]:
# UpsampleGroup: conv + upsample + concat 
# see https://blog.csdn.net/leviopku/article/details/82660381, DBL + 上采样 + concat
class UpsampleGroup(nn.Module):
    def __init__(self, nin):
        super().__init__()
        self.conv = conv_bn_relu(nin, nin//2, kernel=1)
        self.up = nn.Upsample(scale_factor=2, mode="nearest")
        
    def forward(self, route_head, route_tail):
        out = self.up(self.conv(route_head))
        return torch.cat((out, route_tail), 1) # concat, size: nin/2 + nin

#### Darknet53 - Feature extraction

In [12]:
def make_res_stack(nin, num_block):
    return nn.ModuleList([conv_bn_relu(nin, nin*2, 3, stride=2)] + [res_layer(nin*2) for n in range(num_block)])

class Darknet(nn.Module):
    def __init__(self, blkList, nout=32):
        super().__init__()
        
        self.mlist = nn.ModuleList()
        self.mlist += [conv_bn_relu(3, nout, 3)]
        for i,nb in enumerate(blkList):
            self.mlist += make_res_stack(nout*(2**i), nb)
            
        self.map2yolocfg = map2cfgDict(self.mlist)
        self.cachedOutDict = dict()
        
    def forward(self, x):
        for i,m in enumerate(self.mlist):
            x = m(x)
            if i in self.cachedOutDict:
                self.cachedOutDict[i] = x
        return x
    
    def addCachedOut(self, idx, mode="yolocfg"):
        if mode == "yolocfg":
            idxs = self.map2yolocfg[idx]
        self.cachedOutDict[idxs] = None
        
    def getCachedOut(self, idx, mode="yolocfg"):
        if mode == "yolocfg":
            idxs = self.map2yolocfg[idx]
        return self.cachedOutDict[idxs]

In [7]:
net = Darknet([1,2,8,8,4])

In [41]:
net.addCachedOut(61)

In [15]:
#net

In [19]:
from visualization import *

inputs = torch.randn(1,3,416,416)
net = Darknet([1,2,8,8,4])
y = net(Variable(inputs))
print(y.shape)

g = make_dot(y, net.state_dict());
g.view()

torch.Size([1, 1024, 13, 13])


'Digraph.gv.pdf'

#### PreDetectionConvGroup - conv layers before the yolo detection layer

In [13]:
class PreDetectionConvGroup(nn.Module):
    def __init__(self, nin, nout, num_conv=3, numClass=80):
        super().__init__()
        
        self.mlist = nn.ModuleList()
        
        for i in range(num_conv): # 2*3 = 6,see figure of conv set on https://blog.csdn.net/qq_37541097/article/details/81214953 
            self.mlist += [conv_bn_relu(nin, nout, kernel=1)]
            self.mlist += [conv_bn_relu(nout, nout*2, kernel=3)]
            if i == 0:
                nin = nout*2
                
        self.mlist += [nn.Conv2d(nin, (numClass+5)*3, 1)] # expand dim to (numClass+5)*3
        self.map2yolocfg = map2cfgDict(self.mlist)
        self.cachedOutDict = dict()
        
    def forward(self, x):
        for i,m in enumerate(self.mlist):
            x = m(x)
            if i in self.cachedOutDict:
                self.cachedOutDict[i] = x
        
        return x
    
    #mode - normal  -- direct index to mlist 
    #     - yolocfg -- index follow the sequences of the cfg file from https://github.com/pjreddie/darknet/blob/master/cfg/yolov3.cfg
    def addCachedOut(self, idx, mode="yolocfg"):
        if mode == "yolocfg":
            idx = self.getIdxFromYoloIdx(idx)
        elif idx < 0:
            idx = len(self.mlist) - idx
        
        self.cachedOutDict[idx] = None
        
    def getCachedOut(self, idx, mode="yolocfg"):
        if mode == "yolocfg":
            idx = self.getIdxFromYoloIdx(idx)
        elif idx < 0:
            idx = len(self.mlist) - idx
        return self.cachedOutDict[idx]
    
    def getIdxFromYoloIdx(self,idx):
        if idx < 0:
            return len(self.map2yolocfg) + idx
        else:
            return self.map2yolocfg[idx]

In [28]:
# conv layer before do pred 1, and it is also applied on another two pred.
# see https://blog.csdn.net/qq_37541097/article/details/81214953
# convolution set(5) + conv = 6
mlist = nn.ModuleList()
numClass = 80
nin = 1024
nout = 512
for i in range(3):
    mlist += [conv_bn_relu(nin, nout, kernel=1)]
    mlist += [conv_bn_relu(nout, nout*2, kernel=3)]
    if i == 0:
        nin = nout*2
        
mlist += [nn.Conv2d(nin, (numClass+5)*3, 1)]

In [29]:
mlist

ModuleList(
  (0): conv_bn_relu(
    (conv): Conv2d(1024, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
    (bn): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): LeakyReLU(negative_slope=0.1, inplace)
  )
  (1): conv_bn_relu(
    (conv): Conv2d(512, 1024, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (bn): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): LeakyReLU(negative_slope=0.1, inplace)
  )
  (2): conv_bn_relu(
    (conv): Conv2d(1024, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
    (bn): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): LeakyReLU(negative_slope=0.1, inplace)
  )
  (3): conv_bn_relu(
    (conv): Conv2d(512, 1024, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (bn): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): LeakyReLU(negative_sl

In [31]:
pre_det1 = PreDetectionConvGroup(1024, 512, numClass=80);pre_det1

PreDetectionConvGroup(
  (mlist): ModuleList(
    (0): conv_bn_relu(
      (conv): Conv2d(1024, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): LeakyReLU(negative_slope=0.1, inplace)
    )
    (1): conv_bn_relu(
      (conv): Conv2d(512, 1024, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): LeakyReLU(negative_slope=0.1, inplace)
    )
    (2): conv_bn_relu(
      (conv): Conv2d(1024, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): LeakyReLU(negative_slope=0.1, inplace)
    )
    (3): conv_bn_relu(
      (conv): Conv2d(512, 1024, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn): BatchNorm2d(1024, eps=1e-05, momentum=0.1, af

In [40]:
print(pre_det1.addCachedOut(-3))

None


#### Yolo Detection Layer

In [14]:
class YoloLayer(nn.Module):
    def __init__(self, anchors, img_dim, nClass):
        super().__init__()
        
        self.anchors = anchors # [(116, 90), (156, 198), (373, 326)]
        self.img_dim = img_dim
        self.nClass = nClass
        self.bbox_attrib = nClass + 5
        
    def forward(self, x, img_dim):
        # x: bs * nA(5 + nClass) * h * w
        nB = x.shape[0]    # batch_size
        nA = len(self.anchors) # 3
        nH, nW = x.shape[2], x.shape[3]
        stride = img_dim[1] / nH # 416/13=32
        anchors = torch.FloatTensor(self.anchors) / stride
        """
        tensor([[ 3.6250,  2.8125],
                [ 4.8750,  6.1875],
                [11.6562, 10.1875]])
        """
        
        ##Reshape predictions from [B x [A * (5 + numClass)] x H x W] to [B x A x H x W x (5 + numClass)]
        # like:[1, 3, 85, 416, 416]->[1, 3, 416, 416, 85], see https://www.zhihu.com/question/60321866 for details.
        preds = x.view(nB, nA, self.bbox_attrib, nH, nW).permute(0, 1, 3, 4, 2).contiguous()
        
        # tx, ty, tw, wh
        preds_xy = preds[..., :2]            # = [:,:,:,:,:2], shape:torch.Size([1, 3, 416, 416, 2])
        preds_wh = preds[..., 2:4]           # [x, y, w, h, 3-84]
        preds_conf = preds[..., 4].sigmoid() # [x, y, w, h, conf, 4-84]
        preds_cls = preds[..., 5:].sigmoid() # [x, y, w, h, conf, cls(80)]
        
        # Calculate cx, cy, anchor mesh
        mesh_x = torch.arange(nW).repeat(nH,1).unsqueeze(2)              # H * W * 1
        mesh_y = torch.arange(nH).repeat(nW,1).t().unsqueeze(2)          # H * W * 1
        mesh_xy = torch.cat((mesh_x,mesh_y), 2)                          # H * W * 2
        mesh_anchors = anchors.view(1, nA, 1, 1, 2).repeat(1, 1, nH, nW, 1) # 1 * nA * 1 * 1 * 2 -> 1 * nA * H * W * 2
        
        # pred_boxes holds bx,by,bw,bh
        pred_boxes = torch.FloatTensor(preds[..., :4].shape)  # [1, 3, 416, 416, 4]
        #pred_boxes[..., :2] = preds_xy.detach().cpu().sigmoid() + mesh_xy
        pred_boxes[..., :2] = preds_xy.detach().cpu().sigmoid() + mesh_xy.float() # sig(tx) + cx, detach(): http://www.bnikolic.co.uk/blog/pytorch-detach.html
        pred_boxes[..., 2:4] = preds_wh.detach().cpu().exp() * mesh_anchors  # exp(tw) * anchor
        
        # Return predictions if not training # # [1, 3, 416, 416, 85]
        """
        out = torch.cat((pred_boxes.cuda() * stride, 
                         preds_conf.cuda().unsqueeze(4),
                         preds_cls.cuda() ), 4)
        """
        out = torch.cat((pred_boxes * stride, 
                         preds_conf.unsqueeze(4),
                         preds_cls), 4)
        
        # Reshape predictions from [B x A x H x W x (5 + numClass)] to [B x [A x H x W] x (5 + numClass)]
        # such that predictions at different strides could be concatenated on the same dimension
        out = out.permute(0, 2, 3, 1, 4).contiguous().view(nB, nA*nH*nW, self.bbox_attrib)
        
        return out

In [57]:
inp = torch.randn([1, 1024, 13, 13])
o = pre_det1(inp)
print(o.shape)

torch.Size([1, 255, 13, 13])


In [69]:
yolo = YoloLayer(anchors[0], 416, 80)
inp = torch.randn([1, 255, 13, 13]).float()
p = yolo(inp, 416)
print(p.shape)

torch.Size([1, 507, 85])


In [72]:
pre_det1.getCachedOut(-3).shape

torch.Size([1, 512, 13, 13])

In [82]:
inp = torch.randn([1, 3, 416, 416])
feature = Darknet([1,2,8,8,4])
feature.addCachedOut(61)
op = feature(inp)
feature.getCachedOut(61).shape

torch.Size([1, 512, 26, 26])

In [83]:
r_head1 = torch.randn([1, 512, 13, 13])
r_tail1 = torch.randn([1, 512, 26, 26])
up1 = UpsampleGroup(512)
out = up1(r_head1, r_tail1)
out.shape



torch.Size([1, 768, 26, 26])

In [84]:
pre_det2 = PreDetectionConvGroup(768, 256, 80)
inp = torch.randn([1, 768, 26, 26])
out = pre_det2(inp)
out.shape

torch.Size([1, 255, 26, 26])

In [85]:
yolo2 = YoloLayer(anchors[1], 416, 80)
inp = torch.randn([1, 255, 26, 26])
det2 = yolo2(out, 416)
det2.shape # 26*26*3

torch.Size([1, 2028, 85])

In [111]:
pre_det2 = PreDetectionConvGroup(768, 256, numClass=80)
pre_det2.addCachedOut(-3)
inp = torch.randn([1, 768, 26, 26])
o = pre_det2(inp)

In [112]:
pre_det2.getCachedOut(-3).shape

torch.Size([1, 256, 26, 26])

In [113]:
inp = torch.randn([1, 3, 416, 416])
feature = Darknet([1,2,8,8,4])
feature.addCachedOut(36)
op = feature(inp)
feature.getCachedOut(36).shape

torch.Size([1, 256, 52, 52])

In [115]:
r_head2 = torch.randn([1, 256, 26, 26])
r_tail2 = torch.randn([1, 256, 52, 52])
up2 = UpsampleGroup(256)
out = up2(r_head2, r_tail2)
out.shape



torch.Size([1, 384, 52, 52])

In [116]:
pre_det3 = PreDetectionConvGroup(384, 128, numClass=80)
inp = torch.randn([1, 384, 52, 52])
out = pre_det3(inp)
out.shape

torch.Size([1, 255, 52, 52])

In [117]:
yolo3 = YoloLayer(anchors[2], 416, 80)
inp = torch.randn([1, 255, 52, 52])
det3 = yolo3(out, 416)
det3.shape # 52*52*3

torch.Size([1, 8112, 85])

#### Entire network - putting everything together

In [15]:
class YoloNet(nn.Module):
    def __init__(self, img_dim=None, anchors=[10,13,  16,30,  33,23,  30,61,  62,45,  59,119,  116,90,  156,198,  373,326], numClass=80):
        super().__init__()
        
        nin = 32
        self.numClass = numClass
        self.img_dim = img_dim
        self.stat_keys = ['loss', 'loss_x', 'loss_y', 'loss_w', 'loss_h', 'loss_conf', 'loss_cls',
                          'nCorrect', 'nGT', 'recall']
        
        anchors = [(anchors[i], anchors[i+1]) for i in range(0,len(anchors),2)]
        anchors = [anchors[i:i+3] for i in range(0, len(anchors), 3)][::-1]
        """
        [[(116, 90), (156, 198), (373, 326)],
         [(30, 61), (62, 45), (59, 119)],
         [(10, 13), (16, 30), (33, 23)]]
        """
        self.feature = Darknet([1,2,8,8,4]) # darknet 53 
        self.feature.addCachedOut(61)
        self.feature.addCachedOut(36)
        
        self.pre_det1 = PreDetectionConvGroup(1024, 512, numClass=self.numClass) # 6 + 1 conv layer,see above 
        self.yolo1 = YoloLayer(anchors[0], img_dim, self.numClass) # [(116, 90), (156, 198), (373, 326)]
        self.pre_det1.addCachedOut(-3) #Fetch output from 4th layer backward including yolo layer
        
        self.up1 = UpsampleGroup(512)
        self.pre_det2 = PreDetectionConvGroup(768, 256, numClass=self.numClass)
        self.yolo2 = YoloLayer(anchors[1], img_dim, self.numClass) # [(30, 61), (62, 45), (59, 119)]
        self.pre_det2.addCachedOut(-3)
        
        self.up2 = UpsampleGroup(256)
        self.pre_det3 = PreDetectionConvGroup(384, 128, numClass=self.numClass)
        self.yolo3 = YoloLayer(anchors[2], img_dim, self.numClass) # [(10, 13), (16, 30), (33, 23)]
        
    def forward(self, x):
        img_dim = (x.shape[3], x.shape[2]) # w, h

        # extract features
        out = self.feature(x) # out = torch.Size([1, 1024, 13, 13])

        # detection layer 1
        out = self.pre_det1(out) # torch.Size([1, 255, 13, 13]) , 255=3*(80+5)
        det1 = self.yolo1(out, img_dim) # torch.Size([1, 507, 85]), 507=13*13*3

        # upsample 1
        r_head1 = self.pre_det1.getCachedOut(-3) # torch.Size([1, 512, 13, 13])
        r_tail1 = self.feature.getCachedOut(61)  # torch.Size([1, 512, 26, 26])
        out = self.up1(r_head1, r_tail1) # torch.Size([1, 768, 26, 26]), 256 + 512, see UpsampleGroup

        # detection layer 2
        out = self.pre_det2(out) # torch.Size([1, 255, 26, 26])  , 255=3*(80+5), like pre_det1
        det2 = self.yolo2(out, img_dim) # torch.Size([1, 2028, 85]), 2028 = 26*26*3

        # upsample 2
        r_head2 = self.pre_det2.getCachedOut(-3) # torch.Size([1, 256, 26, 26])
        r_tail2 = self.feature.getCachedOut(36)  # torch.Size([1, 256, 52, 52])
        out = self.up2(r_head2, r_tail2) # torch.Size([1, 384, 52, 52]), 256+128=384

        # detection layer 3
        out = self.pre_det3(out) # torch.Size([1, 255, 52, 52]) , 255=3*(80+5), like pre_det 1,2
        det3 = self.yolo3(out, img_dim) # torch.Size([1, 8112, 85]), 8112=52*52*3

        return det1, det2, det3  # torch.Size([1, 507, 85]), torch.Size([1, 2028, 85]), torch.Size([1, 8112, 85])

![a](yolo_v3.png)

In [16]:
inp = torch.rand([1,3,416,416])
YNet = YoloNet()
det1, det2, det3 = YNet(inp)



In [17]:
det1.shape, det2.shape, det3.shape

(torch.Size([1, 507, 85]),
 torch.Size([1, 2028, 85]),
 torch.Size([1, 8112, 85]))

In [18]:
detections = torch.cat((det1,det2,det3), 1);detections.shape

torch.Size([1, 10647, 85])

#### IOU and non-max supression(NMS)

In [19]:
def iou_vectorized(bbox): # torch.Size([41, 7]), 41 depends on predictions, [b1_x, b1_y, b2_x, b2_y, obj_conf, class_score, class]
    num_box = bbox.shape[0] # 41
    
    bbox_leftTop_x = bbox[:,0] # torch.Size([41])
    bbox_leftTop_y = bbox[:,1] # torch.Size([41])
    bbox_rightBottom_x = bbox[:,2] # torch.Size([41])
    bbox_rightBottom_y = bbox[:,3] # torch.Size([41])
    
    inter_leftTop_x = torch.max(bbox_leftTop_x.unsqueeze(1).repeat(1,num_box), bbox_leftTop_x) # torch.Size([41, 41])
    inter_leftTop_y = torch.max(bbox_leftTop_y.unsqueeze(1).repeat(1,num_box), bbox_leftTop_y)
    inter_rightBottom_x = torch.min(bbox_rightBottom_x.unsqueeze(1).repeat(1,num_box), bbox_rightBottom_x)
    inter_rightBottom_y = torch.min(bbox_rightBottom_y.unsqueeze(1).repeat(1,num_box), bbox_rightBottom_y)
    
    inter_area = torch.clamp(inter_rightBottom_x - inter_leftTop_x, min=0) * torch.clamp(inter_rightBottom_y - inter_leftTop_y, min=0) ## torch.Size([41, 41])
    bbox_area = (bbox_rightBottom_x - bbox_leftTop_x) * (bbox_rightBottom_y - bbox_leftTop_y) # torch.Size([41])
    union_area = bbox_area.expand(num_box, -1) + bbox_area.expand(num_box, -1).transpose(0,1) - inter_area # torch.Size([41, 41])
    
    iou = inter_area / union_area # torch.Size([41, 41])
    
    return iou

In [68]:
iou.shape

torch.Size([41, 41])

In [27]:
#Iterate through the bounding boxes and remove rows accordingly
def reduce_row_by_column(inp):
    i = 0
    while i < inp.shape[0]: # 41
        remove_row_idx = inp[i][1].item()
        if inp[i][0] != remove_row_idx and i < inp.shape[0]:
            keep_mask = (inp[:,0] != remove_row_idx).nonzero().squeeze()
            inp = inp[keep_mask]
        i += 1
        
    return inp

In [26]:
#bbox is expected to be sorted by class score in descending order
def nms(bbox, iou, nms_thres):
    #Create a mapping that indicates which row has iou > threshold
    remove_map = (iou > nms_thres).nonzero() # torch.Size([41, 2])
    remove_map = reduce_row_by_column(remove_map)
    
    remove_idx = torch.unique(remove_map[:,0])
    res_bbox = bbox[remove_idx]
    
    return res_bbox

#### Post-processing - convert predictions from network to bounding boxes (calls IOU/NMS)

In [72]:
def postprocessing(detections, num_classes, obj_conf_thr=0.5, nms_thr=0.4):
    
    # Zero bounding box with objectioness confidence score less than threshold
    obj_conf_filter = (detections[:,:,4] > obj_conf_thr).float().unsqueeze(2) # torch.Size([1, 10647, 1])
    detections = detections * obj_conf_filter
    
    #Transform bounding box coordinates to two corners, from centorid to corner
    box = detections.new(detections[:,:,:4].shape) # torch.Size([1, 10647, 4])
    box[:,:,0] = detections[:,:,0] - detections[:,:,2]/2
    box[:,:,1] = detections[:,:,1] - detections[:,:,3]/2
    box[:,:,2] = box[:,:,0] + detections[:,:,2]
    box[:,:,3] = box[:,:,1] + detections[:,:,3]
    detections[:,:,:4] = box
    
    num_batches = detections.shape[0]
    
    results = list()
    
    for b in range(num_batches):
        #batch_results = torch.Tensor().cuda()
        batch_results = torch.Tensor()
        img_det = detections[b] # torch.Size([10647, 85])
        
        max_class_score, max_class_idx = torch.max(img_det[:,5:5+num_classes], 1) # torch.Size([10647]), torch.Size([10647])
        img_det = torch.cat((img_det[:,:5], # torch.Size([10647, 5])
                           max_class_score.float().unsqueeze(1), # torch.Size([10647, 1])
                           max_class_idx.float().unsqueeze(1)), # torch.Size([10647, 1])
                           1)
        #img det - [b1_x, b1_y, b2_x, b2_y, obj_conf, class_score, class], torch.Size([10647, 7])
        
        #Remove zeroed rows, < obj_conf_thr will be removed.
        nonzero_idx = img_det[:,4].nonzero() # Return the indices of the elements that are non-zero. torch.Size([5105, 1])
        img_det = img_det[nonzero_idx,:].view(-1,7) # torch.Size([5105, 7])
        
        if img_det.shape[0] == 0:
            #results.append(batch_results.cpu())
            results.append(batch_results)
        else:
            # get the classes
            img_classes = torch.unique(img_det[:,-1]) # torch.Size([80])
            for c in img_classes:
                # Select rows with "c" class and sort by the class score
                class_img_det = img_det[(img_det[:,-1] == c).nonzero().squeeze()]
                # If there is only one detection, it will return a 1D tensor. Therefore, we perform a view to keep it in 2D
                class_img_det = class_img_det.view(-1,7)
                #Sort by objectness score
                _, sort_idx = class_img_det[:,4].sort(descending=True)
                class_img_det = class_img_det[sort_idx]
                
                iou = iou_vectorized(class_img_det) # torch.Size([41, 41])
                #Alert: There's another loop operation in nms function
                class_img_det = nms(class_img_det, iou, nms_thr)
                batch_results = torch.cat((batch_results, class_img_det), 0)
                
            #results.append(batch_results.cpu())
            results.append(batch_results)
            
    return results

In [28]:
pdet = postprocessing(detections, 80, obj_conf_thr=0.5)
detections.shape, pdet[0]

(torch.Size([1, 10647, 85]),
 tensor([[331.3665, -10.9208, 349.6289,  ...,   0.5518,   0.5647,  60.0000],
         [220.0789,  23.0935, 307.1667,  ...,   0.6094,   0.6346,  60.0000],
         [293.7178, 232.7700, 331.5706,  ...,   0.5858,   0.5942,  60.0000],
         ...,
         [128.0964, 100.9296, 136.8837,  ...,   0.5062,   0.6574,  53.0000],
         [190.9442,  78.4702, 200.4171,  ...,   0.5062,   0.6237,  53.0000],
         [246.5906, 214.9729, 257.4446,  ...,   0.5052,   0.6380,  53.0000]],
        grad_fn=<CatBackward>))

In [29]:
print(pdet[0].shape)
pdet

torch.Size([5513, 7])


[tensor([[331.3665, -10.9208, 349.6289,  ...,   0.5518,   0.5647,  60.0000],
         [220.0789,  23.0935, 307.1667,  ...,   0.6094,   0.6346,  60.0000],
         [293.7178, 232.7700, 331.5706,  ...,   0.5858,   0.5942,  60.0000],
         ...,
         [128.0964, 100.9296, 136.8837,  ...,   0.5062,   0.6574,  53.0000],
         [190.9442,  78.4702, 200.4171,  ...,   0.5062,   0.6237,  53.0000],
         [246.5906, 214.9729, 257.4446,  ...,   0.5052,   0.6380,  53.0000]],
        grad_fn=<CatBackward>)]