In [1]:
%matplotlib inline
from __future__ import print_function
from PIL import Image
import numpy as np
import tensorflow as tf
import os
import glob
import matplotlib.pyplot as plt

In [2]:
def conv_block(inputs, out_channels, kernel_size=3, strides=2, padding='SAME', name='conv'):
    with tf.variable_scope(name):
        conv = tf.layers.conv2d(inputs, out_channels, kernel_size=kernel_size, strides=strides, padding=padding)
        conv = tf.contrib.layers.batch_norm(conv, updates_collections=None, decay=0.99, scale=True, center=True)
        conv = tf.nn.relu(conv)
        #conv = tf.contrib.layers.max_pool2d(conv, 2)
        return conv

In [3]:
def encoder(x, h_dim, z_dim, reuse=False):
    with tf.variable_scope('encoder', reuse=reuse):
        net = conv_block(x, h_dim, kernel_size=3, strides=2, padding='SAME', name='conv_1') # 42x42
        print(net.shape)
        net = conv_block(net, h_dim, kernel_size=3, strides=2, padding='SAME', name='conv_2') # 21x21
        print(net.shape)
        net = conv_block(net, h_dim,  kernel_size=3, strides=2, padding='VALID', name='conv_3') # 10x10
        print(net.shape)
        net = conv_block(net, z_dim, name='conv_4') # 6x6
        print(net.shape)
        net = tf.contrib.layers.flatten(net)
        return net

In [4]:
input = tf.placeholder(dtype=tf.float32, shape=[1, 84, 84, 1])

In [5]:
output = encoder(input, 64, 32)

(1, 42, 42, 64)
(1, 21, 21, 64)
(1, 10, 10, 64)
(1, 5, 5, 32)


In [6]:
def decoder(x, h_dim, z_dim, reuse=False):
    with tf.variable_scope('decoder', reuse=reuse):
        net = tf.layers.dense(x, 5 * 5 * 64)
        net = tf.reshape(net, [-1, 5, 5, 64])
        net = deconv_block(net, h_dim, size=4, stride=2, padding='SAME', name='deconv_1') # 10x10
        net = deconv_block(net, h_dim, size=3, stride=2, padding='VALID', name='deconv_2') # 21x21
        net = deconv_block(net, h_dim, size=4, stride=2, padding='SAME', name='deconv_3') # 42x42
        net = deconv_block(net, h_dim, size=4, stride=2, padding='SAME', name='deconv_4') # 84x84
        net = tf.layers.conv2d(net, 3, 3, padding='SAME')
        print(net.shape)
        net = tf.nn.tanh(net)
        return net

In [7]:
def deconv_block(inputs, out_channels, size=3, stride=2, padding='SAME', name='deconv'):
    with tf.variable_scope(name):
        conv = tf.layers.conv2d_transpose(inputs, out_channels, kernel_size=size, strides=stride, padding=padding)
        conv = tf.contrib.layers.batch_norm(conv, updates_collections=None, decay=0.99, scale=True, center=True)
        conv = tf.nn.relu(conv)
        return conv

In [8]:
result = decoder(output, 64, 64)

(1, 84, 84, 3)


In [9]:
result.shape

TensorShape([Dimension(1), Dimension(84), Dimension(84), Dimension(3)])