In [1]:
import tensorflow as tf
from tensorflow import keras
import os
import pathlib
import matplotlib.pyplot as plt
from tensorflow.keras.optimizers import Adam

2024-01-09 15:59:39.213082: E tensorflow/compiler/xla/stream_executor/cuda/cuda_dnn.cc:9342] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-01-09 15:59:39.213124: E tensorflow/compiler/xla/stream_executor/cuda/cuda_fft.cc:609] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-01-09 15:59:39.213139: E tensorflow/compiler/xla/stream_executor/cuda/cuda_blas.cc:1518] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2024-01-09 15:59:39.217776: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [2]:
# load data
batch_size = 32

path = os.getcwd()
data_path = os.path.join(path, 'notebooks/data_clean')
root_path = pathlib.Path(data_path)

data = keras.utils.image_dataset_from_directory(
    directory=root_path,
    label_mode=None,
    batch_size=batch_size,
    image_size=(64,64))

data = data.map(lambda d : ((d-127.5)/127.5))
data

KeyboardInterrupt: 

In [None]:
class GAN(tf.keras.Model):
    def __init__(self, discriminator, generator, latent_dim):
        super(GAN, self).__init__()
        self.discriminator = discriminator
        self.generator = generator
        self.latent_dim = latent_dim

    def compile(self, d_optimizer, g_optimizer, loss_fn):
        super(GAN, self).compile()
        self.d_optimizer = d_optimizer
        self.g_optimizer = g_optimizer
        self.loss_fn = loss_fn
        self.d_loss_metric = tf.keras.metrics.Mean(name="d_loss")
        self.g_loss_metric = tf.keras.metrics.Mean(name="g_loss")

    @property
    def metrics(self):
        return [self.d_loss_metric, self.g_loss_metric]

    def train_step(self, real_images):
        # Sample random points in the latent space
        batch_size = tf.shape(real_images)[0]
        seed = tf.random.normal(shape=(batch_size, self.latent_dim))
        # Decode them to fake images
        generated_images = self.generator(seed)
        # Combine them with real images
        combined_images = tf.concat([generated_images, real_images], axis=0)
        # Assemble labels discriminating real from fake images
        labels = tf.concat([tf.ones((batch_size, 1)), tf.zeros((batch_size, 1))], axis=0)
        # Add random noise to the labels - important trick!
        labels += 0.05 * tf.random.uniform(tf.shape(labels))
        # Train the discriminator
        with tf.GradientTape() as tape:
            predictions = self.discriminator(combined_images)
            d_loss = self.loss_fn(labels, predictions)
        grads = tape.gradient(d_loss, self.discriminator.trainable_weights)
        self.d_optimizer.apply_gradients(zip(grads, self.discriminator.trainable_weights))

        # Sample random points in the latent space
        seed = tf.random.normal(shape=(batch_size, self.latent_dim))

        # Assemble labels that say "all real images"
        misleading_labels = tf.zeros((batch_size, 1))

        # Train the generator (note that we should *not* update the weights of the discriminator)!
        with tf.GradientTape() as tape:
            predictions = self.discriminator(self.generator(seed))
            g_loss = self.loss_fn(misleading_labels, predictions)
        grads = tape.gradient(g_loss, self.generator.trainable_weights)
        self.g_optimizer.apply_gradients(zip(grads, self.generator.trainable_weights))

        # Update metrics
        self.d_loss_metric.update_state(d_loss)
        self.g_loss_metric.update_state(g_loss)
        return {"d_loss": self.d_loss_metric.result(), "g_loss": self.g_loss_metric.result()}

: 

In [None]:
# load model

# my model name is [epoch]_generator.h5 and [epoch]_discriminator.h5, take the epoch number bigger in models folder

models_path = pathlib.Path('/tf/notebooks/models')
models = os.listdir(models_path)
epoch_start = 0

for model in models:
    if 'generator' in model:
        epoch = int(model.split('_')[0])
        if epoch > 0:
            epoch_start = epoch
            G_model_path = os.path.join(models_path, model)
    elif 'discriminator' in model:
        epoch = int(model.split('_')[0])
        if epoch > 0:
            epoch_start = epoch
            D_model_path = os.path.join(models_path, model)

G_model = tf.keras.models.load_model(G_model_path)
D_model = tf.keras.models.load_model(D_model_path)

model = GAN(discriminator=D_model, generator=G_model, latent_dim=100)

loss_fn = tf.keras.losses.BinaryCrossentropy()
G_optm = Adam(1e-4)
D_optm = Adam(1e-4)
model.compile(d_optimizer=D_optm, g_optimizer=G_optm, loss_fn=loss_fn)

: 

In [None]:
print(G_model.output_shape)
print(D_model.output_shape)

: 

In [None]:
# check if working
latent_dim = 100
random_noise = tf.random.normal([1,latent_dim])
generated_image = G_model(random_noise, training=False)
generated_image.shape
plt.imshow(generated_image[0, :, :, 0])
plt.axis("off")

: 

In [None]:
epoch = 5
model.fit(data, epochs=epoch)

In [None]:
# creating a random nosie to feed it to the trained Generator model
noise = tf.random.normal([32, 100])
# Generatine new images using the trained Generator model 
generated_images = G_model(noise, training=False)

# converting the input image to the range [0, 255]
generated_images1 = (generated_images+127.5)*127.5

plt.figure(figsize=(5, 5))
for i in range(16):
    ax = plt.subplot(4, 4, i+1)
    plt.imshow(generated_images1[i].numpy().astype("uint8"))
    plt.axis('off')

plt.show()

: 

In [None]:
epoch_end = epoch_start + epoch

# save image (4x4)
plt.savefig(f'/tf/notebooks/models/{epoch_end}_generated_images.png')
G_model.save(f'/tf/notebooks/models/{epoch_end}_generator.h5')
D_model.save(f'/tf/notebooks/models/{epoch_end}_discriminator.h5')

: 