In [7]:
import os
import tensorflow as tf
from gan.GAN import GanNet
import tensorflow_datasets as tfds
import numpy as np

In [8]:
BATCH_SIZE = 32
IMAGE_WIDTH = 28
IMAGE_HEIGHT = 28

In [9]:
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2"
physical_devices = tf.config.list_physical_devices("GPU")
tf.config.experimental.set_memory_growth(physical_devices[0], True)

In [10]:
# Load and prepare dataset

(training_set, validation_set), dataset_info = tfds.load(
    'oxford_flowers102',
    split=['test', 'train'],
    with_info=True,
    as_supervised=True,
)

num_training_examples = 0
for example in training_set:
    num_training_examples += 1

def format_image(image, label):
    image = tf.image.resize(image, (IMAGE_WIDTH, IMAGE_HEIGHT))/255.0
    return image, label

train_batches = training_set.shuffle(num_training_examples//4).map(format_image)

num_classes = dataset_info.features['label'].num_classes
train_batches_images = np.array([_[0] for _ in train_batches])
train_batches_labels = np.array([_[1] for _ in train_batches])

In [11]:
gan = GanNet(batch_size=BATCH_SIZE,
          batches_per_epoch=40,
          image_width=IMAGE_WIDTH,
          image_height=IMAGE_HEIGHT,
          number_of_channels=3,
          latent_dimension=100,
          training_data=train_batches_images,
          )

In [None]:
#gan.clear_files_structure()

In [12]:
gan.define_discriminator()
gan.define_generator()
gan.define_gan()
gan.create_files_structure()

In [13]:
#gan.train(number_of_epochs=10, load_past_model=False)
#gan.plot_loss()

Dataset size: 6149
Batches per epoch: 40
----> Epoch: 0

D_real_loss: 0.4008719325065613 D_fake_loss: 0.8263919949531555 G_loss: 0.7210378646850586
D_real_acc: 1.0 D_fake_acc: 0.0
----> Epoch: 1

D_real_loss: 0.49151521921157837 D_fake_loss: 0.8079332113265991 G_loss: 0.6575926542282104
D_real_acc: 0.90625 D_fake_acc: 0.15625
----> Epoch: 2

D_real_loss: 0.785713791847229 D_fake_loss: 0.6918242573738098 G_loss: 0.8974084854125977
D_real_acc: 0.15625 D_fake_acc: 0.5625
----> Epoch: 3

D_real_loss: 0.5814696550369263 D_fake_loss: 0.5355560779571533 G_loss: 1.09686279296875
D_real_acc: 0.6875 D_fake_acc: 0.9375
----> Epoch: 4

D_real_loss: 1.0108349323272705 D_fake_loss: 0.4092980921268463 G_loss: 1.319187879562378
D_real_acc: 0.09375 D_fake_acc: 1.0
----> Epoch: 5

D_real_loss: 0.8394730091094971 D_fake_loss: 0.4008246660232544 G_loss: 1.400928258895874
D_real_acc: 0.3125 D_fake_acc: 1.0
----> Epoch: 6

D_real_loss: 0.548281729221344 D_fake_loss: 0.9264053106307983 G_loss: 0.700375795364



In [None]:
gan.train(number_of_epochs=100, load_past_model=True)
gan.plot_loss()

Dataset size: 6149
Batches per epoch: 40
----> Load epoch number: 9 from file .epoch
----> Epoch: 10