# Image-to-Image translation (CycleGAN)
Jun-Yan Zhu et. al, Unpaired Image-to-Image Translation using Cycle-Consistent Adversarial Networks, ICCV 2017


## Import Libraries
필요한 라이브러리들을 가져옵니다.

In [None]:
%pylab inline

#from google.colab import drive
#drive.mount('/content/drive')

import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
from datetime import datetime
import os
import glob
from PIL import Image
import tqdm
import time
from itertools import chain
import random

## Custom dataset

In [None]:
class Dataset(torch.utils.data.Dataset):
    def __init__(self, rootA, rootB, img_size=256, test=False):
        self.rootA = rootA
        self.rootB = rootB
        if not os.path.exists(self.rootA):
            raise Exception("[!] {} not exists.".format(rootA))
        if not os.path.exists(self.rootB):
            raise Exception("[!] {} not exists.".format(rootB))

        self.name_A = os.path.basename(rootA)
        self.name_B = os.path.basename(rootB)

        self.paths_A = glob.glob(os.path.join(self.rootA, '*'))
        if len(self.paths_A) == 0:
            raise Exception("No images are found in {}".format(self.rootA))
            
        self.paths_B = glob.glob(os.path.join(self.rootB, '*'))
        if len(self.paths_B) == 0:
            raise Exception("No images are found in {}".format(self.rootB))
        #self.shape = list(Image.open(self.paths_A[0]).size) + [3]

        self.transform = transforms.Compose([
            transforms.Resize(img_size),
            transforms.ToTensor(),
            transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
        ])
        
        self.len = max(len(self.paths_A), len(self.paths_B))
        self.lenA = len(self.paths_A)
        self.lenB = len(self.paths_B)
        
        self.test = test

    def __getitem__(self, index):
        index_A = index % self.lenA 
        image_A = Image.open(self.paths_A[index_A]).convert('RGB')
        
        if self.test:
            return {'x_A': self.transform(image_A)}
        
        #index_B = index % self.lenB
        index_B = random.randint(0, self.lenB - 1)
        image_B = Image.open(self.paths_B[index_B]).convert('RGB')

        return {'x_A': self.transform(image_A), 'x_B': self.transform(image_B)}

    def __len__(self):
        return self.len

## Define models
### Define discriminator

In [None]:
class Discriminator(nn.Module):
    def __init__(self, input_channel, output_channel, hidden_dims):
        super(Discriminator, self).__init__()
        self.layers = []

        prev_dim = hidden_dims[0]
        self.layers.append(nn.Conv2d(input_channel, prev_dim, 4, 2, 1, bias=False))
        self.layers.append(nn.LeakyReLU(0.2, inplace=True))

        for out_dim in hidden_dims[1:]:
            self.layers.append(nn.Conv2d(prev_dim, out_dim, 4, 2, 1, bias=False))
            self.layers.append(nn.BatchNorm2d(out_dim))
            self.layers.append(nn.LeakyReLU(0.2, inplace=True))
            prev_dim = out_dim

        self.layers.append(nn.Conv2d(prev_dim, output_channel, 4, 1, 0, bias=False))
        #self.layers.append(nn.Sigmoid())

        self.layer_module = nn.ModuleList(self.layers)


    def forward(self, x):
        out = x
        for layer in self.layer_module:
            out = layer(out)
        return out.view(out.size(0), -1)

### Define generator

The generator consists of encoding and decoding part.


In [None]:
class Generator(nn.Module):
    def __init__(self, input_channel, output_channel, conv_dims, deconv_dims):
        super(Generator, self).__init__()
        self.encoder = nn.ModuleList()

        prev_dim = conv_dims[0]
        self.encoder.append(nn.Conv2d(input_channel, prev_dim, 4, 2, 1, bias=False))
        self.encoder.append(nn.LeakyReLU(0.2, inplace=True))

        for out_dim in conv_dims[1:]:
            self.encoder.append(nn.Conv2d(prev_dim, out_dim, 4, 2, 1, bias=False))
            self.encoder.append(nn.BatchNorm2d(out_dim))
            self.encoder.append(nn.LeakyReLU(0.2, inplace=True))
            prev_dim = out_dim

        self.decoder = nn.ModuleList()
        for out_dim in deconv_dims:
            self.decoder.append(nn.ConvTranspose2d(prev_dim, out_dim, 4, 2, 1, bias=False))
            self.decoder.append(nn.BatchNorm2d(out_dim))
            self.decoder.append(nn.ReLU(True))
            prev_dim = out_dim

        self.decoder.append(nn.ConvTranspose2d(prev_dim, output_channel, 4, 2, 1, bias=False))
        self.decoder.append(nn.Tanh())

    def forward(self, x):
        out = x
        for layer in self.encoder:
            out = layer(out)
        for layer in self.decoder:
            out = layer(out)
        return out

## Prepararation
학습을 시작하기 위해 필요한 모든것들을 준비합니다.

Parameter를 자유롭게 바꾸면서 학습시켜 보세요.

주의!!!) CycleGAN 부터 학습 시간이 매우 길어집니다.

In [None]:
### Your parameters ###
dataset_dir = '/content/drive/My Drive/StyleTransfer/summer2winter_yosemite/'
dataset_dir = '../datasets/summer2winter_yosemite/'

# models
conv_dims, deconv_dims = [64, 128, 256, 512], [256, 128, 64]
conv_dims, deconv_dims = [64, 128, 256], [128, 64]

# dataloader
img_size = 32 #128
batch_size = 128
num_workers = 2

# optimizer
lr = 0.0002
beta1 = 0.5
beta2 = 0.999
#weight_decay = 0.0001

# training
max_epoch = 50000

# log step
log_disp_step = 100
img_disp_step = 100
net_save_step = 100

# real & fake labels
real_label = 1
fake_label = 0

# loss weights
lambda_cyc = 10
lambda_idt = 5
lambda_adv = 1


#######################

path_snapshot = 'snapshots'
path_outresult = 'results'

### Your losses here ###
#d = nn.MSELoss()
L1 = nn.L1Loss()
adv = nn.MSELoss()# LSGAN
#adv = nn.BCEWithLogitsLoss()# Original GAN
########################

## Define dataloader

In [None]:
# prepare dataset
"""
dataset_A = Dataset(dataset_dir + 'trainA/', img_size=img_size)
dataloader_A = torch.utils.data.DataLoader(dataset_A, 
                                           batch_size=batch_size, 
                                           num_workers=num_workers, 
                                           shuffle=True,
                                           drop_last=True)

dataset_B = Dataset(dataset_dir + 'trainB/', img_size=img_size)
dataloader_B = torch.utils.data.DataLoader(dataset_B, 
                                           batch_size=batch_size, 
                                           num_workers=num_workers, 
                                           shuffle=True,
                                           drop_last=True)"""

# prepare dataset
dataset = Dataset(dataset_dir + 'trainA/', dataset_dir + 'trainB/', img_size=img_size)
dataloader = torch.utils.data.DataLoader(dataset, 
                                           batch_size=batch_size, 
                                           num_workers=num_workers, 
                                           shuffle=True,
                                           drop_last=True,
                                        )

## Define network and optimizer
We define two discriminators $D_A, D_B$ and generators $G, F$.

The discriminator $D_A$ predicts whether the input belongs to the domain $A$.

The output of $D_A$ is given as:
$$D_A(x)$$
where $x\sim A$.

The discriminator $D_B$ predicts whether the input belongs to the domain $B$.

The output of $D_B$ is given as:
$$D_B(x)$$
where $x\sim B$.

The generator $G$ translates the input from domain $A$ to domain $B$.

The generator $F$ translates the input from domain $B$ to domain $A$.

In [None]:
# define models
G = Generator(3, 3, conv_dims, deconv_dims)
F = Generator(3, 3, conv_dims, deconv_dims)

D_A = Discriminator(3, 1, conv_dims)
D_B = Discriminator(3, 1, conv_dims)

G = torch.nn.DataParallel(G.cuda())
F = torch.nn.DataParallel(F.cuda())

D_A = torch.nn.DataParallel(D_A.cuda())
D_B = torch.nn.DataParallel(D_B.cuda())

# define optimizers
optimizer = torch.optim.Adam

optimizer_G = optimizer(
    chain(G.parameters(), F.parameters()),
    lr=lr, betas=(beta1, beta2))#, weight_decay=weight_decay)
optimizer_D = optimizer(
    chain(D_A.parameters(), D_B.parameters()),
    lr=lr, betas=(beta1, beta2))#, weight_decay=weight_decay)

## Training
### Training step
1. Generate fake images
2. Calculate GAN loss for discriminator
3. Update discriminator
4. Calculate GAN loss for generator
5. Calculate cycle consistency loss
6. Update generator

### Adversarial loss (Generating fake images)
Assume two elements $x_A, x_B$ sampled from different domains $A,B$:
$$x_A\sim A,x_B\sim B$$

The fake images $x_A', x_B'$ are generated using two generators $G,F$:

\begin{eqnarray}
x_B'&=&G(x_A)\\
x_A'&=&F(x_B)
\end{eqnarray}

Therefore adversarial loss $\mathcal{L}_{adv}$ is given as:
\begin{eqnarray}
\mathcal{L}_{adv} &= &\mathbb{E}_{x_A}[\log(D_A(x_A))] + \mathbb{E}_{x_B}[\log(1-D_A(F(x_B))) ]\\
&&+ \mathbb{E}_{x_B}[\log(D_B(x_B))] + \mathbb{E}_{x_A}[\log(1-D_B(G(x_A))]
\end{eqnarray}

For LSGAN, the adversarial loss $\mathcal{L}_{adv}$ is given as:
\begin{eqnarray}
\mathcal{L}_{adv} &= &\mathbb{E}_{x_A}[||(D_A(x_A))-1||^2] + \mathbb{E}_{x_B}[||D_A(F(x_B)))||^2 ]\\
&&+ \mathbb{E}_{x_B}[||D_B(x_B))-1||^2] + \mathbb{E}_{x_A}[||D_B(G(x_A))||^2]
\end{eqnarray}

### Cycle-consistency loss
Two generators $G,F$ work as an inverse function of each other.
\begin{eqnarray}
G&=&F^{-1}\\
F&=&G^{-1}
\end{eqnarray}

The reconstructed images $\hat{x}_A, \hat{x}_B$ are given as:
\begin{eqnarray}
\hat{x}_A&=F(x_B')&=F(G(x_A))\\
\hat{x}_B&=G(x_A')&=G(F(x_B))
\end{eqnarray}

The cycle-consistency loss $\mathcal{L}_{cyc}$ is given as:
$$\mathcal{L}_{cyc}=\mathbb{E}_{x_A}[||x_A-\hat{x}_A||]+\mathbb{E}_{x_B}[||x_B-\hat{x}_B||]$$

### Identity loss
We wish generators to translate the input to another domain while preserving the core structures.

The translated images $x_B', x_A'$ are given as:

\begin{eqnarray}
x_B'&=&G(x_A)\\
x_A'&=&F(x_B)
\end{eqnarray}

while $x_A, x_B$ denotes the input images sampled from domain $A$ and $B$.

The identity loss $\mathcal{L}_{idt}$ is given as:

\begin{eqnarray}
\mathcal{L}_{idt} = \mathbb{E}_{x_A}[||x_B'-x_A||] + \mathbb{E}_{x_B}[||x_A'-x_B||]
\end{eqnarray}

### Total loss
The total loss $\mathcal{L}$ is given as:
\begin{eqnarray}
\mathcal{L} = \mathcal{L}_{adv} + \lambda_{cyc}\mathcal{L}_{cyc} + \lambda_{idt}\mathcal{L}_{idt}
\end{eqnarray}

where $\lambda_{cyc}$ and $\lambda_{idt}$ are weights.

The final objective is to find optimal $G^*$, $F^*$, $D_A^*$, and $D_B^*$, which are given as:
\begin{eqnarray}
G^*, F^*, D_A^*, D_B^* = \arg\min_{G,F}\max_{D_A, D_B}\mathcal{L} 
\end{eqnarray}

In [None]:
if not os.path.exists(path_snapshot):
    os.makedirs(path_snapshot)

# Set training mode
G.train()
F.train()
D_A.train()
D_B.train()



pbar = tqdm.tqdm(range(max_epoch), dynamic_ncols=True)
for i in pbar:
    

    for data in dataloader:

        x_A = data['x_A'].cuda()
        x_B = data['x_B'].cuda()

        ## Update D_A & D_B

        optimizer_D.zero_grad()

        ## Generate fake images
        ### Your code here ###


        x_B_fake = G(x_A)
        x_A_fake = F(x_B)


        #######################

        ## Adversarial loss for two Discriminators
        loss_D_A_real = adv(D_A(x_A), torch.ones_like(D_A(x_A)))
        loss_D_A_fake = adv(D_A(x_A_fake.detach()), torch.zeros_like(D_A(x_A)))

        loss_D_B_real = adv(D_B(x_B), torch.ones_like(D_B(x_B)))
        loss_D_B_fake = adv(D_B(x_B_fake.detach()), torch.zeros_like(D_B(x_B)))

        loss_D = 0.5*(loss_D_A_real + loss_D_A_fake + loss_D_B_real + loss_D_B_fake)

        loss_D.backward()

        optimizer_D.step()


        ## Update G & F

        optimizer_G.zero_grad()

        ## Adversarial loss for two Generators
        # G
        loss_G_B = adv(D_B(x_B_fake), torch.ones_like(D_B(x_B_fake)))
        # F
        loss_G_A = adv(D_A(x_A_fake), torch.ones_like(D_A(x_A_fake)))
        loss_G_adv = 0.5*(loss_G_A + loss_G_B)


        ## Cycle-consistency loss
        ### Your code here ###

        loss_G_cyc = L1(F(x_B_fake), x_A) + L1(G(x_A_fake), x_B)
        loss_G_cyc *= lambda_cyc

        #######################

        ## Identity loss 
        ### Your code here ###

        loss_G_idt = L1(G(x_A), x_A) + L1(F(x_B), x_B)
        loss_G_idt *= lambda_idt

        #######################

        ## Total loss
        loss_G = loss_G_adv + loss_G_cyc + loss_G_idt

        loss_G.backward()

        optimizer_G.step()
        
    if (i + 1)%log_disp_step==0:
        print("[{}] Epoch {} / {}, D/loss: {:.4f}, G/loss: {:.4f}, G_adv/loss: {:.4f}, G_cyc/loss: {:.4f}, G_idt/loss: {:.4f}".format(
            str(datetime.now())[:-3], i + 1, max_epoch,
            loss_D.item(), loss_G.item(), loss_G_adv.item(), loss_G_cyc.item(), loss_G_idt.item()))

    if (i + 1)%img_disp_step==0:
        disp_img = (torch.cat([x_A[:8],x_B_fake[:8],x_B[:8],x_A_fake[:8]]) + 1)*0.5
        disp_img = disp_img.cpu()
        disp_img = torchvision.utils.make_grid(disp_img) 
        disp_img = transforms.ToPILImage()(disp_img)
        #disp_img.save('result_{:06d}.jpg'.format(i + 1))
        gcf().set_size_inches(10,10)
        imshow(disp_img)
        show()

    if (i + 1)%net_save_step == 0:
        torch.save(G.module.state_dict(), '{}/netG_{:06d}.pth'.format(path_snapshot, i + 1))
        torch.save(F.module.state_dict(), '{}/netF_{:06d}.pth'.format(path_snapshot, i + 1))


## Testing
### Parameters

- ```batch_size```: the number of batch size while loading input image from dataset (bigger than 1 is not supported)
- ```max_test_images```: the maximum number of images to test from domain $A$ and $B$
- ```g_load_dir```: path to the checkpoint file of generator $G$
- ```f_load_dir```: path to the checkpoint file of generator $F$

In [None]:
test_results = 'test_results'
if not os.path.exists(test_results):
    os.makedirs(test_results)

# parameters
batch_size = 1
max_test_images = 3

g_load_dir = "snapshots/netG_002700.pth"
f_load_dir = "snapshots/netF_002700.pth"

# Set eval mode
G.eval()
F.eval()
D_A.eval()
D_B.eval()

# Load model
G.module.load_state_dict(torch.load(g_load_dir))
F.module.load_state_dict(torch.load(f_load_dir))


# prepare dataset
dataset = Dataset(dataset_dir + 'testA/', dataset_dir + 'testB/', img_size=img_size, test=True)
dataloader = torch.utils.data.DataLoader(dataset, 
                                           batch_size=batch_size, 
                                           num_workers=num_workers, 
                                           shuffle=False,
                                           drop_last=False,
                                        )

print('Results of x_A to domain B')

for idx, data in enumerate(dataloader):
    if idx>max_test_images:
        break
    x_A = data['x_A'].cuda()

    with torch.no_grad():
        x_B_fake = G(x_A)

    disp_img = (torch.cat([x_A,x_B_fake]) + 1)*0.5
    disp_img = disp_img.cpu()
    disp_img = torchvision.utils.make_grid(disp_img) 
    disp_img = transforms.ToPILImage()(disp_img)
    #disp_img.save('{}/test_result_{:06d}.jpg'.format(test_results, i + 1))
    gcf().set_size_inches(5, 5)
    imshow(disp_img)
    show() 
    
    

# prepare dataset
dataset = Dataset(dataset_dir + 'testB/', dataset_dir + 'testA/', img_size=img_size, test=True)
dataloader = torch.utils.data.DataLoader(dataset, 
                                           batch_size=batch_size, 
                                           num_workers=num_workers, 
                                           shuffle=False,
                                           drop_last=False,
                                        )

print('Results of x_B to domain A')
for idx, data in enumerate(dataloader):
    if idx>max_test_images:
        break
    x_A = data['x_A'].cuda()

    with torch.no_grad():
        x_B_fake = G(x_A)

    disp_img = (torch.cat([x_A,x_B_fake]) + 1)*0.5
    disp_img = disp_img.cpu()
    disp_img = torchvision.utils.make_grid(disp_img) 
    disp_img = transforms.ToPILImage()(disp_img)
    #disp_img.save('{}/test_result_{:06d}.jpg'.format(test_results, i + 1))
    gcf().set_size_inches(5,5)
    imshow(disp_img)
    show() 
    
