In [1]:
import torch
import numpy as np
import matplotlib.pyplot as plt
from dataset.flame import FlameFOV
from dataset.flame import FlameThermal
from dataset.flame import FlameRGB
from dataset.flame import FlameSatelite
from torchvision.transforms import transforms
from torchvision.utils import draw_bounding_boxes
from torchvision.ops import box_convert
import torchvision.transforms.functional as F
from torch.utils.data import DataLoader
from torch.utils.data import random_split
from helper.image_processing import SquarePadTransform
from helper.utils import collate_fn

In [2]:
compose = transforms.Compose(
    [
        # SquarePadTransform(),
        transforms.Resize((256, 256)),
        transforms.ToTensor(),
        # transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
    ]
)

In [None]:
dataset = FlameSatelite(download=True, transform=compose)

train_size = int(0.8 * len(dataset))
test_size = len(dataset) - train_size
train_dataset, test_dataset = random_split(dataset, [train_size, test_size])

train_loader = DataLoader(
    train_dataset,
    batch_size=32,
    shuffle=True,
    num_workers=4,
    collate_fn=collate_fn
)

test_loader = DataLoader(
    test_dataset,
    batch_size=32,
    shuffle=False,
    num_workers=4,
    collate_fn=collate_fn,  # Use custom collate function
)

In [None]:
# Test the DataLoader
for images, bboxes in train_loader:
    print("Image batch shape:", images.shape)
    print("Bounding box batch shape:", bboxes.shape)
    break

In [4]:
image, bbox = next(iter(train_loader))

In [5]:
plt.rcParams["savefig.bbox"] = "tight"

def show(imgs):
    if not isinstance(imgs, list):
        imgs = [imgs]
    fig, axs = plt.subplots(ncols=len(imgs), squeeze=False)
    for i, img in enumerate(imgs):
        img = img.detach()
        img = F.to_pil_image(img)
        axs[0, i].imshow(np.asarray(img))
        axs[0, i].set(xticklabels=[], yticklabels=[], xticks=[], yticks=[])

In [None]:
bbox_xyxy = box_convert(bbox[11], in_fmt="cxcywh", out_fmt="xyxy")

result = draw_bounding_boxes(image[11], bbox_xyxy, colors="blue", width=5)
show(result)