# **SRGAN**:
 high resolution이미지에서 low resolution으로 만들어서 super resolution으로 이미지를 복원 시키는 모델이라고 생각하시면 됩니다.

참고 링크:

1) https://github.com/kunalrdeshmukh/SRGAN/blob/master/SRGAN.ipynb

2) https://www.kaggle.com/balraj98single-image-super-resolution-gan-srgan-pytorch

3) https://github.com/leftthomas/SRGAN/blob/master/data_utils.py

4) https://github.com/deepak112/Keras-SRGAN

# **ESRGAN:**

1) https://blog.naver.com/leeth5225/221645820290

2) https://github.com/xinntao/ESRGAN

In [None]:
import numpy as np
import pandas as pd
import os, math, sys
import glob, itertools
import random

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
from torchvision.models import vgg19
from torchvision.datasets import CIFAR100 # household furniture dataset을 학습에 활용함
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, Dataset
import torchvision.datasets as dset
from torchvision.utils import save_image, make_grid

from PIL import Image
# from sklearn.model_selection import train_test_split

random.seed(42)
import warnings
warnings.filterwarnings("ignore")

Setting

In [None]:
# number of epochs of training
n_epochs = 200

# size of the batches
batch_size = 4 # train
# batch_size = 1 # test 

# adam: learning rate
lr = 0.0002
# adam: decay of first order momentum of gradient
b1 = 0.9
# adam: decay of second order momentum of gradient
b2 = 0.999
# number of cpu threads to use during batch generation = number of workers
n_cpu = 2

# high res. image height
hr_height = 512
# high res. image width
hr_width = 512

# # high res. image height # out of cuda memory
# hr_height = 1024
# # high res. image width
# hr_width = 1024

# #cifar-100은 32x32
# hr_height = 32
# # high res. image width
# hr_width = 32

# number of image channels = rgb
channels = 3

cuda = torch.cuda.is_available()
hr_shape = (hr_height, hr_width)

# **Dataset**

아마도 coco train 2017의 chair & sofa를 사용하는게 좋을 것 같습니다.

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [None]:
# Normalization parameters for pre-trained PyTorch models
mean = np.array([0.485, 0.456, 0.406])
std = np.array([0.229, 0.224, 0.225])

class ImageDataset(Dataset):
    def __init__(self, files, hr_shape):
        hr_height, hr_width = hr_shape
        # Transforms for low resolution images and high resolution images
        # low : high = 4 배
        self.lr_transform = transforms.Compose(
            [
                transforms.Resize((hr_height // 4, hr_height // 4), Image.BICUBIC),
                transforms.ToTensor(),
                transforms.Normalize(mean, std),
            ]
        )
        self.hr_transform = transforms.Compose(
            [
                transforms.Resize((hr_height, hr_height), Image.BICUBIC),
                transforms.ToTensor(),
                transforms.Normalize(mean, std),
            ]
        )
        self.files = files
    
    def __getitem__(self, index):
        img = Image.open(self.files[index % len(self.files)])
        img_lr = self.lr_transform(img)
        img_hr = self.hr_transform(img)

        return {"lr": img_lr, "hr": img_hr}

    def __len__(self):
        return len(self.files)

# denormalize
def denormalize(tensors):
  """ Denormalizes image tensors using mean and std """
  for c in range(3):
      tensors[:, c].mul_(std[c]).add_(mean[c])
  return torch.clamp(tensors, 0, 255)

## **imaterialist**-furniture


사무실 의자, 나무의자, 소파, 세면대


In [None]:
dataset_path = "/content/drive/MyDrive/AI604_TeamProject/final_train_data"

train_dataloader = DataLoader(ImageDataset(glob.glob(dataset_path + "/*.*"), hr_shape=hr_shape), batch_size=batch_size, shuffle=True, num_workers=n_cpu)

# **Improved ESRGAN** Model Define

In [None]:
#VGG19를 사용한 Fixed Feature Extraction
# activation이전을 불러와서 화면 밝기를 일정하게 만들어 준다
class FeatureExtractor(nn.Module):
  def __init__(self):
    super(FeatureExtractor, self).__init__()

    vgg19_model = vgg19(pretrained=True)
    # self.feature_extractor = nn.Sequential(*list(vgg19_model.features.children())[:35]).eval()
    self.feature_extractor = nn.Sequential(*list(vgg19_model.features.children())[:35])

    # for param in self.feature_extractor.parameters():
    #   param.requires_grad = False

  def forward(self, img):
    return self.feature_extractor(img)

# 5 layers로 바꿈 & BN 제거
class ResnetBlock(nn.Module):
  def __init__(self, filters, res_scale=0.2):
    super(ResnetBlock, self).__init__()
    self.res_scale = res_scale
         
    def block(in_features, non_linearity=True):
      layers = [nn.Conv2d(in_features, filters, 3, 1, 1, bias=True)]
      if non_linearity:
          layers += [nn.LeakyReLU(negative_slope=0.2, inplace=True)]
      return nn.Sequential(*layers)

    self.conv1 = block(in_features=1 * filters)
    self.conv2 = block(in_features=2 * filters)
    self.conv3 = block(in_features=3 * filters)
    self.conv4 = block(in_features=4 * filters)
    self.conv5 = block(in_features=5 * filters, non_linearity=False)

    self.blocks = [self.conv1, self.conv2, self.conv3, self.conv4, self.conv5]

  def forward(self, x):
    inputs = x
    for block in self.blocks:
        out = block(inputs)
        inputs = torch.cat([inputs, out], 1)
    return out.mul(self.res_scale) + x

class RRDB(nn.Module):
    def __init__(self, filters, res_scale=0.2):
      super(RRDB, self).__init__()
      self.res_scale = res_scale
      self.dense_blocks = nn.Sequential(
          ResnetBlock(filters), ResnetBlock(filters), ResnetBlock(filters))

    def forward(self, x):
      return self.dense_blocks(x) * self.res_scale + x

In [None]:
class Discriminator(nn.Module):
  def __init__(self, input_shape):
    super(Discriminator, self).__init__()

    self.input_shape = input_shape
    in_channels, in_height, in_width = self.input_shape
    patch_h, patch_w = int(in_height / 2 ** 4), int(in_width / 2 ** 4)
    self.output_shape = (1, patch_h, patch_w)

    def discriminator_block(in_filters, out_filters, first_block=False):
      layers = []
      layers.append(nn.Conv2d(in_filters, out_filters, kernel_size=3, stride=1, padding=1, bias = False))
      if not first_block:
          layers.append(nn.BatchNorm2d(out_filters))
      layers.append(nn.LeakyReLU(0.2, inplace=True))
      # layers.append(nn.Dropout2d(0.2))
      layers.append(nn.Conv2d(out_filters, out_filters, kernel_size=3, stride=2, padding=1, bias =False))
      layers.append(nn.BatchNorm2d(out_filters))
      layers.append(nn.LeakyReLU(0.2, inplace=True))
      # layers.append(nn.Dropout2d(0.25))
      return layers

    layers = []
    in_filters = in_channels
    for i, out_filters in enumerate([64, 128, 256, 512]):
      layers.extend(discriminator_block(in_filters, out_filters, first_block=(i == 0)))
      in_filters = out_filters

    layers.append(nn.Conv2d(out_filters, 1, kernel_size=3, stride=1, padding=1))
    # layers.append(nn.Dropout2d(0.3))# add dropout to prevent discriminator overfitting--너무 안좋아져서 제거

    self.model = nn.Sequential(*layers)

  def forward(self, img):
    return self.model(img)

In [None]:
# ESRGAN처럼 Generator에서 BN을 없애고 RRDB로 바꿔준다
# 더 많은 layer와 block은 성능을 좋게 해주고 복잡성과 메모리 사용 감소

class Generator(nn.Module):
  def __init__(self, in_channels=3, out_channels=3, filters = 64, n_residual_blocks=23):
    super(Generator, self).__init__()

    # First layer
    self.conv1 = nn.Conv2d(in_channels, filters, kernel_size=3, stride=1, padding=1)

    # Residual blocks
    res_blocks = []
    for i in range(n_residual_blocks):
        res_blocks.append(ResnetBlock(filters))
    self.res_blocks = nn.Sequential(*res_blocks)

    # Second conv layer post residual blocks
    self.conv2 = nn.Conv2d(filters, filters, kernel_size=3, stride=1, padding=1)

    # Upsampling layers
    upsampling = []
    for out_features in range(2):
        upsampling += [
            # nn.Upsample(scale_factor=2),
            nn.Conv2d(filters, filters*4, kernel_size=3, stride=1, padding=1),
            nn.LeakyReLU(),
            nn.PixelShuffle(upscale_factor=2),
        ]
    self.upsampling = nn.Sequential(*upsampling)

    # Final output layer
    self.conv3 = nn.Sequential(
        nn.Conv2d(filters, filters, kernel_size=3, stride=1, padding=1),
        nn.LeakyReLU(),
        nn.Conv2d(filters, out_channels, kernel_size=3, stride=1, padding=1),
        )

  def forward(self, x):
    out1 = self.conv1(x)
    out = self.res_blocks(out1)
    out2 = self.conv2(out)
    out = torch.add(out1, out2)
    out = self.upsampling(out)
    out = self.conv3(out)
    return out

In [None]:
netG = Generator(channels, filters = 64, n_residual_blocks = 23).cuda()
netD = Discriminator(input_shape=(channels, *hr_shape)).cuda()
feature_extractor = FeatureExtractor().cuda()

feature_extractor.eval()

FeatureExtractor(
  (feature_extractor): Sequential(
    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU(inplace=True)
    (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (3): ReLU(inplace=True)
    (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (6): ReLU(inplace=True)
    (7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (8): ReLU(inplace=True)
    (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (11): ReLU(inplace=True)
    (12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (13): ReLU(inplace=True)
    (14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (15): ReLU(inplace=True)
    (16): Conv2d(256, 256, kernel_size=(3, 3

## **Loss Function** and **Optimizer**




In [None]:
criterion_GAN = torch.nn.BCEWithLogitsLoss().cuda() # 저절로 sigmoid를 해주기 때문에 discriminator에 sigmoid불필요!
criterion_content = torch.nn.MSELoss().cuda() # perceptual loss
criterion_pixel = torch.nn.L1Loss().cuda()

optimizer_G = torch.optim.Adam(netG.parameters(), lr=lr, betas=(b1, b2))
optimizer_D = torch.optim.Adam(netD.parameters(), lr=lr, betas=(b1, b2))

# **Training**

In [None]:
os.makedirs('/content/drive/MyDrive/AI604_TeamProject/improved_srgan_data/results/', exist_ok=True)
os.makedirs('/content/drive/MyDrive/AI604_TeamProject/improved_srgan_data/checkpoints/', exist_ok=True)

Tensor = torch.cuda.FloatTensor if cuda else torch.Tensor

# Retrain - load check points :
checkpoint = torch.load("/content/drive/MyDrive/AI604_TeamProject/improved_srgan_data/checkpoints/no_noise_train_188.pth") # no noise인지 noise포함 인지 확인 필수!******************************************
# checkpoint = torch.load("/content/drive/MyDrive/AI604_TeamProject/improved_srgan_data/big_checkpoints/train_184.pth")

netG.load_state_dict(checkpoint['model_G_state_dict'])
netD.load_state_dict(checkpoint['model_D_state_dict'])
optimizer_G.load_state_dict(checkpoint['optimizer_G_state_dict'])
optimizer_D.load_state_dict(checkpoint['optimizer_D_state_dict'])
start = checkpoint['epoch']
loss_content = checkpoint['loss_content']
loss_pixel = checkpoint['loss_pixel']
loss_G_real = checkpoint['loss_G_real']
loss_G_fake = checkpoint['loss_G_fake']
loss_GAN = checkpoint['loss_GAN']
loss_G = checkpoint['loss_G']
loss_real = checkpoint['loss_real']
loss_fake = checkpoint['loss_fake']
loss_D = checkpoint['loss_D']

pix = len(train_dataloader)*0.25
netG.train()
netD.train()

for epoch in range(start, n_epochs):
# for epoch in range(n_epochs):
  for batch_idx, imgs in enumerate(train_dataloader):

    batches_done = epoch * len(train_dataloader) + batch_idx

    # Configure model input
    imgs_lr = Variable(imgs["lr"].type(Tensor)).cuda()# 낮은 화질의 이미지
    imgs_hr = Variable(imgs["hr"].type(Tensor)).cuda() # 원래 이미지

    noise = Variable(imgs["hr"].type(Tensor).normal_(0, 0.1)) #gaussian noise

    valid = Variable(Tensor(np.ones((imgs_lr.size(0), *netD.output_shape))), requires_grad=False)
    fake = Variable(Tensor(np.zeros((imgs_lr.size(0), *netD.output_shape))), requires_grad=False)
        
    ########################### Train Generator ################################
    optimizer_G.zero_grad()

    # Generate a high resolution image from low resolution input
    gen_hr = netG(imgs_lr)

    # content loss
    loss_pixel = criterion_pixel(gen_hr, imgs_hr)

    # print(batches_done)
    
    if batches_done < pix:
      Warm-up (pixel-wise loss only)
      loss_pixel.backward()
      optimizer_G.step()
      continue

    # Adversarial loss
    pred_G_real = netD(imgs_hr).detach()
    pred_G_fake = netD(gen_hr) # super resolution

    loss_G_real = criterion_GAN(pred_G_real - pred_G_fake.mean(0, keepdim=True), fake) ######### ESRGAN에서 소개된 relativistic gan방법
    loss_G_fake = criterion_GAN(pred_G_fake - pred_G_real.mean(0, keepdim=True), valid)

    loss_GAN = (loss_G_real + loss_G_fake) / 2

    # Perceptual loss - 생성된 이미지 밝기를 일정하게 해줌
    gen_features = feature_extractor(gen_hr)
    real_features = feature_extractor(imgs_hr).detach()
    loss_content = criterion_content(gen_features, real_features)

    # Total Generator loss
    # 기존 GAN loss와 다르게 generator로 부터 생성한 이미지를 HR 이미지로 구별할 확률을 정해줍니다.
    # 아래의 식으로 최소화하면 결과가 더 좋다고 합니다..
    loss_G = loss_content + 0.005 * loss_GAN + 0.01 * loss_pixel

    loss_G.backward()
    optimizer_G.step()

    ########################### Train Discriminator ############################
    optimizer_D.zero_grad()
    # Loss of real and fake images
    # pred_real = netD(imgs_hr+noise) # add gaussian noise to discriminator to prevent discriminator overfitting*****************************************************************************
    pred_real = netD(imgs_hr) # 150부터 noise 없이 해봄
    pred_fake = netD(gen_hr).detach()

    loss_real = criterion_GAN(pred_real - pred_fake.mean(0, keepdim=True), valid) ######### ESRGAN에서 소개된 relativistic gan방법
    loss_fake = criterion_GAN(pred_fake - pred_real.mean(0, keepdim=True), fake)

    # Total loss
    loss_D = (loss_real + loss_fake) / 2

    loss_D.backward()
    optimizer_D.step()

    ############################################################################

    # Save Checkpoints
    if (batch_idx+1 == len(train_dataloader)) and (epoch+1) % 1 == 0:
      # print("chedking")
      torch.save({
            'epoch': epoch,
            'model_G_state_dict': netG.state_dict(),
            'model_D_state_dict': netD.state_dict(),
            'optimizer_G_state_dict': optimizer_G.state_dict(),
            'optimizer_D_state_dict': optimizer_D.state_dict(),
            'loss_content': loss_content,
            'loss_GAN': loss_GAN,
            'loss_pixel': loss_pixel,
            'loss_G_real': loss_G_real,
            'loss_G_fake': loss_G_fake,
            'loss_G': loss_G,
            'loss_real': loss_real,
            'loss_fake': loss_fake,
            'loss_D':loss_D,
            # }, os.path.join('/content/drive/MyDrive/AI604_TeamProject/improved_srgan_data/checkpoints/', 'train_{:d}.pth'.format(epoch)))**********************************************************

            }, os.path.join('/content/drive/MyDrive/AI604_TeamProject/improved_srgan_data/checkpoints/', 'no_noise_train_{:d}.pth'.format(epoch)))

    if (batch_idx + 1) % 400 == 0:
      print('Epoch [%d/%d], Step[%d/%d], lossD: %.4f, lossG: %.4f'
      % (epoch+1, n_epochs, batch_idx+1, len(train_dataloader), loss_D.item(), loss_G.item()))
    
    # Save images
    if batch_idx+1 == len(train_dataloader) and (epoch+1) %1 == 0:
      imgs_lr = nn.functional.interpolate(imgs_lr, scale_factor=4)
      img_grid = denormalize(torch.cat((imgs_hr, imgs_lr, gen_hr), -1))
      # 한 이미지에 원래 high, low, 만든 super가 들어갑니다!
      # save_image(img_grid, os.path.join('/content/drive/MyDrive/AI604_TeamProject/improved_srgan_data/results/', 'train_fake-{:03d}.jpg'.format(epoch +1)), nrow = 1, normalize = False)***********************************
      save_image(img_grid, os.path.join('/content/drive/MyDrive/AI604_TeamProject/improved_srgan_data/results/', 'no_noise_train_fake-{:03d}.jpg'.format(epoch +1)), nrow = 1, normalize = False)

## TEST

In [None]:
os.makedirs('/content/drive/MyDrive/AI604_TeamProject/evaluation_samples/for sfm/esrgan/3-1/office/', exist_ok=True) # Testing==========================


Tensor = torch.cuda.FloatTensor if cuda else torch.Tensor

# Test - load check points :
checkpoint = torch.load("/content/drive/MyDrive/AI604_TeamProject/improved_srgan_data/checkpoints/no_noise_train_199.pth")

netG.load_state_dict(checkpoint['model_G_state_dict'])

netG.eval()

# test_path = '/content/drive/MyDrive/AI604_TeamProject/final_train_data/168219_71.jpg'
test_path = '/content/drive/MyDrive/AI604_TeamProject/evaluation_samples/synsin(256*256)/office'

class TestImageDataset(Dataset):
  def __init__(self, files):
    self.transform = transforms.Compose(
        [
          transforms.ToTensor(),
          transforms.Normalize(mean, std),
        ]
    )
    self.files = files
  
  def __getitem__(self, index):
    img = Image.open(self.files[index % len(self.files)])
    img = self.transform(img)
    return img

  def __len__(self):
    return len(self.files)

# Prepare input
# final_test_dataloader = DataLoader(TestImageDataset(glob.glob(test_path )), batch_size=1, shuffle=False, num_workers=0)
final_test_dataloader = DataLoader(TestImageDataset(glob.glob(test_path + "/*.*")), batch_size=1, shuffle=False, num_workers=0)

for i, img in enumerate(final_test_dataloader):

  test_image = Variable(img.type(Tensor))

  # Upsample image
  with torch.no_grad():
    sr_image = denormalize(netG(test_image)).cpu()

  save_image(sr_image, os.path.join('/content/drive/MyDrive/AI604_TeamProject/evaluation_samples/for sfm/esrgan/3-1/office/', 'synsin_office{:d}.jpg'.format(i+1)))


# **Test** Image low resolution으로 바꾸기

In [None]:
# # make low resolution
# hy_test_path = "/content/drive/MyDrive/AI604_TeamProject/evaluation_samples/origin/장하영"
# hr_shape = (512, 512)

# test_dataloader = DataLoader(ImageDataset(glob.glob(hy_test_path + "/*.*"), hr_shape=hr_shape), batch_size=1, shuffle=False, num_workers=0)

In [None]:
# os.makedirs('/content/drive/MyDrive/AI604_TeamProject/improved_srgan_data/test_results/low_resolution/', exist_ok=True)
# Tensor = torch.cuda.FloatTensor if cuda else torch.Tensor

# for i, imgs in enumerate(test_dataloader):
#     imgs_lr = Variable(imgs["lr"].type(Tensor)).cuda()# 낮은 화질의 이미지
#     imgs_lr = nn.functional.interpolate(imgs_lr, scale_factor=4)
#     imgs_lr = denormalize(imgs_lr)
#     save_image(imgs_lr, os.path.join('/content/drive/MyDrive/AI604_TeamProject/improved_srgan_data/test_results/low_resolution/', 'hy_lr-{:d}.jpg'.format(i +1)))