SRGAN은 high resolution이미지에서 low resolution으로 만들어서 super resolution으로 이미지를 복원 시키는 모델이라고 생각하시면 됩니다.

참고 링크:

1) https://github.com/kunalrdeshmukh/SRGAN/blob/master/SRGAN.ipynb

2) https://www.kaggle.com/balraj98single-image-super-resolution-gan-srgan-pytorch

3) https://github.com/leftthomas/SRGAN/blob/master/data_utils.py

4) https://github.com/deepak112/Keras-SRGAN

In [None]:
import numpy as np
import pandas as pd
import os, math, sys
import glob, itertools
import argparse, random

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
from torchvision.models import vgg19
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, Dataset
import torchvision.datasets as dset
from torchvision.utils import save_image, make_grid

import plotly
import plotly.express as px
import plotly.graph_objects as go
import matplotlib.pyplot as plt

from PIL import Image
from tqdm import tqdm_notebook as tqdm
from sklearn.model_selection import train_test_split

random.seed(42)
import warnings
warnings.filterwarnings("ignore")

Setting

In [None]:
# number of epochs of training
n_epochs = 2

# name of the dataset: 폴더에서 경로복사 해주세요!
dataset_path = "/content/training_data/img_align_celeba"

# size of the batches
# batch_size = 16 # 지금 적용함 - 예시
batch_size = 64 # train
# batch_size = 1 # test 

# adam: learning rate
lr = 0.0001
# adam: decay of first order momentum of gradient
b1 = 0.5
# adam: decay of second order momentum of gradient
b2 = 0.999
# number of cpu threads to use during batch generation = number of workers
n_cpu = 8

# high res. image height
hr_height = 256
# high res. image width
hr_width = 256
# number of image channels = rgb
channels = 3

cuda = torch.cuda.is_available()
hr_shape = (hr_height, hr_width)

# **Dataset- CelebA**
예시입니다.

아마도 coco train 2017의 chair & sofa를 사용하는게 좋을 것 같습니다.

In [None]:
# Normalization parameters for pre-trained PyTorch models
mean = np.array([0.485, 0.456, 0.406])
std = np.array([0.229, 0.224, 0.225])

class ImageDataset(Dataset):
    def __init__(self, files, hr_shape):
        hr_height, hr_width = hr_shape
        # Transforms for low resolution images and high resolution images
        # low : high = 4 배
        self.lr_transform = transforms.Compose(
            [
                transforms.Resize((hr_height // 4, hr_height // 4), Image.BICUBIC),
                transforms.ToTensor(),
                transforms.Normalize(mean, std),
            ]
        )
        self.hr_transform = transforms.Compose(
            [
                transforms.Resize((hr_height, hr_height), Image.BICUBIC),
                transforms.ToTensor(),
                transforms.Normalize(mean, std),
            ]
        )
        self.files = files
    
    def __getitem__(self, index):
        img = Image.open(self.files[index % len(self.files)])
        img_lr = self.lr_transform(img)
        img_hr = self.hr_transform(img)

        return {"lr": img_lr, "hr": img_hr}

    def __len__(self):
        return len(self.files)

In [None]:
!wget https://s3-us-west-1.amazonaws.com/udacity-dlnfd/datasets/celeba.zip
!mkdir training_data/ && unzip celeba.zip -d training_data/

# dataset = dset.ImageFolder(root="training_data")
# train_dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=n_cpu)

--2020-11-19 09:42:13--  https://s3-us-west-1.amazonaws.com/udacity-dlnfd/datasets/celeba.zip
Resolving s3-us-west-1.amazonaws.com (s3-us-west-1.amazonaws.com)... 52.219.112.16
Connecting to s3-us-west-1.amazonaws.com (s3-us-west-1.amazonaws.com)|52.219.112.16|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 1443490838 (1.3G) [application/zip]
Saving to: ‘celeba.zip.4’


utime(celeba.zip.4): No such file or directory
2020-11-19 09:43:25 (19.5 MB/s) - ‘celeba.zip.4’ saved [1443490838/1443490838]

mkdir: cannot create directory ‘training_data/’: File exists


In [None]:
train_paths, test_paths = train_test_split(sorted(glob.glob(dataset_path + "/*.*")), test_size=0.02, random_state=42)
train_dataloader = DataLoader(ImageDataset(train_paths, hr_shape=hr_shape), batch_size=batch_size, shuffle=True, num_workers=n_cpu)
# test_dataloader = DataLoader(ImageDataset(test_data, hr_shape=hr_shape), batch_size=int(batch_size*0.75), shuffle=True, num_workers=n_cpu)

# **SRGAN** Model Define

In [None]:
def conv(ch_in, ch_out, k_size, stride=1, pad=1):
    layers = []
    layers.append(nn.Conv2d(ch_in, ch_out, k_size, stride, pad))
    return nn.Sequential(*layers)

def deconv(ch_in, ch_out, k_size, stride=2, pad=1, bn=False):
    layers = []
    layers.append(nn.ConvTranspose2d(ch_in, ch_out, k_size, stride, pad))
    if bn:
        layers.append(nn.BatchNorm2d(ch_out))
    return nn.Sequential(*layers)

In [None]:
#VGG19를 사용한 Fixed Feature Extraction
class FeatureExtractor(nn.Module):
  def __init__(self):
    super(FeatureExtractor, self).__init__()

    vgg19_model = vgg19(pretrained=True)
    self.feature_extractor = nn.Sequential(*list(vgg19_model.features.children())[:18])

  def forward(self, img):
    return self.feature_extractor(img)

class ResnetBlock(nn.Module):
    def __init__(self, in_features):
        super(ResnetBlock, self).__init__()
                
        self.conv_block = nn.Sequential(
            conv(in_features, in_features, 3, 1, 1),
            nn.BatchNorm2d(in_features, 0.8),
            nn.PReLU(),
            conv(in_features, in_features, 3, 1, 1),
            nn.BatchNorm2d(in_features, 0.8))

    def forward(self, x: torch.Tensor):
        out = x + self.conv_block(x)
        return out

In [None]:
class Discriminator(nn.Module):
  def __init__(self, input_shape):
    super(Discriminator, self).__init__()

    self.input_shape = input_shape
    in_channels, in_height, in_width = self.input_shape
    patch_h, patch_w = int(in_height / 2 ** 4), int(in_width / 2 ** 4)
    self.output_shape = (1, patch_h, patch_w)

    def discriminator_block(in_filters, out_filters, first_block=False):
        layers = []
        layers.append(conv(in_filters, out_filters, 3, 1, 1))
        if not first_block:
            layers.append(nn.BatchNorm2d(out_filters))
        layers.append(nn.LeakyReLU(0.2, inplace=True))
        layers.append(conv(out_filters, out_filters, 3, 2, 1))
        layers.append(nn.BatchNorm2d(out_filters))
        layers.append(nn.LeakyReLU(0.2, inplace=True))
        return layers

    layers = []
    in_filters = in_channels
    for i, out_filters in enumerate([64, 128, 256, 512]):
        layers.extend(discriminator_block(in_filters, out_filters, first_block=(i == 0)))
        in_filters = out_filters

    layers.append(conv(out_filters, 1, 3, 1, 1))

    self.model = nn.Sequential(*layers)
  def forward(self, img):
    return self.model(img)

In [None]:
class GeneratorResNet(nn.Module):
    def __init__(self, in_channels=3, out_channels=3, n_residual_blocks=16):
        super(GeneratorResNet, self).__init__()

        # First layer
        self.conv1 = nn.Sequential(conv(in_channels, 64, 9, 1, 4), nn.PReLU())

        # Residual blocks
        res_blocks = []
        for _ in range(n_residual_blocks):
            res_blocks.append(ResnetBlock(64))
        self.res_blocks = nn.Sequential(*res_blocks)

        # Second conv layer post residual blocks
        self.conv2 = nn.Sequential(conv(64, 64, 3, 1, 1), nn.BatchNorm2d(64, 0.8))

        # Upsampling layers
        upsampling = []
        for out_features in range(2):
            upsampling += [
                # nn.Upsample(scale_factor=2),
                nn.Conv2d(64, 256, 3, 1, 1),
                nn.BatchNorm2d(256),
                nn.PixelShuffle(upscale_factor=2),
                nn.PReLU(),
            ]
        self.upsampling = nn.Sequential(*upsampling)

        # Final output layer
        self.conv3 = nn.Sequential(conv(64, out_channels, 9, 1, 4), nn.Tanh())

    def forward(self, x):
        out1 = self.conv1(x)
        out = self.res_blocks(out1)
        out2 = self.conv2(out)
        out = torch.add(out1, out2)
        out = self.upsampling(out)
        out = self.conv3(out)
        return out

In [None]:
netG = GeneratorResNet().cuda()
netD = Discriminator(input_shape=(channels, *hr_shape)).cuda()
feature_extractor = FeatureExtractor().cuda()

feature_extractor.eval()

FeatureExtractor(
  (feature_extractor): Sequential(
    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU(inplace=True)
    (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (3): ReLU(inplace=True)
    (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (6): ReLU(inplace=True)
    (7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (8): ReLU(inplace=True)
    (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (11): ReLU(inplace=True)
    (12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (13): ReLU(inplace=True)
    (14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (15): ReLU(inplace=True)
    (16): Conv2d(256, 256, kernel_size=(3, 3

## **Loss Function** and **Optimizer**




In [None]:
criterion_GAN = torch.nn.MSELoss().cuda()
criterion_content = torch.nn.L1Loss().cuda()

optimizer_G = torch.optim.Adam(netG.parameters(), lr=lr, betas=(b1, b2))
optimizer_D = torch.optim.Adam(netD.parameters(), lr=lr, betas=(b1, b2))

# **Training**

In [None]:
os.makedirs('./results/images/', exist_ok=True)
os.makedirs('./results/checkpoints/', exist_ok=True)

Tensor = torch.cuda.FloatTensor if cuda else torch.Tensor

train_gen_losses, train_disc_losses, train_counter = [], [], []

for epoch in range(n_epochs):
  ### Training
  gen_loss, disc_loss = 0, 0
  netG.train()
  netD.train()
  tqdm_bar = tqdm(train_dataloader, desc=f'Training Epoch {epoch} ', total=int(len(train_dataloader)))

  for batch_idx, imgs in enumerate(tqdm_bar):

    # Configure model input
    imgs_lr = Variable(imgs["lr"].type(Tensor)).cuda() # 낮은 화질의 이미지
    imgs_hr = Variable(imgs["hr"].type(Tensor)).cuda() # 원래 이미지

    valid = Variable(Tensor(np.ones((imgs_lr.size(0), *netD.output_shape))), requires_grad=False)
    fake = Variable(Tensor(np.zeros((imgs_lr.size(0), *netD.output_shape))), requires_grad=False)
        
    ########################### Train Generator ################################
    optimizer_G.zero_grad()

    # Generate a high resolution image from low resolution input
    gen_hr = netG(imgs_lr)
    
    # Adversarial loss
    fake_hr = netD(gen_hr) # super resolution
    loss_GAN = criterion_GAN(fake_hr, valid)

    # Content loss
    gen_features = feature_extractor(gen_hr)
    real_features = feature_extractor(imgs_hr).detach()
    loss_content = criterion_content(gen_features, real_features)

    # Total loss
    # 기존 GAN loss와 다르게 generator로 부터 생성한 이미지를 HR 이미지로 구별할 확률을 정해줍니다.
    # 아래의 식으로 최소화하면 결과가 더 좋다고 합니다..
    loss_G = loss_content + 1e-3 * loss_GAN
    loss_G.backward()
    optimizer_G.step()

    ########################### Train Discriminator ############################
    optimizer_D.zero_grad()
    # Loss of real and fake images
    loss_real = criterion_GAN(netD(imgs_hr), valid)
    loss_fake = criterion_GAN(netD(gen_hr.detach()), fake)

    # Total loss
    loss_D = (loss_real + loss_fake) / 2
    loss_D.backward()
    optimizer_D.step()

    gen_loss += loss_G.item()
    train_gen_losses.append(loss_G.item())
    disc_loss += loss_D.item()
    train_disc_losses.append(loss_D.item())

    train_counter.append(batch_idx*batch_size + imgs_lr.size(0) + epoch*len(train_dataloader.dataset))
    tqdm_bar.set_postfix(gen_loss=gen_loss/(batch_idx+1), disc_loss=disc_loss/(batch_idx+1))

    # Test코드는 아직 안했습니다!
    
    # Save image grid with upsampled inputs and SRGAN outputs
    if random.uniform(0,1)<0.1:
        imgs_lr = nn.functional.interpolate(imgs_lr, scale_factor=4)
        imgs_hr = make_grid(imgs_hr, nrow=1, normalize=True)
        gen_hr = make_grid(gen_hr, nrow=1, normalize=True)
        imgs_lr = make_grid(imgs_lr, nrow=1, normalize=True)
        img_grid = torch.cat((imgs_hr, imgs_lr, gen_hr), -1)
        # 한 이미지에 원래 high, low, 만든 high가 들어갑니다!
        # batch size만큼 저장해서 한 이미징 16개씩 들어갑니다.
        save_image(img_grid, os.path.join('./results/images', 'fake-{:03d}.png'.format(batch_idx)), normalize=False)

HBox(children=(FloatProgress(value=0.0, description='Training Epoch 0 ', max=12410.0, style=ProgressStyle(desc…




HBox(children=(FloatProgress(value=0.0, description='Training Epoch 1 ', max=12410.0, style=ProgressStyle(desc…