# Treinamento de uma GAN com MNIST
Este notebook executa uma GAN simples para geração de dígitos MNIST, com separação didática por blocos.

## Importação de Bibliotecas

In [None]:

import tensorflow as tf
from tensorflow.keras.layers import Dense, Reshape, Flatten, LeakyReLU, BatchNormalization, Conv2DTranspose, Conv2D, Dropout
from tensorflow.keras.models import Sequential
import numpy as np
import matplotlib.pyplot as plt


## Carregamento e Pré-processamento do MNIST

In [None]:

(X_train, _), (_, _) = tf.keras.datasets.mnist.load_data()
X_train = (X_train - 127.5) / 127.5
X_train = X_train.reshape(-1, 28, 28, 1)


## Definição do Gerador

In [None]:

latent_dim = 100

def build_generator():
    model = Sequential()
    model.add(Dense(7*7*256, use_bias=False, input_shape=(latent_dim,)))
    model.add(BatchNormalization())
    model.add(LeakyReLU())
    model.add(Reshape((7, 7, 256)))
    model.add(Conv2DTranspose(128, (5,5), strides=(2,2), padding='same', use_bias=False))
    model.add(BatchNormalization())
    model.add(LeakyReLU())
    model.add(Conv2DTranspose(64, (5,5), strides=(2,2), padding='same', use_bias=False))
    model.add(BatchNormalization())
    model.add(LeakyReLU())
    model.add(Conv2DTranspose(1, (5,5), strides=(1,1), padding='same', use_bias=False, activation='tanh'))
    return model


## Definição do Discriminador

In [None]:

def build_discriminator():
    model = Sequential()
    model.add(Conv2D(64, (5,5), strides=(2,2), padding='same', input_shape=[28,28,1]))
    model.add(LeakyReLU())
    model.add(Dropout(0.3))
    model.add(Conv2D(128, (5,5), strides=(2,2), padding='same'))
    model.add(LeakyReLU())
    model.add(Dropout(0.3))
    model.add(Flatten())
    model.add(Dense(1, activation='sigmoid'))
    return model

generator = build_generator()
discriminator = build_discriminator()
cross_entropy = tf.keras.losses.BinaryCrossentropy()

generator_optimizer = tf.keras.optimizers.Adam(1e-4, beta_1=0.5)
discriminator_optimizer = tf.keras.optimizers.Adam(1e-4, beta_1=0.5)


## Funções de Treinamento para o Discriminador e o Gerador

In [None]:

@tf.function
def train_discriminator(images):
    noise = tf.random.normal([batch_size, latent_dim])
    generated_images = generator(noise, training=True)

    with tf.GradientTape() as disc_tape:
        real_output = discriminator(images, training=True)
        fake_output = discriminator(generated_images, training=True)

        disc_loss_real = cross_entropy(tf.ones_like(real_output) * 0.9, real_output)  # Label smoothing
        disc_loss_fake = cross_entropy(tf.zeros_like(fake_output), fake_output)
        disc_loss = disc_loss_real + disc_loss_fake

    gradients_of_discriminator = disc_tape.gradient(disc_loss, discriminator.trainable_variables)
    discriminator_optimizer.apply_gradients(zip(gradients_of_discriminator, discriminator.trainable_variables))
    return disc_loss

@tf.function
def train_generator():
    noise = tf.random.normal([batch_size, latent_dim])
    with tf.GradientTape() as gen_tape:
        generated_images = generator(noise, training=True)
        fake_output = discriminator(generated_images, training=True)
        gen_loss = cross_entropy(tf.ones_like(fake_output), fake_output)
    gradients_of_generator = gen_tape.gradient(gen_loss, generator.trainable_variables)
    generator_optimizer.apply_gradients(zip(gradients_of_generator, generator.trainable_variables))
    return gen_loss


## Funções de Visualização de Imagens Geradas

In [None]:

def generate_and_show_all_images(generator, epoch):
    noise = tf.random.normal([25, latent_dim])
    gen_imgs = generator(noise, training=False)
    gen_imgs = (gen_imgs + 1) / 2.0
    plt.figure(figsize=(5,5))
    for i in range(25):
        plt.subplot(5,5,i+1)
        plt.imshow(gen_imgs[i, :, :, 0], cmap='gray')
        plt.axis('off')
    plt.suptitle(f'Epoch {epoch+1}: Exemplos de imagens geradas')
    plt.tight_layout()
    plt.show()

def generate_and_show_filtered_images(generator, discriminator, epoch, threshold=0.5):
    noise = tf.random.normal([200, latent_dim])
    generated_images = generator(noise, training=False)
    predictions = discriminator(generated_images, training=False).numpy().flatten()
    selected_indices = np.where(predictions > threshold)[0]
    selected_images = generated_images.numpy()[selected_indices]

    if len(selected_images) == 0:
        print(f'Epoch {epoch+1}: Nenhuma imagem enganou o Discriminador nesta época.')
        return []

    selected_images = (selected_images + 1) / 2.0
    total = len(selected_images)
    cols = 5
    rows = (total // cols) + 1
    plt.figure(figsize=(cols*2, rows*2))
    for i in range(total):
        plt.subplot(rows, cols, i+1)
        plt.imshow(selected_images[i, :, :, 0], cmap='gray')
        plt.axis('off')
    plt.suptitle(f'Epoch {epoch+1}: Imagens que enganaram o Discriminador (threshold={threshold})')
    plt.tight_layout()
    plt.show()
    return selected_images

def show_all_filtered_images(images):
    if len(images) == 0:
        print("Nenhuma imagem enganou o Discriminador em nenhuma época.")
        return
    images = np.array(images)
    images = (images + 1) / 2.0
    total = len(images)
    cols = 5
    rows = (total // cols) + 1
    plt.figure(figsize=(cols*2, rows*2))
    for i in range(total):
        plt.subplot(rows, cols, i+1)
        plt.imshow(images[i, :, :, 0], cmap='gray')
        plt.axis('off')
    plt.suptitle('Todas as imagens que enganaram o Discriminador ao longo das épocas')
    plt.figtext(0.5, 0.01, 'Estas são todas as imagens que em algum momento durante as épocas de treinamento conseguiram enganar o Discriminador (D(x) > threshold).', ha='center', fontsize=9)
    plt.tight_layout()
    plt.show()


## Loop Principal de Treinamento

In [None]:

epochs = 50
batch_size = 128
buffer_size = 10000
max_batches_per_epoch = 100

dataset = tf.data.Dataset.from_tensor_slices(X_train).shuffle(buffer_size).batch(batch_size)
filtered_images_all_epochs = []

for epoch in range(epochs):
    batch_count = 0
    for image_batch in dataset:
        d_loss = train_discriminator(image_batch)
        g_loss1 = train_generator()
        g_loss2 = train_generator()
        batch_count += 1
        if batch_count >= max_batches_per_epoch:
            break
    print(f'Epoch {epoch+1}: D loss: {d_loss.numpy():.4f}, G loss 1: {g_loss1.numpy():.4f}, G loss 2: {g_loss2.numpy():.4f}')
    generate_and_show_all_images(generator, epoch)
    filtered_images = generate_and_show_filtered_images(generator, discriminator, epoch, threshold=0.5)
    if len(filtered_images) > 0:
        filtered_images_all_epochs.extend(filtered_images)

show_all_filtered_images(filtered_images_all_epochs)
