In [None]:
%matplotlib inline
from mpl_toolkits.mplot3d import Axes3D
import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf

In [None]:
class VAE:
    def __init__(self, batch_size=32, latent_dim=2):
        self.latent_dim = latent_dim
        self.batch_size = tf.cast(tf.placeholder_with_default(batch_size, shape=()), dtype=tf.int64)
        self.convd_size = 22
        self.dense_size = int(np.sqrt(self.convd_size * self.convd_size * 16))
        
        self.is_training = tf.placeholder_with_default(True, shape=())
        self.image_input = tf.placeholder(dtype=tf.float32, shape=[None, 28, 28, 1])
        self.image_batch, self.iterator, _ = self._make_dataset_iterator()
        self.z_mean, self.z_log_var = self._encoder()
        self.z = self._sampler()
        self.decoded = self._decoder()
        
        self.loss, self.optimization, self.reconstruction_loss, self.latent_loss = self._make_loss_opt()
        
    def _make_dataset_iterator(self):
        dataset = tf.data.Dataset.from_tensor_slices(self.image_input)
        dataset = dataset.shuffle(buffer_size=20000)
        dataset = dataset.batch(batch_size=self.batch_size)

        iterator = dataset.make_initializable_iterator()
        image_batch = iterator.get_next()
        return image_batch, iterator, dataset
        
    def _encoder(self):
        conv_kwargs = {'kernel_size': 3, 'filters': 16, 'padding': 'valid', 'strides': 1, 'activation': tf.nn.relu}
        x = tf.layers.conv2d(self.image_batch, **conv_kwargs)
        x = tf.layers.batch_normalization(x, training=self.is_training)
        x = tf.layers.conv2d(x, **conv_kwargs)
        x = tf.layers.conv2d(x, **conv_kwargs)
        x = tf.layers.flatten(x)
        x = tf.layers.dense(x, units=self.dense_size, activation=tf.nn.relu)
        z_mean = tf.layers.dense(x, units=self.latent_dim)
        z_log_var = tf.layers.dense(x, units=self.latent_dim)
        return z_mean, z_log_var
    
    def _sampler(self):
        self.samples = tf.random_normal(shape=[self.batch_size, self.latent_dim],
                                   mean=0.,
                                   stddev=1.,
                                   dtype=tf.float32)
        z = self.z_mean + tf.sqrt(tf.exp(self.z_log_var)) * self.samples
        return z
        
    def _decoder(self):
        conv_kwargs = {'kernel_size': 3, 'strides': 1, 'activation': tf.nn.relu}
        self.z = tf.Print(self.z, [self.z_mean, self.z_log_var, self.samples])
        x = tf.layers.dense(self.z, units=self.dense_size, activation=tf.nn.relu)
        x = tf.layers.dense(x, units=self.dense_size ** 2, activation=tf.nn.relu)
        x = tf.reshape(x, shape=[-1, self.convd_size, self.convd_size, 16])
        x = tf.layers.conv2d_transpose(x, filters=16, padding='valid', **conv_kwargs)
        x = tf.layers.conv2d_transpose(x, filters=16, padding='valid', **conv_kwargs)
        decoded = tf.layers.conv2d_transpose(x, filters=1,padding='valid', **conv_kwargs)
        return decoded
    
    def _make_loss_opt(self):        
        reconstruction_loss = 0.5 * tf.reduce_sum(tf.squared_difference(self.decoded, self.image_batch), axis=[1,2,3])
        reconstruction_loss = tf.reduce_mean(reconstruction_loss)
        latent_loss = 0.5 * tf.reduce_sum(1 + self.z_log_var - self.z_mean ** 2 - tf.exp(self.z_log_var), axis=1)
        latent_loss = tf.reduce_mean(latent_loss)
        loss = reconstruction_loss + latent_loss
        
        opt = tf.train.AdamOptimizer(learning_rate=1e-4).minimize(loss)
        return loss, opt, reconstruction_loss, latent_loss
    
    def train(self, session, images):
        session.run(self.iterator.initializer, feed_dict={self.image_input: images})
        
        while True:
            try:
                _, loss, reconstruction_loss, latent_loss = session.run(
                    [self.optimization, self.loss, self.reconstruction_loss, self.latent_loss], 
                    feed_dict={self.is_training: True}
                )
            except tf.errors.OutOfRangeError:
                break
                
        return loss, reconstruction_loss, latent_loss
    
    def predict(self, session, images):
        session.run(self.iterator.initializer, feed_dict={self.image_input: images})
        
        _, loss, reconstruction_loss, latent_loss = session.run(
            [self.optimization, self.loss, self.reconstruction_loss, self.latent_loss], 
            feed_dict={self.batch_size: 1,
                       self.is_training: True, 
                       self.image_input: images}
        )
        return None

In [None]:
with np.load('vae-cvae-challenge.npz') as fh:
    images, labels = fh['data_x'], fh['data_y']
    images = np.reshape(images, newshape=[-1, 28, 28, 1])
print(f'image shape: {images.shape}, labels shape: {labels.shape}')

In [None]:
count_epochs = 1
vae = VAE()

# Check sampling with tf.Print (alraedy in there, but logging doesn't work in jupyter.)
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    
    for epoch in range(count_epochs):
        loss, reconstruction_loss, latent_loss = vae.train(sess, images)
        

In [None]:
x = tf.constant(np.random.rand(10,20,5))
x = tf.reshape(x, shape=tf.constant((-1, 2)))

In [None]:
item = tf.constant(np.random.rand(10,5,3,2))
ok = tf.reduce_sum(np.random.rand(10,5,3,2), axis=item.rank)
with tf.Session() as sess:
    print(sess.run(ok.shape))

In [None]:
len(item.shape)

In [None]:
def iterate(predicate, images, iterator, session):
    session.run(iterator.initializer, feed_dict={real_images: images})
    while True:
        try:
            result = session.run(predicate)
        except tf.errors.OutOfRangeError:
            break
    return result