In [None]:
import os
import numpy as np
from PIL import Image
import torch
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms.v2 as v2

In [None]:
# ImageDataset class with target for train

class ImageDatasetTrain(Dataset):
    def __init__(self, images_folder, target_folder, transform=None):
        self.images_folder = images_folder
        self.target_folder = target_folder
        self.transform = transform
        self.files = os.listdir(images_folder)
        self.files = list(map(lambda x: x.split(".")[0], self.files))

    def __len__(self):
        return len(self.files)
    
    def __getitem__(self, idx):
        image_path = os.path.join(self.images_folder, self.files[idx] + ".jpeg")
        target_path = os.path.join(self.target_folder, self.files[idx] + ".txt")

        image = Image.open(image_path)
        target = np.loadtxt(target_path)

        if self.transform:
            image = self.transform(image)

        return image, target
    
# ImageDataset class without target for test

class ImageDatasetTest(Dataset):
    def __init__(self, images_folder, transform=None):
        self.images_folder = images_folder
        self.transform = transform
        self.files = os.listdir(images_folder)
        self.files = list(map(lambda x: x.split(".")[0], self.files))

    def __len__(self):
        return len(self.files)
    
    def __getitem__(self, idx):
        image_path = os.path.join(self.images_folder, self.files[idx] + ".jpeg")

        image = Image.open(image_path)

        if self.transform:
            image = self.transform(image)

        return image

In [None]:
# Create a transform for the images to perform object detection using YOLO v7

transform = v2.Compose([
    v2.Resize((416, 416)),
    v2.ToImage(),
    v2.ToDtype(torch.float32, scale=True)
])

# Create ImageDatasetTrain and ImageDatasetTest objects

train_dataset = ImageDatasetTrain(TRAIN_FOLDER, TRAIN_TARGET, transform)
test_dataset = ImageDatasetTest(TEST_FOLDER, transform)

# Create DataLoader objects

train_loader = DataLoader(train_dataset, batch_size=1, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False)

# Debugging

if DEBUG:
    print("Train dataset length:", len(train_dataset))
    print("Test dataset length:", len(test_dataset))

    for i, (image, target) in enumerate(train_loader):
        print("Image shape:", image.shape)
        print("Target shape:", target.shape)
        break

    for i, image in enumerate(test_loader):
        print("Image shape:", image.shape)
        break