In [2]:
from os import listdir
from os.path import join
from torch.utils.data import Dataset
import torchvision.transforms as transforms
from PIL import Image, ImageFilter
import os


In [6]:

def is_image_file(filename):
    return any(filename.endswith(extension) for extension in [".png", ".jpg", ".jpeg"])

def load_img(filepath):
    img = Image.open(filepath).convert('YCbCr')
    y, _, _ = img.split()  # Y kanalını alıyoruz
    return y

class DatasetFromFolder(Dataset):
    def __init__(self, low_res_dir, high_res_dir):
        """
        Args:
            low_res_dir (str): Zaten düşük çözünürlüklü görsellerin bulunduğu klasör.
            high_res_dir (str): Orijinal yüksek çözünürlüklü görsellerin bulunduğu klasör.
        """
        super(DatasetFromFolder, self).__init__()
        self.low_res_filenames = [join(low_res_dir, x) for x in os.listdir(low_res_dir) if is_image_file(x)]
        self.high_res_filenames = [join(high_res_dir, x) for x in os.listdir(high_res_dir) if is_image_file(x)]

        assert len(self.low_res_filenames) == len(self.high_res_filenames), "Low-res and high-res images must match in count."

        self.input_transform = transforms.ToTensor()
        self.target_transform = transforms.ToTensor()

    def __getitem__(self, index):
        low_res_img = load_img(self.low_res_filenames[index])
        high_res_img = load_img(self.high_res_filenames[index])

        low_res_img = self.input_transform(low_res_img)
        high_res_img = self.target_transform(high_res_img)

        return low_res_img, high_res_img

    def __len__(self):
        return len(self.low_res_filenames)


def is_image_file(filename):
    return any(filename.endswith(extension) for extension in [".png", ".jpg", ".jpeg"])

def load_img(filepath):
    img = Image.open(filepath).convert('YCbCr')
    y, _, _ = img.split()  # Y kanalını alıyoruz, YCbCr renk uzayında
    return y

CROP_SIZE = 32

class DatasetFromFolder(Dataset):
    def __init__(self, low_res_dir, high_res_dir, zoom_factor):
        """
        Args:
            low_res_dir (str): Düşük çözünürlüklü görsellerin bulunduğu klasör.
            high_res_dir (str): Yüksek çözünürlüklü görsellerin bulunduğu klasör.
            zoom_factor (int): Görsellerin düşük çözünürlüklü hale getirilmesi için kullanılan faktör.
        """
        super(DatasetFromFolder, self).__init__()
        self.low_res_filenames = [join(low_res_dir, x) for x in os.listdir(low_res_dir) if is_image_file(x)]
        self.high_res_filenames = [join(high_res_dir, x) for x in os.listdir(high_res_dir) if is_image_file(x)]

        # Düşük çözünürlüklü görsellerin boyutunu doğrulamak için
        assert len(self.low_res_filenames) == len(self.high_res_filenames), "Low-res and high-res images must match in count."
        
        crop_size = CROP_SIZE - (CROP_SIZE % zoom_factor)  # Geçerli crop boyutu
        self.input_transform = transforms.Compose([
            transforms.CenterCrop(crop_size),  # Görseli kes
            transforms.Resize(crop_size // zoom_factor),  # Düşük çözünürlük (subsampling)
            transforms.Resize(crop_size, interpolation=Image.BICUBIC),  # Yeniden yüksek çözünürlük
            transforms.ToTensor()])
        
        self.target_transform = transforms.Compose([
            transforms.CenterCrop(crop_size),  # Hedefi orijinal çözünürlükte tut
            transforms.ToTensor()])

    def __getitem__(self, index):
        low_res_img = load_img(self.low_res_filenames[index])  # Düşük çözünürlüklü görsel
        high_res_img = load_img(self.high_res_filenames[index])  # Yüksek çözünürlüklü hedef

        low_res_img = self.input_transform(low_res_img)  # Düşük çözünürlüklü input'a dönüşüm
        high_res_img = self.target_transform(high_res_img)  # Yüksek çözünürlüklü hedefe dönüşüm

        return low_res_img, high_res_img

    def __len__(self):
        return len(self.low_res_filenames)


def is_image_file(filename):
    return any(filename.endswith(extension) for extension in [".png", ".jpg", ".jpeg"])

def load_img(filepath):
    img = Image.open(filepath).convert('YCbCr')
    y, _, _ = img.split()
    return y

CROP_SIZE = 32

class DatasetFromFolder(Dataset):
    def __init__(self, image_dir, zoom_factor):
        super(DatasetFromFolder, self).__init__()
        self.image_filenames = [join(image_dir, x) for x in listdir(image_dir) if is_image_file(x)]

        crop_size = CROP_SIZE - (CROP_SIZE % zoom_factor) # Valid crop size
        self.input_transform = transforms.Compose([transforms.CenterCrop(crop_size), # cropping the image
                                      transforms.Resize(crop_size//zoom_factor),  # subsampling the image (half size)
                                      transforms.Resize(crop_size, interpolation=Image.BICUBIC),  # bicubic upsampling to get back the original size 
                                      transforms.ToTensor()])
        self.target_transform = transforms.Compose([transforms.CenterCrop(crop_size), # since it's the target, we keep its original quality
                                       transforms.ToTensor()])

    def __getitem__(self, index):
        input = load_img(self.image_filenames[index])
        target = input.copy()
        
        # input = input.filter(ImageFilter.GaussianBlur(1)) 
        input = self.input_transform(input)
        target = self.target_transform(target)

        return input, target

    def __len__(self):
        return len(self.image_filenames)
