In [1]:
import tensorflow as tf
import tensorflow.keras as keras
from tensorflow.python.ops import summary_ops_v2
from tensorflow.python.keras.callbacks import keras_model_summary
from keras.layers import Conv2D, Conv2DTranspose, UpSampling2D, MaxPool2D, Input, Dense, Flatten, Dropout, Concatenate

import numpy as np

import matplotlib.pyplot as plt

import cv2

import os
import datetime

from random import sample

In [2]:
class Combine(keras.layers.Layer):
    def __init__(self, *args, **kwargs):
        super(Combine, self).__init__(*args, **kwargs)

    def build(self, input_shape):
        self.alpha = self.add_weight(shape=(1,), initializer="zeros", trainable=False)
        
    def get_config(self):
        config = super(Combine, self).get_config().copy()
        return config

    def call(self, inputs):
        return self.alpha*inputs[0] + (1 - self.alpha)*inputs[1]

In [3]:
class MinibatchStdev(keras.layers.Layer):
    def __init__(self, **kwargs):
        super(MinibatchStdev, self).__init__(**kwargs)

    # calculate the mean standard deviation across each pixel coord
    def call(self, inputs):
        mean = keras.backend.mean(inputs, axis=0, keepdims=True)
        mean_sq_diff = keras.backend.mean(keras.backend.square(inputs - mean), axis=0, keepdims=True) + 1e-8
        mean_pix = keras.backend.mean(keras.backend.sqrt(mean_sq_diff), keepdims=True)
        shape = keras.backend.shape(inputs)
        output = keras.backend.tile(mean_pix, [shape[0], shape[1], shape[2], 1])
        return keras.backend.concatenate([inputs, output], axis=-1)

    # define the output shape of the layer
    def compute_output_shape(self, input_shape):
        input_shape = list(input_shape)
        input_shape[-1] += 1
        return tuple(input_shape)

In [4]:
class ImageGenerator(object):
    def __init__(self, images_folder_path, initial_images_size=4):
        self.__images_folder_path = images_folder_path
        self.__images_size = initial_images_size
        self.__filenames = []

        for _, _, fnames in os.walk(self.__images_folder_path):
            for fname in fnames:
                if fname.split('.')[-1] in ('jpg', 'jpeg'):
                    self.__filenames.append(fname)
            break
    
        print(f'Loaded {len(self.__filenames)} images.')
    
    def set_images_size(self, size):
        self.__images_size = size

    def get_batch(self, batch_size=32):
        result = np.zeros((batch_size, self.__images_size, self.__images_size, 3))

        fnames = sample(self.__filenames, batch_size)

        for i in range(batch_size):
            img = cv2.imread(os.path.join(self.__images_folder_path, fnames[i]))[:,:,::-1]
            min_size = min(img.shape[:2])
            img = img[(img.shape[0] - min_size)//2:(img.shape[0] + min_size)//2,
                      (img.shape[1] - min_size)//2:(img.shape[1] + min_size)//2]
            img = cv2.resize(img, (self.__images_size,)*2).astype(np.float32)
            img -= img.min()
            img /= img.max()
            # [0, 1] -> [-1, 1]
            img *= 2.
            img -= 1.
            result[i,] = img
        
        return result

In [5]:
def clone_layer(layer):
    config = layer.get_config()
    del config['name']
    weights = layer.get_weights()
    cloned_layer = type(layer).from_config(config)
    cloned_layer.build(layer.input_shape)
    cloned_layer.set_weights(weights)
    return cloned_layer

In [6]:
class ProgressiveGAN(object):
    __latent_dim            : int
    __initial_image_size    : int
    __final_image_size      : int
    __current_image_size    : int

    __generator             : keras.Model
    __discriminator         : keras.Model

    __gan                   : keras.Model

    def __init__(self, latent_dim : int =128, initial_image_size : int =4, final_image_size : int =512, gan_optimizer : (str | keras.optimizers.Optimizer) ='adam', discriminator_optimizer : (str | keras.optimizers.Optimizer) ='adam'):
        self.__latent_dim = latent_dim
        self.__initial_image_size = initial_image_size
        self.__final_image_size = final_image_size
        self.__gan_optimizer = gan_optimizer
        self.__discriminator_optimizer = discriminator_optimizer

        self.__current_image_size = self.__initial_image_size

        self.__generator = None
        self.__discriminator = None
        self.__gan = None

        self.__init_generator()
        self.__init_discriminator()
        self.__init_gan()
    
    @property
    def generator(self):
        return self.__generator

    @property
    def discriminator(self):
        return self.__discriminator

    @property
    def gan(self):
        return self.__gan

    def fit(self, image_generator : ImageGenerator, epochs_per_step=32, batch_size=32, discriminator_train_per_gan_train=5, tensorboard_callback=None):
        
        step = 0
        epochs_per_half_step = epochs_per_step >> 1

        if tensorboard_callback is not None:
            tensorboard_callback.write_model_summary(step=step)
        
        combines_dropped = True

        while True:
            print(f'img size={self.__current_image_size}  ')

            image_generator.set_images_size(self.__current_image_size)

            for epoch in range(epochs_per_step):
                # increasing image size
                if epoch == epochs_per_step >> 1 and self.__current_image_size < self.__final_image_size:
                    print(f'\nimg size={self.__current_image_size} -> ', end='')

                    cv2.imwrite(f'./gen/{self.__current_image_size}.png', np.clip((generated_images[0] * 255), 0, 255).astype(np.uint8))

                    generator_combine, discriminator_combine = self.__increase_image_size()
                    combines_dropped = False
                    image_generator.set_images_size(self.__current_image_size)

                    print(f'{self.__current_image_size}  ')

                    if tensorboard_callback is not None:
                        tensorboard_callback.write_model_summary(step=step*10 + 1)

                # adjust combine param
                if not combines_dropped:
                    alpha = float(epoch - (epochs_per_step >> 1))/(epochs_per_step >> 1)
                    generator_combine.set_weights([np.array([alpha])])
                    discriminator_combine.set_weights([np.array([alpha])])

                    print(f' alpha={generator_combine.get_weights()[0]}  ', end='')

                # train discriminator
                d_loss = 0.
                for _ in range(discriminator_train_per_gan_train):
                    latent_noise = np.random.normal(0, 1, (batch_size, 1, 1, self.__latent_dim))

                    generated_images = self.__generator.predict(latent_noise)
                    real_images = image_generator.get_batch(batch_size)
                    
                    generated_labels = np.zeros((batch_size, 1))
                    real_labels = np.ones((batch_size, 1))

                    # combined_images = np.concatenate([generated_images, real_images])

                    # labels = np.ones((batch_size << 1, 1))
                    # labels[:batch_size,] = 0

                    generated_labels += 0.1 * np.random.random(generated_labels.shape)
                    generated_labels = np.clip(generated_labels, 0, 1)

                    real_labels += 0.1 * np.random.random(real_labels.shape)
                    real_labels = np.clip(real_labels, 0, 1)
                    
                    d_loss += self.__discriminator.train_on_batch(generated_images, generated_labels)
                    d_loss += self.__discriminator.train_on_batch(real_images, real_labels)
                
                d_loss /= discriminator_train_per_gan_train
                
                # train generator
                latent_noise = np.random.normal(0, 1, (batch_size, 1, 1, self.__latent_dim))

                misleading_targets = np.ones((batch_size, 1))

                g_loss = self.__gan.train_on_batch(latent_noise, misleading_targets)

                print(f'\r{epoch + 1} / {epochs_per_step} d_loss={d_loss} g_loss={g_loss}', end='')

                if tensorboard_callback is not None:
                    tensorboard_callback.write_losses(d_loss, g_loss, step=(epoch + step*epochs_per_step))
            
            print()

            if tensorboard_callback is not None:
                tensorboard_callback.write_generator_preview(step=step)

            # final size reached
            if self.__current_image_size == self.__final_image_size and combines_dropped:
                break
            
            self.__drop_combine_layer()
            combines_dropped = True
            
            if tensorboard_callback is not None:
                tensorboard_callback.write_model_summary(step=step*10)

            step += 1
        
        print()

        cv2.imwrite(f'./gen/{self.__current_image_size}.png', np.clip((generated_images[0] * 255), 0, 255).astype(np.uint8))

    def __init_generator(self) -> keras.Model:
        generator_input = x = Input((1, 1, self.__latent_dim))

        x = Conv2DTranspose(self.__latent_dim, 4, activation='relu')(x)
        x = Conv2D(self.__latent_dim, 3, padding='same', activation='relu')(x)

        output_size = 4

        while output_size < self.__initial_image_size:
            filters = self.__filters_count(output_size)

            x = UpSampling2D()(x)
            x = Conv2D(filters, 3, padding='same', activation='relu')(x)
            x = Conv2D(filters, 3, padding='same', activation='relu')(x)

            output_size <<= 1

        x = Conv2D(3, 1, padding='same', activation='tanh')(x)

        self.__generator = keras.Model(generator_input, x, name='generator')

    def __init_discriminator(self) -> keras.Model:
        discriminator_input = x = Input((self.__initial_image_size, self.__initial_image_size, 3))

        x = Conv2D(self.__filters_count(self.__initial_image_size << 1), 1, padding='same', activation='relu')(x)

        output_size = self.__initial_image_size

        while output_size > 4:
            x = Conv2D(self.__filters_count(output_size << 1), 3, padding='same', activation='relu')(x)
            x = Conv2D(self.__filters_count(output_size), 3, padding='same', activation='relu')(x)
            x = MaxPool2D()(x)

            output_size >>= 1

        x = MinibatchStdev()(x)
        x = Conv2D(self.__filters_count(output_size), 3, padding='same', activation='relu')(x)
        x = Conv2D(self.__latent_dim, 4, activation='relu')(x)
        x = Flatten()(x)
        
        x = Dense(1, activation='sigmoid')(x)

        self.__discriminator = keras.Model(discriminator_input, x, name='discriminator')

        self.__compile_discriminator()
    
    def __init_gan(self):
        gan_input = x = Input((1, 1, self.__latent_dim))

        self.__discriminator.trainable = False
        self.__gan = keras.Model(gan_input, self.__discriminator(self.__generator(gan_input)), name='gan')

        self.__gan.compile(optimizer=keras.optimizers.RMSprop(learning_rate=1e-4, clipvalue=1.0, decay=1e-8), loss='binary_crossentropy')
    
    def __increase_image_size(self):
        self.__current_image_size <<= 1

        # generator

        # new layers
        x = UpSampling2D()(self.__generator.layers[-2].output)
        x = Conv2D(self.__filters_count(self.__current_image_size), 3, padding='same', activation='relu')(x)
        x = Conv2D(self.__filters_count(self.__current_image_size), 3, padding='same', activation='relu')(x)

        x = Conv2D(3, 1, padding='same', activation='tanh')(x)

        # upsample previous output
        prev_step_output = self.__generator.output
        prev_step_rgb = keras.layers.UpSampling2D()(prev_step_output)

        generator_combine = Combine(name='generator_combine')
        x = generator_combine([x, prev_step_rgb])

        self.__generator = keras.Model(self.__generator.input, x, name='generator')

        # discriminator
        discriminator_input = Input((self.__current_image_size, self.__current_image_size, 3))

        prev_step_downsample = MaxPool2D()(discriminator_input)

        prev_step_rgb = self.__discriminator.layers[1](prev_step_downsample)

        x = Conv2D(self.__filters_count(self.__current_image_size << 1), 1, padding='same', activation='relu')(x)

        x = Conv2D(self.__filters_count(self.__current_image_size << 1), 3, padding='same', activation='relu')(discriminator_input)
        x = Conv2D(self.__filters_count(self.__current_image_size), 3, padding='same', activation='relu')(x)
        new_step_output = MaxPool2D()(x)

        discriminator_combine = Combine(name='discriminator_combine')
        x = discriminator_combine([new_step_output, prev_step_rgb])
        
        for layer in self.__discriminator.layers[2:]:
            x = clone_layer(layer)(x)

        self.__discriminator = keras.Model(discriminator_input, x, name='discriminator')

        self.__compile_discriminator()

        # gan
        self.__init_gan()

        return generator_combine, discriminator_combine
    
    def __drop_combine_layer(self):
        # generator
        layer = self.__generator.layers[0]

        layers = []
        i = -1
        s = self.__initial_image_size
        while s < self.__current_image_size:
            s <<= 1
            i += 1

        while True:
            layer = layer._outbound_nodes[0].layer
            layers.append(layer)
            if len(set(map(lambda x: x.layer.name, layer._outbound_nodes))) != 1:
                break
                # if i == 0:
                #     break
                # i -= 1

        outs = list(set(layer._outbound_nodes))

        layer_a = outs[0].layer
        layer_b = outs[1].layer

        branch_a = [layer_a]
        branch_b = [layer_b]
        
        while len(layer_a._outbound_nodes) > 0:
            layer_a = layer_a._outbound_nodes[0].layer
            branch_a.append(layer_a)
        
        while len(layer_b._outbound_nodes) > 0:
            layer_b = layer_b._outbound_nodes[0].layer
            branch_b.append(layer_b)

        if len(branch_a) > len(branch_b):
            layers.extend(branch_a)
            
        else:
            layers.extend(branch_b)

        generator_input = x = keras.layers.Input(shape=self.__generator.input.shape[1:])
        for layer in layers:
            if 'combine' in layer.name:
                break
            x = clone_layer(layer)(x)

        self.__generator = keras.Model(generator_input, x, name='generator')

        # discirminator
        layer = self.__discriminator.layers[0]

        layer_a = layer._outbound_nodes[0].layer
        layer_b = layer._outbound_nodes[1].layer

        branch_a = [layer_a]
        branch_b = [layer_b]
        
        while len(layer_a._outbound_nodes) > 0:
            layer_a = layer_a._outbound_nodes[0].layer
            branch_a.append(layer_a)
        
        while len(layer_b._outbound_nodes) > 0:
            layer_b = layer_b._outbound_nodes[0].layer
            branch_b.append(layer_b)

        if len(branch_a) > len(branch_b):
            layers = branch_a
            
        else:
            layers = branch_b

        discriminator_input = x = keras.layers.Input(self.__discriminator.input.shape[1:])
        for layer in layers:
            if 'combine' in layer.name:
                continue
            x = clone_layer(layer)(x)

        self.__discriminator = keras.Model(discriminator_input, x, name='discriminator')

        self.__compile_discriminator()
        
        # gan
        self.__init_gan()
    
    def __compile_discriminator(self):
        self.__discriminator.compile(optimizer=keras.optimizers.RMSprop(learning_rate=1e-4, clipvalue=1.0, decay=1e-8), loss='binary_crossentropy')

    def __filters_count(self, output_size):
        filters = self.__latent_dim
        while output_size*filters >= self.__final_image_size*16:
            filters //= 2
        
        return filters

In [7]:
class TensorBoardCallback(object):
    def __init__(self, logdir : str, model : ProgressiveGAN = None):
        self.__logdir = logdir

        self.__discriminator_writer = tf.summary.create_file_writer(os.path.join(logdir, datetime.datetime.now().strftime('%Y%m%d-%H%M%S'), 'discriminator'))
        self.__gan_writer = tf.summary.create_file_writer(os.path.join(logdir, datetime.datetime.now().strftime('%Y%m%d-%H%M%S'), 'gan'))
        self.__model_writer = tf.summary.create_file_writer(os.path.join(logdir, datetime.datetime.now().strftime('%Y%m%d-%H%M%S'), 'model'))

        self.__model = model
        if model is not None:
            self.__preview_latent_noise = np.random.normal(0, 1, (4, *model.generator.input.shape[1:]))
    
    def write_losses(self, d_loss, g_loss, step):
        with self.__discriminator_writer.as_default():
            tf.summary.scalar('loss', d_loss, step=step)
        with self.__gan_writer.as_default():
            tf.summary.scalar('loss', -g_loss, step=step)
    
    def write_generator_preview(self, step):
        with self.__gan_writer.as_default():
            preview_generated_images = self.__model.generator.predict(self.__preview_latent_noise)
            tf.summary.image('Generator preview', preview_generated_images[:4,], step=step, max_outputs=4)
    
    def write_model_summary(self, step):
        with self.__model_writer.as_default():
            with summary_ops_v2.record_if(True):
                if self.__model.gan._is_graph_network:
                    keras_model_summary('gan', self.__model.gan, step=step)

                if self.__model.generator._is_graph_network:
                    keras_model_summary('generator', self.__model.generator, step=step)

                if self.__model.discriminator._is_graph_network:
                    keras_model_summary('discriminator', self.__model.discriminator, step=step)
                    
                gan_train_fn = self.__model.gan.train_tf_function
                if hasattr(gan_train_fn, 'function_spec'):
                    summary_ops_v2.graph(gan_train_fn._concrete_stateful_fn.graph)

    def __del__(self):
        self.__discriminator_writer.close()
        self.__gan_writer.close()
        self.__model_writer.close()


In [8]:
img_gen = ImageGenerator(r'E:\Workspace\datasets\cats') # (r'E:\Workspace\datasets\b\train_1\512') # 

Loaded 9993 images.


In [9]:
progan = ProgressiveGAN(latent_dim=128, initial_image_size=4, final_image_size=512)

tensorboard_callback = TensorBoardCallback('./logs', progan)

progan.fit(img_gen, epochs_per_step=400, batch_size=32, discriminator_train_per_gan_train=4, tensorboard_callback=tensorboard_callback)

img size=4  
200 / 400 d_loss=0.21217384043666243 g_loss=3.4416749477386475
img size=4 -> 8  
400 / 400 d_loss=0.26689958949873466 g_loss=2.642033100128174 alpha=[0.995]   
img size=8  
200 / 400 d_loss=19.49173406461549 g_loss=5.8443516692818775e-099
img size=8 -> 16  
400 / 400 d_loss=25.836608615367055 g_loss=40.052955627441406lpha=[0.995]  ]     
img size=16  
107 / 400 d_loss=37.891128182411194 g_loss=7.695806214513823e-109

In [None]:
batch = img_gen.get_batch()
print(batch.shape)
print(batch.min())
print(batch.max())

plt.figure(figsize=(64, 8))

plt.imshow(np.vstack([np.hstack([batch[i + 8*j] for i in range(8)]) for j in range(4)]))

plt.show()

In [None]:

latent_noise = np.random.normal(0, 1, (32, 1, 1, 512))

generated_images = progan.generator.predict(latent_noise)
print(generated_images.min())
print(generated_images.max())

plt.figure(figsize=(64, 8))

plt.imshow(np.vstack([np.hstack([generated_images[i + 8*j] for i in range(8)]) for j in range(4)]))

plt.show()