In [None]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [1]:
!mkdir -p ~/.kaggle
!cp kaggle.json ~/.kaggle/
!chmod 600 ~/.kaggle/kaggle.json

In [2]:
!kaggle datasets download -d aladdinpersson/pascal-voc-dataset-used-in-yolov3-video

Downloading pascal-voc-dataset-used-in-yolov3-video.zip to /content
100% 4.30G/4.31G [03:40<00:00, 21.3MB/s]
100% 4.31G/4.31G [03:40<00:00, 21.0MB/s]


In [3]:
!unzip  -qq /content/pascal-voc-dataset-used-in-yolov3-video.zip

In [None]:
# https://wikidocs.net/181720
# https://github.com/RichardMinsooGo-ML/TF2.0-Yolov3-Yolov4-image/blob/main/yolo_core/models.py

In [1]:
# YOLOv4 model

import torch
import torch.nn as nn
import math

class DarknetConv2D(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, downsample=False, bn_act=True, act="mish"):
        super().__init__()

        if downsample:
            kernel_size = 3
            stride = 2
            padding = "valid"

        else:
            stride = 1
            padding = "same"

        self.conv = nn.Conv2d(
            in_channels=in_channels, out_channels=out_channels,
            kernel_size=kernel_size, stride=stride, padding=padding,
            bias=not bn_act
        )

        self.downsample = downsample
        self.bn = nn.BatchNorm2d(out_channels)
        self.mish = nn.Mish()
        self.leaky = nn.LeakyReLU(0.1)
        self.use_bn_act = bn_act
        self.act = act

    def forward(self, x):
        if self.downsample:
            x = torch.nn.functional.pad(x, (1, 0, 1, 0))
        if self.use_bn_act:
            if self.act == "mish":
                return self.mish(self.bn(self.conv(x)))
            elif self.act == "leaky":
                return self.leaky(self.bn(self.conv(x)))
        else:
            return self.conv(x)



class CSPResBlock(nn.Module):
    def __init__(self, in_channels, num_repeats=1):
        super().__init__()

        self.split1x1 = DarknetConv2D(in_channels=in_channels, out_channels=in_channels//2, kernel_size=1)
        self.res1x1 = DarknetConv2D(in_channels=in_channels//2, out_channels=in_channels//2, kernel_size=1)
        self.concat1x1 = DarknetConv2D(in_channels=in_channels, out_channels=in_channels, kernel_size=1)
        self.num_repeats = num_repeats

        self.DenseBlock = nn.ModuleList()
        for i in range(num_repeats):
            DenseLayer = nn.ModuleList()
            DenseLayer.append(DarknetConv2D(in_channels=in_channels//2, out_channels=in_channels//2, kernel_size=1))
            DenseLayer.append(DarknetConv2D(in_channels=in_channels//2, out_channels=in_channels//2, kernel_size=3))
            self.DenseBlock.append(DenseLayer)

    def forward(self, x):
        route = self.split1x1(x)
        x = self.split1x1(x)

        for module in self.DenseBlock:
            h = x
            for res in module:
                h = res(h)
            x = x + h

        x = self.res1x1(x)
        x = torch.cat([x, route], dim=1)
        x = self.concat1x1(x)

        return x



class SPP(nn.Module):
    def __init__(self):
        super().__init__()

        self.maxpool5 = nn.MaxPool2d(kernel_size=5, stride=1, padding=2)
        self.maxpool9 = nn.MaxPool2d(kernel_size=9, stride=1, padding=4)
        self.maxpool13 = nn.MaxPool2d(kernel_size=13, stride=1, padding=6)

    def forward(self, x):
        x = torch.cat([x,
                       self.maxpool5(x),
                       self.maxpool9(x),
                       self.maxpool13(x)], dim=1)

        return x



class ScalePrediction(nn.Module):
    def __init__(self, in_channels, num_classes):
        super().__init__()

        self.conv = DarknetConv2D(in_channels=in_channels, out_channels=in_channels*2, kernel_size=3, act="leaky")
        self.ScalePred = DarknetConv2D(in_channels=in_channels*2, out_channels=3*(num_classes+5), kernel_size=1, bn_act=False, act="leaky")
        self.num_classes = num_classes

    def forward(self, x):
        return(
            self.ScalePred(self.conv(x))
            # x = [batch_num, 3*(num_classes + 5), N, N
            .reshape(x.shape[0], 3, self.num_classes + 5, x.shape[2], x.shape[3])
            .permute(0, 1, 3, 4, 2)
            # output = [B x 3 x N x N x 5+num_classes]
        )



class CSPDarknet53(nn.Module):
    def __init__(self, in_channels):
        super().__init__()

        self.in_channels = in_channels

        self.layers = nn.ModuleList([
            DarknetConv2D(in_channels=in_channels, out_channels=32, kernel_size=3),
            DarknetConv2D(in_channels=32, out_channels=64, kernel_size=3, downsample=True),
            CSPResBlock(in_channels=64, num_repeats=1),
            DarknetConv2D(in_channels=64, out_channels=128, kernel_size=3, downsample=True),
            CSPResBlock(in_channels=128, num_repeats=2),
            DarknetConv2D(in_channels=128, out_channels=256, kernel_size=3, downsample=True),
            CSPResBlock(in_channels=256, num_repeats=8), # Route_1
            DarknetConv2D(in_channels=256, out_channels=512, kernel_size=3, downsample=True),
            CSPResBlock(in_channels=512, num_repeats=8), # Route_2
            DarknetConv2D(in_channels=512, out_channels=1024, kernel_size=3, downsample=True),
            CSPResBlock(in_channels=1024, num_repeats=4),
            DarknetConv2D(in_channels=1024, out_channels=512, kernel_size=1, act="leaky"),
            DarknetConv2D(in_channels=512, out_channels=1024, kernel_size=3, act="leaky"),
            DarknetConv2D(in_channels=1024, out_channels=512, kernel_size=1, act="leaky"),
            SPP(),
            DarknetConv2D(in_channels=2048, out_channels=512, kernel_size=1, act="leaky"),
            DarknetConv2D(in_channels=512, out_channels=1024, kernel_size=3, act="leaky"),
            DarknetConv2D(in_channels=1024, out_channels=512, kernel_size=1, act="leaky") # output
        ])

    def forward(self, x):
        route = []

        for layer in self.layers:
            x = layer(x)

            if isinstance(layer, CSPResBlock) and layer.num_repeats == 8:
                route.append(x)

        route.append(x)

        return tuple(route)



class Conv5(nn.Module):
    def __init__(self, in_channels, up=True):
        super().__init__()

        self.in_channels = in_channels
        self.up = up
        self.conv1x1 = DarknetConv2D(in_channels=in_channels, out_channels=in_channels//2, kernel_size=1, act="leaky")
        self.conv3x3 = DarknetConv2D(in_channels=in_channels//2, out_channels=in_channels, kernel_size=3, act="leaky")

    def forward(self, x):
        x = self.conv1x1(x)
        x = self.conv3x3(x)
        x = self.conv1x1(x)
        x = self.conv3x3(x)
        x = self.conv1x1(x)

        return x



class YOLOv4(nn.Module):
    def __init__(self, in_channels=3, num_classes=80):
        super().__init__()

        self.in_channels = in_channels
        self.num_classes = num_classes
        self.CSPDarknet53 = CSPDarknet53(in_channels)
        self.route2conv = DarknetConv2D(in_channels=512, out_channels=256, kernel_size=1, act="leaky")
        self.route1conv = DarknetConv2D(in_channels=256, out_channels=128, kernel_size=1, act="leaky")

        self.layers = nn.ModuleList([
            DarknetConv2D(in_channels=512, out_channels=256, kernel_size=1, act="leaky"),
            nn.Upsample(scale_factor=2, mode='nearest'),
            Conv5(in_channels=512, up=True), # after concat
            DarknetConv2D(in_channels=256, out_channels=128, kernel_size=1, act="leaky"),
            nn.Upsample(scale_factor=2, mode='nearest'),
            Conv5(in_channels=256, up=True), # after concat
            ScalePrediction(in_channels=128, num_classes=num_classes), # sbbox 52x52
            DarknetConv2D(in_channels=128, out_channels=256, kernel_size=3, downsample=True, act="leaky"),
            Conv5(in_channels=512, up=False),
            ScalePrediction(in_channels=256, num_classes=num_classes), # mbbox 26x26
            DarknetConv2D(in_channels=256, out_channels=512, kernel_size=3, downsample=True, act="leaky"),
            Conv5(in_channels=1024, up=False),
            ScalePrediction(in_channels=512, num_classes=num_classes)  # lbbox 13x13
        ])

    def forward(self, x):
        outputs = []
        Route = []
        OutputRoute = []
        route1, route2, CSPout = self.CSPDarknet53(x)

        OutputRoute.append(CSPout)

        route2 = self.route2conv(route2)
        route1 = self.route1conv(route1)
        Route.append(route1)
        Route.append(route2)


        x = CSPout
        for layer in self.layers:
            if isinstance(layer, ScalePrediction):
                outputs.append(layer(x))
                continue # Since this is the output of each scale, it must skip x = ScalePrediction(x).

            x = layer(x)

            if isinstance(layer, nn.Upsample):
                x = torch.cat([Route[-1], x], dim=1)
                Route.pop()

            if isinstance(layer, Conv5) and layer.in_channels == 512 and layer.up == True:
                OutputRoute.append(x)

            if isinstance(layer, DarknetConv2D) and layer.downsample == True:
                x = torch.cat([x, OutputRoute[-1]], dim=1)
                OutputRoute.pop()

        outputs[0], outputs[1], outputs[2] = outputs[2], outputs[1], outputs[0]
        return outputs

        '''
        torch.Size([1, 13, 13, 255])
        torch.Size([1, 26, 26, 255])
        torch.Size([1, 52, 52, 255])
        '''

In [3]:
# model test
# import config
num_classes = 20
IMAGE_SIZE = 416
model = YOLOv4(num_classes=num_classes)

x = torch.randn((2, 3, IMAGE_SIZE, IMAGE_SIZE))
out = model(x)

assert out[0].shape == (2, 3, IMAGE_SIZE//32, IMAGE_SIZE//32, num_classes + 5)
assert out[1].shape == (2, 3, IMAGE_SIZE//16, IMAGE_SIZE//16, num_classes + 5)
assert out[2].shape == (2, 3, IMAGE_SIZE//8, IMAGE_SIZE//8, num_classes + 5)
print("Success!")

Success!


In [2]:
import config
import torch
import torch.optim as optim

# from yolov4 import YOLOv4
from tqdm import tqdm
from utils import (
    mean_average_precision,
    non_max_suppression,
    cells_to_bboxes,
    get_evaluation_bboxes,
    save_checkpoint,
    load_checkpoint,
    check_class_accuracy,
    get_loaders,
    plot_couple_examples,
    plot_image
)
from loss import YoloLoss
import warnings
warnings.filterwarnings("ignore")

torch.backends.cudnn.benchmark = True


print(config.ANCHORS)
print(config.LEARNING_RATE)


def train_fn(train_loader, model, optimizer, loss_fn, scaler, scaled_anchors, box_loss="MSE"):
    loop = tqdm(train_loader, leave=True)
    losses = []
    for batch_idx, (x, y) in enumerate(loop):
        x = x.to(config.DEVICE)
        y0, y1, y2 = (
            y[0].to(config.DEVICE),
            y[1].to(config.DEVICE),
            y[2].to(config.DEVICE),
        )

        with torch.cuda.amp.autocast():
            out = model(x)
            loss = (
                loss_fn(out[0], y0, scaled_anchors[0], box_loss=box_loss)
                + loss_fn(out[1], y1, scaled_anchors[1], box_loss=box_loss)
                + loss_fn(out[2], y2, scaled_anchors[2], box_loss=box_loss)
            )

        losses.append(loss.item())
        optimizer.zero_grad()
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        # update progress bar
        mean_loss = sum(losses) / len(losses)
        loop.set_postfix(loss=mean_loss)



def main():
    model = YOLOv4(num_classes=config.NUM_CLASSES).to(config.DEVICE)
    # model = YOLOv4(backbone=CSPDarknet53(pretrained=True)).to(config.DEVICE)
    optimizer = optim.Adam(
        model.parameters(), lr=config.LEARNING_RATE, weight_decay=config.WEIGHT_DECAY
    )
    loss_fn = YoloLoss()
    scaler = torch.cuda.amp.GradScaler()

    train_loader, test_loader, train_eval_loader = get_loaders(
        train_csv_path=config.DATASET + "/train.csv", test_csv_path=config.DATASET + "/test.csv"
    )

    if False: # config.LOAD_MODEL
        load_checkpoint(
            config.CHECKPOINT_FILE, model, optimizer, config.LEARNING_RATE
        ) # config.CHECKPOINT_FILE

    scaled_anchors = (
        torch.tensor(config.ANCHORS)
        * torch.tensor(config.S).unsqueeze(1).unsqueeze(1).repeat(1, 3, 2)
    ).to(config.DEVICE)

    for epoch in range(200):
        train_fn(train_loader, model, optimizer, loss_fn, scaler, scaled_anchors, box_loss="MSE")

        if config.SAVE_MODEL:
            save_checkpoint(model, optimizer, filename=f"/content/drive/MyDrive/yolov3/checkpoint.pth.tar")

        #print(f"Currently epoch {epoch}")
        #print("On Train Eval loader:")
        #print("On Train loader:")
        #check_class_accuracy(model, train_loader, threshold=config.CONF_THRESHOLD)

        if epoch > 0 and (epoch+1) % 10 == 0:
            #plot_couple_examples(model, test_loader, 0.2, 0.45, scaled_anchors , iou_mode="IoU")
            # check_class_accuracy(model, test_loader, threshold=config.CONF_THRESHOLD)
            pred_boxes, true_boxes = get_evaluation_bboxes(
                test_loader,
                model,
                iou_threshold=0.45,
                anchors=config.ANCHORS,
                threshold=0.5,
                iou_mode = "IoU"
            )
            mapval = mean_average_precision(
                pred_boxes,
                true_boxes,
                iou_threshold=config.MAP_IOU_THRESH,
                box_format="midpoint",
                num_classes=config.NUM_CLASSES,
            )
            print(f"MAP: {mapval.item()}")
            model.train()


if __name__ == "__main__":
    main()



[[(0.3, 0.23), (0.41, 0.52), (0.99, 0.87)], [(0.07, 0.16), (0.16, 0.11), (0.15, 0.31)], [(0.02, 0.03), (0.04, 0.07), (0.08, 0.06)]]
0.001


100%|██████████| 518/518 [06:20<00:00,  1.36it/s, loss=16.1]


=> Saving checkpoint


100%|██████████| 518/518 [06:12<00:00,  1.39it/s, loss=12.9]


=> Saving checkpoint


100%|██████████| 518/518 [06:12<00:00,  1.39it/s, loss=12]


=> Saving checkpoint


100%|██████████| 518/518 [06:13<00:00,  1.39it/s, loss=11.3]


=> Saving checkpoint


100%|██████████| 518/518 [06:14<00:00,  1.38it/s, loss=10.8]


=> Saving checkpoint


100%|██████████| 518/518 [06:15<00:00,  1.38it/s, loss=10.3]


=> Saving checkpoint


100%|██████████| 518/518 [06:15<00:00,  1.38it/s, loss=10]


=> Saving checkpoint


100%|██████████| 518/518 [06:16<00:00,  1.38it/s, loss=9.73]


=> Saving checkpoint


100%|██████████| 518/518 [06:16<00:00,  1.38it/s, loss=9.43]


=> Saving checkpoint


100%|██████████| 518/518 [06:16<00:00,  1.38it/s, loss=9.19]


=> Saving checkpoint


100%|██████████| 155/155 [04:22<00:00,  1.70s/it]


MAP: 0.014559321105480194


  3%|▎         | 13/518 [00:11<07:09,  1.18it/s, loss=9.85]


KeyboardInterrupt: ignored

In [None]:
model = YOLOv4(num_classes=config.NUM_CLASSES).to(config.DEVICE)
optimizer = optim.Adam(
    model.parameters(), lr=config.LEARNING_RATE, weight_decay=config.WEIGHT_DECAY
)
loss_fn = YoloLoss()
scaler = torch.cuda.amp.GradScaler()

train_loader, test_loader, train_eval_loader = get_loaders(
    train_csv_path=config.DATASET + "/train.csv", test_csv_path=config.DATASET + "/test.csv"
)

if True: # config.LOAD_MODEL
    load_checkpoint(
        config.CHECKPOINT_FILE, model, optimizer, config.LEARNING_RATE
    )

scaled_anchors = (
    torch.tensor(config.ANCHORS)
    * torch.tensor(config.S).unsqueeze(1).unsqueeze(1).repeat(1, 3, 2)
).to(config.DEVICE)

plot_couple_examples(model, test_loader, 0.2, 0.45, scaled_anchors , iou_mode="IoU")

In [None]:
import torch
import torch.nn.functional as F

# 예제 입력 데이터 생성 (배치 크기, 채널 수, 높이, 너비)
batch_size, in_channels, height, width = 1, 3, 32, 32
input_data = torch.randn(batch_size, in_channels, height, width)

# DarknetConv2D 레이어 생성
darknet_conv = DarknetConv2D(in_channels=3, out_channels=64, downsample=True, bn_act=True, act="mish")
res = CSPResBlock(in_channels=64, num_repeats=1)

# 결과 출력
output_data1 = darknet_conv(input_data)
output_data2 = res(output_data1)
print("Input shape:", input_data.shape)
print("Output1 shape:", output_data1.shape)
print("Output2 shape:", output_data2.shape)


batch_size, in_channels, height, width = 1, 512, 32, 32
input_data = torch.randn(batch_size, in_channels, height, width)
spp = SPP()
output_data = spp(input_data)

print("SPP Input shape:", input_data.shape)
print("SPP Output shape:", output_data.shape)


batch_size, in_channels, height, width = 1, 512, 52, 52
input_data = torch.randn(batch_size, in_channels, height, width)
sp = ScalePrediction(in_channels=512, num_classes=80)
output_data = sp(input_data)

print("ScalePrediction Input shape:", input_data.shape)
print("ScalePrediction Output shape:", output_data.shape)


Input shape: torch.Size([1, 3, 32, 32])
Output1 shape: torch.Size([1, 64, 16, 16])
Output2 shape: torch.Size([1, 64, 16, 16])
SPP Input shape: torch.Size([1, 512, 32, 32])
SPP Output shape: torch.Size([1, 2048, 32, 32])
ScalePrediction Input shape: torch.Size([1, 512, 52, 52])
ScalePrediction Output shape: torch.Size([1, 255, 52, 52])


In [None]:
batch_size, in_channels, height, width = 2, 3, 416, 416
input_data = torch.randn(batch_size, in_channels, height, width)
darknet = CSPDarknet53(in_channels=3)
route1, route2, ouput = darknet(input_data)

print("CSPDarknet53 Route1:", route1.shape)
print("CSPDarknet53 Route2:", route2.shape)
print("CSPDarknet53 Output:", ouput.shape)

CSPDarknet53 Route1: torch.Size([2, 256, 52, 52])
CSPDarknet53 Route2: torch.Size([2, 512, 26, 26])
CSPDarknet53 Output: torch.Size([2, 512, 13, 13])


In [3]:
anchor_values = [12, 16, 19, 36, 40, 28, 36, 75, 76, 55, 72, 146, 142, 110, 192, 243, 459, 401]

# 최댓값 계산
max_value = max(anchor_values)

# 각 값을 최댓값으로 나누어 정규화
normalized_anchors = [value / max_value for value in anchor_values]

print(normalized_anchors)

[0.026143790849673203, 0.034858387799564274, 0.04139433551198257, 0.0784313725490196, 0.08714596949891068, 0.06100217864923747, 0.0784313725490196, 0.16339869281045752, 0.1655773420479303, 0.11982570806100218, 0.1568627450980392, 0.31808278867102396, 0.3093681917211329, 0.23965141612200436, 0.41830065359477125, 0.5294117647058824, 1.0, 0.8736383442265795]


In [4]:
# model test
# import config
num_classes = 20
IMAGE_SIZE = 416
model = YOLOv4(backbone=CSPDarknet53(pretrained=True))

x = torch.randn((2, 3, IMAGE_SIZE, IMAGE_SIZE))
out = model(x)

assert out[0].shape == (2, 3, IMAGE_SIZE//32, IMAGE_SIZE//32, num_classes + 5)
assert out[1].shape == (2, 3, IMAGE_SIZE//16, IMAGE_SIZE//16, num_classes + 5)
assert out[2].shape == (2, 3, IMAGE_SIZE//8, IMAGE_SIZE//8, num_classes + 5)
print("Success!")

num_params :  63012949
Success!


In [2]:
import torch
# import wget
import os
import torch.nn as nn

import numpy as np
import torch.nn.functional as F


class ResidualBlock(nn.Module):
    def __init__(self, in_channel, hidden_channel=None):
        if hidden_channel is None:
            hidden_channel = in_channel
        super().__init__()
        self.features = nn.Sequential(nn.Conv2d(in_channel, hidden_channel, 1, stride=1, padding=0, bias=False),
                                      nn.BatchNorm2d(hidden_channel),
                                      Mish(),
                                      nn.Conv2d(hidden_channel, in_channel, 3, stride=1, padding=1, bias=False),
                                      nn.BatchNorm2d(in_channel),
                                      Mish(),
                                      )

    def forward(self, x):
        residual = x
        x = self.features(x)
        x += residual
        return x


class Mish(nn.Module):
    def __init__(self):
        super(Mish, self).__init__()

    def forward(self, x):
        return x * torch.tanh(F.softplus(x))


class CSPBlock(nn.Module):
    def __init__(self, in_channel, is_first=False, num_blocks=1):
        super().__init__()
        self.part1_conv = nn.Sequential(nn.Conv2d(in_channel, in_channel//2, 1, stride=1, padding=0, bias=False),
                                        nn.BatchNorm2d(in_channel//2),
                                        Mish())
        self.part2_conv = nn.Sequential(nn.Conv2d(in_channel, in_channel//2, 1, stride=1, padding=0, bias=False),
                                        nn.BatchNorm2d(in_channel//2),
                                        Mish())
        self.features = nn.Sequential(*[ResidualBlock(in_channel=in_channel//2) for _ in range(num_blocks)])
        self.transition1_conv = nn.Sequential(nn.Conv2d(in_channel//2, in_channel//2, 1, stride=1, padding=0, bias=False),
                                              nn.BatchNorm2d(in_channel//2),
                                              Mish())
        self.transition2_conv = nn.Sequential(nn.Conv2d(in_channel, in_channel, 1, stride=1, padding=0, bias=False),
                                              nn.BatchNorm2d(in_channel),
                                              Mish())
        if is_first:
            self.part1_conv = nn.Sequential(nn.Conv2d(in_channel, in_channel, 1, stride=1, padding=0, bias=False),
                                            nn.BatchNorm2d(in_channel),
                                            Mish())
            self.part2_conv = nn.Sequential(nn.Conv2d(in_channel, in_channel, 1, stride=1, padding=0, bias=False),
                                            nn.BatchNorm2d(in_channel),
                                            Mish())
            self.features = nn.Sequential(*[ResidualBlock(in_channel=in_channel,
                                                          hidden_channel=in_channel//2) for _ in range(num_blocks)])
            self.transition1_conv = nn.Sequential(nn.Conv2d(in_channel, in_channel, 1, stride=1, padding=0, bias=False),
                                                  nn.BatchNorm2d(in_channel),
                                                  Mish())
            self.transition2_conv = nn.Sequential(nn.Conv2d(2 * in_channel, in_channel, 1, stride=1, padding=0, bias=False),
                                                  nn.BatchNorm2d(in_channel),
                                                  Mish())

    def forward(self, x):
        # split features
        part1 = self.part1_conv(x)
        part2 = self.part2_conv(x)

        # residual
        residual = part2
        part2 = self.features(part2)
        part2 += residual
        part2 = self.transition1_conv(part2)

        x = self.transition2_conv(torch.cat([part1, part2], dim=1))
        return x


class CSPDarknet53(nn.Module):
    def __init__(self, num_classes=20, pretrained=False):
        super().__init__()

        self.num_classes = num_classes
        self.features1 = nn.Sequential(
            nn.Conv2d(3, 32, 3, stride=1, padding=1, bias=False),
            nn.BatchNorm2d(32),
            Mish(),
            nn.Conv2d(32, 64, 3, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(64),
            Mish(),
            CSPBlock(64, is_first=True, num_blocks=1),
            nn.Conv2d(64, 128, 3, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(128),
            Mish(),
            CSPBlock(128, num_blocks=2),
            nn.Conv2d(128, 256, 3, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(256),
            Mish(),
            CSPBlock(256, num_blocks=8),
        )

        self.features2 = nn.Sequential(
            nn.Conv2d(256, 512, 3, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(512),
            Mish(),
            CSPBlock(512, num_blocks=8),
        )

        self.features3 = nn.Sequential(
            nn.Conv2d(512, 1024, 3, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(1024),
            Mish(),
            CSPBlock(1024, num_blocks=4),
        )

        self.gap = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(1024, self.num_classes)
        self.init_layer()


    def count_parameters(self):
        return sum(p.numel() for p in self.parameters())

    def init_layer(self):
        for layer in self.children():
            if isinstance(layer, nn.Conv2d):
                nn.init.xavier_uniform_(layer.weight)

    def forward(self, x):
        x = self.features1(x)
        x = self.features2(x)
        x = self.features3(x)
        x = self.gap(x)
        x = x.view(-1, 1024)
        x = self.fc(x)
        return x

    def load_darknet_weights(self, weights_path):
        """Parses and loads the weights stored in 'weights_path'"""

        # Open the weights file
        with open(weights_path, "rb") as f:
            header = np.fromfile(f, dtype=np.int32, count=5)  # First five are header values
            self.header_info = header  # Needed to write header when saving weights
            self.seen = header[3]  # number of images seen during training
            weights = np.fromfile(f, dtype=np.float32)  # The rest are weights

        ptr = 0
        conv_layer = None
        # refer to https://github.com/eriklindernoren/PyTorch-YOLOv3/blob/master/models.py
        for i, module in enumerate(self.modules()):
            if isinstance(module, nn.Conv2d):
                conv_layer = module
            if isinstance(module, nn.BatchNorm2d):
                bn_layer = module
                num_b = bn_layer.bias.numel()  # Number of biases

                # Bias
                bn_b = torch.from_numpy(weights[ptr: ptr + num_b]).view_as(bn_layer.bias)
                bn_layer.bias.data.copy_(bn_b)
                ptr += num_b

                # Weight
                bn_w = torch.from_numpy(weights[ptr: ptr + num_b]).view_as(bn_layer.weight)
                bn_layer.weight.data.copy_(bn_w)
                ptr += num_b

                # Running Mean
                bn_rm = torch.from_numpy(weights[ptr: ptr + num_b]).view_as(bn_layer.running_mean)
                bn_layer.running_mean.data.copy_(bn_rm)
                ptr += num_b

                # Running Var
                bn_rv = torch.from_numpy(weights[ptr: ptr + num_b]).view_as(bn_layer.running_var)
                bn_layer.running_var.data.copy_(bn_rv)
                ptr += num_b
            else:
                continue
            # Load conv. weights
            num_w = conv_layer.weight.numel()
            conv_w = torch.from_numpy(weights[ptr: ptr + num_w]).view_as(conv_layer.weight)
            conv_layer.weight.data.copy_(conv_w)
            ptr += num_w


# https://github.com/ultralytics/yolov5/blob/850970e081687df6427898948a27df37ab4de5d3/models/common.py#L139
class SPPNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Sequential(
                        nn.Conv2d(1024, 512, 1, stride=1, padding=0, bias=False),
                        nn.BatchNorm2d(512),
                        Mish(),
        )
        self.conv2 = nn.Sequential(
            nn.Conv2d(2048, 2048, 1, stride=1, padding=0, bias=False),
            nn.BatchNorm2d(2048),
            Mish(),
        )

        self.maxpool5 = nn.MaxPool2d(kernel_size=5, stride=1, padding=5//2)
        self.maxpool9 = nn.MaxPool2d(kernel_size=9, stride=1, padding=9//2)
        self.maxpool13 = nn.MaxPool2d(kernel_size=13, stride=1, padding=13//2)

    def forward(self, x):
        x = self.conv1(x)   # torch.Size([1, 512, 16, 16])
        maxpool5 = self.maxpool5(x)
        maxpool9 = self.maxpool9(x)
        maxpool13 = self.maxpool13(x)
        x = torch.cat([x, maxpool5, maxpool9, maxpool13], dim=1)
        x = self.conv2(x)
        return x


class PANet(nn.Module):
    def __init__(self):
        super(PANet, self).__init__()

        self.p52d5 = nn.Sequential(nn.Conv2d(2048, 512, 1, stride=1, padding=0, bias=False),
                                   nn.BatchNorm2d(512),
                                   Mish(),
                                   nn.Conv2d(512, 1024, 3, stride=1, padding=1, bias=False),
                                   nn.BatchNorm2d(1024),
                                   Mish(),
                                   nn.Conv2d(1024, 512, 1, stride=1, padding=0, bias=False),
                                   nn.BatchNorm2d(512),
                                   Mish(),
                                   )

        self.p42p4_ = nn.Sequential(nn.Conv2d(512, 256, 1, stride=1, padding=0, bias=False),
                                    nn.BatchNorm2d(256),
                                    Mish(),
                                    )

        self.p32p3_ = nn.Sequential(nn.Conv2d(256, 128, 1, stride=1, padding=0, bias=False),
                                    nn.BatchNorm2d(128),
                                    Mish(),
                                    )

        self.d5_p4_2d4 = nn.Sequential(nn.Conv2d(512, 256, 1, stride=1, padding=0, bias=False),
                                       nn.BatchNorm2d(256),
                                       Mish(),
                                       nn.Conv2d(256, 512, 3, stride=1, padding=1, bias=False),
                                       nn.BatchNorm2d(512),
                                       Mish(),
                                       nn.Conv2d(512, 256, 1, stride=1, padding=0, bias=False),
                                       nn.BatchNorm2d(256),
                                       Mish(),
                                       nn.Conv2d(256, 512, 3, stride=1, padding=1, bias=False),
                                       nn.BatchNorm2d(512),
                                       Mish(),
                                       nn.Conv2d(512, 256, 1, stride=1, padding=0, bias=False),
                                       nn.BatchNorm2d(256),
                                       Mish(),
                                       )

        self.d4_p3_2d3 = nn.Sequential(nn.Conv2d(256, 128, 1, stride=1, padding=0, bias=False),
                                       nn.BatchNorm2d(128),
                                       Mish(),
                                       nn.Conv2d(128, 256, 3, stride=1, padding=1, bias=False),
                                       nn.BatchNorm2d(256),
                                       Mish(),
                                       nn.Conv2d(256, 128, 1, stride=1, padding=0, bias=False),
                                       nn.BatchNorm2d(128),
                                       Mish(),
                                       nn.Conv2d(128, 256, 3, stride=1, padding=1, bias=False),
                                       nn.BatchNorm2d(256),
                                       Mish(),
                                       nn.Conv2d(256, 128, 1, stride=1, padding=0, bias=False),
                                       nn.BatchNorm2d(128),
                                       Mish(),
                                       )

        self.d52d5_ = nn.Sequential(nn.Conv2d(512, 256, 1, stride=1, padding=0, bias=False),
                                    nn.BatchNorm2d(256),
                                    Mish(),
                                    nn.Upsample(scale_factor=2)
                                    )

        self.d42d4_ = nn.Sequential(nn.Conv2d(256, 128, 1, stride=1, padding=0, bias=False),
                                    nn.BatchNorm2d(128),
                                    Mish(),
                                    nn.Upsample(scale_factor=2)
                                    )

        self.u32u3_ = nn.Sequential(nn.Conv2d(128, 256, 3, stride=2, padding=1, bias=False),
                                    nn.BatchNorm2d(256),
                                    Mish())

        self.u42u4_ = nn.Sequential(nn.Conv2d(256, 512, 3, stride=2, padding=1, bias=False),
                                    nn.BatchNorm2d(512),
                                    Mish())

        self.d4u3_2u4 = nn.Sequential(nn.Conv2d(512, 256, 1, stride=1, padding=0, bias=False),
                                      nn.BatchNorm2d(256),
                                      Mish(),

                                      nn.Conv2d(256, 512, 3, stride=1, padding=1, bias=False),
                                      nn.BatchNorm2d(512),
                                      Mish(),

                                      nn.Conv2d(512, 256, 1, stride=1, padding=0, bias=False),
                                      nn.BatchNorm2d(256),
                                      Mish(),

                                      nn.Conv2d(256, 512, 3, stride=1, padding=1, bias=False),
                                      nn.BatchNorm2d(512),
                                      Mish(),

                                      nn.Conv2d(512, 256, 1, stride=1, padding=0, bias=False),
                                      nn.BatchNorm2d(256),
                                      Mish(),
                                      )

        self.d5u4_2u5 = nn.Sequential(nn.Conv2d(1024, 512, 1, stride=1, padding=0, bias=False),
                                      nn.BatchNorm2d(512),
                                      Mish(),

                                      nn.Conv2d(512, 1024, 3, stride=1, padding=1, bias=False),
                                      nn.BatchNorm2d(1024),
                                      Mish(),

                                      nn.Conv2d(1024, 512, 1, stride=1, padding=0, bias=False),
                                      nn.BatchNorm2d(512),
                                      Mish(),

                                      nn.Conv2d(512, 1024, 3, stride=1, padding=1, bias=False),
                                      nn.BatchNorm2d(1024),
                                      Mish(),

                                      nn.Conv2d(1024, 512, 1, stride=1, padding=0, bias=False),
                                      nn.BatchNorm2d(512),
                                      Mish(),
                                      )

    def forward(self, P5, P4, P3):
        D5 = self.p52d5(P5)    # [B, 512, 13, 13]
        D5_ = self.d52d5_(D5)  # [B, 256, 26, 26]
        P4_ = self.p42p4_(P4)  # [B, 256, 26, 26]
        D4 = self.d5_p4_2d4(torch.cat([D5_, P4_], dim=1))   # [B, 256, 26, 26]
        D4_ = self.d42d4_(D4)                               # [B, 128, 52, 52]
        P3_ = self.p32p3_(P3)                               # [B, 128, 52, 52]
        D3 = self.d4_p3_2d3(torch.cat([D4_, P3_], dim=1))   # [B, 128, 52, 52]

        U3 = D3                                             # [B, 128, 52, 52]   V
        U3_ = self.u32u3_(U3)
        U4 = self.d4u3_2u4(torch.cat([D4, U3_], dim=1))     # [B, 256, 26, 26]   V
        U4_ = self.u42u4_(U4)                               # [B, 512, 13, 13]
        U5 = self.d5u4_2u5(torch.cat([D5, U4_], dim=1))     # [B, 512, 13, 13]   V

        return [U5, U4, U3]


class YOLOv4(nn.Module):
    def __init__(self, backbone, num_classes=20):
        super(YOLOv4, self).__init__()
        self.num_classes = num_classes
        self.backbone = backbone
        self.SPP = SPPNet()
        self.PANet = PANet()

        self.pred_s = nn.Sequential(nn.Conv2d(128, 256, 3, stride=1, padding=1, bias=False),
                                    nn.BatchNorm2d(256),
                                    Mish(),
                                    nn.Conv2d(256, 3 * (1 + 4 + self.num_classes), 1, stride=1, padding=0))

        self.pred_m = nn.Sequential(nn.Conv2d(256, 512, 3, stride=1, padding=1, bias=False),
                                    nn.BatchNorm2d(512),
                                    Mish(),
                                    nn.Conv2d(512, 3 * (1 + 4 + self.num_classes), 1, stride=1, padding=0))

        self.pred_l = nn.Sequential(nn.Conv2d(512, 1024, 3, stride=1, padding=1, bias=False),
                                    nn.BatchNorm2d(1024),
                                    Mish(),
                                    nn.Conv2d(1024, 3 * (1 + 4 + self.num_classes), 1, stride=1, padding=0))

        print("num_params : ", self.count_parameters())

    def count_parameters(self):
        return sum(p.numel() for p in self.parameters() if p.requires_grad)

    def forward(self, x):

        P3 = x = self.backbone.features1(x)  # [B, 256, 52, 52]
        P4 = x = self.backbone.features2(x)  # [B, 512, 26, 26]
        P5 = x = self.backbone.features3(x)  # [B, 1024, 13, 13]

        P5 = self.SPP(P5)
        U5, U4, U3 = self.PANet(P5, P4, P3)

        p_l = self.pred_l(U5).reshape(x.shape[0], 3, 20 + 5, 13, 13).permute(0, 1, 3, 4, 2) # B, 13, 13, 255
        p_m = self.pred_m(U4).reshape(x.shape[0], 3, 20 + 5, 26, 26).permute(0, 1, 3, 4, 2)  # B, 26, 26, 255
        p_s = self.pred_s(U3).reshape(x.shape[0], 3, 20 + 5, 52, 52).permute(0, 1, 3, 4, 2)  # B, 52, 52, 255

        return [p_l, p_m, p_s]

".reshape(x.shape[0], 3, 20 + 5, x.shape[2], x.shape[3]).permute(0, 1, 3, 4, 2)"

'.reshape(x.shape[0], 3, 20 + 5, x.shape[2], x.shape[3]).permute(0, 1, 3, 4, 2)'