In [1]:
import torch
from torch.utils.data import Dataset
import torchvision
import numpy as np
import torchvision.transforms as transforms
from PIL import Image
import glob
import torch.nn as nn
import math
from torch.autograd import Variable

In [2]:
class ImageDataset(Dataset):
    def __init__(self, root, hr_shape):
        hr_h, hr_w = hr_shape
        
        self.lr_transformer = transforms.Compose(
                [
                    transforms.Resize((hr_h//4, hr_h//4), Image.BICUBIC),
                    transforms.ToTensor(), 
                    transforms.Normalize([0.5,0.5,0.5],[0.5,0.5,0.5])
                ]
        )
        
        self.hr_transformer = transforms.Compose(
                [
                    transforms.Resize((hr_h, hr_h), Image.BICUBIC),
                    transforms.ToTensor(),
                    transforms.Normalize([0.5, 0.5, 0.5],[0.5, 0.5, 0.5])
                ]
        )
        
        self.files = sorted(glob.glob(root+'/*.*'))
        
    def __getitem__(self, index):
        img = Image.open(self.files[index % len(self.files)])
        
        lr = self.lr_transformer(img)
        hr = self.hr_transformer(img)
        
        return  {'lr':lr, 'hr':hr}
    def __len__(self):
        return len(self.files)

In [19]:
#voc2012 = ImageDataset('./train/VOC-2012-train/', (88,88))
voc2012_test = ImageDataset('./train/VOC-2012-valid/',(88,88))

In [20]:
len(voc2012_test)

100

In [21]:
voc2012_test = [voc2012_test[i] for i in range(len(voc2012_test))]

In [22]:
with open('./compress_data/test_data.pkl','wb') as f:
    pickle.dump(voc2012_test, f)

In [5]:
dataloader = torch.utils.data.DataLoader(voc2012, 
                                         batch_size=32, 
                                         shuffle=True, 
                                         num_workers=2)

In [6]:
import pickle

In [12]:
with open('./compress_data/train_dataloader.pkl','wb') as f:
    pickle.dump(dataloader, f)

In [8]:
with open('./compress_data/train_dataset.pkl','rb') as f:
    loaded_voc2012 = pickle.load(f) 

In [11]:
loaded_voc2012[0]

{'lr': tensor([[[-0.0588,  0.0667,  0.2078,  ..., -0.1059, -0.7569, -0.6627],
          [ 0.0196,  0.1294,  0.2706,  ..., -0.4118, -0.7098, -0.6784],
          [-0.0039,  0.1216,  0.1686,  ..., -0.2549, -0.7098, -0.6549],
          ...,
          [-0.0118,  0.1451,  0.2392,  ..., -0.9137, -0.9059, -0.8980],
          [-0.3255,  0.0118,  0.1451,  ..., -0.9216, -0.9137, -0.8824],
          [-0.7647, -0.2000,  0.0510,  ..., -0.9373, -0.9216, -0.8824]],
 
         [[-0.2706, -0.2863, -0.1922,  ..., -0.3098, -0.8431, -0.7804],
          [-0.2235, -0.1843, -0.0510,  ..., -0.5765, -0.8039, -0.7804],
          [-0.2157, -0.1686, -0.1529,  ..., -0.4667, -0.7961, -0.7725],
          ...,
          [-0.0353,  0.0824,  0.1608,  ..., -0.8824, -0.8902, -0.8824],
          [-0.3490, -0.0431,  0.0588,  ..., -0.8980, -0.8980, -0.8588],
          [-0.7647, -0.2471, -0.0196,  ..., -0.9137, -0.9059, -0.8667]],
 
         [[-0.3882, -0.4667, -0.4118,  ..., -0.4510, -0.9294, -0.8902],
          [-0.3176, -0

In [6]:
class Residual_Block(nn.Module):
    def __init__(self, in_channels):
        super(Residual_Block, self).__init__()
        self.net = nn.Sequential(
            nn.Conv2d(in_channels=in_channels, out_channels=in_channels, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(in_channels, 0.8),
            nn.ReLU(),
            nn.Conv2d(in_channels=in_channels, out_channels=in_channels, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(in_channels, 0.8)
        )
    def forward(self, x):
        return self.net(x) + x

In [7]:
class Upsampling_Block(nn.Module):
    def __init__(self, in_channels, up_scale):
        super(Upsampling_Block, self).__init__()
        self.net = nn.Sequential(
            nn.Conv2d(in_channels=in_channels, out_channels=in_channels * up_scale** 2, kernel_size=3, padding=1),
            nn.BatchNorm2d(in_channels * up_scale**2),
            nn.PixelShuffle(upscale_factor=up_scale),
            nn.PReLU(),
        )
    def forward(self, x):
        return self.net(x)

In [8]:
class Generator(nn.Module):
    def __init__(self, in_channels = 3, n_residual_blocks = 16, up_scale = 4):
        super(Generator, self).__init__()
        
        self.num_upsample_block = int(math.log(up_scale, 2))
        
        self.block1 = nn.Sequential(
            nn.Conv2d(in_channels=in_channels, out_channels=64, kernel_size=9, stride=1, padding=4),
            nn.PReLU(),
        )
        
        
        res_blocks = []
        for _ in range(n_residual_blocks):
            res_blocks.append(Residual_Block(in_channels=64))
        self.residual_blocks = nn.Sequential(*res_blocks) 
        
        
        self.block2 = nn.Sequential(
            nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64, 0.8),
        )
        
        
        upsampling = []
        for i in range(self.num_upsample_block):
            upsampling.append(Upsampling_Block(in_channels=64, up_scale=2))          
        self.upsampling = nn.Sequential(*upsampling)
        
        
        self.block3 = nn.Sequential(
            nn.Conv2d(in_channels=64, out_channels=in_channels, kernel_size=9, stride=1, padding=4),
            nn.Tanh(),
        )
    
    def forward(self, x):
        out1 = self.block1(x)
        
        out = self.residual_blocks(out1)
        
        out2 = self.block2(out)
        
        out = torch.add(out1, out2)
        
        
        out = self.upsampling(out)
        
        out = self.block3(out)
    
        return out
        
        

In [9]:
class Discriminator(nn.Module):
    def __init__(self, in_channels=3):
        super(Discriminator, self).__init__()
        self.net = nn.Sequential(
            nn.Conv2d(in_channels=in_channels, out_channels=in_channels, kernel_size=3, padding=1),
        )
        
        def Discriminator_Block(in_c, out_c, first=False):
            layers = []
            layers.append(nn.Conv2d(in_c, out_c, kernel_size=3, stride=1, padding=1))
            
            if first == False:
                layers.append(nn.BatchNorm2d(out_c))
            
            layers.append(nn.LeakyReLU(0.2, inplace=True))
            layers.append(nn.Conv2d(out_c, out_c, kernel_size=3, stride=1, padding=1))
            layers.append(nn.BatchNorm2d(out_c))
            layers.append(nn.LeakyReLU(0.2, inplace=True))
            
            return layers
        
        layers = []
        in_c = in_channels
        for i, out_c in enumerate([64, 128, 256, 512]):
            block_layer = Discriminator_Block(in_c, out_c, first=(i==0))
            layers.extend(block_layer)
            in_c = out_c
        
        layers.append(nn.Conv2d(out_c, 1, kernel_size=3, stride=1, padding=1))
         
        self.net = nn.Sequential(*layers)
        
    def forward(self, x):
        out = self.net(x)
        print(out.shape)
        return out
    
        

In [10]:
generator = Generator(in_channels=3, n_residual_blocks=1, up_scale=4)
discriminator = Discriminator(in_channels=3)

In [11]:
optimizer_G = torch.optim.Adam(params=generator.parameters())
optimizer_D = torch.optim.Adam(params=discriminator.parameters())

In [12]:
adv_loss = nn.BCELoss()

In [13]:
def train(epochs):
    for epoch in range(epochs):
        Gloss_epoch = 0.
        Dloss_epoch = 0.
        for batch, data in enumerate(dataloader):
            
            
            lr = Variable(data['lr'])
            hr = Variable(data['hr'])
            batch_size = lr.shape[0]
            
            valid = Variable(torch.Tensor(batch_size, 1).fill_(1.0), requires_grad=False)
            fake = Variable(torch.Tensor(batch_size, 1).fill_(0.0), requires_grad=False)
            
            sr = generator(lr)
            
            # optimize D
            optimizer_D.zero_grad()
            loss_D = adv_loss(discriminator(sr.detach()), fake)+\
                    adv_loss(discriminator(hr), valid)
                
            loss_D.backward()
            optimizer_D.step()
            
            # optimize G
            optimizer_G.zero_grad()
            loss_G = adv_loss(discriminator(sr), valid)
            loss_G.backward()
            optimizer_G.step()
            
            
            Gloss_epoch += loss_G.item()
            Dloss_epoch += loss_D.item()
            
        
        
            break
        print(Gloss_epoch)
        print(Dloss_epoch)
        break

In [None]:
train(10)