In [59]:
# External
import tensorflow as tf
import tensorflow_datasets as tfds
import numpy as np
# Local
from CGAN import CGanNet

In [60]:
NET_NAME = "cgan_example"
BATCH_SIZE = 32
IMAGE_WIDTH = 64
IMAGE_HEIGHT = 64
LEARNING_RATE: float = 0.0002
DROPOUT_RATE: float = 0.4

In [61]:
# Load dataset
(training_set, validation_set), dataset_info = tfds.load(
    'oxford_flowers102',
    split=['test', 'validation'],
    with_info=True,
    as_supervised=True,
)
# Get size of datasets from dataset_info
dataset_size = dataset_info.splits['test'].num_examples
# Get number of labels
num_classes = dataset_info.features['label'].num_classes

# Preprocessing images
def format_image(image, label):
    image = tf.cast(image, tf.float32)
    image = image / 255.0
    # image = (image - 127.5) / 127.5 # Normalize [-1, 1]
    image = tf.image.resize(image, (IMAGE_WIDTH, IMAGE_HEIGHT))

    return image, label
train_batches = training_set.shuffle(dataset_size//4).map(format_image)

# Get labels numbers to names map
get_label_name = dataset_info.features['label'].int2str
labels_strings = { _: get_label_name(_) for _ in range(num_classes) }

# Create numpy arrays with images and labels
numpy_train_batches_images = np.array([_[0] for _ in train_batches])
numpy_train_batches_labels = np.array([_[1] for _ in train_batches])

In [62]:
cgan = CGanNet(net_name=NET_NAME,
               batch_size=BATCH_SIZE,
               image_width=IMAGE_WIDTH,
               image_height=IMAGE_HEIGHT,
               learning_rate= LEARNING_RATE,
               dropout_rate= DROPOUT_RATE,
               number_of_channels=3,
               latent_dimension=100,
               training_data=numpy_train_batches_images,
               labels_data=numpy_train_batches_labels,
               number_of_classes=num_classes,
               labels_names = labels_strings)
# Define GAN models
cgan.define_discriminator()
cgan.define_generator()
cgan.define_gan()

In [None]:
# Wczytanie modelu
cgan.load_model()

In [None]:
# Wygenerowanie obrazów
cgan.show_sample_images_with_labels()