In [1]:
import os
import random

import cv2
import numpy as np
import torch.utils.data as data

from PIL import Image, ImageEnhance, ImageOps
from torchvision.transforms import ToTensor


In [2]:
def load_img(filepath):
    img = Image.open(filepath).convert("RGB")
    return img


def rescale_img(img, scale):
    size_in = img.size
    new_size_in = tuple([int(x * scale) for x in size_in])
    img = img.resize(new_size_in, resample=Image.BICUBIC)
    return img


def get_patch(img_in, img_tar, patch_size, scale, ix=-1, iy=-1):
    (ih, iw) = img_in.size

    patch_mult = scale
    tp = patch_mult * patch_size
    ip = tp // scale

    if ix == -1:
        ix = random.randrange(0, iw - ip + 1)
    if iy == -1:
        iy = random.randrange(0, ih - ip + 1)

    (tx, ty) = (scale * ix, scale * iy)

    img_in = img_in.crop((ty, tx, ty + tp, tx + tp))
    img_tar = img_tar.crop((ty, tx, ty + tp, tx + tp))
    return img_in, img_tar


def augment(img_in, img_tar, flip_h=True, rot=True):
    info_aug = {"flip_h": False, "flip_v": False, "trans": False}

    if random.random() < 0.5 and flip_h:
        img_in = ImageOps.flip(img_in)
        img_tar = ImageOps.flip(img_tar)
        info_aug["flip_h"] = True

    if rot:
        if random.random() < 0.5:
            img_in = ImageOps.mirror(img_in)
            img_tar = ImageOps.mirror(img_tar)
            info_aug["flip_v"] = True
        if random.random() < 0.5:
            img_in = img_in.rotate(180)
            img_tar = img_tar.rotate(180)
            info_aug["trans"] = True

    return img_in, img_tar, info_aug


In [3]:
class VOC2007(data.Dataset):
    def __init__(
        self, img_folder, patch_size, upscale_factor, data_augmentation, transform=None
    ):
        super(VOC2007, self).__init__()
        self.imgFolder = img_folder
        self.image_filenames = [
            os.path.join(self.imgFolder, x)
            for x in os.listdir(self.imgFolder)
            if self.is_image_file(x)
        ]

        self.image_filenames = self.image_filenames
        self.patch_size = patch_size
        self.upscale_factor = upscale_factor
        self.transform = transform
        self.data_augmentation = data_augmentation

    def is_image_file(self, filename):
        return any(
            filename.endswith(extension)
            for extension in [".bmp", ".png", ".jpg", ".jpeg"]
        )

    def __getitem__(self, index):

        ori_img = load_img(self.image_filenames[index])  # PIL image
        width, height = ori_img.size
        ratio = min(width, height) / 384

        newWidth = int(width / ratio)
        newHeight = int(height / ratio)
        ori_img = ori_img.resize((newWidth, newHeight), Image.LANCZOS)
        
        high_image = ori_img.copy()
        high_image.save("high_image.png")

        ## color and contrast *dim*
        color_dim_factor = 0.3 * random.random() + 0.7
        contrast_dim_factor = 0.3 * random.random() + 0.7
        ori_img = ImageEnhance.Color(ori_img).enhance(color_dim_factor)
        ori_img = ImageEnhance.Contrast(ori_img).enhance(contrast_dim_factor)

        ori_img = cv2.cvtColor((np.asarray(ori_img)), cv2.COLOR_RGB2BGR)  # cv2 image
        ori_img = (ori_img.clip(0, 255)).astype("uint8")
        low_img = ori_img.astype("double") / 255.0

        # generate low-light image
        beta = 0.5 * random.random() + 0.5
        alpha = 0.1 * random.random() + 0.9
        gamma = 3.5 * random.random() + 1.5
        low_img = beta * np.power(alpha * low_img, gamma)

        low_img = low_img * 255.0
        low_img = (low_img.clip(0, 255)).astype("uint8")
        low_img = Image.fromarray(cv2.cvtColor(low_img, cv2.COLOR_BGR2RGB))

        img_in, img_tar = low_img, high_image
        
        if self.data_augmentation:
            img_in, img_tar, _ = augment(img_in, img_tar)

        if self.transform:
            img_in = self.transform(img_in)
            img_tar = self.transform(img_tar)

        return img_in, img_tar

    def __len__(self):
        return len(self.image_filenames)




In [4]:
# Iterate through the dataset saving the images
def save_images(dataset):
    # img_in and img_tar are PIL images
    for i, (img_in, img_tar) in enumerate(dataset):
        img_in.save(f"../datasets/test/to_report/low/{i}.png")
        img_tar.save(f"../datasets/test/to_report/high/{i}.png")        
        if i == 15:
            break

In [5]:
dataset = VOC2007("../datasets/test/VOC2007/JPEGImages", 128, 4, False)
save_images(dataset)