In [37]:
import os
from PIL import Image
import torchvision.transforms as transform
from torch.utils.data import Dataset, DataLoader

In [61]:
# change the dimentions to 720p and then to torch tensors
lr = transform.Compose([
    transform.Resize((720, 1280)),
    transform.ToTensor()
])

# convert to torch tensors
hr = transform.Compose([
    transform.ToTensor()
])

In [66]:
class highres_img_dataset(Dataset):
    def __init__(self, image_dir, lr_transform=None, hr_transform=None):
        self.image_dir = image_dir
        self.lr_transform = lr_transform
        self.hr_transform = hr_transform
        self.image_files = [os.path.join(image_dir, file) for file in os.listdir(image_dir) if file.lower().endswith('.jpg')]
        
    def __len__(self):
        return len(self.image_files)
    
    def __getitem__(self, idx):
        img_path = self.image_files[idx]
        hr_image = Image.open(img_path).convert('YCbCr')
        
        if self.lr_transform:
            lr_image = self.lr_transform(hr_image)
            
        if self.hr_transform:
            hr_image = self.hr_transform(hr_image)
            
        return lr_image, hr_image

In [67]:
img_dataset = highres_img_dataset(image_dir='images/training_set', lr_transform=lr, hr_transform=hr)

In [68]:
dataloader = DataLoader(img_dataset, batch_size=16, shuffle=True)

In [70]:
for batch in dataloader:
    print(batch[0].shape, batch[1].shape)
    break

torch.Size([16, 3, 720, 1280]) torch.Size([16, 3, 1440, 2560])
