In [1]:
import glob
import imageio
import matplotlib.pyplot as plt
import numpy as np
import os
import PIL
import tensorflow as tf
from tensorflow.keras import layers
import time
import preprocess

In [117]:
def dcg_generator_model():
    model = tf.keras.Sequential()
    model.add(layers.Dense(45*45*64, use_bias=False, input_shape=(100,)))
    model.add(layers.BatchNormalization())
    model.add(layers.LeakyReLU())

    model.add(layers.Reshape((45, 45, 64)))
    assert model.output_shape == (None, 45, 45, 64) # Note: None is the batch size

    model.add(layers.Conv2DTranspose(32, (5, 5), strides=(1, 1), padding='same', use_bias=False))
    assert model.output_shape == (None, 45, 45, 32)
    model.add(layers.BatchNormalization())
    model.add(layers.LeakyReLU())

    model.add(layers.Conv2DTranspose(16, (5, 5), strides=(2, 2), padding='same', use_bias=False))
    assert model.output_shape == (None, 90,90, 16)
    model.add(layers.BatchNormalization())
    model.add(layers.LeakyReLU())

    model.add(layers.Conv2DTranspose(3, (5, 5), strides=(2, 2), padding='same', use_bias=False, activation='sigmoid'))
    assert model.output_shape == (None, 180, 180, 3)

    return model

def dcg_discriminator_model():
    model = tf.keras.Sequential()
    model.add(layers.Conv2D(32, (5, 5), strides=(2, 2), padding='same',
                                     input_shape=[180, 180, 3]))
    model.add(layers.LeakyReLU())
    model.add(layers.Dropout(0.3))

    model.add(layers.Conv2D(64, (5, 5), strides=(2, 2), padding='same'))
    model.add(layers.LeakyReLU())
    model.add(layers.Dropout(0.3))

    model.add(layers.Flatten())
    model.add(layers.Dense(1))

    return model


In [88]:
def main():
    
    return
if __name__ == '__main__':
    main()

In [118]:
images = preprocess.load_images("/git-repos/latent-space-arithmetic/dataset/")

Found 202599 files belonging to 4 classes.


In [119]:
generator = dcg_generator_model()
noise = tf.random.normal([1, 100])
generated_image = generator(noise, training=False)

# plt.imshow(generated_image[0, :, :, 0], cmap='gray')

discriminator = dcg_discriminator_model()
decision = discriminator(generated_image)


In [120]:
cross_entropy = tf.keras.losses.BinaryCrossentropy(from_logits=True)
generator_optimizer = tf.keras.optimizers.Adam(1e-4)
discriminator_optimizer = tf.keras.optimizers.Adam(1e-4)

def discriminator_loss(real_output, fake_output):
    real_loss = cross_entropy(tf.ones_like(real_output), real_output)
    fake_loss = cross_entropy(tf.zeros_like(fake_output), fake_output)
    total_loss = real_loss + fake_loss
    return total_loss
def generator_loss(fake_output):
    return cross_entropy(tf.ones_like(fake_output), fake_output)


In [121]:
EPOCHS = 50
noise_dim = 100
num_examples_to_generate = 16
BATCH_SIZE = 32
# We will reuse this seed overtime (so it's easier)
# to visualize progress in the animated GIF)
seed = tf.random.normal([num_examples_to_generate, noise_dim])

In [122]:
# Notice the use of `tf.function`
# This annotation causes the function to be "compiled".
@tf.function
def train_step(images):
    noise = tf.random.normal([BATCH_SIZE, noise_dim])

    with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
        generated_images = generator(noise, training=True)

        real_output = discriminator(images, training=True)
        fake_output = discriminator(generated_images, training=True)

        gen_loss = generator_loss(fake_output)
        disc_loss = discriminator_loss(real_output, fake_output)

    gradients_of_generator = gen_tape.gradient(gen_loss, generator.trainable_variables)
    gradients_of_discriminator = disc_tape.gradient(disc_loss, discriminator.trainable_variables)

    generator_optimizer.apply_gradients(zip(gradients_of_generator, generator.trainable_variables))
    discriminator_optimizer.apply_gradients(zip(gradients_of_discriminator, discriminator.trainable_variables))

In [126]:
def train(dataset, epochs):
    for epoch in range(epochs):
        start = time.time()

    for image_batch, label_batch in dataset:
        train_step(image_batch/255.0)
        print(time.time()-start)

    # Produce images for the GIF as we go
    display.clear_output(wait=True)
    generate_and_save_images(generator,
                             epoch + 1,
                             seed)

    # Save the model every 15 epochs
    if (epoch + 1) % 15 == 0:
        checkpoint.save(file_prefix = checkpoint_prefix)

    print ('Time for epoch {} is {} sec'.format(epoch + 1, time.time()-start))

    # Generate after the final epoch
    display.clear_output(wait=True)
    generate_and_save_images(generator,
                           epochs,
                           seed)

In [127]:
def generate_and_save_images(model, epoch, test_input):
    # Notice `training` is set to False.
    # This is so all layers run in inference mode (batchnorm).
    predictions = model(test_input, training=False)

    fig = plt.figure(figsize=(4,4))

    for i in range(predictions.shape[0]):
        plt.subplot(4, 4, i+1)
        plt.imshow(predictions[i, :, :, 0] * 255)
        plt.axis('off')

#     plt.savefig('image_at_epoch_{:04d}.png'.format(epoch))
    plt.show()

In [None]:
train(images, EPOCHS)

0.3062119483947754
0.38497042655944824
0.4737558364868164
0.5585050582885742
0.627354621887207
0.6981334686279297
0.7669482231140137
0.8377599716186523
0.9066097736358643
0.9813742637634277
1.053206205368042
1.1229979991912842
1.1928410530090332
1.2636444568634033
1.3334648609161377
1.4032464027404785
1.4740586280822754
1.5428967475891113
1.6137194633483887
1.685492992401123
1.7543449401855469
1.8241462707519531
1.8939659595489502
1.9657437801361084
2.041539192199707
2.112382173538208
2.181166410446167
2.2509801387786865
2.321814775466919
2.392625331878662
2.4654078483581543
2.5392093658447266
2.623985528945923
2.7107508182525635
2.789541244506836
2.8633768558502197
2.9411582946777344
3.029897928237915
3.1116786003112793
3.187494993209839
3.2612791061401367
3.3480818271636963
3.434814929962158
3.5076558589935303
3.576467514038086
3.6462807655334473
3.7180583477020264
3.790894031524658
3.8656628131866455
3.9385011196136475
4.011273384094238
4.081086874008179
4.157881021499634
4.24567842

In [None]:
(train_images, train_labels), (_, _) = tf.keras.datasets.mnist.load_data()

train_images = train_images.reshape(train_images.shape[0], 28, 28, 1).astype('float32')


In [34]:
train_images.shape

(60000, 28, 28, 1)