In [None]:
import torchvision
from matplotlib import pyplot as plt
import torchvision.transforms.v2 as transforms
import torch
from datasets.encoders import CenternetEncoder
from datasets.dataset import DatasetIterator
from utils.visualizer import get_image_with_bboxes

%load_ext autoreload
%autoreload 2

print("GPU is available: ", torch.cuda.is_available() )

In [None]:
input_height = input_width = 256
down_ratio = 4  # model output compared to model input

In [None]:
dataset_val = torchvision.datasets.VOCDetection(
    root="VOC", year="2007", image_set="val", download=False
)
dataset_val = torchvision.datasets.wrap_dataset_for_transforms_v2(dataset_val)
print(len(dataset_val))

In [None]:
img, lbl = dataset_val[0]
img, lbl

In [None]:
image_with_boxes = get_image_with_bboxes(img, lbl["boxes"], lbl["labels"])
plt.imshow(image_with_boxes)

In [None]:
transform = transforms.Compose(
    [
        transforms.RandomHorizontalFlip(p=0.5),
        transforms.Resize(size=(input_width, input_height)),
    ]
)
img, bboxes, labels = transform(img, lbl["boxes"], lbl["labels"])

In [None]:
encoder = CenternetEncoder(input_height, input_width)
lbl_encoded = encoder.encode(bboxes, labels)
plt.imshow(lbl_encoded[..., 8])

In [None]:
image_with_boxes = get_image_with_bboxes(img, bboxes, labels)
plt.imshow(image_with_boxes)

In [None]:
encoder = CenternetEncoder(input_height, input_width)
transformations = transforms.Compose(
    [transforms.Resize(size=(input_width, input_height)), transforms.ToTensor()]
)
torch_dataset = DatasetIterator(
    dataset=dataset_val, transformer=transformations, encoder=encoder
)

In [None]:
batch_generator = torch_dataset.get_examples(batch_size=24)

In [None]:
input, gt_data = next(iter(batch_generator))
torch.min(input), torch.max(input)