In [1]:
import torch
import torchvision
from torch.utils.data import Dataset, DataLoader
from torchvision import datasets
from torchvision.io import read_image
import torchvision.transforms as transforms

import matplotlib.pyplot as plt
import os
import numpy as np

In [5]:
class SRDataset(Dataset):
    """Define training/valid dataset loading methods.
    Args:
        image_dir (str): Train/Valid dataset address.
        image_size (int): High resolution image size.
        upscale_factor (int): Image up scale factor.
        mode (str): Data set loading method, the training data set is for data enhancement, and the
            verification dataset is not for data enhancement.
    """

    def __init__(self, image_dir, image_size = 64, upscale_factor = 2, jitter_val = 0.2, mode = 'Train') -> None:
        self.image_file_names = [os.path.join(image_dir, image_file_name) for image_file_name in os.listdir(image_dir)]
        self.image_size = image_size
        self.upscale_factor = upscale_factor
        self.jitter_val = jitter_val
        # Load training dataset or test dataset
        self.mode = mode

    def __getitem__(self, index):
        image = read_image(self.image_file_names[index]).float()/255.
        if self.mode == 'Train':
            hr_transformer = transforms.Compose([
                transforms.RandomCrop(self.image_size),
                transforms.ColorJitter(brightness=self.jitter_val, contrast=self.jitter_val, saturation=self.jitter_val, hue=self.jitter_val),
            ])  
        else:
            hr_transformer = transforms.Compose([
                transforms.CenterCrop(self.image_size),
                transforms.ColorJitter(brightness=self.jitter_val, contrast=self.jitter_val, saturation=self.jitter_val, hue=self.jitter_val),

            ])
        
        lr_transformer = transforms.Compose([
            transforms.RandomCrop(self.image_size),
            transforms.ColorJitter(brightness=self.jitter_val, contrast=self.jitter_val, saturation=self.jitter_val, hue=self.jitter_val),
            transforms.Resize(size=(int(self.image_size / self.upscale_factor), int(self.image_size / self.upscale_factor))),
        ])
        hr_image = hr_transformer(image)
        lr_image = lr_transformer(image)
        return lr_image, hr_image

    def __len__(self):
        return len(self.image_file_names)

In [6]:
train_path = './train'
test_path = './eval'
train_dataset = SRDataset(train_path)
test_dataset = SRDataset(test_path)
train_dataloader = DataLoader(
    train_dataset,
    batch_size=4,
    shuffle=True,
    num_workers=2,
    drop_last=True,
    pin_memory=True,
)
print(f" * Dataset contains {len(train_dataset)} image(s).")
for _, batch in enumerate(train_dataloader, 0):
    lr_image, hr_image = batch
    torchvision.io.write_png(lr_image[0, ...].mul(255).byte(), "lr_image.png")
    torchvision.io.write_png(hr_image[0, ...].mul(255).byte(), "hr_image.png")
    break # we deliberately break after one batch as this is just a test


 * Dataset contains 301 image(s).
