In [None]:
%matplotlib inline
import os
import pathlib

import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
import tensorflow as tf
import tensorflow.keras as K
import tensorflow_probability as tfp
import tensorflow_probability.python.bijectors as tfb
import tensorflow_probability.python.distributions as tfd
import tensorflow_datasets as tfds

devices = tf.config.experimental.list_physical_devices(device_type='GPU')
for d in devices:
    tf.config.experimental.set_memory_growth(d, True)

print(devices)
    
# tf.debugging.set_log_device_placement(True)
# tf.debugging.enable_check_numerics()

# Disable LayoutOptimizer since it raises reshape error (Why?)
tf.config.optimizer.set_experimental_options({'layout_optimizer': False})

# Configuration

In [None]:
log_dir = pathlib.Path('./logs/realnvp_cifar10_keras_test_run')
ckpt_dir = log_dir / 'checkpoints'

# Dataset

In [None]:
dataset, dataset_info = tfds.load(
        'cifar10',
        split=None, #tfds.Split.ALL,
        shuffle_files=True,
        as_supervised=True,
        with_info=True,
)

def preprocess(dataset):
    return dataset.map(
        lambda img, label: (
            0.5 + 0.5 * tf.cast(img, tf.float32) / 256.0,
            label,
        )
    )

dataset['train'] = preprocess(dataset['train'])
dataset['test'] = preprocess(dataset['test'])

h, w, c = dataset_info.features['image'].shape
num_train_examples = dataset_info.splits['train'].num_examples
num_test_examples = dataset_info.splits['test'].num_examples

sample = next(iter(dataset['train']))
plt.imshow(sample[0].numpy().squeeze())

# Flow

In [None]:
class ShiftAndLogScale(tf.Module):
    def __init__(self, output_units, name='shift_and_log_scale'):
        super().__init__(name=name)

        self.output_units = output_units
        
        self.net = K.Sequential([
            K.layers.BatchNormalization(),
            K.layers.Conv2D(256, 3, 1, 'same', activation='relu'),
            K.layers.BatchNormalization(),
            K.layers.Conv2D(512, 3, 1, 'same', activation='relu'),
            K.layers.BatchNormalization(),
            K.layers.Conv2D(output_units*2, 3, 1, 'same', activation=None),
        ])
        
    @tf.function
    def __call__(self, x, output_units):
        assert output_units == self.output_units
        x = self.net(x)
        shift, log_scale = tf.split(x, 2, axis=-1)
        return shift, log_scale


class RealNVP(tfb.Chain):
    def __init__(self, n_layers, n_masked, n_units, name=None):
        def make_layer(i):
            fn = ShiftAndLogScale(n_units - n_masked)
            chain = [
                tfb.RealNVP(
                    num_masked=n_masked,
                    shift_and_log_scale_fn=fn,
                ),
                tfb.BatchNormalization(),
            ]
            if i % 2 == 0:
                perm = lambda: tfb.Permute(permutation=[2, 0, 1], axis=-1)
                chain = [perm(), *chain, perm()]
            return tfb.Chain(chain)

        chain = [
            tfb.Sigmoid(),
            *[make_layer(i) for i in range(n_layers)],
            tfb.Reshape((h, w, 3), (-1,)),
        ]
        super().__init__(chain, name=name)

# Wrapper classes for Keras

In [None]:
class TransformedDistribution(tfd.TransformedDistribution):
    @property
    def weights(self):
        return self.variables

    @property
    def trainable_weights(self):
        return self.trainable_variables

    @property
    def non_trainable_weights(self):
        return tuple(filter(lambda v: not getattr(v, 'trainable', False), self.weights))
    
class LogProb(K.Model):
    def __init__(self, distribution, bijector):
        super().__init__()
        self.flow = TransformedDistribution(
            distribution=distribution,
            bijector=bijector,
        )
        
    def call(self, x):
        log_prob = self.flow.log_prob(x)
        return log_prob

class NegativeLogLikelihood(K.losses.Loss):
    def call(self, _, log_prob):
        return -log_prob / (h*w*c)

# Build model

In [None]:
distribution = tfd.MultivariateNormalDiag(
    loc=tf.zeros(h*w*3, name='loc'),
    scale_diag=tf.ones(h*w*3, name='scale_diag'),
    name='distribution',
)

bijector = RealNVP(n_layers=10, n_masked=2, n_units=3, name='bijector') 

model = LogProb(distribution, bijector)
loss_fn = NegativeLogLikelihood(name='nll')
model.compile(
    optimizer=tf.optimizers.SGD(learning_rate=1e-3),
    loss=loss_fn,
)
model.build((None, h, w, c))
model.summary()

# Training

In [None]:
def sampling_callback(flow, steps_per_epoch, n_samples=9):
    def calibrate(samples):
        samples = (samples - 0.5) / 0.5
        samples = tf.clip_by_value(samples, 0.0, 1.0)
        return samples

    normal_samples = flow.distribution.sample(n_samples)
    def _(epoch, logs):
        print(epoch)
        samples = flow.bijector.forward(normal_samples)
        samples = calibrate(samples)

        tf.summary.image(
            'samples', samples,
            step=epoch*steps_per_epoch,
            max_outputs=n_samples
        )
    return _
        
callbacks = [
    K.callbacks.TensorBoard(
        log_dir=log_dir.as_posix(), update_freq='batch',
        histogram_freq=100,
        write_graph=True, write_images=True,
    ),
    K.callbacks.ModelCheckpoint(
        ckpt_dir.as_posix(), save_weights_only=True, verbose=1,
    ),
    K.callbacks.LambdaCallback(
        on_epoch_end=sampling_callback(
            model.flow, num_train_examples),
    ),
]

batchsize = 64

history = model.fit(
    dataset['train'].batch(batchsize).repeat(),
    steps_per_epoch=num_train_examples,
    epochs=100,
    validation_data=dataset['test'].batch(batchsize).repeat(),
    validation_steps=num_test_examples,
    callbacks=callbacks,
    
)