Vamos a usar este notebook como ejemplo de redes GAN. Estas se componen de dos redes: una generadora y otra discriminadora. La generadora tratará de crear instancias de datos falsos que tratará de "colar" entre los datos reales provenientes del dataset que se le pasarán a la discriminadora. Usaremos el datast de MNIST.

In [0]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
%matplotlib inline
import keras
%tensorflow_version 1.x
from keras.layers import Dense, Dropout, Input
from keras.models import Model,Sequential
from keras.datasets import mnist
from tqdm import tqdm
from keras.layers.advanced_activations import LeakyReLU
from keras.optimizers import Adam

In [0]:
def load_data():
  (x_train, y_train), (x_test, y_test) = mnist.load_data()
  x_train = (x_train.astype(np.float32) - 127.5)/127.5
  
  # convert shape of x_train from (60000, 28, 28) to (60000, 784) 
  # 784 columns per row
  x_train = x_train.reshape(60000, 784)
  return (x_train, y_train, x_test, y_test)

In [0]:
(X_train, y_train,X_test, y_test)=load_data()
print(X_train.shape)

(60000, 784)


In [0]:
def create_generator():
  generator=Sequential()
  generator.add(Dense(units=256,input_dim=100))
  generator.add(LeakyReLU(0.2))
  
  generator.add(Dense(units=512))
  generator.add(LeakyReLU(0.2))
  
  generator.add(Dense(units=1024))
  generator.add(LeakyReLU(0.2))
  
  generator.add(Dense(units=784, activation='tanh'))
  
  generator.compile(loss='binary_crossentropy', optimizer=Adam(lr=0.0002, beta_1=0.5))
  return generator

In [0]:
g = create_generator()
g.summary()

Model: "sequential_3"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
dense_9 (Dense)              (None, 256)               25856     
_________________________________________________________________
leaky_re_lu_7 (LeakyReLU)    (None, 256)               0         
_________________________________________________________________
dense_10 (Dense)             (None, 512)               131584    
_________________________________________________________________
leaky_re_lu_8 (LeakyReLU)    (None, 512)               0         
_________________________________________________________________
dense_11 (Dense)             (None, 1024)              525312    
_________________________________________________________________
leaky_re_lu_9 (LeakyReLU)    (None, 1024)              0         
_________________________________________________________________
dense_12 (Dense)             (None, 784)              

In [0]:
def create_discriminator():
  discriminator=Sequential()
  discriminator.add(Dense(units=1024,input_dim=784))
  discriminator.add(LeakyReLU(0.2))
  discriminator.add(Dropout(0.3))
      
  
  discriminator.add(Dense(units=512))
  discriminator.add(LeakyReLU(0.2))
  discriminator.add(Dropout(0.3))
      
  discriminator.add(Dense(units=256))
  discriminator.add(LeakyReLU(0.2))
  
  discriminator.add(Dense(units=1, activation='sigmoid'))
  
  discriminator.compile(loss='binary_crossentropy', optimizer=Adam(lr=0.0002, beta_1=0.5))
  return discriminator

In [0]:
d = create_discriminator()
d.summary()

Model: "sequential_4"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
dense_13 (Dense)             (None, 1024)              803840    
_________________________________________________________________
leaky_re_lu_10 (LeakyReLU)   (None, 1024)              0         
_________________________________________________________________
dropout_3 (Dropout)          (None, 1024)              0         
_________________________________________________________________
dense_14 (Dense)             (None, 512)               524800    
_________________________________________________________________
leaky_re_lu_11 (LeakyReLU)   (None, 512)               0         
_________________________________________________________________
dropout_4 (Dropout)          (None, 512)               0         
_________________________________________________________________
dense_15 (Dense)             (None, 256)              

In [0]:
def create_gan(discriminator, generator):
  discriminator.trainable=False
  gan_input = Input(shape=(100,))
  x = generator(gan_input)
  gan_output= discriminator(x)
  gan= Model(inputs=gan_input, outputs=gan_output)
  gan.compile(loss='binary_crossentropy', optimizer='adam')
  return gan

In [0]:
gan = create_gan(d, g)
gan.summary()

Model: "model_1"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input_1 (InputLayer)         (None, 100)               0         
_________________________________________________________________
sequential_3 (Sequential)    (None, 784)               1486352   
_________________________________________________________________
sequential_4 (Sequential)    (None, 1)                 1460225   
Total params: 2,946,577
Trainable params: 1,486,352
Non-trainable params: 1,460,225
_________________________________________________________________


# plot_generated_images
Imprimimos las imágenes generadas para ver cómo va evolucionando la generación.

In [0]:
def plot_generated_images(epoch, generator, examples=100, dim=(10,10), figsize=(10,10)):
  noise= np.random.normal(loc=0, scale=1, size=[examples, 100])
  generated_images = generator.predict(noise)
  generated_images = generated_images.reshape(100,28,28)
  plt.figure(figsize=figsize)
  for i in range(generated_images.shape[0]):
    plt.subplot(dim[0], dim[1], i+1)
    plt.imshow(generated_images[i], interpolation='nearest')
    plt.axis('off')
  plt.tight_layout()
  plt.savefig('gan_generated_image %d.png' %epoch)

# train_gan
Entrenamos la GAN. Para ello primero cargamos los datos de entrenamiento y test del dataset. A continuación creamos un generador, un discriminador y, con ellos, la GAN.

## En cada instante 
Generamos ruido aleatorio para inicializar el generador. Este crea dígitos falsos de MNIST. Con estos dígitos falsos y otros verdaderos tomados aleatoriamente del dataset, creamos un "batch" o conjunto de imágenes. Pasaremos estos batches al discriminador. También crearemos etiquetas para los datos reales y generados.
Luego pre-entrenaremos al discriminador con datos tanto falsos como reales antes de iniciar la GAN. Así nos aseguramos de que el modelo funcione correctamente con ambos tipos de datos.
A continuación tomamos las imágenes generadas y las hacemos pasar por datos auténticos.
Para entrenar la GAN debemos entrenar al discriminador, y luego congelar sus pesos y entrenar el resto del modelo.

Finalmente, cada 20 instantes mostramos las imágenes generadas.

In [0]:
def train_gan(epochs=1, batch_size=128):
    
  # Loading data
  (X_train, y_train, X_test, y_test) = load_data()
  batch_count = X_train.shape[0] / batch_size
  
  # Create GAN
  generator = create_generator()
  discriminator = create_discriminator()
  gan = create_gan(discriminator, generator)
  
  for e in range(1,epochs+1 ):
    print("Epoch %d" %e)
    for _ in tqdm(range(batch_size)): # We use tqdm to show a progress bar of the for loop
      # Generate random noise as an input to initialize the generator
      noise = np.random.normal(0,1, [batch_size, 100])
      
      # Generate fake MNIST images from noised input
      generated_images = generator.predict(noise)
      
      # Get a random set of real images
      image_batch = X_train[np.random.randint(low=0, high=X_train.shape[0], size=batch_size)]
      
      # Construct different batches of real and fake data 
      X = np.concatenate([image_batch, generated_images])
      
      # Labels for generated and real data
      y_dis = np.zeros(2 * batch_size)
      y_dis[:batch_size] = 0.9
      
      # Pre train discriminator on  fake and real data  before starting the gan. 
      discriminator.trainable = True
      discriminator.train_on_batch(X, y_dis)
      
      # Tricking the noised input of the Generator as real data
      noise = np.random.normal(0,1, [batch_size, 100])
      y_gen = np.ones(batch_size)
      
      # During the training of gan, 
      # the weights of discriminator should be fixed. 
      # We can enforce that by setting the trainable flag
      discriminator.trainable = False
      
      # training the GAN by alternating the training of the Discriminator 
      # and training the chained GAN model with Discriminator’s weights freezed.
      gan.train_on_batch(noise, y_gen)
          
      if e == 1 or e % 20 == 0:
        plot_generated_images(e, generator)

In [0]:
train_gan(400,128)

  0%|          | 0/128 [00:00<?, ?it/s]

Epoch 1










  """
100%|██████████| 128/128 [06:56<00:00,  2.66s/it]
  1%|          | 1/128 [00:00<00:21,  6.04it/s]

Epoch 2


100%|██████████| 128/128 [00:23<00:00,  5.60it/s]
  1%|          | 1/128 [00:00<00:22,  5.73it/s]

Epoch 3


100%|██████████| 128/128 [00:21<00:00,  6.08it/s]
  1%|          | 1/128 [00:00<00:22,  5.64it/s]

Epoch 4


100%|██████████| 128/128 [00:20<00:00,  6.16it/s]
  1%|          | 1/128 [00:00<00:19,  6.36it/s]

Epoch 5


100%|██████████| 128/128 [00:20<00:00,  6.04it/s]
  1%|          | 1/128 [00:00<00:21,  5.88it/s]

Epoch 6


100%|██████████| 128/128 [00:21<00:00,  5.82it/s]
  1%|          | 1/128 [00:00<00:20,  6.12it/s]

Epoch 7


100%|██████████| 128/128 [00:21<00:00,  6.01it/s]
  1%|          | 1/128 [00:00<00:20,  6.19it/s]

Epoch 8


100%|██████████| 128/128 [00:20<00:00,  5.88it/s]
  1%|          | 1/128 [00:00<00:22,  5.61it/s]

Epoch 9


100%|██████████| 128/128 [00:20<00:00,  6.10it/s]
  1%|          | 1/128 [00:00<00:20,  6.21it/s]

Epoch 10


100%|██████████| 128/128 [00:20<00:00,  6.17it/s]
  1%|          | 1/128 [00:00<00:19,  6.57it/s]

Epoch 11


100%|██████████| 128/128 [00:20<00:00,  6.02it/s]
  1%|          | 1/128 [00:00<00:20,  6.11it/s]

Epoch 12


100%|██████████| 128/128 [00:21<00:00,  5.99it/s]
  1%|          | 1/128 [00:00<00:20,  6.20it/s]

Epoch 13


100%|██████████| 128/128 [00:21<00:00,  6.07it/s]
  1%|          | 1/128 [00:00<00:21,  5.89it/s]

Epoch 14


100%|██████████| 128/128 [00:20<00:00,  6.19it/s]
  1%|          | 1/128 [00:00<00:20,  6.30it/s]

Epoch 15


100%|██████████| 128/128 [00:21<00:00,  6.08it/s]
  1%|          | 1/128 [00:00<00:21,  6.00it/s]

Epoch 16


100%|██████████| 128/128 [00:21<00:00,  6.16it/s]
  1%|          | 1/128 [00:00<00:21,  5.78it/s]

Epoch 17


100%|██████████| 128/128 [00:21<00:00,  6.24it/s]
  1%|          | 1/128 [00:00<00:20,  6.06it/s]

Epoch 18


100%|██████████| 128/128 [00:20<00:00,  6.03it/s]
  1%|          | 1/128 [00:00<00:20,  6.05it/s]

Epoch 19


100%|██████████| 128/128 [00:20<00:00,  6.18it/s]
  0%|          | 0/128 [00:00<?, ?it/s]

Epoch 20


100%|██████████| 128/128 [06:45<00:00,  2.56s/it]
  1%|          | 1/128 [00:00<00:22,  5.76it/s]

Epoch 21


100%|██████████| 128/128 [00:24<00:00,  6.00it/s]
  1%|          | 1/128 [00:00<00:21,  5.99it/s]

Epoch 22


100%|██████████| 128/128 [00:20<00:00,  6.16it/s]
  1%|          | 1/128 [00:00<00:20,  6.29it/s]

Epoch 23


100%|██████████| 128/128 [00:20<00:00,  5.92it/s]
  1%|          | 1/128 [00:00<00:21,  5.88it/s]

Epoch 24


100%|██████████| 128/128 [00:20<00:00,  6.43it/s]
  1%|          | 1/128 [00:00<00:21,  6.04it/s]

Epoch 25


100%|██████████| 128/128 [00:19<00:00,  6.53it/s]
  1%|          | 1/128 [00:00<00:19,  6.52it/s]

Epoch 26


100%|██████████| 128/128 [00:19<00:00,  6.68it/s]
  1%|          | 1/128 [00:00<00:19,  6.61it/s]

Epoch 27


100%|██████████| 128/128 [00:19<00:00,  6.34it/s]
  1%|          | 1/128 [00:00<00:19,  6.43it/s]

Epoch 28


100%|██████████| 128/128 [00:20<00:00,  6.49it/s]
  1%|          | 1/128 [00:00<00:19,  6.44it/s]

Epoch 29


100%|██████████| 128/128 [00:19<00:00,  6.48it/s]
  1%|          | 1/128 [00:00<00:19,  6.67it/s]

Epoch 30


100%|██████████| 128/128 [00:19<00:00,  6.61it/s]
  1%|          | 1/128 [00:00<00:18,  6.76it/s]

Epoch 31


100%|██████████| 128/128 [00:19<00:00,  6.69it/s]
  1%|          | 1/128 [00:00<00:18,  7.00it/s]

Epoch 32


100%|██████████| 128/128 [00:19<00:00,  6.68it/s]
  1%|          | 1/128 [00:00<00:19,  6.62it/s]

Epoch 33


100%|██████████| 128/128 [00:19<00:00,  6.38it/s]
  1%|          | 1/128 [00:00<00:18,  6.84it/s]

Epoch 34


100%|██████████| 128/128 [00:19<00:00,  6.77it/s]
  1%|          | 1/128 [00:00<00:18,  6.83it/s]

Epoch 35


100%|██████████| 128/128 [00:19<00:00,  6.26it/s]
  1%|          | 1/128 [00:00<00:20,  6.06it/s]

Epoch 36


100%|██████████| 128/128 [00:19<00:00,  6.62it/s]
  1%|          | 1/128 [00:00<00:20,  6.32it/s]

Epoch 37


100%|██████████| 128/128 [00:19<00:00,  6.66it/s]
  1%|          | 1/128 [00:00<00:18,  6.73it/s]

Epoch 38


100%|██████████| 128/128 [00:19<00:00,  6.36it/s]
  1%|          | 1/128 [00:00<00:19,  6.66it/s]

Epoch 39


100%|██████████| 128/128 [00:19<00:00,  6.51it/s]
  0%|          | 0/128 [00:00<?, ?it/s]

Epoch 40


 73%|███████▎  | 93/128 [04:37<01:31,  2.62s/it]