In [0]:
%matplotlib inline
%tensorflow_version 2.x

import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf
import typing
import math
import tqdm
import os
import itertools
import uuid
import copy
import subprocess
import imageio
import IPython


from IPython import display
from IPython.display import Image
from pathlib import Path
from enum import IntEnum

In [0]:
class DSprites:
    class Latents(IntEnum):
        COLOUR, SHAPE, SCALE, ORIENTATION, XPOS, YPOS = range(6)
        
    """
    A significant portion of this class is taken as-is from the manual
    https://github.com/deepmind/dsprites-dataset/blob/master/dsprites_reloading_example.ipynb
    """
    def __init__(self, path=".", download=False):
        self._filename = 'dsprites_ndarray_co1sh3sc6or40x32y32_64x64.npz'
        path = Path(path).absolute()
        assert path.exists()
        data = path / self._filename
        if not data.exists():
            if download:
                subprocess.run(["wget", "-O", str(data),
                                "https://github.com/deepmind/dsprites-dataset/raw/master/dsprites_ndarray_co1sh3sc6or40x32y32_64x64.npz"])
            else:
                raise ValueException("Can't find dataset, use download=True to download.")
        data = np.load(str(data), encoding='bytes', allow_pickle=True)

        # ====

        imgs = data['imgs']
        metadata_raw = data['metadata'][()]
        self._metadata = dict()
        for k, v in metadata_raw.items():
            self._metadata[k.decode()] = v

        # NOTE: can't cast now because our notebook runs out of RAM. cast in map instead
        # imgs = imgs.reshape(-1, 64, 64, 1).astype(np.float32)
        self._imgs = imgs.reshape(-1, 64, 64, 1)

        # for example: array([ 0,  0,  2, 37, 15, 22])
        # i.e. the relative (normalised) change in latent factors
        self._latents_classes = data['latents_classes']

        # for example: array([1., 1., 0.7 , 5.96097068, 0.48387097, 0.70967742])
        # i.e. the actual latent values used to generate the image
        self._latents_values = data['latents_values']

        # specification: the number of varying "degrees" of change in each
        # dimension corresponding to an independent generative factor
        # array([ 1,        3,      6,           40,  32,  32])
        #       colour, shape,  scale,  orientation,   X,   Y
        self._latents_sizes = self._metadata['latents_sizes']

        # for easy conversion from latent vector to indices later (see latent_to_idx)
        # essentially: array([737280, 245760,  40960,   1024,     32,      1])
        self._latents_bases = np.r_[self._latents_sizes[::-1].cumprod()[::-1][1:], 1]
    
    def latent_size(self, latent: 'DSprites.Latents') -> int:
        """
        :param latent: of type DSprites.Latents (an enum class)
        :return: the maximum integer allowed for the specified `latent`
        """
        return self._latents_sizes[latent.value]
    
    def to_idx(self, latents: np.array) -> int:
        """
        convert latent vector into index that can then be used to index
        the actual image in self._imgs
        """
        return np.dot(latents, self._latents_bases).astype(int)
    
    def sample_latent(self, n: int=1, fixed: 'DSprites.Latents'=None) -> np.array:
        """
        randomly samples `n` latent vectors

        :param n: number samples
        :param fixed: if not `None`, then in all samples, this latent is kept
                     fixed based on a random draw. The rest of the latents are
                     random.
        :return: an `np.array` of shape nx6
        """
        samples = np.zeros((n, self._latents_sizes.shape[0]))
        for i, lat_size in enumerate(self._latents_sizes):
            samples[:, i] = np.random.randint(lat_size, size=n)
        if fixed:
            samples[:, fixed] = np.random.randint(0, ds.latent_size(fixed))
        return samples
    
    @property
    def imgs(self) -> np.array:
        return self._imgs

    def subset(self, size=50_000) -> np.array:
        """
        returns a subset of the images. (Workaround for memory constraints)
        :param size: number of samples to return
        """
        return self._imgs[np.random.choice(self._imgs.shape[0], size=size, replace=False)]

def make_grid(tensor: np.array, nrow: int=8, padding: int=2, pad_value: int=0) -> np.array:
    """
    adapted from: https://pytorch.org/docs/stable/_modules/torchvision/utils.html#make_grid
    :param tensor: nxwxhxc np.array
    :param nrow: number of rows to use.
    :param padding: padding between images
    :param pad_value: value used to pad
    :return: np.array of dimension 3 (WxHxC) with all images arranged in a grid.

    """
    if tensor.shape[0] == 1:
        return tensor.squeeze()

    # make the mini-batch of images into a grid
    nmaps = tensor.shape[0]
    xmaps = min(nrow, nmaps)
    ymaps = int(math.ceil(float(nmaps) / xmaps))
    height, width = int(tensor.shape[1] + padding), int(tensor.shape[2] + padding)
    num_channels = tensor.shape[3]
    grid = np.full((height * ymaps + padding, width * xmaps + padding, num_channels), pad_value, dtype=tensor.dtype)
    k = 0
    for y in range(ymaps):
        for x in range(xmaps):
            if k >= nmaps: break
            ystart = y * height + padding
            xstart = x * width + padding
            grid[ystart:(ystart + height - padding), ...][:, xstart:(xstart + width - padding), :] = tensor[k]
            k = k + 1
    return grid.squeeze()

def imshow(img: np.array, title: str='', ax: plt.Axes=None):
    if not ax:
        fig = plt.figure(figsize=(15, 15))
        ax = fig.add_subplot(111)
    ax.imshow(img, cmap='gray', interpolation='nearest')
    ax.set_xticks(())
    ax.set_yticks(())
    if title:
        ax.set_title(title)

In [0]:
ds = DSprites(download=True)

In [0]:
imshow(make_grid(ds.imgs[ds.to_idx(ds.sample_latent(32))]))

In [0]:
random_latents = ds.sample_latent(32, fixed=DSprites.Latents.SHAPE)
imshow(make_grid(ds.imgs[ds.to_idx(random_latents)]))

In [0]:
def criterion(x, x_recon, mean, logvar, beta=1.0, lhd='bernoulli'):
    """
    x - original image
    x_recon - LOGITS! depending on `lhd`, it'll either be activated by sigmoid
              or a normal distribution
    """
    kl = -0.5 * tf.reduce_sum(1 + logvar - mean**2 - tf.exp(logvar))
    if lhd.lower() == 'bernoulli':
        lh = tf.reduce_sum(tf.nn.sigmoid_cross_entropy_with_logits(logits=x_recon, labels=x))
    elif lhd.lower() == 'normal':
        # TODO!
        raise NotImplementedError
    else:
        raise ValueError(f"Expected lhd to be one of bernoulli or normal, got {lhd}.")
    return (lh + beta * kl) / x.shape[0]

def get_step_function():
    @tf.function
    def step(model, x, optimiser, beta, lhd):
        with tf.GradientTape() as tape:
            z, mean, logvar = model.encode(x)
            x_recon = model.decode(z)
            loss = criterion(x, x_recon, mean, logvar, beta, lhd)
        gradients = tape.gradient(loss, model.trainable_variables)
        optimiser.apply_gradients(zip(gradients, model.trainable_variables))
        return loss
    return step
    
def train_model(model, epochs, train, optimiser, beta, lhd='bernoulli'):
    # redefine step function because it's a compiled graph (bug)
    # see: https://github.com/tensorflow/tensorflow/issues/27120
    step = get_step_function()
    rtqdm = tqdm.trange(epochs)
    losses = []
    for e in rtqdm:
        epochs_losses = []
        for x in train:
            epochs_losses.append(step(model, x, optimiser, beta, lhd))
        losses.append(np.mean(epochs_losses))
        rtqdm.set_postfix(loss=losses[-1])
    return losses

# Build a model according to specification (Tbl. 1 Higgins et al.)
* **Input**  4096 (flattened 64x64x1).
* **Encoder**  FC 1200, 1200. ReLU activation.
* **Latents**  10
* **Decoder**  FC 1200, 1200, 1200, 4096. Tanh activation. Bernoulli.
* **Optimiser** Adagrad 1e-2

In [0]:
class VAE(tf.keras.Model):
    def __init__(self, latent_dim=10):
        super(VAE, self).__init__()
        self._latent_dim = latent_dim
        self.encoder = tf.keras.Sequential([
            tf.keras.layers.InputLayer((64, 64, 1)),
            tf.keras.layers.Flatten(),
            tf.keras.layers.Dense(1200, activation='relu'),
            tf.keras.layers.Dense(1200, activation='relu'),
            tf.keras.layers.Dense(latent_dim * 2)
        ])
        self.decoder = tf.keras.Sequential([
            tf.keras.layers.InputLayer(latent_dim),
            tf.keras.layers.Dense(1200, activation='tanh'),
            tf.keras.layers.Dense(1200, activation='tanh'),
            tf.keras.layers.Dense(1200, activation='tanh'),
            tf.keras.layers.Dense(4096),
            tf.keras.layers.Reshape((64, 64, 1))
        ])

    def call(self, x): raise NotImplementedError

    def encode(self, x):
        h = self.encoder(x)
        mean, logvar = tf.split(h, num_or_size_splits=2, axis=1)
        return self.reparameterise(mean, logvar), mean, logvar
    
    def decode(self, z):
        return self.decoder(z)

    @staticmethod
    def reparameterise(mean, logvar):
        # log sig^2 = 2 log sig => exp(1/2 log sig^2) = exp(log sig) = sig
        eps = tf.random.normal(mean.shape, mean=0.0, stddev=1.0)
        return mean + tf.exp(logvar * 0.5) * eps

BETA = 4
EPOCHS = 25
LATENT_DIM = 10
BATCH_SIZE = 32
TR_SIZE = 400_000

In [0]:
def cast_dtype(x):
    return tf.cast(x, tf.float32)

vae4 = VAE(latent_dim=LATENT_DIM)
train = tf.data.Dataset.from_tensor_slices(ds.subset(size=TR_SIZE))
train = (train
         .map(cast_dtype)
         .shuffle(2**10)
         .batch(BATCH_SIZE))
losses4 = train_model(vae4, EPOCHS,
                      train, tf.keras.optimizers.Adagrad(learning_rate=1e-2),
                      beta=BETA, lhd='bernoulli')
plt.plot(losses4)
plt.title("Average epoch loss")
plt.xlabel("Epoch")
plt.ylabel("ELBO loss")
plt.show()

In [0]:
x = ds.subset(size=32)
z, *_= vae4.encode(x)
x_recon = tf.nn.sigmoid(vae4.decode(z)).numpy()
fig = plt.figure(figsize=(20, 20))
ax1 = fig.add_subplot(121)
ax2 = fig.add_subplot(122)
imshow(make_grid(x), title='Original', ax=ax1)
imshow(make_grid(x_recon), title='Reconstructed', ax=ax2)
fig.subplots_adjust(wspace=0.01)
plt.show()

In [0]:
weight_file = f'./vae4/vae4_e{EPOCHS}_{TR_SIZE}'
vae4.save_weights(weight_file)
# to load:
# vae4 = VAE(LATENT_DIM)
# vae4.load_weights(weight_file)

In [0]:
vae1 = VAE(latent_dim=LATENT_DIM)
losses1 = train_model(vae1, EPOCHS,
                      train, tf.keras.optimizers.Adagrad(learning_rate=1e-2),
                      beta=1, lhd='bernoulli')
plt.plot(losses1, label='beta=1')
plt.plot(losses4, label='beta=4')
plt.title("Average epoch loss")
plt.xlabel("Epoch")
plt.ylabel("ELBO loss")
plt.legend()
plt.show()

In [0]:
weight_file = f'./vae1/vae1_e{EPOCHS}_{TR_SIZE}'
vae1.save_weights(weight_file)

In [0]:
z, *_= vae1.encode(x)
x_recon = tf.nn.sigmoid(vae1.decode(z)).numpy()
fig = plt.figure(figsize=(20, 20))
ax1 = fig.add_subplot(121)
ax2 = fig.add_subplot(122)
imshow(make_grid(x), title='Original', ax=ax1)
imshow(make_grid(x_recon), title='Reconstructed', ax=ax2)
fig.subplots_adjust(wspace=0.01)
plt.show()

# Visualising latents

In [0]:
class LatentVisualiser:
    def __init__(self, model: tf.keras.Model, dim: int=0,
                 init_latent: tf.Tensor=None, width: int=8,
                 height: int=8):
        assert dim < LATENT_DIM 
        if init_latent is not None:
            self._rnd = tf.identity(init_latent)
        else:
            self._rnd = tf.convert_to_tensor(np.random.normal(0, 1, rnd_shape), dtype=tf.float32)
        self._dim = dim
        self._filenames = []
        self._model = model
        self._width = width
        self._height = height

    def __call__(self, iters: int, step: float, save: str=None):
        crnd = self._rnd.numpy().copy()
        for i in range(1, iters + 1):
            crnd[:, self._dim] += step
            recon = tf.nn.sigmoid(self._model.decode(tf.convert_to_tensor(crnd))).numpy()
            fig = plt.figure(figsize=(self._width, self._height))
            ax = fig.add_subplot(111)
            # hacky!: to prevent similar filenames
            filename = f'{uuid.uuid4().hex}.png'
            imshow(make_grid(recon), title=f"Dimension {self._dim}: {i * step:0.4f}", ax=ax)
            fig.savefig(filename, bbox_inches='tight')
            self._filenames.append(filename)
            plt.close()
        output_filename = save if save else "_temp.gif"
        assert output_filename.endswith(".gif")
        self._generate_fig(output_filename)
        # with open(output_filename,'rb') as f:
        #     display.display(Image(data=f.read(), format='png'))
        return output_filename

    def _generate_fig(self, output_filename):
        assert len(self._filenames)
        # taken from https://www.tensorflow.org/tutorials/generative/cvae#generate_a_gif_of_all_the_saved_images
        with imageio.get_writer(output_filename, mode='I') as writer:
            last = -1
            for i, filename in enumerate(self._filenames):
                frame = 2*(i**0.5)
                if round(frame) > round(last): last = frame
                else: continue
                image = imageio.imread(filename)
                writer.append_data(image)
            image = imageio.imread(filename)
            writer.append_data(image)
    
    @staticmethod
    def combine_gifs(filenames: typing.List[str], output: str, cols: int) -> str:
        # rows = math.ceil(len(filenames) / cols)
        assert len(filenames) % cols == 0
        rows = int(len(filenames) / cols)
        gifs = [imageio.get_reader(f) for f in filenames]
        frames = gifs[0].get_length()
        assert all(g.get_length() == frames for g in gifs)
        gifs = itertools.cycle(gifs)

        with imageio.get_writer(output, mode='I') as writer:
            for _ in range(frames):
                # buf = [[] for _ in range(rows)]
                buf = []
                for row in range(rows):
                    row_buffer = [next(gifs).get_next_data() for _ in range(cols)]
                    buf.append(np.hstack(row_buffer))
                new_image = np.vstack(buf)
                writer.append_data(new_image)
        return output

In [0]:
ITERS = 30
STEP = 0.25

# rnd = tf.convert_to_tensor(np.random.normal(0, 1, size=(32, LATENT_DIM)), dtype=tf.float32)
rnd, *_ = vae4.encode(x)

gifs = []
for dim in range(0, LATENT_DIM):
    # model with beta=4
    gifs.append(LatentVisualiser(vae4, dim=dim, init_latent=rnd)(iters=ITERS, step=STEP, save=f"beta4_dim{dim}.gif"))

LatentVisualiser.combine_gifs(gifs, "beta4.gif", 2)
with open("beta4.gif",'rb') as f:
    display.display(Image(data=f.read(), format='png'))
        

In [0]:
ITERS = 30
STEP = -0.25

gifs = []
for dim in range(0, LATENT_DIM):
    # model with beta=4
    gifs.append(LatentVisualiser(vae4, dim=dim, init_latent=rnd)(iters=ITERS, step=STEP, save=f"beta4_dim{dim}_n.gif"))

LatentVisualiser.combine_gifs(gifs, "beta4_n.gif", 2)
with open("beta4_n.gif",'rb') as f:
    display.display(Image(data=f.read(), format='png'))