In [None]:
import torchvision
import numpy as np
from matplotlib import pyplot as plt
import torchvision.transforms.v2 as transforms
import torch
from torchsummary import summary
from datasets.dataset import DatasetIterator

from datasets.encoders import CenternetEncoder
from core.model_builder import ModelBuilder
from utils.visualizer import get_image_with_bboxes

%load_ext autoreload
%autoreload 2

print("GPU is available: ", torch.cuda.is_available())

### Load Dataset

In [None]:
input_height = input_width = 256
down_ratio = 4  # model output compared to model input

In [None]:
dataset_val = torchvision.datasets.VOCDetection(
    root="../tmp/VOC", year="2007", image_set="val", download=False
)

dataset_val = torchvision.datasets.wrap_dataset_for_transforms_v2(dataset_val)

print(len(dataset_val))

In [None]:
img, lbl = dataset_val[2000]
img, lbl

In [None]:
image_with_boxes = get_image_with_bboxes(img, lbl["boxes"], lbl["labels"])
plt.imshow(image_with_boxes)

### Apply Encoding and Transformations

In [None]:
encoder = CenternetEncoder(input_height, input_width)

transform = transforms.Compose(
    [transforms.Resize(size=(input_width, input_height)), transforms.ToTensor()]
)

torch_dataset = DatasetIterator(
    dataset=dataset_val, transformer=transform, encoder=encoder
)

In [None]:
img_, lbl_encoded = torch_dataset[30]
img_.shape, lbl_encoded.numpy().shape

In [None]:
plt.imshow(img_.numpy().transpose(1, 2, 0))

In [None]:
plt.imshow(lbl_encoded.numpy()[:, :, 7])

### Prepare for Training

In [None]:
batch_generator = torch.utils.data.DataLoader(
    torch_dataset, batch_size=4, num_workers=0, shuffle=False
)

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = ModelBuilder(input_width, input_height, alpha=0.25).to(device)
summary(model, input_size=(3, 256, 256), batch_size=-1)

In [None]:
lr = 0.01
parameters = list(model.parameters())
optimizer = torch.optim.Adam(parameters, lr=lr)

In [None]:
loss_history = []

In [None]:
# WARNING !!! Optionally can be uncommented and run after first 1500 epochs.
# EPOCHS should be changed to 500 in order to reach overfit
# (loss should be between 0 and 1)

# for g in optimizer.param_groups:
#     g["lr"] = 0.001

### Train model

In [None]:
EPOCHS = 300
model.train(True)

for epoch in range(EPOCHS):
    print("EPOCH {}:".format(epoch + 1))
    for i, data in enumerate(batch_generator):
        img, gt_data = data
        img = img.to(device).contiguous()

        gt_data = gt_data.to(device)
        gt_data.requires_grad = False

        loss_dict = model(img, gt=gt_data)
        optimizer.zero_grad()  # compute gradient and do optimize step
        loss_dict["loss"].backward()

        optimizer.step()

    loss_history.append(loss_dict["loss"].cpu().detach().numpy())
    print(loss_dict["loss"])

In [None]:
plt.plot(np.log(loss_history))

### Check training results

In [None]:
model.eval()

In [None]:
img, lbl = dataset_val[4]
image_with_boxes = get_image_with_bboxes(img, lbl["boxes"], lbl["labels"])
plt.imshow(image_with_boxes)

In [None]:
lbl["labels"]

In [None]:
pred_input = transform(img)
plt.imshow(pred_input.numpy().transpose(1, 2, 0))

In [None]:
reshaped = torch.reshape(pred_input, [1, 3, 256, 256])
pred = model.forward(reshaped.to(device)).cpu().detach().numpy()
pred.shape

In [None]:
for i in range(20):
    print(np.max(pred[0, i, :, :]))

In [None]:
# TODO: take 4 channels (distance from image borders), show bounding boxes

In [None]:
plt.imshow(pred[0, 8, ...])

### Save model weights

In [None]:
torch.save(model, "pascal_voc_30img.pt")

In [None]:
torch.save(model.state_dict(), "pascal_voc_30img_state_dict.pt")

In [None]:
model.load_state_dict(torch.load("pascal_voc_30img_state_dict.pt", weights_only=True))

### Draw bounding boxes based on model prediction