In [None]:
import tensorflow as tf
import keras
import numpy as np
import struct
import gzip
# import matplotlib
import matplotlib.pyplot as plt
from array import array

from flatbuffers.packer import float32
from tensorflow.python.data.experimental.ops.distribute import batch_sizes_for_worker

# my project
from module.conf import PROJECT_DIR
# matplotlib.use("QTAgg")
%matplotlib inline

In [None]:
tf.config.list_physical_devices()

In [None]:
# tf.config.list_physical_devices('GPU')
tf.config.set_visible_devices(tf.config.list_physical_devices('GPU')[0], 'GPU')

In [None]:
# dataset = tf.data.Dataset.from_tensor_slices(X_train).shuffle(60000).batch(128)
(x_train, _), (_, _) = keras.datasets.cifar10.load_data()
x_train = (x_train.astype("float64") - 127.5) / 127.5  # Standartize to [-1, 1]

In [None]:
dataset = tf.data.Dataset.from_tensor_slices(x_train).shuffle(x_train.shape[0]).batch(256)
# dataset = x_train

In [None]:
class Generator(keras.Model):
    def __init__(self, noise_dim=100):
        super(Generator, self).__init__()
        self.noise_dim = noise_dim
        self.dense = keras.layers.Dense(4 * 4 * 256, use_bias=False)
        self.bn1 = keras.layers.BatchNormalization()
        self.reshape = keras.layers.Reshape((4, 4, 256))

        self.conv1 = keras.layers.Conv2DTranspose(128, (4, 4), strides=(2, 2), padding="same", use_bias=False)
        self.bn2 = keras.layers.BatchNormalization()

        self.conv2 = keras.layers.Conv2DTranspose(64, (4, 4), strides=(2, 2), padding="same", use_bias=False)
        self.bn3 = keras.layers.BatchNormalization()

        self.conv3 = keras.layers.Conv2DTranspose(3, (4, 4), strides=(2, 2), padding="same", activation="tanh")

    def call(self, inputs, training=True):
        x = self.dense(inputs)
        x = self.bn1(x, training=training)
        x = tf.nn.relu(x)
        x = self.reshape(x)

        x = self.conv1(x)
        x = self.bn2(x, training=training)
        x = tf.nn.relu(x)

        x = self.conv2(x)
        x = self.bn3(x, training=training)
        x = tf.nn.relu(x)

        x = self.conv3(x)
        return x

In [None]:
class Discriminator(keras.Model):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.conv1 = keras.layers.Conv2D(64, (4, 4), strides=(2, 2), padding="same")
        self.lrelu1 = keras.layers.LeakyReLU(alpha=0.2)
        self.dropout1 = keras.layers.Dropout(0.3)

        self.conv2 = keras.layers.Conv2D(128, (4, 4), strides=(2, 2), padding="same")
        self.lrelu2 = keras.layers.LeakyReLU(alpha=0.2)
        self.dropout2 = keras.layers.Dropout(0.3)

        self.flatten = keras.layers.Flatten()
        self.fc = keras.layers.Dense(1, activation="sigmoid")

    def call(self, inputs, training=True):
        x = self.conv1(inputs)
        x = self.lrelu1(x)
        x = self.dropout1(x, training=training)

        x = self.conv2(x)
        x = self.lrelu2(x)
        x = self.dropout2(x, training=training)

        x = self.flatten(x)
        x = self.fc(x)
        return x

In [None]:
noise_dim = 100
generator = Generator(noise_dim)
discriminator = Discriminator()

# Optimizers
gen_optimizer = keras.optimizers.Adam(learning_rate=0.0002, beta_1=0.5)
disc_optimizer = keras.optimizers.Adam(learning_rate=0.0002, beta_1=0.5)

# Loss function
cross_entropy = keras.losses.BinaryCrossentropy(from_logits=True)

In [None]:
@tf.function
def train_step(real_images):
    noise = tf.random.normal([32, noise_dim])

    # Huấn luyện Discriminator
    with tf.GradientTape() as disc_tape:
        fake_images = generator(noise, training=True)
        real_output = discriminator(real_images, training=True)
        fake_output = discriminator(fake_images, training=True)

        real_loss = cross_entropy(tf.ones_like(real_output), real_output)
        fake_loss = cross_entropy(tf.zeros_like(fake_output), fake_output)
        disc_loss = real_loss + fake_loss

    grads_disc = disc_tape.gradient(disc_loss, discriminator.trainable_variables)
    disc_optimizer.apply_gradients(zip(grads_disc, discriminator.trainable_variables))

    # Huấn luyện Generator
    with tf.GradientTape() as gen_tape:
        fake_images = generator(noise, training=True)
        fake_output = discriminator(fake_images, training=True)
        gen_loss = cross_entropy(tf.ones_like(fake_output), fake_output)

    grads_gen = gen_tape.gradient(gen_loss, generator.trainable_variables)
    gen_optimizer.apply_gradients(zip(grads_gen, generator.trainable_variables))

    return gen_loss, disc_loss

In [None]:
def train(dataset, epochs):
    for epoch in range(epochs):
        for batch, real_images in enumerate(dataset):
            gen_loss, disc_loss = train_step(real_images)

        print(f"Epoch {epoch+1}, Generator Loss: {gen_loss.numpy()}, Discriminator Loss: {disc_loss.numpy()}")

        if (epoch + 1) % 10 == 0:
            generate_and_save_images(generator, epoch+1)

def generate_and_save_images(model, epoch):
    noise = tf.random.normal([16, noise_dim])
    generated_images = model(noise, training=False)

    fig, axes = plt.subplots(4, 4, figsize=(4, 4))
    for i, ax in enumerate(axes.flat):
        ax.imshow((generated_images[i] + 1) / 2)
        ax.axis("off")
    plt.savefig(f"gan_cifar10_epoch_{epoch}.png")
    plt.show()

# Train GAN
train(dataset, epochs=50)