In [None]:
import tensorflow as tf
import tensorflow.keras as keras
from keras.layers import Conv2D, Conv2DTranspose, UpSampling2D, MaxPool2D, Input, Dense

import numpy as np

import matplotlib.pyplot as plt

import cv2

import os

from random import sample

In [None]:
class Combine(keras.layers.Layer):
    def __init__(self):
        super(Combine, self).__init__()

    def build(self, input_shape):
        self.alpha = tf.Variable(initial_value=tf.zeros_initializer()(shape=(1,), dtype="float32"), trainable=False)
        
    def get_config(self):
        config = super(Combine, self).get_config().copy()
        return config

    def call(self, inputs):
        return self.alpha*inputs[0] + (1 - self.alpha)*inputs[1]

In [None]:
class ImageGenerator(object):
    def __init__(self, images_folder_path, initial_images_size=4, batch_size=16):
        self.__images_folder_path = images_folder_path
        self.__images_size = initial_images_size
        self.__batch_size = batch_size
        self.__filenames = []

        for _, _, fnames in os.walk(self.__images_folder_path):
            for fname in fnames:
                if fname.split('.')[-1] in ('jpg', 'jpeg'):
                    self.__filenames.append(fname)
            break
    
        print(f'Loaded {len(self.__filenames)} images.')
    
    def set_images_size(self, size):
        self.__images_size = size

    def get_batch(self):
        result = np.zeros((self.__batch_size, self.__images_size, self.__images_size, 3))

        fnames = sample(self.__filenames, self.__batch_size)

        for i in range(self.__batch_size):
            img = cv2.imread(os.path.join(self.__images_folder_path, fnames[i]))[:,:,::-1]
            min_size = min(img.shape[:2])
            img = img[(img.shape[0] - min_size)//2:(img.shape[0] + min_size)//2,
                      (img.shape[1] - min_size)//2:(img.shape[1] + min_size)//2]
            img = cv2.resize(img, (self.__images_size,)*2)
            result[i,] = img.astype(np.float32)/255
        
        return result

In [None]:
class ProgressiveGAN(object):
    def __init__(self, latent_dim=128, initial_image_size=4, final_image_size=512):
        self.__latent_dim = latent_dim
        self.__initial_image_size = initial_image_size
        self.__final_image_size= final_image_size

        self.__current_image_size = self.__initial_image_size

        self.__generator_input = None
        self.__generator_output_pre_rgb = None
        self.__discriminator_input_post_rgb = None
        self.__discriminator_output = None

        self.__input_combine_layer = None
        self.__output_combine_layer = None

        self.__to_rgb_output = None
        self.__from_rgb_input = None

        self.__generator = self.__init_generator()
        self.__discriminator = self.__init_discriminator()
        self.__gan = self.__init_gan()
    
    @property
    def generator(self):
        return self.__generator

    @property
    def discriminator(self):
        return self.__discriminator

    @property
    def gan(self):
        return self.__gan

    def fit(self, image_generator, lr=.01):
        pass

    def __init_generator(self):
        self.__generator_input = x = Input((1, 1, self.__latent_dim))

        x = Conv2DTranspose(self.__latent_dim, 4, activation='relu')(x)
        x = Conv2D(self.__latent_dim, 3, padding='same', activation='relu')(x)

        output_size = 4

        while output_size < self.__initial_image_size:
            filters = self.__filters_count(output_size)

            x = UpSampling2D()(x)
            x = Conv2D(filters, 3, padding='same', activation='relu')(x)
            if (output_size << 1) < self.__initial_image_size:
                x = Conv2D(filters, 3, padding='same', activation='relu')(x)
            else:
                self.__generator_output_pre_rgb = x = Conv2D(filters, 3, padding='same', activation='relu')(x)

            output_size <<= 1

        self.__to_rgb_output = x = Conv2D(3, 1, padding='same', activation='relu')(x)

        return keras.Model(self.__generator_input, self.__to_rgb_output, name='generator')

    def __init_discriminator(self):
        discriminator_input = x = Input((self.__initial_image_size, self.__initial_image_size, 3))

        self.__from_rgb_input = x = Conv2D(self.__filters_count(self.__initial_image_size), 1, padding='same', activation='relu')(x)

        output_size = self.__initial_image_size
        
        filters = self.__filters_count(output_size)
        
        self.__discriminator_input_post_rgb = x = Conv2D(filters, 3, padding='same', activation='relu')(x)
        x = Conv2D(filters, 3, padding='same', activation='relu')(x)
        x = MaxPool2D()(x)

        output_size >>= 1

        while output_size > 4:
            filters = self.__filters_count(output_size)
            
            x = Conv2D(filters, 3, padding='same', activation='relu')(x)
            x = Conv2D(filters, 3, padding='same', activation='relu')(x)
            x = MaxPool2D()(x)

            output_size >>= 1
        x = Conv2D(filters, 3, padding='same', activation='relu')(x)
        x = Conv2D(self.__latent_dim, 4, activation='relu')(x)
        self.__discriminator_output = x = Dense(1, activation='sigmoid')(x)

        return keras.Model(discriminator_input, self.__discriminator_output, name='discriminator')
    
    def __init_gan(self):
        gan_input = x = Input((1, 1, self.__latent_dim))

        return keras.Model(gan_input, self.__discriminator(self.__generator(gan_input)), name='gan')
    
    def increase_image_size_transition_step(self):
        self.__current_image_size <<= 1
        
    def increase_image_size_final_step(self):
        filters = self.__filters_count(self.__current_image_size)

        # generator
        x = UpSampling2D()(self.__generator_output_pre_rgb)
        x = Conv2D(filters, 3, padding='same')(x)
        self.__generator_output_pre_rgb = x = Conv2D(filters, 3, padding='same')(x)

        self.__to_rgb_output = x = Conv2D(3, 1, padding='same')(x)

        self.__generator = keras.Model(self.__generator_input, self.__to_rgb_output)

        # discriminator
        discriminator_input = x = Input((self.__current_image_size, self.__current_image_size, 3))

        self.__from_rgb_input = x = Conv2D(self.__filters_count(self.__current_image_size), 1, padding='same')(x)

        self.__discriminator_input_post_rgb = x = Conv2D(filters, 3, padding='same')(x)
        x = Conv2D(filters, 3, padding='same')(x)
        x = MaxPool2D()(x)
        
        for layer in self.__discriminator.layers[2:]:
            x = layer(x)

        self.__discriminator = keras.Model(discriminator_input, x)

        self.__discriminator.summary(line_length=100, expand_nested=True)

        self.__gan = self.__init_gan()

    def __filters_count(self, output_size):
        filters = self.__latent_dim
        while output_size*filters >= self.__final_image_size*16:
            filters /= 2
        
        return filters

In [None]:
progan = ProgressiveGAN(initial_image_size=8)
progan.gan.summary(line_length=100, expand_nested=True)
progan.increase_image_size_transition_step()
progan.increase_image_size_final_step()
progan.gan.summary(line_length=100, expand_nested=True)

In [None]:
img_gen = ImageGenerator(r'E:\Workspace\datasets\cats', batch_size=32)

In [None]:
img_gen.set_images_size(128)

In [None]:
batch = img_gen.get_batch()
plt.figure(figsize=(64, 8))

plt.imshow(np.vstack([np.hstack([batch[8*j + i] for i in range(8)]) for j in range(4)]))

plt.show()

In [None]:
img = cv2.imread(os.path.join(r'E:\Workspace\datasets\cats', '00000001_005.jpg'))

In [None]:
img.resize((64, 64, 3))

In [None]:
img.shape