# CycleGAN : 이미지를 이미지로 변환하는 GAN 모델
* 이미지 --> 이미지 색조 등의 변환 --> 다시 이미지 재배치 해서 성능 비교
* 판별자 2개와 생성자 2개(A->B & B->A)

In [1]:
!pip install git+https://www.github.com/keras-team/keras-contrib.git

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting git+https://www.github.com/keras-team/keras-contrib.git
  Cloning https://www.github.com/keras-team/keras-contrib.git to /tmp/pip-req-build-dqucae5d
  Running command git clone -q https://www.github.com/keras-team/keras-contrib.git /tmp/pip-req-build-dqucae5d
Building wheels for collected packages: keras-contrib
  Building wheel for keras-contrib (setup.py) ... [?25l[?25hdone
  Created wheel for keras-contrib: filename=keras_contrib-2.0.8-py3-none-any.whl size=101077 sha256=3416430f743405b879663d122dba6c0e75c6cdd5e709a4ab5e20ae7a2e73419a
  Stored in directory: /tmp/pip-ephem-wheel-cache-9b9ctzqx/wheels/bb/1f/f2/b57495012683b6b20bbae94a3915ec79753111452d79886abc
Successfully built keras-contrib
Installing collected packages: keras-contrib
Successfully installed keras-contrib-2.0.8


In [2]:
!pip install scipy==1.1.0

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting scipy==1.1.0
  Downloading scipy-1.1.0-cp37-cp37m-manylinux1_x86_64.whl (31.2 MB)
[K     |████████████████████████████████| 31.2 MB 1.6 MB/s 
Installing collected packages: scipy
  Attempting uninstall: scipy
    Found existing installation: scipy 1.7.3
    Uninstalling scipy-1.7.3:
      Successfully uninstalled scipy-1.7.3
[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
pymc 4.1.4 requires scipy>=1.4.1, but you have scipy 1.1.0 which is incompatible.
plotnine 0.8.0 requires scipy>=1.5.0, but you have scipy 1.1.0 which is incompatible.
jaxlib 0.3.14+cuda11.cudnn805 requires scipy>=1.5, but you have scipy 1.1.0 which is incompatible.
jax 0.3.14 requires scipy>=1.5, but you have scipy 1.1.0 which is incompatible.
aeppl 0.0.33 requires scipy>=1.4.0,

In [5]:
from __future__ import print_function, division
from tensorflow.keras.datasets import mnist
from keras_contrib.layers.normalization.instancenormalization import InstanceNormalization
from tensorflow.keras.layers import (Input, Dense, Reshape, Flatten, Dropout,
                                     Concatenate, BatchNormalization, Activation, ZeroPadding2D,
                                     LeakyReLU, UpSampling2D, Conv2D)
from tensorflow.keras.models import Sequential, Model
from tensorflow.keras.optimizers import Adam
from glob import glob
import scipy
import datetime
import matplotlib.pyplot as plt
import numpy as np
import sys
import os

In [6]:
# DataLoader --> CycleGAN 학습에 필요한 데이터 로드

class DataLoader():

  def __init__(self, dataset_name, img_res=(128,128)):
    
    self.dataset_name = dataset_name
    self.img_res = img_res
  

  def load_data(self, domain, batch_size=1, is_testing=False):
    
    data_type = 'train%s' % domain if not is_testing else 'test%s' % domain
    path = glob('/content/drive/MyDrive/com_vision_study/data/%s/%s/*' % (self.dataset_name, data_type))    # self.dataset_name(apple2orange) 및 data_type(A또는B) 지정

    batch_images = np.random.choice(path, size=batch_size)  # batch_size 크기에 맞게 이미지 랜덤 추출

    imgs = []

    for img_path in batch_images:
      img = self.imread(img_path)

      if not is_testing:  # is_testing=True 이면
        img = scipy.misc.imresize(img, self.img_res)   # 이미지 사이즈 조정

        if np.random.random() > 0.5:
          img = np.fliplr(img)  # np.fliplr --> 이미지 상하좌우 전체반전
        
      else:   #is_testing=False이면
        img = scipy.misc.imresize(img, self.img_res)
      
      imgs.append(img)
    
    imgs = np.array(imgs) / 127.5 - 1

    return imgs


  def load_batch(self, batch_size=1, is_testing=False):

    data_type = 'train' if not is_testing else 'val'
    path_A = glob('/content/drive/MyDrive/com_vision_study/data/%s/%sA/*' %(self.dataset_name, data_type)) # train/test 중 data_type에 맞는 A 로드
    path_B = glob('/content/drive/MyDrive/com_vision_study/data/%s/%sB/*' %(self.dataset_name, data_type)) # train/test 중 data_type에 맞는 B 로드

    self.n_batches = int(min(len(path_A), len(path_B)) / batch_size)
    total_samples = self.n_batches * batch_size

    path_A = np.random.choice(path_A, total_samples, replace=False)
    path_B = np.random.choice(path_B, total_samples, replace=False)
    
    for i in range(self.n_batches-1):
      batch_A = path_A[i*batch_size:(i+1)*batch_size]
      batch_B = path_B[i*batch_size:(i+1)*batch_size]
      imgs_A, imgs_B = [], []

      for img_A, img_B in zip(batch_A, batch_B):
        img_A = self.imread(img_A)
        img_B = self.imread(img_B)

        img_A = scipy.misc.imresize(img_A, self.img_res)
        img_B = scipy.misc.imresize(img_B, self.img_res)

        if not is_testing and np.random.random() > 0.5:
          img_A = np.fliplr(img_A) 
          img_B = np.fliplr(img_B)
        
        imgs_A.append(img_A)
        imgs_B.append(img_B)
    
      imgs_A = np.array(imgs_A) / 127.5-1
      imgs_B = np.array(imgs_B) / 127.5-1

      yield imgs_A, imgs_B
  
  def imread(self, path):
    '''
    load_data에서 쓰일 imread하는 함수
    '''
    return scipy.misc.imread(path, mode='RGB').astype(np.float)

In [7]:
# Cycle GAN

class CycleGAN():

  def __init__(self):

    # input_image shape
    self.img_rows = 128
    self.img_cols = 128
    self.channels = 3
    self.img_shape = (self.img_rows, self.img_cols, self.channels)

    self.dataset_name = 'apple2orange'
    # DataLoader 클래스를 사용해 데이터 로드
    self.data_loader = DataLoader(dataset_name=self.dataset_name, 
                                  img_res = (self.img_rows, self.img_cols))
    
    patch = int(self.img_rows / 2**4) # patchGAN의 크기(output shape of D)
    self.disc_patch = (patch, patch, 1)

    # Generator와 Discriminator의 첫번째 층에 들어갈 필터 개수
    self.gf = 32
    self.df = 64

    self.lambda_cycle = 10  # 사이클-일관선 손실 가중치
    self.lambda_id = 0.9 * self.lambda_cycle   # 동일성 손실 가중치
    optimizer = Adam(0.0002, 0.5)


    # 신경망 구성

    # 1. 2개의 판별자 d_A, d_B 만든 후 compile
    self.d_A = self.build_discriminator()
    self.d_B = self.build_discriminator()
    self.d_A.compile(loss='mse', optimizer=optimizer, metrics=['accuracy'])
    self.d_B.compile(loss='mse', optimizer=optimizer, metrics=['accuracy'])

    # 2. Generator A-->B, Generator B-->A 생성
    self.g_AB = self.build_generator()
    self.g_BA = self.build_generator()

    # 3. 양방향 학습에 쓸 입력 이미지 구성
    img_A = Input(shape=self.img_shape)
    img_B = Input(shape=self.img_shape)

    # 4. 각 이미지(img_A, img_B)를 각각 다른 도메인으로 변환
    fake_B = self.g_AB(img_A)
    fake_A = self.g_BA(img_B)

    # 5. 원본 이미지로 이미지 재변환
    reconstr_A = self.g_BA(fake_B)
    reconstr_B = self.g_AB(fake_A)

    # 동일한 이미지 매핑
    img_A_id = self.g_BA(img_A)
    img_B_id = self.g_AB(img_B)

    # 연결 모델에서는 생성자만을 훈련
    self.d_A.trainable = False
    self.d_B.trainable = False
    
    # 판별자가 변환된 이미지의 유효성을 결정
    valid_A = self.d_A(fake_A)
    valid_B = self.d_B(fake_B)

    # 연결모델 --> input으로 img_A와 img_B
    #        --> output으로 valid_A, valid_B, reconstr_A, reconstr_B, img_A_id, img_B_id
    self.combined = Model(inputs=[img_A, img_B],
                          outputs = [valid_A, valid_B,
                                     reconstr_A, reconstr_B,
                                     img_A_id, img_B_id])
    self.combined.compile(loss=['mse', 'mse',
                                'mae', 'mae',
                                'mae', 'mae'],
                          loss_weights=[1,1,
                                        self.lambda_cycle, self.lambda_cycle,
                                        self.lambda_id, self.lambda_id],
                          optimizer=optimizer)

In [None]:
# 위에서 만든 CycleGAN()을 상속시켜 generator및 discriminator, sample_image생성, train을 진행할 것

# 1. generator 및 discriminato에서 사용할 conv2d 및 deconv2d(=upsampling) 정의

class CycleGAN(CycleGAN):
      
      def conv2d(layer_input, filters, f_size=4, normalization=True):
        '''
        conv2d for downsampling
        순서
        1) 일반적인 Conv2D 합성곱층
        2) LeakyReLU 활성화 함수
        3) InstanceNormalization(샘플 정규화 층)
        '''
        d = Conv2D(filters, kernel_size=f_size,
                   strides=2, padding='same')(layer_input)
        d = LeakyReLU(alpha=0.2)(d)
        if normalization:
            d = InstanceNormalization()(d)  # 샘플 정규화 층
        return d
      
      def deconv2d(layer_input, skip_input, filters, f_size=4, dropout_rate=0):
        '''
        deconv2d for upsamling
        순서
        1) input_layer에 대해 Upsampling
        2) 드롭아웃 비율을 지정했다면 드롭아웃 적용
        3) InstanceNormalization
        4) 출력층과 출력층의 차원에 대응하는 downsampling부분에 있는 층의 skip_connection
        '''

        u = UpSampling2D(size=2)(layer_input)
        u = Conv2D(filters, kernel_size=f_size, strides=1,
                    padding='same', activation='relu')(u)
        if dropout_rate:
            u = Dropout(dropout_rate)(u)
        u = InstanceNormalization()(u)
        u = Concatenate()([u, skip_input])
        return u

In [None]:
# Build Generator & Discriminator

class CycleGAN(CycleGAN):
    def build_generator(self):
        '''
        U-Net구조의 Generator --> ResNet으로도 변환 가능
        '''
        # Image 입력
        d0 = Input(shape=self.img_shape) #(128,128,3)

        # Downsampling
        d1 = self.conv2d(d0, self.gf)  # filters=32  --> (64,64,32)
        d2 = self.conv2d(d1, self.gf * 2)  # filters=64  --> (32,32,64)
        d3 = self.conv2d(d2, self.gf * 4)  # filters=128 --> (16,16,128)
        d4 = self.conv2d(d3, self.gf * 8)  # filters=256 --> (8,8,256)

        # Upsampling
        u1 = self.deconv2d(d4, d3, self.gf * 4)   # layer_input=d4, skip_input=d3, filters=128
        u2 = self.deconv2d(u1, d2, self.gf * 2)   # layer_input=u1, skip_input=d2, filters=64 --> u1의 출력을 거치면 filters=64=d2
        u3 = self.deconv2d(u2, d1, self.gf)       # layer_input=d4, skip_input=d3, filters=32 --> u2의 출력을 거치면 filters=32=d1

        u4 = UpSampling2D(size=2)(u3)   # (128,128,64)
        output_img = Conv2D(self.channels, kernel_size=4,
                            strides=1, padding='same', activation='tanh')(u4)

        return Model(d0, output_img) 

class CycleGAN(CycleGAN):
    def build_discriminator(self):
      img = Input(shape=self.img_shape)  # (128, 128, 3)

      d1 = self.conv2d(img, self.df, normalization=False)  # (64,64,64)
      d2 = self.conv2d(d1, self.df * 2)   # (32,32,128)
      d3 = self.conv2d(d2, self.df * 4)   # (16,16,256)
      d4 = self.conv2d(d3, self.df * 8)   # (8,8,512)

      validity = Conv2D(1, kernel_size=4, strides=1, padding='same')(d4)   # (8,8,1)

      return Model(img, validity)

In [None]:
class CycleGAN(CycleGAN):
  
      def sample_images(self, epoch, batch_i):
        r, c = 2, 3

        imgs_A = self.data_loader.load_data(domain="A", batch_size=1, is_testing=True)
        imgs_B = self.data_loader.load_data(domain="B", batch_size=1, is_testing=True)
        
        # Translate images to the other domain
        fake_B = self.g_AB.predict(imgs_A)
        fake_A = self.g_BA.predict(imgs_B)
        # Translate back to original domain
        reconstr_A = self.g_BA.predict(fake_B)
        reconstr_B = self.g_AB.predict(fake_A)

        gen_imgs = np.concatenate([imgs_A, fake_B, reconstr_A, imgs_B, fake_A, reconstr_B])

        # Rescale images 0 - 1
        gen_imgs = 0.5 * gen_imgs + 0.5

        titles = ['Original', 'Translated', 'Reconstructed']
        fig, axs = plt.subplots(r, c)
        cnt = 0
        for i in range(r):
            for j in range(c):
                axs[i,j].imshow(gen_imgs[cnt])
                axs[i, j].set_title(titles[j])
                axs[i,j].axis('off')
                cnt += 1
        fig.savefig(f"{self.dataset_name}_{epoch}_{batch_i}")
        plt.show()

In [None]:
class CycleGAN(CycleGAN):
  
      def train(self, epochs, batch_size=1, sample_interval=50):
        # Adversarial loss ground truths
        valid = np.ones((batch_size,) + self.disc_patch)
        fake = np.zeros((batch_size,) + self.disc_patch)


        for epoch in range(epochs):
            for batch_i, (imgs_A, imgs_B) in enumerate(self.data_loader.load_batch(batch_size)):

                # ----------------------
                #  Train Discriminators
                # ----------------------

                # Translate images to opposite domain
                fake_B = self.g_AB.predict(imgs_A)
                fake_A = self.g_BA.predict(imgs_B)

                # Train the discriminators (original images = real / translated = Fake)
                dA_loss_real = self.d_A.train_on_batch(imgs_A, valid)
                dA_loss_fake = self.d_A.train_on_batch(fake_A, fake)
                dA_loss = 0.5 * np.add(dA_loss_real, dA_loss_fake)

                dB_loss_real = self.d_B.train_on_batch(imgs_B, valid)
                dB_loss_fake = self.d_B.train_on_batch(fake_B, fake)
                dB_loss = 0.5 * np.add(dB_loss_real, dB_loss_fake)

                # Total discriminator loss
                d_loss = 0.5 * np.add(dA_loss, dB_loss)

                # ------------------
                #  Train Generators
                # ------------------

                # Train the generators
                g_loss = self.combined.train_on_batch([imgs_A, imgs_B],
                                                      [valid, valid,
                                                       imgs_A, imgs_B,
                                                       imgs_A, imgs_B])
                # If at save interval => plot the generated image samples
                if batch_i % sample_interval == 0:
                    self.sample_images(epoch, batch_i)

In [9]:
cycle_gan = CycleGAN()
cycle_gan.train(epochs=100, batch_size=64, sample_interval=10)

Output hidden; open in https://colab.research.google.com to view.

In [8]:
class CycleGAN(CycleGAN):
      @staticmethod
      def conv2d(layer_input, filters, f_size=4, normalization=True):
        """Discriminator layer"""
        d = Conv2D(filters, kernel_size=f_size,
                   strides=2, padding='same')(layer_input)
        d = LeakyReLU(alpha=0.2)(d)
        if normalization:
            d = InstanceNormalization()(d)
        return d
      
      @staticmethod
      def deconv2d(layer_input, skip_input, filters, f_size=4, dropout_rate=0):
            """Layers used during upsampling"""
            u = UpSampling2D(size=2)(layer_input)
            u = Conv2D(filters, kernel_size=f_size, strides=1,
                       padding='same', activation='relu')(u)
            if dropout_rate:
                u = Dropout(dropout_rate)(u)
            u = InstanceNormalization()(u)
            u = Concatenate()([u, skip_input])
            return u

class CycleGAN(CycleGAN):
    def build_generator(self):
        """U-Net Generator"""
        # Image input
        d0 = Input(shape=self.img_shape)

        # Downsampling
        d1 = self.conv2d(d0, self.gf)
        d2 = self.conv2d(d1, self.gf * 2)
        d3 = self.conv2d(d2, self.gf * 4)
        d4 = self.conv2d(d3, self.gf * 8)

        # Upsampling
        u1 = self.deconv2d(d4, d3, self.gf * 4)
        u2 = self.deconv2d(u1, d2, self.gf * 2)
        u3 = self.deconv2d(u2, d1, self.gf)

        u4 = UpSampling2D(size=2)(u3)
        output_img = Conv2D(self.channels, kernel_size=4,
                            strides=1, padding='same', activation='tanh')(u4)

        return Model(d0, output_img)

class CycleGAN(CycleGAN):
    def build_discriminator(self):
      img = Input(shape=self.img_shape)

      d1 = self.conv2d(img, self.df, normalization=False)
      d2 = self.conv2d(d1, self.df * 2)
      d3 = self.conv2d(d2, self.df * 4)
      d4 = self.conv2d(d3, self.df * 8)

      validity = Conv2D(1, kernel_size=4, strides=1, padding='same')(d4)

      return Model(img, validity)


class CycleGAN(CycleGAN):
      def sample_images(self, epoch, batch_i):
        r, c = 2, 3

        imgs_A = self.data_loader.load_data(domain="A", batch_size=1, is_testing=True)
        imgs_B = self.data_loader.load_data(domain="B", batch_size=1, is_testing=True)
        
        # Translate images to the other domain
        fake_B = self.g_AB.predict(imgs_A)
        fake_A = self.g_BA.predict(imgs_B)
        # Translate back to original domain
        reconstr_A = self.g_BA.predict(fake_B)
        reconstr_B = self.g_AB.predict(fake_A)

        gen_imgs = np.concatenate([imgs_A, fake_B, reconstr_A, imgs_B, fake_A, reconstr_B])

        # Rescale images 0 - 1
        gen_imgs = 0.5 * gen_imgs + 0.5

        titles = ['Original', 'Translated', 'Reconstructed']
        fig, axs = plt.subplots(r, c)
        cnt = 0
        for i in range(r):
            for j in range(c):
                axs[i,j].imshow(gen_imgs[cnt])
                axs[i, j].set_title(titles[j])
                axs[i,j].axis('off')
                cnt += 1
        fig.savefig(f"{self.dataset_name}_{epoch}_{batch_i}")
        plt.show()

class CycleGAN(CycleGAN):
      def train(self, epochs, batch_size=1, sample_interval=50):
        # Adversarial loss ground truths
        valid = np.ones((batch_size,) + self.disc_patch)
        fake = np.zeros((batch_size,) + self.disc_patch)


        for epoch in range(epochs):
            for batch_i, (imgs_A, imgs_B) in enumerate(self.data_loader.load_batch(batch_size)):

                # ----------------------
                #  Train Discriminators
                # ----------------------

                # Translate images to opposite domain
                fake_B = self.g_AB.predict(imgs_A)
                fake_A = self.g_BA.predict(imgs_B)

                # Train the discriminators (original images = real / translated = Fake)
                dA_loss_real = self.d_A.train_on_batch(imgs_A, valid)
                dA_loss_fake = self.d_A.train_on_batch(fake_A, fake)
                dA_loss = 0.5 * np.add(dA_loss_real, dA_loss_fake)

                dB_loss_real = self.d_B.train_on_batch(imgs_B, valid)
                dB_loss_fake = self.d_B.train_on_batch(fake_B, fake)
                dB_loss = 0.5 * np.add(dB_loss_real, dB_loss_fake)

                # Total discriminator loss
                d_loss = 0.5 * np.add(dA_loss, dB_loss)

                # ------------------
                #  Train Generators
                # ------------------

                # Train the generators
                g_loss = self.combined.train_on_batch([imgs_A, imgs_B],
                                                      [valid, valid,
                                                       imgs_A, imgs_B,
                                                       imgs_A, imgs_B])
                # If at save interval => plot the generated image samples
                if batch_i % sample_interval == 0:
                    self.sample_images(epoch, batch_i)