In [28]:
import os
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import pandas as pd
from pathlib import Path

from datasets import load_dataset

from dataset import RetroGamesHelper
from config_file import config

In [29]:
class CustomDataset(Dataset):
    def __init__(self, dataset_path: Path, dataset_split: str = "train", fold_index: int = None, k_folds: int = None, transform=None):
        self.dataset_path = Path(dataset_path)
        dataset = load_dataset(str(dataset_path))
        self.dataset = dataset[dataset_split]
        self.retro_helper = RetroGamesHelper(dataset_path / dataset_split, dataset_path / f"{dataset_split}.csv")
        self.validation_dataset = None

        if fold_index is not None and k_folds is not None:
            train, val = self.retro_helper.get_fold(fold_index, k_folds)
            self.dataset = train
            self.validation_dataset = val

        self.transform = transform

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):

        if isinstance(self.dataset, pd.core.frame.DataFrame):
            screenshot = self.dataset.iloc[idx]
        else:
            screenshot = self.dataset[idx]
        filepath = self.dataset_path / screenshot["file_name"]
        prompt = screenshot["caption"]

        image = Image.open(filepath).convert("RGB")
        if self.transform:
            image = self.transform(image)

        return {"pixel_values": image, "prompt": prompt}

In [30]:
dataset_path = config.DATASET_PATH
dataset_split = "train"
fold_index = 3
k_folds = 5

dataset = CustomDataset(dataset_path, dataset_split, fold_index = fold_index, k_folds=k_folds)

In [31]:
dataset[0]

{'pixel_values': <PIL.Image.Image image mode=RGB size=512x512>,
 'prompt': 'a screenshot from a video game shows a man in a suit'}