# Install

In [None]:
!pip install -q datasets
!pip install -q albumentations

# Env

In [None]:
import torch
from torch import nn
from torch.nn import functional as F
from torch.nn import Conv2d, BatchNorm2d, Identity, LeakyReLU, Upsample

In [None]:
import random
import numpy as np
import pandas as pd
from tqdm.auto import tqdm

from einops import einsum, rearrange, reduce, repeat
from einops.layers.torch import Rearrange, Reduce

In [None]:
import seaborn as sns
sns.set()

import matplotlib
import matplotlib.pyplot as plt
%matplotlib inline
matplotlib.rcParams['lines.linewidth']=2

from matplotlib_inline.backend_inline import set_matplotlib_formats
set_matplotlib_formats('svg', 'pdf')

In [None]:
seed=42
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)

torch.backends.cudnn.deterministic=True
torch.backends.cudnn.benmark=False
torch.use_deterministic_algorithms(True)

device='cuda' if torch.cuda.is_available() else 'cpu'
device

In [None]:
from datasets import load_dataset, load_from_disk
from datasets.features import Image as ImageFeature

import albumentations as A

# Global

In [None]:
anchors=[[[0.28, 0.22], [0.38, 0.48], [0.9, 0.78]],
         [[0.07, 0.15], [0.15, 0.11], [0.14, 0.29]],
         [[0.02, 0.03], [0.04, 0.07], [0.08, 0.06]]]

# Ingestor

In [None]:
import os
import json
from datasets import concatenate_datasets, DatasetDict

class Ingestor:
    def __init__(self, save_dir='data', dataset_name='visual-layer/oxford-iiit-pet-vl-enriched'):
        self.save_dir=save_dir
        self.dataset_name=dataset_name

    def create_dataset(self, train_size=0.9):
        columns=['image','label_bbox_enriched']
        dataset = load_dataset(self.dataset_name).select_columns(columns).cast_column('image', ImageFeature(mode='RGB'))
        dataset=concatenate_datasets([dataset['train'],dataset['test']])

        tmp=dataset.train_test_split(train_size=train_size)
        self.dataset=DatasetDict()
        self.dataset['train']=tmp.pop('train')
        tmp=tmp['test'].train_test_split(train_size=0.5)
        self.dataset['validation']=tmp.pop('train')
        self.dataset['test']=tmp.pop('test')

        self.dataset=self.dataset.map(self.preprocess, batched=True, batch_size=1024, remove_columns=['label_bbox_enriched'])
        self.dataset.save_to_disk(self.save_dir)

        return self.dataset

    @staticmethod
    def preprocess(batch):
        target=[x if x is not None else [] for x in batch['label_bbox_enriched']]
        return {'label_bbox': target}

    def create_label_mapping(self):
        labels={}
        for split in ['train', 'validation', 'test']:
            for sample in self.dataset[split]['label_bbox']:
                for box in sample:
                    label=box['label']
                    if label in labels.keys():
                        labels[label]+=1
                    else:
                        labels[label]=1

        labels=sorted(labels.items(), key=lambda x: -x[1])
        labels={x[0]:x[1] for x in labels}
        label_to_id = {label:id for id, label in enumerate(labels.keys())}
        id_to_label = {id:label for id,label in enumerate(labels.keys())}

        with open(self.save_dir+'/label_count.json', 'w') as file:
            json.dump(labels, file, indent=4)

        with open(self.save_dir+'/label_to_id.json', 'w') as file:
            json.dump(label_to_id, file, indent=4)

        with open(self.save_dir+'/id_to_label.json', 'w') as file:
            json.dump(id_to_label, file, indent=4)


if os.path.exists('data')==False:
    tmp=Ingestor()
    tmp.create_dataset()
    tmp.create_label_mapping()

In [None]:
with open('data/label_count.json', 'r') as file:
    labels=json.load(file)

labels

# Data module

In [None]:
import cv2
from torch.utils.data import DataLoader
from albumentations.pytorch import ToTensorV2

class Preprocessor:
    def __init__(self):
        self.test_transform = A.Compose(
            [
                A.LongestMaxSize(max_size=224),
                A.PadIfNeeded(min_height=224,min_width=224, border_mode=cv2.BORDER_CONSTANT, value=0),
                A.Normalize(mean=[0,0,0], std=[1,1,1], max_pixel_value=255),
                ToTensorV2()
            ],
            bbox_params=A.BboxParams(
                format='coco',
                min_visibility=0.4,
                label_fields=['labels']
            )
        )
        self.train_transform = A.Compose(
            [
                A.LongestMaxSize(max_size=224),
                A.PadIfNeeded(min_height=224,min_width=224, border_mode=cv2.BORDER_CONSTANT, value=0),
                A.ColorJitter(brightness=0.5, contrast=0.5, saturation=0.5, hue=0.5, p=0.5),
                A.HorizontalFlip(p=0.5),
                A.Normalize(mean=[0,0,0], std=[1,1,1], max_pixel_value=255),
                ToTensorV2()
            ],
            bbox_params=A.BboxParams(
                format='coco',
                min_visibility=0.4,
                label_fields=['labels']
            )
        )

class Collator:
    def __init__(self, label_to_id):
        self.anchors=np.array(anchors)
        self.image_size=224
        self.num_scale=3
        self.num_anchors_per_scale=3
        self.grid_sizes=[self.image_size//32, self.image_size//16, self.image_size//8]
        self.num_classes=208
        self.preprocessor=Preprocessor()
        self.transform_fn=self.preprocessor.train_transform
        self.label_to_id=label_to_id

    def set_mode(self, split):
        if split=='train':
            self.transform_fn=self.preprocessor.train_transform
        else:
            self.transform_fn=self.preprocessor.test_transform

    def __call__(self, batch):
        images=[]
        target=[np.zeros([len(batch), self.num_anchors_per_scale, grid_size,grid_size, 6], dtype='float32') for grid_size in self.grid_sizes]
        for batch_idx, sample in enumerate(batch):
            image = sample['image']
            bboxes = [i['bbox'] for i in sample['label_bbox']]
            labels = [i['label'] for i in sample['label_bbox']]
            processed = self.transform_fn(image=image, bboxes=bboxes, labels=labels)
            image, bboxes, labels = processed['image'], processed['bboxes'], processed['labels']
            images.append(image)

            for box, label in zip(bboxes, labels):
                box=self.coco_to_yolo_format(box)
                for scale_idx, grid_size in enumerate(self.grid_sizes):
                    cell_size = self.image_size/grid_size
                    i=int(box[1]/cell_size)
                    j=int(box[0]/cell_size)
                    ious=self.anchor_iou(self.anchors[scale_idx], box)
                    selected_anchor_idx=np.argmax(ious)
                    selected_anchor = self.anchors[scale_idx][selected_anchor_idx]

                    x = (box[0]-j*cell_size)/cell_size
                    y = (box[1]-i*cell_size)/cell_size

                    # real_w = exp(pred_w)*anchor_w*image_size
                    # real_h = exp(pred_h)*anchor_h*image_size
                    w = np.log(box[2]/selected_anchor[0]/self.image_size)
                    h = np.log(box[3]/selected_anchor[1]/self.image_size)

                    target[scale_idx][batch_idx, selected_anchor_idx, i, j,0]= 1
                    target[scale_idx][batch_idx, selected_anchor_idx, i, j,1:5] = x,y,w,h
                    target[scale_idx][batch_idx, selected_anchor_idx, i, j, 5] = self.label_to_id[label]

        images=torch.stack(images, dim=0)
        target=[torch.tensor(i) for i in target]
        return images, target

    def coco_to_yolo_format(self,box):
        x,y,w, h = box

        x_center = x+w/2
        y_center = y+h/2

        return x_center,y_center,w,h

    @staticmethod
    def anchor_iou(anchors, box):
        anchors_area = anchors[..., 0]*anchors[..., 1]
        box_area = box[0]*box[1]
        intersection = np.minimum(anchors[..., 0], box[0]) * np.minimum(anchors[..., 1], box[1])
        iou = intersection/(anchors_area+box_area-intersection)
        return iou

class DataModule:
    def __init__(self, data_dir='data'):
        self.dataset=load_from_disk(data_dir).with_format("numpy")
        with open(data_dir+'/label_to_id.json') as file:
            self.label_to_id = json.load(file)
        with open(data_dir+'/id_to_label.json') as file:
            self.id_to_label = json.load(file)

        self.collator=Collator(self.label_to_id)

    def get_data_loader(self,split, batch_size=4, shuffle=False):
        self.collator.set_mode(split)
        return DataLoader(self.dataset[split], batch_size=batch_size, shuffle=shuffle, drop_last=True, collate_fn=self.collator)

tmp=DataModule('data')
loader=tmp.get_data_loader('test')
test_batch=next(iter(loader))
print(test_batch[0].shape)
for x in test_batch[1]:
    print(x.shape)

# Plot

In [None]:
from PIL import Image, ImageDraw, ImageFont

def draw_boxes(image, annotations):
    colors = {}
    image.save('input.png')
    draw = ImageDraw.Draw(image)
    for annotation in annotations:
        x_min, y_min, width, height = annotation['bbox']
        x_max=x_min+width
        y_max=y_min+height
        box=[x_min, y_min, x_max,y_max]
        label = annotation['label']

        if label not in colors:
            colors[label] = (random.randint(0, 255), random.randint(0, 255), random.randint(0, 255))

        color = colors[label]
        draw.rectangle(box, outline=color, width=2)

        text_position = (box[0], box[1] - 10)
        font = ImageFont.load_default()

        text_bbox = font.getbbox(label)
        text_width = text_bbox[2] - text_bbox[0]
        text_height = text_bbox[3] - text_bbox[1]

        draw.rectangle(
            [text_position, (text_position[0] + text_width, text_position[1] + text_height)],
            fill=color
        )
        draw.text(text_position, label, fill=(255, 255, 255), font=font)

    image.save('target.png')
    return image

data=tmp.dataset['validation'][0]
result_image = draw_boxes(Image.fromarray(data['image']), data['label_bbox'])
result_image

# Architecture

In [None]:
class CNNBlock(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=0, use_bn=True):
        super().__init__()
        self.net=nn.Sequential(
            Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding, bias=not use_bn),
            BatchNorm2d(out_channels) if use_bn else Identity(),
            LeakyReLU(negative_slope=0.1) if use_bn else Identity()
        )

    def forward(self, x):
        return self.net(x)

tmp=CNNBlock(3,6)
x=torch.rand(7,3,224,224)
tmp(x).shape

In [None]:
class ResidualBlock(nn.Module):
    def __init__(self, in_channels, use_residual=True, num_repeats=1):
        super().__init__()
        layers=[]
        for _ in range(num_repeats):
            block=nn.Sequential(
                Conv2d(in_channels, in_channels//2, kernel_size=1),
                BatchNorm2d(in_channels//2),
                LeakyReLU(negative_slope=0.1),
                Conv2d(in_channels//2, in_channels, kernel_size=1),
                BatchNorm2d(in_channels),
                LeakyReLU(negative_slope=0.1)
            )
            layers.append(block)

        self.net=nn.ModuleList(layers)
        self.use_residual=use_residual
        self.num_repeats=num_repeats

    def forward(self, x):
        for layer in self.net:
            x = x + layer(x) if self.use_residual else layer(x)

        return x

tmp=ResidualBlock(3, num_repeats=3)
x=torch.rand(7,3,224,224)
tmp(x).shape

In [None]:
class ScalePrediction(nn.Module):
    def __init__(self, in_channels, num_classes, num_anchors_per_scale):
        super().__init__()
        self.net=nn.Sequential(
            Conv2d(in_channels, 2*in_channels, kernel_size=3, padding=1),
            BatchNorm2d(2*in_channels),
            LeakyReLU(negative_slope=0.1),
            Conv2d(2*in_channels, (num_classes+5)*num_anchors_per_scale, kernel_size=1),
            Rearrange('b (a c) w h -> b a w h c', c=num_classes+5)
        )

    def forward(self, x):
        return self.net(x)

tmp=ScalePrediction(3,208,3)
x=torch.rand(7,3,224,224)
tmp(x).shape

In [None]:
class Darknet(nn.Module):
    def __init__(self, in_channels, num_classes, num_anchors_per_scale=3):
        super().__init__()
        self.net=nn.ModuleList([
            CNNBlock(in_channels=3, out_channels=32, kernel_size=3, stride=1, padding=1),
            CNNBlock(in_channels=32, out_channels=64,kernel_size=3, stride=2, padding=1),
            ResidualBlock(in_channels=64, num_repeats=1),

            CNNBlock(in_channels=64, out_channels=128,kernel_size=3, stride=2, padding=1),
            ResidualBlock(in_channels=128, num_repeats=1),

            CNNBlock(in_channels=128, out_channels=256, kernel_size=3, stride=2, padding=1),
            ResidualBlock(in_channels=256, num_repeats=2),

            CNNBlock(in_channels=256, out_channels=512, kernel_size=3, stride=2, padding=1),
            ResidualBlock(in_channels=512, num_repeats=2),

            CNNBlock(in_channels=512, out_channels=1024, kernel_size=3, stride=2, padding=1),
            ResidualBlock(in_channels=1024, num_repeats=1),

            CNNBlock(in_channels=1024, out_channels=512, kernel_size=1),
            CNNBlock(in_channels=512, out_channels=1024, kernel_size=3, padding=1),
            ResidualBlock(in_channels=1024, use_residual=False, num_repeats=1),

            CNNBlock(in_channels=1024, out_channels=512, kernel_size=1),
            ScalePrediction(in_channels=512, num_classes=num_classes, num_anchors_per_scale=num_anchors_per_scale),

            CNNBlock(in_channels=512, out_channels=256, kernel_size=1),
            nn.Upsample(scale_factor=2),

            CNNBlock(in_channels=256+512, out_channels=256, kernel_size=1),
            CNNBlock(in_channels=256, out_channels=512, kernel_size=3, padding=1),
            ResidualBlock(in_channels=512, use_residual=False, num_repeats=1),

            CNNBlock(in_channels=512, out_channels=256, kernel_size=1),
            ScalePrediction(in_channels=256, num_classes=num_classes, num_anchors_per_scale=num_anchors_per_scale),

            CNNBlock(in_channels=256, out_channels=128, kernel_size=1),
            nn.Upsample(scale_factor=2),

            CNNBlock(in_channels=128+256, out_channels=128, kernel_size=1),
            CNNBlock(in_channels=128, out_channels=256, kernel_size=3, padding=1),
            ResidualBlock(in_channels=256, use_residual=False, num_repeats=1),

            CNNBlock(in_channels=256, out_channels=128, kernel_size=1),
            ScalePrediction(in_channels=128, num_classes=num_classes, num_anchors_per_scale=num_anchors_per_scale),
        ])

    def forward(self, x):
        outputs=[]
        route_connections=[]
        for idx, layer in enumerate(self.net):
            if isinstance(layer, ScalePrediction):
                outputs.append(layer(x))
                continue
            x=layer(x)
            if isinstance(layer, ResidualBlock) and layer.num_repeats==2:
                route_connections.append(x)
            elif isinstance(layer, nn.Upsample):
                x=torch.cat([x, route_connections[-1]], dim=1)
                route_connections.pop()

        return outputs

tmp=Darknet(3,208)
x=torch.rand(7,3,224,224)
out=tmp(x)
for i in out:
    print(i.shape)

In [None]:
class YoloMimic(nn.Module):
    def __init__(self, in_channels, num_classes):
        super().__init__()
        self.in_channels = in_channels
        self.num_classes = num_classes

        self.net_1 = Conv2d(in_channels, (num_classes+5)*3, kernel_size=3, stride=32, padding=1)
        self.net_2 = Conv2d(in_channels, (num_classes+5)*3, kernel_size=3, stride=16, padding=1)
        self.net_3 = Conv2d(in_channels, (num_classes+5)*3, kernel_size=3, stride=8, padding=1)

    def forward(self, x):
        x1 = self.net_1(x).reshape([-1, 3, 7, 7, self.num_classes+5])
        x2 = self.net_2(x).reshape([-1, 3, 14, 14, self.num_classes+5])
        x3 = self.net_3(x).reshape([-1, 3, 28, 28, self.num_classes+5])

        return x1, x2, x3

tmp=YoloMimic(3,208)
x=torch.rand(7,3,224,224)
out=tmp(x)
for i in out:
    print(i.shape)

# Postprocess

In [None]:
def format_prediction(predictions, anchors, grid_sizes):
    # prediction= list, scale, batch_size, num_anchors, w, h, num_classes+6
    refined=[]
    for scale_idx, scale in enumerate(predictions):
        for idx, anchor in enumerate(anchors[scale_idx]):
            pred=scale[:, idx]

            obj=F.sigmoid(pred[..., 0:1])
            x=F.sigmoid(pred[..., 1:2])+torch.arange(grid_sizes[scale_idx]).view(1, -1).expand(grid_sizes[scale_idx], -1).unsqueeze(-1)
            y=F.sigmoid(pred[..., 2:3])+torch.arange(grid_sizes[scale_idx]).view(-1,1).expand(-1, grid_sizes[scale_idx]).unsqueeze(-1)
            wh=torch.exp(pred[..., 3:5])*anchor
            prob, cls=torch.max(pred[..., 5:].softmax(dim=-1), dim=-1, keepdims=True)

            pred=torch.cat([obj, x, y, wh, cls, prob],dim=-1)
            refined.append(pred)

    return refined

res=format_prediction(out, torch.tensor(anchors), [7,14,28])
for i in res:
    print(i.shape)

In [None]:
def yolo_iou(box1, box2):
    x_min=np.maximum(box1[..., 0]-box1[..., 2]/2, box2[..., 0]-box2[..., 2]/2)
    x_max=np.minimum(box1[..., 0]+box1[..., 2]/2, box2[..., 0]+box2[..., 2]/2)

    y_min=np.maximum(box1[..., 1]-box1[..., 3]/2, box2[..., 1]-box2[..., 3]/2)
    y_max=np.minimum(box1[..., 1]+box1[..., 3]/2, box2[..., 1]+box2[..., 3]/2)

    intersection = np.maximum(0, x_max-x_min)*np.maximum(0, y_max-y_min)
    union = box1[..., 2]*box1[..., 3]+box2[..., 2]*box2[..., 3] - intersection

    return intersection / np.maximum(union, 1e-9)

def non_max_suppression(predictions, obj_threshold=0.5, iou_threshold=0.5):
    filtered_boxes=[[] for _  in range(predictions[0].shape[0])]
    for pred in predictions:
        for idx, sample in enumerate(pred):
            boxes=sample[sample[..., 0]>obj_threshold]
            filtered_boxes[idx].extend(boxes)

    result=[]
    for idx,batch in enumerate(filtered_boxes):
        if not batch:
            result.append([])
            continue

        boxes = torch.stack(batch, dim=0)
        boxes = boxes[torch.argsort(-boxes[..., 0])]

        selected=[]
        while len(boxes)>0:
            selected.append(boxes[0])
            ious=yolo_iou(selected[-1][1:5], boxes[..., 1:5])
            boxes = boxes[ious < iou_threshold]

        result.append(selected)

    return result

x=non_max_suppression([i.detach() for i in res])
for i in x:
    print(len(i))

In [None]:
def get_true_boxes(targets, anchors, grid_sizes):
    refined=[]
    for scale_idx, scale in enumerate(targets):
        for idx, anchor in enumerate(anchors[scale_idx]):
            target=scale[:, idx]

            obj=target[..., 0:1]
            x=target[..., 1:2]+torch.arange(grid_sizes[scale_idx]).view(1, -1).expand(grid_sizes[scale_idx], -1).unsqueeze(-1)
            y=target[..., 2:3]+torch.arange(grid_sizes[scale_idx]).view(-1,1).expand(-1, grid_sizes[scale_idx]).unsqueeze(-1)
            wh=torch.exp(target[..., 3:5])*anchor
            cls=target[..., 5:]

            target=torch.cat([obj, x, y, wh, cls],dim=-1)
            refined.append(target)

    result=[[] for _ in range(targets[0].shape[0])]
    for target in refined:
        for idx, sample in enumerate(target):
            result[idx].extend(sample[sample[..., 0]==1])

    return result
    # list batch box(obj, x,y,w,h,cls)

get_true_boxes(test_batch[1], torch.tensor(anchors), [7,14,28])

# MAP and MAR

In [None]:
def mimic_output(batch_size, grid_sizes, num_classes):
    outputs=[]
    for grid_size in grid_sizes:
        obj=torch.rand(batch_size,3, grid_size, grid_size, 1)
        x=torch.rand(batch_size,3, grid_size, grid_size, 1)
        y=torch.rand(batch_size,3, grid_size, grid_size, 1)
        w=torch.rand(batch_size,3, grid_size, grid_size, 1)
        h=torch.rand(batch_size,3, grid_size, grid_size, 1)
        cls=torch.rand(batch_size, 3, grid_size, grid_size, num_classes)
        o=torch.cat([obj,x,y,w,h,cls], dim=-1)
        outputs.append(o)

    return outputs

mimic=mimic_output(4,[7,14,28],208)
for i in mimic:
    print(i.shape)

In [None]:
class MAPR:
    def __init__(self, anchors, grid_sizes=None, eps=1e-6):
        super().__init__()
        self.eps=eps
        self.num_classes=208
        self.anchors=torch.tensor(anchors)
        self.grid_sizes=grid_sizes

        self.iou_thresholds=[0.0, 0.075]

        self.TP_FP=[[] for _ in range(self.num_classes)]
        self.TP_FN=np.zeros([self.num_classes])

    def box_count(self, pred_boxes,true_boxes):
        for instance_pred_boxes, instance_true_boxes in zip(pred_boxes, true_boxes):
            if len(instance_true_boxes)==0:
                   continue

            instance_pred_boxes = sorted(instance_pred_boxes, key=lambda x:x[5], reverse=True)

            for true_box in instance_true_boxes:
                self.TP_FN[int(true_box[4])]+=1

            matched=[False for _ in range(len(instance_true_boxes))]
            instance_pred_boxes = torch.stack(instance_pred_boxes, dim=0)
            instance_true_boxes = torch.stack(instance_true_boxes, dim=0)
            ious=yolo_iou(instance_pred_boxes.unsqueeze(1), instance_true_boxes.unsqueeze(0))
            best_scores,best_indices = torch.max(ious, dim=1)

            for idx, (best_score, best_idx) in enumerate(zip(best_scores, best_indices)):
                pred_box=instance_pred_boxes[idx]
                if matched[best_idx]==False:
                    self.TP_FP[int(pred_box[4])].append([best_score, 'TP'])
                    matched[best_idx]=True
                else:
                    self.TP_FP[int(pred_box[4])].append([best_score, 'FP'])

    def reset(self):
        self.TP_FP=[[] for _ in range(self.num_classes)]
        self.TP_FN=np.zeros([self.num_classes])

    def update(self, preds, targets):
        true_boxes=get_true_boxes(targets, self.anchors, self.grid_sizes)
        preds=format_prediction(preds, self.anchors, self.grid_sizes)
        pred_boxes=non_max_suppression(preds, iou_threshold=0.05, obj_threshold=0.05)

        self.box_count(pred_boxes,true_boxes)

    def compute(self):
        mAP=np.zeros([len(self.iou_thresholds), self.num_classes])
        res={}
        for cls_idx, P in enumerate(self.TP_FP):
            for threshold_idx, threshold in enumerate(self.iou_thresholds):
                P=sorted(P, key= lambda x: x[0], reverse=True)
                TP=[1 if threshold<=x[0] and x[1]=='TP' else 0 for x in P]
                FP=[1 if threshold>x[0] and x[1]=='FP' else 0 for x in P]
                TP_cumsum=torch.cumsum(torch.tensor(TP),dim=0)
                FP_cumsum=torch.cumsum(torch.tensor(FP),dim=0)
                recall=TP_cumsum/(self.TP_FN[cls_idx]+self.eps)
                recall=torch.cat([torch.tensor([0]), recall])

                precision=torch.divide(TP_cumsum, TP_cumsum+FP_cumsum+self.eps)
                precision=torch.cat([torch.tensor([1]), precision])

                mAP[threshold_idx][cls_idx]=torch.trapz(precision, recall)

            res['mAP@'+str(threshold)]=sum(mAP[threshold_idx])/len(mAP[threshold_idx])

        return res

tmp=MAPR(anchors, [7,14,28])
tmp.update(mimic,test_batch[1])
tmp.compute()

# Metric

In [None]:
class Metric:
    def __init__(self):
        self.metric={}
        self.current_metric={}
        # self.mAP=MAPR(anchors, [7,14,28])

    def update(self, batch_metric, preds=None, targets=None):
        for key,value in batch_metric.items():
            if key not in self.current_metric.keys():
                self.current_metric[key]=[]

            self.current_metric[key].append(value)

        # if preds is not None and targets is not None:
        #     self.mAP.update(preds, targets)

    def finalize(self):
        # self.update(self.mAP.compute())
        for key, value in self.current_metric.items():
            if key not in self.metric.keys():
                self.metric[key]=[]
            self.metric[key].append(sum(value)/len(value))

        self.current_metric={}

    def plot(self):
        plt.figure(figsize=(10, 5))
        for key, val in self.metric.items():
            plt.plot(val, label=key)

        plt.title("Metrics per Epoch")
        plt.xlabel("Epochs")
        plt.ylabel("Metrics")
        plt.legend()

tmp=Metric()
for i in range(10):
    tmp.update({'loss': i,
                'acc': i+1})
    tmp.finalize()
tmp.plot()

# Loss

In [None]:
class YOLOLoss:
    def __init__(self):
        self.anchors=torch.tensor(anchors)
        self.lambda_obj=1
        self.lambda_noobj=1
        self.lambda_coord=1
        self.lambda_class=1

        self.mse = torch.nn.MSELoss()
        self.bce = torch.nn.BCEWithLogitsLoss()
        self.ce = torch.nn.CrossEntropyLoss()

    def __call__(self, preds, targets):
        loss=0.0
        for pred, target, anchor in zip(preds, targets, self.anchors):
            loss=loss+self.single_loss(pred, target, anchor)

        return loss

    def single_loss(self, pred, target, anchors):
        obj = target[..., 0]==1
        noobj = target[..., 0]==0

        # no obj loss
        noobj_loss = self.bce(pred[..., 0][noobj], target[..., 0][noobj])

        # obj loss
        anchors = anchors.reshape(1,3,1,1,2)
        pred_boxes = torch.cat([torch.sigmoid(pred[...,1:3]), torch.exp(pred[...,3:5])*anchors], dim=-1)
        iou_scores = yolo_iou(pred_boxes[obj].detach(), target[..., 1:5][obj].detach())
        obj_loss = self.bce(pred[...,0][obj], iou_scores*target[...,0][obj])

        # box loss
        pred_boxes = torch.cat([torch.sigmoid(pred[..., 1:3]), pred[..., 3:5]], dim=-1)
        true_boxes = torch.cat([target[..., 1:3], target[...,3:5]], axis=-1)
        coord_loss = self.mse(pred_boxes[obj], true_boxes[obj])
        # class loss
        class_loss = self.ce(pred[...,5:][obj], target[..., 5][obj].long())

        return self.lambda_noobj*noobj_loss + self.lambda_obj*obj_loss + self.lambda_coord*coord_loss + self.lambda_class*class_loss

# Training

In [None]:
class Trainer:
    def __init__(self):
        self.data_module = DataModule('data')
        self.criterion=YOLOLoss()
        self.model=YoloMimic(3,208).to(device)
        self.opt=torch.optim.Adam(params = self.model.parameters(), lr=1e-3)
        self.metrics=Metric()

    def train_step(self, inputs, targets):
        inputs = inputs.to(device)
        targets = [target.to(device) for target in targets]
        preds = self.model(inputs)
        loss= self.criterion(preds, targets)

        self.opt.zero_grad()
        loss.backward()
        self.opt.step()

        return loss.item()

    def val_step(self, inputs, targets):
        inputs = inputs.to(device)
        targets = [target.to(device) for target in targets]
        with torch.no_grad():
            preds = self.model(inputs)
            loss= self.criterion(preds, targets)

        return loss.item()

    def evaluate(self, batch_size=32):
        mAP=MAPR(anchors, [7,14,28])
        test_loader = self.data_module.get_data_loader('test', batch_size=batch_size, shuffle=False)
        p_bar = tqdm(test_loader, desc='Evaluation: ')
        self.model.eval()
        for inputs, targets in p_bar:
            inputs = inputs.to(device)
            targets = [target.to(device) for target in targets]
            with torch.no_grad():
                preds = self.model(inputs)
                loss= self.criterion(preds, targets).item()

            mAP.update(preds, targets)
            self.metrics.update({'train_loss': loss})
            p_bar.set_postfix(train_loss=loss)

        return mAP.compute()

    def train(self, num_epochs=1, batch_size=32):
        train_loader = self.data_module.get_data_loader('train', batch_size=batch_size, shuffle=True)
        val_loader = self.data_module.get_data_loader('validation', batch_size=batch_size, shuffle=False)
        for epoch in tqdm(range(num_epochs)):
            self.model.train()
            p_bar = tqdm(train_loader, desc=f'Epoch {epoch+1}/{num_epochs}: ')
            for inputs, targets in p_bar:
                loss=self.train_step(inputs, targets)

                self.metrics.update({'train_loss': loss})
                p_bar.set_postfix(train_loss=loss)

            self.model.eval()
            p_bar = tqdm(val_loader, desc=f'Epoch {epoch+1}/{num_epochs}: ')
            for inputs, targets in p_bar:
                loss=self.val_step(inputs, targets)

                self.metrics.update({'val_loss': loss})
                p_bar.set_postfix(val_loss=loss)

            self.metrics.finalize()

    def save(self):
        torch.save(self.model, 'model.pt')
        model_scripted = torch.jit.script(self.model)
        model_scripted.save('model_scripted.pt')

trainer=Trainer()
trainer.train()
# tmp.evaluate()

# Inference

In [None]:
from PIL import Image, ImageDraw, ImageFont

def yolo_to_pascal_voc_format(box):
    x1 = box[...,0]-box[...,2]/2
    y1 = box[...,1]-box[...,3]/2
    x2 = box[...,0]+box[...,2]/2
    y2 = box[...,1]+box[...,3]/2

    return np.array([x1,y1,x2,y2])

class Predictor:
    def __init__(self, model, data_dir, anchors, grid_sizes):
        self.transform=Preprocessor().test_transform
        self.model=model
        self.anchors=anchors
        self.grid_sizes=grid_sizes
        with open(data_dir+'/id_to_label.json') as file:
            self.id_to_label = json.load(file)

    def inference(self, img):
        # img is pil image
        input=self.transform(image=np.array(img), labels=[])['image'][None]
        with torch.no_grad():
            output=self.model(input)
        output=format_prediction(output, self.anchors, self.grid_sizes)
        boxes=non_max_suppression(output)

        img_size=img.size
        scale_factor=max(img_size[0], img_size[1])/input.shape[-1]
        return self.draw_boxes(img, boxes[0], scale_factor)

    def draw_boxes(self, image, boxes, scale_factor):
        colors = {}
        draw = ImageDraw.Draw(image)
        for box in boxes:
            box=np.array(box)
            label_idx=str(int(box[5]))
            label_prob=box[6]
            label=self.id_to_label[label_idx]
            box=box[1:5]
            box=yolo_to_pascal_voc_format(box)
            box = [int(coord*scale_factor) for coord in box]
            if label not in colors:
                colors[label] = (random.randint(0, 255), random.randint(0, 255), random.randint(0, 255))

            color = colors[label]
            draw.rectangle(box, outline=color, width=2)

            text_position = (box[0], box[1] - 10)
            font = ImageFont.load_default()

            text_bbox = font.getbbox(label)
            text_width = text_bbox[2] - text_bbox[0]
            text_height = text_bbox[3] - text_bbox[1]

            draw.rectangle(
                [text_position, (text_position[0] + text_width, text_position[1] + text_height)],
                fill=color
            )
            draw.text(text_position, label+': '+str(round(label_prob*100,2)), fill=(255, 255, 255), font=font)

        return image

predictor=Predictor(trainer.model, 'data', torch.tensor(anchors), [7,14,28])
img=Image.open('input.png')
predictor.inference(img)