In [None]:
import pandas as pd
import numpy as np
import tensorflow as tf
import time
import os
import matplotlib.pyplot as plt

from IPython import display
from PIL import Image
from tqdm import tqdm

from src.preprocessing.image_sampling import create_sample, downscale_images
from src.preprocessing.preprocessing import preprocess
from src.evaluation.losses import generator_loss, discriminator_loss, combined_metric
from src.misc.plotting import plot_loss
from src.misc.saving import generate_and_save_images, save_loss

from src.models.wgan_gp import create_generator, create_discriminator, WGAN_GP

from tensorflow.keras.models import load_model

In [None]:
ds = tf.data.Dataset.list_files('./data/10000_images_downscaled/*')

ds = ds.map(preprocess(tanh=True)).batch(32)

In [None]:
sample = next(iter(ds))
sample_image = sample[0].numpy()
plt.title('Sample Image')
plt.imshow(sample_image*0.5 + 0.5)

In [None]:
LATENT_DIM = 100

In [None]:
generator =  create_generator(input_shape=LATENT_DIM)
discriminator = create_discriminator(input_shape=(64, 64, 3))

generator_optimizer = tf.keras.optimizers.Adam(learning_rate=1e-4, 
									beta_1=0.5,
									beta_2=0.9
									)

discriminator_optimizer = tf.keras.optimizers.Adam(learning_rate=1e-4, 
									beta_1=0.5,
									beta_2=0.9
									)

In [None]:
wgan_gp = WGAN_GP(generator, discriminator, latent_dim=LATENT_DIM)

wgan_gp.compile(
    disc_optimizer=discriminator_optimizer,
    gen_optimizer=generator_optimizer,
    disc_loss_fn=discriminator_loss,
    gen_loss_fn=generator_loss,
)

In [None]:
checkpoint_dir = './src/model_checkpoints/wgan_gp_checkpoints/'
checkpoint_prefix = checkpoint_dir + "wgan_gp_ckpt"

checkpoint = tf.train.Checkpoint(generator_optimizer=generator_optimizer,
                                 discriminator_optimizer=discriminator_optimizer,
                                 generator=generator,
                                 discriminator=discriminator)

checkpoint_manager = tf.train.CheckpointManager(checkpoint, directory=checkpoint_dir, checkpoint_name=checkpoint_prefix, max_to_keep=1)


checkpoint.restore(tf.train.latest_checkpoint(checkpoint_dir))
    