In [1]:
import os
import itertools
import numpy as np
import matplotlib.pyplot as plt
import cv2

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torchsummary import summary
from torchvision.datasets import MNIST

  warn(f"Failed to load image Python extension: {e}")


In [2]:
plt.style.use("ggplot")
DEVICE = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
CHANNELS, IMG_ROWS, IMG_COLS = 3, 128, 128
IMG_SHAPE = (CHANNELS, IMG_ROWS, IMG_COLS)
Z_DIM = 100

In [3]:
apple_dataset, apple_img_path = [], "./apple2orange/trainA"
orange_dataset, orange_img_path = [], "./apple2orange/trainB"
apple_img_list, orange_img_list = os.listdir(apple_img_path), os.listdir(orange_img_path)
num_apple, num_orange = len(apple_img_list), len(orange_img_list)

for img_name in apple_img_list:
    img = cv2.imread(f"{apple_img_path}/{img_name}")
    img = cv2.resize(img, dsize=(IMG_ROWS, IMG_COLS))
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    img = np.transpose(img, (2, 0, 1))
    img = np.expand_dims(img, axis=0)
    apple_dataset.append(torch.FloatTensor(img))

for img_name in orange_img_list:
    img = cv2.imread(f"{orange_img_path}/{img_name}")
    img = cv2.resize(img, dsize=(IMG_ROWS, IMG_COLS))
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    img = np.transpose(img, (2, 0, 1))
    img = np.expand_dims(img, axis=0)
    orange_dataset.append(torch.FloatTensor(img))

apple_dataset, orange_dataset = torch.cat(apple_dataset, dim=0).to(DEVICE), torch.cat(orange_dataset, dim=0).to(DEVICE)
apple_dataset, orange_dataset = (apple_dataset / 127.5) - 1.0, (orange_dataset / 127.5) - 1.0

In [4]:
class Conv2d(nn.Module):
    def __init__(self, in_channels, out_channels, f_size=4, normalization=True):
        super(Conv2d, self).__init__()
        self.conv = nn.Conv2d(in_channels=in_channels, out_channels=out_channels,
                              kernel_size=f_size, stride=2, padding=1)
        self.norm = nn.InstanceNorm2d(out_channels)
        self.normalization = normalization

    def forward(self, x):
        x = F.leaky_relu(self.conv(x), negative_slope=0.01)
        if self.normalization:
            x = self.norm(x)
        return x


conv2d = Conv2d(in_channels=32, out_channels=64).to(DEVICE)
summary(conv2d, (32, 64, 64))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1           [-1, 64, 32, 32]          32,832
    InstanceNorm2d-2           [-1, 64, 32, 32]               0
Total params: 32,832
Trainable params: 32,832
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.50
Forward/backward pass size (MB): 1.00
Params size (MB): 0.13
Estimated Total Size (MB): 1.63
----------------------------------------------------------------


In [5]:
class DeConv2d(nn.Module):
    def __init__(self, in_channels, out_channels, f_size=4, dropout_rate=0):
        super(DeConv2d, self).__init__()
        self.upsample = nn.Upsample(scale_factor=2, mode="nearest")
        self.conv = nn.Conv2d(in_channels=in_channels, out_channels=out_channels,
                              kernel_size=f_size, stride=1, padding='same')
        self.dropout_rate = dropout_rate
        self.dropout = nn.Dropout2d(p=dropout_rate)
        self.norm = nn.InstanceNorm2d(out_channels)

    def forward(self, x, skip_input):
        x = self.upsample(x)
        x = F.relu(self.conv(x))
        if self.dropout_rate:
            x = self.dropout(x)
        x = self.norm(x)
        output = torch.cat((x, skip_input), dim=1)
        return output


deconv2d = DeConv2d(256, 128).to(DEVICE)
summary(deconv2d, [(256, 8, 8), (128, 16, 16)])

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
          Upsample-1          [-1, 256, 16, 16]               0
            Conv2d-2          [-1, 128, 16, 16]         524,416
    InstanceNorm2d-3          [-1, 128, 16, 16]               0
Total params: 524,416
Trainable params: 524,416
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 2048.00
Forward/backward pass size (MB): 1.00
Params size (MB): 2.00
Estimated Total Size (MB): 2051.00
----------------------------------------------------------------


  return F.conv2d(input, weight, bias, self.stride,


In [11]:
class Generator(nn.Module):
    def __init__(self, img_shape, gf=32):
        super(Generator, self).__init__()
        self.conv1 = Conv2d(in_channels=img_shape[0], out_channels=gf).to(DEVICE)
        self.conv2 = Conv2d(in_channels=gf, out_channels=gf * 2).to(DEVICE)
        self.conv3 = Conv2d(in_channels=gf * 2, out_channels=gf * 4).to(DEVICE)
        self.conv4 = Conv2d(in_channels=gf * 4, out_channels=gf * 8).to(DEVICE)
        self.deconv1 = DeConv2d(in_channels=gf * 8, out_channels=gf * 4).to(DEVICE)
        self.deconv2 = DeConv2d(in_channels=gf * 8, out_channels=gf * 2).to(DEVICE)
        self.deconv3 = DeConv2d(in_channels=gf * 4, out_channels=gf).to(DEVICE)
        self.upsample = nn.Upsample(scale_factor=2, mode="nearest")
        self.conv5 = nn.Conv2d(in_channels=gf * 2, out_channels=img_shape[0], kernel_size=4, stride=1, padding="same")
        self.tanh = nn.Tanh()

    def forward(self, x):
        d1 = self.conv1(x)
        d2 = self.conv2(d1)
        d3 = self.conv3(d2)
        d4 = self.conv4(d3)
        u1 = self.deconv1(d4, d3)
        u2 = self.deconv2(u1, d2)
        u3 = self.deconv3(u2, d1)
        u4 = self.upsample(u3)
        output = self.tanh(self.conv5(u4))
        return output


generator_AB = Generator(img_shape=(3, 128, 128)).to(DEVICE)
generator_BA = Generator(img_shape=(3, 128, 128)).to(DEVICE)
summary(generator_AB, (3, 128, 128))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1           [-1, 32, 64, 64]           1,568
    InstanceNorm2d-2           [-1, 32, 64, 64]               0
            Conv2d-3           [-1, 32, 64, 64]               0
            Conv2d-4           [-1, 64, 32, 32]          32,832
    InstanceNorm2d-5           [-1, 64, 32, 32]               0
            Conv2d-6           [-1, 64, 32, 32]               0
            Conv2d-7          [-1, 128, 16, 16]         131,200
    InstanceNorm2d-8          [-1, 128, 16, 16]               0
            Conv2d-9          [-1, 128, 16, 16]               0
           Conv2d-10            [-1, 256, 8, 8]         524,544
   InstanceNorm2d-11            [-1, 256, 8, 8]               0
           Conv2d-12            [-1, 256, 8, 8]               0
         Upsample-13          [-1, 256, 16, 16]               0
           Conv2d-14          [-1, 128,

In [12]:
class Discriminator(nn.Module):
    def __init__(self, img_shape, df=64):
        super(Discriminator, self).__init__()
        self.conv1 = Conv2d(in_channels=img_shape[0], out_channels=df, normalization=False).to(DEVICE)
        self.conv2 = Conv2d(in_channels=df, out_channels=df * 2).to(DEVICE)
        self.conv3 = Conv2d(in_channels=df * 2, out_channels=df * 4).to(DEVICE)
        self.conv4 = Conv2d(in_channels=df * 4, out_channels=df * 8).to(DEVICE)
        self.conv5 = nn.Conv2d(in_channels=df * 8, out_channels=1, kernel_size=4, stride=1, padding='same')

    def forward(self, x):
        x = self.conv2(self.conv1(x))
        x = self.conv4(self.conv3(x))
        validity = self.conv5(x)
        return validity
    

discriminator_A = Discriminator(img_shape=(3, 128, 128)).to(DEVICE)
discriminator_B = Discriminator(img_shape=(3, 128, 128)).to(DEVICE)
summary(discriminator_A, (3, 128, 128))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1           [-1, 64, 64, 64]           3,136
            Conv2d-2           [-1, 64, 64, 64]               0
            Conv2d-3          [-1, 128, 32, 32]         131,200
    InstanceNorm2d-4          [-1, 128, 32, 32]               0
            Conv2d-5          [-1, 128, 32, 32]               0
            Conv2d-6          [-1, 256, 16, 16]         524,544
    InstanceNorm2d-7          [-1, 256, 16, 16]               0
            Conv2d-8          [-1, 256, 16, 16]               0
            Conv2d-9            [-1, 512, 8, 8]       2,097,664
   InstanceNorm2d-10            [-1, 512, 8, 8]               0
           Conv2d-11            [-1, 512, 8, 8]               0
           Conv2d-12              [-1, 1, 8, 8]           8,193
Total params: 2,764,737
Trainable params: 2,764,737
Non-trainable params: 0
---------------------------

In [13]:
optim_dis_A = optim.Adam(discriminator_A.parameters(), lr=0.0002, betas=(0.5, 0.999))
optim_dis_B = optim.Adam(discriminator_B.parameters(), lr=0.0002, betas=(0.5, 0.999))
optim_gen = optim.Adam(itertools.chain(generator_AB.parameters(), generator_BA.parameters()), 
                       lr=0.0002, betas=(0.5, 0.999))
criterion_cycle_loss = nn.L1Loss()
criterion_adversarial_loss = nn.MSELoss()

In [9]:
def sample_images(gen_AB, gen_BA, iteration, path="./Chapter09_image"):
    if not os.path.exists(path):
        os.mkdir(path)
    img_A_idx = np.random.randint(low=0, high=apple_dataset.shape[0], size=1)
    img_B_idx = np.random.randint(low=0, high=orange_dataset.shape[0], size=1)
    img_A, img_B = apple_dataset[img_A_idx], orange_dataset[img_B_idx]

    with torch.no_grad():
        gen_AB.eval()
        gen_BA.eval()
        fake_B, fake_A = gen_AB(img_A), gen_BA(img_B)
        reconstr_A, reconstr_B = gen_BA(fake_B), gen_AB(fake_A)
        gen_imgs = torch.cat([img_A, fake_B, reconstr_A, img_B, fake_A, reconstr_B], dim=0)
        gen_imgs = 0.5 * gen_imgs + 0.5
        gen_imgs = gen_imgs.detach().to(torch.device("cpu")).numpy()
    titles = ["Original", "Translated", "Reconstructed"]
    fig, axes = plt.subplots(2, 3, figsize=(6, 4), sharex=True, sharey=True)
    cnt = 0
    for i in range(2):
        for j in range(3):
            axes[i, j].imshow(gen_imgs[cnt].transpose(1, 2, 0))
            axes[i, j].set_title(titles[j])
            axes[i, j].axis("off")
            cnt += 1
    plt.tight_layout()
    plt.savefig(f"{path}/img_{iteration:03d}")
    plt.close()

In [14]:
losses, iteration_checkpoints = [], []
iterations = 2000
batch_size = 64
sample_interval = 100
lambda_cycle, lambda_id = 10.0, 9.0

valid = torch.ones(batch_size, 1, 8, 8).to(DEVICE)
fake = torch.zeros(batch_size, 1, 8, 8).to(DEVICE)

for iteration in range(iterations):
    idx_apple = np.random.randint(low=0, high=apple_dataset.shape[0], size=batch_size)
    idx_orange = np.random.randint(low=0, high=orange_dataset.shape[0], size=batch_size)
    imgs_A, imgs_B = apple_dataset[idx_apple], orange_dataset[idx_orange]

    fake_A, fake_B = generator_BA(imgs_B).detach(), generator_AB(imgs_A).detach()

    optim_dis_A.zero_grad()
    dA_pred_real, dA_pred_fake = discriminator_A(imgs_A), discriminator_A(fake_A)
    dA_loss_real = criterion_adversarial_loss(dA_pred_real, valid)
    dA_loss_fake = criterion_adversarial_loss(dA_pred_fake, fake)
    dA_loss = (dA_loss_real + dA_loss_fake) * 0.5
    dA_loss.backward()
    optim_dis_A.step()

    optim_dis_B.zero_grad()
    dB_pred_real, dB_pred_fake = discriminator_B(imgs_B), discriminator_B(fake_B)
    dB_loss_real = criterion_adversarial_loss(dB_pred_real, valid)
    dB_loss_fake = criterion_adversarial_loss(dB_pred_fake, fake)
    dB_loss = (dB_loss_real + dB_loss_fake) * 0.5
    dB_loss.backward()
    optim_dis_B.step()

    d_loss = dA_loss + dB_loss
    
    optim_gen.zero_grad()
    fake_B = generator_AB(imgs_A)
    fake_A = generator_BA(imgs_B)
    reconstr_A = generator_BA(fake_B)
    reconstr_B = generator_AB(fake_A)
    imgs_A_identical = generator_BA(imgs_A)
    imgs_B_identical = generator_AB(imgs_B)

    valid_A = discriminator_A(fake_A)
    valid_B = discriminator_B(fake_B)

    loss_A_reconstr = criterion_cycle_loss(reconstr_A, imgs_A)
    loss_B_reconstr = criterion_cycle_loss(reconstr_B, imgs_B)
    
    loss_A_adversarial = criterion_adversarial_loss(valid_A, valid)
    loss_B_adversarial = criterion_adversarial_loss(valid_B, valid)
    
    loss_A_identical = criterion_cycle_loss(imgs_A_identical, imgs_A)
    loss_B_identical = criterion_cycle_loss(imgs_B_identical, imgs_B)
    
    loss_A_total = loss_A_adversarial + lambda_cycle * loss_A_reconstr + lambda_id * loss_A_identical
    loss_B_total = loss_B_adversarial + lambda_cycle * loss_B_reconstr + lambda_id * loss_B_identical
    
    g_loss = loss_A_total + loss_B_total
    g_loss.backward()
    optim_gen.step()

    if (iteration + 1) % sample_interval == 0:
        losses.append([d_loss.item(), g_loss.item()])
        iteration_checkpoints.append(iteration + 1)
        print(f"{iteration + 1} [D loss: {d_loss.item():.4f}] [G loss: {g_loss.item():.4f}]")
        sample_images(generator_AB, generator_BA, iteration=iteration + 1)

100 [D loss: 0.1560] [G loss: 11.3459]
200 [D loss: 0.0818] [G loss: 10.7927]
300 [D loss: 0.0673] [G loss: 9.3598]
400 [D loss: 0.2565] [G loss: 8.6609]
500 [D loss: 0.0510] [G loss: 8.8316]
600 [D loss: 0.0355] [G loss: 8.6586]
700 [D loss: 0.0308] [G loss: 8.1440]
800 [D loss: 0.0905] [G loss: 8.0540]
900 [D loss: 0.0307] [G loss: 7.7996]
1000 [D loss: 0.0216] [G loss: 8.1018]
1100 [D loss: 0.0934] [G loss: 7.3984]
1200 [D loss: 0.1553] [G loss: 6.8660]
1300 [D loss: 0.0462] [G loss: 7.6937]
1400 [D loss: 0.0400] [G loss: 7.1914]
1500 [D loss: 0.0415] [G loss: 7.1440]
1600 [D loss: 0.0348] [G loss: 7.0255]
1700 [D loss: 0.0242] [G loss: 6.8890]
1800 [D loss: 0.0618] [G loss: 6.5398]
1900 [D loss: 0.0488] [G loss: 6.5681]
2000 [D loss: 0.0284] [G loss: 6.7743]
