In [4]:
from PIL import Image
import os
from torch.utils.data import Dataset
import numpy as np

In [5]:
class SummerWinterDataset(Dataset):
    def __init__(self, root_winter, root_summer, transform=None):
        self.root_winter = root_winter
        self.root_summer = root_summer
        self.transform = transform

        self.winter_images = os.listdir(root_winter)
        self.summer_images = os.listdir(root_summer)
        self.length_dataset = max(len(self.winter_images), len(self.summer_images)) 
        self.winter_len = len(self.winter_images)
        self.summer_len = len(self.summer_images)

    def __len__(self):
        return self.length_dataset

    def __getitem__(self, index):
        winter_img = self.winter_images[index % self.winter_len]
        summer_img = self.summer_images[index % self.summer_len]

        winter_path = os.path.join(self.root_winter, winter_img)
        summer_path = os.path.join(self.root_summer, summer_img)

        winter_img = np.array(Image.open(winter_path).convert("RGB"))
        summer_img = np.array(Image.open(summer_path).convert("RGB"))

        if self.transform:
            augmentations = self.transform(image=winter_img, image0=summer_img)
            winter_img = augmentations["image"]
            summer_img = augmentations["image0"]

        return winter_img, summer_img
