In [126]:
from keras.models import Model
from keras.utils import plot_model
from keras.layers import (
    Conv2D,
    Dense,
    Input,
    Reshape,
    BatchNormalization,
    UpSampling2D,
    LeakyReLU,
    Dropout
)
import keras.backend as K
import numpy as np

def generator():
    noise_shape = [None, None, 100]
    input_layer = Input(shape=(noise_shape[0], noise_shape[1], noise_shape[2]))
    
    def g_cnn_layer(inputs, filter_size, kernel_size, strides_size=2, moment_rate=0.8):
        cnn = Conv2D(filters=filter_size, kernel_size=kernel_size, strides=strides_size, activation='relu')(inputs)
        batch_norm = BatchNormalization(momentum=moment_rate)(cnn)
        up_sampling = UpSampling2D()(batch_norm)
        return up_sampling
    
    dense_layer = Dense(1024, input_dim=100)(input_layer)
    batch_normalization = BatchNormalization()(dense_layer)
    cnn1 = g_cnn_layer(batch_normalization, 512, 8)
    cnn2 = g_cnn_layer(cnn1, 256, 16)
    cnn3 = g_cnn_layer(cnn2, 128, 32)
    assert cnn3.shape[3] == 128
    cnn4 = Conv2D(filters=3, kernel_size=3, strides=2, activation='tanh')(cnn3)
    assert cnn4.shape[3] == 3
    
    model = Model(inputs=input_layer, outputs=cnn4)
    model.compile(
        optimizer='Adam',
        loss='binary_crossentropy',
        metrics=['accuracy'],
    )
    model.summary()
    return model


def discriminator():
    def d_cnn_layer(inputs, filter_size, kernel_size, stride_size, hyper_params=[0.2, 0.8, 0.25]):
        cnn = Conv2D(filters=filter_size, kernel_size=kernel_size, strides=stride_size)(inputs)
        activate_func = LeakyReLU(alpha=hyper_params[0])(cnn)
        dropout = Dropout(hyper_params[1])(activate_func)
        batch_norm = BatchNormalization(momentum=hyper_params[2])(dropout)
        
        return batch_norm
    input_layer = Input(shape=(32, 32, 64))
    cnn1 = d_cnn_layer(input_layer, 128, 2, 2)
    cnn2 = d_cnn_layer(cnn1, 256, 2, 2)
    cnn3 = d_cnn_layer(cnn2, 512, 2, 2)
    classfication = Dense(512, activation='sigmoid')(cnn3)
    model = Model(inputs=input_layer, outputs=classfication)
    model.compile(
        optimizer='Adam',
        loss='binary_crossentropy',
        metrics=['accuracy'],
    )
    model.summary()
    return model
    
    
#     cnn1 = Conv2D(filters=512, kernel_size=8, strides=2, activation='tanh')(input_layer)
#     shape_checker(cnn1, [4,4,1024])
    

def shape_checker(layer_output, expect_shape):
    """
        check layer output shape,
        but this func isn't check batch size(keras outpus batch size is "?")
        argument
            first: layer_output(taple(int)) -> layer output
            second: answer_shape(list(int)) -> expect output shape(not add batch size)
            eg:
                layer_output: (?, 32, 32)
                expect_shape: [32, 32] (* not add batch size)
            
    """
    layer_output_shape = layer_output.shape
    for i, shape in enumerate(layer_output_shape):
        if i != 0:
            answer_s = shape
            expect_s = expect_shape[i-1]
            assert answer_s == expect_s, f"not match shape: output {layer_output_shape},  expect: {expect_shape}"
    return True
generator()
discriminator()

_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input_86 (InputLayer)        (None, None, None, 100)   0         
_________________________________________________________________
dense_50 (Dense)             (None, None, None, 1024)  103424    
_________________________________________________________________
batch_normalization_76 (Batc (None, None, None, 1024)  4096      
_________________________________________________________________
conv2d_133 (Conv2D)          (None, None, None, 512)   33554944  
_________________________________________________________________
batch_normalization_77 (Batc (None, None, None, 512)   2048      
_________________________________________________________________
up_sampling2d_1 (UpSampling2 (None, None, None, 512)   0         
_________________________________________________________________
conv2d_134 (Conv2D)          (None, None, None, 256)   33554688  
__________

<keras.engine.training.Model at 0xbb9f867f0>