In [7]:
import os
import tensorflow as tf
from gan.GAN import GanNet
import tensorflow_datasets as tfds
import numpy as np

In [8]:
BATCH_SIZE = 32
IMAGE_WIDTH = 28
IMAGE_HEIGHT = 28

In [9]:
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2"
physical_devices = tf.config.list_physical_devices("GPU")
tf.config.experimental.set_memory_growth(physical_devices[0], True)

In [10]:
# Load and prepare dataset

(training_set, validation_set), dataset_info = tfds.load(
    'oxford_flowers102',
    split=['test', 'train'],
    with_info=True,
    as_supervised=True,
)

num_training_examples = 0
for example in training_set:
    num_training_examples += 1

def format_image(image, label):
    image = tf.image.resize(image, (IMAGE_WIDTH, IMAGE_HEIGHT))/255.0
    return image, label

train_batches = training_set.shuffle(num_training_examples//4).map(format_image)

num_classes = dataset_info.features['label'].num_classes
train_batches_images = np.array([_[0] for _ in train_batches])
train_batches_labels = np.array([_[1] for _ in train_batches])

In [11]:
gan = GanNet(batch_size=BATCH_SIZE,
          batches_per_epoch=40,
          image_width=IMAGE_WIDTH,
          image_height=IMAGE_HEIGHT,
          number_of_channels=3,
          latent_dimension=100,
          training_data=train_batches_images,
          )

In [None]:
#gan.clear_files_structure()

In [12]:
gan.define_discriminator()
gan.define_generator()
gan.define_gan()
gan.create_files_structure()

In [None]:
#gan.train(number_of_epochs=10, load_past_model=False)
#gan.plot_loss()

In [None]:
gan.train(number_of_epochs=100, load_past_model=True)
gan.plot_loss()