# Prepare

In [None]:
!pip install pytorch-lightning
!pip uninstall -y imgaug
!pip install imgaug==0.4.0
!pip install wcmatch loguru wandb

In [None]:
!wget https://pjreddie.com/media/files/yolov3-tiny.weights 

!mkdir images
!cd images

# Download Images
!wget -c "https://pjreddie.com/media/files/train2014.zip" --header "Referer: pjreddie.com"
!unzip -q train2014.zip && rm train2014.zip
!wget -c "https://pjreddie.com/media/files/val2014.zip" --header "Referer: pjreddie.com"
!unzip -q val2014.zip && rm val2014.zip


# Download COCO Metadata
!wget -c "https://pjreddie.com/media/files/instances_train-val2014.zip" --header "Referer: pjreddie.com"
!wget -c "https://pjreddie.com/media/files/coco/5k.part" --header "Referer: pjreddie.com"
!wget -c "https://pjreddie.com/media/files/coco/trainvalno5k.part" --header "Referer: pjreddie.com"
!wget -c "https://pjreddie.com/media/files/coco/labels.tgz" --header "Referer: pjreddie.com"
!tar xzf labels.tgz
!unzip -q instances_train-val2014.zip

# Set Up Image Lists
!paste <(awk "{print \"$PWD\"}" <5k.part) 5k.part | tr -d '\t' > 5k.txt
!paste <(awk "{print \"$PWD\"}" <trainvalno5k.part) trainvalno5k.part | tr -d '\t' > trainvalno5k.txt

!rm instances_train-val2014.zip 5k.part trainvalno5k.part labels.tgz
!rm -rf sample_data

# Detection Module

In [None]:
import pytorch_lightning as pl
import torch
from torch import nn
import torch.nn.functional as F

class DetectionLayer(pl.LightningModule):
  """
  Use for yolo module
  """
  
  def __init__(self, anchors, num_classes, apply_focal_loss, image_dim):
    super(DetectionLayer, self).__init__()
    self.anchors = anchors
    self.apply_focal_loss = apply_focal_loss
    self.num_anchors = len(anchors)
    self.num_classes = num_classes
    self.ignore_thres = 0.5
    self.mse_loss = nn.MSELoss()
    self.bce_loss = nn.BCEWithLogitsLoss(reduction="mean")
    self.obj_scale = 1
    self.no_obj_scale = 100
    self.metrics = {}
    self.image_dim = image_dim
    self.grid_size = 0
    self.focal_loss = FocalLoss(self.bce_loss,gamma=1.5, alpha=0.25)
  
  def compute_bce_loss(self, inputs, targets, apply_focal_loss):
    if apply_focal_loss:
      self.no_obj_scale = 1
      return self.focal_loss(inputs, targets)
    else:
      self.bce_loss = nn.BCELoss()
      return self.bce_loss(inputs, targets)
  
  def compute_grid_offsets(self, grid_size, CUDA=True):
    self.grid_size = grid_size
    g = self.grid_size
    FloatTensor = torch.cuda.FloatTensor if CUDA else torch.FloatTensor
    self.stride = self.image_dim / self.grid_size

    self.grid_x = torch.arange(g).repeat(g, 1).view([1, 1, g, g]).type(FloatTensor)
    self.grid_y = torch.arange(g).repeat(g, 1).t().view([1, 1, g, g]).type(FloatTensor)

    self.scaled_anchors = FloatTensor([(a_w / self.stride, a_h / self.stride) for a_w, a_h in self.anchors])
    self.anchor_w = self.scaled_anchors[:, 0:1].view((1, self.num_anchors, 1, 1))
    self.anchor_h = self.scaled_anchors[:, 1:2].view((1, self.num_anchors, 1, 1))
  
  def forward(self, x, targets=None, image_dim=None):
    FloatTensor = torch.cuda.FloatTensor if x.is_cuda else torch.FloatTensor
    LongTensor = torch.cuda.LongTensor if x.is_cuda else torch.LongTensor

    self.image_dim = image_dim
    num_samples = x.size(0)
    grid_size = x.size(2)

    prediction = (x.view(num_samples, self.num_anchors, self.num_classes + 5, grid_size, grid_size).permute(0, 1, 3, 4, 2).contiguous())

    # center x, y
    x = torch.sigmoid(prediction[..., 0])
    y = torch.sigmoid(prediction[..., 1])
    # width, height
    w = prediction[..., 2]
    h = prediction[..., 3]

    pred_conf = torch.sigmoid(prediction[..., 4])
    pred_cls = torch.sigmoid(prediction[..., 5:])

    if grid_size != self.grid_size:
      self.compute_grid_offsets(grid_size, CUDA=x.is_cuda)
    
    # Add offset and scale with anchors
    pred_boxes = FloatTensor(prediction[..., :4].shape)
    pred_boxes[..., 0] = x.data + self.grid_x
    pred_boxes[..., 1] = y.data + self.grid_y
    pred_boxes[..., 2] = torch.exp(w.data) * self.anchor_w
    pred_boxes[..., 3] = torch.exp(h.data) * self.anchor_h

    res = (pred_boxes.view(num_samples, -1, 4) * self.stride, pred_conf.view(num_samples, -1, 1), pred_cls.view(num_samples, -1, self.num_classes),)
    output = torch.cat(res, -1)

    if targets is None:
      return output, 0
    else:
      iou_scores, class_mask, obj_mask, no_obj_mask, tx, ty, tw, th, tcls, tconf = build_targets(
        pred_boxes=pred_boxes,
        pred_cls=pred_cls,
        target=targets,
        anchors=self.scaled_anchors,
        ignore_thres=self.ignore_thres,
      )

      # loss
      loss_x = self.mse_loss(x[obj_mask], tx[obj_mask])
      loss_y = self.mse_loss(y[obj_mask], ty[obj_mask])
      loss_w = self.mse_loss(w[obj_mask], tw[obj_mask])
      loss_h = self.mse_loss(h[obj_mask], th[obj_mask])
      loss_conf_obj = self.compute_bce_loss(pred_conf[obj_mask], tconf[obj_mask],self.apply_focal_loss)
      loss_conf_no_obj = self.compute_bce_loss(pred_conf[no_obj_mask], tconf[no_obj_mask],self.apply_focal_loss)
      loss_conf = self.obj_scale * loss_conf_obj + self.no_obj_scale * loss_conf_no_obj
      loss_cls = self.compute_bce_loss(pred_cls[obj_mask], tcls[obj_mask], self.apply_focal_loss)
      total_loss = loss_x + loss_y + loss_w + loss_h + loss_conf + loss_cls

      # metrics
      cls_acc = 100 * class_mask[obj_mask].mean()
      conf_obj = pred_conf[obj_mask].mean()
      conf_no_obj = pred_conf[no_obj_mask].mean()
      conf50 = (pred_conf > 0.5).float()
      iou50 = (iou_scores > 0.5).float()
      iou75 = (iou_scores > 0.75).float()
      detected_mask = conf50 * class_mask
      precision = torch.sum(iou50 * detected_mask) / (conf50.sum() + 1e-16)
      recall50 = torch.sum(iou50 * detected_mask) / (obj_mask.sum() + 1e-16)
      recall75 = torch.sum(iou75 * detected_mask) / (obj_mask.sum() + 1e-16)

      self.metrics = {
        "loss": to_cpu(total_loss).item(),
        "x": to_cpu(loss_x).item(),
        "y": to_cpu(loss_y).item(),
        "w": to_cpu(loss_w).item(),
        "h": to_cpu(loss_h).item(),
        "conf": to_cpu(loss_conf).item(),
        "cls": to_cpu(loss_cls).item(),
        "cls_acc": to_cpu(cls_acc).item(),
        "recall50": to_cpu(recall50).item(),
        "recall75": to_cpu(recall75).item(),
        "precision": to_cpu(precision).item(),
        "conf_obj": to_cpu(conf_obj).item(),
        "conf_no_obj": to_cpu(conf_no_obj).item(),
        "grid_size": grid_size,
      }

      return output, total_loss

In [None]:
class Upsample(pl.LightningModule):
  """
  nn.Upsample is deprecated
  """

  def __init__(self, scale_factor, mode="nearest"):
    super(Upsample, self).__init__()
    self.scale_factor = scale_factor
    self.mode = mode

  def forward(self, x):
    x = F.interpolate(x, scale_factor=self.scale_factor, mode=self.mode)
    return x

In [None]:
class FocalLoss(pl.LightningModule):
  def __init__(self, loss_fcn, gamma=1.5, alpha=0.25):
    super(FocalLoss, self).__init__()
    self.loss_fcn = loss_fcn
    self.gamma = gamma
    self.alpha = alpha
    self.reduction = loss_fcn.reduction
    self.loss_fcn.reduction = "none"
  
  def forward(self, pred, t):
    loss = self.loss_fcn(pred, t)
    prediction = torch.sigmoid(pred)
    pt = t * prediction * (1 - t) * (1 - prediction)
    alpha_factor = t * self.alpha + (1 - t) * (1 - self.alpha)
    m_factor = (1 - pt) ** self.gamma
    loss *= alpha_factor * m_factor

    if self.reduction == "mean":
      return loss.mean()
    elif self.redduction == "sum":
      return loss.sum()
    else:
      return loss

In [None]:
from __future__ import division

import torch
import torch.nn as nn
import numpy as np
import pytorch_lightning as pl

class Darknet(pl.LightningModule):
  def __init__(self, cfg_file, apply_focal_loss=False):
    super(Darknet, self).__init__()
    self.apply_focal_loss = apply_focal_loss
    self.blocks = parse_cfg(cfg_file)
    self.net, self.module_list = create_modules(self.blocks, self.apply_focal_loss)
    self.detection_layers = [layer[0] for layer in self.module_list if isinstance(layer[0], DetectionLayer)]
    self.image_size = int(self.net["height"])
    self.seen = 0
    self.header_info = np.array([0, 0, 0, self.seen, 0], dtype=np.int32)

  def forward(self, x, targets=None):
    """
    Calculate the output
    Transform the output detection feature maps in a vay can be processed easier
    """

    image_dim = x.shape[2]
    loss = 0
    layer_outputs, detection_outputs = [], []

    for i, (module_def, module) in enumerate(zip(self.blocks, self.module_list)):     
      module_type = module_def["type"]
      
      if module_type in ["convolutional", "upsample", "maxpool"]:
        x = module(x)
      elif module_type == "route":
        x = torch.cat([layer_outputs[int(layer_i)] for layer_i in module_def["layers"].split(",")], 1)
      elif  module_type == "shortcut":
        layer_i = int(module_def["from"])
        x = layer_outputs[-1] + layer_outputs[layer_i]
      elif module_type == "yolo":
        x, layer_loss = module[0](x, targets, image_dim)
        loss += layer_loss
        detection_outputs.append(x)
      
      layer_outputs.append(x)
    detection_outputs = to_cpu(torch.cat(detection_outputs, 1))
    return detection_outputs if targets is None else (loss, detection_outputs)

  def load_weight(self, file_path):
    # first 5 items in weight file are header information
    # major ver, minor ver, subversion, images seen by the network
    with open(file_path, "rb") as file:
      header = np.fromfile(file, dtype=np.int32, count=5)
      self.header_info = header
      self.seen = self.header_info[3]
      weights = np.fromfile(file, dtype=np.float32)

    cutoff = None
    if "darknet53.conv.74" in file_path:
      cutoff = 75
    
    n = 0
    for i, (module_def, module) in enumerate(zip(self.blocks, self.module_list)):
      module_type = module_def["type"]
      if i == cutoff:
        break

      # if not convolutional, ignore
      if module_type == "convolutional":
        convol_layer = module[0]
        try:
          batch_normalize = int(module_def["batch_normalize"])
        except:
          batch_normalize = 0
        # batch normalize layer
        if batch_normalize:
          batch_norm_layer = module[1]
          num_biases = batch_norm_layer.bias.numel()
          
          # load weights
          bnl_biases = torch.from_numpy(weights[n: n + num_biases])
          n += num_biases

          bnl_weights = torch.from_numpy(weights[n: n + num_biases])
          n += num_biases

          bnl_running_mean = torch.from_numpy(weights[n: n + num_biases])
          n += num_biases

          bnl_running_var = torch.from_numpy(weights[n: n + num_biases])
          n += num_biases

          # cast weights into dimensions of model weights
          bnl_biases = bnl_biases.view_as(batch_norm_layer.bias.data)
          bnl_weights = bnl_weights.view_as(batch_norm_layer.weight.data)
          bnl_running_mean = bnl_running_mean.view_as(batch_norm_layer.running_mean)
          bnl_running_var = bnl_running_var.view_as(batch_norm_layer.running_var)

          # copy data to model
          batch_norm_layer.bias.data.copy_(bnl_biases)
          batch_norm_layer.weight.data.copy_(bnl_weights)
          batch_norm_layer.running_mean.copy_(bnl_running_mean)
          batch_norm_layer.running_var.copy_(bnl_running_var)
        else:     # convolutional layer
          num_biases = convol_layer.bias.numel()

          # load weights
          convol_biases = torch.from_numpy(weights[n: n + num_biases])
          n += num_biases

          # cast weights into dimensions of model weights
          convol_biases = convol_biases.view_as(convol_layer.bias.data)

          # copy data to model
          convol_layer.bias.data.copy_(convol_biases)
        
        # weights of convolutional layerss
        num_weights = convol_layer.weight.numel()
        convol_weights = torch.from_numpy(weights[n: n + num_weights])
        n += num_weights
        convol_weights = convol_weights.view_as(convol_layer.weight.data)
        convol_layer.weight.data.copy_(convol_weights)

In [None]:
import torch
from torch.autograd import Variable
import numpy as np
import cv2
import random
import os
import shutil
import fnmatch

def predict_transform(predict, input_dim, anchors, num_classes, CUDA=False):
  """
  Transfer input (which is output of forward()) into 2d tensor.
  Each row of the tensor corresponds to attributes of a bounding box.
  """

  batch_size = predict.size(0)
  stride = input_dim // predict.size(2)
  grid_size = input_dim // stride
  bounding_box_attrs = num_classes + 5

  predict = predict.view(batch_size, bounding_box_attrs * len(anchors), grid_size ** 2)
  predict = predict.transpose(1,2).contiguous()
  predict = predict.view(batch_size, grid_size ** 2 * len(anchors), bounding_box_attrs)

  # dimensions of anchors are in accordance to height and width attr of net block
  anchors = [(a[0] / stride, a[1] / stride) for a in anchors]

  # sigmoid x, y coordinates and objectness score
  # center_x, center_y, object_confidence
  predict[:, :, 0] = torch.sigmoid(predict[:, :, 0])
  predict[:, :, 1] = torch.sigmoid(predict[:, :, 1])
  predict[:, :, 4] = torch.sigmoid(predict[:, :, 4])

  # add center offsets
  grid = np.arange(grid_size)
  x, y = np.meshgrid(grid, grid)
  x_offset = torch.FloatTensor(x).view(-1, 1)
  y_offset = torch.FloatTensor(y).view(-1, 1)

  if CUDA:
    x_offset = x_offset.cuda()
    y_offset = y_offset.cuda()
  
  xy_offset = torch.cat((x_offset, y_offset), 1).repeat(1, len(anchors)).view(-1, 2).unsqueeze(0)
  predict[:, :, :2] += xy_offset

  # apply anchors to dimensions of bounding box
  anchors = torch.FloatTensor(anchors)
  if CUDA:
    anchors = anchors.cuda()

  anchors = anchors.repeat(grid_size ** 2, 1).unsqueeze(0)

  predict[:, :, 2: 4] = torch.exp(predict[:, :, 2: 4]) * anchors
  # apply sigmoid to class scores
  predict[:, :, 5: num_classes + 5] = torch.sigmoid(predict[:, :, 5: num_classes + 5])
  # resize detections map to size of input image
  predict[:, :, :4] *= stride

  return predict

In [None]:
def build_targets(pred_boxes, pred_cls, target, anchors, ignore_thres):
    BoolTensor = torch.cuda.BoolTensor if pred_boxes.is_cuda else torch.BoolTensor
    FloatTensor = torch.cuda.FloatTensor if pred_boxes.is_cuda else torch.FloatTensor

    nB = pred_boxes.size(0)
    nA = pred_boxes.size(1)
    nC = pred_cls.size(-1)
    nG = pred_boxes.size(2)

    # output tensors
    obj_mask = BoolTensor(nB, nA, nG, nG).fill_(0)
    no_obj_mask = BoolTensor(nB, nA, nG, nG).fill_(1)
    class_mask = FloatTensor(nB, nA, nG, nG).fill_(0)
    iou_scores = FloatTensor(nB, nA, nG, nG).fill_(0)
    tx = FloatTensor(nB, nA, nG, nG).fill_(0)
    ty = FloatTensor(nB, nA, nG, nG).fill_(0)
    tw = FloatTensor(nB, nA, nG, nG).fill_(0)
    th = FloatTensor(nB, nA, nG, nG).fill_(0)
    tcls = FloatTensor(nB, nA, nG, nG, nC).fill_(0)

    # convert to position relative to box
    target_boxes = target[:, 2:6] * nG
    gxy = target_boxes[:, :2]
    gwh = target_boxes[:, 2:]
    
    # get anchors with best iou
    ious = torch.stack([bounding_box_wh_iou(anchor, gwh) for anchor in anchors])
    _, best_n = ious.max(0)
    
    # separate target values
    b, target_labels = target[:, :2].long().t()
    gx, gy = gxy.t()
    gw, gh = gwh.t()
    gi, gj = gxy.long().t()
    
    # masks
    obj_mask[b, best_n, gj, gi] = 1
    no_obj_mask[b, best_n, gj, gi] = 0

    # set no obj mask to zero where iou exceeds ignore threshold
    for i, anchor_ious in enumerate(ious.t()):
        no_obj_mask[b[i], anchor_ious > ignore_thres, gj[i], gi[i]] = 0

    # coordinates
    tx[b, best_n, gj, gi] = gx - gx.floor()
    ty[b, best_n, gj, gi] = gy - gy.floor()
    
    # width and height
    tw[b, best_n, gj, gi] = torch.log(gw / anchors[best_n][:, 0] + 1e-16)
    th[b, best_n, gj, gi] = torch.log(gh / anchors[best_n][:, 1] + 1e-16)
    
    # one-hot encoding of label
    tcls[b, best_n, gj, gi, target_labels] = 1
    
    # compute label correctness and iou at best anchor
    class_mask[b, best_n, gj, gi] = (pred_cls[b, best_n, gj, gi].argmax(-1) == target_labels).float()
    iou_scores[b, best_n, gj, gi] = get_bounding_boxes_iou(pred_boxes[b, best_n, gj, gi], target_boxes)

    tconf = obj_mask.float()
    return iou_scores, class_mask, obj_mask, no_obj_mask, tx, ty, tw, th, tcls, tconf

In [None]:
def bounding_box_wh_iou(wh1, wh2):
    wh2 = wh2.t()
    w1, h1 = wh1[0], wh1[1]
    w2, h2 = wh2[0], wh2[1]

    area_1 = torch.min(w1, w2) * torch.min(h1, h2)
    area_2 = (w1 * h1 + 1e-16) + w2 * h2 - area_1
    return area_1 / area_2

In [None]:
def parse_cfg(file):
  """
  Parse config from file. Returns a list of blocks.
  Each blocks describes a block in neural network to be built.
  """

  file = open(file, 'r')
  lines = file.read().split('\n')
  lines = [x for x in lines if x and not x.startswith('#')]
  lines = [x.rstrip().lstrip() for x in lines]
  module_defs = []

  for line in lines:
    if line.startswith("["):                 # Check for new block
      module_defs.append({})                 # Check if block not empty
      module_defs[-1]["type"] = line[1:-1].rstrip()
      if module_defs[-1]["type"] == "convolutional":
        module_defs[-1]["batch_normalize"] = 0
    else:
      key, value = line.split("=")           # get key-value from line
      value = value.strip()
      module_defs[-1][key.rstrip()] = value.strip()

  return module_defs

In [None]:
def create_modules(module_defs, focal_loss):
  hyperparams = module_defs.pop(0)                # net info about the input and pre-processing
  momentum = float(hyperparams["momentum"])
  module_list = nn.ModuleList()
  in_channels = 3
  output_filters = [int(hyperparams["channels"])]

  for module_i, module_def in enumerate(module_defs):
    modules = nn.Sequential()
    module_type = module_def["type"]

    # check type of block
    # create new module for block
    # append to module list (modules )
    if module_type == "convolutional":
      batch_normalize = int(module_def["batch_normalize"])
      filters = int(module_def["filters"])
      kernel_size = int(module_def["size"])
      pad = (kernel_size - 1) // 2

      # convolutional layer
      convol_layer = nn.Conv2d(in_channels=output_filters[-1], out_channels=filters, kernel_size=kernel_size, stride=int(module_def["stride"]), padding=pad, bias=not batch_normalize)
      modules.add_module("conv_{}".format(module_i), convol_layer)

      # batch norm layer
      if batch_normalize:
        modules.add_module("batch_norm_{}".format(module_i), nn.BatchNorm2d(filters, momentum=momentum, eps=1e-5))
      # linear or leaky relu for yolo
      if module_def["activation"] == "leaky":
        modules.add_module("leaky_{}".format(module_i), nn.LeakyReLU(0.1))
    # maxpool layers
    elif module_type == "maxpool":
      kernel_size = int(module_def["size"])
      stride = int(module_def["stride"])

      if kernel_size == 2 and stride == 1:
        modules.add_module('ZeroPad2d', nn.ZeroPad2d((0, 1, 0, 1)))
      
      maxpool = nn.MaxPool2d(kernel_size=kernel_size, stride=stride, padding=int((kernel_size - 1) // 2))
      modules.add_module("maxpool_{}".format(module_i), maxpool)
    # unsample layers
    elif module_type == "upsample":
      upsample = Upsample(scale_factor=int(module_def["stride"]), mode="nearest")
      modules.add_module("upsample_{}".format(module_i), upsample)
    # route layer
    elif module_type == "route":
      layers = [int(x) for x in module_def["layers"].split(",")]
      filters = sum([output_filters[1:][i] for i in layers])
      modules.add_module("route_{}".format(module_i), nn.Sequential())
    # shortcut
    elif module_type == "shortcut":
      filters = output_filters[1:][int(module_def["from"])]
      modules.add_module("shortcut_{}".format(module_i), nn.Sequential())
    # yolo: detection layer
    elif module_type == "yolo":
      anchor_indexs = [int(x) for x in module_def["mask"].split(",")]

      anchors = [int(x) for x in module_def["anchors"].split(",")]
      anchors = [(anchors[i], anchors[i + 1]) for i in range(0, len(anchors), 2)]
      anchors = [anchors[i] for i in anchor_indexs]
      num_classes = int(module_def["classes"])
      image_size = int(hyperparams["height"])

      detection = DetectionLayer(anchors, num_classes, focal_loss, image_size)
      modules.add_module("Detection_{}".format(module_i), detection)

    module_list.append(modules)
    output_filters.append(filters)

  return hyperparams, module_list

In [None]:
def test_input(file_path, img_size):
    img = cv2.imread(file_path)
    img = cv2.resize(img, img_size)
    img_result = img[:, :, ::-1].transpose((2, 0, 1))     # BGR -> RGB
    img_result = img_result[np.newaxis, :, :, :]/255.0    # Add a channel at 0
    img_result = torch.from_numpy(img_result).float()     # Convert to float
    img_result = Variable(img_result)                     # Convert to Variable
    return img_result

In [None]:
def get_result(prediction, confidence, num_classes, nms_conf=0.4):
  # object confidence thresholding
  # each bounding box having objectness score below a threshold
  # set the value of entrie row representing the bounding box to zero
  conf_mask = (prediction[:, :, 4] > confidence).float().unsqueeze(2)
  prediction *= conf_mask

  # transform center_x, center_y, height, width of box
  # to top_left_corner_x, top_right_corner_y, right_bottom_corner_x, right_bottom_corner_y 
  box = prediction.new(prediction.shape)
  box[:, :, 0] = prediction[:, :, 0] - prediction[:, :, 2] / 2
  box[:, :, 1] = prediction[:, :, 1] - prediction[:, :, 3] / 2
  box[:, :, 2] = prediction[:, :, 0] + prediction[:, :, 2] / 2
  box[:, :, 3] = prediction[:, :, 1] + prediction[:, :, 3] / 2
  prediction[:, :, :4] = box[:, :, :4]

  batch_size = prediction.size(0)
  check = False

  # the number of true detections in every image may be different
  # confidence thresholding and nms has to be done for one image at conce
  # must loop over the 1st dimension of prediction
  for i in range(batch_size):
    image_prediction = prediction[i]      # image tensor

    # each bounding box has 85 attri
    # 80 attri are class scores
    max_confidence, max_confidence_score = torch.max(image_prediction[:, 5: num_classes + 5], 1)
    max_confidence = max_confidence.float().unsqueeze(1)
    max_confidence_score = max_confidence_score.float().unsqueeze(1)
    image_prediction = torch.cat((image_prediction[:, :5], max_confidence, max_confidence_score), 1)

    non_zero = torch.nonzero(image_prediction[:, 4])
    try:
      image_prediction_ = image_prediction[non_zero.squeeze(), :].view(-1, 7)
    except:
      continue
    
    if image_prediction_.shape[0] == 0:
      continue
    
    # get various classes detected in image
    image_classes = get_unique(image_prediction_[:, -1])

    for c in image_classes:
      # nms
      # get detections with 1 particular class
      class_mask = image_prediction_ * (image_prediction_[:, -1] == c).float().unsqueeze(1)
      class_mask_index = torch.nonzero(class_mask[:, -2]).squeeze()
      image_prediction_class = image_prediction_[class_mask_index].view(-1, 7)

      # sort detection
      # confidence at top
      confidence_sorted_index = torch.sort(image_prediction_class[:, 4], descending=True)[1]
      image_prediction_class = image_prediction_class[confidence_sorted_index]
      index = image_prediction_class.size(0)

      for idx in range(index):
        # get ious of all boxes
        try:
          ious = get_bounding_boxes_iou(image_prediction_class[idx].unsqueeze(0), image_prediction_class[idx + 1:])
        except ValueError:
          break
        except IndexError:
          break
        
        # mark zero all detections iou > threshold
        iou_mask = (ious < nms_conf).float().unsqueeze(1)
        image_prediction_class[idx + 1:] *= iou_mask

        # remove non-zero entries
        non_zero_index = torch.nonzero(image_prediction_class[:, 4]).squeeze()
        image_prediction_class = image_prediction_class[non_zero_index].view(-1, 7)
      
      batch_index = image_prediction_class.new(image_prediction_class.size(0), 1).fill_(i)
      s = batch_index, image_prediction_class

      if not check:
        output = torch.cat(s, 1)
        check = True
      else:
        output = torch.cat((output, torch.cat(s, 1)))
      
  try:
    return output
  except:
    return 0

In [None]:
def get_unique(tensor):
  np_tensor = tensor.cpu().numpy()
  unique = np.unique(np_tensor)
  unique_tensor = torch.from_numpy(unique)
  result = tensor.new(unique_tensor.shape)
  result.copy_(unique_tensor)

  return result

In [None]:
def get_bounding_boxes_iou(b1, b2):
  """
  Returns iou of 2 bouding boxes
  """

  # get coordinates of 2 bounding boxes
  b1_x1, b1_y1, b1_x2, b1_y2 = b1[:, 0], b1[:, 1], b1[:, 2], b1[:, 3]
  b2_x1, b2_y1, b2_x2, b2_y2 = b2[:, 0], b2[:, 1], b2[:, 2], b2[:, 3]

  # get coordinates of overclap rectangle
  x1 = torch.max(b1_x1, b2_x1)
  y1 = torch.max(b1_y1, b2_y1)
  x2 = torch.min(b1_x2, b2_x2)
  y2 = torch.min(b1_y2, b2_y2)

  # overclap area
  area = torch.clamp(x2 - x1 + 1, min=0) * torch.clamp(y2 - y1 + 1, min=0)

  # union area
  b1_area = (b1_x2 - b1_x1 + 1) * (b1_y2 - b1_y1 + 1)
  b2_area = (b2_x2 - b2_x1 + 1) * (b2_y2 - b2_y1 + 1)

  return area / (b1_area + b2_area - area)

In [None]:
def resize_image(img, input_dim):
    """
    resize image with unchanged aspect ratio using padding
    """
    width, height = img.shape[1], img.shape[0]
    w, h = input_dim
    new_width = int(width * min(w / width, h / height))
    new_height = int(height * min(w / width, h / height))
    resized_image = cv2.resize(img, (new_width, new_height), interpolation = cv2.INTER_CUBIC)
    
    canvas = np.full((input_dim[1], input_dim[0], 3), 128)
    canvas[(h - new_height) // 2: (h - new_height) // 2 + new_height,(w - new_width) // 2: (w - new_width) // 2 + new_width,  :] = resized_image
    return canvas

In [None]:
def pre_image(img, input_dim):
  """
  Prepare image as input for neural network
  """

  img = resize_image(img, (input_dim, input_dim))
  img = img[:, :, ::-1].transpose((2, 0, 1)).copy()
  img = torch.from_numpy(img).float().div(255.0).unsqueeze(0)
  return img

In [None]:
def draw_result(x, results, colors, classes):
  t1 = tuple(x[1: 3].int())
  t2 = tuple(x[3: 5].int())
  img = results[int(x[0])]
  text_font = cv2.FONT_HERSHEY_PLAIN
  cls = int(x[-1])
  color = random.choice(colors)
  label = "{}".format(classes[cls])
  cv2.rectangle(img, t1, t2, color, 1)
  text_size = cv2.getTextSize(label, text_font, 1, 1)[0]
  t2 = t1[0] + text_size[0] + 3, t1[1] + text_size[1] + 4
  cv2.rectangle(img, t1, t2, color, -1)
  text_pos = t1[0], t1[1] + text_size[1] + 4
  cv2.putText(img, label, text_pos, text_font, 1, [255, 255, 255], 1)
  return img

In [None]:
def load_dataset(file_path):
  file = open(file_path, "r")
  names = file.read().split("\n")[:-1]
  return names

In [None]:
def to_cpu(tensor):
    return tensor.detach().cpu()

In [None]:
def save_code_files(output_path, root_path):
  def match_patterns(include, exclude):
    def _ignore_patterns(path, names):
      # If current path in exclude list, ignore everything
      if path in set(name for pattern in exclude for name in fnmatch.filter([path], pattern)):
        return names
      # Get initial keep list from include patterns
      keep = set(name for pattern in include for name in fnmatch.filter(names, pattern))
      # Add subdirectories to keep list
      keep = set(list(keep) + [name for name in names if os.path.isdir(os.path.join(path, name))])
      # Remove exclude patterns from keep list
      keep_ex = set(name for pattern in exclude for name in fnmatch.filter(keep, pattern))
      keep = [name for name in keep if name not in keep_ex]
      # Ignore files not in keep list
      return set(name for name in names if name not in keep)

    return _ignore_patterns


  dst_dir = os.path.join(output_path, "code")
  if os.path.exists(dst_dir):
    shutil.rmtree(dst_dir)
  shutil.copytree(root_path, dst_dir, ignore=match_patterns(include=['*.py', '*.data', '*.cfg'],
                                                            exclude=['experiment*',
                                                                      '*.idea',
                                                                      '*__pycache__',
                                                                      'weights',
                                                                      'wandb',
                                                                      'asets'
                                                                      ]))

## Image detection

In [None]:
from __future__ import division
import time
import torch
from torch.autograd import Variable
import cv2
import argparse
import os
import os.path as osp
import pickle as pkl
import pandas as pd

def parse_arg():
  """
  Parse arguments to detect module
  """

  parser = argparse.ArgumentParser(description="reYOLO Detection Module")
  parser.add_argument("--images", default="/content/dog-cycle-car.png", type=str, help="Image path or directory containing images to perform detection")
  parser.add_argument("--det", default="det", type=str, help="Imgage path or directory to store detections")
  parser.add_argument("--bs", default=1, help="Batch size")
  parser.add_argument("--confidence", default=0.5, help="Object confidence to filter predictions")
  parser.add_argument("--nms", default=0.4, help="NMS Threshold")
  parser.add_argument("--cfg", dest="cfg_file", default="/content/yolov3-tiny.cfg", type=str, help="Config file path")
  parser.add_argument("--weights", dest="weights_file", default="/content/yolov3-tiny.weights", type=str, help="Weights file path")
  parser.add_argument("--dataset", default="/content/coco.names", type=str, help="Dataset file path")
  parser.add_argument("--colors", dest="colors_file", default="/content/pallete", type=str, help="Colors file path")

  args, _ = parser.parse_known_args()
  return args

class ImageDetect():
  def __init__(self):
    args = parse_arg()
    self.images = args.images
    self.cfg_file = args.cfg_file
    self.weights_file = args.weights_file
    self.det = args.det
    self.batch_size = int(args.bs)
    self.confidence = float(args.confidence)
    self.nms = float(args.nms)
    self.CUDA = torch.cuda.is_available()
    self.classes = load_dataset(args.dataset)
    self.num_classes = len(self.classes)
    self.colors_file = args.colors_file
  
  def load_network(self):
    """
    Setup neural network
    """
    self.model = Darknet(self.cfg_file)
    self.model.load_weight(self.weights_file)
    self.input_dim = int(self.model.net["height"])
    assert self.input_dim % 32 == 0
    assert self.input_dim > 32
  
  def get_detections(self):
    self.load_network()
    if self.CUDA:         # if cuda available
      self.model.cuda()
    
    self.model.eval()       # set model in evaluation mode
    read_time = time.time()

    try:
      image_list = [osp.join(osp.realpath("."), self.images, img) for img in os.listdir(self.images)]
    except NotADirectoryError:
      image_list = []
      image_list.append(osp.join(osp.realpath("."), self.images))
    except FileNotFoundError:
      print("No file or directory with name {}".format(self.images))
      exit()

    if not os.path.exists(self.det):
      os.makedirs(self.det)

    load_batch_time = time.time()
    loaded_img_list = [cv2.imread(x) for x in image_list]
    # pytorch variables for images
    img_batches = list(map(pre_image, loaded_img_list, [self.input_dim for i in range(len(image_list))]))
    # dimensions of original images
    img_dim_list = [(x.shape[1], x.shape[0]) for x in loaded_img_list]
    img_dim_list = torch.FloatTensor(img_dim_list).repeat(1, 2)

    # create batches
    left_over = 0
    if len(img_dim_list) % self.batch_size:
      left_over = 1
    
    if self.batch_size != 1:
      num_batches = len(image_list) // self.batch_size + left_over
      img_batches = [torch.car((img_batches[i * self.batch_size: min((i + 1) * self.batch_size, len(img_batches))])) for i in range(num_batches)]
    
    check = 0
    if self.CUDA:
      img_dim_list = img_dim_list.cuda()

    start_detect_loop_time = time.time()

    # detection loop
    for i, batch in enumerate(img_batches):
      start = time.time()
      if self.CUDA:
        batch = batch.cuda()
      with torch.no_grad():
        prediction = self.model(Variable(batch))
      
      prediction = get_result(prediction, self.confidence, self.num_classes, nms_conf=self.nms)

      end = time.time()
      if type(prediction) == int:
        for img_num, image in enumerate(image_list[i * self.batch_size: min((i + 1) * self.batch_size, len(image_list))]):
          img_id = i * self.batch_size + img_num
          print("{0:20s} predicted in {1:6.3f} seconds".format(image.split("/")[-1], (end - start) / self.batch_size))
          print("{0:20s} {1:s}".format("Objects Detected:", ""))
          print("*********************************************")
        continue
      
      # transform attr from index in batch to index in image list
      prediction[:, 0] += i * self.batch_size
      if not check:           # initialize output
        output = prediction
        check = 1
      else:
        output = torch.cat((output, prediction))
      
      for img_num, image in enumerate(image_list[i * self.batch_size: min((i + 1) * self.batch_size, len(image_list))]):
          img_id = i * self.batch_size + img_num
          objects = [self.classes[int(x[-1])] for x in output if int(x[0]) == img_id]
          print("{0:20s} predicted in {1:6.3f} seconds".format(image.split("/")[-1], (end - start) / self.batch_size))
          print("{0:20s} {1:s}".format("Objects Detected:", " ".join(objects)))
          print("*********************************************")
      
      if self.CUDA:
        torch.cuda.synchronize()

    # draw bouding boxes on images
    try:
      output
    except NameError:
      print("No detection were made")
      exit()
    
    img_dim_list = torch.index_select(img_dim_list, 0, output[:, 0].long())
    scale_factor = torch.min(self.input_dim / img_dim_list, 1)[0].view(-1, 1)
    output[:, [1, 3]] -= (self.input_dim - scale_factor * img_dim_list[:, 0].view(-1, 1)) / 2
    output[:, [2, 4]] -= (self.input_dim - scale_factor * img_dim_list[:, 1].view(-1, 1)) / 2
    output[:, 1:5] /= scale_factor

    for i in range(output.shape[0]):
      output[i, [1, 3]] = torch.clamp(output[i, [1, 3]], 0.0, img_dim_list[i, 0])
      output[i, [2, 4]] = torch.clamp(output[i, [2, 4]], 0.0, img_dim_list[i, 1])
    
    output_recast_time = time.time()
    class_load_time = time.time()
    colors = pkl.load(open(self.colors_file, "rb"))
    draw_time = time.time()

    list(map(lambda x: draw_result(x, loaded_img_list, colors, self.classes), output))
    detect_names = pd.Series(image_list).apply(lambda x: "{}/detect_{}".format(self.det, x.split("/")[-1]))
    list(map(cv2.imwrite, detect_names, loaded_img_list))

    end = time.time()
    print("Results")
    print("*********************************************")
    print("{:25s}: {}".format("Task", "Time Taken (in seconds)"))
    print("{:25s}: {:2.3f}".format("Reading", load_batch_time - read_time))
    print("{:25s}: {:2.3f}".format("Loading batch", start_detect_loop_time - load_batch_time))
    print("{:25s}: {:2.3f}".format("Detection (" + str(len(image_list)) +  " images)", output_recast_time - start_detect_loop_time))
    print("{:25s}: {:2.3f}".format("Output processing", class_load_time - output_recast_time))
    print("{:25s}: {:2.3f}".format("Drawing boxes", end - draw_time))
    print("{:25s}: {:2.3f}".format("Average time per img", (end - load_batch_time) / len(image_list)))
    print("{:25s}: {}".format("Result Folder", self.det))

    torch.cuda.empty_cache()

test = ImageDetect()
test.get_detections()

dog-cycle-car.png    predicted in  0.165 seconds
Objects Detected:    bicycle car truck dog
*********************************************
Results
*********************************************
Task                     : Time Taken (in seconds)
Reading                  : 0.000
Loading batch            : 0.025
Detection (1 images)     : 0.168
Output processing        : 0.000
Drawing boxes            : 0.018
Average time per img     : 0.210
Result Folder            : det


## Video detection

In [None]:
from google.colab.patches import cv2_imshow

from __future__ import division
import time
import torch
from torch.autograd import Variable
import cv2
import argparse

def parse_arg():
  """
  Parse arguments to detect module
  """

  parser = argparse.ArgumentParser(description="reYOLO Detection Module")
  parser.add_argument("--video", dest="video_file", default="/content/videoplayback.mp4", type=str, help="Image path or directory containing images to perform detection")
  parser.add_argument("--bs", default=1, help="Batch size")
  parser.add_argument("--confidence", default=0.5, help="Object confidence to filter predictions")
  parser.add_argument("--nms", default=0.4, help="NMS Threshold")
  parser.add_argument("--cfg", dest="cfg_file", default="/content/yolov3.cfg", type=str, help="Config file path")
  parser.add_argument("--weights", dest="weights_file", default="/content/yolov3.weights", type=str, help="Weights file path")
  parser.add_argument("--dataset", default="/content/coco.names", type=str, help="Dataset file path")
  parser.add_argument("--colors", dest="colors_file", default="/content/pallete", type=str, help="Colors file path")
  parser.add_argument("--source", default="file", type=str, help="Video source")
  
  args, _ = parser.parse_known_args()
  return args

class VideoDetect():
  def __init__(self):
    args = parse_arg()
    self.video_file = args.video_file
    self.batch_size = args.bs
    self.confidence = args.confidence
    self.nms = args.nms
    self.cfg_file = args.cfg_file
    self.weights_file = args.weights_file
    self.classes = load_dataset(args.dataset)
    self.num_classes = len(self.classes)
    self.colors_file = args.colors_file
    self.CUDA = torch.cuda.is_available()
    self.source = args.source
  
  def load_network(self):
    """
    Setup neural network
    """
    self.model = Darknet(self.cfg_file)
    self.model.load_weight(self.weights_file)
    self.input_dim = int(self.model.net["height"])
    assert self.input_dim % 32 == 0
    assert self.input_dim > 32
  
  def get_detections(self):
    self.load_network()
    if self.CUDA:         # if cuda available
      self.model.cuda()
    
    self.model.eval()     # set model in evaluation mode

    # get video capture from source (file/webcam)
    if self.source == "video":
      cap = cv2.VideoCapture(self.video_file)
    else:
      cap = cv2.VideoCapture(0)   # webcam
    assert cap.isOpened(), 'Cannot captutre video source'
    
    frames = 0
    start = time.time()
    while cap.isOpened():
      ret, frame = cap.read()

      if ret:
        image = pre_image(frame, self.input_dim)
        img_dim = frame.shape[1], frame.shape[0]
        img_dim = torch.FloatTensor(img_dim).repeat(1, 2)

        if self.CUDA:
          img_dim = img_dim.cuda()
          image = image.cuda()
        
        with torch.no_grad():
          prediction = self.model(Variable(image))

        prediction = get_result(prediction, self.confidence, self.num_classes, nms_conf=self.nms)
        if type(prediction) == int:
          frames += 1
          print("FPS: {:5.4f}".format(frames / (time.time() - start)))
          # cv2.imshow("frame", frame)
          cv2_imshow(frame)
          key = cv2.waitKey(1)
          if key & 0xFF == ord('q'):    # exit if press q
            break
          continue
        
        img_dim = img_dim.repeat(prediction.size(0), 1)
        scale_factor = torch.min(self.input_dim / img_dim, 1)[0].view(-1, 1)
        prediction[:, [1, 3]] -= (self.input_dim - scale_factor * img_dim[:, 0].view(-1, 1)) / 2
        prediction[:, [2, 4]] -= (self.input_dim - scale_factor * img_dim[:, 1].view(-1, 1)) / 2
        prediction[:, 1: 5] /= scale_factor

        for i in range(prediction.shape[0]):
          prediction[i, [1, 3]] = torch.clamp(prediction[i, [1, 3]], 0.0, img_dim[i, 0])
          prediction[i, [2, 4]] = torch.clamp(prediction[i, [2, 4]], 0.0, img_dim[i, 1])
        
        list(map(lambda x: draw_result(x, frame, self.colors, self.classes), prediction))
        # cv2.imshow("frame", frame)
        cv2_imshow(frame)
        key = cv2.waitKey(1)
        if key & 0xFF == ord('q'):
          break
        frames += 1
        t = time.time() - start
        print("Predicted in {1:6.3f} seconds".format(t))
        print("FPS: {:5.2f}".format(frames / (time.time() - start)))
      else:
        break

test = VideoDetect()
test.get_detections()

# Training Module

In [None]:
from torch.utils.data import Dataset
import torch.nn.functional as F
import torch
import random
import os
import warnings
import numpy as np
from PIL import Image
from PIL import ImageFile

ImageFile.LOAD_TRUNCATED_IMAGES = True

class ImageDataset(Dataset):
  def __init__(self, images_path, image_size, max_objects=100, multiscale=True, transform=None, quick=False):
    with open(images_path, "r") as file:
      self.image_files = [name.rstrip() for name in file.readlines()]

    self.label_files = [
      path.replace("images", "labels").replace(".png", ".txt").replace(".jpg", ".txt")
      for path in self.image_files
    ]

    if quick:
      self.image_files = self.image_files[:1000]

    self.image_size = image_size
    self.max_objects = max_objects
    self.multiscale = multiscale
    self.min_size = self.image_size - 3 * 32
    self.max_size = self.image_size + 3 * 32
    self.batch_count = 0
    self.transform = transform
    
  def __getitem__(self, index):
    try:
      image_path = self.image_files[index % len(self.image_files)].rstrip()
      image = np.array(Image.open(image_path).convert('RGB'), dtype=np.uint8)
    except Exception:
      print(f"Cannot read image '{image_path}'.")

    try:
      label_path = self.label_files[index % len(self.image_files)].rstrip()
      with warnings.catch_warnings():
        warnings.simplefilter("ignore")
        boxes = np.loadtxt(label_path).reshape(-1, 5)
    except Exception:
      print(f"Cannot read label '{label_path}'.")
      return
  
    if self.transform:
      try:
        image, targets = self.transform((image, boxes))
      except Exception:
        print("Cannot apply transform.")
        return
    
    return image_path, image, targets

  def collate_fn(self, batch):
    self.batch_count += 1

    # Drop invalid images
    batch = [data for data in batch if data is not None]

    paths, imgs, targets = list(zip(*batch))

    # Selects new image size every tenth batch
    if self.multiscale and self.batch_count % 10 == 0:
      self.image_size = random.choice(
          range(self.min_size, self.max_size + 1, 32))

    # Resize images to input shape
    imgs = torch.stack([resize(img, self.image_size) for img in imgs])

    # Add sample index to targets
    for i, boxes in enumerate(targets):
      boxes[:, 0] = i
    targets = torch.cat(targets, 0)

    return paths, imgs, targets

  def __len__(self):
    return len(self.image_files)

def resize(image, size):
  image = F.interpolate(image.unsqueeze(0), size=size, mode="nearest").squeeze(0)
  return image

In [None]:
import imgaug.augmenters as iaa
import torch
import numpy as np
from imgaug.augmentables.bbs import BoundingBox, BoundingBoxesOnImage
import torchvision.transforms as transforms
from dataclasses import dataclass

def xywh2xyxy_np(x):
  y = np.zeros_like(x)
  y[..., 0] = x[..., 0] - x[..., 2] / 2
  y[..., 1] = x[..., 1] - x[..., 3] / 2
  y[..., 2] = x[..., 0] + x[..., 2] / 2
  y[..., 3] = x[..., 1] + x[..., 3] / 2
  return y

class ImageAugmenter(object):
  def __init__(self, augmentations=[]):
    self.augmentations = augmentations

  def __call__(self, data):
    image, boxes = data
    # Convert xywh to xyxy
    boxes = np.array(boxes)
    boxes[:, 1:] = xywh2xyxy_np(boxes[:, 1:])

    bounding_boxes = BoundingBoxesOnImage([BoundingBox(*box[1:], label=box[0]) for box in boxes], shape=image.shape)
    image, bounding_boxes = self.augmentations(image=image, bounding_boxes=bounding_boxes)
    bounding_boxes = bounding_boxes.clip_out_of_image()
    boxes = np.zeros((len(bounding_boxes), 5))
    for i, box in enumerate(bounding_boxes):
      x1 = box.x1
      y1 = box.y1
      x2 = box.x2
      y2 = box.y2

      # (x, y, w, h)
      boxes[i, 0] = box.label
      boxes[i, 1] = (x1 + x2) / 2
      boxes[i, 2] = (y1 + y2) / 2
      boxes[i, 3] = x2 - x1
      boxes[i, 4] = y2 - y1
    
    return image, boxes

class RelativeLabels(object):
  def __init__(self):
    pass

  def __call__(self, data):
    image, boxes = data
    h, w, _ = image.shape
    boxes[:, [1, 3]] /= w
    boxes[:, [2, 4]] /= h
    return image, boxes

class AbsoluteLabels(object):
  def __init__(self):
    pass

  def __call__(self, data):
    image, boxes = data
    h, w, _ = image.shape
    boxes[:, [1, 3]] *= w
    boxes[:, [2, 4]] *= h
    return image, boxes

class PadSquare(ImageAugmenter):
  def __init__(self):
    self.augmentations = iaa.Sequential([
      iaa.PadToAspectRatio(
        1.0,
        position="center-center").to_deterministic()
    ])

class ToTensor(object):
  def __init__(self):
    pass

  def __call__(self, data):
    image, boxes = data
    # Extract image as PyTorch tensor
    image = transforms.ToTensor()(image)

    targets = torch.zeros((len(boxes), 6))
    targets[:, 1:] = transforms.ToTensor()(boxes)

    return image, targets

class DefaultAugmenter(ImageAugmenter):
  def __init__(self):
    self.augmentations = iaa.Sequential([
      iaa.Sharpen((0.0, 0.1)),
      iaa.Affine(rotate=(-0, 0), translate_percent=(-0.1, 0.1), scale=(0.8, 1.5)),
      iaa.AddToBrightness((-60, 40)),
      iaa.AddToHue((-10, 10)),
      iaa.Fliplr(0.5),
    ])
@dataclass
class Transform:
  train =  transforms.Compose([
    AbsoluteLabels(),
    PadSquare(),
    RelativeLabels(),
    ToTensor(),
  ])

  val = transforms.Compose([
    AbsoluteLabels(),
    DefaultAugmenter(),
    PadSquare(),
    RelativeLabels(),
    ToTensor(),
  ])

In [None]:
from __future__ import division

import os
import argparse
from torch.autograd import Variable
import torch
from torch.utils.data import DataLoader
from wcmatch.pathlib import Path
import pytorch_lightning as pl
from datetime import datetime
from loguru import logger
from pytorch_lightning.callbacks import ModelCheckpoint
import wandb
from pytorch_lightning.loggers import WandbLogger

def parse_arg():
  parser = argparse.ArgumentParser(description="reYOLO Training Module")
  parser.add_argument("--cfg", dest="cfg_file", type=str, default="/content/yolov3-tiny.cfg", help="Config file path")
  parser.add_argument("--dataset", type=str, default="/content/coco.names", help="Dataset file path")
  parser.add_argument("--train_path", type=str, default="/content/data/trainvalno5k.txt")
  parser.add_argument("--valid_path", type=str, default="/content/data/5k.txt")
  parser.add_argument("--nms", default=0.4, help="NMS Threshold")
  parser.add_argument("--iou", default=0.5, help="NMS Threshold")
  parser.add_argument("--confidence", default=0.5, help="Object confidence to filter predictions")
  parser.add_argument("--epochs", type=int, default=4, help="Number of epochs")
  parser.add_argument("--cpus", type=int, default=2, help="Number of cpu threads during batch generation")
  parser.add_argument("--pretrained_weights", default="/content/yolov3-tiny.weights", type=str, help="Checkpoint file path (.weights or .pt)")
  parser.add_argument("--multiscale_train", action="store_true", help="Allow multi-scale training")
  parser.add_argument("--seed", type=int, default=-1)
  args, _ = parser.parse_known_args()
  return args

def load_model(path, device, weights=None):
  model = Darknet(path).to(device)

  if weights:
    if weights.endswith(".pth"):
      model.load_state_dict(torch.load(weights, map_location=device))
    else:
      model.load_weight(weights)
  
  return model


class DataModule(pl.LightningDataModule):
  def __init__(self, train_ds, val_ds, batch_size, cpus):
    super().__init__()
    self.train_ds = train_ds
    self.val_ds = val_ds
    self.batch_size = batch_size
    self.cpus = cpus
  
  def train_dataloader(self):
    return DataLoader(self.train_ds, batch_size=self.batch_size, shuffle=True, num_workers=self.cpus, pin_memory=True, collate_fn=self.train_ds.collate_fn)
  
  def val_dataloader(self):
    return DataLoader(self.val_ds, batch_size=self.batch_size, shuffle=False, num_workers=self.cpus, pin_memory=True, collate_fn=self.val_ds.collate_fn)


class Net(pl.LightningModule):
  def __init__(self, model, img_size, batch_size, args):
    super().__init__()
    self.model = model
    self.valid_path = args.valid_path
    self.img_size = img_size
    self.batch_size = batch_size
    self.nms = args.nms
    self.conf = args.confidence
    self.iou = args.iou
  
  def forward(self, x):
    return self.model(x)
  
  def training_step(self, batch, batch_index):
    images, targets = batch[1:]
    cuda = torch.cuda.is_available()
    device = torch.device("cuda" if cuda else "cpu")
    loss, outputs = self.model(Variable(images.to(device)), targets)
    self.log("train_loss", loss, on_step=True, on_epoch=True, prog_bar=True, logger=True)
    return loss

  def get_progress_bar_dict(self):
    items = super().get_progress_bar_dict()
    items.pop("v_num", None)
    items.pop("loss", None)
    return items
  
  def validation_step(self, batch, batch_idx):
    imgs, targets = batch[1:]
    loss, outputs = self.model(imgs, targets)
    self.log('val_loss',loss, on_step=True, on_epoch=True, prog_bar=True, logger=True)
    return loss
  
  def configure_optimizers(self):
    if "optimizer" not in self.model.net or self.model.net["optimizer"] == "adam":
      optimizer = torch.optim.Adam(self.model.parameters(), lr=float(self.model.net["learning_rate"]), weight_decay=float(self.model.net["decay"]))
    elif self.net["optimizer"] == "sgd":
      optimizer = torch.optim.SGD(self.model.parameters(), lr=float(self.model.net["learning_rate"]), weight_decay=float(self.model.net["decay"]), momentum=self.model.net["momentum"])

    return optimizer

if __name__ == "__main__":
  level = "DEBUG"
  experiment_root = 'experiments'
  exp_root_path = Path(experiment_root)
  wandb.login()
  experiment_dir = exp_root_path / "train" / f"exp_{datetime.now()}"
  log_file = experiment_dir / f"log_{datetime.now()}.log"
  logger.opt(record=True).add(log_file, format=" {time:YYYY-MMM HH:mm:ss} {name}:{function}:{line} <lvl>{message}</>", level=level, rotation="5 MB")
  experiment_dir.mkdir(exist_ok=True)

  args = parse_arg()
  logger.opt(colors=True).info(args)
  save_code_files(experiment_dir, os.path.abspath(''))

  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
  darknet = Darknet(args.cfg_file)
  img_size = int(darknet.net["height"])
  batch_size = 6
  classes = load_dataset(args.dataset)
  
  train_ds = ImageDataset(images_path=args.train_path, multiscale=args.multiscale_train, image_size=img_size, transform=Transform.train)
  val_ds = ImageDataset(images_path=args.valid_path, multiscale=args.multiscale_train, image_size=img_size, transform=Transform.val)

  model = Net(darknet, img_size, batch_size, args)
  model = model.to(device)

  data_module = DataModule(train_ds, val_ds, batch_size, args.cpus)

  wandb_logger = WandbLogger(project="reYOLO", save_dir=experiment_dir, offline=False, name="tiny")
  checkpoint_callback = ModelCheckpoint(dirpath=experiment_dir, mode="min", monitor="val_loss")

  logger.opt(colors=True).info("Start training")
  logger.info(batch_size)
  trainer = pl.Trainer(logger=wandb_logger, auto_scale_batch_size='binsearch', num_sanity_val_steps=0, callbacks=[checkpoint_callback], weights_save_path="weights")
  trainer.fit(model, data_module)