In [1]:
import tensorflow as tf
from mnist import load_mnist, evaluate

tf.random.set_seed(42)

tf.keras.backend.clear_session()

print(tf.__version__)

2.3.0


In [2]:
# configurations

IMAGE_SIZE = (16, 16)
BINARIZE = True
AUTOENCODER_TYPE = ('vanilla', 'variational')[1]
NUM_DATA = 100000

In [3]:
(x_train, _), _ = load_mnist(image_size=IMAGE_SIZE, binarize=BINARIZE)
x_train = (x_train + 1) / 2

In [4]:
def binarize(x, threshold):
    """Returns 1 if x > threshold else 0, element-wisely.

    Parameters
    ----------
    x : tensor
    threshold : float

    Returns
    -------
    tensor
        The same shape and dtype as x.
    """
    y = tf.where(x > threshold, 1, 0)
    y = tf.cast(y, x.dtype)
    return y


def softly_binarize(x, threshold):
    """Returns 1 if x > threshold else 0, element-wisely, with the gradients
    :math:`\partial f_i / \partial x_j = \delta_{i j}`, i.e. an unit Jacobian.

    Parameters
    ----------
    x : tensor
    threshold : float

    Returns
    -------
    tensor
        The same shape and dtype as x.
    """

    def identity(dy):
        return dy

    @tf.custom_gradient
    def fn(x):
        y = binarize(x, threshold)
        return y, identity

    return fn(x)


class SoftBinarization(tf.keras.layers.Layer):
    """Do nothing when training."""

    def __init__(self, threshold, **kwargs):
        super().__init__(**kwargs)
        self.threshold = threshold

    def call(self, x, training=None):
        if training:
            return x
        return softly_binarize(x, self.threshold)

In [5]:
class FeedForwardNetwork(tf.keras.layers.Layer):

    def __init__(self, units, activation=None, **kwargs):
        super().__init__(**kwargs)
        self.units = units
        self.activation = activation

    def build(self, batch_input_shape):
        layers = [tf.keras.layers.Dense(n, 'relu') for n in self.units[:-1]]
        layers.append(tf.keras.layers.Dense(self.units[-1], self.activation))
        self._ffn = tf.keras.Sequential(layers)
        self._ffn.build(batch_input_shape)
        super().build(batch_input_shape)

    def call(self, x):
        y = self._ffn(x)
        return y

In [6]:
class LatentBernoulliVanillaAutoencoder(tf.keras.layers.Layer):

    def __init__(self, units, activation=None, **kwargs):
        super().__init__(**kwargs)
        self.units = units
        self.activation = activation

        self._encoder = FeedForwardNetwork(units, 'sigmoid')

    def build(self, batch_input_shape):
        ambient_dim = batch_input_shape[-1]
        units = self.units[::-1][1:] + [ambient_dim]  # symmetric structure
        self._decoder = FeedForwardNetwork(units, self.activation)
        super().build(batch_input_shape)

    def encode(self, x):
        z = self._encoder(x)
        z = softly_binarize(z, 0.5)
        return z

    def decode(self, z):
        x = self._decoder(z)
        return x

    def call(self, x):
        z = self.encode(x)
        x_recon = self.decode(z)
        return x_recon

In [7]:
class LatentBernoulliVariationalAutoencoder(tf.keras.layers.Layer):
    """
    Notes
    -----
    Useful relations:
        log(sigmoid(x)) = x - softplus(x)
        log(1 - sigmoid(x)) = - softplus(x)
        log(sigmoid(x)) - log(1 - sigmoid(x)) = x

    References
    ----------
    1. https://davidstutz.de/bernoulli-variational-auto-encoder-in-torch/
    """

    def __init__(self, units, activation=None, use_prior=False, **kwargs):
        super().__init__(**kwargs)
        self.units = units
        self.activation = activation
        self.use_prior = use_prior

        self._encoder = FeedForwardNetwork(units, name='encoder')

    def build(self, batch_input_shape):
        ambient_dim = batch_input_shape[-1]
        units = self.units[::-1][1:] + [ambient_dim]  # symmetric structure
        self._decoder = FeedForwardNetwork(units, self.activation, name='decoder')
        super().build(batch_input_shape)

    def call(self, x, training=None):
        logits = self._encoder(x)
        z = self._reparam_trick(logits)
        x_recon = self._decoder(z)
        if training and self.use_prior:
            self.add_loss(self._kl_div(z, logits))
        return x_recon

    @staticmethod
    def _reparam_trick(logits):
        """
        s ~ uniform(0, 1)
        a = s / (1 - s) * p / (1 - p)
        z = 1 if log(a) > 0 else 0
        => z ~ bernoulli(p)
        """
        eps = 1e-8
        s = tf.random.uniform(logits.shape[1:], minval=eps, maxval=1-eps)
        s = s[tf.newaxis, ...]
        a = tf.math.log(s) - tf.math.log(1 - s) + logits
        z = softly_binarize(tf.nn.sigmoid(a), 0.5)
        return z

    @staticmethod
    def _kl_div(z, logits):
        """KL-divergence between Q-distribution and latent prior."""
        log_z = log_sigmoid(logits)
        log_1mz = log_1m_sigmoid(logits)
        return tf.reduce_mean(
            tf.where(z > 0.5, log_z, log_1mz)
        )


def log_sigmoid(x):
    return x - tf.nn.softplus(x)


def log_1m_sigmoid(x):
    return - tf.nn.softplus(x)

In [8]:
if AUTOENCODER_TYPE == 'vanilla':
    layers = [
        LatentBernoulliVanillaAutoencoder([128], 'sigmoid'),
    ]
elif AUTOENCODER_TYPE == 'variational':
    layers = [
        LatentBernoulliVariationalAutoencoder([128], 'sigmoid'),
    ]
else:
    raise ValueError()
if BINARIZE:
    layers.append(SoftBinarization(0.5))
ae = tf.keras.Sequential(layers)
ae.compile(loss='binary_crossentropy', optimizer='adam')

In [9]:
ds = tf.data.Dataset.from_tensor_slices((x_train[:NUM_DATA], x_train[:NUM_DATA]))
ds = ds.shuffle(10000).repeat(20).batch(128)
ae.fit(ds)



<tensorflow.python.keras.callbacks.History at 0x7fafd25136d0>

In [10]:
evaluate(ae, x_train[:1000])

0.9731875