In [None]:
import tensorflow as tf
import tensorflow.keras as keras
import tensorflow.keras.backend as K
from tensorflow.python.ops import summary_ops_v2
from tensorflow.python.keras.callbacks import keras_model_summary
from keras.layers import Conv2D, Conv2DTranspose, UpSampling2D, MaxPool2D, Input, Dense, Flatten, Dropout, Concatenate, Layer, LeakyReLU, Reshape, AveragePooling2D, Add

import numpy as np

import matplotlib.pyplot as plt

import cv2

import os
import datetime
from time import perf_counter, sleep
import threading

from random import sample

from functools import partial

In [None]:
class WeightedSum(Add):
	# init with default value
	def __init__(self, alpha=0.0, **kwargs):
		super(WeightedSum, self).__init__(**kwargs)
		self.alpha = K.variable(alpha, name='ws_alpha')
 
	# output a weighted sum of inputs
	def _merge_function(self, inputs):
		output = (self.alpha*inputs[0]) + ((1.0 - self.alpha)*inputs[1])
		return output

In [None]:
class PixelNormalization(Layer):
    '''
    pixel-wise feature vector normalization layer
    '''
    def __init__(self, **kwargs):
        super(PixelNormalization, self).__init__(**kwargs)
 
    def call(self, inputs):
        values = inputs**2.0
        mean_values = K.mean(values, axis=-1, keepdims=True)
        mean_values += 1.0e-8
        l2 = K.sqrt(mean_values)
        normalized = inputs / l2
        return normalized
 
    def compute_output_shape(self, input_shape):
        return input_shape

In [None]:
class MinibatchStdev(Layer):
    '''
    mean standard deviation across each pixel coord
    '''
    def __init__(self, **kwargs):
        super(MinibatchStdev, self).__init__(**kwargs)

    def call(self, inputs):
        mean = K.mean(inputs, axis=0, keepdims=True)
        mean_sq_diff = K.mean(K.square(inputs - mean), axis=0, keepdims=True) + 1e-8
        mean_pix = K.mean(K.sqrt(mean_sq_diff), keepdims=True)
        shape = K.shape(inputs)
        output = K.tile(mean_pix, [shape[0], shape[1], shape[2], 1])
        return K.concatenate([inputs, output], axis=-1)

    def compute_output_shape(self, input_shape):
        input_shape = list(input_shape)
        input_shape[-1] += 1
        return tuple(input_shape)

In [None]:
# TODO add __len__
# TODO shuffle inputs and get them batch by batch instead of random sample
class ImageGenerator(object):
    def __init__(self, images_folder_path, initial_images_size=4, batch_size=32, image_channels=3):
        self.__images_folder_path = images_folder_path
        self.__images_size = initial_images_size
        self.__batch_size = batch_size
        self.__image_channels = image_channels
        
        self.__filenames = []
        
        self.__cached_bank = 0
        self.__cached_batch = [None, None]
        self.__cached_size = [0, 0]
        self.__cached_ready = [True, True]

        for _, _, fnames in os.walk(self.__images_folder_path):
            for fname in fnames:
                if fname.split('.')[-1] in ('jpg', 'jpeg'):
                    self.__filenames.append(fname)
            break
    
        print(f'Loaded {len(self.__filenames)} images.')
    
    @property
    def batch_size(self):
        return self.__batch_size
    
    def set_images_size(self, size):
        self.__images_size = size

    def get_batch(self):
        while not self.__cached_ready[self.__cached_bank]:
            sleep(.01)
        
        result = self.__cached_batch[self.__cached_bank]
        result_size = self.__cached_size[self.__cached_bank]
        
        self.__cached_bank ^= 1
        self.__cached_ready[self.__cached_bank] = False
        prepare_thread = threading.Thread(target=self.__prepare_cached_batch)
        prepare_thread.start()
        
        if result_size != self.__images_size:
            # wrong size
            return self.get_batch()
        
        return result
        

    def __prepare_cached_batch(self):
        img_size = self.__images_size
        result = np.zeros((self.__batch_size, img_size, img_size, self.__image_channels))

        fnames = sample(self.__filenames, self.__batch_size)

        for i in range(self.__batch_size):
            img = cv2.imread(os.path.join(self.__images_folder_path, fnames[i]))[:,:,::-1]
            min_size = min(img.shape[:2])
            img = img[(img.shape[0] - min_size)//2:(img.shape[0] + min_size)//2,
                      (img.shape[1] - min_size)//2:(img.shape[1] + min_size)//2]
            img = cv2.resize(img, (img_size,)*2).astype(np.float32)
            img -= img.min()
            img /= (img.max() + 1e-9)
            # [0, 1] -> [-1, 1]
            img *= 2.
            img -= 1.
            if len(img.shape) == 2:
                img = img[:,:,np.newaxis]
            result[i,] = img[:,:,:self.__image_channels]
        
        self.__cached_batch[self.__cached_bank] = result
        self.__cached_size[self.__cached_bank] = img_size
        self.__cached_ready[self.__cached_bank] = True
        

In [None]:
def wasserstein_loss(y_true, y_pred):
    return K.mean(y_true * y_pred)

In [None]:
# TODO add gradient penalty
# TODO add progress bar
# TODO change epoch to use all data
from email.mime import image


class ProgressiveGAN(object):
    __latent_dim                : int
    __initial_image_size        : int
    __final_image_size          : int
    __image_channels            : int
    __gan_optimizer             : (str | keras.optimizers.Optimizer)
    __discriminator_optimizer   : (str | keras.optimizers.Optimizer)
    
    __steps                     : list[int]

    __generator                 : keras.Model
    __discriminator             : keras.Model

    __gan                       : keras.Model

    def __init__(self, latent_dim : int =128, initial_image_size : int =4, final_image_size : int =512, image_channels : int =3, gan_optimizer : (str | keras.optimizers.Optimizer) ='adam', discriminator_optimizer : (str | keras.optimizers.Optimizer) ='adam'):
        self.__latent_dim = latent_dim
        self.__initial_image_size = initial_image_size
        self.__final_image_size = final_image_size
        self.__image_channels = image_channels
        self.__gan_optimizer = gan_optimizer
        self.__discriminator_optimizer = discriminator_optimizer

        self.__steps = []
        
        image_size = initial_image_size
        while image_size <= final_image_size:
            self.__steps.append(image_size)
            image_size <<= 1
        
        self.__total_epochs = 0

        self.__generator = None
        self.__discriminator = None
        self.__gan = None

        self.__init_generator()
        self.__init_discriminator()
        self.__init_gan()
        
        self.__timer = perf_counter()
    
    @property
    def generator(self):
        return self.__generator

    @property
    def discriminator(self):
        return self.__discriminator

    @property
    def gan(self):
        return self.__gan

    def sample_latent_space(self, n : int) -> np.ndarray:
        # sample from unit hypersphere
        normal_sample = np.random.normal(size=(n, 1, 1, self.__latent_dim))
        
        return normal_sample/np.sqrt((normal_sample**2).sum(axis=3))[:,:,:,np.newaxis]
    
    def __train_models(self, step : int, fade : bool, image_generator : ImageGenerator, epochs_per_step : int =32, discriminator_train_per_gan_train : int =5, tensorboard_callback=None):
        generator       = self.__generator[step][int(fade)]
        discriminator   = self.__discriminator[step][int(fade)]
        gan             = self.__gan[step][int(fade)]
        
        d_loss_total = .0
        g_loss_total = .0
        
        for epoch in range(epochs_per_step):
            # adjust fade in parameter
            if fade:
                for model in (generator, discriminator, gan):
                    for layer in model.layers:
                        if isinstance(layer, WeightedSum):
                            K.set_value(layer.alpha, epoch/epochs_per_step)
                
            # train discriminator
            d_loss_generated = 0.
            d_loss_real = 0.
            
            d_accuracy_generated = 0.
            d_accuracy_real = 0.
            for _ in range(discriminator_train_per_gan_train):
                latent_noise = self.sample_latent_space(image_generator.batch_size)

                generated_images = generator.predict(latent_noise)
                real_images = image_generator.get_batch()
                
                generated_labels = -1. * np.ones((image_generator.batch_size, 1))
                real_labels = np.ones((image_generator.batch_size, 1))

                # combined_images = np.concatenate([generated_images, real_images])

                # labels = np.ones((batch_size << 1, 1))
                # labels[:batch_size,] = 0

                generated_labels += .1 * np.random.normal(0, 1, generated_labels.shape)
                real_labels += .1 * np.random.normal(0, 1, real_labels.shape)
                
                loss, accuracy = discriminator.train_on_batch(generated_images, generated_labels)
                d_loss_generated += loss
                d_accuracy_generated += accuracy
                
                loss, accuracy = discriminator.train_on_batch(real_images, real_labels)
                d_loss_real += loss
                d_accuracy_real += accuracy
            
            d_loss_generated /= discriminator_train_per_gan_train
            d_loss_real /= discriminator_train_per_gan_train
            
            d_accuracy_generated /= discriminator_train_per_gan_train
            d_accuracy_real /= discriminator_train_per_gan_train
            
            d_loss = (d_loss_generated + d_loss_real)/2
            d_accuracy = (d_accuracy_generated + d_accuracy_real)/2
            
            # train generator
            latent_noise = self.sample_latent_space(image_generator.batch_size)

            misleading_labels = np.ones((image_generator.batch_size, 1))
            misleading_labels += .1 * np.random.normal(0, 1, misleading_labels.shape)

            g_loss, g_accuracy = gan.train_on_batch(latent_noise, misleading_labels)
            
            d_loss_total += d_loss
            g_loss_total += g_loss

            if epoch + 1 < epochs_per_step:
                self.__print_fit_progress(self.__steps[step], step, fade, epoch + 1, epochs_per_step, d_loss, g_loss)
            else:
                self.__print_fit_progress(self.__steps[step], step, fade, epoch + 1, epochs_per_step, d_loss_total/epochs_per_step, g_loss_total/epochs_per_step)

            if tensorboard_callback is not None:
                tensorboard_callback.on_epoch_end(
                    epoch, step, fade,
                    {'loss': {      #'d_loss_generated' : d_loss_generated,
                                    #'d_loss_real' : d_loss_real,
                                    'd_loss' : d_loss,
                                    'g_loss': g_loss},
                     'accuracy': {  #'d_accuracy_generated' : d_accuracy_generated,
                                    #'d_accuracy_real' : d_accuracy_real,
                                    'd_accuracy' : d_accuracy,
                                    'g_accuracy': g_accuracy}})
            
            self.__total_epochs += 1
    
    def fit(self, image_generator : ImageGenerator, epochs_per_step : (int | list) =32, discriminator_train_per_gan_train=5, tensorboard_callback=None):
        if isinstance(epochs_per_step, int):
            epochs_per_step = [epochs_per_step for _ in range(len(self.__steps))]
            
        image_generator.set_images_size(self.__steps[0])
        
        self.__print_fit_progress_header()
        self.__train_models(step=0, fade=False, image_generator=image_generator, epochs_per_step=epochs_per_step[0], discriminator_train_per_gan_train=discriminator_train_per_gan_train, tensorboard_callback=tensorboard_callback)

        for step in range(1, len(self.__steps)):
            img_size = self.__steps[step]
            
            image_generator.set_images_size(img_size)
            
            self.__train_models(step=step, fade=True, image_generator=image_generator, epochs_per_step=epochs_per_step[step], discriminator_train_per_gan_train=discriminator_train_per_gan_train, tensorboard_callback=tensorboard_callback)
            
            self.__train_models(step=step, fade=False, image_generator=image_generator, epochs_per_step=epochs_per_step[step], discriminator_train_per_gan_train=discriminator_train_per_gan_train, tensorboard_callback=tensorboard_callback)
        
        
        if tensorboard_callback is not None:
            tensorboard_callback.on_fit_end()
    
    def __print_fit_progress_header(self):
        print('| image size       | step | fade | epoch            | time     | d_loss                           | g_loss                           |', end='')
    
    def __print_fit_progress(self, img_size, step, fade, epoch, total_epochs, d_loss, g_loss):
        if epoch == 1:
            self.__timer = perf_counter()
            print()
            
        epoch_time = perf_counter() - self.__timer
        
        time = 0
        if epoch == total_epochs:
            # time passed
            time = epoch_time
        else:
            # eta
            time = epoch_time*(total_epochs - epoch)/epoch
        
        time_str = ''
        if time < 60:
            time_str = f'{int(time)}.{int(time*100)%100:02d}'
        elif time < 60*60:
            time_str = f'{(int(time)//60)%60}:{int(time)%60:02d}'
        else:
            time_str = f'{(int(time)//(60*60))%60}:{(int(time)//60)%60:02d}:{int(time)%60:02d}'
        
        img_size_str = ''
        if fade:
            img_size_str = f'{img_size//2} -> {img_size}'
        else:
            img_size_str = f'{img_size}'
        
        epoch_str = f'{epoch} / {total_epochs}'
    
        print(f'\r| {img_size_str:>16s} | {step:4d} | {int(fade):4d} | {epoch_str:>16s} | {time_str:>8s} | {d_loss:32f} | {g_loss:32f} |', end='')

    def __init_generator(self):
        kernel_initializer = keras.initializers.HeNormal()
        kernel_constraint = keras.constraints.MaxNorm(1.)
        
        self.__generator = []
        
        generator_input = x = Input((1, 1, self.__latent_dim))
        
        x = Conv2DTranspose(self.__latent_dim, 4, kernel_initializer=kernel_initializer, kernel_constraint=kernel_constraint)(x)
        x = PixelNormalization()(x)
        x = LeakyReLU(alpha=.2)(x)

        x = Conv2D(self.__latent_dim, 3, padding='same', kernel_initializer=kernel_initializer, kernel_constraint=kernel_constraint)(x)
        x = PixelNormalization()(x)
        x = LeakyReLU(alpha=.2)(x)

        output_size = 4

        while output_size < self.__initial_image_size:
            filters = self.__filters_count(output_size)

            x = UpSampling2D()(x)
            
            x = Conv2DTranspose(filters, 3, padding='same', kernel_initializer=kernel_initializer, kernel_constraint=kernel_constraint)(x)
            x = PixelNormalization()(x)
            x = LeakyReLU(alpha=.2)(x)
            
            x = Conv2D(filters, 3, padding='same', kernel_initializer=kernel_initializer, kernel_constraint=kernel_constraint)(x)
            x = PixelNormalization()(x)
            x = LeakyReLU(alpha=.2)(x)

            output_size <<= 1

        x = Conv2D(self.__image_channels, 1, padding='same', kernel_initializer=kernel_initializer, kernel_constraint=kernel_constraint)(x)
        
        generator = keras.Model(generator_input, x, name=f'generator_{self.__initial_image_size:}x{self.__initial_image_size:}')

        self.__generator.append([generator, generator])
        
        for _ in range(1, len(self.__steps)):
            next_generators = self.__add_generator_block(self.__generator[-1][0])
            self.__generator.append(next_generators)
    
    def __add_generator_block(self, generator : keras.Model):
        kernel_initializer = keras.initializers.HeNormal()
        kernel_constraint = keras.constraints.MaxNorm(1.)
        
        prev_block_end = generator.layers[-2].output
        
        upsampling = x = UpSampling2D()(prev_block_end)
        
        output_image_size = x.shape[1]
        filters = self.__filters_count(output_image_size)
        
        x = Conv2D(filters, 3, padding='same', kernel_initializer=kernel_initializer, kernel_constraint=kernel_constraint)(x)
        x = PixelNormalization()(x)
        x = LeakyReLU(alpha=.2)(x)

        x = Conv2D(filters, 3, padding='same', kernel_initializer=kernel_initializer, kernel_constraint=kernel_constraint)(x)
        x = PixelNormalization()(x)
        x = LeakyReLU(alpha=.2)(x)
        
        new_generator_output = Conv2D(self.__image_channels, 1, padding='same', kernel_initializer=kernel_initializer, kernel_constraint=kernel_constraint)(x)
        
        new_generator = keras.Model(generator.input, new_generator_output, name=f'generator_{output_image_size}x{output_image_size}')
        
        generator_output = generator.layers[-1]
        generator_output_upscaled = generator_output(upsampling)
        
        combined = WeightedSum()((new_generator_output, generator_output_upscaled))
        
        new_generator_fade = keras.Model(generator.input, combined, name=f'generator_fade_{output_image_size}x{output_image_size}')
        
        return (new_generator, new_generator_fade)

    def __init_discriminator(self):
        kernel_initializer = keras.initializers.HeNormal()
        kernel_constraint = keras.constraints.MaxNorm(1.)
        
        self.__discriminator = []
        
        discriminator_input = x = Input((self.__initial_image_size, self.__initial_image_size, self.__image_channels))

        x = Conv2D(self.__filters_count(self.__initial_image_size << 1), 1, padding='same', kernel_initializer=kernel_initializer, kernel_constraint=kernel_constraint)(x)
        x = LeakyReLU(alpha=.2)(x)

        output_size = self.__initial_image_size

        while output_size > 4:
            x = Conv2D(self.__filters_count(output_size << 1), 3, padding='same', kernel_initializer=kernel_initializer, kernel_constraint=kernel_constraint)(x)
            x = LeakyReLU(alpha=.2)(x)
            
            x = Conv2D(self.__filters_count(output_size), 3, padding='same', kernel_initializer=kernel_initializer, kernel_constraint=kernel_constraint)(x)
            x = LeakyReLU(alpha=.2)(x)
            
            x = AveragePooling2D()(x)

            output_size >>= 1

        x = MinibatchStdev()(x)
        
        x = Conv2D(self.__filters_count(output_size), 3, padding='same', kernel_initializer=kernel_initializer, kernel_constraint=kernel_constraint)(x)
        x = LeakyReLU(alpha=.2)(x)
        
        x = Conv2D(self.__latent_dim, 4, kernel_initializer=kernel_initializer, kernel_constraint=kernel_constraint)(x)
        x = LeakyReLU(alpha=.2)(x)
        
        x = Flatten()(x)
        
        x = Dense(1)(x)

        discriminator = keras.Model(discriminator_input, x, name=f'discriminator_{self.__initial_image_size}x{self.__initial_image_size}')
        
        discriminator.compile(optimizer=self.__discriminator_optimizer, loss=wasserstein_loss, metrics=['accuracy'])
        
        self.__discriminator.append((discriminator, discriminator))

        for _ in range(1, len(self.__steps)):
            next_discriminators = self.__add_discriminator_block(self.__discriminator[-1][0])
            self.__discriminator.append(next_discriminators)
    
    def __add_discriminator_block(self, discriminator : keras.Model):
        kernel_initializer = keras.initializers.HeNormal()
        kernel_constraint = keras.constraints.MaxNorm(1.)
        
        discriminator_input_shape = discriminator.input.shape
        new_discriminator_input_shape = (discriminator_input_shape[-3] << 1, discriminator_input_shape[-2] << 1, discriminator_input_shape[-1])
        output_image_size = new_discriminator_input_shape[0]
        
        new_discriminator_input = x = Input(shape=new_discriminator_input_shape)
        
        x = Conv2D(self.__filters_count(output_image_size << 1), 1, padding='same', kernel_initializer=kernel_initializer, kernel_constraint=kernel_constraint)(x)
        x = LeakyReLU(alpha=.2)(x)
        
        x = Conv2D(self.__filters_count(output_image_size << 1), 3, padding='same', kernel_initializer=kernel_initializer, kernel_constraint=kernel_constraint)(x)
        x = LeakyReLU(alpha=.2)(x)
        
        x = Conv2D(self.__filters_count(output_image_size), 3, padding='same', kernel_initializer=kernel_initializer, kernel_constraint=kernel_constraint)(x)
        x = LeakyReLU(alpha=.2)(x)
        
        new_block_end = x = AveragePooling2D()(x)
        
        # skip the input, 1x1 and activation layers of the old model
        for i in range(3, len(discriminator.layers)):
            x = discriminator.layers[i](x)
            
        new_discriminator = keras.Model(new_discriminator_input, x, name=f'discriminator_{output_image_size}x{output_image_size}')
        
        new_discriminator.compile(optimizer=self.__discriminator_optimizer, loss=wasserstein_loss, metrics=['accuracy'])
        
        x = AveragePooling2D()(new_discriminator_input)
        x = discriminator.layers[1](x)  # 1x1 conv
        x = discriminator.layers[2](x)  # activation
        
        x = WeightedSum()([new_block_end, x])
        
        # same as above
        for i in range(3, len(discriminator.layers)):
            x = discriminator.layers[i](x)
            
        new_discriminator_fade = keras.Model(new_discriminator_input, x, name=f'discriminator_fade_{output_image_size}x{output_image_size}')
        
        new_discriminator_fade.compile(optimizer=self.__discriminator_optimizer, loss=wasserstein_loss, metrics=['accuracy'])
        
        return (new_discriminator, new_discriminator_fade)
            
    
    def __init_gan(self):
        self.__gan = []
        
        for generators, discirminators in zip(self.__generator, self.__discriminator):
            # straight-through model
            discirminators[0].trainable = False
            
            gan = keras.Sequential(name=f'gan_{generators[0].output.shape[1]}x{generators[0].output.shape[1]}')
            gan.add(generators[0])
            gan.add(discirminators[0])
            
            gan.compile(loss=wasserstein_loss, optimizer=self.__gan_optimizer, metrics=['accuracy'])
            
            # fade-in model
            discirminators[1].trainable = False
            
            gan_fade = keras.Sequential(name=f'gan_fade_{generators[0].output.shape[1]}x{generators[0].output.shape[1]}')
            gan_fade.add(generators[1])
            gan_fade.add(discirminators[1])
            
            gan_fade.compile(loss=wasserstein_loss, optimizer=self.__gan_optimizer, metrics=['accuracy'])
            
            self.__gan.append((gan, gan_fade))

    def __filters_count(self, output_size):
        filters = self.__latent_dim
        while output_size*filters >= self.__final_image_size*16:
            filters //= 2
        
        return filters

In [None]:
# TODO fix write_model_summary
class TensorBoardCallback(object):
    def __init__(self, logdir : str, model : ProgressiveGAN = None, metrics_save_interval : int =20, generator_preview_save_interval : int =100):
        self.__datetime_str = datetime.datetime.now().strftime('%Y%m%d-%H%M%S')
        try:
            os.mkdir(os.path.join('gen', self.__datetime_str))
        except:
            pass
        
        self.__logdir = os.path.join(logdir, self.__datetime_str)
        self.__model = model
        self.__metrics_save_interval = metrics_save_interval
        self.__generator_preview_save_interval = generator_preview_save_interval
        
        self.__generator_preview_vid_latent = model.sample_latent_space(1)
        self.__generator_preview_vid = cv2.VideoWriter(os.path.join('./gen', self.__datetime_str, 'generator_preview.mp4'), 0x7634706d, 60.0, (64, 64))
        
        self.__total_epochs = 0
        
        self.__writers = {}
        
        self.__metrics_interval = {}
        self.__metrics_interval_count = 0

        if model is not None:
            self.__preview_latent_noise = model.sample_latent_space(4)
    
    def on_epoch_end(self, epoch : int, step : int, fade : bool, metrics_dict : dict):
        self.__total_epochs += 1
        
        for metric_name, metric_dict in metrics_dict.items():
            if metric_name not in self.__metrics_interval.keys():
                self.__metrics_interval[metric_name] = {}
                
            for metric_subname, metric_value in metric_dict.items():
                if metric_subname not in self.__metrics_interval[metric_name].keys():
                    self.__metrics_interval[metric_name][metric_subname] = .0
            
                self.__metrics_interval[metric_name][metric_subname] += metric_value
            
        self.__metrics_interval_count += 1
        
        if self.__total_epochs % self.__metrics_save_interval == 0:
            self.__write_metrics()
            self.__metrics_interval_count = 0
            self.__metrics_interval = {}
        
        if self.__total_epochs % self.__generator_preview_save_interval == 0:
            self.__write_generator_preview(step, fade)
        
        if self.__total_epochs % 100 == 0:
            frame = self.__model.generator[step][int(fade)].predict(self.__generator_preview_vid_latent)[0]
            # [-1., 1.] -> [0, 255]
            frame = np.clip((frame + 1.)*127.5, 0, 255).astype(np.uint8)
            frame = cv2.resize(frame, (64, 64), interpolation=cv2.INTER_NEAREST)
            if len(frame.shape) == 2:
                frame = frame[:,:,np.newaxis]
            if frame.shape[2] == 1:
                frame = cv2.cvtColor(frame, cv2.COLOR_GRAY2BGR)
                
            self.__generator_preview_vid.write(frame)
            
            if self.__total_epochs % 1000 == 0:
                cv2.imwrite(os.path.join('./gen', self.__datetime_str , f'{step:02d}_{epoch:06d}_{int(fade)}.jpg'), frame)
        
    def on_fit_end(self):
        self.__generator_preview_vid.release()
        
        for _, writer in self.__writers.items():
            writer.close()
        
        self.__writers = {}
        
    def __write_metrics(self):
        for metric_name, metric_dict in self.__metrics_interval.items():
            for loss_name, loss_value in metric_dict.items():
                if os.path.join(self.__logdir, loss_name) not in self.__writers.keys():
                    writer = tf.summary.create_file_writer(os.path.join(self.__logdir, loss_name))
                    self.__writers[os.path.join(self.__logdir, loss_name)] = writer
                    
                with self.__writers[os.path.join(self.__logdir, loss_name)].as_default():
                    tf.summary.scalar(metric_name, loss_value/self.__metrics_interval_count, step=self.__total_epochs)
    
    def __write_generator_preview(self, step : int, fade : bool):
        if os.path.join(self.__logdir, 'model') not in self.__writers.keys():
            writer = tf.summary.create_file_writer(os.path.join(self.__logdir, 'model'))
            self.__writers[os.path.join(self.__logdir, 'model')] = writer
            
        with self.__writers[os.path.join(self.__logdir, 'model')].as_default():
            preview_generated_images = self.__model.generator[step][fade].predict(self.__preview_latent_noise)
            tf.summary.image('Generator preview', preview_generated_images[:4,], step=step, max_outputs=4)
    
    def __del__(self):
        for _, writer in self.__writers.items():
            writer.close()
    

In [None]:
img_gen = ImageGenerator(r'E:\Workspace\datasets\cats\train', batch_size=16, image_channels=3) # r'E:\Workspace\datasets\b\train_1\512'

In [None]:
optimizer = keras.optimizers.Adam(learning_rate=5e-5, beta_1=0., beta_2=.99, epsilon=1e-8)
# optimizer = keras.optimizers.RMSprop(learning_rate=5e-5)

progan = ProgressiveGAN(
    latent_dim=256,
    initial_image_size=4,
    final_image_size=512,
    image_channels=3,
    discriminator_optimizer=optimizer,
    gan_optimizer=optimizer)

tensorboard_callback = TensorBoardCallback('./logs', progan)

progan.fit(img_gen, epochs_per_step=400, discriminator_train_per_gan_train=1, tensorboard_callback=tensorboard_callback)

In [None]:
batch = np.concatenate([img_gen.get_batch(), img_gen.get_batch()], axis=0)
print(batch.shape)
print(batch.min())
print(batch.max())

plt.figure(figsize=(16, 8))

plt.imshow(np.vstack([np.hstack([batch[i + 8*j] for i in range(8)]) for j in range(4)])/2 + .5)

plt.show()

In [None]:

latent_noise = progan.sample_latent_space(8)

generated_images = np.zeros((8*len(progan.generator), 32, 32, 3))
for i in range(len(progan.generator)):
    g = progan.generator[i][0].predict(latent_noise)
    for j in range(8):
        img = g[j,]
        img = (img + 1.)/2.
        generated_images[8*i + j,] = cv2.resize(img, (32, 32), interpolation=cv2.INTER_NEAREST)
    
print(generated_images.shape)
print(generated_images.min())
print(generated_images.max())

plt.figure(figsize=(16, 2*len(progan.generator)))

plt.imshow(np.vstack([np.hstack([generated_images[i + 8*j] for i in range(8)]) for j in range(len(progan.generator))])/2 + .5, interpolation=None)

plt.show()

In [None]:
progan.generator[-1][0].save('./model/generator_cats.h5')