In [None]:
from __future__ import print_function, division
import scipy
from tensorflow.keras.datasets import mnist
from tensorflow.addons.layers import InstanceNormalization
from tensorflow.keras.layers import *
from tensorflow.keras.models import Sequential, Model
from tensorflow.keras.optimizers import Adam
import datetime
import matplotlib.pyplot as plt
import sys
import numpy as np
import os

## DataLoader

In [None]:
import scipy
import imageio
from skimage.transform import resize
from glob import glob
import numpy as np

class DataLoader():
  def __init__(self, dataset_name, img_res = (128, 128)):
    self.dataset_name = dataset_name
    self.img_res = img_res

  def load_data(self, domain, batch_size = 1, is_testing = False):
    data_type = "train%s" % domain if not is_testing else "test%s" % domain
    path = glob('./datasets/%s/%s/*' % (self.dataset_name, data_type))

    batch_images = np.random.choice(path, size = batch_size)

    imgs = []
    for img_path in batch_images:
      img = self.imread(img_path)
      if not is_testing:
        img = resize(img, self.img_res)

        if np.random.random() > 0.5:
          img = np.fliplr(img)
      else:
        img = resize(img, self.img_res)
      imgs.append(img)

    img = np.array(imgs) / 127.5 - 1.

    return imgs

  def load_batch(self, batch_size = 1, is_testing = False):
    data_type = 'train' if not is_testing else 'val'
    path_A = glob('./datasets/%s/%sA/*' % (self.dataset_name, data_type))
    path_B = glob('./datasets/%s/%sB/*' % (self.dataset_name, data_type))

    self.n_batches = int(min(len(path_A), len(path_B)) / batch_size)
    total_samples = self.n_batches * batch_size

    path_A = np.random.choice(path_A, total_samples, replace = False)
    path_B = np.random.choice(path_B, total_samples, replace = False)

    for i in range(self.n_batches - 1):
      batch_A = path_A[i * batch_size : (i+1) * batch_size]
      batch_B = path_B[i * batch_size : (i+1) * batch_size]
      imgs_A, imgs_B = [], []
      for img_A, img_B in zip(batch_A, batch_B):
        img_A = self.imread(img_A)
        img_B = self.imread(img_B)

        img_A = resize(img_A, self.img_res)
        img_B = resize(img_B, self.img_res)

        if not is_testing and np.random.random() > 0.5:
          img_A = np.fliplr(img_A)
          img_B = np.fliplr(img_B)

        imgs_A.append(img_A)
        imgs_B.append(img_B)

      imgs_A = np.array(imgs_A) / 127.5 - 1.
      imgs_B = np.array(imgs_B) / 127.5 - 1.

      yield imgs_A, imgs_B

  def imread(self, path):
    return imageio.imread(path, pilmode = 'RGB').astype(np.float)

## CycleGAN 정의

In [None]:
class CycleGAN():
  def __init__(self):
    self.img_rows = 28
    self.img_cols = 28
    self.channels = 1
    self.img_shape = (self.img_rows, self.img_cols, self.channels)

    self.dataset_name = 'apple2ornage'
    self.data_loader = DataLoader(dataset_name = self.dataset_name, img_res = (self.img_rows, self.img_cols))

    # D(PatchGAN)의 출력 크기 계산
    patch = int(self.img_rows / 2 ** 4)
    self.disc_patch = (patch, patch, 1)
    '''
    patch Check
    '''
    # G와 D의 첫 번재 층에 있는 필터의 개수
    self.gf = 32
    self.df = 64

    #손실 가중치
    self.lambda_cycle = 10.0 #Cycle-consistency loss
    self.lambda_id = 0.9 * self.lambda_cycle #identity loss

    optimizer = Adam(0.0002, 0.5)

    # 판별자 compile
    self.d_A = self.build_discriminator()
    self.d_B = self.build_discriminator()
    self.d_A.compile(loss = 'mse', optimizer = optimizer, metrics = ['accuracy'])
    self.d_B.compile(loss = 'mse', optimizer = optimizer, metrics = ['accuracy'])

    #생성자

    self.g_AB = self.build_generator()
    self.g_BA = self.build_generator()

    img_A = Input(shape = self.img_shape)
    img_B = Input(shape = self.img_shape)

    fake_B = self.g_AB(img_A)
    fake_A = self.g_BA(img_B)

    reconstr_A = self.g_BA(fake_B)
    reconstr_B = self.g_AB(fake_A)

    img_A_id = self.g_BA(img_A)
    img_B_id = self.g_AB(img_B)

    self.d_A.trainable = False
    self.d_B.trainable = False

    # 변환된 이미지 판별
    valid_A = self.d_A(fake_A)
    valid_B = self.d_B(fake_B)

    self.conbined = Model(inputs = [img_A, img_B], outputs = [valid_A, valid_B, \
      reconstr_A, reconstr_B, img_A_id, img_B_id])

    self.combined.compile(loss = ['mse','mse','mae','mae','mae','mae'], loss_weights = [1, 1, \
      self.lambda_cycle, self.lambda_cycle, self.lambda_id, self.lambda_id], optimizer = optimizer)


### 정적 메서드

In [None]:
class CycleGAN(CycleGAN):
  @staticmethod
  def conv2d(layer_inputm filters, f_size = 4, normaliztion = True):
    d = Conv2D(filters, kernel_size = f_size, strides = 2, padding = 'same')(layer_input)
    d = LeakyReLU(alpha = 0.2)(d)
    if normalization:
      d = InstanceNormalization()(d)

    return d

  @staticmethod
  def deconv2d(layer_input, skip_input, filters, f_size ,dropout_rate = 0):
    u = UpSampling2D(size = 2)(layer_input)
    u = Conv2D(filters, kernel_size = f_size mstrides = 1, padding = 'same', activation = 'relu')(u)
    if dropout_rate:
      u = Dropout(dropout_rate)(u)

    u = InstanceNormalization()(u)
    u = Concatenate()([u, skip_input])
    return u
  

## 생성자

In [None]:
class CycleGAN(CycleGAN):
  def build_generator(self):
    #U-Net
    d0 = Input(shape = self.img_shape)
    d1 = self.conv2d(d0, self.gf)
    d2 = self.conv2d(d1, self.gf * 2)
    d3 = self.conv2d(d2, self.gf * 4)
    d4 = self.conv2d(d3, self.gf * 8)

    u1 = self.deconv2d(d4, d3, self.gf * 4)
    u2 = self.deconv2d(u1, d2, self.gf * 2)
    u3 = self.deconv2d(u2, d1, self.gf)

    u4 = UpSampling2D(size = 2)(u3)
    output_img = Conv2D(self.channels, kernel_size = 4, strides = 1, padding = 'same', activation = 'tanh')(u4)

    return Model(d0, output_img)

## 판별자

In [None]:
class CycleGAN(CycleGAN):
  def build_discriminator(self):
    img = Input(shaep = self.img_shape)

    d1 = self.conv2d(img, self.df, normalization = False)
    d2 = self.conv2d(d1, self.df * 2)
    d3 = self.conv2d(d2, self.df * 4)
    d4 = self.conv2d(D3, self.df * 8)

    validity = Conv2D(1, kernel_size = 4, strides = 1, padding = 'same')

    return Model(img, validity)

## 샘플링 함수

In [None]:
class CycleGAN(CycleGAN):
  def sample_images(self, epoch, batch_i):
    r, c = 2, 3

    imgs_A = self.data_loader.load_data(domain = 'A', batch_size=1, is_testing=True)
    imgs_B = self.data_loader.load_data(domain = 'B', batch_size=1, is_testing=True)

    fake_B = self.g_AB.predict(imgs_A)
    fake_A = self.g_BA.predict(imgs_B)

    reconstr_A = self.g_BA.predict(fake_B)
    reconstr_B = self.g_AB.predict(fake_A)
    
    gen_imgs = np.concatenate([imgs_A, fake_B, reconstr_A, imgs_B, fake_A, reconstr_B])

    gen_imgs = 0.5 * gen_imgs + 0.5

    titles = ['Original','Translated','Reconstrcted']
    fig, axs = plt.subplot(r, c)
    cnt = 0

    for i in range(r):
      for j in range(c):
        axs[i, j].imshow(gen_imgs[cnt])
        axs[i, j].set_title(titles[j])
        axs[i, j].axis('off')
        cnt += 1

    fig.savefig('images/%s/%d_%d.png' % (self.dataset_name, epoch, batch_i))
    plt.show()

## 학습

In [None]:
class CycleGAN(CycleGAN):
  def train(self, epochs, batch_size = 1, sample_interval = 50):
    valid = np.ones((batch_size,) + self.disc_patch)
    fake = np.zeors((batch_size,) + self.disc_patch)

    for epoch in range(epochs):
      for batch_i, (imgs_A, imgs_B) in enumerate(self.data_loader.load_batch(batch_size)):
        fake_B = self.g_AB.predict(imgs_A)
        fake_A = self.g_BA.predict(imgs_B)

        dA_loss_real = self.d_A.train_on_batch(imgs_A, valid)
        dA_loss_fake = self.d_A.train_on_batch(fake_A, fake)
        dA_loss = 0.5 * np.add(dA_loss_real, dA_loss_fake)

        dB_loss_real = self.d_B.train_on_batch(imgs_B, valid)
        dB_loss_fake = self.d_B.train_on_batch(fake_B, fake)
        dB_loss = 0.5 * np.add(dB_loss_real, dB_loss_fake)

        d_loss = 0.5 * np.add(dA_loss, dB_loss)

        g_loss = self.conbined.train_on_batch([imgs_A, imgs_B],\
          [valid, valid, imgs_A, imgs_B, imgs_A, imgs_B])

          if batch_i % sample_interval == 0:
            self.sample_images(epoch, batch_i)

In [None]:
cycle_gan = CycleGAN()
cyclegan.train(epochs = 100, batch_size = 64, sample_interval = 10)