In [145]:
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import numpy as np
import matplotlib.pyplot as plt
import os

In [146]:
image_size = 64
batch_size = 64
latent_dim = 100
num_epochs = 50
save_dir = '../data'
os.makedirs(save_dir, exist_ok=True)

In [147]:
def preprocess(path):
    img = tf.io.read_file(path)
    img = tf.io.decode_png(img, channels=3)
    
    img = tf.image.resize_with_pad(img, image_size, image_size)
    
    img = tf.cast(img, tf.float32)
    img = (img / 127.5) - 1.0
    
    img.set_shape([image_size, image_size, 3])
    return img

dataset = tf.data.Dataset.list_files('../dataset/*.png').shuffle(
    1000,
).map(
    preprocess, num_parallel_calls=tf.data.AUTOTUNE,
).batch(
    batch_size, drop_remainder=True,
).prefetch(tf.data.AUTOTUNE)

In [148]:
def make_generator():
    model = keras.Sequential([
        layers.Input(shape=(latent_dim,)),
        layers.Reshape((1, 1, latent_dim)),
        layers.Conv2DTranspose(512, 4, strides=1, padding='valid', use_bias=False),
        layers.BatchNormalization(),
        layers.ReLU(),
        layers.Conv2DTranspose(256, 4, strides=2, padding='same', use_bias=False),
        layers.BatchNormalization(),
        layers.ReLU(),
        layers.Conv2DTranspose(128, 4, strides=2, padding='same', use_bias=False),
        layers.BatchNormalization(),
        layers.ReLU(),
        layers.Conv2DTranspose(64, 4, strides=2, padding='same', use_bias=False),
        layers.BatchNormalization(),
        layers.ReLU(),
        layers.Conv2DTranspose(3, 4, strides=2, padding='same', activation='tanh')
    ])
    return model

In [149]:
def make_discriminator():
    model = keras.Sequential([
        layers.Input(shape=(64, 64, 3)),
        layers.Conv2D(64, 4, strides=2, padding='same'),
        layers.LeakyReLU(0.2),
        layers.Conv2D(128, 4, strides=2, padding='same', use_bias=False),
        layers.BatchNormalization(),
        layers.LeakyReLU(0.2),
        layers.Conv2D(256, 4, strides=2, padding='same', use_bias=False),
        layers.BatchNormalization(),
        layers.LeakyReLU(0.2),
        layers.Conv2D(512, 4, strides=2, padding='same', use_bias=False),
        layers.BatchNormalization(),
        layers.LeakyReLU(0.2),
        layers.Conv2D(1, 4, strides=1, padding='valid'),
        layers.Flatten(),
        layers.Activation('sigmoid')
    ])
    return model

In [150]:
generator = make_generator()
discriminator = make_discriminator()

In [151]:
cross_entropy = keras.losses.BinaryCrossentropy()
generator_optimizer = keras.optimizers.Adam(0.0002, beta_1=0.5)
discriminator_optimizer = keras.optimizers.Adam(0.0002, beta_1=0.5)

In [152]:
fixed_noise = tf.random.normal([25, latent_dim])

In [153]:
@tf.function
def train_step(images):
    noise = tf.random.normal([batch_size, latent_dim])

    with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
        generated_images = generator(noise, training=True)

        real_output = discriminator(images, training=True)
        fake_output = discriminator(generated_images, training=True)

        real_loss = cross_entropy(tf.ones_like(real_output), real_output)
        fake_loss = cross_entropy(tf.zeros_like(fake_output), fake_output)
        disc_loss = (real_loss + fake_loss) / 2

        gen_loss = cross_entropy(tf.ones_like(fake_output), fake_output)

    gradients_of_discriminator = disc_tape.gradient(disc_loss, discriminator.trainable_variables)
    discriminator_optimizer.apply_gradients(zip(gradients_of_discriminator, discriminator.trainable_variables))

    gradients_of_generator = gen_tape.gradient(gen_loss, generator.trainable_variables)
    generator_optimizer.apply_gradients(zip(gradients_of_generator, generator.trainable_variables))

    return disc_loss, gen_loss


In [154]:
def generate_and_save_images(model, epoch, test_input):
    predictions = model(test_input, training=False)
    predictions = (predictions + 1) * 127.5

    plt.figure(figsize=(5,5))
    for i in range(25):
        plt.subplot(5,5,i+1)
        plt.imshow(predictions[i].numpy().astype('uint8'))
        plt.axis('off')
    plt.savefig(f'{save_dir}/epoch_{epoch:03d}.png')
    plt.close()

In [155]:
for epoch in range(num_epochs):
    for batch in dataset:
        d_loss, g_loss = train_step(batch)

    print(f'Epoch {epoch+1}, D Loss: {d_loss}, G Loss: {g_loss}')

    generate_and_save_images(generator, epoch + 1, fixed_noise)

Epoch 1, D Loss: 0.0437578409910202, G Loss: 10.256768226623535
Epoch 2, D Loss: 0.18808400630950928, G Loss: 1.789301872253418
Epoch 3, D Loss: 0.06618992984294891, G Loss: 6.82210111618042
Epoch 4, D Loss: 0.24231873452663422, G Loss: 1.4839198589324951
Epoch 5, D Loss: 0.3612321615219116, G Loss: 5.992333889007568
Epoch 6, D Loss: 0.22319740056991577, G Loss: 2.3483009338378906
Epoch 7, D Loss: 0.20988160371780396, G Loss: 2.183026075363159
Epoch 8, D Loss: 0.30152982473373413, G Loss: 1.38065767288208
Epoch 9, D Loss: 0.3443082869052887, G Loss: 1.1181130409240723
Epoch 10, D Loss: 0.4050460755825043, G Loss: 3.2214674949645996
Epoch 11, D Loss: 0.6893696188926697, G Loss: 2.4957332611083984
Epoch 12, D Loss: 0.21532689034938812, G Loss: 2.8754186630249023
Epoch 13, D Loss: 0.147647887468338, G Loss: 2.809462785720825
Epoch 14, D Loss: 0.14291104674339294, G Loss: 6.0212202072143555
Epoch 15, D Loss: 0.1491013467311859, G Loss: 4.106727600097656
Epoch 16, D Loss: 0.5393797755241394

In [156]:
generator.save(f'{save_dir}/generator.h5')
discriminator.save(f'{save_dir}/discriminator.h5')

