In [1]:
"""
BIOINFORMATICS: LAB07
@author: Irene Benedetto
"""


import tensorflow as tf
import numpy as np
from tqdm import tqdm
from models import *


def normalize(data, label):
    data = tf.cast(data, tf.float32)/255.0
    data = tf.expand_dims(data, axis=-1)
    return data, label

In [2]:
mnist = tf.keras.datasets.mnist

(train_images, train_labels), (test_images, test_labels) = mnist.load_data()

images = tf.concat([train_images, test_images], axis = 0)
labels = tf.concat([train_images, test_images], axis = 0)
N_EPOCHS = 20
BATCH_SIZE = 64

train_ds = tf.data.Dataset.from_tensor_slices((images, labels))

train_ds = train_ds.map(normalize, num_parallel_calls=tf.data.experimental.AUTOTUNE)
train_ds = train_ds.shuffle(buffer_size=1024).batch(BATCH_SIZE).prefetch(BATCH_SIZE)

input_shape = list(train_images.shape)[1:]
output_shape = len(set(train_labels))
latent_dim = 128

print(f'Training set: {input_shape}')
print(f'Number of classes: {output_shape}')

input_shape = list(train_images.shape)[1:]
output_shape = len(set(train_labels))
latent_dim = 128

print(f'Training set: {input_shape}')
print(f'Number of classes: {output_shape}')



Training set: [28, 28]
Number of classes: 10
Training set: [28, 28]
Number of classes: 10


In [3]:
loss = tf.keras.losses.BinaryCrossentropy(from_logits=True)
d_metrics = tf.keras.metrics.SparseCategoricalAccuracy(name='discriminator_accuracy')
g_metrics = tf.keras.metrics.SparseCategoricalAccuracy(name='generator_accuracy')
metrics = tf.keras.metrics.SparseCategoricalAccuracy(name='gan_accuracy')

g_optimizer = tf.keras.optimizers.Adam(learning_rate=0.0003)
d_optimizer = tf.keras.optimizers.Adam(learning_rate=0.0003)

discriminator = Discriminator(input_shape=input_shape + [1], output_shape=1)
discriminator.compile(loss, d_optimizer, d_metrics)

generator = Generator(latent_dim=latent_dim)
generator.compile(loss, g_optimizer, g_metrics)
gan = GAN(generator=generator, discriminator=discriminator, latent_dim=latent_dim, BATCH_SIZE=BATCH_SIZE)

gan.compile(loss, d_optimizer, g_optimizer, metrics)

for epoch in range(N_EPOCHS):

    losses = {"discriminator_loss": 0, "generator_loss": 0}
    print(f'Epoch {epoch + 1}/{N_EPOCHS}')
    for x, y in tqdm(train_ds):
        history = gan.train_step((x, y))
        losses['discriminator_loss'] += history['discriminator_loss']
        losses['generator_loss'] += history['generator_loss']

    print(f"Discriminator loss: {losses['discriminator_loss']}")
    print(f"Generator loss: {losses['generator_loss']}")

    random_latent_vectors = tf.random.normal(shape=(1, latent_dim))
    generated_images = gan.generator(random_latent_vectors)
    generated_images *= 255
    generated_images.numpy()

    img = tf.keras.preprocessing.image.array_to_img(generated_images[0])
    img.save("synthetic_img_{epoch}.png".format(epoch=epoch))



  0%|          | 0/1094 [00:00<?, ?it/s]

Epoch 1/20


100%|██████████| 1094/1094 [00:31<00:00, 34.55it/s]
  0%|          | 1/1094 [00:00<02:44,  6.64it/s]

Discriminator loss: 537.7638549804688
Generator loss: 1287.8514404296875
Epoch 2/20


100%|██████████| 1094/1094 [00:29<00:00, 36.67it/s]
  0%|          | 1/1094 [00:00<02:25,  7.54it/s]

Discriminator loss: 733.9061279296875
Generator loss: 906.4324340820312
Epoch 3/20


100%|██████████| 1094/1094 [00:30<00:00, 36.40it/s]
  0%|          | 1/1094 [00:00<02:19,  7.82it/s]

Discriminator loss: 751.3807373046875
Generator loss: 863.263916015625
Epoch 4/20


100%|██████████| 1094/1094 [00:30<00:00, 36.37it/s]
  0%|          | 1/1094 [00:00<02:04,  8.78it/s]

Discriminator loss: 744.859375
Generator loss: 866.505615234375
Epoch 5/20


100%|██████████| 1094/1094 [00:30<00:00, 36.24it/s]
  0%|          | 1/1094 [00:00<02:12,  8.25it/s]

Discriminator loss: 738.5790405273438
Generator loss: 897.722900390625
Epoch 6/20


100%|██████████| 1094/1094 [00:30<00:00, 36.45it/s]
  0%|          | 1/1094 [00:00<02:01,  9.00it/s]

Discriminator loss: 619.3616333007812
Generator loss: 1146.2047119140625
Epoch 7/20


100%|██████████| 1094/1094 [00:30<00:00, 35.75it/s]
  0%|          | 1/1094 [00:00<02:06,  8.67it/s]

Discriminator loss: 747.6897583007812
Generator loss: 963.7528686523438
Epoch 8/20


100%|██████████| 1094/1094 [00:30<00:00, 36.34it/s]
  0%|          | 1/1094 [00:00<02:14,  8.15it/s]

Discriminator loss: 777.7894287109375
Generator loss: 872.8348999023438
Epoch 9/20


100%|██████████| 1094/1094 [00:29<00:00, 36.59it/s]
  0%|          | 1/1094 [00:00<02:22,  7.67it/s]

Discriminator loss: 732.620361328125
Generator loss: 922.3572998046875
Epoch 10/20


100%|██████████| 1094/1094 [00:29<00:00, 36.72it/s]
  0%|          | 1/1094 [00:00<02:25,  7.51it/s]

Discriminator loss: 742.150390625
Generator loss: 925.9978637695312
Epoch 11/20


100%|██████████| 1094/1094 [00:29<00:00, 36.49it/s]
  0%|          | 1/1094 [00:00<02:14,  8.13it/s]

Discriminator loss: 751.8447875976562
Generator loss: 907.7006225585938
Epoch 12/20


100%|██████████| 1094/1094 [00:29<00:00, 36.62it/s]
  0%|          | 1/1094 [00:00<01:58,  9.24it/s]

Discriminator loss: 713.589111328125
Generator loss: 948.6929931640625
Epoch 13/20


100%|██████████| 1094/1094 [00:29<00:00, 36.52it/s]
  0%|          | 1/1094 [00:00<02:32,  7.18it/s]

Discriminator loss: 729.646240234375
Generator loss: 996.1788940429688
Epoch 14/20


100%|██████████| 1094/1094 [00:29<00:00, 36.68it/s]
  0%|          | 1/1094 [00:00<02:11,  8.28it/s]

Discriminator loss: 783.6210327148438
Generator loss: 860.6781616210938
Epoch 15/20


100%|██████████| 1094/1094 [00:29<00:00, 36.89it/s]
  0%|          | 1/1094 [00:00<02:25,  7.51it/s]

Discriminator loss: 657.26220703125
Generator loss: 1471.9725341796875
Epoch 16/20


100%|██████████| 1094/1094 [00:29<00:00, 36.47it/s]
  0%|          | 1/1094 [00:00<02:06,  8.62it/s]

Discriminator loss: 677.8019409179688
Generator loss: 1546.230712890625
Epoch 17/20


100%|██████████| 1094/1094 [00:30<00:00, 36.05it/s]
  0%|          | 1/1094 [00:00<02:22,  7.69it/s]

Discriminator loss: 695.947998046875
Generator loss: 1004.0736083984375
Epoch 18/20


100%|██████████| 1094/1094 [00:29<00:00, 36.91it/s]
  0%|          | 1/1094 [00:00<02:20,  7.78it/s]

Discriminator loss: 822.3173828125
Generator loss: 839.575927734375
Epoch 19/20


100%|██████████| 1094/1094 [00:30<00:00, 35.80it/s]
  0%|          | 1/1094 [00:00<02:09,  8.46it/s]

Discriminator loss: 751.70458984375
Generator loss: 902.0648803710938
Epoch 20/20


100%|██████████| 1094/1094 [00:29<00:00, 36.88it/s]

Discriminator loss: 743.4371337890625
Generator loss: 958.8997192382812



