In [1]:
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
from lib.utils import non_max_suppression, intersection_over_union, load_checkpoint, cells_to_bboxes
from lib.utils import intersection_over_union as iou
#from lib.YOLOV3 import YOLOv3 as YOLO
from albumentations.pytorch import ToTensorV2
import albumentations as A
from lib import config as C
import cv2
from torch.utils.data import DataLoader
from PIL import Image, ImageFile, ImageDraw, ImageFont
from torch.utils.data import Dataset, DataLoader
import torch.optim as optim
from lib.utils import (
    cells_to_bboxes,
    iou_width_height as iou,
    non_max_suppression as nms,
    plot_image
)
from matplotlib import pyplot

  from .autonotebook import tqdm as notebook_tqdm


In [8]:
config_yolo = [
    (32, 3, 1),
    (64, 3, 2),
    ["B", 1],
    (128, 3, 2),
    ["B", 2],
    (256, 3, 2),
    ["B", 8],
    (512, 3, 2),
    ["B", 8],
    (1024, 3, 2),
    ["B", 4],  # To this point is Darknet-53
    (512, 1, 1),
    (1024, 3, 1),
    "S",
    (256, 1, 1),
    "U",
    (256, 1, 1),
    (512, 3, 1),
    "S",
    (128, 1, 1),
    "U",
    (128, 1, 1),
    (256, 3, 1),
    "S",
]


class CNNBlock(nn.Module):
    def __init__(self, in_channels, out_channels, bn_act=True, **kwargs):
        super().__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, bias= not bn_act, **kwargs)
        self.bn = nn.BatchNorm2d(out_channels)
        self.leaky = nn.LeakyReLU(0.1)
        self.use_bn_act = bn_act

    def forward(self, x):
        if self.use_bn_act:
            return self.leaky(self.bn(self.conv(x)))
        else:
            return self.conv(x)


class ResidualBlock(nn.Module):
    def __init__(self, channels, use_residual=True, num_repeats=1):
        super().__init__()
        self.layers = nn.ModuleList()
        for repeat in range(num_repeats):
            self.layers += [
                nn.Sequential(
                    CNNBlock(channels, channels // 2, kernel_size=1),
                    CNNBlock(channels // 2, channels, kernel_size=3, padding=1),
                )
            ]

        self.use_residual = use_residual
        self.num_repeats = num_repeats

    def forward(self, x):
        for layer in self.layers:
            if self.use_residual:
                x = x + layer(x)
            else:
                x = layer(x)

        return x


class ScalePrediction(nn.Module):
    # False
    def __init__(self, in_channels, num_classes):
        super().__init__()
        self.pred = nn.Sequential(
            CNNBlock(in_channels, 2 * in_channels, kernel_size=3, padding=1),
            CNNBlock(
                2 * in_channels, (num_classes + 5) * 3, bn_act=False, kernel_size=1
            ),
        )
        self.num_classes = num_classes

    def forward(self, x):
        return (
            self.pred(x)
            .reshape(x.shape[0], 3, self.num_classes + 5, x.shape[2], x.shape[3])
            .permute(0, 1, 3, 4, 2)
        )


class YOLOv3(nn.Module):
    def __init__(self, in_channels=3, num_classes=80):
        super().__init__()
        self.num_classes = num_classes
        self.in_channels = in_channels
        self.layers = self._create_conv_layers()

    def forward(self, x):
        outputs = []  # for each scale
        route_connections = []
        test = 0
        for layer in self.layers:
            if isinstance(layer, ScalePrediction):
                outputs.append(layer(x))
                continue

            x = layer(x)
            print(x.size())
            if test >= 0:
                test += 1
                tmp = x.to("cpu").numpy()
                ix = 1
                #pyplot.rcParams.update({'figure.max_open_warning': 0})
                pyplot.figure(figsize=(40,20))
                for _ in range(4):
                    for _ in range(8):
                        ax = pyplot.subplot(4,8, ix)
                        ax.set_xticks([])
                        ax.set_yticks([])
                        #ax.set_aspect

                        pyplot.imshow(tmp[0,ix-1,:,:], cmap="gray")
                        ix += 1
                #pyplot.show()
                pyplot.savefig("lib/datasets/test_images/test/layer_"+str(test)+".png")
                pyplot.close()





            if isinstance(layer, ResidualBlock) and layer.num_repeats == 8:
                route_connections.append(x)

            elif isinstance(layer, nn.Upsample):
                x = torch.cat([x, route_connections[-1]], dim=1)
                route_connections.pop()

        return outputs

    def _create_conv_layers(self):
        layers = nn.ModuleList()
        in_channels = self.in_channels

        for module in config_yolo:
            if isinstance(module, tuple):
                out_channels, kernel_size, stride = module
                layers.append(
                    CNNBlock(
                        in_channels,
                        out_channels,
                        kernel_size=kernel_size,
                        stride=stride,
                        padding=1 if kernel_size == 3 else 0,
                    )
                )
                in_channels = out_channels

            elif isinstance(module, list):
                num_repeats = module[1]
                layers.append(ResidualBlock(in_channels, num_repeats=num_repeats,))

            elif isinstance(module, str):
                if module == "S":
                    layers += [
                        ResidualBlock(in_channels, use_residual=False, num_repeats=1),
                        CNNBlock(in_channels, in_channels // 2, kernel_size=1),
                        ScalePrediction(in_channels // 2, num_classes=self.num_classes),
                    ]
                    in_channels = in_channels // 2

                elif module == "U":
                    layers.append(nn.Upsample(scale_factor=2),)
                    in_channels = in_channels * 3

        return layers

In [9]:
def get_eval_boxes(x, model, anchors, iou_threshold, threshold, device="cuda"):
    model.eval()
    x = x.to(device)
    tmp = torch.reshape(x,(1,x.size()[0],x.size()[1],x.size()[2]))
    all_pred_boxes = []
    train_idx = 0
    with torch.no_grad():
        preditcions = model(tmp)
    batch_size = tmp.shape[0]
    bboxes = bboxes = [[] for _ in range(batch_size)]
    for i in range(3):
        S = preditcions[i].shape[2]
        anchor = torch.tensor([*anchors[i]]).to(device) * S
        boxes_scale_i = cells_to_bboxes(
            preditcions[i], anchor, S = S, is_preds=True

        )
        for idx, (box) in enumerate(boxes_scale_i):
            bboxes[idx] += box

    for idx in range(batch_size):
        nms_boxes = non_max_suppression(
            bboxes[idx],
            iou_threshold = iou_threshold,
            threshold = threshold,
            box_format = "midpoint"
        )

        for nms_box in nms_boxes:
            all_pred_boxes.append([train_idx] + nms_box)
            train_idx += 1
    model.train()
    return all_pred_boxes

In [10]:
test_transforms = A.Compose(
    [
        A.LongestMaxSize(max_size=C.IMAGE_SIZE),
        A.PadIfNeeded(
            min_height=int(C.IMAGE_SIZE),
            min_width=int(C.IMAGE_SIZE),
            border_mode=cv2.BORDER_CONSTANT,
        ),
        A.Normalize(mean=[0, 0, 0], std=[1, 1, 1], max_pixel_value=255, ),
        ToTensorV2(),
    ],
)
path = "lib/datasets/LISA/LISA/"
S=[C.IMAGE_SIZE // 32, C.IMAGE_SIZE // 16, C.IMAGE_SIZE // 8]
anchors = C.ANCHORS
check = "lib/models/checkpoint_test.pth.tar"
model = YOLOv3(num_classes=C.NUM_CLASSES).to(C.DEVICE)
optimizer = optim.Adam(
        model.parameters(), lr=C.LEARNING_RATE, weight_decay=C.WEIGHT_DECAY
    )
load_checkpoint(
        check, model, optimizer, C.LEARNING_RATE
    )
ids = pd.read_csv("lib/datasets/ids.csv")

=> Loading checkpoint


In [12]:
test_image = "lib/datasets/test_images/stop_test.jpg"
image = np.array(Image.open(test_image).convert("RGB"))
augmentations = test_transforms(image=image)
image1 = augmentations["image"]
bboxes = get_eval_boxes(image1, model, C.ANCHORS, iou_threshold = C.NMS_IOU_THRESH, threshold=C.CONF_THRESHOLD)

torch.Size([1, 32, 416, 416])
torch.Size([1, 64, 208, 208])
torch.Size([1, 64, 208, 208])
torch.Size([1, 128, 104, 104])
torch.Size([1, 128, 104, 104])
torch.Size([1, 256, 52, 52])
torch.Size([1, 256, 52, 52])
torch.Size([1, 512, 26, 26])
torch.Size([1, 512, 26, 26])
torch.Size([1, 1024, 13, 13])
torch.Size([1, 1024, 13, 13])
torch.Size([1, 512, 13, 13])
torch.Size([1, 1024, 13, 13])
torch.Size([1, 1024, 13, 13])
torch.Size([1, 512, 13, 13])
torch.Size([1, 256, 13, 13])
torch.Size([1, 256, 26, 26])
torch.Size([1, 256, 26, 26])
torch.Size([1, 512, 26, 26])
torch.Size([1, 512, 26, 26])
torch.Size([1, 256, 26, 26])
torch.Size([1, 128, 26, 26])
torch.Size([1, 128, 52, 52])
torch.Size([1, 128, 52, 52])
torch.Size([1, 256, 52, 52])
torch.Size([1, 256, 52, 52])
torch.Size([1, 128, 52, 52])
