In [7]:
import os
import numpy as np
import math
import itertools
import sys
import torchvision.transforms as transforms
from torchvision.utils import save_image, make_grid
from torch.utils.data import DataLoader
from torch.autograd import Variable
from models import *
from datasets import *
import torch.nn as nn
import torch.nn.functional as F
import torch

In [2]:
os.makedirs("image", exist_ok=True)
os.makedirs("model", exist_ok=True)

n_epochs, decay_epoch = 100, 50
dataset_name = "img_align_celeba"
batch_size = 16
hr_h, hr_w = 256, 256
sample_interval = 1000

In [3]:
cuda = torch.cuda.is_available()

hr_shape = (hr_h, hr_w)

generator = GeneratorResNet()
discriminator = Discriminator(input_shape=(3, *hr_shape))

feature_extractor = FeatureExtractor()
feature_extractor.eval()

criterion_GAN = torch.nn.MSELoss()
criterion_content = torch.nn.L1Loss()

if cuda:
    generator = generator.cuda()
    discriminator = discriminator.cuda()
    feature_extractor = feature_extractor.cuda()
    criterion_GAN = criterion_GAN.cuda()
    criterion_content = criterion_content.cuda()

optimizer_G = torch.optim.Adam(generator.parameters(), 
                               lr=1e-4, betas=(0.5, 0.999))
optimizer_D = torch.optim.Adam(discriminator.parameters(), 
                               lr=1e-4, betas=(0.5, 0.999))

Tensor = torch.cuda.FloatTensor if cuda else torch.Tensor

dataloader = DataLoader(
    ImageDataset("./data/%s" % dataset_name, hr_shape=hr_shape),
    batch_size=batch_size,
    shuffle=True,
    num_workers=8,
)



cuda


In [None]:
for epoch in range(n_epochs):
    for i, imgs in enumerate(dataloader):

        imgs_lr = Variable(imgs["lr"].type(Tensor))
        imgs_hr = Variable(imgs["hr"].type(Tensor))

        valid = Variable(Tensor(np.ones((imgs_lr.size(0),*discriminator.output_shape))), requires_grad=False)
        fake = Variable(Tensor(np.zeros((imgs_lr.size(0),*discriminator.output_shape))), requires_grad=False)

        optimizer_G.zero_grad()

        gen_hr = generator(imgs_lr)
        gen_features = feature_extractor(gen_hr)
        real_features = feature_extractor(imgs_hr)
        
        loss_GAN = criterion_GAN(discriminator(gen_hr), valid)
        loss_content = criterion_content(gen_features, real_features.detach())
        loss_G = loss_content + 1e-3 * loss_GAN
        loss_G.backward()
        optimizer_G.step()

        optimizer_D.zero_grad()

        loss_real = criterion_GAN(discriminator(imgs_hr), valid)
        loss_fake = criterion_GAN(discriminator(gen_hr.detach()), fake)
        loss_D = (loss_real + loss_fake) / 2
        loss_D.backward()
        optimizer_D.step()

        sys.stdout.write(
            "\r[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f]"
            % (epoch, n_epochs, i, len(dataloader), loss_D.item(), loss_G.item())
        )

        batches_done = epoch * len(dataloader) + i
        if batches_done % sample_interval == 0:
            imgs_lr = nn.functional.interpolate(imgs_lr, scale_factor=4)
            gen_hr = make_grid(gen_hr, nrow=1, normalize=True)
            imgs_lr = make_grid(imgs_lr, nrow=1, normalize=True)
            img_grid = torch.cat((imgs_lr, gen_hr), -1)
            save_image(img_grid, "image/%d.png" % batches_done, normalize=False)

torch.save(generator.state_dict(), "model/G_%d.pth" % epoch)
torch.save(discriminator.state_dict(), "model/D_%d.pth" % epoch)

In [None]:
# 사전학습 모델 체크포인트 다운로드 링크 :
# drive.google.com/file/d/1GnATGVD6Aba4g7DE9Ohc9B7_Iz1KRi0i/view

In [10]:
load_epoch = 100
num_test = 10

generator.load_state_dict(torch.load("model/G_%d.pth" % load_epoch))

for i, imgs in enumerate(dataloader):
    if i == num_test :
        break
        
    imgs_lr = Variable(imgs["lr"].type(Tensor))
    imgs_hr = Variable(imgs["hr"].type(Tensor))
    
    gen_hr = generator(imgs_lr)
    imgs_lr = nn.functional.interpolate(imgs_lr, scale_factor=4)
    gen_hr = make_grid(gen_hr, nrow=1, normalize=True)
    
    imgs_lr = make_grid(imgs_lr, nrow=1, normalize=True)
    img_grid = torch.cat((imgs_lr, gen_hr), -1)
    save_image(img_grid, "image/result_%d.png" % i, normalize=False)
