In [None]:
# Git설치 링크 : git-scm.com/download/win
# Wget 다운로드 링크 : eternallybored.org/misc/wget/

In [None]:
# Facades 데이터셋 다운로드 코드
# bash download_pix2pix_dataset.sh facades

In [None]:
import os
import numpy as np
import sys
import torchvision.transforms as transforms
from torchvision.utils import save_image
from torch.utils.data import DataLoader
from torchvision import datasets
from torch.autograd import Variable
from models import *
from datasets import *
import torch.nn as nn
import torch.nn.functional as F
import torch

In [None]:
start_epoch, end_epochs = 0, 200
dataset_name = 'facades'
batch_size = 1
learning_rate = 0.0002
b1, b2 = 0.5, 0.999
img_h, img_w = 256, 256
sample_interval = 500

In [None]:
os.makedirs("images/%s" %dataset_name, exist_ok=True)
os.makedirs("save/%s" %dataset_name, exist_ok=True)

cuda = True if torch.cuda.is_available() else False

criterion_GAN = torch.nn.MSELoss()
criterion_pixelwise = torch.nn.L1Loss()

lambda_pixel = 100

patch = (1, img_h // 2 ** 4, img_w // 2 ** 4)

generator = GeneratorUNet()
discriminator = Discriminator()

if cuda:
    generator = generator.cuda()
    discriminator = discriminator.cuda()
    criterion_GAN.cuda()
    criterion_pixelwise.cuda()

generator.apply(weights_init_normal)
discriminator.apply(weights_init_normal)

optimizer_G = torch.optim.Adam(generator.parameters(),
                               lr=learning_rate, betas=(b1, b2))
optimizer_D = torch.optim.Adam(discriminator.parameters(),
                               lr=learning_rate, betas=(b1, b2))

transforms_ = [
    transforms.Resize((img_h, img_w), Image.BICUBIC),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
]

dataloader = DataLoader(
    ImageDataset("./%s" % dataset_name, transforms_=transforms_),
    batch_size=batch_size,
    shuffle=True,
    num_workers=8,
)

val_dataloader = DataLoader(
    ImageDataset("./%s" % dataset_name, transforms_=transforms_, mode="val"),
    batch_size=10,
    shuffle=True,
    num_workers=1,
)

Tensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor

In [None]:
def sample_images(batches_done):
    imgs = next(iter(val_dataloader))
    real_A = Variable(imgs["B"].type(Tensor))
    real_B = Variable(imgs["A"].type(Tensor))
    fake_B = generator(real_A)
    img_sample = torch.cat((real_A.data, fake_B.data, real_B.data), -2)
    save_image(img_sample, "images/%s/%s.png" % (dataset_name, batches_done), nrow=5, normalize=True)

In [None]:

for epoch in range(start_epoch, end_epochs):
    for i, batch in enumerate(dataloader):
        real_A = Variable(batch["B"].type(Tensor))
        real_B = Variable(batch["A"].type(Tensor))

        valid = Variable(Tensor(np.ones((real_A.size(0), *patch))), requires_grad=False)
        fake = Variable(Tensor(np.zeros((real_A.size(0), *patch))), requires_grad=False)

        optimizer_G.zero_grad()

        fake_B = generator(real_A)
        pred_fake = discriminator(fake_B, real_A)
        loss_GAN = criterion_GAN(pred_fake, valid)
        loss_pixel = criterion_pixelwise(fake_B, real_B)

        loss_G = loss_GAN + lambda_pixel * loss_pixel
        loss_G.backward()

        optimizer_G.step()


        optimizer_D.zero_grad()

        pred_real = discriminator(real_B, real_A)
        loss_real = criterion_GAN(pred_real, valid)
        pred_fake = discriminator(fake_B.detach(), real_A)
        loss_fake = criterion_GAN(pred_fake, fake)

        loss_D = 0.5 * (loss_real + loss_fake)
        loss_D.backward()
        optimizer_D.step()

        batches_done = epoch * len(dataloader) + i

        sys.stdout.write(
            "\r[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f, pixel: %f, adv: %f]"
            % (
                epoch,
                end_epochs,
                i,
                len(dataloader),
                loss_D.item(),
                loss_G.item(),
                loss_pixel.item(),
                loss_GAN.item(),
            )
        )

        if batches_done % sample_interval == 0:
            sample_images(batches_done)

    if epoch == (end_epoch-1):
        torch.save(generator.state_dict(), "save/%s/generator_%d.pth" % (dataset_name, epoch))
        torch.save(discriminator.state_dict(), "save/%s/discriminator_%d.pth" % (dataset_name, epoch))

[Epoch 0/200] [Batch 59/506] [D loss: 0.453430] [G loss: 42.223408, pixel: 0.415385, adv: 0.684910] ETA: 1:10:51.7617031.344002

KeyboardInterrupt: 

In [None]:
generator.load_state_dict(torch.load("save/%s/generator_199.pth"
                                         %(dataset_name)))
discriminator.load_state_dict(torch.load("save/%s/discriminator_199.pth"
                                             %(dataset_name)))
imgs = next(iter(val_dataloader))
real_A = Variable(imgs["B"].type(Tensor))
fake_B = generator(real_A)
img_sample = torch.cat((real_A.data, fake_B.data), -2)
save_image(img_sample, "images/%s/generation.png" % (dataset_name), nrow=5, normalize=True)