#Treinar a rede GAN para gerar novos digitos

Código adaptado de https://github.com/Zackory/Keras-MNIST-GAN/blob/master/mnist_gan.py

In [None]:
import os
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm

from keras.layers import Input
from keras.models import Model, Sequential
from keras.layers.core import Reshape, Dense, Dropout, Flatten
from keras.layers import LeakyReLU
from keras.layers.convolutional import Convolution2D, UpSampling2D
from keras.layers import BatchNormalization
from keras.datasets import mnist
from keras.optimizers import Adam
from keras import backend as K
from keras import initializers

os.environ["KERAS_BACKEND"] = "tensorflow"

In [None]:
np.random.seed(1000)

(X_train, y_train), (X_test, y_test) = mnist.load_data()
X_train = (X_train.astype(np.float32) - 127.5)/127.5
X_train = X_train.reshape(60000, 784)

adam = Adam(learning_rate=0.0002, beta_1=0.5)

In [None]:
randomDim = 100
adam = Adam(learning_rate=0.0002, beta_1=0.5)

In [None]:
generator = Sequential()
generator.add(Dense(256, input_dim=randomDim, kernel_initializer=initializers.RandomNormal(stddev=0.02)))
generator.add(LeakyReLU(0.2))
generator.add(Dense(512))
generator.add(LeakyReLU(0.2))
generator.add(Dense(1024))
generator.add(LeakyReLU(0.2))
generator.add(Dense(784, activation='tanh'))
generator.compile(loss='binary_crossentropy', optimizer=adam)

In [None]:
discriminator = Sequential()
discriminator.add(Dense(1024, input_dim=784, kernel_initializer=initializers.RandomNormal(stddev=0.02)))
discriminator.add(LeakyReLU(0.2))
discriminator.add(Dropout(0.3))
discriminator.add(Dense(512))
discriminator.add(LeakyReLU(0.2))
discriminator.add(Dropout(0.3))
discriminator.add(Dense(256))
discriminator.add(LeakyReLU(0.2))
discriminator.add(Dropout(0.3))
discriminator.add(Dense(1, activation='sigmoid'))
discriminator.compile(loss='binary_crossentropy', optimizer=adam)

In [None]:
discriminator.trainable = False
ganInput = Input(shape=(randomDim,))
x = generator(ganInput)
ganOutput = discriminator(x)
gan = Model(inputs=ganInput, outputs=ganOutput)
gan.compile(loss='binary_crossentropy', optimizer=adam)

dLosses = []
gLosses = []

In [None]:
def plotLoss():
    plt.figure(figsize=(10, 8))
    plt.plot(dLosses, label='Discriminitive loss')
    plt.plot(gLosses, label='Generative loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.legend()
    plt.savefig("/content/drive/MyDrive/Digitos_GAN/Funcao_Perda.png")

In [None]:
def plotGeneratedImages(epoch, examples=100, dim=(10, 10), figsize=(10, 10)):
    noise = np.random.normal(0, 1, size=[examples, randomDim])
    generatedImages = generator.predict(noise)
    generatedImages = generatedImages.reshape(examples, 28, 28)

    plt.figure(figsize=figsize)
    for i in range(generatedImages.shape[0]):
        plt.subplot(dim[0], dim[1], i+1)
        plt.imshow(generatedImages[i], interpolation='nearest', cmap='gray_r')
        plt.axis('off')
    plt.tight_layout()
    plt.savefig('/content/drive/MyDrive/Digitos_GAN/digito_gerado_%d.png' % epoch)

In [None]:
def saveModels():
    generator.save('/content/drive/MyDrive/Digitos_GAN/gerador.h5')
    discriminator.save('/content/drive/MyDrive/Digitos_GAN/discriminador.h5')

In [None]:
def train(epochs=1, batchSize=128):
    batchCount = int(X_train.shape[0] / batchSize)
    print('Épocas:', epochs)
    print('Batch size:', batchSize)
    print('Batch por épocas:', batchCount)

    for e in range(1, epochs+1):
        print('-'*15, 'Época %d' % e, '-'*15)
        for _ in  tqdm(range(batchCount)):
            noise = np.random.normal(0, 1, size=[batchSize, randomDim])
            imageBatch = X_train[np.random.randint(0, X_train.shape[0], size=batchSize)]

            generatedImages = generator.predict(noise, verbose=0)
            X = np.concatenate([imageBatch, generatedImages])

            yDis = np.zeros(2*batchSize)
            yDis[:batchSize] = 0.9

            discriminator.trainable = True
            dloss = discriminator.train_on_batch(X, yDis)

            noise = np.random.normal(0, 1, size=[batchSize, randomDim])
            yGen = np.ones(batchSize)
            discriminator.trainable = False
            gloss = gan.train_on_batch(noise, yGen)
        
        dLosses.append(dloss)
        gLosses.append(gloss)

        plotGeneratedImages(e)

In [None]:
train(100, 128)

In [None]:
plotLoss()

In [None]:
saveModels()

In [None]:
def gerar_digito(n_ex=1, dim=(1, 10), figsize=(12, 2)):
    noise = np.random.normal(0, 1, size=(n_ex, randomDim))
    imagem_gerada = generator.predict(noise)
    imagem_gerada = imagem_gerada.reshape(28, 28)

    plt.imshow(imagem_gerada, interpolation='nearest', cmap='gray_r')
    plt.figure(figsize=figsize)
    plt.tight_layout()
    plt.show()

In [None]:
gerar_digito()