### codes reference https://github.com/ydixon/yolo_v3/blob/master/yolo_detect.ipynb
### a nice bolg for understanding. https://blog.csdn.net/leviopku/article/details/82660381

In [102]:
import numpy as np
from collections import OrderedDict
import torch
import torch.nn as nn
import torch.nn.functional as F

#### Basic network building blocks - conv_bn_relu, res_layer

In [87]:
# see https://blog.csdn.net/leviopku/article/details/82660381, resn
class conv_bn_relu(nn.Module):
    def __init__(self, nin, nout, kernel, stride=1, pad="SAME", padding=0, bn=True, activation="leakyRelu"):
        super().__init__()
        
        self.bn = bn
        self.activation = activation
        
        if pad == 'SAME':
            padding = (kernel-1)//2
            
        self.conv = nn.Conv2d(nin, nout, kernel, stride, padding, bias=not bn)
        if bn == True:
            self.bn = nn.BatchNorm2d(nout)
        if activation == "leakyRelu":
            self.relu = nn.LeakyReLU(negative_slope=0.1, inplace=True)
    
    def forward(self, x):
        return self.relu(self.bn(self.conv(x)))
    
class res_layer(nn.Module):
    def __init__(self, nin):
        super().__init__()
        
        self.conv1 = conv_bn_relu(nin, nin//2, kernel=1)  #64->32, 1
        self.conv2 = conv_bn_relu(nin//2, nin, kernel=3)  #32->64, 3 see figure of darknet
        
    def forward(self, x):
        return x + self.conv2(self.conv1(x))

#### Map2cfgDict - used to creating mapping that follows the cfg file from prjreddit's repository

In [105]:
def map2cfgDict(mlist):
    idx = 0 
    mdict = OrderedDict()
    for i,m in enumerate(mlist):
        if isinstance(m, res_layer):
            mdict[idx] = None
            mdict[idx+1] = None
            idx += 2
        mdict[idx] = i
        idx += 1
    
    return mdict        

#### UpsampleGroup

In [125]:
# see https://blog.csdn.net/leviopku/article/details/82660381, DBL + 上采样
class UpsampleGroup(nn.Module):
    def __init__(self, nin):
        super().__init__()
        self.conv = conv_bn_relu(nin, nin//2, kernel=1)
        self.up = nn.Upsample(scale_factor=2, mode="nearest")
        
    def forward(self, route_head, route_tail):
        out = self.up(self.conv(route_head))
        return torch.cat((out, route_tail), 1)

#### PreDetectionConvGroup - conv layers before the yolo detection layer

In [109]:
class PreDetectionConvGroup(nn.Module):
    def __init__(self, nin, nout, num_conv=3, numClass=80):
        super().__init__()
        
        self.mlist = nn.ModuleList()
        
        for i in range(num_conv):
            self.mlist += [conv_bn_relu(nin, nout, kernel=1)]
            self.mlist += [conv_bn_relu(nout, nout*2, kernel=3)]
            if i == 0:
                nin = nout*2
                
        self.mlist += [nn.Conv2d(nin, (numClass+5)*3, 1)]
        self.map2yolocfg = map2cfgDict(self.mlist)
        self.cachedOutDict = dict()
        
    def forward(self, x):
        for i,m in enumerate(self.mlist):
            x = m(x)
            if i in self.cachedOutDict:
                self.cachedOutDict[i] = x
        
        return x
    
    #mode - normal  -- direct index to mlist 
    #     - yolocfg -- index follow the sequences of the cfg file from https://github.com/pjreddie/darknet/blob/master/cfg/yolov3.cfg
    def addCachedOut(self, idx, mode="yolocfg"):
        if mode == "yolocfg":
            idx = self.getIdxFromYoloIdx(idx)
        elif idx < 0:
            idx = len(self.mlist) - idx
        
        self.cachedOutDict[idx] = None
        
    def getCachedOut(self, idx, mode="yolocfg"):
        if mode == "yolocfg":
            idx = self.getIdxFromYoloIdx(idx)
        elif idx < 0:
            idx = len(self.mlist) - idx
        return self.cachedOutDict[idx]
    
    def getIdxFromYoloIdx(self,idx):
        if idx < 0:
            return len(self.map2yolocfg) + idx
        else:
            return self.map2yolocfg[idx]

#### Darknet53 - Feature extraction

In [100]:
def make_res_stack(nin, num_block):
    return nn.ModuleList([conv_bn_relu(nin, nin*2, 3, stride=2)] + [res_layer(nin*2) for n in range(num_block)])

class Darknet(nn.Module):
    def __init__(self, blkList, nout=32):
        super().__init__()
        
        self.mlist = nn.ModuleList()
        self.mlist += [conv_bn_relu(3, nout, 3)]
        for i,nb in enumerate(blkList):
            self.mlist += make_res_stack(nout*(2**i), nb)
            
        self.map2yolocfg = map2cfgDict(self.mlist)
        self.cachedOutDict = dict()
        
    def forward(self, x):
        for i,m in enumerate(self.mlist):
            x = m(x)
            if i in self.cachedOutDict:
                self.cachedOutDict[i] = x
        return x
    
    def addCachedOut(self, idx, mode="yolocfg"):
        if mode == "yolocfg":
            idxs = self.map2yolocfg[idx]
        self.cachedOutDict[idxs] = None
        
    def getCachedOut(self, idx, mode="yolocfg"):
        if mode == "yolocfg":
            idxs = self.map2yolocfg[idx]
        return self.cachedOutDict[idxs]

In [103]:
#net = Darknet([1,2,8,8,4])

In [107]:
net.addCachedOut(61)

#### Yolo Detection Layer

In [69]:
class YoloLayer(nn.Module):
    def __init__(self, anchors, img_dim, nClass):
        super().__init__()
        
        self.anchors = anchors # [(116, 90), (156, 198), (373, 326)]
        self.img_dim = img_dim
        self.nClass = nClass
        self.bbox_attrib = nClass + 5
        
    def forward(self, x, img_dim):
        # x: bs * nA(5 + nClass) * h * w
        nB = x.shape[0] # batch_size
        nA = len(self.anchors) # 3
        nH, nW = x.shape[2], x.shape[3]
        stride = img_dim / nH # 416/13=32
        anchors = torch.FloatTensor(self.anchors) / stride
        
        ##Reshape predictions from [B x [A * (5 + numClass)] x H x W] to [B x A x H x W x (5 + numClass)]
        # like:[1, 3, 85, 416, 416]->[1, 3, 416, 416, 85], see https://www.zhihu.com/question/60321866 for details.
        preds = x.view(nB, nA, self.bbox_attrib, nH, nW).permute(0, 1, 3, 4, 2).contiguous()
        
        # tx, ty, tw, wh
        preds_xy = preds[..., :2]            # = [:,:,:,:,:2], shape:torch.Size([1, 3, 416, 416, 2])
        preds_wh = preds[..., 2:4]           # [x, y, w, h, 3-84]
        preds_conf = preds[..., 4].sigmoid() # [x, y, w, h, conf, 4-84]
        preds_cls = preds[..., 5:].sigmoid() # [x, y, w, h, conf, cls(80)]
        
        # Calculate cx, cy, anchor mesh
        mesh_x = torch.arange(nW).repeat(nH,1).unsqueeze(2)              # H * W * 1
        mesh_y = torch.arange(nH).repeat(nW,1).t().unsqueeze(2)          # H * W * 1
        mesh_xy = torch.cat((mesh_x,mesh_y), 2)                          # H * W * 2
        mesh_anchors = anchors.view(1, nA, 1, 1, 2).repeat(1, 1, nH, nW, 1) # 1 * nA * 1 * 1 * 2 -> 1 * nA * H * W * 2
        
        # pred_boxes holds bx,by,bw,bh
        pred_boxes = torch.FloatTensor(preds[..., :4].shape)  # [1, 3, 416, 416, 4]
        pred_boxes[..., :2] = preds_xy.detach().cpu().sigmoid() + mesh_xy # sig(tx) + cx, detach(): http://www.bnikolic.co.uk/blog/pytorch-detach.html
        pred_boxes[..., 2:4] = preds_wh.detach().cpu().exp() * mesh_anchors  # exp(tw) * anchor
        
        # Return predictions if not training # # [1, 3, 416, 416, 85]
        out = torch.cat((pred_boxes.cuda() * stride, 
                         preds_conf.cuda().unsqueeze(4),
                         preds_cls.cuda() ), 4)
        
        # Reshape predictions from [B x A x H x W x (5 + numClass)] to [B x [A x H x W] x (5 + numClass)]
        # such that predictions at different strides could be concatenated on the same dimension
        out = out.permute(0, 2, 3, 1, 4).contiguous().view(nB, nA*nH*nW, self.bbox_attrib)
        
        return out

#### Entire network - putting everything together

In [121]:
class YoloNet(nn.Module):
    def __init__(self, img_dim=None, anchors=[10,13,  16,30,  33,23,  30,61,  62,45,  59,119,  116,90,  156,198,  373,326], numClass=80):
        super().__init__()
        
        nin = 32
        self.numClass = numClass
        self.img_dim = img_dim
        self.stat_keys = ['loss', 'loss_x', 'loss_y', 'loss_w', 'loss_h', 'loss_conf', 'loss_cls',
                          'nCorrect', 'nGT', 'recall']
        
        anchors = [(anchors[i], anchors[i+1]) for i in range(0,len(anchors),2)]
        anchors = [anchors[i:i+3] for i in range(0, len(anchors), 3)][::-1]
        """
        [[(116, 90), (156, 198), (373, 326)],
         [(30, 61), (62, 45), (59, 119)],
         [(10, 13), (16, 30), (33, 23)]]
        """
        self.feature = Darknet([1,2,8,8,4]) # darknet 53
        self.feature.addCachedOut(61)
        self.feature.addCachedOut(36)
        
        self.pre_det1 = PreDetectionConvGroup(1024, 512, numClass=self.numClass)
        self.yolo1 = YoloLayer(anchors[0], img_dim, self.numClass) # [(116, 90), (156, 198), (373, 326)]
        self.pre_det1.addCachedOut(-3) #Fetch output from 4th layer backward including yolo layer
        
        self.up1 = UpsampleGroup(512)
        self.pre_det2 = PreDetectionConvGroup(768, 256, numClass=self.numClass)
        self.yolo2 = YoloLayer(anchors[1], img_dim, self.numClass) # [(30, 61), (62, 45), (59, 119)]
        self.pre_det2.addCachedOut(-3)
        
        self.up2 = UpsampleGroup(256)
        self.pre_det3 = PreDetectionConvGroup(384, 128, numClass=self.numClass)
        self.yolo3 = YoloLayer(anchors[2], img_dim, self.numClass) # [(10, 13), (16, 30), (33, 23)]
        
        def forward(self, x):
            img_dim = (x.shape[3], x.shape[2]) # w, h
            
            # extract features
            out = self.feature(x)
            
            # detection layer 1
            out = self.pre_det1(out)
            det1 = self.yolo1(out, img_dim)
            
            # upsample 1
            r_head1 = self.pre_det1.getCachedOut(-3)
            r_tail1 = self.feature.getCachedOut(61)
            out = self.up1(r_head1, r_tail1)
            
            # detection layer 2
            out = self.pre_det2(out)
            det2 = self.yolo2(out, img_dim)
            
            # upsample 2
            r_head2 = self.pre_det2.getCachedOut(-3)
            r_tail2 = self.feature.getCachedOut(36)
            out = self.up2(r_head2, r_tail2)
            
            # detection layer 3
            out = self.pre_det3(out)
            det3 = self.yolo3(out, img_dim)
            
            return det1, det2, det3  

In [126]:
yolo = YoloNet()

In [127]:
yolo

YoloNet(
  (feature): Darknet(
    (mlist): ModuleList(
      (0): conv_bn_relu(
        (conv): Conv2d(3, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): LeakyReLU(negative_slope=0.1, inplace)
      )
      (1): conv_bn_relu(
        (conv): Conv2d(32, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
        (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): LeakyReLU(negative_slope=0.1, inplace)
      )
      (2): res_layer(
        (conv1): conv_bn_relu(
          (conv): Conv2d(64, 32, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu): LeakyReLU(negative_slope=0.1, inplace)
        )
        (conv2): conv_bn_relu(
          (conv): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1)