In [1]:
import tensorflow as tf
import glob
import imageio
import matplotlib.pyplot as plt
import numpy as np
import os
import PIL
from tensorflow.keras import layers
import time
from IPython import display

In [2]:

BUFFER_SIZE = 60000
BATCH_SIZE = 256

(train_images, train_labels), (_, _) = tf.keras.datasets.fashion_mnist.load_data()
train_images = train_images.reshape(train_images.shape[0], 28, 28, 1).astype('float32')
train_images = (train_images - 127.5) / 127.5 # Normalize the images to [-1, 1]

train_dataset = tf.data.Dataset.from_tensor_slices((train_images, train_labels)).shuffle(BUFFER_SIZE).batch(BATCH_SIZE)

In [3]:
def make_generator_model(latent_dim, n_classes=10):
    input_label = layers.Input(shape=(1,))
    label_emb = layers.Embedding(n_classes, 50)(input_label)
    label_emb = layers.Dense(7*7, use_bias=False)(label_emb)
    label_emb = layers.Reshape((7, 7, 1))(label_emb)
    
    input_noise = layers.Input(shape=(latent_dim,))
    noise_emb = layers.Dense(7*7*256, use_bias=False)(input_noise)
    noise_emb = layers.BatchNormalization()(noise_emb)
    noise_emb = layers.LeakyReLU()(noise_emb)

    input_merged = layers.Concatenate()([label_emb, noise_emb])

    gen = layers.Conv2DTranspose(128, (5, 5), strides=(1, 1), padding='same', use_bias=False)(input_merged)
    gen = layers.BatchNormalization()(gen)
    gen = layers.LeakyReLU()(gen)


    gen = layers.Conv2DTranspose(64, (5, 5), strides=(2, 2), padding='same', use_bias=False)(gen)
    gen = layers.BatchNormalization()(gen)
    gen = layers.LeakyReLU()(gen)

    out = layers.Conv2DTranspose(1, (5, 5), strides=(2, 2), padding='same', use_bias=False, activation='tanh')(gen)

    model = tf.keras.Model([input_noise, input_label], out)
    return model

In [2]:
def make_discriminator(n_classes=10):
    input_label = layers.Input(shape=(1,))
    label_emb = layers.Embedding(n_classes, 50)(input_label)
    label_emb = layers.Dense(28*28, use_bias=False)(label_emb)
    label_emb = layers.Reshape((28, 28, 1))(label_emb)
    
    input_img = layers.Input(shape=(28,28, 1))
    img_emb = layers.Dense(7*7*256, use_bias=False)(input_img)
    img_emb = layers.BatchNormalization()(img_emb)
    img_emb = layers.LeakyReLU()(img_emb)

    input_merged = layers.Concatenate()([label_emb, img_emb])

    disc = layers.Conv2D(64, (5,5), strides=(2,2), padding='same')(input_merged)
    disc = layers.LeakyReLU()(disc)
    disc = layers.Dropout(0.3)(disc)

    disc = layers.Conv2D(128, (5,5), strides=(2,2), padding='same')(disc)
    disc = layers.LeakyReLU()(disc)
    disc = layers.Dropout(0.3)(disc)

    disc = layers.Flatten()(disc)

    out = layers.Dense(1, activation='sigmoid')(disc)

    model = tf.keras.Model([input_noise, input_label], out)
    return model