# Library import

In [None]:
import tensorflow as tf
import keras.api as keras
from keras.api import layers
from keras.api.layers import Conv2D, LeakyReLU, Flatten, Dropout, Dense, Softmax, BatchNormalization, UpSampling2D
import matplotlib.pyplot as plt
import numpy as np

# Parameters

In [None]:
image_size = (128, 128)
channel = 1
BUFFER_SIZE = 60000
BATCH_SIZE = 32
latent_dim_size = 128
dataset_path = "../../data/train/"
complete_hist = {
    'loss_dis': [],
    'loss_gen': [],
}

# Class

In [None]:
class miniBatch(keras.layers.Layer):
    def __init__(self,num_kernels,kernel_dim):
        super(miniBatch,self).__init__()
        self.num_kernels = num_kernels
        self.kernel_dim = kernel_dim
        #self.batch_size = batch_size
    
    def build(self, input_shape):
        print(input_shape)
        self.T = self.add_weight(
            shape=(input_shape[-1],self.num_kernels*self.kernel_dim), # Teoricamente 128x500
            initializer='random_normal',
            trainable=True,
        )

    def call(self, x):
        M = tf.matmul(x,self.T) # teoricamente 128x128 \times 128x500 = 128x500
        M = tf.reshape(M,(-1,self.num_kernels,self.kernel_dim)) # teoricamente 128x100x5
        M_T = tf.expand_dims(M,1) # teoricamente 128x1x100x5
        M = tf.expand_dims(M,0) # teoricamente 1x128x100x5
        diff = tf.abs(M-M_T)
        exp_diff = tf.exp(-tf.reduce_mean(diff,-1))
        miniBatch_features = tf.reduce_sum(exp_diff,1)
        output = tf.concat([x,miniBatch_features],-1)
        return output
    
    def compute_output_shape(self, input_shape):
        # Define a forma de saída explicitamente
        return (input_shape[0], input_shape[1] + self.num_kernels)
    
class SelfAttention(keras.layers.Layer):
    def __init__(self, filters):
        super(SelfAttention, self).__init__()
        self.filters = filters
        self.query_conv = Conv2D(filters // 8, kernel_size=1)
        self.key_conv = Conv2D(filters // 8, kernel_size=1)
        self.value_conv = Conv2D(filters, kernel_size=1)
        self.softmax = Softmax(axis=-1)
    
    def call(self, x):
        batch, height, width, channels = tf.shape(x)[0], tf.shape(x)[1], tf.shape(x)[2], tf.shape(x)[3]

        Q = tf.reshape(self.query_conv(x), (batch, height * width, -1))  # [B, HW, C/8]
        K = tf.reshape(self.key_conv(x), (batch, -1, height * width))    # [B, C/8, HW]
        V = tf.reshape(self.value_conv(x), (batch, height * width, -1))  # [B, HW, C]

        attention_map = self.softmax(tf.matmul(Q, K))  # [B, HW, HW]

        attention_output = tf.matmul(attention_map, V)  # [B, HW, C]
        attention_output = tf.reshape(attention_output, (batch, height, width, channels))

        return attention_output + x

# Functions

In [None]:
def normalize(image):
    image = tf.cast(image, tf.float32) / 255.0
    return image

def loss_discriminador(real_output, fake_output):
    real_loss = keras.losses.BinaryCrossentropy()(tf.ones_like(real_output)*tf.random.uniform((1,),.89,.99), real_output)
    fake_loss = keras.losses.BinaryCrossentropy()(tf.zeros_like(fake_output), fake_output)
    return real_loss + fake_loss

def loss_gerador(fake_output):
    return keras.losses.BinaryCrossentropy()(tf.ones_like(fake_output), fake_output)



def create_generator():

    input = layers.Input((latent_dim_size,))
    x = layers.Dense(4*4*1024)(input)
    x = layers.LeakyReLU()(x)
    x = layers.Reshape((4, 4, x.shape[-1]//4//4))(x)
    x = BatchNormalization()(x)


    for _ in range(5):
        x = UpSampling2D()(x)
        x = layers.Conv2DTranspose(512, (3, 3), strides=1, padding='same', use_bias=False)(x)
        x = layers.LeakyReLU(.2)(x)
        x = layers.BatchNormalization()(x)

    x = Conv2D(1,3,1,'same')(x)
    x = layers.Activation('sigmoid')(x)

    model = keras.Model(input,x)
    
    return model

def create_discriminator():

    input = layers.Input((128,128,1))
    x = input
    k = 1
    for _ in range(6):
        x = layers.Conv2D(32*k, (4, 4), strides=(2, 2), padding='same')(x)
        x = layers.LeakyReLU()(x)
        k *= 2
    
    x = layers.Flatten()(x)
    x = miniBatch(100,5)(x)
    x = layers.Dropout(0.4)(x)
    x = layers.Dense(1)(x)
    output = layers.Activation('sigmoid')(x)


    modelo = keras.Model(input,output)
    
    return modelo

# Dataset load

In [None]:
train_ds = keras.preprocessing.image_dataset_from_directory(
    dataset_path,
    label_mode=None,
    color_mode='grayscale',
    image_size=image_size,
    shuffle=True,
    seed = 1234,
    batch_size=BATCH_SIZE
)
train_ds = train_ds.map(lambda x: (normalize(x)))

for batch in train_ds:
    for image in batch:
        plt.imshow(image,cmap=plt.cm.gray)
        plt.show()
        break
    break

# Create models

In [None]:
gen = create_generator()
dis = create_discriminator()

# Configurate optimizers

In [None]:
gen_opt = keras.optimizers.Adam(learning_rate=1e-4, beta_1=0.5)
dis_opt = keras.optimizers.Adam(learning_rate=1e-4/2, beta_1=0.5)

# Training Function

In [None]:
@tf.function
def train_step():
    gen_loss,dis_loss = 0.,0.
    gen_loss_iter,dis_loss_iter = 0.,0.
    for batch in train_ds:
        
        noise = tf.random.normal((BATCH_SIZE,latent_dim_size))

        with tf.GradientTape() as gen_tape, tf.GradientTape() as dis_tape:
            fake_imgs = gen(noise,training=True)
            true_labels = dis(batch,training=True)
            fake_labels = dis(fake_imgs,training=True)

            gen_loss_iter = loss_gerador(fake_labels)
            dis_loss_iter = loss_discriminador(true_labels,fake_labels)
        
        gen_gras = gen_tape.gradient(gen_loss_iter,gen.trainable_variables)
        gen_opt.apply_gradients(zip(gen_gras,gen.trainable_variables))

        dis_grads = dis_tape.gradient(dis_loss_iter,dis.trainable_variables)
        dis_opt.apply_gradients(zip(dis_grads,dis.trainable_variables))

        gen_loss += gen_loss_iter
        dis_loss += dis_loss_iter
        gen_loss_iter,dis_loss_iter = 0.,0.

    return gen_loss/tf.cast(len(train_ds),tf.float32),dis_loss/tf.cast(len(train_ds),tf.float32)

# Training Block

In [None]:
EPOCHS = 5000
EPOCH_SAMPLE = 10
n = 5

for i in range(EPOCHS):

    # Histórico de Loss
    loss_gen, loss_dis = train_step()
    complete_hist['loss_gen'].append(loss_gen)
    complete_hist['loss_dis'].append(loss_dis)
    
    # Iteração das épocas
    if i % EPOCH_SAMPLE == 0:
        gen.save_weights(f'models/weights/gen_{i}.weights.h5')
        dis.save_weights(f'models/weights/dis_{i}.weights.h5')
        # Print Loss
        print(f'Ep = {i} | Loss_gen = {loss_gen:.4f}; Loss_dis = {loss_dis:.4f}')
        # Salvar uma amostra das imagens
        noise = tf.random.normal((n**2,latent_dim_size))
        img_fake = gen(noise)
        fig, ax = plt.subplots(n,n,figsize=(1,1))
        plt.subplots_adjust(wspace=0,hspace=0)
        ax = ax.ravel()
        for ii in range(n**2):
            ax[ii].imshow(img_fake[ii],cmap='gray')
            ax[ii].set_axis_off()
        fig.tight_layout(pad=0)
        plt.savefig(f'../../imgs_fake/fig{i}.png',dpi=1000)
        plt.close()

    if i % 400 == 0:
        gen.save(f'models/gen_model_{i}.keras')
        dis.save(f'models/dis_model_{i}.keras')

    plt.semilogy(np.array(complete_hist['loss_gen']),label=f'GEN = {loss_gen:.4f}',color='r')
    plt.semilogy(np.array(complete_hist['loss_dis']),label=f'DIS = {loss_dis:.4f}',color='k')
    plt.legend()
    plt.grid(True,'minor')
    plt.savefig('loss.png')
    plt.close()


print('==================== COMPLETE ====================')

In [None]:
gen.save('../../models/gen_2_860.keras')
dis.save('../../models/dis_2_860.keras')