In [1]:
import os 
import numpy as np
import tensorflow as tf
from tensorflow.keras import datasets
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Dense
from tensorflow.keras import optimizers
from tensorflow.keras import metrics
from DataLoader import DataLoader
import matplotlib.pyplot as plt
#matplotlib inline

In [3]:
class VAE(Model):
    def __init__(self):
        self.encoder = Encoder()
        self.decoder = Decoder()

    def call(self, x):
        mean, var = self.encoder(x)
        z = self.reparameterize(mean, var)
        y = self.decoder(z)
        return y

    def reparameterize(self, mean, var):
        eps = tf.random.normal(mean.shape)
        z = mean + tf.math.sqrt(var) * eps
        return z

    def lower_bound(self, x):
        mean, var = self.encoder(x)
        kl = -1/2 * tf.reduce_mean(tf.reduce_sum(1 + self._log(var, max=var)-mean**2-var, axis=1))
        z = self.reparameterize(mean, var)
        y = self.decoder(z)

        reconst = tf.reduce_mean(tf.reduce_sum(x * self._log(y) + (1-x)*self._log(1-y), axis=1))
        L = reconst - kl
        return L

    def _log(self, value, min=1.e-10, max=1.0):
        return tf.math.log(tf.clip_by_value(value, min, max))

In [5]:
class Encoder(Model):
    def __init__(self):
        super().__init__()
        self.l1 = Dense(200, activation="relu")
        self.l2 = Dense(200, activation="relu")
        self.l_mean = Dense(10, activation="linear")
        self.l_var = Dense(10, activation=tf.nn.softplus)

    def call(self, x):
        h = self.l1(x)
        h = self.l2(h)

        mean = self.l_mean(h)
        var = self.l_var(h)
        return mean, var

In [None]:
class Decoder(Model):
    def __init__(self):
        super().__init__()
        self.l1 = Dense(200, activation="relu")
        self.l2 = Dense(200, activation="relu")
        self.out = Dense(783, activation="sigmoid")

    def call(self, x):
        h = self.l1(x)
        h = self.l2(h)
        y = self.out(h)
        return y
        