In [None]:
import tensorflow as tf
from tensorflow.keras.models import Model
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.datasets import cifar10
from keras.preprocessing import image
import keras.backend as K
import matplotlib.pyplot as plt
import numpy as np
import time
from tqdm import tqdm

In [None]:
batch_size = 16
epoch_count = 50
noise_dim = 100
n_class = 10
tags = ['Airplane', 'Automobile', 'Bird', 'Cat', 'Deer', 'Dog', 'Frog', 'Horse', 'Ship', 'Truck']
img_size = 32

In [None]:
(X_train, y_train), (_, _) = cifar10.load_data()

X_train = (X_train - 127.5) / 127.5

dataset = tf.data.Dataset.from_tensor_slices((X_train, y_train))
dataset = dataset.shuffle(buffer_size=1000).batch(batch_size)

In [None]:
plt.figure(figsize=(2,2))
idx = np.random.randint(0,len(X_train))
img = image.array_to_img(X_train[idx], scale=True)
plt.imshow(img)
plt.axis('off')
plt.title(tags[y_train[idx][0]])
plt.show()

In [None]:
bce_loss = tf.keras.losses.BinaryCrossentropy()

def discriminator_loss(real, fake):
	real_loss = bce_loss(tf.ones_like(real), real)
	fake_loss = bce_loss(tf.zeros_like(fake), fake)
	total_loss = real_loss + fake_loss
	return total_loss

def generator_loss(preds):
	return bce_loss(tf.ones_like(preds), preds)

In [None]:
d_optimizer = Adam(learning_rate=0.0002, beta_1=0.5)
g_optimizer = Adam(learning_rate=0.0002, beta_1=0.5)

In [None]:
def build_generator():
	in_label = tf.keras.layers.Input(shape=(1,))
	li = tf.keras.layers.Embedding(n_class, 50)(in_label)

	n_nodes = 8 * 8
	li = tf.keras.layers.Dense(n_nodes)(li)
	li = tf.keras.layers.Reshape((8, 8, 1))(li)

	in_lat = tf.keras.layers.Input(shape=(noise_dim,))
	n_nodes = 128 * 8 * 8
	gen = tf.keras.layers.Dense(n_nodes)(in_lat)
	gen = tf.keras.layers.LeakyReLU(alpha=0.2)(gen)
	gen = tf.keras.layers.Reshape((8, 8, 128))(gen)

	merge = tf.keras.layers.Concatenate()([gen, li])

	gen = tf.keras.layers.Conv2DTranspose(128, (4, 4), strides=(2, 2), padding='same')(merge)
	gen = tf.keras.layers.LeakyReLU(alpha=0.2)(gen)

	gen = tf.keras.layers.Conv2DTranspose(128, (4, 4), strides=(2, 2), padding='same')(gen)
	gen = tf.keras.layers.LeakyReLU(alpha=0.2)(gen)

	out_layer = tf.keras.layers.Conv2D(3, (8, 8), activation='tanh', padding='same')(gen)

	model = Model([in_lat, in_label], out_layer)
	return model

g_model = build_generator()
g_model.summary()

In [None]:
def build_discriminator():
	in_label = tf.keras.layers.Input(shape=(1,))
	li = tf.keras.layers.Embedding(n_class, 50)(in_label)

	n_nodes = img_size * img_size
	li = tf.keras.layers.Dense(n_nodes)(li)
	li = tf.keras.layers.Reshape((img_size, img_size, 1))(li)

	in_image = tf.keras.layers.Input(shape=(img_size, img_size, 3))

	merge = tf.keras.layers.Concatenate()([in_image, li])

	fe = tf.keras.layers.Conv2D(128, (3,3), strides=(2,2), padding='same')(merge)
	fe = tf.keras.layers.LeakyReLU(alpha=0.2)(fe)

	fe = tf.keras.layers.Conv2D(128, (3,3), strides=(2,2), padding='same')(fe)
	fe = tf.keras.layers.LeakyReLU(alpha=0.2)(fe)

	fe = tf.keras.layers.Flatten()(fe)
	fe = tf.keras.layers.Dropout(0.4)(fe)

	out_layer = tf.keras.layers.Dense(1, activation='sigmoid')(fe)

	model = Model([in_image, in_label], out_layer)
	return model

d_model = build_discriminator()
d_model.summary()


In [None]:
def build_gan(generator, discriminator):
	discriminator.trainable = False
	noise, label = generator.input
	img = generator.output

	validity = discriminator([img, label])
	model = Model([noise, label], validity)
	return model

gan_model = build_gan(g_model, d_model)
gan_model.summary()

In [None]:
g_model.compile(loss=generator_loss, optimizer=g_optimizer)
d_model.compile(loss=discriminator_loss, optimizer=d_optimizer)
gan_model.compile(loss=generator_loss, optimizer=g_optimizer)

In [None]:
def train_gan(dataset, epochs, batch_size):
	for epoch in range(epochs):
		for batch in dataset:
			real_images, labels = batch
			batch_size = real_images.shape[0]

			noise = np.random.normal(0, 1, (batch_size, noise_dim))
			gen_labels = np.random.randint(0, n_class, batch_size).reshape(-1, 1)
			gen_images = g_model.predict([noise, gen_labels])

			d_loss_real = d_model.train_on_batch([real_images, labels], np.ones((batch_size, 1)))
			d_loss_fake = d_model.train_on_batch([gen_images, gen_labels], np.zeros((batch_size, 1)))
			d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)

			noise = np.random.normal(0, 1, (batch_size, noise_dim))
			valid_y = np.ones((batch_size, 1))
			g_loss = gan_model.train_on_batch([noise, gen_labels], valid_y)

		print(f"{epoch + 1}/{epochs}, D Loss: {d_loss}, G Loss: {g_loss}")
		if (epoch + 1) % 10 == 0:
			save_images(epoch, noise, gen_labels)

train_gan(dataset, epoch_count, batch_size)

In [None]:
def save_images(epoch, noise, labels):
	gen_images = g_model.predict([noise, labels])
	gen_images = 0.5 * gen_images + 0.5  # Rescale images to [0, 1]

	fig, axs = plt.subplots(4, 4, figsize=(4, 4))
	axs = axs.flatten()
	for img, label, ax in zip(gen_images, labels, axs):
		ax.imshow(img)
		ax.axis('off')
		ax.set_title(tags[label[0]])
	plt.show()