In [1]:
import tensorflow as tf
import keras
import numpy as np
import struct
import gzip
# import matplotlib
import matplotlib.pyplot as plt
from array import array

from flatbuffers.packer import float32
from tensorflow.python.data.experimental.ops.distribute import batch_sizes_for_worker

# my project
from module.conf import PROJECT_DIR
# matplotlib.use("QTAgg")
%matplotlib inline

In [2]:
tf.config.list_physical_devices()

[PhysicalDevice(name='/physical_device:CPU:0', device_type='CPU'),
 PhysicalDevice(name='/physical_device:GPU:0', device_type='GPU')]

In [3]:
# tf.config.list_physical_devices('GPU')
tf.config.set_visible_devices(tf.config.list_physical_devices('GPU')[0], 'GPU')

In [4]:
# dataset = tf.data.Dataset.from_tensor_slices(X_train).shuffle(60000).batch(128)
import keras.datasets
import keras.datasets.cifar
import keras.datasets.cifar10
(x_train, _), (_, _) = keras.datasets.cifar10.load_data()
x_train = (x_train.astype("float64") - 127.5) / 127.5  # Standartize to [-1, 1]

In [5]:
# dataset = tf.data.Dataset.from_tensor_slices(x_train) #.shuffle(x_train.shape[0]).batch(256)
dataset = x_train

In [6]:
class Generator(keras.Model):
    def __init__(self, noise_dim: int, target_shape: tuple):
        super(Generator, self).__init__()
        self.model = keras.Sequential(layers= [
            keras.layers.InputLayer(input_shape=(noise_dim,)),

            keras.layers.Dense(units=8*8*256, activation=None, use_bias=False),
            keras.layers.Reshape(target_shape=(4, 4, 1024)),
            keras.layers.BatchNormalization(),
            keras.layers.ReLU(),

            keras.layers.Conv2DTranspose(128, (4, 4), strides=2, padding="same", use_bias=False),
            keras.layers.BatchNormalization(),
            keras.layers.ReLU(),

            keras.layers.Conv2DTranspose(64, (4, 4), strides=2, padding="same", use_bias=False),
            keras.layers.BatchNormalization(),
            keras.layers.Conv2DTranspose(3, (4, 4), strides=2, padding="same", activation="tanh")
        ])
        # self.fc = keras.layers.Dense(units=8*8*256, use_bias=False)
        # self.bn1 = keras.layers.BatchNormalization()
        # self.relu = keras.layers.ReLU()
        #
        # self.conv1 = keras.layers.Conv2DTranspose(128, (4, 4), strides=2, padding="same", use_bias=False)
        # self.bn2 = keras.layers.BatchNormalization()
        # self.conv2 = keras.layers.Conv2DTranspose(64, (4, 4), strides=2, padding="same", use_bias=False)
        # self.bn3 = keras.layers.BatchNormalization()
        # self.conv3 = keras.layers.Conv2DTranspose(3, (4, 4), strides=2, padding="same", activation="tanh")
        return

    def call(self, inputs, training=False):
        # x = self.fc(inputs, training=training)
        # x = tf.reshape(x, (-1, 4, 4, 1024))
        # x = self.bn1(x, training=training)
        # x = self.relu(x, training=training)
        #
        # x = self.conv1(x, training=training)
        # x = self.bn2(x, training=training)
        # x = self.relu(x, training=training)
        #
        # x = self.conv2(x, training=training)
        # x = self.bn3(x, training=training)
        # x = self.relu(x, training=training)
        #
        # x = self.conv3(x, training=training)  # Output shape: (32, 32, 3)
        # return x
        return self.model(inputs, training=training)

In [7]:
class Discriminator(keras.Model):
    def __init__(self, target_shape: tuple):
        super(Discriminator, self).__init__()
        # self.conv1 = keras.layers.Conv2D(64, (4, 4), strides=2, padding="same")
        # self.leaky_relu = keras.layers.LeakyReLU(0.2)
        # self.conv2 = keras.layers.Conv2D(128, (4, 4), strides=2, padding="same")
        # self.bn = keras.layers.BatchNormalization()
        # self.conv3 = keras.layers.Conv2D(filters=256, kernel_size=(4, 4), strides=2, padding="same")
        # self.flatten = keras.layers.Flatten()
        # self.fc = keras.layers.Dense(units=1, activation="sigmoid")
        self.model = keras.Sequential(layers=[
            # keras.layers.InputLayer(shape=target_shape),
            # keras.layers.Flatten(input_shape=target_shape),

            keras.layers.Conv2D(filters=64, kernel_size=(4, 4), strides=2, padding="same"),
            keras.layers.LeakyReLU(0.2),
            keras.layers.Conv2D(filters=128, kernel_size=(4, 4), strides=2, padding="same"),
            keras.layers.BatchNormalization(),
            keras.layers.LeakyReLU(0.2),
            keras.layers.Conv2D(filters=256, kernel_size=(4, 4), strides=2, padding="same"),
            keras.layers.LeakyReLU(0.2),
            keras.layers.Flatten(),
            keras.layers.Dense(1, activation="sigmoid")
        ])

    def call(self, inputs, training=False):
        # x = self.conv1(inputs, training=training)
        # x = self.leaky_relu(x, training=training)
        #
        # x = self.conv2(x, training=training)
        # x = self.bn(x, training=training)
        # x = self.leaky_relu(x, training=training)
        #
        # x = self.conv3(x, training=training)
        # x = self.leaky_relu(x, training=training)
        #
        # x = self.flatten(x, training=training)
        # x = self.fc(x, training=training)
        # return x
        return self.model(inputs, training=training)


In [8]:
class GAN(keras.Model):
    def __init__(self, noise_dim=100, target_shape=[32, 32, 3]):
        super(GAN, self).__init__()
        self.gen_optimizer = None
        self.disc_optimizer = None
        self.loss_fn = None
        self.noise_dim = noise_dim
        self.target_shape = target_shape
        self.generator = Generator(self.noise_dim, self.target_shape)
        self.discriminator = Discriminator(self.target_shape)

    def compile(self, gen_optimizer, disc_optimizer, loss_fn):
        super(GAN, self).compile()
        self.gen_optimizer = gen_optimizer
        self.disc_optimizer = disc_optimizer
        self.loss_fn = loss_fn
        pass

    @tf.function
    def train_step(self, batch):
        batch_sizes = tf.shape(batch)[0]
        # random noise
        noise = tf.random.normal(shape=(batch_sizes, self.noise_dim))

        # 1 - real image, 0 - fake image
        real_labels = tf.ones((batch_sizes, 1))
        fake_labels = tf.zeros((batch_sizes, 1))

        # train Discriminator
        with tf.GradientTape() as disc_tape, tf.GradientTape() as gen_tape:
            # gen fake img by Generator
            fake_images = self.generator(noise, training=True)
            
            real_preds = self.discriminator(batch, training=True)
            fake_preds = self.discriminator(fake_images, training=True)

            real_loss = self.loss_fn(real_labels, real_preds)
            fake_loss = self.loss_fn(fake_labels, fake_preds)

            gen_loss = self.loss_fn(real_labels, fake_preds)
            disc_loss = tf.divide(tf.add(real_loss, fake_loss), 2)
            pass
        # train Generator
        gen_grad = gen_tape.gradient(gen_loss, self.generator.trainable_variables)
        disc_grad = disc_tape.gradient(disc_loss, self.discriminator.trainable_variables)
        self.gen_optimizer.apply_gradients(zip(gen_grad, self.generator.trainable_variables))
        self.disc_optimizer.apply_gradients(zip(disc_grad, self.discriminator.trainable_variables))

        return {"gen_loss": gen_loss, "disc_loss": disc_loss}

In [None]:
import keras.optimizer_v2
import keras.optimizer_v2.adam


gan = GAN(noise_dim=512)

gan.compile(
    # gen_optimizer=keras.optimizers.Adam(0.0002, 0.5),
    # disc_optimizer=keras.optimizers.Adam(0.0002, 0.5),
    # loss_fn=keras.losses.BinaryCrossentropy()
    # gen_optimizer=keras.optimizers.Adam(learning_rate=0.001, beta_1=0.9),
    # disc_optimizer=keras.optimizers.Adam(learning_rate=0.001, beta_1=0.9),
    # loss_fn=keras.losses.BinaryCrossentropy()
    gen_optimizer=keras.optimizer_v2.adam.Adam(learning_rate=0.0002, beta_1=0.9),
    disc_optimizer=keras.optimizer_v2.adam.Adam(learning_rate=0.0002, beta_1=0.9),
    loss_fn=keras.losses.BinaryCrossentropy()
)

# gan.fit(dataset, epochs=10000)
gan.fit(dataset, epochs=200, batch_size=128)

In [None]:
gan.save_weights(PROJECT_DIR + "/data/models/cifar10_gan_200/model")

In [None]:
# gan.load_weights(PROJECT_DIR + "/data/models/fashionmnist_gan_1000_2/model")

In [None]:
import matplotlib.pyplot as plt

def generate_and_show_images(generator, num_examples=10):
    noise = tf.random.normal(shape=(num_examples, 512))
    images = generator(noise, training=False)
    images = (images + 1) / 2.0  # standardize to [0, 1]
    rows=4

    fig, axes = plt.subplots(nrows=rows, ncols=num_examples//rows, figsize=(10, 5))
    for i, ax in enumerate(axes.flat):
        ax.imshow(images[i], cmap="gray")
        ax.axis("off")
    plt.show()

generate_and_show_images(gan.generator, num_examples=20)