In [1]:
import torch
from torch.utils.data import Dataset
import torchvision
import numpy as np
import torchvision.transforms as transforms
from PIL import Image
import glob
import torch.nn as nn
import math
from torch.autograd import Variable
import os
from tqdm import tqdm
import pickle

In [2]:
TRAIN_DIR = './train/VOC-2012-train/' 
VALID_DIR = './train/VOC-2012-valid/'

## Compress train

In [9]:
train_files= os.listdir(TRAIN_DIR)

train_images = []
for file in tqdm(train_files):
    image = Image.open(os.path.join(TRAIN_DIR, file))
    train_images.append(np.array(image))
    image.close()

100%|██████████| 3000/3000 [00:22<00:00, 133.22it/s]


In [11]:
with open('./compress_data/voc_train.pkl', 'wb') as f:
    pickle.dump(train_images, f)

## Compress valid

In [13]:
valid_files = os.listdir(VALID_DIR)

valid_images = []
for file in tqdm(valid_files):
    image = Image.open(os.path.join(VALID_DIR, file))
    valid_images.append(np.array(image))
    image.close()

100%|██████████| 100/100 [00:00<00:00, 115.24it/s]


In [14]:
with open('./compress_data/voc_valid.pkl', 'wb') as f:
    pickle.dump(valid_images, f)

## Dataset train

In [3]:
def lr_transformer(crop_size, upscale_factor):
    return transforms.Compose(
                [
                    transforms.Resize((crop_size//upscale_factor, crop_size//upscale_factor), Image.BICUBIC),
                    #transforms.ToTensor(), 
                    #transforms.Normalize([0.5,0.5,0.5],[0.5,0.5,0.5])
                ]
        )
    
def hr_transformer(crop_size):
    return transforms.Compose(
            [
                
                transforms.RandomCrop(crop_size),
                #transforms.Resize((hr_h, hr_h), Image.BICUBIC),
                #transforms.ToTensor(),
                #transforms.Normalize([0.5, 0.5, 0.5],[0.5, 0.5, 0.5])
            ]
        )
    

In [4]:
class TrainDataset(Dataset):
    def __init__(self, path, crop_size, upscale_factor):
        super(TrainDataset, self).__init__()
        with open(path, 'rb') as f:
            self.images = pickle.load(f)
        
        self.hr_transformer = transforms.Compose(
            [
                transforms.RandomCrop(crop_size),
            ]
        )
        self.lr_transformer = transforms.Compose(
            [
                transforms.Resize((crop_size//upscale_factor, crop_size//upscale_factor), Image.BICUBIC),
            ]
        )
    
    def __getitem__(self, index):
        
        image = Image.fromarray(self.images[index])
        hr = self.hr_transformer(image)
        lr = self.lr_transformer(hr)
        
        to_tensor = transforms.ToTensor()
        
        
        return  {'lr':to_tensor(lr), 'hr':to_tensor(hr)}
    
    def __len__(self):
        return len(self.images)
        

In [13]:
class ValidDataset(Dataset):
    def __init__(self, path, crop_size, upscale_factor):
        super(ValidDataset, self).__init__()
        with open(path, 'rb') as f:
            self.images = pickle.load(f)
        
        self.hr_transformer = transforms.Compose(
            [
                transforms.RandomCrop(crop_size),
            ]
        )
        self.lr_transformer = transforms.Compose(
            [
                transforms.Resize((crop_size//upscale_factor, crop_size//upscale_factor), Image.BICUBIC),
            ]
        )
        self.hr_restore_transformer = transforms.Compose(
            [
                transforms.Resize(crop_size, Image.BICUBIC)
            ]
        )
        
    def __getitem__(self, index):
        
        image = Image.fromarray(self.images[index])
        hr = self.hr_transformer(image)
        lr = self.lr_transformer(hr)
        hr_restore = self.hr_restore_transformer(lr)
        
        to_tensor = transforms.ToTensor()
        
        
        return  {'lr':to_tensor(lr), 'hr':to_tensor(hr), 'hr_restore':to_tensor(hr_restore)}
    
    def __len__(self):
        return len(self.images)
        

In [6]:
traindataset = TrainDataset(path='./compress_data/voc_train.pkl', crop_size=88, upscale_factor=4)

In [7]:
lr, hr = traindataset[0]['lr'], traindataset[0]['hr']

In [14]:
validdataset = ValidDataset(path='./compress_data/voc_valid.pkl', crop_size=88, upscale_factor=4)

In [15]:
x = validdataset[0]