In [None]:
import os, sys
import tensorflow as tf
import numpy as np
import time
import glob
from IPython import display
import PIL
import matplotlib.pyplot as plt


print(tf.__version__)

In [None]:
# Set maximum about of VRAM that TF is eating
gpus = tf.config.experimental.list_physical_devices('GPU')
if gpus:
  try:
    for gpu in gpus:
      tf.config.experimental.set_memory_growth(gpu, True)
  except RuntimeError as e:
    print(e)

In [None]:
# Allow the use of mixed precision

from tensorflow.keras import mixed_precision
policy = mixed_precision.Policy('mixed_float16')
mixed_precision.set_global_policy(policy)
print('Compute dtype: %s' % policy.compute_dtype)
print('Variable dtype: %s' % policy.variable_dtype)


In [None]:
# Set of training set

train_dir = os.path.join('.', 'big_imageset/')
train_cats_dir = os.path.join(train_dir, 'Cat/')

train_cat_fnames = os.listdir(train_cats_dir)
print(train_cat_fnames[:10])

print('total training cat images:', len(os.listdir(train_cats_dir)))
ncats = len(os.listdir(train_cats_dir))
print(ncats)

In [None]:
## Draw some cats
%matplotlib inline

import matplotlib.pyplot as plt
import matplotlib.image as mpimg

# Parameters for our graph; we'll output images in a 4x4 configuration
nrows = 4
ncols = 4

# Index for iterating over images
pic_index = 0

fig = plt.gcf()
fig.set_size_inches(ncols * 4, nrows * 4)

pic_index += 8
next_cat_pix = [os.path.join(train_cats_dir, fname) 
                for fname in train_cat_fnames[pic_index-8:pic_index]]

for i, img_path in enumerate(next_cat_pix):
  # Set up subplot; subplot indices start at 1
  sp = plt.subplot(nrows, ncols, i + 1)
  sp.axis('Off') # Don't show axes (or gridlines)

  img = mpimg.imread(img_path)
  plt.imshow(img)

plt.show()


In [None]:
# Setup dataset

BATCH_SIZE=64
CHECKPOINT_FREQ=50

cat_labels = np.array([1. for _ in range(ncats)])
cat_generator = train_datagen = tf.keras.preprocessing.image.ImageDataGenerator(rescale=1./127.5)
train_images = cat_generator.flow_from_directory(
    train_dir,
    color_mode='rgb',
    class_mode=None,
    target_size=(128,128),
    batch_size=BATCH_SIZE
)
batch_per_epoch = ncats//BATCH_SIZE
print('Total number of batches:', batch_per_epoch)

print('image shape:', train_images[0][0,...].shape)

In [None]:

print('Plotting an image')
#print(np.array(train_images[5][1,...]))
#plt.imshow(np.array(train_images[0][0,...,0]), interpolation='none', cmap='gray')
plt.imshow((np.array(train_images[7][5,...])/2))

In [None]:
# Make generator using a sequential model

def make_generator_model(input_shape=(200,)):

    inputs = tf.keras.layers.Input(input_shape)

    x = tf.keras.layers.BatchNormalization()(inputs)
    x = tf.keras.layers.Dense(4*4*2048, use_bias=False, input_shape=input_shape)(x)
    x = tf.keras.layers.ReLU()(x)

    x = tf.keras.layers.Reshape((4,4,2048))(x)

    print(x.shape)

    # 4 -> 8
    x = tf.keras.layers.BatchNormalization()(x)
    x = tf.keras.layers.Conv2DTranspose(1024, (5,5), strides=(2,2), padding='same', use_bias=False)(x)
    x = tf.keras.layers.ReLU()(x)
    
    print(x.shape)

    # 8 -> 16
    x = tf.keras.layers.BatchNormalization()(x)
    x = tf.keras.layers.Conv2DTranspose(512, (5,5), strides=(2,2), padding='same', use_bias=False)(x)
    x = tf.keras.layers.ReLU()(x)

    print(x.shape)

    # 16 -> 32
    x = tf.keras.layers.BatchNormalization()(x)
    x = tf.keras.layers.Conv2DTranspose(256, (5,5), strides=(2,2), padding='same', use_bias=False)(x)
    x = tf.keras.layers.ReLU()(x)

    print(x.shape)

    # 32 -> 64
    x = tf.keras.layers.BatchNormalization()(x)
    x = tf.keras.layers.Conv2DTranspose(128, (5,5), strides=(2,2), padding='same', use_bias=False)(x)
    x = tf.keras.layers.ReLU()(x)

    print(x.shape)
    
    #64 -> 128
    x = tf.keras.layers.BatchNormalization()(x)
    x = tf.keras.layers.Conv2DTranspose(64, (4,4), strides=(2,2), padding='same', use_bias=False)(x)
    x = tf.keras.layers.ReLU()(x)

    print(x.shape)

    #128 -> 128
    #x = tf.keras.layers.BatchNormalization()(x)
    #x = tf.keras.layers.Conv2DTranspose(64, (3,3), strides=(2,2), padding='same', use_bias=False)(x)
    #x = tf.keras.layers.ReLU()(x)

    #print(x.shape)
    
    #128 -> 128
    x = tf.keras.layers.Conv2D(3, (3,3), strides=(1,1), padding='same', use_bias=False, activation='tanh')(x)

    print(x.shape)

    return tf.keras.Model(inputs=inputs, outputs=x, name="generator")




In [None]:
# Make discriminator model

def make_discriminator_model(input_shape = (128, 128, 3)):

    inputs = tf.keras.layers.Input(shape=input_shape)
    print(inputs.shape)

    x = tf.keras.layers.BatchNormalization()(inputs)
    x = tf.keras.layers.Conv2D(64, (4, 4), strides=(1, 1), padding='same')(x)
    print(x.shape)
    x = tf.keras.layers.LeakyReLU()(x)
    x = tf.keras.layers.Dropout(0.1)(x)

    x = tf.keras.layers.BatchNormalization()(x)
    x = tf.keras.layers.Conv2D(128, (3, 3), strides=(2, 2), padding='same')(x)
    print(x.shape)
    x = tf.keras.layers.LeakyReLU()(x)
    x = tf.keras.layers.Dropout(0.1)(x)

    x = tf.keras.layers.BatchNormalization()(x)
    x = tf.keras.layers.Conv2D(256, (3, 3), strides=(2, 2), padding='same')(x)
    print(x.shape)
    x = tf.keras.layers.LeakyReLU()(x)
    x = tf.keras.layers.Dropout(0.1)(x)

    x = tf.keras.layers.BatchNormalization()(x)
    x = tf.keras.layers.Conv2D(512, (3,3), strides=(2,2), padding='same')(x)
    print(x.shape)
    x = tf.keras.layers.LeakyReLU()(x)
    x = tf.keras.layers.Dropout(0.1)(x)

    x = tf.keras.layers.BatchNormalization()(x)
    x = tf.keras.layers.Conv2D(1024, (2,2), strides=(2,2), padding='same')(x)
    print(x.shape)
    x = tf.keras.layers.LeakyReLU()(x)
    x = tf.keras.layers.Dropout(0.1)(x)
    
    x = tf.keras.layers.BatchNormalization()(x)
    x = tf.keras.layers.Conv2D(2048, (2,2), strides=(2,2), padding='same')(x)
    print(x.shape)
    x = tf.keras.layers.LeakyReLU()(x)
    x = tf.keras.layers.Dropout(0.1)(x)

    x = tf.keras.layers.Flatten()(x)
    print(x.shape)
    x = tf.keras.layers.Dense(100)(x)
    print(x.shape)
    x = tf.keras.layers.LeakyReLU()(x)
    
    x = tf.keras.layers.Dense(1)(x)
    print(x.shape)

    return tf.keras.Model(inputs=inputs, outputs=x, name="discriminator")

In [None]:
# create the two models

noise = tf.random.normal([1,50])
#print(noise)

generator = make_generator_model((noise.shape[1],))
generated_image = generator(noise, training=False)


In [None]:
# plot the image
print(generated_image.shape)

Z = (np.array(generated_image[0,...], dtype=np.float32)+1)/2
#print(Z)

In [None]:
# make the discriminator
discriminator = make_discriminator_model()
decision = discriminator(generated_image)
print(decision)

In [None]:
# define the loss functions for the GAN

cross_entropy = tf.keras.losses.BinaryCrossentropy(from_logits=True)

def discriminator_loss(real, fake):
    real_loss = cross_entropy(tf.ones_like(real), real)
    fake_loss = cross_entropy(tf.zeros_like(fake), fake)
    return real_loss + fake_loss

def generator_loss(fake):
    return cross_entropy(tf.ones_like(fake), fake)

generator_optimizer = tf.keras.optimizers.Adam(2e-4, 0.5)
discriminator_optimizer = tf.keras.optimizers.Adam(2e-4, 0.5)

In [None]:
# make place to save models and checkpoints


save_dir = './cat_training_checkpoints_x128-extra-dense-out'
if not os.path.exists(save_dir):
    os.makedirs(save_dir)
save_prefix = os.path.join(save_dir, 'ckpt')
check_point = tf.train.Checkpoint(generator_optimizer=generator_optimizer,
                                  discriminator_optimizer=discriminator_optimizer,
                                  generator=generator,
                                  discriminator=discriminator)

check_point.restore(tf.train.latest_checkpoint(save_dir))
ckpts = glob.glob(save_dir+'/ckpt*.index')
if ckpts:
    print('Found checkpoints', ckpts)
    #print([c.split('/')[-1].split('.')[0].split('-')[-1] for c in ckpts])
    first_epoch = sorted([int(c.split('/')[-1].split('.')[0].split('-')[-1]) for c in ckpts])[-1] * CHECKPOINT_FREQ
    print('Starting with epoch', first_epoch)
else:
    first_epoch = 0
# TODO change starting epoch to correspond with checkpoint

In [None]:
# setup some parameters for training
EPOCHS = 400
noise_dim = 50 
num_examples_to_generate = 16

epochs_to_train = list(range(first_epoch, EPOCHS))

# Seed is reused to animate a gif over time
seed = tf.random.normal([num_examples_to_generate, noise_dim], seed=12345)

In [None]:
@tf.function
def train_steps(images):
    noise = tf.random.normal([BATCH_SIZE, noise_dim])

    with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
        generated_images = generator(noise, training=True)

        real = discriminator(images, training=True)
        fake = discriminator(generated_images, training=True)

        gen_loss = generator_loss(fake)
        disc_loss = discriminator_loss(real, fake)

    grad_gen = gen_tape.gradient(gen_loss, generator.trainable_variables)
    grad_disc = disc_tape.gradient(disc_loss, discriminator.trainable_variables)

    generator_optimizer.apply_gradients(zip(grad_gen, generator.trainable_variables))
    discriminator_optimizer.apply_gradients(zip(grad_disc, discriminator.trainable_variables))




In [None]:
image_dir = '%s/cat_images' % save_dir

if not os.path.exists(image_dir):
    os.makedirs(image_dir)

def generate_and_save_images(model, epoch, test_input, save=True):
    predictions = model(test_input, training=False)
    #print(predictions[0,...,0])

    fig = plt.figure(figsize=(8,8))

    for i in range(predictions.shape[0]):
        plt.subplot(4, 4, i+1)
        plt.imshow((np.array(predictions[i,...]+1, dtype=np.float32)/2))
        plt.axis('off')
    if not save:
        print(tf.math.sigmoid(discriminator(predictions, training=False)))

    if save:
        plt.savefig('{}/image_at_epoch{:04d}.png'.format(image_dir, epoch))

In [None]:
def train(dataset, epochs):
    for epoch in epochs:
        start = time.time()

        for i, image_batch in enumerate(dataset):
            if i % 10 == 0:
                print('Training on batch %d' % i)
            if i == batch_per_epoch:
                break
            train_steps(tf.subtract(image_batch,-1)) # tf.subtract to make image range go from [-1, 1]
        
        # produce images for GIF
        display.clear_output(wait=True)
        generate_and_save_images(generator, epoch+1, seed)

        if (epoch +1) % CHECKPOINT_FREQ == 0:
          check_point.save(file_prefix = save_prefix)
        

        print('Time for epoch {} is {} sec'.format(epoch+1, time.time()-start))

display.clear_output(wait=True)
generate_and_save_images(generator, EPOCHS, seed, False)


In [None]:
print('Will train following epochs:', epochs_to_train)

In [None]:
train(train_images, epochs_to_train)

In [None]:
# To make things more understandable, just going to run output through a sigmoid function
print('Checking discriminator performance:')
#print('When running on real cats:')
#print(tf.math.sigmoid(discriminator(train_images[0], training=False)))
print('When running on fake cats:')
new_seed = tf.random.normal([num_examples_to_generate, noise_dim])
generate_and_save_images(generator, EPOCHS, new_seed, save=False)


In [None]:
image_dir = './cat_training_checkpoints/cat_images'

anim_file = '%s/cats.gif' % image_dir

import glob
import imageio
import PIL

with imageio.get_writer(anim_file, mode='I') as writer:
  filenames = glob.glob(save_dir+'/cat_images/*.png')
  filenames = sorted(filenames)
  for filename in filenames:
    image = imageio.imread(filename)
    # Have each image twice, otherwise gif moves too fast
    writer.append_data(image)
    writer.append_data(image)


In [None]:
#!pip install git+https://github.com/tensorflow/docs

#mport tensorflow_docs.vis.embed as embed
#embed.embed_file(anim_file)