In [1]:
import os
import matplotlib.pyplot as plt
import tensorflow as tf
from gan.GAN import GanNet
import tensorflow_datasets as tfds
import numpy as np

In [2]:
NET_NAME = 'alfa'
BATCH_SIZE = 32
IMAGE_WIDTH = 32
IMAGE_HEIGHT = 32

In [3]:
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2"
physical_devices = tf.config.list_physical_devices("GPU")
tf.config.experimental.set_memory_growth(physical_devices[0], True)

In [None]:
# Load and prepare dataset
(training_set, validation_set), dataset_info = tfds.load(
    'oxford_flowers102',
    split=['test', 'validation'],
    with_info=True,
    as_supervised=True,
)
# Get size of datasets from dataset_info
dataset_size = dataset_info.splits['test'].num_examples
print(f'Dataset size is: {dataset_size}')
num_classes = dataset_info.features['label'].num_classes
print(f'Number of different images class labels: {num_classes}')

In [5]:
# Preprocessing images
def format_image(image, label):
    image = tf.cast(image, tf.float32)
    image = image / 255.0
    # image = (image - 127.5) / 127.5 # Normalize [-1, 1]
    image = tf.image.resize(image, (IMAGE_WIDTH, IMAGE_HEIGHT))

    return image, label

train_batches = training_set.shuffle(dataset_size//4).map(format_image)

get_label_name = dataset_info.features['label'].int2str
labels_strings = { _: get_label_name(_) for _ in range(num_classes) }

numpy_train_batches_images = np.array([_[0] for _ in train_batches])
numpy_train_batches_labels = np.array([_[1] for _ in train_batches])

In [None]:
rows = 5
cols = 5

fig, axs = plt.subplots(rows, cols, figsize=(15, 15))
cnt = 0

for i in range(rows):
    for j in range(cols):
        image = numpy_train_batches_images[cnt, :, :, :]
       #  image = (image + 1) / 2.0
        image = np.clip(image, 0, 1)
        label = numpy_train_batches_labels[cnt]
        label_str = labels_strings[label]
        axs[i, j].set_title(f'({label}): {label_str}')
        axs[i, j].imshow(np.squeeze(image), cmap='gray')
        axs[i, j].axis('off')
        cnt += 1
fig.set_facecolor('white')
plt.show()

In [7]:
gan = GanNet(
          net_name=NET_NAME,
          batch_size=BATCH_SIZE,
          batches_per_epoch=40,
          image_width=IMAGE_WIDTH,
          image_height=IMAGE_HEIGHT,
          number_of_channels=3,
          latent_dimension=100,
          training_data=numpy_train_batches_images,
          )

In [8]:
gan.define_discriminator()
gan.define_generator()
gan.define_gan()

In [None]:
gan.train(number_of_epochs=10, load_past_model=False)
gan.plot_loss()

In [None]:
gan.train(number_of_epochs=10, load_past_model=True)
gan.plot_loss()

In [11]:
gan.visualize_model()