In [None]:
import numpy as np
import keras.api as keras
from keras.api.layers import Dense, Conv2D, Input, Lambda, UpSampling2D, Dropout, Flatten, Activation, LeakyReLU
from keras.api import Model
import tensorflow as tf
import matplotlib.pyplot as plt

2025-03-30 13:30:22.173907: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1743352222.189529  124202 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1743352222.194121  124202 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2025-03-30 13:30:22.210404: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [None]:
image_size = (128, 128)
channel = 1
BUFFER_SIZE = 60000
BATCH_SIZE = 32
latent_dim_size = 128
dataset_path = "../../data/train/"
complete_hist = {
    'loss_dis': [],
    'loss_gen': [],
}

In [None]:
def normalize(image):
    image = tf.cast(image, tf.float32) / 255.0
    return image

def loss_discriminador(real_output, fake_output):
    real_loss = keras.losses.BinaryCrossentropy()(tf.ones_like(real_output), real_output)
    fake_loss = keras.losses.BinaryCrossentropy()(tf.zeros_like(fake_output), fake_output)
    return real_loss + fake_loss

def loss_gerador(fake_output):
    return keras.losses.BinaryCrossentropy()(tf.ones_like(fake_output), fake_output)

In [None]:
train_ds = keras.preprocessing.image_dataset_from_directory(
    dataset_path,
    label_mode=None,
    color_mode='grayscale',
    image_size=image_size,
    shuffle=True,
    seed = 1234,
    batch_size=BATCH_SIZE
)
train_ds = train_ds.map(lambda x: (normalize(x)))

for batch in train_ds:
    for image in batch:
        plt.imshow(image,cmap=plt.cm.gray)
        plt.show()
        break
    break

In [None]:
# Adaptive Instance Normalization (AdaIN)
def ada_in(x):
    style, content = x
    style = tf.reshape(style,(-1,1,1,style.shape[-1]))
    mean_s, var_s = tf.nn.moments(style, axes=[1, 2], keepdims=True)
    mean_c, var_c = tf.nn.moments(content, axes=[1, 2], keepdims=True)
    return (var_c**0.5) * ((content - mean_c) / (var_s**0.5)) + mean_s

# Mapping Network (f)
def mapping_network(latent_dim, w_dim):
    inputs = Input(shape=(latent_dim,))
    x = Dense(512, activation="relu")(inputs)
    for _ in range(7):
        x = Dense(512, activation="relu")(x)
    outputs = Dense(w_dim)(x)
    return Model(inputs, outputs, name="MappingNetwork")

# Synthesis Network (g)
def synthesis_network(w_dim, image_size=64):
    inputs = Input(shape=(w_dim,))
    
    # Inicializa um tensor constante (4x4x512)
    x = tf.Variable(tf.random.normal([1, 4, 4, 512]), trainable=True)
    
    # Primeira camada convolucional + AdaIN
    style = Dense(512)(inputs)
    x = Conv2D(512, (3,3), padding="same", activation="relu")(x)
    x = Lambda(ada_in, x[0].shape)([style, x])
    print(x.shape)

    # Upsample + convolução em cascata
    for i in range(2, int(np.log2(image_size))):
        x = UpSampling2D()(x)
        x = Conv2D(512 // (2**i), (3,3), padding="same", activation="relu")(x)
        style = Dense(512 // (2**i))(inputs)
        x = Lambda(ada_in, x[0].shape)([style, x])
    
    # Camada final para gerar a imagem
    outputs = Conv2D(1, (1,1), activation="sigmoid")(x)
    
    return Model(inputs, outputs, name="SynthesisNetwork")

def create_discriminator():

    input = Input((128,128,1))
    x = input
    k = 1
    for _ in range(6):
        x = Conv2D(32*k, (4, 4), strides=(2, 2), padding='same')(x)
        x = LeakyReLU()(x)
        k *= 2
    
    x = Flatten()(x)
    x = Dropout(0.4)(x)
    x = Dense(1)(x)
    output = Activation('sigmoid')(x)


    return keras.Model(input,output, name='Discriminator')

# Dimensões da rede
latent_dim = 128
w_dim = 512
image_size = 128

# Criando os modelos
mapping = mapping_network(latent_dim, w_dim)
synthesis = synthesis_network(w_dim, image_size)

# Pipeline completo do StyleGAN
latent_input = Input(shape=(latent_dim,))
w = mapping(latent_input)
image_output = synthesis(w)

gen = Model(latent_input, image_output, name="StyleGAN")
dis = create_discriminator()
gen_opt = keras.optimizers.Adam(learning_rate=1e-4, beta_1=0.5)
dis_opt = keras.optimizers.Adam(learning_rate=1e-4/2, beta_1=0.5)
#dis.summary()



In [None]:
#@tf.function
def train_step():
    gen_loss,dis_loss = 0.,0.
    gen_loss_iter,dis_loss_iter = 0.,0.
    for batch in train_ds:
        
        noise = tf.random.normal((32,128))

        with tf.GradientTape() as gen_tape, tf.GradientTape() as dis_tape:
            fake_imgs = gen(noise,training=True)
            true_labels = dis(batch,training=True)
            fake_labels = dis(fake_imgs,training=True)

            gen_loss_iter = loss_gerador(fake_labels)
            dis_loss_iter = loss_discriminador(true_labels,fake_labels)
        
        gen_gras = gen_tape.gradient(gen_loss_iter,gen.trainable_variables)
        gen_opt.apply_gradients(zip(gen_gras,gen.trainable_variables))

        dis_grads = dis_tape.gradient(dis_loss_iter,dis.trainable_variables)
        dis_opt.apply_gradients(zip(dis_grads,dis.trainable_variables))

        gen_loss += gen_loss_iter
        dis_loss += dis_loss_iter
        gen_loss_iter,dis_loss_iter = 0.,0.

    return gen_loss/tf.cast(len(train_ds),tf.float32),dis_loss/tf.cast(len(train_ds),tf.float32)

In [None]:
EPOCHS = 5000
EPOCH_SAMPLE = 10
n = 5

for i in range(EPOCHS):

    # Histórico de Loss
    loss_gen, loss_dis = train_step()
    complete_hist['loss_gen'].append(loss_gen)
    complete_hist['loss_dis'].append(loss_dis)
    
    # Iteração das épocas
    if i % EPOCH_SAMPLE == 0:
        gen.save_weights(f'models/weights/gen_{i}.weights.h5')
        dis.save_weights(f'models/weights/dis_{i}.weights.h5')
        # Print Loss
        print(f'Ep = {i} | Loss_gen = {loss_gen:.4f}; Loss_dis = {loss_dis:.4f}')
        # Salvar uma amostra das imagens
        noise = tf.random.normal((n**2,latent_dim_size))
        img_fake = gen(noise)
        fig, ax = plt.subplots(n,n,figsize=(1,1))
        plt.subplots_adjust(wspace=0,hspace=0)
        ax = ax.ravel()
        for ii in range(n**2):
            ax[ii].imshow(img_fake[ii],cmap='gray')
            ax[ii].set_axis_off()
        fig.tight_layout(pad=0)
        plt.savefig(f'../../imgs_fake/fig{i}.png',dpi=1000)
        plt.close()

    if i % 400 == 0:
        gen.save(f'models/gen_model_style_{i}.keras')
        dis.save(f'models/dis_model_style_{i}.keras')

    plt.semilogy(np.array(complete_hist['loss_gen']),label=f'GEN = {loss_gen:.4f}',color='r')
    plt.semilogy(np.array(complete_hist['loss_dis']),label=f'DIS = {loss_dis:.4f}',color='k')
    plt.legend()
    plt.grid(True,'minor')
    plt.savefig('loss.png')
    plt.close()


print('==================== COMPLETE ====================')