<a href="https://colab.research.google.com/github/halilagin/gen-adv-net/blob/master/gan_dummy_educational.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!pip install dill
import dill
from pathlib import Path
import os
root = Path("/content/drive/My Drive/root/colab/")
output_dir=Path(str(root.absolute())+"/gan-dummy-educational-output")
os.makedirs(str(output_dir), exist_ok=True)

session_file = str(root.absolute())+"/gan-dummy-educational.db"
dill.load_session(session_file)

#dill.dump_session(session_file)





Using TensorFlow backend.


In [2]:
import tensorflow as tf
import numpy as np
from tensorflow.keras import layers
import matplotlib.pyplot as plt
import keras
%matplotlib inline
tf.__version__

'2.2.0'

In [3]:
TRAINING_SIZE=256
BUFFER_SIZE = 200
BATCH_SIZE = 64
EPOCHS = 2000
noise_dim = 100


In [75]:
#generate a dataset whose elemnts are one of [[1,1],[-1,1],[-1,-1],[1,-1]]
fourpoints_set = tf.tile(tf.constant([[1,1],[-1,1],[-1,-1],[1,-1]]), [TRAINING_SIZE//8,1])

#generate a dataset whose elemnts are points of a circled centered at (0,0)
pi_ = tf.reshape(tf.linspace(-np.pi,np.pi,TRAINING_SIZE//2),[TRAINING_SIZE//2,1])
circle_set = tf.concat([tf.cos(pi_),tf.sin(pi_)], -1)

##generate square with corners [[1,1],[-1,1],[-1,-1],[1,-1]]
line_size=TRAINING_SIZE//8
square_upline_lr = tf.stack([tf.linspace(-1.,1.,line_size),tf.constant([1.0]*line_size)],axis=-1)
square_rightline_tb = tf.stack([tf.constant([1.0]*line_size),tf.linspace(1.,-1.,line_size),],axis=-1)
square_botline_rl = tf.stack([tf.linspace(1.,-1.,line_size),tf.constant([-1.0]*line_size)],axis=-1)
square_leftline_bt = tf.stack([tf.constant([-1.0]*line_size),tf.linspace(-1.,1.,line_size),],axis=-1)
square_set = tf.concat([square_upline_lr,square_rightline_tb,square_botline_rl,square_leftline_bt],axis=0)

#generate a dataset whose elemnts are points of a circled centered at (0,0)
x2 = tf.reshape(tf.linspace(-1.,1.,TRAINING_SIZE//2),[TRAINING_SIZE//2,1])
x2_set = tf.concat([x2,tf.constant(x2**2)], -1)


datasets = {
    "fourpoints":fourpoints_set,
    "circle":circle_set,
    "square":square_set,
    "x2":x2_set
}

train_dataset_name="circle"
training_set = datasets[train_dataset_name]
train_dataset = tf.data.Dataset.from_tensor_slices(tf.reshape(training_set, [1,16,16,1])).shuffle(BUFFER_SIZE).batch(BATCH_SIZE)


In [76]:

def make_generator_model():
    model = tf.keras.Sequential()
    model.add(layers.Dense(4*4*64, use_bias=False, input_shape=(noise_dim,) ) )
    model.add(layers.BatchNormalization())
    model.add(layers.LeakyReLU())

    model.add(layers.Reshape((4, 4 ,64)))
    assert model.output_shape == (None, 4, 4, 64) # Note: None is the batch size

    model.add(layers.Conv2DTranspose(32, (5, 5), strides=(2, 2), padding='same', use_bias=False))
    assert model.output_shape == (None, 8, 8, 32)
    model.add(layers.BatchNormalization())
    model.add(layers.LeakyReLU())

    model.add(layers.Conv2DTranspose(1, (5, 5), strides=(2, 2), padding='same', use_bias=False, activation='tanh'))
    assert model.output_shape == (None, 16, 16, 1)

    return model

def make_discriminator_model():
    model = tf.keras.Sequential()
    model.add(layers.Conv2D(64, (5, 5), strides=(2, 2), padding='same', input_shape=[16, 16, 1]))
    model.add(layers.LeakyReLU())
    model.add(layers.Dropout(0.3))

    model.add(layers.Conv2D(128, (5, 5), strides=(2, 2), padding='same'))
    model.add(layers.LeakyReLU())
    model.add(layers.Dropout(0.3))

    model.add(layers.Flatten())
    model.add(layers.Dense(1))

    return model

# This method returns a helper function to compute cross entropy loss
#cross_entropy = tf.keras.losses.BinaryCrossentropy(from_logits=True)
cross_entropy = tf.keras.losses.MeanSquaredError()

def discriminator_loss(real_output, fake_output):
    real_loss = cross_entropy(tf.ones_like(real_output), real_output)
    fake_loss = cross_entropy(tf.zeros_like(fake_output), fake_output)
    total_loss = real_loss + fake_loss
    return total_loss

def generator_loss(fake_output):
    return cross_entropy(tf.ones_like(fake_output), fake_output)

generator_optimizer = tf.keras.optimizers.Adam(1e-4)
discriminator_optimizer = tf.keras.optimizers.Adam(1e-4)

generator = make_generator_model()
discriminator = make_discriminator_model()

checkpoint_dir = './gan_dummy_training_checkpoints'
checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt")
checkpoint = tf.train.Checkpoint(generator_optimizer=generator_optimizer,
                                 discriminator_optimizer=discriminator_optimizer,
                                 generator=generator,
                                 discriminator=discriminator)



seed = tf.random.normal([BATCH_SIZE, noise_dim])

@tf.function
def train_step(images):
    noise = tf.random.normal([BATCH_SIZE, noise_dim])

    with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
        generated_images = generator(noise, training=True)

        real_output = discriminator(images, training=True)
        fake_output = discriminator(generated_images, training=True)

        gen_loss = generator_loss(fake_output)
        disc_loss = discriminator_loss(real_output, fake_output)

    gradients_of_generator = gen_tape.gradient(gen_loss, generator.trainable_variables)
    gradients_of_discriminator = disc_tape.gradient(disc_loss, discriminator.trainable_variables)

    generator_optimizer.apply_gradients(zip(gradients_of_generator, generator.trainable_variables))
    discriminator_optimizer.apply_gradients(zip(gradients_of_discriminator, discriminator.trainable_variables))

def generate_and_save_images(model, epoch, test_input):
    predictions = model(test_input, training=False)
    shape_ = predictions.shape #[batch_size, 16, 16, 1]
    predictions = tf.reshape(predictions[0],[shape_[1]*shape_[2]//2,2])
    
    fig,ax = plt.subplots()
    ax.axis('equal')
    ax.set_xlim([-1.1,1.1])
    ax.set_ylim([-1.1,1.1])
    ax.scatter(predictions[:, 0],predictions[:, 1], c='gray')
    #data_ = tf.constant([[1,1],[-1,1],[-1,-1],[1,-1]])
    #ax.scatter(data_[:,0],data_[:,1], c="r", s=25)
    ax.scatter(training_set[:,0],training_set[:,1])
    plt.axis('off')
    plt.savefig(str(output_dir)+'/{0}_image_at_epoch_{1:04d}.png'.format(train_dataset_name,epoch))
    plt.close(fig) 
    #plt.show()

#@tf.function
def train(dataset, epochs):
    generate_and_save_images(generator, 0, seed)
    for epoch in range(epochs):
        start = time.time()

        for image_batch in dataset:
            train_step(image_batch)

        # Produce images for the GIF as we go
        #display.clear_output(wait=True)
        
        # Save the model every 15 epochs
        if (epoch + 1) % 40 == 0:
          generate_and_save_images(generator, epoch + 1, seed)
          checkpoint.save(file_prefix = checkpoint_prefix)
          print ('Time for epoch {} is {} sec'.format(epoch + 1, time.time()-start))
    generate_and_save_images(generator, epochs+1, seed)

    # Generate after the final epoch
    #display.clear_output(wait=True)
    #generate_and_save_images(generator, epochs, seed)
  

In [77]:
train(train_dataset, EPOCHS)

Time for epoch 40 is 0.2227933406829834 sec
Time for epoch 80 is 0.22194647789001465 sec
Time for epoch 120 is 0.22320866584777832 sec
Time for epoch 160 is 0.2252638339996338 sec
Time for epoch 200 is 0.23028230667114258 sec
Time for epoch 240 is 0.21278715133666992 sec
Time for epoch 280 is 0.22409915924072266 sec
Time for epoch 320 is 0.21778202056884766 sec
Time for epoch 360 is 0.22771048545837402 sec
Time for epoch 400 is 0.21193671226501465 sec
Time for epoch 440 is 0.2294473648071289 sec
Time for epoch 480 is 0.2232227325439453 sec
Time for epoch 520 is 0.226304292678833 sec
Time for epoch 560 is 0.22312092781066895 sec
Time for epoch 600 is 0.21706390380859375 sec
Time for epoch 640 is 0.21593618392944336 sec
Time for epoch 680 is 0.21481943130493164 sec
Time for epoch 720 is 0.21309566497802734 sec
Time for epoch 760 is 0.21375274658203125 sec
Time for epoch 800 is 0.21391725540161133 sec
Time for epoch 840 is 0.21647286415100098 sec
Time for epoch 880 is 0.218536376953125 se

In [112]:
import imageio
#    plt.savefig(str(output_dir)+'/{0}_image_at_epoch_{1:04d}.png'.format(train_dataset_name,epoch))

def create_movie(dsname):
  gifpath=str(output_dir)+"/{0}.gif".format(dsname)
  with imageio.get_writer(gifpath, mode='I', fps=2) as writer:
    for epoch in np.arange(0,EPOCHS+40,40):
      filename = str(output_dir)+'/{0}_image_at_epoch_{1:04d}.png'.format(dsname,epoch)
      image = imageio.imread(filename)
      writer.append_data(image)
  return gifpath

#mergeed_images data set name:
#['circle', 'fourpoints', 'square', 'x2']
dsnames = sorted(list(datasets.keys()))
gifpath = create_movie("000_merge")
gifpath

'/content/drive/My Drive/root/colab/gan-dummy-educational-output/000_merge.gif'

In [113]:
!ls /content/drive/My\ Drive/root/colab/gan-dummy-educational-output/000_merge.gif

'/content/drive/My Drive/root/colab/gan-dummy-educational-output/000_merge.gif'


In [111]:
import matplotlib.image as mpimg

def merge_gan_progress():
  
  def merge_one_epoch(epoch, dsnames):
    fig, ((ax1,ax2),(ax3,ax4)) = plt.subplots(2,2,figsize=(8,8))
    #fig.suptitle("Epoch:{0}".format(epoch), fontsize="x-large")
    fig.text(0.5, 0.95, 'Approximating to four functions with Generative Adversarial Networks', transform=fig.transFigure, horizontalalignment='center',fontsize="x-large")
    fig.text(0.5, 0.91, "Epoch:{0}".format(epoch), transform=fig.transFigure, horizontalalignment='center',fontsize="x-large")

    axes = [ax1,ax2,ax3,ax4]
    for ax_ in axes:
      ax_.set_aspect(aspect="equal", adjustable="box",  share=False)
    for i in range(4):
      image_filepath = str(output_dir)+'/{0}_image_at_epoch_{1:04d}.png'.format(dsnames[i],epoch)
      
      axes[i].set_xticks([]) 
      axes[i].set_yticks([]) 
      # axes[i].set_xlim([-1.1,1.1])
      # axes[i].set_ylim([-1.1,1.1])
      axes[i].imshow(mpimg.imread(image_filepath))
    #plt.axis('off')
    merge_image = str(output_dir)+'/000_merge_image_at_epoch_{0:04d}.png'.format(epoch)
    plt.savefig(merge_image)
    plt.close(fig) 
    
  for epoch in np.arange(0,EPOCHS+40,40):
    merge_one_epoch(epoch, dsnames)

merge_gan_progress()    