In [2]:
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

from keras.layers import Dense, Input
from keras.layers import Conv2D, Flatten, Lambda
from keras.layers import Reshape, Conv2DTranspose
from keras.models import Model
from keras.datasets import mnist
from keras.losses import mse, binary_crossentropy
from keras.utils import plot_model
from keras import backend as K

import numpy as np
import matplotlib.pyplot as plt
import argparse
import os


# reparameterization trick
# instead of sampling from Q(z|X), sample eps = N(0,I)
# then z = z_mean + sqrt(var)*eps
def sampling(args):
    z_mean, z_log_var = args
    batch = K.shape(z_mean)[0]
    dim = K.int_shape(z_mean)[1]
    # by default, random_normal has mean=0 and std=1.0
    epsilon = K.random_normal(shape=(batch, dim))
    return z_mean + K.exp(0.5 * z_log_var) * epsilon


# MNIST dataset
(x_train, y_train), (x_test, y_test) = mnist.load_data()

print(x_train.shape)
image_size = x_train.shape[1]
x_train = np.reshape(x_train, [-1, image_size, image_size, 1])
x_test = np.reshape(x_test, [-1, image_size, image_size, 1])
x_train = x_train.astype('float32') / 255
x_test = x_test.astype('float32') / 255


print(x_train.shape)


(60000, 28, 28)


(60000, 28, 28, 1)


In [5]:
input_shape = (image_size, image_size, 1)
batch_size = 128
kernel_size = 3
filters = 16
latent_dim = 2
epochs = 30

inputs = Input(shape=input_shape, name='encoder_input')
x = inputs
for i in range(2):
    filters *= 2
    x = Conv2D(filters=filters,
               kernel_size=kernel_size,
               activation='relu',
               strides=2,
               padding='same')(x)

# shape info needed to build decoder model
shape = K.int_shape(x)

# generate latent vector Q(z|X)
x = Flatten()(x)
x = Dense(16, activation='relu')(x)
z_mean = Dense(latent_dim, name='z_mean')(x)
z_log_var = Dense(latent_dim, name='z_log_var')(x)

# use reparameterization trick to push the sampling out as input
# note that "output_shape" isn't necessary with the TensorFlow backend
z = Lambda(sampling, output_shape=(latent_dim,), name='z')([z_mean, z_log_var])

# instantiate encoder model
encoder = Model(inputs, [z_mean, z_log_var, z], name='encoder')
encoder.summary()
plot_model(encoder, to_file='vae_cnn_encoder.png', show_shapes=True)

# build decoder model
latent_inputs = Input(shape=(latent_dim,), name='z_sampling')
x = Dense(shape[1] * shape[2] * shape[3], activation='relu')(latent_inputs)
x = Reshape((shape[1], shape[2], shape[3]))(x)
print("CIAO:"+ str(K.int_shape(x)))

for i in range(2):
    x = Conv2DTranspose(filters=filters,
                        kernel_size=kernel_size,
                        activation='relu',
                        strides=2,
                        padding='same')(x)
    filters //= 2

outputs = Conv2DTranspose(filters=1,
                          kernel_size=kernel_size,
                          activation='sigmoid',
                          padding='same',
                          name='decoder_output')(x)

# instantiate decoder model
decoder = Model(latent_inputs, outputs, name='decoder')
decoder.summary()
plot_model(decoder, to_file='vae_cnn_decoder.png', show_shapes=True)

# instantiate VAE model
outputs = decoder(encoder(inputs)[2])
vae = Model(inputs, outputs, name='vae')


models = (encoder, decoder)
# data = (x_test, y_test)

# VAE loss = mse_loss or xent_loss + kl_loss
reconstruction_loss = mse(K.flatten(inputs), K.flatten(outputs))


reconstruction_loss *= image_size*image_size 
kl_loss = 1 + z_log_var - K.square(z_mean) - K.exp(z_log_var)
kl_loss = K.sum(kl_loss, axis=-1)
kl_loss *= -0.5
vae_loss = K.mean(reconstruction_loss + kl_loss)
vae.add_loss(vae_loss)
vae.compile(optimizer='rmsprop')
vae.summary()
plot_model(vae, to_file='vae_cnn.png', show_shapes=True)


__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
encoder_input (InputLayer)      (None, 28, 28, 1)    0                                            
__________________________________________________________________________________________________
conv2d_3 (Conv2D)               (None, 14, 14, 32)   320         encoder_input[0][0]              
__________________________________________________________________________________________________
conv2d_4 (Conv2D)               (None, 7, 7, 64)     18496       conv2d_3[0][0]                   
__________________________________________________________________________________________________
flatten_2 (Flatten)             (None, 3136)         0           conv2d_4[0][0]                   
__________________________________________________________________________________________________
dense_3 (D

CIAO:(None, 7, 7, 64)


_________________________________________________________________
Layer (type)                 Output Shape              Param #   
z_sampling (InputLayer)      (None, 2)                 0         
_________________________________________________________________
dense_4 (Dense)              (None, 3136)              9408      
_________________________________________________________________
reshape_2 (Reshape)          (None, 7, 7, 64)          0         
_________________________________________________________________
conv2d_transpose_3 (Conv2DTr (None, 14, 14, 64)        36928     
_________________________________________________________________
conv2d_transpose_4 (Conv2DTr (None, 28, 28, 32)        18464     
_________________________________________________________________
decoder_output (Conv2DTransp (None, 28, 28, 1)         289       
Total params: 65,089
Trainable params: 65,089
Non-trainable params: 0
_________________________________________________________________


_________________________________________________________________
Layer (type)                 Output Shape              Param #   
encoder_input (InputLayer)   (None, 28, 28, 1)         0         
_________________________________________________________________
encoder (Model)              [(None, 2), (None, 2), (N 69076     
_________________________________________________________________
decoder (Model)              (None, 28, 28, 1)         65089     
Total params: 134,165
Trainable params: 134,165
Non-trainable params: 0
_________________________________________________________________


In [6]:

vae.fit(x_train,
        epochs=1,
        batch_size=batch_size)
#vae.save_weights('vae_cnn_mnist.h5')


Epoch 1/1


  128/60000 [..............................] - ETA: 1:15:47 - loss: 181.6040

  256/60000 [..............................] - ETA: 49:13 - loss: 178.9046  

  384/60000 [..............................] - ETA: 37:37 - loss: 175.2211

  512/60000 [..............................] - ETA: 32:11 - loss: 168.2553

  640/60000 [..............................] - ETA: 28:47 - loss: 198.2654

  768/60000 [..............................] - ETA: 27:42 - loss: 184.3693

  896/60000 [..............................] - ETA: 27:12 - loss: 171.8195

 1024/60000 [..............................] - ETA: 25:53 - loss: 161.7295

 1152/60000 [..............................] - ETA: 24:43 - loss: 153.1900

 1280/60000 [..............................] - ETA: 23:46 - loss: 145.9027

 1408/60000 [..............................] - ETA: 22:35 - loss: 139.8841

 1536/60000 [..............................] - ETA: 21:30 - loss: 134.4687

 1664/60000 [..............................] - ETA: 21:51 - loss: 130.1917

 1792/60000 [..............................] - ETA: 20:57 - loss: 126.3128

 1920/60000 [..............................] - ETA: 20:07 - loss: 122.7139

 2048/60000 [>.............................] - ETA: 19:28 - loss: 119.4924

 2176/60000 [>.............................] - ETA: 18:55 - loss: 116.3851

 2304/60000 [>.............................] - ETA: 18:23 - loss: 113.6533

 2432/60000 [>.............................] - ETA: 17:52 - loss: 111.1076

 2560/60000 [>.............................] - ETA: 17:26 - loss: 108.7736

 2688/60000 [>.............................] - ETA: 17:32 - loss: 106.5947

 2816/60000 [>.............................] - ETA: 17:09 - loss: 104.6438

 2944/60000 [>.............................] - ETA: 16:49 - loss: 102.7156

 3072/60000 [>.............................] - ETA: 16:31 - loss: 101.1435

 3200/60000 [>.............................] - ETA: 16:10 - loss: 99.6399 

 3328/60000 [>.............................] - ETA: 15:51 - loss: 98.2126

 3456/60000 [>.............................] - ETA: 15:33 - loss: 96.7175

 3584/60000 [>.............................] - ETA: 15:19 - loss: 95.4948

 3712/60000 [>.............................] - ETA: 15:03 - loss: 94.3063

 3840/60000 [>.............................] - ETA: 14:49 - loss: 93.1142

 3968/60000 [>.............................] - ETA: 14:38 - loss: 92.0421

 4096/60000 [=>............................] - ETA: 14:24 - loss: 90.9967

 4224/60000 [=>............................] - ETA: 14:14 - loss: 89.8774

 4352/60000 [=>............................] - ETA: 14:05 - loss: 88.8987

 4480/60000 [=>............................] - ETA: 13:53 - loss: 87.8879

 4608/60000 [=>............................] - ETA: 13:43 - loss: 87.0332

 4736/60000 [=>............................] - ETA: 13:32 - loss: 86.2531

 4864/60000 [=>............................] - ETA: 13:22 - loss: 85.5081

 4992/60000 [=>............................] - ETA: 13:12 - loss: 84.7651

 5120/60000 [=>............................] - ETA: 13:02 - loss: 84.0921

 5248/60000 [=>............................] - ETA: 12:52 - loss: 83.3443

 5376/60000 [=>............................] - ETA: 12:43 - loss: 82.7177

 5504/60000 [=>............................] - ETA: 12:34 - loss: 82.1548

 5632/60000 [=>............................] - ETA: 12:22 - loss: 81.5509

 5760/60000 [=>............................] - ETA: 12:16 - loss: 80.9752

 5888/60000 [=>............................] - ETA: 12:04 - loss: 80.3670

 6016/60000 [==>...........................] - ETA: 11:56 - loss: 79.8007

 6144/60000 [==>...........................] - ETA: 11:50 - loss: 79.2561

 6272/60000 [==>...........................] - ETA: 11:37 - loss: 78.7170

 6400/60000 [==>...........................] - ETA: 11:26 - loss: 78.2422

 6528/60000 [==>...........................] - ETA: 11:13 - loss: 77.7733

 6656/60000 [==>...........................] - ETA: 11:02 - loss: 77.3376

 6784/60000 [==>...........................] - ETA: 10:51 - loss: 76.9062

 6912/60000 [==>...........................] - ETA: 10:41 - loss: 76.5277

 7040/60000 [==>...........................] - ETA: 10:30 - loss: 76.0977

 7168/60000 [==>...........................] - ETA: 10:20 - loss: 75.7122

 7296/60000 [==>...........................] - ETA: 10:10 - loss: 75.3774

 7424/60000 [==>...........................] - ETA: 10:00 - loss: 75.0379

 7552/60000 [==>...........................] - ETA: 9:51 - loss: 74.6927 

 7680/60000 [==>...........................] - ETA: 9:42 - loss: 74.3194

 7808/60000 [==>...........................] - ETA: 9:33 - loss: 74.0200

 7936/60000 [==>...........................] - ETA: 9:25 - loss: 73.7470

 8064/60000 [===>..........................] - ETA: 9:20 - loss: 73.4098

 8192/60000 [===>..........................] - ETA: 9:14 - loss: 73.0993

 8320/60000 [===>..........................] - ETA: 9:09 - loss: 72.8132

 8448/60000 [===>..........................] - ETA: 9:04 - loss: 72.5028

 8576/60000 [===>..........................] - ETA: 8:56 - loss: 72.2286

 8704/60000 [===>..........................] - ETA: 8:49 - loss: 71.9540

 8832/60000 [===>..........................] - ETA: 8:42 - loss: 71.7317

 8960/60000 [===>..........................] - ETA: 8:35 - loss: 71.4825

 9088/60000 [===>..........................] - ETA: 8:29 - loss: 71.2304

 9216/60000 [===>..........................] - ETA: 8:23 - loss: 71.0256

 9344/60000 [===>..........................] - ETA: 8:17 - loss: 70.7949

 9472/60000 [===>..........................] - ETA: 8:11 - loss: 70.5554

 9600/60000 [===>..........................] - ETA: 8:05 - loss: 70.3392

 9728/60000 [===>..........................] - ETA: 8:00 - loss: 70.1304

 9856/60000 [===>..........................] - ETA: 7:54 - loss: 69.8808

 9984/60000 [===>..........................] - ETA: 7:49 - loss: 69.6792

10112/60000 [====>.........................] - ETA: 7:44 - loss: 69.4856

10240/60000 [====>.........................] - ETA: 7:38 - loss: 69.2986

10368/60000 [====>.........................] - ETA: 7:44 - loss: 69.0807

10496/60000 [====>.........................] - ETA: 7:48 - loss: 68.9140

10624/60000 [====>.........................] - ETA: 7:50 - loss: 68.7130

10752/60000 [====>.........................] - ETA: 7:53 - loss: 68.5427

10880/60000 [====>.........................] - ETA: 7:54 - loss: 68.3617

11008/60000 [====>.........................] - ETA: 7:52 - loss: 68.1759

11136/60000 [====>.........................] - ETA: 7:54 - loss: 67.9941

11264/60000 [====>.........................] - ETA: 8:00 - loss: 67.8176

11392/60000 [====>.........................] - ETA: 8:00 - loss: 67.6602

11520/60000 [====>.........................] - ETA: 7:59 - loss: 67.5122

11648/60000 [====>.........................] - ETA: 7:57 - loss: 67.3433

11776/60000 [====>.........................] - ETA: 8:04 - loss: 67.1822

11904/60000 [====>.........................] - ETA: 8:01 - loss: 67.0437

12032/60000 [=====>........................] - ETA: 7:59 - loss: 66.8873

12160/60000 [=====>........................] - ETA: 7:56 - loss: 66.7678

12288/60000 [=====>........................] - ETA: 7:54 - loss: 66.6259

12416/60000 [=====>........................] - ETA: 8:02 - loss: 66.4819

12544/60000 [=====>........................] - ETA: 8:10 - loss: 66.3406

12672/60000 [=====>........................] - ETA: 8:08 - loss: 66.2022

12800/60000 [=====>........................] - ETA: 8:06 - loss: 66.0795

12928/60000 [=====>........................] - ETA: 8:08 - loss: 65.9591

13056/60000 [=====>........................] - ETA: 8:10 - loss: 65.8370

13184/60000 [=====>........................] - ETA: 8:16 - loss: 65.7111

13312/60000 [=====>........................] - ETA: 8:18 - loss: 65.6068

13440/60000 [=====>........................] - ETA: 8:21 - loss: 65.5133

13568/60000 [=====>........................] - ETA: 8:23 - loss: 65.4086

13696/60000 [=====>........................] - ETA: 8:20 - loss: 65.3058

13824/60000 [=====>........................] - ETA: 8:18 - loss: 65.1839

13952/60000 [=====>........................] - ETA: 8:20 - loss: 65.0722







































































































































































































































































































































































































































































































































































































































































































































































<keras.callbacks.History at 0x1f6a89039b0>

In [13]:

def plot_results(models,
                 data,
                 batch_size=128,
                 model_name="vae_mnist"):
   

    encoder, decoder = models
    x_test, y_test = data
    os.makedirs(model_name, exist_ok=True)

    filename = os.path.join(model_name, "vae_mean.png")
    # display a 2D plot of the digit classes in the latent space
    z_mean, z_log_var, z = encoder.predict(x_test,
                                   batch_size=batch_size)
    plt.figure(figsize=(12, 10))
    plt.scatter(z_mean[:, 0], z_mean[:, 1], c=y_test)
    plt.colorbar()
    plt.xlabel("z[0]")
    plt.ylabel("z[1]")
    plt.savefig(filename)
    plt.show()

    # filename = os.path.join(model_name, "digits_over_latent.png")
    # # display a 30x30 2D manifold of digits
    # n = 30
    # digit_size = 28
    # figure = np.zeros((digit_size * n, digit_size * n))
    # # linearly spaced coordinates corresponding to the 2D plot
    # # of digit classes in the latent space
    # grid_x = np.linspace(-4, 4, n)
    # grid_y = np.linspace(-4, 4, n)[::-1]
    # 
    # for i, yi in enumerate(grid_y):
    #     for j, xi in enumerate(grid_x):
    #         z_sample = np.array([[xi, yi]])
    #         x_decoded = decoder.predict(z_sample)
    #         digit = x_decoded[0].reshape(digit_size, digit_size)
    #         figure[i * digit_size: (i + 1) * digit_size,
    #                j * digit_size: (j + 1) * digit_size] = digit
    # 
    # plt.figure(figsize=(10, 10))
    # start_range = digit_size // 2
    # end_range = n * digit_size + start_range + 1
    # pixel_range = np.arange(start_range, end_range, digit_size)
    # sample_range_x = np.round(grid_x, 1)
    # sample_range_y = np.round(grid_y, 1)
    # plt.xticks(pixel_range, sample_range_x)
    # plt.yticks(pixel_range, sample_range_y)
    # plt.xlabel("z[0]")
    # plt.ylabel("z[1]")
    # plt.imshow(figure, cmap='Greys_r')
    # plt.savefig(filename)
    # plt.show()
    
data = (x_test, y_test)    
plot_results(models, data, batch_size=batch_size, model_name="vae_cnn")
