In [5]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as tf
import cv2
import numpy as np
import os
import matplotlib.pyplot as plt

In [2]:
from tqdm.notebook import tqdm
import PIL

###  Train

### Eval

In [3]:
IMG_SIZE = 128

In [1]:
#Generator Network
class Generator

###  Prepare data

In [4]:
CROP_SIZE = 128
class TrainDataset(Dataset):
    def __init__(self, dest, path):
        super(TrainDataset, self).__init__()
        self.image_names = [os.path.join(x,path) for x in os.listdir(path)]
        self.lr_preprocess = tf.Compose([tf.ToPILImage(),tf.Resize((CROP_SIZE,CROP_SIZE), interpolation = PIL.Image.BICUBIC),tf.ToTensor()])
        self.hr_preprocess = tf.Compose([tf.ToTensor()])
        self.len = len(self.image_names)
    def __len__(self):
        return self.len
    def __getitem__(self, index):
        img = self.image_names[index]
        hr = self.hr_preprocess(PIL.Image.open(img))
        lr = self.lr_preprocess(hr)
        return lr,hr

In [33]:
lr_preprocess = tf.Compose([tf.ToPILImage(),tf.Resize((CROP_SIZE,CROP_SIZE), interpolation = PIL.Image.BICUBIC),tf.ToTensor()])
hr_preprocess = tf.Compose([tf.ToTensor()])
img =PIL.Image.open(path)
img_hr = hr_preprocess(PIL.Image.open(path))
img_lr = lr_preprocess(img_hr)

# Model

<img src = "https://miro.medium.com/max/4916/1*zsiBj3IL4ALeLgsCeQ3lyA.png">

### Generator Network

In [9]:
#Residual Block
class Residual_Block(nn.Module):
    def __init__(self):
        super(Residual_Block, self).__init__()
        self.conv1 = nn.Conv2d(in_channels = 64, out_channels = 64, kernel_size = 3, padding = 1)
        self.BN1 = nn.BatchNorm2d(num_features = 64)
        self.conv2 = nn.Conv2d(in_channels = 64, out_channels = 64, padding = 1)
        self.BN2 = nn.BatchNorm2d(num_features = 64)
    def forward(self, x):
        yhat = self.conv1(x)
        yhat = self.BN1(x)
        yhat = nn.PReLU(yhat)
        yhat = self.conv2(x)
        yhat = self.BN2(x)
        x = x+yhat
#Upsample Block
class Upsample_Block(nn.Module):
    def __init__(self):
        super(Upsample_Block, self).__init__()
        self.conv1 = nn.Conv2d(in_channels = 64, out_channels = 1024, kernel_size = 3, stride = 1)
        self.ps = nn.PixelShuffle(upscale_factor = 4)
    def forward(self, x):
        yhat = self.conv1(x)
        yhat = self.ps(yhat)
        yhat = nn.PReLU(yhat)
        return yhat
        

In [10]:
#Generator Block
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.conv1 = nn.Conv2d(in_channels = 3, out_channels = 64, kernel_size = 9, stride = 1)
        #8 Residual Blocks
        self.residuals = nn.ModuleList([ResidualBlock for x in range(8)])
        self.conv2 = nn.Conv2d(in_channels = 64, out_channels = 64, kernel_size = 3, stride = 1)
        self.BN1 = nn.BatchNord2d(num_features = 64)
        # 2 Upsampling Blocks
        self.upsamples = nn.ModuleList([Upsample_Block for x in range(2)])
        self.conv3 = nn.Conv2d(in_channels = 1024, out_channels = 3, kernel_size = 9, stride = 1)
    def forward(self, x):
        x = self.conv1(x)
        x = nn.PReLU(x)
        x_res = x.clone()
        x_res = self.residuals(x_res)
        x_res = self.conv2(x_res)
        x_res = self.BN1(x_res)
        x = x+x_res
        x = self.upsamples(x)
        yhat = self.conv3(x)
        return yhat

### Discriminator Network

In [38]:
class Discriminator(nn.Module):
    def __init__(self, l=0.2):
        super(Discriminator, self).__init__()
        self.net = nn.Sequential(
        nn.Conv2d(in_channels = 3, out_channels = 64, kernel_size = 3, stride = 1),
        nn.LeakyReLU(l),
        
        nn.Conv2d(in_channels = 64, out_channels = 64, kernel_size = 3, stride = 2, padding = 1),
        nn.BatchNorm2d(num_features = 64),
        nn.LeakyReLU(l),
            
        nn.Conv2d(in_channels = 64, out_channels = 128, kernel_size = 3, stride = 1, padding = 1),
        nn.BatchNorm2d(num_features = 128),
        nn.LeakyReLU(l),
        
        nn.Conv2d(in_channels = 128, out_channels = 128, kernel_size = 3, stride = 2, padding = 1),
        nn.BatchNorm2d(num_features = 128),
        nn.LeakyReLU(l),
            
        nn.Conv2d(in_channels = 128, out_channels = 256, kernel_size = 3, stride = 1, padding = 1),
        nn.BatchNorm2d(num_features = 256),
        nn.LeakyReLU(l),
            
        nn.Conv2d(in_channels = 256, out_channels = 256, kernel_size = 3, stride = 2, padding = 1),
        nn.BatchNorm2d(num_features = 256),
        nn.LeakyReLU(l),
            
        nn.Conv2d(in_channels = 256, out_channels = 512, kernel_size = 3, stride = 1, padding = 1),
        nn.BatchNorm2d(num_features = 512),
        nn.LeakyReLU(l),
            
        nn.Conv2d(in_channels = 512, out_channels = 512, kernel_size = 3, stride = 2, padding = 1),
        nn.BatchNorm2d(num_features = 512),
        nn.LeakyReLU(l),
            
        nn.AdaptiveAvgPool2d(1),
        nn.Conv2d(512, 1024, kernel_size = 1),
        nn.LeakyReLU(l),
        nn.Conv2d(1024,1,kernel_size = 1)
        )
    def forward(self,x):
        y = self.net(x)
        yhat = torch.sigmoid(y).view(x.shape[0])
        return yhat

In [39]:
x = torch.randn((5,3,512,512))

In [40]:
disc = Discriminator()

In [41]:
y = disc(x)

In [44]:
y.data

tensor([0.4547, 0.4564, 0.4596, 0.4561, 0.4580])