In [None]:
# !pip install -q imageio

In [None]:
import tensorflow as tf
import glob
import imageio
import matplotlib.pyplot as plt
import tensorflow_datasets as tfds
import numpy as np
import os
import PIL
import pandas as pd
from tensorflow.keras import layers
import time
from keras.datasets.fashion_mnist import load_data

from IPython import display

In [None]:
(train_images, train_labels), (_, _) = load_data()

In [None]:
train_images = train_images.reshape(train_images.shape[0],28*28).astype('float32')
train_labels=train_labels.astype('float32')
train_images=train_images/255.0
label=np.zeros((train_images.shape[0],10))
for i in range(train_images.shape[0]):
  label[i,int(train_labels[i])]=1
train_images=np.concatenate((train_images,label),axis=1)
label[0]

In [None]:
BUFFER_SIZE=60000
BATCH_SIZE=256

In [None]:
train_dataset=tf.data.Dataset.from_tensor_slices(train_images).shuffle(BUFFER_SIZE).batch(BATCH_SIZE)

In [None]:
noise_dim=100
Y_dimension=10
input_dimension=noise_dim+Y_dimension

In [None]:
def make_generator_model():
  model=tf.keras.Sequential()
  
  model.add(layers.Dense(7*7*256, use_bias=False, input_shape=(input_dimension,)))
  model.add(layers.BatchNormalization())
  model.add(layers.LeakyReLU())

  model.add(layers.Reshape((7,7,256)))
  assert model.output_shape==(None,7,7,256)     #NOne is for Batch Size

  model.add(layers.Conv2DTranspose(128,(5,5),strides=(1,1),use_bias=False,padding='same'))
  assert model.output_shape==(None,7,7,128)
  model.add(layers.BatchNormalization())
  model.add(layers.LeakyReLU())

  model.add(layers.Conv2DTranspose(64,(5,5),strides=(2,2),use_bias=False,padding='same'))
  assert model.output_shape==(None,14,14,64)
  model.add(layers.BatchNormalization())
  model.add(layers.LeakyReLU())

  model.add(layers.Conv2DTranspose(1,(5,5),strides=(2,2),padding='same',use_bias=False,activation='tanh'))
  assert model.output_shape==(None,28,28,1)

  return model

In [None]:
def make_descriminator_model():
  model=tf.keras.Sequential()

  model.add(layers.Dense(128, use_bias=True,input_shape=(28*28+10,)))
  model.add(layers.BatchNormalization())
  model.add(layers.LeakyReLU())

  model.add(layers.Reshape((2,2,32)))
  assert model.output_shape==(None,2,2,32)

  model.add(layers.Conv2D(64,(5,5),strides=(2,2),padding='same'))
  model.add(layers.LeakyReLU())
  model.add(layers.Dropout(0.3))

  model.add(layers.Conv2D(128,(5,5),strides=(2,2),padding='same'))
  model.add(layers.LeakyReLU())
  model.add(layers.Dropout(0.3))

  model.add(layers.Flatten())
  model.add(layers.Dense(1))

  return model

In [None]:
generator=make_generator_model()
noise=tf.random.normal([BATCH_SIZE,noise_dim])

for i in train_dataset:
  X=i
  break
y=X[:,-10:]
noise=np.concatenate((noise,y),axis=1)
generated_image=generator(noise,training=False)
plt.imshow(generated_image[0,:,:,0],cmap='gray')
print(y[0])

In [None]:
generated_image=tf.reshape(generated_image,[BATCH_SIZE,28*28])
generated_image=np.concatenate([generated_image,y],axis=-1)
generated_image.shape

In [None]:
discriminator=make_descriminator_model()
decision=discriminator(generated_image)
print(decision)

In [None]:
cross_entropy=tf.losses.BinaryCrossentropy(from_logits=True)

In [None]:
def discriminator_loss(real_output,fake_output):
  real_loss=cross_entropy(tf.ones_like(real_output),real_output)
  fake_loss=cross_entropy(tf.zeros_like(fake_output),fake_output)
  
  return fake_loss+real_loss

In [None]:
def generator_loss(fake_output):
  return cross_entropy(tf.ones_like(fake_output),fake_output)

In [None]:
generator_optimizer=tf.keras.optimizers.Adam(1e-4)
discriminator_optimizer=tf.keras.optimizers.Adam(1e-4)

In [None]:
checkpoint_dir='./training_checkpoint'
checkpoint_prefix=os.path.join(checkpoint_dir,"ckpt")
checkpoint = tf.train.Checkpoint(generator_optimizer=generator_optimizer,
                                 discriminator_optimizer=discriminator_optimizer,
                                 generator=generator,
                                 discriminator=discriminator)

In [None]:
EPOCHS=500
num_examples_to_generate=16

seed=tf.random.normal([num_examples_to_generate,noise_dim])

In [None]:
@tf.function
def train_step(images):
  noise=tf.random.normal([images.shape[0],noise_dim])
  y=images[:,-10:]
  y=tf.cast(y,dtype='float32')
  noise=tf.concat([noise,y],axis=1)

  with tf.GradientTape() as gen_tape,tf.GradientTape() as disc_tape:
    generated_images=generator(noise,training=True)

    generated_images=tf.reshape(generated_images,[images.shape[0],28*28])
    generated_images=tf.concat([generated_images,y],axis=-1)

    real_output=discriminator(images,training=True)
    fake_output=discriminator(generated_images,training=True)

    gen_loss=generator_loss(fake_output)
    disc_loss=discriminator_loss(real_output,fake_output)

  gradients_generator = gen_tape.gradient(gen_loss, generator.trainable_variables)
  gradient_discriminator=disc_tape.gradient(disc_loss,discriminator.trainable_variables)

  generator_optimizer.apply_gradients(zip(gradients_generator,generator.trainable_variables))
  discriminator_optimizer.apply_gradients(zip(gradient_discriminator,discriminator.trainable_variables))

In [None]:
def generate_and_save(model,epoch,test_input):
  prediction=model(test_input,training=False)

  plt.figure(figsize=(4,4))
  for i in range(prediction.shape[0]):
    plt.subplot(4,4,i+1)
    plt.imshow(prediction[i,:,:,0]*255,cmap='gray')
    plt.axis('off')

  if epoch%50==0:
    plt.savefig("plot at epoch {}.png".format(epoch))
  plt.show()

In [None]:
def train(dataset,epochs):
  for epoch in range(epochs):
    start=time.time()

    for image_batch in dataset:
      train_step(image_batch)
    
    # display.clear_output(wait=True)
    test_input=np.zeros((seed.shape[0],10))
    test_input[:,epoch%10]=1
    test_input=np.concatenate((seed,test_input),axis=1)
    generate_and_save(generator,epoch+1,test_input)

    if((epoch+1)%15==0):
      checkpoint.save(file_prefix=checkpoint_prefix)
    
    print("Time for epoch {} is {}".format(epoch,time.time()-start))
  
  # display.clear_output(wait=True)
  generate_and_save(generator,epochs,seed)

In [None]:
train(train_dataset,EPOCHS)

In [None]:
# checkpoint.restore(tf.train.latest_checkpoint(checkpoint_dir))

In [None]:
def display_image(epoch):
  return PIL.Image.open("plot at epoch {}.png".format(450))

In [None]:
display_image(EPOCHS)

In [None]:
export_dir='Fashion mnist GAN'
tf.saved_model.save(generator,export_dir)

In [None]:
converter=tf.lite.TFLiteConverter.from_saved_model(export_dir)
tflite_model=converter.convert()

with open('fashion_mnist_gan.tflite','wb') as f:
  f.write(tflite_model)

In [None]:
from google.colab import files
try:
  files.download('fashion_mnist_gan.tflite')
except ImportError:
  pass

In [None]:
input=tf.random.normal([10,100])
label=np.zeros((10,10))
for i in range(10):
  label[i,i]=1
input=tf.concat([input,label],axis=1)
prediction=generator.predict(input)

plt.figure(figsize=(4,4))
for i in range(prediction.shape[0]):
  plt.subplot(4,4,i+1)
  plt.imshow(prediction[i,:,:,0]*255,cmap='gray')
  plt.axis('off')
plt.show()

In [None]:
interpreter=tf.lite.Interpreter(model_path='fashion_mnist_gan.tflite')
interpreter.allocate_tensors()

input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()

# Test model on random input data.
input_shape = input_details[0]['shape']
input_data = input[0]
input_data=tf.reshape(input[0],[1,110])
interpreter.set_tensor(input_details[0]['index'], input_data)

In [None]:
interpreter.invoke()
output_data = interpreter.get_tensor(output_details[0]['index'])

In [None]:
plt.figure(figsize=(4,4))
plt.subplot(4,4,1)
plt.imshow(output_data[0,:,:,0]*255,cmap='gray')
plt.axis('off')