In [2]:
from __future__ import print_function, absolute_import, division

import tensorflow as tf

tfe = tf.contrib.eager
tf.enable_eager_execution()

import os
import time
import numpy as np
import glob
import matplotlib.pyplot as plt
import PIL
import imageio
from IPython import display

# Load MNIST data

In [8]:
# 1. Load data
(train_images, _), (test_images, _) = tf.keras.datasets.mnist.load_data()

#print(train_images.shape, test_images.shape)

# 2. Reshape into n-by-28-28-1
train_images = train_images.reshape(train_images.shape[0], 28, 28, 1).astype('float32')
test_images = test_images.reshape(test_images.shape[0], 28, 28, 1).astype('float32')

#print(train_images.shape, test_images.shape)

# 3. Normalize the intensity into [0, 1]
train_images /= 255.
test_images /=  255.

# 4. Binarize
train_images[train_images >= 0.5] = 1.
train_images[train_images < 0.5] = 0.
test_images[test_images >= 0.5] = 1.
test_images[test_images < 0.5] = 0.


(60000, 28, 28) (10000, 28, 28)


In [12]:
TRAIN_BUFF = 60000
BATCH_SIZE = 100
TEST_BUFF = 10000

# Create Dataset

In [13]:
train_dataset = tf.data.Dataset.from_tensor_slices(train_images).shuffle(TRAIN_BUFF).batch(BATCH_SIZE)
test_dataset = tf.data.Dataset.from_tensor_slices(test_images).shuffle(TEST_BUFF).batch(BATCH_SIZE)

# Build a convolution VAE  model

In [35]:
class CVAE(tf.keras.Model):
    def __init__(self, latent_dim):
        super(CVAE, self).__init__()
        self.latent_dim = latent_dim
        
        # inference net
        self.inference_net = tf.keras.Sequential([
            tf.keras.layers.InputLayer(input_shape=(28, 28, 1)),
            tf.keras.layers.Conv2D(32, kernel_size=(3, 3), strides=(2, 2), activation='relu'),
            tf.keras.layers.Conv2D(64, kernel_size=(3, 3), strides=(2, 2), activation='relu'),
            tf.keras.layers.Flatten(),
            tf.keras.layers.Dense(2 * latent_dim)
        ])
        
        # generative net
        self.generative_net = tf.keras.Sequential([
            tf.keras.layers.InputLayer(input_shape=(latent_dim,)),
            tf.keras.layers.Dense(7*7*32, activation='relu'),
            tf.keras.layers.Reshape(target_shape=(7, 7, 32)),
            tf.keras.layers.Conv2DTranspose(64, kernel_size=(3, 3), strides=(2, 2), 
                                            padding='SAME', activation='relu'),
            tf.keras.layers.Conv2DTranspose(32, kernel_size=(3, 3), strides=(2, 2),
                                            padding='SAME', activation='relu'),
            tf.keras.layers.Conv2DTranspose(1, kernel_size=(3, 3), strides=(1, 1), padding='SAME')
        ])
        
        def sample(self, eps=None):
            if eps == None:
                eps = tf.random_normal(shape=(100, self.latent_dim))
            return self.decode(eps, apply_sigmoid=True)
        
        def encode(self, x):
            mean, logvar = tf.split(self.inference_net(x), num_or_size_splits=2, axis=1)
            return mean, logvar
        
        def reparameterize(self, mean, logvar):
            eps = tf.random_normal(shape=mean.shape)
            return eps * tf.exp(logvar * 0.5) + mean
        
        def decode(self, z, apply_sigmoid=False):
            logits = self.generative_net(z)
            if apply_sigmoid:
                return tf.sigmoid(logits)
            return logits        

# Define loss and optimizer

In [None]:
tf.losses.sigmoid_cross_entropy()

In [33]:
x = tf.random_normal((5,8))
print(x)
tf.split(x,num_or_size_splits=2, axis=1)

tf.Tensor(
[[ 0.04974003  1.879085   -1.6373792  -0.8047549  -1.6680679  -1.7090809
   0.1803183   1.4117349 ]
 [ 0.8699547  -0.879452    1.5379285   0.9367092  -1.2001451  -1.6544592
  -1.884405   -0.3977359 ]
 [ 0.041895   -0.0760288  -0.08455824 -1.1129953  -1.2623453   0.7881292
   1.0732952  -0.65369636]
 [-0.66194594 -1.2186763   1.9461879   0.8346287  -1.3878553   0.613098
  -1.9747087  -2.2258372 ]
 [-0.5341543   1.1518692  -1.5466375   0.8781498   0.46751708  0.10178968
   0.8567726  -0.5413133 ]], shape=(5, 8), dtype=float32)


[<tf.Tensor: id=391, shape=(5, 4), dtype=float32, numpy=
 array([[ 0.04974003,  1.879085  , -1.6373792 , -0.8047549 ],
        [ 0.8699547 , -0.879452  ,  1.5379285 ,  0.9367092 ],
        [ 0.041895  , -0.0760288 , -0.08455824, -1.1129953 ],
        [-0.66194594, -1.2186763 ,  1.9461879 ,  0.8346287 ],
        [-0.5341543 ,  1.1518692 , -1.5466375 ,  0.8781498 ]],
       dtype=float32)>, <tf.Tensor: id=392, shape=(5, 4), dtype=float32, numpy=
 array([[-1.6680679 , -1.7090809 ,  0.1803183 ,  1.4117349 ],
        [-1.2001451 , -1.6544592 , -1.884405  , -0.3977359 ],
        [-1.2623453 ,  0.7881292 ,  1.0732952 , -0.65369636],
        [-1.3878553 ,  0.613098  , -1.9747087 , -2.2258372 ],
        [ 0.46751708,  0.10178968,  0.8567726 , -0.5413133 ]],
       dtype=float32)>]