In [None]:
# Predefined, not worked
class Generator(tf.keras.Model):
    '''
        Generator
    '''
    def __init__(self, noise_dim=NOISE_DIM):
        super(Generator, self).__init__(name='')
        # initial 4 * 4
        self.blocks = Sequential([Input(noise_dim),
                            Dense(4*4*512, kernel_initializer=random_normal, bias_initializer='zeros'),
                            Reshape((4, 4, 512)),
                            Conv2D(512, (3, 3), strides=1, padding='same', kernel_initializer=random_normal, bias_initializer='zeros'),
                            PixelNormalization(),
                            LeakyReLU(),], name='Initial_Block')
        
        self.block_list = [self.blocks,
                          upsample_block(input_shape=(4, 4, 512),filters=512, kernel_size=3, strides=1,
                                         padding='same', activation=tf.nn.leaky_relu, name='Up_{}x{}'.format(8, 8)),
                          upsample_block(input_shape=(8, 8, 512),filters=512, kernel_size=3, strides=1,
                                         padding='same', activation=tf.nn.leaky_relu, name='Up_{}x{}'.format(16, 16)),
                          upsample_block(input_shape=(16, 16, 512),filters=512, kernel_size=3, strides=1,
                                         padding='same', activation=tf.nn.leaky_relu, name='Up_{}x{}'.format(32, 32)),
                          upsample_block(input_shape=(32, 32, 512),filters=256, kernel_size=3, strides=1,
                                         padding='same', activation=tf.nn.leaky_relu, name='Up_{}x{}'.format(64, 64)),
                          upsample_block(input_shape=(64, 64, 256),filters=128, kernel_size=3, strides=1,
                                         padding='same', activation=tf.nn.leaky_relu, name='Up_{}x{}'.format(128, 128)),
                          upsample_block(input_shape=(128, 128, 128),filters=64, kernel_size=3, strides=1,
                                         padding='same', activation=tf.nn.leaky_relu, name='Up_{}x{}'.format(256, 256)),
                          upsample_block(input_shape=(256, 256, 64),filters=32, kernel_size=3, strides=1,
                                         padding='same', activation=tf.nn.leaky_relu, name='Up_{}x{}'.format(512, 512)),]
                              
        self.stable_block_index = 0
        self.total_blocks_length = len(self.block_list)
        
        self.to_rgb_list = [to_rgb_block((4, 4, 512), filters=3, kernel_size=1, strides=1,
                                         padding='same', activation=tf.nn.tanh, name='ToRGB_{}x{}'.format(4, 4)),
                                   to_rgb_block((8, 8, 512), filters=3, kernel_size=1, strides=1,
                                         padding='same', activation=tf.nn.tanh, name='ToRGB_{}x{}'.format(8, 8)),
                                   to_rgb_block((16, 16, 512), filters=3, kernel_size=1, strides=1,
                                                 padding='same', activation=tf.nn.tanh, name='ToRGB_{}x{}'.format(16, 16)),
                                   to_rgb_block((32, 32, 512), filters=3, kernel_size=1, strides=1,
                                                 padding='same', activation=tf.nn.tanh, name='ToRGB_{}x{}'.format(32, 32)),
                                   to_rgb_block((64, 64, 256), filters=3, kernel_size=1, strides=1,
                                                 padding='same', activation=tf.nn.tanh, name='ToRGB_{}x{}'.format(64, 64)),
                                   to_rgb_block((128, 128, 128), filters=3, kernel_size=1, strides=1,
                                                 padding='same', activation=tf.nn.tanh, name='ToRGB_{}x{}'.format(128, 128)),
                                   to_rgb_block((256, 256, 64), filters=3, kernel_size=1, strides=1,
                                                 padding='same', activation=tf.nn.tanh, name='ToRGB_{}x{}'.format(256, 256)),
                                   to_rgb_block((512, 512, 32), filters=3, kernel_size=1, strides=1,
                                                 padding='same', activation=tf.nn.tanh, name='ToRGB_{}x{}'.format(512, 512)),]
        
        
    def call(self, inputs, training=False, fade_in=False, alpha=0):
        x = inputs
        for block in self.block_list[0:self.stable_block_index+1]:
            # Upsample blocks will have two output, but we only need the previous in normal phase
            x= block(x)
            if type(x) == list:
                x = x[0]
        if fade_in and (self.stable_block_index + 1 < self.total_blocks_length):
            # Fade in stage
            fade_in_index = self.stable_block_index + 1
            x, up_x = self.block_list[fade_in_index](x)
            x = self.to_rgb_list[fade_in_index](x)
            up_x = self.to_rgb_list[fade_in_index](up_x)
            x = (1- alpha) * up_x + alpha * x
        else:
            x = self.to_rgb_list[self.stable_block_index](x)
        return x
    
    def equalize_learning_rate(self):
        numpy_weights = self.get_weights()
        new_weights = []
        for i, weight in enumerate(self.weights):
            if 'conv2d' in weight.name and 'bias' not in weight.name:
                new_weights.append(compute_equal_lr(numpy_weights[i]))
            else:
                new_weights.append(numpy_weights[i])
        self.set_weights(new_weights)
        
    def get_current_output_shape(self):
        ## Not include To RGB
        output_shapes = self.block_list[self.stable_block_index].get_output_shape_at(-1)
        if type(output_shapes) == list:
            output_shape = output_shapes[0][1:]
        else:
            output_shape = output_shapes[1:]
        return output_shape
    
    def training_next_block(self):
        if self.stable_block_index + 1 < self.total_blocks_length:
            self.stable_block_index += 1
        else:
            print("Already reach the max resolution")

class Discriminator(tf.keras.Model):
    '''
        Generator
    '''
    def __init__(self):
        super(Discriminator, self).__init__(name='')
        # 4 * 4
        self.final_block = Sequential([Input((4, 4, 512)),
                            MinibatchSTDDEV(),
                            Conv2D(512, (3, 3), strides=1, padding='same', kernel_initializer=random_normal, bias_initializer='zeros'),
                            LeakyReLU(),
                            Conv2D(512, (4, 4), strides=1, padding='valid', kernel_initializer=random_normal, bias_initializer='zeros'),
                            LeakyReLU(),
                            Flatten(),
                            Dense(1, kernel_initializer=random_normal, bias_initializer='zeros')], name='FinalBlock')
        
        self.block_list = [downsample_block(input_shape=(512, 512, 32), filters=64, kernel_size=3, strides=1,
                                            padding='same', activation=tf.nn.leaky_relu, name='Down_{}x{}'.format(512,512)),
                                   downsample_block(input_shape=(256, 256, 64), filters=128, kernel_size=3, strides=1,
                                            padding='same', activation=tf.nn.leaky_relu, name='Down_{}x{}'.format(256,256)),
                                   downsample_block(input_shape=(128, 128, 128), filters=256, kernel_size=3, strides=1,
                                            padding='same', activation=tf.nn.leaky_relu, name='Down_{}x{}'.format(128,128)),
                                   downsample_block(input_shape=(64, 64, 256), filters=512, kernel_size=3, strides=1,
                                            padding='same', activation=tf.nn.leaky_relu, name='Down_{}x{}'.format(64,64)),
                                   downsample_block(input_shape=(32, 32, 512), filters=512, kernel_size=3, strides=1,
                                            padding='same', activation=tf.nn.leaky_relu, name='Down_{}x{}'.format(32,32)),
                                   downsample_block(input_shape=(16, 16, 512), filters=512, kernel_size=3, strides=1,
                                            padding='same', activation=tf.nn.leaky_relu, name='Down_{}x{}'.format(16,16)),
                                   downsample_block(input_shape=(8, 8, 512), filters=512, kernel_size=3, strides=1,
                                            padding='same', activation=tf.nn.leaky_relu, name='Down_{}x{}'.format(8,8)),
                                   self.final_block]
        
        self.stable_block_index = len(self.block_list) - 1
        

        self.from_rgb_list = [
                            from_rgb_block(input_shape=(512, 512, 3), filters=32, down_sampled_filters=64, kernel_size=1, strides=1,
                                             padding='same', activation=tf.nn.tanh, name='FromRGB_{}x{}'.format(512, 512)),
                            from_rgb_block(input_shape=(256, 256, 3), filters=64, down_sampled_filters=128, kernel_size=1, strides=1,
                                             padding='same', activation=tf.nn.tanh, name='FromRGB_{}x{}'.format(256, 256)),
                            from_rgb_block(input_shape=(128, 128, 3), filters=128, down_sampled_filters=256, kernel_size=1, strides=1,
                                             padding='same', activation=tf.nn.tanh, name='FromRGB_{}x{}'.format(128, 128)),
                            from_rgb_block(input_shape=(64, 64, 3), filters=256, down_sampled_filters=512, kernel_size=1, strides=1,
                                             padding='same', activation=tf.nn.tanh, name='FromRGB_{}x{}'.format(64, 64)),
                            from_rgb_block(input_shape=(32, 32, 3), filters=512, down_sampled_filters=512, kernel_size=1, strides=1,
                                             padding='same', activation=tf.nn.tanh, name='FromRGB_{}x{}'.format(32, 32)),
                            from_rgb_block(input_shape=(16, 16, 3), filters=512, down_sampled_filters=512, kernel_size=1, strides=1,
                                             padding='same', activation=tf.nn.tanh, name='FromRGB_{}x{}'.format(16, 16)),
                             from_rgb_block(input_shape=(8, 8, 3), filters=512, down_sampled_filters=512, kernel_size=1, strides=1,
                                             padding='same', activation=tf.nn.tanh, name='FromRGB_{}x{}'.format(8, 8)),
                              from_rgb_block(input_shape=(4, 4, 3), filters=512, down_sampled_filters=512, kernel_size=1, strides=1,
                                             padding='same', activation=tf.nn.tanh, name='FromRGB_{}x{}'.format(4, 4))]
        
    def call(self, inputs, training=False, fade_in=False, alpha=0):
        x = inputs
        
        if fade_in and (self.stable_block_index - 1 >= 0):
            # Fade in stage
            fade_in_index = self.stable_block_index - 1
            
            x, down_x = self.from_rgb_list[fade_in_index](x)
            x = self.block_list[fade_in_index](x)
            
            x = (1- alpha) * down_x + alpha * x
        else:
            # Using stable from rgb
            x, _ = self.from_rgb_list[self.stable_block_index](x)
        for block in self.block_list[self.stable_block_index:]:
            x= block(x)
        return x
    
    def equalize_learning_rate(self):
        numpy_weights = self.get_weights()
        new_weights = []
        for i, weight in enumerate(self.weights):
            if 'conv2d' in weight.name and 'bias' not in weight.name:
                new_weights.append(compute_equal_lr(numpy_weights[i]))
            else:
                new_weights.append(numpy_weights[i])
        self.set_weights(new_weights)
    
    def get_current_output_shape(self):
        output_shapes = self.block_list[self.stable_block_index].get_output_shape_at(-1)
        if type(output_shapes) == list:
            output_shape = output_shapes[0][1:]
        else:
            output_shape = output_shapes[1:]
        return output_shape
                              
    def training_next_block(self):
        if self.stable_block_index - 1 >= 0:
            self.stable_block_index -= 1
        else:
            print("Already reach the max resolution")
            
    def set_stable_index(self, index):
        self.stable_block_index = index

In [None]:
class EqualRLDense(tf.keras.layers.Layer):
    def __init__(self, units, activation=None,
                 kernel_initializer=RandomNormal(mean=0.0, stddev=1.0), bias_initializer='zeros', **kwargs):
        super(EqualRLDense, self).__init__(**kwargs)
        #[filter_height, filter_width, in_channels, out_channels]
        self.units = units
        self.kernel_initializer = kernel_initializer
        self.bias_initializer = bias_initializer
        self.activation = activation
    
    def build(self, input_shape):
        self.weight_shape = (input_shape[-1], self.units)
        self.fan_in, self.fan_out= compute_fans(self.weight_shape)
        self.he_constant = tf.Variable(1.0 / np.sqrt(self.fan_in), dtype=tf.float32, trainable=False)
        
        self.kernel = self.add_weight(name='kernel',
                            shape=self.weight_shape,
                             initializer=self.kernel_initializer,
                             trainable=True)
        self.bias = self.add_weight(name='bias',
                                    shape=(self.weight_shape[-1],),
                                 initializer=self.bias_initializer,
                                 trainable=True)
        
        super(EqualRLDense, self).build(input_shape)
        
    def call(self, inputs, training=False):
        if training:  
            eqrl_kernel = tf.multiply(self.kernel, self.he_constant)
            outputs = K.dot(inputs, eqrl_kernel)
            outputs = K.bias_add(
                outputs,
                self.bias)
        else:
            outputs = K.dot(inputs, self.kernel)
            outputs = K.bias_add(
                outputs,
                self.bias)
        if self.activation != None:
            outputs = Activation(self.activation)(outputs)
        return outputs

class EqualRLConv2D(tf.keras.layers.Layer):
    def __init__(self, filters, kernel_size=3, strides=1, padding='valid', activation=None,
                 kernel_initializer=RandomNormal(mean=0.0, stddev=1.0), bias_initializer='zeros', **kwargs):
        super(EqualRLConv2D, self).__init__(**kwargs)
        #[filter_height, filter_width, in_channels, out_channels]
        self.strides = strides
        self.kernel_size = kernel_size
        self.filters = filters
        self.padding = padding
        self.data_format = 'channels_last'
        self.kernel_initializer = kernel_initializer
        self.bias_initializer = bias_initializer
        self.activation = activation
    
    def build(self, input_shape):
        self.weight_shape = (self.kernel_size, self.kernel_size, input_shape[-1], self.filters)
        self.fan_in, self.fan_out= compute_fans(self.weight_shape)
        
        self.he_constant = tf.Variable(1.0 / np.sqrt(self.fan_in), dtype=tf.float32, trainable=False)
        
        self.kernel = self.add_weight(name='kernel',
                            shape=self.weight_shape,
                             initializer=self.kernel_initializer,
                             trainable=True)
        self.bias = self.add_weight(name='bias',
                                    shape=(self.weight_shape[-1],),
                                 initializer=self.bias_initializer,
                                 trainable=True)
        
        super(EqualRLConv2D, self).build(input_shape)
        
    def call(self, inputs, training=False):
        if training:
            eqrl_kernel = tf.multiply(self.kernel, self.he_constant)
            outputs = K.conv2d(inputs, eqrl_kernel,
                         strides=self.strides,
                        padding=self.padding,
                        data_format=self.data_format)
            outputs = K.bias_add(
                outputs,
                self.bias,
                data_format=self.data_format)
        else:
            outputs = K.conv2d(inputs, self.kernel,
                         strides=self.strides,
                        padding=self.padding,
                        data_format=self.data_format)
            outputs = K.bias_add(
                outputs,
                self.bias,
                data_format=self.data_format)
        if self.activation != None:
            outputs = Activation(self.activation)(outputs)
        return outputs


def upsample_block(x, in_filters, filters, kernel_size=3, strides=1, padding='valid', activation=tf.nn.leaky_relu, name=''):
    '''
        Upsampling + 2 Convolution-Activation
    '''
    upsample = UpSampling2D(size=2, interpolation='nearest')(x)
#     if in_filters != filters:
#         x = EqualRLConv2D(filters, kernel_size, strides, padding=padding,
#                    kernel_initializer=kernel_initializer, bias_initializer='zeros', name=name+'_conv2d_0')(upsample)
    upsample_x = EqualRLConv2D(filters, kernel_size, strides, padding=padding,
                   kernel_initializer=kernel_initializer, bias_initializer='zeros', name=name+'_conv2d_1')(upsample)
    x = PixelNormalization()(upsample_x)
    x = Activation(activation)(x)
    x = EqualRLConv2D(filters, kernel_size, strides, padding=padding, kernel_initializer=kernel_initializer, bias_initializer='zeros', name=name+'_conv2d_2')(x)
    x = PixelNormalization()(x)
    x = Activation(activation)(x)
    return x, upsample

def downsample_block(x, filters1, filters2, kernel_size=3, strides=1, padding='valid', activation=tf.nn.leaky_relu, name=''):
    '''
        2 Convolution-Activation + Downsampling
    '''
    x = EqualRLConv2D(filters1, kernel_size, strides, padding=padding,
               kernel_initializer=kernel_initializer, bias_initializer='zeros', name=name+'_conv2d_1')(x)
    x = Activation(activation)(x)
    x = EqualRLConv2D(filters2, kernel_size, strides, padding=padding,
               kernel_initializer=kernel_initializer, bias_initializer='zeros', name=name+'_conv2d_2')(x)
    x = Activation(activation)(x)
    downsample = AveragePooling2D(pool_size=2)(x)

    return downsample