In [1]:
!nvidia-smi -L

In [2]:
%matplotlib inline
%config InlineBackend.figure_format = 'retina'

import numpy as np
import os
import tensorflow as tf
from tensorflow.keras.datasets import mnist
from tensorflow.keras.utils import  image_dataset_from_directory
from tensorflow.keras.layers import (Conv2D, Dense, Flatten, Input, 
                                     Activation, Reshape, Conv2DTranspose, 
                                     BatchNormalization, LeakyReLU, Dropout,
                                     Embedding, Concatenate, UpSampling2D)

from tensorflow.keras.models import Sequential, Model
import matplotlib.pyplot as plt
import datetime
from IPython import display

z_dim = 100
batch_size = 1
DEFAULT_PATH = "./Dataset"
img_height=img_width=128
batch_size=256

# Carregando banco de imagens

In [3]:
#https://www.kaggle.com/code/theyazilimci/alzheimer-prediction-92-acc/notebook
#https://www.kaggle.com/datasets/sachinkumar413/alzheimer-mri-dataset

!wget  "https://github.com/limonheiro/GAN_Classificador/raw/main/alzheimer.zip"

In [4]:
!unzip -q "alzheimer.zip"

In [5]:
from os import listdir
from os.path import join, isdir
class_name = np.array([f for f in listdir(DEFAULT_PATH) if isdir(join(DEFAULT_PATH, f))])
n_classes = len(class_name)

In [6]:
train_data = image_dataset_from_directory(DEFAULT_PATH,
                  validation_split=None,
                  seed=123,
                  color_mode='grayscale',
                  interpolation="nearest",
                  follow_links=False,
                  image_size=(img_height, img_width),
                  batch_size=batch_size)


# val_data = image_dataset_from_directory(DEFAULT_PATH,
#                                         validation_split=0.2,
#                                         subset="validation", 
#                                         seed=123,
#                                         color_mode="rgb",
#                                         interpolation="nearest",
#                                         follow_links=False,
#                                         image_size=(img_height, img_width),
#                                         batch_size=batch_size)

In [7]:
a, b  = next(iter(train_data.take(1)))
# a = np.expand_dims(a, axis=1)  
plt.imshow(a[0][:,:,0],cmap='gray')

# Normalização

In [8]:
def process(image,label):
    image = tf.cast(((tf.cast(image, tf.float32)/127.5) -1)  ,tf.float32)
    return image,label

train_data = train_data.map(process)

# Variaveis de entrada

In [9]:
cat_label = Input((1,))
latent_input = Input((z_dim,))

In [10]:
img_rows = 128
img_cols = 128
channels = 1
img_shape = (img_rows, img_cols, channels)

gf = 64
df = 64

In [11]:
#U-net Generator
# define an encoder block
def conv2d(layer_in,n_filters,f_size=4,batchnorm=True,strides=2):
  # add downsampling layer
  g = Conv2D(n_filters, 
              kernel_size=f_size, 
              strides=strides, 
              padding='same')(layer_in)
  # conditionally add batch normalization
  if batchnorm:
    g = BatchNormalization(momentum=0.8)(g, training=False)
  # leaky relu activation
  g = LeakyReLU(alpha=0.2)(g)
  return g 

# define an dencoder block
def deconv2d(layer_in,skip_in,n_filters,f_size=4,dropout=False):

  # add upsampling layer
  g= UpSampling2D(size=2)(layer_in)

  g = Conv2D(n_filters, 
            kernel_size=f_size, 
            strides=1, 
            padding='same', 
            activation='relu')(g)
  # conditionally add dropout
  if dropout:
    g = Dropout(0.5)(g)
  # add batch normalization
  g = BatchNormalization(momentum=0.8)(g)
  # merge with skip connection
  g = Concatenate()([g, skip_in])
  return g

# gerador

In [12]:
def generator_label_emb(layer_input : Input, num_cat : int = n_classes, embedding_input : int = z_dim) :

  label_embedding = Embedding(num_cat, embedding_input)(layer_input)
  d1 = Dense(8*8*1)(label_embedding)
  layer_reshape = Reshape((8,8,1))(d1)
  
  return layer_reshape

In [13]:
def generator_latent_dim(layer_input : Input, latente_dim : int = z_dim):
  d1 = Dense(8*8*256)(layer_input)
  d2 = BatchNormalization()(d1)
  d3 = LeakyReLU()(d2)
  layer_reshape = Reshape((8,8,256))(d3)

  return layer_reshape

In [14]:
def Input_generator(cat_label,latent_input) -> tf.keras.models.Model :

  gle = generator_label_emb(cat_label)
  gld = generator_latent_dim(latent_input)
  input = Concatenate()([gle, gld])

  return input

In [15]:
def build_generator():
  
  input = Input_generator(cat_label,latent_input)

  #U-net Generator

  g = Conv2DTranspose(128,3,2,'same')(input)
  g = BatchNormalization()(g)
  g = LeakyReLU()(g)

  g = Conv2DTranspose(64,3,1,'same')(input)
  g = BatchNormalization()(g)
  g = LeakyReLU()(g)

  g = Conv2DTranspose(64,3,2,'same')(g)
  g = BatchNormalization()(g)
  g = LeakyReLU()(g)

  g = Conv2DTranspose(64,3,1,'same')(g)
  g = BatchNormalization()(g)
  g = LeakyReLU()(g)

  g = Conv2DTranspose(32,3,2,'same')(g)
  g = BatchNormalization()(g)
  g = LeakyReLU()(g)

  g = Conv2DTranspose(32,3,1,'same')(g)
  g = BatchNormalization()(g)
  g = LeakyReLU()(g)

  g = Conv2DTranspose(16,3,2,'same')(g)
  g = BatchNormalization()(g)
  g = LeakyReLU()(g)

  g = Conv2DTranspose(16,3,1,'same')(g)
  g = BatchNormalization()(g)
  g = LeakyReLU()(g)

  g = Conv2DTranspose(16,3,2,'same')(g)
  g = BatchNormalization()(g)
  g = LeakyReLU()(g)
  g = Dropout(0.3)(g)

  g = Conv2DTranspose(1,3,1,'same')(g)
  g = Activation('tanh')(g)


  return Model(inputs=[cat_label,latent_input], outputs=g)

In [16]:
generator = build_generator()
tf.keras.utils.plot_model(generator, show_shapes=True, dpi=64)

In [17]:
generator.summary()

# Discriminador

In [18]:
shape_input=(128,128,1)

In [19]:
def discrimimator_label_emb(layer_input : Input, shape_input : tuple() = shape_input, num_cat : int = 10, embedding_input : int = z_dim) :

  label_embedding = Embedding(num_cat, embedding_input)(layer_input)
  d1 = Dense(shape_input[0]*shape_input[1]*shape_input[2])(label_embedding)
  layer_reshape = Reshape(shape_input)(d1)
  
  return layer_reshape

In [20]:
def discrimimator_input(shape_input : tuple() = shape_input, latente_dim : int = z_dim):
  layer_input = Input(shape_input)
  return layer_input

In [21]:
def discriminator():

  layer_label_emb = discrimimator_label_emb(cat_label)
  layer_input = discrimimator_input()
  concate = Concatenate()([layer_label_emb,layer_input])

  d = Conv2D(64,3,2,'same')(concate)
  d = LeakyReLU()(d)
  d = Dropout(0.3)(d)

  d = Conv2D(128,3,2,'same')(d)
  d = LeakyReLU()(d)
  d = Dropout(0.3)(d)

  d = Conv2D(64,3,2,'same')(d)
  d = LeakyReLU()(d)
  d = Dropout(0.3)(d)

  d = Conv2D(32,3,2,'same')(d)
  d = LeakyReLU()(d)
  d = Dropout(0.3)(d)

  d = Conv2D(8,3,2,'same')(d)
  d = LeakyReLU()(d)
  d = Dropout(0.3)(d)

  d = Flatten()(d)
  d = Activation('sigmoid')(d)
  d = Dense(1)(d)


  model = Model(inputs=[layer_input,cat_label], outputs=d, name="discriminator")
  return model

In [22]:
discriminator = discriminator()
tf.keras.utils.plot_model(discriminator,to_file='discriminator.png', show_shapes=True, dpi=64)

In [23]:
discriminator.summary()

In [24]:
generator_optimizer = tf.optimizers.Adam(0.0001)
discriminator_optimizer = tf.optimizers.Adam(0.0001)

## Checkpoint

In [25]:
default_patch = "./"

In [26]:
checkpoint_dir = default_patch + '/check/'
checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt")
checkpoint = tf.train.Checkpoint(generator_optimizer=generator_optimizer,
                                 discriminator_optimizer=discriminator_optimizer,
                                 generator=generator,
                                 discriminator=discriminator)

In [27]:
log_dir = default_patch + "/logs/"

summary_writer = tf.summary.create_file_writer(
  log_dir + "fit/" + datetime.datetime.now().strftime("%Y%m%d-%H%M%S")+'/')

In [28]:
rows=n_classes
cols=1
noises = np.random.normal(0, 1, (rows * cols, z_dim))
num_labels = np.arange(0, rows).reshape(-1, 1)

In [29]:
def sample_images(model, rows=2, cols=2):
  prediction = model([num_labels,noises], training=False)

  plot, axs = plt.subplots(rows, cols,figsize=(10, 5))  
  
  gen_imgs = [p for p in prediction]
  pos_img = 0
  for r in range(rows):
    for c in range(cols):
      axs[r,c].axis('off')
      axs[r,c].set_title(f'{class_name[pos_img]}')
      axs[r,c].imshow(gen_imgs[pos_img][:,:,0] * 0.5 + 0.5, cmap='gray')
      pos_img += 1
  plt.show()

In [30]:
sample_images(model=generator)

In [31]:
loss_object = tf.keras.losses.BinaryCrossentropy(from_logits=True)

In [32]:
def discrimantor_loss(disc_fake, disc_real, step, add_summary=True):
  loss_disc_real = loss_object(tf.ones_like(disc_real),disc_real)
  
  loss_disc_fake = loss_object(tf.zeros_like(disc_fake),disc_fake)
  
  loss_disc = (loss_disc_real + loss_disc_fake)

  if add_summary:
    with summary_writer.as_default():
      tf.summary.scalar('loss_disc_real', loss_disc_real, step = step // 1000)
      tf.summary.scalar('loss_disc_fake', loss_disc_fake, step = step // 1000)
      tf.summary.scalar('loss_disc', loss_disc, step = step // 1000)

  return loss_disc, loss_disc_real, loss_disc_fake

In [33]:
def generator_loss(gen_out, step):

  loss_gen_out = loss_object(tf.ones_like(gen_out),gen_out)

  with summary_writer.as_default():
    tf.summary.scalar('loss_gen_out', loss_gen_out, step = step // 1000)

  return loss_gen_out

## Train

In [34]:
%reload_ext tensorboard

In [35]:
%load_ext tensorboard
%tensorboard --logdir {log_dir}

In [41]:
@tf.function
def train_step(input_image, input_label, generator, discriminator, step):
  z_noise  = np.random.normal(0,1,(input_label.shape[0], z_dim))

  assert z_noise.shape == (input_label.shape[0], z_dim)


  with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape_real, tf.GradientTape() as disc_tape_fake:
    #Train Generator
    gen_out = generator([input_label,z_noise], training=True)

    #Train Discriminator
    disc_fake = discriminator([gen_out, input_label], training=True)
    disc_real = discriminator([input_image, input_label], training=True)

    #Loss Discriminator
    loss_disc, loss_disc_real, loss_disc_fake = discrimantor_loss(disc_fake, disc_real, step)

    #Loss Generator
    loss_gen_out = generator_loss(disc_fake, step)

    
  #Generator Gradiente
  generator_gradients = gen_tape.gradient(loss_gen_out,generator.trainable_variables)
  generator_optimizer.apply_gradients(zip(generator_gradients,generator.trainable_variables))

  #Discriminator Gradiente to real imagens
  discriminator_gradients = disc_tape_real.gradient(loss_disc_real,discriminator.trainable_variables)
  discriminator_optimizer.apply_gradients(zip(discriminator_gradients,discriminator.trainable_variables))

  #Discriminator Gradiente to fake imagens
  discriminator_gradients = disc_tape_fake.gradient(loss_disc_fake,discriminator.trainable_variables)
  discriminator_optimizer.apply_gradients(zip(discriminator_gradients,discriminator.trainable_variables))
  

In [42]:
def fit(interations):
  start = datetime.datetime.now()
  start_1k = datetime.datetime.now()
  for (step, input_image) in train_data.repeat().take(interations).enumerate():
    train_step(input_image[0], input_image[1], generator, discriminator, step)

    if( (step+1) % 1000 == 0):
      checkpoint.save(file_prefix=checkpoint_prefix)
      display.clear_output(wait=True)
      print(f'Time per 1k steps:{datetime.datetime.now()-start_1k}')
      start_1k = datetime.datetime.now()
      print(f'step: {(step+1)//1000}k')
      sample_images(generator)
      print(f'Total time:{datetime.datetime.now()-start}')

In [43]:
fit(50000)

In [44]:
def test_train(model, rows=1, cols=5):
  
  for n in range(4):

    labels = np.ones((5,1)) * n
    z_dim_lat = np.random.normal(0, 1, (1 * 5, z_dim))

    prediction = model([labels, z_dim_lat], training=False)

    plt.subplots(1, 5,figsize=(10, 5))
    
    
    gen_imgs = [p for p in prediction]
    pos_img = 0
    for c in range(cols):
      plt.subplot(1,5, c+1)
      plt.axis('off')
      plt.title(f'{class_name[int(labels[pos_img])]}')
      plt.imshow(gen_imgs[pos_img][:,:,0] * 0.5 + 0.5, cmap='gray')
      pos_img += 1
  plt.show()

In [45]:
test_train(generator)

In [46]:
class_name[1]

In [47]:
print(1)

In [48]:
!ls ./check

In [49]:
!zip -r checkpoint ./check

In [51]:
!pwd