In [1]:
import numpy as np
import jax
import jax.numpy as jnp
from flax import linen as nn
from flax.training import train_state
import optax
import tensorflow as tf
import tensorflow_datasets as tfds
import matplotlib.pyplot as plt
from PIL import Image

import math

### Prepare dataset

In [2]:
BATCH_SIZE = 128


def prepare_image(x):
    x = tf.cast(x['image'], tf.float32)
    x = tf.reshape(x, (-1,))
    return x

ds_builder = tfds.builder('binarized_mnist')
ds_builder.download_and_prepare()
train_ds = ds_builder.as_dataset(split=tfds.Split.TRAIN)
train_ds = train_ds.map(prepare_image)
train_ds = train_ds.cache()
train_ds = train_ds.repeat()
train_ds = train_ds.shuffle(50000)
train_ds = train_ds.batch(BATCH_SIZE)
train_ds = iter(tfds.as_numpy(train_ds))

test_ds = ds_builder.as_dataset(split=tfds.Split.TEST)
test_ds = test_ds.map(prepare_image).batch(10000)
test_ds = np.array(list(test_ds)[0])
test_ds = jax.device_put(test_ds)

print(train_ds)
print(test_ds.shape)

<generator object _eager_dataset_iterator at 0x7f033932ecf0>
(10000, 784)


### Create Auto encoder

In [3]:
LATENT_SIZE = 32
LEARNING_RATE = 1e-3

class Encoder(nn.Module):
    latents: int
    
    @nn.compact
    def __call__(self, x):
        x = nn.Dense(256)(x)
        x = nn.relu(x)
        x = nn.Dense(256)(x)
        x = nn.relu(x)
        mean_x = nn.Dense(self.latents)(x)
        logvar_x = nn.Dense(self.latents)(x)
        return mean_x, logvar_x
    
class Decoder(nn.Module):
    @nn.compact
    def __call__(self, z):
        z = nn.Dense(256)(z)
        z = nn.relu(z)
        z = nn.Dense(256)(z)
        z = nn.relu(z)
        z = nn.Dense(784)(z)
        return z
    

class VAE(nn.Module):
    latents: int
    
    def setup(self):
        self.encoder = Encoder(self.latents)
        self.decoder = Decoder()
    
    def __call__(self, x, key):
        mean, logvar = self.encoder(x)
        std = jnp.exp(0.5 * logvar)
        eps = jax.random.normal(key, logvar.shape)
        z = mean + std * eps
        recon_x = self.decoder(z)
        return recon_x, mean, logvar
    
    def generate(self, z):
        return nn.sigmoid(self.decoder(z))
    


def create_train_state(model, key):
    @jax.jit
    def init():
        init_data = jnp.ones((BATCH_SIZE, 784), jnp.float32)
        return model.init(key, init_data, key)['params']
    
    return train_state.TrainState.create(
        apply_fn=model.apply,
        params=init(),
        tx=optax.adam(LEARNING_RATE))


key = jax.random.PRNGKey(666)
model = VAE(latents=LATENT_SIZE)
state = create_train_state(model, key)

In [4]:
example_batch = next(train_ds)
print(example_batch.shape)

(128, 784)


This is an very cool example of how vmap is used.

for this example logits and labels both have shape `(128, 784)`.

without `jax.vmap`, `binary_cross_entropy_with_logits` could compute the total sum of the matrix `(128, 784)`.

With `jax.vmap`, `binary_cross_entropy_with_logits` would sum the `(, 784)` part and return shape of `(128)`

and we can call `mean()` to get the mean of `(128)` losses.

not sure why kl_divergence is computed like this

it's supposed to be
$
\sum p(x) log(\frac{p(x)}{q(x)}) + (1-p(x)) log(\frac{1-p(x)}{1-q(x)})
$

In [5]:
@jax.vmap
def binary_cross_entropy_with_logits(logits, labels):
    logits = nn.log_sigmoid(logits)
    return -jnp.sum(labels * logits + (1. - labels) * jnp.log(-jnp.expm1(logits)))

@jax.vmap
def kl_divergence(mean, logvar):
    return -0.5 * jnp.sum(1 + logvar - jnp.square(mean) - jnp.exp(logvar))


@jax.jit
def train_step(state, batch, z_key):
    def loss_fn(params):
        recon_x, mean, logvar = state.apply_fn({'params': params}, batch, z_key)
        bce_loss = binary_cross_entropy_with_logits(recon_x, batch).mean()
        kld_loss = kl_divergence(mean, logvar)
        kld_loss = kld_loss.mean()
        
        loss = bce_loss + kld_loss
        return loss
    
    grads = jax.grad(loss_fn)(state.params)
    return state.apply_gradients(grads=grads)


key, z_key, eval_key = jax.random.split(key, 3)
test_state = train_step(state, example_batch, z_key)

In [6]:
%timeit train_step(test_state, example_batch, z_key)

477 µs ± 17.2 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)


`model.apply()` has a optional parameter `method`.

we can use this `method` to tell jax which method to call instead of the `__call__()`

In [7]:
def compute_metrics(recon_x, x, mean, logvar):
    bce_loss = binary_cross_entropy_with_logits(recon_x, x).mean()
    kld_loss = kl_divergence(mean, logvar).mean()
    return {
        'bce': bce_loss,
        'kld': kld_loss,
        'loss': bce_loss + kld_loss
    }


@jax.jit
def eval(state, batch, z, z_key):
    model = VAE(latents=LATENT_SIZE)
    recon_images, mean, logvar = state.apply_fn({'params': state.params}, batch, z_key)
    comparison = jnp.concatenate([batch[:16].reshape(-1, 28, 28, 1),
                                  recon_images[:16].reshape(-1, 28, 28, 1)])
    generate_images = state.apply_fn({'params': state.params}, z, method=model.generate)
    generate_images = generate_images.reshape(-1, 28, 28, 1)
    metrics = compute_metrics(recon_images, batch, mean, logvar)
    return metrics, comparison, generate_images


z = jax.random.normal(z_key, (64, LATENT_SIZE))
metrics, comparison, sample = eval(state, example_batch, z, eval_key)
print(comparison.shape)
print(sample.shape)

(32, 28, 28, 1)
(64, 28, 28, 1)


copy the following straight from the example.

I am too lazy to implement this. :p

In [8]:
def save_image(ndarray, fp, nrow=8, padding=2, pad_value=0.0, format=None):
    if not (isinstance(ndarray, jnp.ndarray) or
        (isinstance(ndarray, list) and all(isinstance(t, jnp.ndarray) for t in ndarray))):
        raise TypeError('array_like of tensors expected, got {}'.format(type(ndarray)))

    ndarray = jnp.asarray(ndarray)

    if ndarray.ndim == 4 and ndarray.shape[-1] == 1:  # single-channel images
        ndarray = jnp.concatenate((ndarray, ndarray, ndarray), -1)

    # make the mini-batch of images into a grid
    nmaps = ndarray.shape[0]
    xmaps = min(nrow, nmaps)
    ymaps = int(math.ceil(float(nmaps) / xmaps))
    height, width = int(ndarray.shape[1] + padding), int(ndarray.shape[2] + padding)
    num_channels = ndarray.shape[3]
    grid = jnp.full((height * ymaps + padding, width * xmaps + padding, num_channels), pad_value).astype(jnp.float32)
    k = 0
    for y in range(ymaps):
        for x in range(xmaps):
            if k >= nmaps:
                break
            grid = grid.at[y * height + padding:(y + 1) * height,
                           x * width + padding:(x + 1) * width].set(ndarray[k])
            k = k + 1

    # Add 0.5 after unnormalizing to [0, 255] to round to nearest integer
    ndarr = jnp.clip(grid * 255.0 + 0.5, 0, 255).astype(jnp.uint8)
    im = Image.fromarray(ndarr.copy())
    im.save(fp, format=format)

In [9]:
!mkdir -p results

In [10]:
NUM_EPOCHS = 30
STEPS_PER_EPOCH = 50000 // BATCH_SIZE

def train_and_evaluate():
    key = jax.random.PRNGKey(666)
    key, z_key = jax.random.split(key)
    z = jax.random.normal(z_key, (64, LATENT_SIZE))
    model = VAE(latents=LATENT_SIZE)
    state = create_train_state(model, key)
    for epoch in range(NUM_EPOCHS):
        for _ in range(STEPS_PER_EPOCH):
            batch = next(train_ds)
            key, z_key = jax.random.split(key)
            state = train_step(state, batch, z_key)
        
        metrics, comparison, sample = eval(state, test_ds, z, eval_key)
        save_image(comparison, f'results/reconstruction_{epoch}.png', nrow=8)
        save_image(sample, f'results/sample_{epoch}.png', nrow=8)
        print('eval epoch: {}, loss: {:.4f}, BCE: {:.4f}, KLD: {:.4f}'.format(
            epoch + 1, metrics['loss'], metrics['bce'], metrics['kld']
        ))


train_and_evaluate()

eval epoch: 1, loss: 140.5900, BCE: 122.9232, KLD: 17.6668
eval epoch: 2, loss: 122.3568, BCE: 101.5168, KLD: 20.8400
eval epoch: 3, loss: 114.6378, BCE: 92.2613, KLD: 22.3764
eval epoch: 4, loss: 110.6642, BCE: 88.0380, KLD: 22.6262
eval epoch: 5, loss: 107.6390, BCE: 84.2557, KLD: 23.3833
eval epoch: 6, loss: 106.5050, BCE: 82.6635, KLD: 23.8414
eval epoch: 7, loss: 105.0298, BCE: 80.9643, KLD: 24.0655
eval epoch: 8, loss: 104.3721, BCE: 79.8203, KLD: 24.5518
eval epoch: 9, loss: 103.8689, BCE: 79.1262, KLD: 24.7428
eval epoch: 10, loss: 102.9717, BCE: 78.2321, KLD: 24.7397
eval epoch: 11, loss: 102.5942, BCE: 78.4284, KLD: 24.1658
eval epoch: 12, loss: 102.7179, BCE: 77.3482, KLD: 25.3698
eval epoch: 13, loss: 101.7541, BCE: 77.2554, KLD: 24.4987
eval epoch: 14, loss: 101.6877, BCE: 76.4036, KLD: 25.2841
eval epoch: 15, loss: 101.3727, BCE: 76.4956, KLD: 24.8771
eval epoch: 16, loss: 100.9258, BCE: 75.7253, KLD: 25.2005
eval epoch: 17, loss: 100.7445, BCE: 75.4332, KLD: 25.3113
eval