In [None]:
import time
import numpy as np
import matplotlib.pyplot as plt
import corner

import tensorflow as tf
import tensorflow_probability as tfp

tfk = tf.keras
tfd = tfp.distributions
tfb = tfp.bijectors

seed = 42
np.random.seed(seed)
tf.random.set_seed(seed)

In [None]:
def timer(func):
    
    t0 = time.time()
    x = func()
    print(time.time() - t0)
    
    return x

In [None]:
# https://github.com/tensorflow/probability/issues/1410
# https://github.com/tensorflow/probability/issues/1006#issuecomment-663141106

def maf_kwargs(flow, condition):
    
    return {
        b.name: {'conditional_input': condition}
        for b in flow.bijector.bijectors if 'maf' in b.name
        }


def iaf_kwargs(flow, condition):
    
    bijectors = flow.bijector.bijectors[-1].bijector.bijectors
    prefix = bijectors[-1].name[:-1]
    
    return {
        'invert': {
            b.name: {'conditional_input': condition}
            for b in flow.bijector.bijectors[-1].bijector.bijectors if 'maf' in b.name
            }
        }

In [None]:
dims = 2
cond_dims = 1

In [None]:
distribution = tfd.Sample(
    tfd.Normal(loc=0., scale=1.),
    sample_shape=[dims],
    )

output_bijectors = [
    tfb.Scale(scale=.5),
    tfb.Shift(shift=1.),
    tfb.Tanh(),
    ]

bijectors = []
for i in range(10):
    made = tfb.AutoregressiveNetwork(
        params=2,
        event_shape=[dims],
        conditional=True,
        conditional_event_shape=[cond_dims],
        hidden_units=[1024],
        kernel_initializer='RandomNormal',
        bias_initializer='RandomNormal',
        )
    bijectors.append(tfb.MaskedAutoregressiveFlow(made, name=f'maf{i}'))

maf = tfd.TransformedDistribution(
    distribution=distribution,
    bijector=tfb.Chain(output_bijectors + bijectors),
    )
iaf = tfd.TransformedDistribution(
    distribution=distribution,
    bijector=tfb.Chain(output_bijectors + [tfb.Invert(tfb.Chain(bijectors), name='invert')]),
    )

In [None]:
n = 10000

In [None]:
kw = maf_kwargs(maf, [1.])
sample = timer(lambda: maf.sample(n, bijector_kwargs=kw))
lp = timer(lambda: maf.log_prob(sample, bijector_kwargs=kw))
corner.corner(sample.numpy())
plt.show()

kw = iaf_kwargs(iaf, [1.])
sample = timer(lambda: iaf.sample(n, bijector_kwargs=kw))
lp = timer(lambda: iaf.log_prob(sample, bijector_kwargs=kw))
corner.corner(sample.numpy())
plt.show()