In [None]:
import keras
from keras import layers
from keras.models import Model
from keras import backend as K

# Parameters 

img_shape = (28, 28, 1)
batch_size = 150
latent_dim = 2
epochs = 30
optimizer = 'adam'

class VAE(keras.layers.Layer):
    
    import keras
    from keras import layers
    from keras.models import Model
    from keras import backend as K
    
    def __init__(self, img_shape, latent_dim, batch_size, epochs, optimizer, x_train, x_test):
        self.img_shape = img_shape
        self.latent_dim = latent_dim
        self.batch_size = batch_size
        self.epochs = epochs
        self.optimizer = optimizer
        self.x_train = x_train
        self.x_test = x_test
        self.input_layer = layers.Input(shape=self.img_shape)
    
    # gaussian encoder
    def model(self):
#         i = layers.Input(shape=self.img_shape)
        i = self.input_layer
        # Convolution layer 
        c = layers.Conv2D(32, 3, padding='same', activation='relu')(i)
        c = layers.Conv2D(64, 3, padding='same', activation='relu', strides=(2, 2))(c)
        c = layers.Conv2D(64, 3, padding='same', activation='relu')(c)
        e = layers.Conv2D(64, 3, padding='same', activation='relu')(c)
        # shape before_Dense
        self.shape_before_dense = K.int_shape(e)
        # 
        x = layers.Flatten()(e)
        x = layers.Dense(32, activation='relu')(x)
        
        # q_pie parameters 
        self.z_mean = layers.Dense(self.latent_dim)(x)
        self.z_log_var = layers.Dense(self.latent_dim)(x)
        # sampling
        
        ez = layers.Lambda(self.sampling)([self.z_mean, self.z_log_var])
        
        z = layers.Input(K.int_shape(ez)[1:])
        x = layers.Dense(32)(z)
        x = layers.Dense(np.prod(self.shape_before_dense[1:]), activation='relu')(z)
        x = layers.Reshape(self.shape_before_dense[1:])(x)
        x = layers.Conv2DTranspose(32, 3, padding='same', activation='relu', strides=(2, 2))(x)
        z_decoded = layers.Conv2D(1, 3, padding='same', activation='sigmoid')(x)
        
        return z_decoded
    
#     # bernoulli decoder
#     def decoder(self):
#         i = layers.Input(shape=self.img_shape)
#         z = self.encoder()(i)
#         z = layers.Input(K.int_shape(z))
#         x = layers.Dense(32)(z)
#         x = layers.Dense(np.prod(self.shape_before_dense), activation='relu')(z)
#         x = layers.Reshape(self.shape_before_dense)(x)
#         x = layers.Conv2DTranspose(32, 3, padding='same', activation='relu', stride=(2, 2))(x)
#         reconstruction = layers.Conv2D(1, 3, padding='same', activation='sigmoid')(x)
        
#         return reconstruction
    
    def sampling(self, args):
        """
        q_pie(Z)에서 샘플링을 한 것을 Decoder(Generator)로 넘기기 때문에, 원칙적으로는 Back-propagation이 안된다. 이를 가능하게 하기 위해서 
        사용한 트릭으로, 아래와 같은 수식으로 샘플링을 하면, 원 확률분포의 특성을 해치지 않으면서, back-propagation이 가능하다. 
        """
        z_mean, z_log_var = args
                                            # batch     # batch 
        ep = K.random_normal(shape=(K.shape(z_mean)[0], latent_dim), mean=0, stddev=1)
        
        return z_mean + K.exp(z_log_var) * ep
    
    def vae_loss(self, x, z_decoded):
        x = K.flatten(x)
        z_decoded = K.flatten(z_decoded)
        xent_loss = keras.metrics.binary_crossentropy(x, z_decoded)

        k1_loss = -5e-4 * K.mean(
            1 + z_log_var - K.square(z_mean) - K.exp(z_log_var), axis=-1)
        return K.mean(xent_loss + k1_loss)
    
    def call(self, inputs):
        x = inputs[0]
        z_decoded = inputs[1]
        loss = self.vae_loss(x, z_decoded)
        self.add_loss(loss, inputs=inputs)
        return x
    
    def train(self):
#         z_decoded = self.decoder(self.encoder(x))
        x = self.input_layer
        z_decoded = self.model()
        
        model = Model(x, z_decoded)
        model.compile(optimizer=self.optimizer, loss=None)
        model.summary()
        
#         model.fit(x=self.x_train, y=None, shuffle=True, epochs=self.epochs, 
#                   batch_size=self.batch_size, validation_data=(self.x_test, None))

vae = VAE(img_shape, latent_dim, batch_size, epochs, optimizer, x_train, x_test)
vae.train()