In [70]:
import tensorflow as tf
from tensorflow import keras
import numpy as np
from keras.datasets.mnist import load_data
from tensorflow.keras.optimizers import Adam
from keras.models import Sequential
import matplotlib.pyplot as plt


In [71]:
def discriminator(in_shape=(784,)):
  model=Sequential()
  model.add(keras.Input(shape=in_shape))
  model.add(keras.layers.Dense(512,activation='leaky_relu'))
  model.add(keras.layers.Dropout(0.3))
  model.add(keras.layers.Dense(64,activation='leaky_relu'))
  model.add(keras.layers.Dropout(0.3))
  model.add(keras.layers.Dense(8,activation='leaky_relu'))
  model.add(keras.layers.Dense(1,activation='sigmoid'))


  model.compile(loss='binary_crossentropy',
                optimizer=Adam(),
                metrics=['accuracy'])


  return model

In [72]:
print(discriminator().summary())

None


In [73]:
def generator(rv_dim=50):
  model=Sequential()
  model.add(keras.Input(shape=(rv_dim,)))
  model.add(keras.layers.Dense(64,activation='leaky_relu'))
  model.add(keras.layers.Dense(256,activation='leaky_relu'))
  model.add(keras.layers.Dropout(0.2))
  model.add(keras.layers.Dense(512,activation='leaky_relu'))
  model.add(keras.layers.Dense(784,activation='tanh'))

  model.compile(loss='mse',
                optimizer=Adam(),
                metrics=['accuracy'])

  return model

In [74]:
print(generator().summary())

None


In [75]:
def gan_model(g_model,d_model):
  model=Sequential()
  model.add(g_model)
  model.add(d_model)

  model.compile(loss='binary_crossentropy',
                optimizer=Adam(),
                metrics=['accuracy'])

  return model

In [76]:
print(gan_model(generator(),discriminator()).summary())

None


In [77]:
from re import X
def load_real_data():
  (X_train,_),(_,_) = load_data()
  X_train=X_train.reshape(-1,784).astype('f8')-127.5
  return X_train/127.5

In [78]:
def generate_real_sample(data,n_samples=100):
  ix=np.random.randint(0,data.shape[0],n_samples)
  X_train=data[ix]
  y=np.ones(shape=(X_train.shape[0],1))
  return X_train,y

In [79]:
def generate_rv(rv_dim,n_samples=100):
  return np.random.randn(n_samples,rv_dim)


In [80]:
def generate_fake_images(g_model,rv_dim,n_samples=100):
  rv=generate_rv(rv_dim,n_samples)
  fimg=g_model.predict(rv)
  y=np.zeros(shape=(n_samples,1))
  return fimg,y

In [81]:
def save_fig(g_model,rv_dim,e):
  n=10
  rv=generate_rv(rv_dim,n*n)
  f_img=g_model.predict(rv)
  for i in range(n*n):
    plt.subplot(n,n,i+1)
    plt.axis('off')
    plt.imshow(f_img[i].reshape((28,28)),interpolation="nearest",cmap='gray')
  filename=f'/content/gan_output/generated_plot_e{e+1}.png'
  plt.savefig(filename)
  plt.close()

In [84]:
def train(g_model,d_model,gan_model,rv_dim,epochs=50,batch_size=250):
  nbatchs=data.shape[0]//batch_size
  half_batch=batch_size//2

  for e in range(epochs):
    for bn in range(nbatchs):
      x_real,y_real=generate_real_sample(data,half_batch)
      x_fake,y_fake=generate_fake_images(g_model,rv_dim,half_batch)

      d_model.trainable=True

      r_loss,_=d_model.train_on_batch(x_real,y_real)
      f_loss,_=d_model.train_on_batch(x_fake,y_fake)

      d_loss=0.5*(r_loss+f_loss)

      d_model.trainable=False

      x_gan=generate_rv(rv_dim,batch_size)
      y=np.ones((batch_size,1))

      g_loss=gan_model.train_on_batch(x_gan,y)

    print(f'Epoch: {e+1},d_loss:{d_loss},g_loss:={g_loss}')
    if (e+1)%10==0:
      save_fig(g_model,rv_dim,e)

In [85]:
keras.utils.disable_interactive_logging()
rv_dim=50
g_model=generator(rv_dim)
d_model=discriminator()
gan=gan_model(g_model,d_model)
data=load_real_data()
train(g_model,d_model,gan,rv_dim)
keras.utils.enable_interactive_logging()

Epoch: 1,d_loss:0.22604407370090485,g_loss:=[array(5.9506006, dtype=float32), array(0.06033333, dtype=float32)]
Epoch: 2,d_loss:0.2178111970424652,g_loss:=[array(5.6518636, dtype=float32), array(0.05220833, dtype=float32)]
Epoch: 3,d_loss:0.22997745871543884,g_loss:=[array(5.1722465, dtype=float32), array(0.04652222, dtype=float32)]
Epoch: 4,d_loss:0.22232642769813538,g_loss:=[array(4.9537067, dtype=float32), array(0.04075, dtype=float32)]
Epoch: 5,d_loss:0.21296554803848267,g_loss:=[array(4.8192654, dtype=float32), array(0.03601333, dtype=float32)]
Epoch: 6,d_loss:0.2059912085533142,g_loss:=[array(4.695736, dtype=float32), array(0.03353889, dtype=float32)]
Epoch: 7,d_loss:0.20306801795959473,g_loss:=[array(4.626872, dtype=float32), array(0.03244524, dtype=float32)]
Epoch: 8,d_loss:0.20568229258060455,g_loss:=[array(4.553083, dtype=float32), array(0.03261875, dtype=float32)]
Epoch: 9,d_loss:0.2080811709165573,g_loss:=[array(4.4581113, dtype=float32), array(0.03289444, dtype=float32)]
E