In [1]:
import argparse
import os
import numpy as np
import math
import itertools
import torchvision.transforms as transforms
from torchvision.utils import save_image, make_grid
from torch.utils.data import DataLoader
from torchvision import datasets
from torch.autograd import Variable
from models import *
from datasets import *
from utils import *
import torch.nn as nn
import torch.nn.functional as F
import torch

In [2]:
start_epoch = 0
end_epoch = 200
dataset = "horse2zebra"
batch_size = 1
learning_rate = 0.0002
b1, b2 = 0.5, 0.999
decay_epoch = 100
channels, img_h, img_w = 3, 256, 256
sample_interval = 100
save_interval = 50
n_residual_blocks = 9
lambda_cyc, lambda_id = 10, 5

In [4]:
os.makedirs("images/%s" % dataset, exist_ok=True)
os.makedirs("saved_models/%s" % dataset, exist_ok=True)

criterion_GAN = torch.nn.MSELoss()
criterion_cycle = torch.nn.L1Loss()
criterion_identity = torch.nn.L1Loss()

input_shape = (channels, img_h, img_w)
G_AB = GeneratorResNet(input_shape, n_residual_blocks)
G_BA = GeneratorResNet(input_shape, n_residual_blocks)
D_A = Discriminator(input_shape)
D_B = Discriminator(input_shape)

cuda = torch.cuda.is_available()
if cuda:
    G_AB = G_AB.cuda()
    G_BA = G_BA.cuda()
    D_A = D_A.cuda()
    D_B = D_B.cuda()
    criterion_GAN.cuda()
    criterion_cycle.cuda()
    criterion_identity.cuda()

if start_epoch != 0:
    G_AB.load_state_dict(torch.load("saved_models/%s/G_AB_%d.pth" 
                                    % (dataset, start_epoch)))
    G_BA.load_state_dict(torch.load("saved_models/%s/G_BA_%d.pth" 
                                    % (dataset, start_epoch)))
    D_A.load_state_dict(torch.load("saved_models/%s/D_A_%d.pth" 
                                   % (dataset, start_epoch)))
    D_B.load_state_dict(torch.load("saved_models/%s/D_B_%d.pth" 
                                   % (dataset, start_epoch)))
else:
    G_AB.apply(weights_init_normal)
    G_BA.apply(weights_init_normal)
    D_A.apply(weights_init_normal)
    D_B.apply(weights_init_normal)

In [5]:
optimizer_G = torch.optim.Adam(
    itertools.chain(G_AB.parameters(), G_BA.parameters()), 
    lr=learning_rate, betas=(b1, b2)
)
optimizer_D_A = torch.optim.Adam(D_A.parameters(), 
                                 lr=learning_rate, betas=(b1, b2))
optimizer_D_B = torch.optim.Adam(D_B.parameters(), 
                                 lr=learning_rate, betas=(b1, b2))

lr_scheduler_G = torch.optim.lr_scheduler.LambdaLR(
    optimizer_G, lr_lambda=LambdaLR(end_epoch, start_epoch, 
                                    decay_epoch).step)
lr_scheduler_D_A = torch.optim.lr_scheduler.LambdaLR(
    optimizer_D_A, lr_lambda=LambdaLR(end_epoch, start_epoch,
                                      decay_epoch).step)
lr_scheduler_D_B = torch.optim.lr_scheduler.LambdaLR(
    optimizer_D_B, lr_lambda=LambdaLR(end_epoch, start_epoch,
                                      decay_epoch).step)

Tensor = torch.cuda.FloatTensor if cuda else torch.Tensor

fake_A_buffer = ReplayBuffer()
fake_B_buffer = ReplayBuffer()

transforms_ = [
    transforms.Resize(int(img_h * 1.12), Image.BICUBIC),
    transforms.RandomCrop((img_h, img_w)),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
]

dataloader = DataLoader(
    ImageDataset("%s" % dataset, transforms_=transforms_, unaligned=True),
    batch_size=batch_size,
    shuffle=True,
    num_workers=8,
)
val_dataloader = DataLoader(
    ImageDataset("%s" %dataset, transforms_=transforms_, 
                 unaligned=True, mode="test"),
    batch_size=5,
    shuffle=True,
    num_workers=1,
)

In [4]:
def sample_images(batches_done):
    imgs = next(iter(val_dataloader))
    G_AB.eval()
    G_BA.eval()
    real_A = Variable(imgs["A"].type(Tensor))
    fake_B = G_AB(real_A)
    real_B = Variable(imgs["B"].type(Tensor))
    fake_A = G_BA(real_B)
    real_A = make_grid(real_A, nrow=5, normalize=True)
    real_B = make_grid(real_B, nrow=5, normalize=True)
    fake_A = make_grid(fake_A, nrow=5, normalize=True)
    fake_B = make_grid(fake_B, nrow=5, normalize=True)
    image_grid = torch.cat((real_A, fake_B, real_B, fake_A), 1)
    save_image(image_grid, "images/%s/%s.png" % (dataset, batches_done)
               , normalize=False)

In [5]:
for epoch in range(start_epoch, end_epoch):
    for i, batch in enumerate(dataloader):
        real_A = Variable(batch["A"].type(Tensor))
        real_B = Variable(batch["B"].type(Tensor))
        valid = Variable(Tensor(np.ones((real_A.size(0), *D_A.output_shape))),requires_grad=False)
        fake = Variable(Tensor(np.zeros((real_A.size(0), *D_A.output_shape))),requires_grad=False)
        G_AB.train()
        G_BA.train()
        optimizer_G.zero_grad()
        loss_id_A = criterion_identity(G_BA(real_A), real_A)
        loss_id_B = criterion_identity(G_AB(real_B), real_B)
        loss_identity = (loss_id_A + loss_id_B) / 2
        fake_B = G_AB(real_A)
        loss_GAN_AB = criterion_GAN(D_B(fake_B), valid)
        fake_A = G_BA(real_B)
        loss_GAN_BA = criterion_GAN(D_A(fake_A), valid)
        loss_GAN = (loss_GAN_AB + loss_GAN_BA) / 2
        recov_A = G_BA(fake_B)
        loss_cycle_A = criterion_cycle(recov_A, real_A)
        recov_B = G_AB(fake_A)
        loss_cycle_B = criterion_cycle(recov_B, real_B)
        loss_cycle = (loss_cycle_A + loss_cycle_B) / 2
        loss_G = loss_GAN + lambda_cyc * loss_cycle + lambda_id * loss_identity
        loss_G.backward()
        optimizer_G.step()
        
        optimizer_D_A.zero_grad(
        loss_real = criterion_GAN(D_A(real_A), valid)
        fake_A_ = fake_A_buffer.push_and_pop(fake_A)
        loss_fake = criterion_GAN(D_A(fake_A_.detach()), fake)
        loss_D_A = (loss_real + loss_fake) / 2
        loss_D_A.backward()
        optimizer_D_A.step()
            
        optimizer_D_B.zero_grad()
        loss_real = criterion_GAN(D_B(real_B), valid)
        fake_B_ = fake_B_buffer.push_and_pop(fake_B)
        loss_fake = criterion_GAN(D_B(fake_B_.detach()), fake)
        loss_D_B = (loss_real + loss_fake) / 2
        loss_D_B.backward()
        optimizer_D_B.step()

        loss_D = (loss_D_A + loss_D_B) / 2

        batches_done = epoch * len(dataloader) + i
        sys.stdout.write(
            "\r[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f,\
            adv: %f, cycle: %f, identity: %f]"
            % (epoch,end_epoch,i,len(dataloader),
                loss_D.item(),loss_G.item(),loss_GAN.item(),
                loss_cycle.item(),loss_identity.item()))

        if batches_done % sample_interval == 0:
            sample_images(batches_done)

    lr_scheduler_G.step()
    lr_scheduler_D_A.step()
    lr_scheduler_D_B.step()

    if epoch == (end_epoch-1) or epoch % save_interval == 0:
        torch.save(G_AB.state_dict(), "saved_models/%s/G_AB_%d.pth"
                   % (dataset, epoch))
        torch.save(G_BA.state_dict(), "saved_models/%s/G_BA_%d.pth"
                   % (dataset, epoch))
        torch.save(D_A.state_dict(), "saved_models/%s/D_A_%d.pth"
                   % (dataset, epoch))
        torch.save(D_B.state_dict(), "saved_models/%s/D_B_%d.pth"
                   % (dataset, epoch))

  valid = Variable(Tensor(np.ones((real_A.size(0), *D_A.output_shape))), requires_grad=False)


[Epoch 199/200] [Batch 1333/1334] [D loss: 0.045644] [G loss: 1.873915, adv: 0.937235, cycle: 0.057094, identity: 0.073147] ETA: 0:00:00.366128262089720

In [6]:
# 모델 체크포인트는 아래 링크에서 다운로드할 수 있습니다.
# drive.google.com/drive/folders/10Lh4CUhiYtTtlM7I4MePuAiapG3jl6B2?usp=drive_link

os.makedirs("saved_img", exist_ok=True)

G_AB = GeneratorResNet(input_shape, n_residual_blocks)
G_BA = GeneratorResNet(input_shape, n_residual_blocks)

if torch.cuda.is_available():
    G_AB = G_AB.cuda()
    G_BA = G_BA.cuda()
    
path = "saved_models/%s"
G_AB.load_state_dict(torch.load("%s/G_AB_199.pth" % (path)))
G_BA.load_state_dict(torch.load("%s/G_BA_199.pth" % (path)))

imgs = next(iter(val_dataloader))
G_AB.eval()
G_BA.eval()
real_A = Variable(imgs["A"].type(Tensor))
real_B = Variable(imgs["B"].type(Tensor))
fake_A = G_BA(real_B)
fake_B = G_AB(real_A)
real_A = make_grid(real_A, nrow=5, normalize=True)
real_B = make_grid(real_B, nrow=5, normalize=True)
fake_A = make_grid(fake_A, nrow=5, normalize=True)
fake_B = make_grid(fake_B, nrow=5, normalize=True)
image_grid = torch.cat((real_A, fake_B, real_B, fake_A), 1)
save_image(image_grid, "saved_img/gen_img.png", normalize=False)