# Pix2pix GAN

In [None]:
import tensorflow as tf
import matplotlib.pyplot as plt

## 데이터 준비

In [None]:
!wget http://efrosgans.eecs.berkeley.edu/pix2pix/datasets/facades.tar.gz -O ./facades.tar.gz
!tar -xvzf facades.tar.gz >> log.txt

In [None]:
path = '/content/facades'
image = tf.io.read_file(f'{path}/train/100.jpg')
image = tf.io.decode_jpeg(image)
print(image.shape)

plt.imshow(image)
plt.show()

In [None]:
path = '/content/facades'
image = tf.io.read_file(f'{path}/train/100.jpg')
image = tf.io.decode_jpeg(image)

w = tf.shape(image)[1] // 2
input, real = image[:, w:, :], image[:, :w, :]

plt.subplot(121)
plt.imshow(input)
plt.subplot(122)
plt.imshow(real)
plt.show()

In [None]:
def preprocess_train(file):
    input, real = load(file)
    input, real = random_jitter(input, real)
    return input, real

def preprocess_test(file):
    input, real = load(file)
    return input, real

def load(path):
    image = tf.io.decode_jpeg(tf.io.read_file(path))
    image = (tf.cast(image, tf.float32) - 127.5) / 127.5
    w = tf.shape(image)[1] // 2
    return image[:, w:, :], image[:, :w, :]

def random_jitter(input, real):
    input = tf.image.resize(input, [286, 286], method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)
    real = tf.image.resize(real, [286, 286], method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)

    stacked = tf.stack([input, real], axis=0)
    croped = tf.image.random_crop(stacked, size=[2, 256, 256, 3])

    input, real = croped[0], croped[1]
    if tf.random.uniform(()) > 0.5:
        input = tf.image.flip_left_right(input)
        real = tf.image.flip_left_right(real)

    return input, real

In [None]:
path = '/content/facades'
train = tf.data.Dataset.list_files(f'{path}/train/*.jpg')
train = train.map(preprocess_train).shuffle(400).batch(4)
test = tf.data.Dataset.list_files(f"{path}/test/*.jpg")
test = test.map(preprocess_test).batch(5)

In [None]:
import matplotlib.pyplot as plt

input, real = next(iter(train))
print(input.shape, real.shape)

plt.subplot(121)
plt.imshow(input[0] * 0.5 + 0.5)

plt.subplot(122)
plt.imshow(real[0] * 0.5 + 0.5)

plt.show()

## 모델 생성

In [None]:
def downsample(x, filters):
    h = tf.keras.layers.Conv2D(filters, 4, 2, padding='same')(x)
    h = tf.keras.layers.BatchNormalization()(h)
    h = tf.keras.layers.Activation('relu')(h)
    return h

def upsample(x, filters, dropout=False):
    h = tf.keras.layers.Conv2DTranspose(filters, 4, 2, padding='same')(x)
    h = tf.keras.layers.BatchNormalization()(h)
    if dropout:
        h = tf.keras.layers.Dropout(0.5)(h)
    h = tf.keras.layers.Activation('relu')(h)
    return h

In [None]:
def make_generator():
    x = tf.keras.layers.Input(shape=[256, 256, 3])

    h1 = tf.keras.layers.Conv2D(64, 4, 2, 'same', activation="relu")(x)
    h2 = downsample(h1, 128)
    h3 = downsample(h2, 256)
    h4 = downsample(h3, 512)
    h5 = downsample(h4, 512)
    h6 = downsample(h5, 512)
    h7 = downsample(h6, 512)
    h8 = downsample(h7, 512) # (1, 1, 512)

    h = upsample(h8, 512, dropout=True)
    h = tf.keras.layers.Concatenate()([h, h7])
    h = upsample(h, 512, dropout=True)
    h = tf.keras.layers.Concatenate()([h, h6])
    h = upsample(h, 512, dropout=True)
    h = tf.keras.layers.Concatenate()([h, h5])
    h = upsample(h, 512)
    h = tf.keras.layers.Concatenate()([h, h4])
    h = upsample(h, 256)
    h = tf.keras.layers.Concatenate()([h, h3])
    h = upsample(h, 128)
    h = tf.keras.layers.Concatenate()([h, h2])
    h = upsample(h, 64)
    h = tf.keras.layers.Concatenate()([h, h1])
    
    y = tf.keras.layers.Conv2DTranspose(3, 4, 2, padding='same', activation='tanh')(h)

    return tf.keras.Model(x, y)

In [None]:
def make_discriminator():
    x1 = tf.keras.layers.Input(shape=[256, 256, 3], name='input_image')
    x2 = tf.keras.layers.Input(shape=[256, 256, 3], name='real_image')
    x = tf.keras.layers.concatenate([x1, x2])  # (256, 256, 6)

    h = tf.keras.layers.Conv2D(64, 4, 2, padding='same', activation="relu")(x)
    h = downsample(h, 128)
    h = downsample(h, 256)
    h = tf.keras.layers.Conv2D(512, 4, padding="same")(h)
    h = tf.keras.layers.BatchNormalization()(h)
    h = tf.keras.layers.Activation("relu")(h)
    y = tf.keras.layers.Conv2D(1, 4, padding="same")(h)  # (30, 30, 1)

    return tf.keras.Model([x1, x2], y)

In [None]:
class Pix2pix(tf.keras.Model):
    def __init__(self, generator, discriminator):
        super(Pix2pix, self).__init__()
        self.compile()

        self.generator = generator 
        self.discriminator = discriminator

        self.g_optimizer = tf.keras.optimizers.Adam(learning_rate=0.0002, beta_1=0.5)
        self.d_optimizer = tf.keras.optimizers.Adam(learning_rate=0.0002, beta_1=0.5)
        self.crossentropy = tf.keras.losses.BinaryCrossentropy(from_logits=True)

        self.d_loss = tf.keras.metrics.Mean(name="d_loss")
        self.g_loss = tf.keras.metrics.Mean(name="g_loss")

    def discriminator_loss(self, y_real, y_generate):
        real_loss = self.crossentropy(tf.ones_like(y_real), y_real)
        generate_loss = self.crossentropy(tf.zeros_like(y_generate), y_generate)
        return real_loss + generate_loss
            
    def generator_loss(self, y_generate, generate, real):
        g_loss = self.crossentropy(tf.ones_like(y_generate), y_generate)
        l1_loss = tf.reduce_mean(tf.abs(real - generate))
        return g_loss + (100 * l1_loss)

    def update_metrics(self, g_loss, d_loss):
        self.g_loss.update_state(g_loss)
        self.d_loss.update_state(d_loss)

    def train_step(self, dataset):
        input, real = dataset

        with tf.GradientTape() as g_tape, tf.GradientTape() as d_tape:
            generate = self.generator(input, training=True)

            y_real = self.discriminator([input, real], training=True)
            y_generate = self.discriminator([input, generate], training=True)

            g_loss = self.generator_loss(y_generate, generate, real)
            d_loss = self.discriminator_loss(y_real, y_generate)

        g_gradients = g_tape.gradient(g_loss, self.generator.trainable_variables)
        d_gradients = d_tape.gradient(d_loss, self.discriminator.trainable_variables)

        self.g_optimizer.apply_gradients(zip(g_gradients, self.generator.trainable_variables))
        self.d_optimizer.apply_gradients(zip(d_gradients, self.discriminator.trainable_variables))

        self.update_metrics(g_loss, d_loss)
        return {
            "d_loss": self.d_loss.result(),
            "g_loss": self.g_loss.result(),
        }

## 모델 훈련

In [None]:
tf.keras.backend.clear_session()

generator = make_generator()
discriminator = make_discriminator()
gan = Pix2pix(generator, discriminator) 

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from IPython import display

class Monitor(tf.keras.callbacks.Callback):
    def on_epoch_end(self, epoch, logs=None):
        display.clear_output(wait=True)
        input, real = next(iter(test))
        generated = self.model.generator(input)

        for i in range(3):
            plt.figure(figsize=(10, 10))

            plt.subplot(131)
            plt.imshow(input[i] * 0.5 + 0.5)

            plt.subplot(132)
            plt.imshow(real[i] * 0.5 + 0.5)

            plt.subplot(133)
            plt.imshow(generated[i] * 0.5 + 0.5)

            plt.show()

gan.fit(train, epochs=50, callbacks=[Monitor()])

## 결과 확인

In [None]:
for input, real in test.take(5):
    print(input.shape, real.shape)
    
    pred = generator.predict(input)
    plt.figure(figsize=(15, 15))

    display_list = [input[0], real[0], pred[0]]
    title = ['Input Image', 'Ground Truth', 'Predicted Image']

    for i in range(3):
        plt.subplot(1, 3, i+1)
        plt.title(title[i])
        plt.imshow(display_list[i] * 0.5 + 0.5)
        plt.axis('off')
    plt.show()

# Cycle Gan

## 데이터 준비 - unpaired
- Pix2pix의 preprocess 함수를 그대로 이용. 

In [None]:
path = '/content/facades'
files_train = tf.data.Dataset.list_files(f'{path}/train/*.jpg')
train = files_train.map(preprocess_train).shuffle(400).batch(400)
files_test = tf.data.Dataset.list_files(f"{path}/test/*.jpg")
test = files_test.map(preprocess_test).batch(5)

# unpairing
trainA, trainB = next(iter(train))
trainB = tf.random.shuffle(trainB)
train = tf.data.Dataset.from_tensor_slices((trainA.numpy(), trainB.numpy()))
train = train.shuffle(400).batch(4)

In [None]:
import matplotlib.pyplot as plt

input, real = next(iter(train))

plt.subplot(121)
plt.imshow(input[0] * 0.5 + 0.5)

plt.subplot(122)
plt.imshow(real[0] * 0.5 + 0.5)

plt.show()

## 모델 생성 
- generator는 Pix2pix에서 사용한 make_generator 함수를 그대로 이용한다. 

In [None]:
def make_discriminator():
    x = tf.keras.layers.Input(shape=[256, 256, 3], name='input_image')

    h = tf.keras.layers.Conv2D(64, 4, 2, padding='same', activation="relu")(x)
    h = downsample(h, 128)
    h = downsample(h, 256)
    h = tf.keras.layers.Conv2D(512, 4, padding="same")(h)
    h = tf.keras.layers.BatchNormalization()(h)
    h = tf.keras.layers.Activation("relu")(h)
    y = tf.keras.layers.Conv2D(1, 4, padding="same")(h)  # (30, 30, 1)

    return tf.keras.Model(x, y)
    

In [None]:
class CycleGAN(tf.keras.Model):
    def __init__(self, genAB, genBA, discA, discB):
        super(CycleGAN, self).__init__()
        self.compile()

        # 생성기 
        self.genAB = genAB
        self.genBA = genBA
        self.gAB_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)
        self.gBA_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)

        # 판별기
        self.discA = discA 
        self.discB = discB 
        self.dA_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)
        self.dB_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)

        self.crossentropy = tf.keras.losses.BinaryCrossentropy(from_logits=True)
        self.d_loss_metric = tf.keras.metrics.Mean(name="d_loss")
        self.g_loss_metric = tf.keras.metrics.Mean(name="g_loss")

    def discriminator_loss(self, y_real, y_generate):
        real_loss = self.crossentropy(tf.ones_like(y_real), y_real)
        generate_loss = self.crossentropy(tf.zeros_like(y_generate), y_generate)
        return (real_loss + generate_loss) * 0.5

    def generator_loss(self, y_generate):
        g_loss = self.crossentropy(tf.ones_like(y_generate), y_generate)
        return g_loss

    def cycle_loss(self, real, cycle):
        return 10 * tf.reduce_mean(tf.abs(real - cycle))

    def identity_loss(self, real, same):
        return 10 * 0.5 * tf.reduce_mean(tf.abs(real - same))        
        
    def train_step(self, dataset):
        realA, realB = dataset
        
        with tf.GradientTape(persistent=True) as tape:
            generateB = self.genAB(realA, training=True)
            cycleA = self.genBA(generateB, training=True)
            generateA = self.genBA(realB, training=True)
            cycleB = self.genAB(generateA, training=True)
            sameA = self.genBA(realA, training=True)
            sameB = self.genAB(realB, training=True)

            y_realA = self.discA(realA, training=True)
            y_realB = self.discB(realB, training=True)
            y_generateA = self.discA(generateA, training=True)
            y_generateB = self.discB(generateB, training=True)

            # Total generator loss = adversarial loss + cycle loss
            total_cycle_loss = self.cycle_loss(realA, cycleA) + self.cycle_loss(realB, cycleB)
            gAB_loss = self.generator_loss(y_generateB) + total_cycle_loss + self.identity_loss(realB, sameB)
            gBA_loss = self.generator_loss(y_generateA) + total_cycle_loss + self.identity_loss(realA, sameA)
            dA_loss = self.discriminator_loss(y_realA, y_generateA)
            dB_loss = self.discriminator_loss(y_realB, y_generateB)
    
        # Calculate the gradients for generator and discriminator
        gAB_gradients = tape.gradient(gAB_loss, self.genAB.trainable_variables)
        gBA_gradients = tape.gradient(gBA_loss, self.genBA.trainable_variables)
        dA_gradients = tape.gradient(dA_loss, self.discA.trainable_variables)
        dB_gradients = tape.gradient(dB_loss, self.discB.trainable_variables)

        self.gAB_optimizer.apply_gradients(zip(gAB_gradients, self.genAB.trainable_variables))
        self.gBA_optimizer.apply_gradients(zip(gBA_gradients, self.genBA.trainable_variables))
        self.dA_optimizer.apply_gradients(zip(dA_gradients, self.discA.trainable_variables))
        self.dB_optimizer.apply_gradients(zip(dB_gradients, self.discB.trainable_variables))
        
        self.d_loss_metric.update_state(dA_loss + dB_loss)
        self.g_loss_metric.update_state(gAB_loss + gBA_loss)

        return {
            "d_loss": self.d_loss_metric.result(),
            "g_loss": self.g_loss_metric.result(),
        }

## 모델 학습

In [None]:
tf.keras.backend.clear_session()

genAB = make_generator()
genBA = make_generator()
discA = make_discriminator()
discB = make_discriminator()
gan = CycleGAN(genAB, genBA, discA, discB) 

In [None]:
import matplotlib.pyplot as plt

g_real = genAB(input)
g_input = genBA(real)

plt.figure(figsize=(8, 8))
plt.subplot(221)
plt.imshow(input[0] * 0.5 + 0.5)
plt.subplot(222)
plt.imshow(g_real[0] * 0.5 * 8 + 0.5)
plt.subplot(223)
plt.imshow(real[0] * 0.5 + 0.5)
plt.subplot(224)
plt.imshow(g_input[0] * 0.5 * 8 + 0.5)
plt.show()

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from IPython import display

class Monitor(tf.keras.callbacks.Callback):
    def on_epoch_end(self, epoch, logs=None):
        display.clear_output(wait=True)
        testA, testB = next(iter(test))
        generated = self.model.genAB(testA)

        for i in range(3):
            plt.subplot(131)
            plt.imshow(testA[i] * 0.5 + 0.5)

            plt.subplot(132)
            plt.imshow(testB[i] * 0.5 + 0.5)

            plt.subplot(133)
            plt.imshow(generated[i] * 0.5 + 0.5)

            plt.show()

gan.fit(train, epochs=20, callbacks=[Monitor()])

## 결과 확인

In [None]:
for testA, testB in test.take(5):
    generated = genAB(testA)
    plt.figure(figsize=(10, 10))

    display_list = [testA[0], testB[0], generated[0]]
    title = ['Input Image', 'Target Image', 'Predicted Image']
    
    for i in range(3):
        plt.subplot(1, 3, i+1)
        plt.title(title[i])
        plt.imshow(display_list[i] * 0.5 + 0.5)
        plt.axis('off')
    plt.show()