In [None]:
import torchvision
import numpy as np
from torch.utils import data
from matplotlib import pyplot as plt
import cv2
import torchvision.transforms.v2 as transforms
from collections import defaultdict
import PIL
import torch
from torch.utils import data
import torch.nn as nn
from torchsummary import summary
from collections import OrderedDict
import torch.nn.functional as F
from core.loss_functions import CenternetTTFLoss
from datasets.encoders import CenternetEncoder
from utils.visualizer import get_image_with_bboxes

%load_ext autoreload
%autoreload 2

print("GPU is available: ", torch.cuda.is_available() )

In [None]:
input_height = input_width = 256
down_ratio = 4  # model output compared to model input

In [None]:
dataset_val = torchvision.datasets.VOCDetection(
    root="~/projects/VOC", year="2007", image_set="val", download=False
)


dataset_val = torchvision.datasets.wrap_dataset_for_transforms_v2(dataset_val)


print(len(dataset_val))

In [None]:
img, lbl = dataset_val[2000]
img, lbl

In [None]:
image_with_boxes = get_image_with_bboxes(img, lbl["boxes"], lbl["labels"])
plt.imshow(image_with_boxes)

In [None]:
transform = transforms.Compose(
    [
        # transforms.RandomHorizontalFlip(p=0.5),
        transforms.Resize(size=(input_width, input_height)),
    ]
)


img, bboxes, labels = transform(img, lbl["boxes"], lbl["labels"])

In [None]:
encoder = CenternetEncoder(input_height, input_width)
lbl_encoded = encoder.encode(bboxes, labels)
plt.imshow(lbl_encoded[..., 2])

In [None]:
image_with_boxes = get_image_with_bboxes(img, bboxes, labels)
plt.imshow(image_with_boxes)

In [None]:
hm_chosen = lbl_encoded[..., 8]

ind_max = np.argwhere(hm_chosen == np.amax(hm_chosen))
for ind in ind_max:
    print("rect center:", ind * 4)
    print("coors", lbl_encoded[..., 20:][ind[0], ind[1]])
    print()

In [None]:
from torch.utils import data


class Dataset(data.Dataset):

    def __init__(self, dataset, transformation, encoder):
        self._dataset = dataset
        self._transformation = transformation
        self._encoder = encoder

    def __getitem__(self, index):

        img, lbl = self._dataset[index % 32]

        img_, bboxes_, labels_ = self._transformation(
            img, lbl["boxes"], lbl["labels"]
        )

        lbl_encoded = self._encoder.encode(bboxes_, labels_)

        return img_, torch.from_numpy(lbl_encoded)

    def __len__(self):
        return 32
        # return len(self._dataset)


encoder = CenternetEncoder(input_height, input_width)


transform = transforms.Compose(
    [transforms.Resize(size=(input_width, input_height)), transforms.ToTensor()]
)


torch_dataset = Dataset(
    dataset=dataset_val, transformation=transform, encoder=encoder
)

In [None]:
img_, lbl_encoded = torch_dataset[30]
img_.shape, lbl_encoded.numpy().shape

In [None]:
plt.imshow(img_.numpy().transpose(1, 2, 0))

In [None]:
plt.imshow(lbl_encoded.numpy()[:, :, 7])

In [None]:
batch_generator = torch.utils.data.DataLoader(
    torch_dataset, batch_size=4, num_workers=0, shuffle=False
)

In [None]:
iter_batch = iter(batch_generator)
img, gt_data = next(iter_batch)

In [None]:
torch.min(gt_data), torch.max(gt_data)

In [None]:
plt.imshow(gt_data[0].data.numpy().sum(axis=2))

In [None]:
class Backbone(nn.Module):
    def __init__(self, alpha=1.0):
        super().__init__()
        self.block_num = 1
        self.alpha = alpha
        self.filters = np.array(
            [
                64 * self.alpha,
                128 * self.alpha,
                256 * self.alpha,
                512 * self.alpha,
                512 * self.alpha,
            ]
        ).astype("int")
        s = self.filters
        self.layer1 = self.conv_bn_relu(3, s[0], False)
        self.layer2 = self.conv_bn_relu(s[0], s[0], True)  # stride 2
        self.layer3 = self.conv_bn_relu(s[0], s[1], False)
        self.layer4 = self.conv_bn_relu(s[1], s[1], True)  # stride 4
        self.layer5 = self.conv_bn_relu(s[1], s[2], False)
        self.layer6 = self.conv_bn_relu(s[2], s[2], False)
        self.layer7 = self.conv_bn_relu(s[2], s[2], True)  # stride 8
        self.layer8 = self.conv_bn_relu(s[2], s[3], False)
        self.layer9 = self.conv_bn_relu(s[3], s[3], False)
        self.layer10 = self.conv_bn_relu(s[3], s[3], True)  # stride 16
        self.layer11 = self.conv_bn_relu(s[4], s[4], False)
        self.layer12 = self.conv_bn_relu(s[4], s[4], False)
        self.layer13 = self.conv_bn_relu(s[4], s[4], True)  # stride 32

    def conv_bn_relu(
        self, input_num, output_num, max_pool=False, kernel_size=3
    ):
        block = OrderedDict()
        block["conv_" + str(self.block_num)] = nn.Conv2d(
            input_num, output_num, kernel_size=kernel_size, stride=1, padding=1
        )
        block["bn_" + str(self.block_num)] = nn.BatchNorm2d(
            output_num, eps=1e-3, momentum=0.01
        )
        block["relu_" + str(self.block_num)] = nn.ReLU()
        if max_pool:
            block["pool_" + str(self.block_num)] = nn.MaxPool2d(
                kernel_size=2, stride=2
            )
        self.block_num += 1
        return nn.Sequential(block)

    def forward(self, x):
        out = self.layer1(x)
        out_stride_2 = self.layer2(out)
        out = self.layer3(out_stride_2)
        out_stride_4 = self.layer4(out)
        out = self.layer5(out_stride_4)
        out = self.layer6(out)
        out_stride_8 = self.layer7(out)
        out = self.layer8(out_stride_8)
        out = self.layer9(out)
        out_stride_16 = self.layer10(out)
        out = self.layer11(out_stride_16)
        out = self.layer12(out)
        out_stride_32 = self.layer13(out)
        return (
            out_stride_2,
            out_stride_4,
            out_stride_8,
            out_stride_16,
            out_stride_32,
        )

In [None]:
class Head(nn.Module):
    def __init__(self, backbone_output_filters, class_number=20):
        super().__init__()
        self.connection_num = 3
        self.class_number = class_number
        self.backbone_output_filters = backbone_output_filters
        self.filters = [128, 64, 32]
        head_filters = [self.backbone_output_filters[-1]] + self.filters

        for i, filter_num in enumerate(self.filters):
            name = f"head_{i+1}"
            setattr(
                self,
                name,
                self.conv_bn_relu(name, head_filters[i], head_filters[i + 1]),
            )
            # create connection with backbone
            if i < self.connection_num:
                name = f"after_{-2-i}"
                setattr(
                    self,
                    name,
                    self.conv_bn_relu(
                        name,
                        self.backbone_output_filters[-2 - i],
                        self.filters[i],
                        1,
                    ),
                )

        self.before_hm = self.conv_bn_relu(
            "before_hm", self.filters[-1], self.filters[-1]
        )
        self.before_sizes = self.conv_bn_relu(
            "before_sizes", self.filters[-1], self.filters[-1]
        )

        self.hm = self.conv_bn_relu(
            "hm", self.filters[-1], self.class_number, 3, "sigmoid"
        )
        self.sizes = self.conv_bn_relu("hm", self.filters[-1], 4, 3, None)

    def conv_bn_relu(
        self, name, input_num, output_num, kernel_size=3, activation="relu"
    ):
        block = OrderedDict()
        padding = 1 if kernel_size == 3 else 0
        block["conv_" + name] = nn.Conv2d(
            input_num,
            output_num,
            kernel_size=kernel_size,
            stride=1,
            padding=padding,
        )
        block["bn_" + name] = nn.BatchNorm2d(
            output_num, eps=1e-3, momentum=0.01
        )
        if activation == "relu":
            block["relu_" + name] = nn.ReLU()
        elif activation == "sigmoid":
            block["sigmoid_" + name] = nn.Sigmoid()
        return nn.Sequential(block)

    def connect_with_backbone(self, *backbone_out):
        used_out = [backbone_out[-i - 2] for i in range(self.connection_num)]
        x = backbone_out[-1]
        for i in range(len(self.filters)):
            x = getattr(self, "head_{}".format(i + 1))(x)
            x = F.interpolate(x, scale_factor=2, mode="nearest")
            if i < self.connection_num:
                name = f"after_{-2-i}"
                x_ = getattr(self, name)(used_out[i])
                x = torch.add(x, x_)
        return x

    def forward(self, *backbone_out):
        self.last_shared_layer = self.connect_with_backbone(self, *backbone_out)
        x = self.before_hm(self.last_shared_layer)
        hm_out = self.hm(x)

        x = self.before_sizes(self.last_shared_layer)
        sizes_out = self.sizes(x)

        x = torch.cat((hm_out, sizes_out), dim=1)
        return x

In [None]:
class ModelBuilder(nn.Module):
    """
    To connect head with backbone
    """

    def __init__(self, alpha=1.0, class_number=20):
        super().__init__()
        self.class_number = class_number
        self.backbone = Backbone(alpha)
        self.head = Head(
            backbone_output_filters=self.backbone.filters,
            class_number=class_number,
        )
        self.loss = CenternetTTFLoss(
            class_number, 4, input_height // 4, input_width // 4
        )

    def forward(self, x, gt=None):
        x = x / 0.5 - 1.0  # normalization
        out = self.backbone(x)
        pred = self.head(*out)

        if gt is None:
            return pred
        else:
            loss = self.loss(gt, pred)
            return loss

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = ModelBuilder(alpha=0.25).to(device)
summary(model, input_size=(3, 256, 256), batch_size=-1)

In [None]:
lr = 0.01
parameters = list(model.parameters())
optimizer = torch.optim.Adam(parameters, lr=lr)

In [None]:
loss_history = []

In [None]:
for g in optimizer.param_groups:
    g["lr"] = 0.001

In [None]:
EPOCHS = 1500
# model.eval()
model.train(True)

for epoch in range(EPOCHS):
    print("EPOCH {}:".format(epoch + 1))
    for i, data in enumerate(batch_generator):
        img, gt_data = data
        img = img.to(device).contiguous()

        gt_data = gt_data.to(device)
        gt_data.requires_grad = False

        loss_dict = model(img, gt=gt_data)
        optimizer.zero_grad()  # compute gradient and do optimize step
        loss_dict["loss"].backward()

        optimizer.step()

    loss_history.append(loss_dict["loss"].cpu().detach().numpy())
    print(loss_dict["loss"])

In [None]:
len(loss_history)
raw_loss = list(map(lambda x: x["loss"].cpu().detach().numpy(), loss_history))

In [None]:
plt.plot(raw_loss)

In [None]:
img, lbl = dataset_val[10]

In [None]:
image_with_boxes = get_image_with_bboxes(img, lbl["boxes"], lbl["labels"])
plt.imshow(image_with_boxes)

In [None]:
lbl["labels"]

In [None]:
model.eval()

In [None]:
img.size

In [None]:
pred_input = transform(img)

In [None]:
plt.imshow(pred_input.numpy().transpose(1, 2, 0))

In [None]:
reshaped = torch.reshape(pred_input, [1, 3, 256, 256])

In [None]:
pred = model.forward(reshaped.to(device))

In [None]:
numpy_pred = pred.cpu().detach().numpy()

In [None]:
numpy_pred.shape

In [None]:
for i in range(20):
    print(np.max(numpy_pred[0, i, :, :]))

In [None]:
# TODO: take 4 channels (distance from image borders), show bounding boxes

In [None]:
plt.imshow(numpy_pred[0, 14, ...])

In [None]:
torch.save(model, "pascal_voc_30img.pt")

In [None]:
torch.save(model.state_dict(), "pascal_voc_30img_state_dict.pt")

In [None]:
model.load_state_dict(
    torch.load("pascal_voc_30img_state_dict.pt", weights_only=True)
)

In [None]:
from core.postprocess import CenternetPostprocess

In [None]:
pp = CenternetPostprocess()