# Image-to-Image translation (CycleGAN)
Jun-Yan Zhu et. al, Unpaired Image-to-Image Translation using Cycle-Consistent Adversarial Networks, ICCV 2017


## Import Libraries
필요한 라이브러리들을 가져옵니다.

In [None]:
%pylab inline

#from google.colab import drive
#drive.mount('/content/drive')

import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
from datetime import datetime
import os
import glob
from PIL import Image
from itertools import chain

## Custom dataset

In [None]:
class Dataset(torch.utils.data.Dataset):
    def __init__(self, root, img_size=256):
        self.root = root
        if not os.path.exists(self.root):
            raise Exception("[!] {} not exists.".format(root))

        self.name = os.path.basename(root)

        self.paths = glob.glob(os.path.join(self.root, '*'))
        if len(self.paths) == 0:
            raise Exception("No images are found in {}".format(self.root))
        self.shape = list(Image.open(self.paths[0]).size) + [3]

        self.transform = transforms.Compose([
            transforms.Resize(img_size),
            transforms.ToTensor(),
        ])

    def __getitem__(self, index):
        image = Image.open(self.paths[index]).convert('RGB')

        return self.transform(image)*2 - 1

    def __len__(self):
        return len(self.paths)

## Define models
### Define discriminator

In [None]:
class Discriminator(nn.Module):
    def __init__(self, input_channel, output_channel, hidden_dims):
        super(Discriminator, self).__init__()
        self.layers = []

        prev_dim = hidden_dims[0]
        self.layers.append(nn.Conv2d(input_channel, prev_dim, 4, 2, 1, bias=False))
        self.layers.append(nn.LeakyReLU(0.2, inplace=True))

        for out_dim in hidden_dims[1:]:
            self.layers.append(nn.Conv2d(prev_dim, out_dim, 4, 2, 1, bias=False))
            self.layers.append(nn.BatchNorm2d(out_dim))
            self.layers.append(nn.LeakyReLU(0.2, inplace=True))
            prev_dim = out_dim

        self.layers.append(nn.Conv2d(prev_dim, output_channel, 4, 1, 0, bias=False))
        self.layers.append(nn.Sigmoid())

        self.layer_module = nn.ModuleList(self.layers)


    def forward(self, x):
        out = x
        for layer in self.layer_module:
            out = layer(out)
        return out.view(out.size(0), -1)

### Define generator

The generator consists of encoding and decoding part.


In [None]:
class Generator(nn.Module):
    def __init__(self, input_channel, output_channel, conv_dims, deconv_dims):
        super(Generator, self).__init__()
        self.encoder = nn.ModuleList()

        prev_dim = conv_dims[0]
        self.encoder.append(nn.Conv2d(input_channel, prev_dim, 4, 2, 1, bias=False))
        self.encoder.append(nn.LeakyReLU(0.2, inplace=True))

        for out_dim in conv_dims[1:]:
            self.encoder.append(nn.Conv2d(prev_dim, out_dim, 4, 2, 1, bias=False))
            self.encoder.append(nn.BatchNorm2d(out_dim))
            self.encoder.append(nn.LeakyReLU(0.2, inplace=True))
            prev_dim = out_dim

        self.decoder = nn.ModuleList()
        for out_dim in deconv_dims:
            self.decoder.append(nn.ConvTranspose2d(prev_dim, out_dim, 4, 2, 1, bias=False))
            self.decoder.append(nn.BatchNorm2d(out_dim))
            self.decoder.append(nn.ReLU(True))
            prev_dim = out_dim

        self.decoder.append(nn.ConvTranspose2d(prev_dim, output_channel, 4, 2, 1, bias=False))
        self.decoder.append(nn.Tanh())

    def forward(self, x):
        out = x
        for layer in self.encoder:
            out = layer(out)
        for layer in self.decoder:
            out = layer(out)
        return out

## Prepararation
학습을 시작하기 위해 필요한 모든것들을 준비합니다.

Parameter를 자유롭게 바꾸면서 학습시켜 보세요.

주의!!!) CycleGAN 부터 학습 시간이 매우 길어집니다. 바로 결과가 나오지 않더라도 참고 기다리셔야 합니다.

In [None]:
### Your parameters ###
dataset_dir = '/content/drive/My Drive/StyleTransfer/summer2winter_yosemite/'
dataset_dir = '../datasets/summer2winter_yosemite/'

# models
conv_dims, deconv_dims = [64, 128, 256, 512], [256, 128, 64]

# dataloader
img_size = 64 #128
batch_size = 64
num_workers = 64*4

# optimizer
lr = 0.0001
beta1 = 0.5
beta2 = 0.999
weight_decay = 0.0001

# training
max_iter = 50000

# log step
log_disp_step = 100
img_disp_step = 100
net_save_step = 1000


#######################

path_snapshot = 'snapshots'
path_outresult = 'results'

### Your losses here ###
#d = nn.MSELoss()
L1 = nn.L1Loss()
bce = nn.BCELoss()

########################

## Define dataloader

In [None]:
# real & fake labels
real_label = 1
fake_label = 0

# prepare dataset
dataset_A = Dataset(dataset_dir + 'trainA/', img_size=img_size)
dataloader_A = torch.utils.data.DataLoader(dataset_A, 
                                           batch_size=batch_size, 
                                           num_workers=num_workers, 
                                           shuffle=True,
                                           drop_last=True)

dataset_B = Dataset(dataset_dir + 'trainB/', img_size=img_size)
dataloader_B = torch.utils.data.DataLoader(dataset_B, 
                                           batch_size=batch_size, 
                                           num_workers=num_workers, 
                                           shuffle=True,
                                           drop_last=True)



## Define network and optimizer
We define two discriminators $D_A, D_B$ and generators $G, F$.

The discriminator $D_A$ predicts whether the input belongs to the domain $A$.

The output of $D_A$ is given as:
$$D_A(x)$$
where $x\in A$.

The discriminator $D_B$ predicts whether the input belongs to the domain $B$.

The output of $D_B$ is given as:
$$D_B(x)$$
where $x\in B$.

The generator $G$ translates the input from domain $A$ to domain $B$.

The generator $F$ translates the input from domain $B$ to domain $A$.

In [None]:
# define models
G = Generator(3, 3, conv_dims, deconv_dims)
F = Generator(3, 3, conv_dims, deconv_dims)

D_A = Discriminator(3, 1, conv_dims)
D_B = Discriminator(3, 1, conv_dims)

G = torch.nn.DataParallel(G.cuda())
F = torch.nn.DataParallel(F.cuda())

D_A = torch.nn.DataParallel(D_A.cuda())
D_B = torch.nn.DataParallel(D_B.cuda())

# define optimizers
optimizer = torch.optim.Adam

optimizer_G = optimizer(
    chain(G.parameters(), F.parameters()),
    lr=lr, betas=(beta1, beta2), weight_decay=weight_decay)
optimizer_D = optimizer(
    chain(D_A.parameters(), D_B.parameters()),
    lr=lr, betas=(beta1, beta2), weight_decay=weight_decay)

# Training
### Training step
1. Generate fake images
2. Calculate GAN loss for discriminator
3. Update discriminator
4. Calculate GAN loss for generator
5. Calculate cycle consistency loss
6. Update generator

### Gnerating fake images
Assume two elements $x_A, x_B$ sampled from different domains $A,B$:
$$x_A\in A,x_B\in B$$

The fake images $x_A', x_B'$ are generated using two generators $G,F$:

\begin{eqnarray}
x_B'&=&G(x_A)\\
x_A'&=&F(x_B)
\end{eqnarray}

### Cycle-consistency loss
Two generators $G,F$ work as an inverse function to each other.
\begin{eqnarray}
G&=&F^{-1}\\
F&=&G^{-1}
\end{eqnarray}

The reconstructed images $\hat{x}_A, \hat{x}_B$ are given as:
\begin{eqnarray}
\hat{x}_A&=F(x_B')&=F(G(x_A))\\
\hat{x}_B&=G(x_A')&=G(F(x_B))
\end{eqnarray}

The cycle-consistency loss $\mathcal{L}_{cyc}$ is given as:
$$\mathcal{L}_{cyc}=||x_A-\hat{x}_A||+||x_B-\hat{x}_B||$$

### Identity loss
We wish generators to translate the input to another domain while preserving the core structures.

The translated images $x_B', x_A'$ are given as:

\begin{eqnarray}
x_B'&=&G(x_A)\\
x_A'&=&F(x_B)
\end{eqnarray}

while $x_A, x_B$ denotes the input images sampled from domain $A$ and $B$.

The identity loss $\mathcal{L}_{id}$ is given as:

\begin{eqnarray}
\mathcal{L}_{id} = ||x_B'-x_A|| + ||x_A'-x_B||
\end{eqnarray}


In [None]:
if not os.path.exists(path_snapshot):
    os.makedirs(path_snapshot)

# Set training mode
G.train()
F.train()
D_A.train()
D_B.train()

for i in range(max_iter):
    try:
        x_A = dataloader_A_iter.next().cuda()
    except:
        dataloader_A_iter = iter(dataloader_A)
        x_A = dataloader_A_iter.next().cuda()
    try:
        x_B = dataloader_B_iter.next().cuda()
    except:
        dataloader_B_iter = iter(dataloader_B)
        x_B = dataloader_B_iter.next().cuda()
        
    ## Update D_A & D_B
    
    optimizer_D.zero_grad()

    ## Generate fake images
    ### Your code here ###
    
    
    x_B_fake = G(x_A)
    x_A_fake = F(x_B)
    
    
    #######################
    
    ## Adversarial loss for two Discriminators
    loss_D_A_real = bce(D_A(x_A), torch.ones_like(D_A(x_A)))
    loss_D_A_fake = bce(D_A(x_A_fake.detach()), torch.zeros_like(D_A(x_A)))
    
    loss_D_B_real = bce(D_B(x_B), torch.ones_like(D_B(x_B)))
    loss_D_B_fake = bce(D_B(x_B_fake.detach()), torch.zeros_like(D_B(x_B)))

    loss_D = loss_D_A_real + loss_D_A_fake + loss_D_B_real + loss_D_B_fake
    
    loss_D.backward()

    optimizer_D.step()

    
    ## Update G & F

    optimizer_G.zero_grad()
    
    ## Adversarial loss for two Generators
    # G
    loss_G_B = bce(D_B(x_B_fake), torch.ones_like(D_B(x_B_fake)))
    # F
    loss_G_A = bce(D_A(x_A_fake), torch.ones_like(D_A(x_A_fake)))


    ## Cycle-consistency loss
    ### Your code here ###
        
    loss_G_cyc = L1(F(G(x_A)), x_A) + L1(G(F(x_B)), x_B)
    
    
    #######################
    
    ## Identity loss 
    ### Your code here ###
    
    loss_G_id = L1(G(x_A), x_A) + L1(F(x_B), x_B)
    
    #######################
    
    ## Total loss
    loss_G = loss_G_A + loss_G_B + 10*loss_G_cyc + 0.5*loss_G_id

    loss_G.backward()

    optimizer_G.step()
    
    if (i + 1)%log_disp_step==0:
        print("[{}] Iter {} / {}, D/loss: {}, G/loss: {}".format(
            str(datetime.now())[:-3], i + 1, max_iter,
            loss_D.item(), loss_G.item()))

    if (i + 1)%img_disp_step==0:
        disp_img = (torch.cat([x_A[:8],x_B_fake[:8],x_B[:8],x_A_fake[:8]]) + 1)*0.5
        disp_img = disp_img.cpu()
        disp_img = torchvision.utils.make_grid(disp_img) 
        disp_img = transforms.ToPILImage()(disp_img)
        disp_img.save('result_{:06d}.jpg'.format(i + 1))
        gcf().set_size_inches(10,10)
        imshow(disp_img)
        show()
        
    if (i + 1)%net_save_step == 0:
        torch.save(G.module.state_dict(), '{}/netG_{:06d}.pth'.format(path_snapshot, i + 1))
        torch.save(F.module.state_dict(), '{}/netF_{:06d}.pth'.format(path_snapshot, i + 1))


In [None]:
test_results = 'test_results'
if not os.path.exists(test_results):
    os.makedirs(test_results)

# parameters
g_load_dir = "snapshots/netG_001000.pth"
f_load_dir = "snapshots/netF_001000.pth"

# Set eval mode
G.eval()
F.eval()
D_A.eval()
D_B.eval()

# prepare dataset
dataset_A = Dataset(dataset_dir + 'testA/', img_size=img_size)
dataloader_A = torch.utils.data.DataLoader(dataset_A, 
                                           batch_size=batch_size, 
                                           num_workers=num_workers, 
                                           shuffle=True,
                                           drop_last=True)

dataset_B = Dataset(dataset_dir + 'testB/', img_size=img_size)
dataloader_B = torch.utils.data.DataLoader(dataset_B, 
                                           batch_size=batch_size, 
                                           num_workers=num_workers, 
                                           shuffle=True,
                                           drop_last=True)


# Load model
G.module.load_state_dict(torch.load(g_load_dir))
F.module.load_state_dict(torch.load(f_load_dir))

dataloader_A_iter = iter(dataloader_A)
dataloader_B_iter = iter(dataloader_B)

i = 0
while 1:
    try:
        x_A = dataloader_A_iter.next().cuda()
        x_B = dataloader_B_iter.next().cuda()
    except:
        break
        
    with torch.no_grad():
        x_B_fake = G(x_A)
        x_A_fake = F(x_B)
        
    #if (i + 1)%img_disp_step==0:
    disp_img = (torch.cat([x_A,x_B_fake,x_B,x_A_fake]) + 1)*0.5
    disp_img = disp_img.cpu()
    disp_img = torchvision.utils.make_grid(disp_img, batch_size) 
    disp_img = transforms.ToPILImage()(disp_img)
    disp_img.save('{}/test_result_{:06d}.jpg'.format(test_results, i + 1))
    #gcf().set_size_inches(10,10)
    #imshow(disp_img)
        #show()
    #if (i + 1)%test_log_step==0:
    i+=1
