In [None]:
import tensorflow as tf
import tensorflow.keras as keras
import tensorflow.keras.backend as K
from tensorflow.python.ops import summary_ops_v2
from tensorflow.python.keras.callbacks import keras_model_summary
from keras.layers import Conv2D, Conv2DTranspose, UpSampling2D, MaxPool2D, Input, Dense, Flatten, Dropout, Concatenate, Layer, LeakyReLU, Reshape, AveragePooling2D, Add

import numpy as np

import matplotlib.pyplot as plt

import cv2

import os
import datetime
from time import perf_counter, sleep
import threading
import traceback

from random import sample

from functools import partial

In [None]:
from progressive_gan import ProgressiveGAN
from utils import ImageGenerator, TensorBoardCallback

In [None]:
img_gen = ImageGenerator(r'E:\Workspace\datasets\cats\train', batch_size=16, image_channels=3) # r'E:\Workspace\datasets\b\train_1\512'E:\Workspace\datasets\cats\train'

In [None]:
# optimizer = keras.optimizers.Adam(learning_rate=5e-6, beta_1=0., beta_2=.99, epsilon=1e-8)
optimizer = keras.optimizers.RMSprop(learning_rate=1e-4)

progan = ProgressiveGAN(
    latent_dim=128,
    initial_image_size=4,
    final_image_size=128,
    image_channels=3,
    discriminator_optimizer=optimizer,
    gan_optimizer=optimizer)

tensorboard_callback = TensorBoardCallback('./logs', progan, image_generator_preview_save_interval=100, use_tensorboard=False)

progan.fit(img_gen, epochs_per_step=[1000, 2000, 3000, 6000, 8000, 10000], discriminator_train_per_gan_train=1, tensorboard_callback=tensorboard_callback)

In [None]:
# batch = np.concatenate([img_gen.get_batch(), img_gen.get_batch()], axis=0)
# print(batch.shape)
# print(batch.min())
# print(batch.max())

# plt.figure(figsize=(16, 8))

# plt.imshow(np.vstack([np.hstack([batch[i + 8*j] for i in range(8)]) for j in range(4)])/2 + .5)

# plt.show()

In [None]:

latent_noise = progan.sample_latent_space(8)

generated_images = np.zeros((8*len(progan.generator), 32, 32, 3))
for i in range(len(progan.generator)):
    g = progan.generator[i][0].predict(latent_noise)
    for j in range(8):
        img = g[j,]
        img = (img + 1.)/2.
        generated_images[8*i + j,] = cv2.resize(img, (32, 32), interpolation=cv2.INTER_NEAREST)
    
print(generated_images.shape)
print(generated_images.min())
print(generated_images.max())

plt.figure(figsize=(16, 2*len(progan.generator)))

plt.imshow(np.vstack([np.hstack([generated_images[i + 8*j] for i in range(8)]) for j in range(len(progan.generator))])/2 + .5, interpolation=None)

plt.show()

In [None]:
progan.generator[-1][0].save('./model/generator_cats.h5')