In [1]:
import tensorflow as tf
import tensorflow_probability
import numpy as np
import tensorflow.keras.layers as layers
from tensorflow.python.framework.ops import EagerTensor
from random import random
from matplotlib import pyplot as plt
import matplotlib.gridspec as gridspec
%matplotlib inline
from tqdm.notebook import tqdm
import tensorflow_gan as tfgan
from tensorflow.compat.v1.train import AdamOptimizer

W0512 10:00:41.073614 140277278357312 module_wrapper.py:138] From /home/student/.local/lib/python3.6/site-packages/tensorflow_gan/python/estimator/tpu_gan_estimator.py:42: The name tf.estimator.tpu.TPUEstimator is deprecated. Please use tf.compat.v1.estimator.tpu.TPUEstimator instead.



In [2]:
BATCH_SIZE = 32
NOISE_DIM = 32
MODEL_DIR = 'models'
RUN_CONFIG = tf.estimator.RunConfig(save_summary_steps=None, save_checkpoints_secs=None, save_checkpoints_steps=500)

In [3]:
def add_random_noise(image):
    return tf.random.normal([NOISE_DIM]), image

def load_dataset(batch_size):
    (train_images, _), (_, _) = tf.keras.datasets.cifar10.load_data()
    train_images = train_images.reshape([-1, 32, 32, 3]).astype('float32')
    train_images = train_images/127.5  - 1    
    dataset = tf.data.Dataset.from_tensor_slices(train_images)
    dataset = dataset.map(add_random_noise)
    dataset = dataset.cache()
    dataset = dataset.shuffle(len(train_images))
    dataset = dataset.batch(batch_size, drop_remainder=True)
    dataset = dataset.prefetch(1)
    dataset = dataset.repeat(None)
    return dataset

def plot_sample(images):
    num_samples = min(32, len(images))

    grid = gridspec.GridSpec(1, num_samples)
    grid.update(left=0, bottom=0, top=1, right=1, wspace=0.01, hspace=0.01)
    fig = plt.figure(figsize=[num_samples, 1])
    
    for y in range(num_samples):
        ax = fig.add_subplot(grid[0, y])
        ax.set_axis_off()
        ax.imshow((images[y] + 1.0)/2)
    plt.show()


In [8]:
dataset = load_dataset(BATCH_SIZE)
for noise, sample in dataset:
    plot_sample(sample)
    break

def generator_model():
    z_in = tf.keras.Input(shape=(NOISE_DIM,))
    
    outputs = []

    x = layers.Dense(32*32*256)(z_in)
    x = layers.BatchNormalization()(x)
    x = layers.LeakyReLU()(x)
    x = layers.Reshape((32, 32, 256))(x)

    x = layers.Conv2DTranspose(128, (5, 5), strides=(1, 1),
        padding='same')(x)
    x = layers.BatchNormalization()(x)
    x = layers.LeakyReLU()(x)

    x = layers.Conv2D(128, (5, 5), strides=(1, 1), padding='same')(x)
    x = layers.BatchNormalization()(x)
    x = layers.LeakyReLU()(x)

    x = layers.Conv2DTranspose(3, (5, 5), strides=(1, 1),
        padding='same', activation='tanh')(x)

    model = tf.keras.Model(inputs=z_in, outputs=x)
    return model


NameError: name 'load_dataset' is not defined

In [None]:
def discriminator_model():
    input = tf.keras.Input(shape=(32, 32, 3))

    x = layers.Conv2D(64, (5, 5), strides=(2, 2),
        padding='same')(input)
    x = layers.LeakyReLU()(x)
    x = layers.Dropout(0.3)(x)
    
    x = layers.Conv2D(128, (3, 3), strides=(1, 1), padding='same')(x)
    x = layers.LeakyReLU()(x)
    x = layers.Dropout(0.3)(x)

    x = layers.Conv2D(64, (5, 5), strides=(2, 2), padding='same')(x)
    x = layers.LeakyReLU()(x)
    x = layers.Dropout(0.3)(x)

    x = layers.Flatten()(x)
    out = layers.Dense(3)(x)
    
    model = tf.keras.Model(inputs=input, outputs=out)
    return model

def generator_fn(input):
    model = generator_model()
    return model(input, training=True)

#FIXME:
def discriminator_fn(generated_data, input):
    model = discriminator_model()
    return model(generated_data, training=True)

def input_fn():
    dataset = load_dataset(BATCH_SIZE)
    return dataset

tf.compat.v1.disable_eager_execution()
tf.estimator.ProfilerHook(output_dir=MODEL_DIR, save_steps=100, show_memory=False)
estimator = tfgan.estimator.GANEstimator(model_dir=MODEL_DIR,
               generator_fn=generator_fn,
               discriminator_fn=discriminator_fn,
               generator_loss_fn=tfgan.losses.wasserstein_generator_loss,
               discriminator_loss_fn=tfgan.losses.wasserstein_discriminator_loss,
               generator_optimizer=AdamOptimizer(1e-3),
               discriminator_optimizer=AdamOptimizer(1e-3),
               warm_start_from='models/model.ckpt',
               config=RUN_CONFIG)

estimator.train(input_fn, max_steps=500)

def predict_input_fn():
    noise = tf.random.normal([BATCH_SIZE, NOISE_DIM])
    return noise

prediction_iterable = estimator.predict(predict_input_fn)
predictions = np.array([next(prediction_iterable) for _ in range(32)])

plot_sample(predictions)
