In [None]:
import numpy as np
import matplotlib.pyplot as plt
from corner import corner
import tensorflow as tf
import tensorflow_addons as tfa
import tensorflow_probability as tfp

tfk = tf.keras
tfd = tfp.distributions
tfb = tfp.bijectors

In [None]:
def powerlaw(x, slope, lo, hi):
    
    return (
        (x >= lo) * (x <= hi)
        * x**slope
        * (slope+1) / (hi**(slope+1) - lo**(slope+1))
        )

def sample_powerlaw(n_samples, slope, lo, hi):
    
    x = np.random.uniform(size=n_samples)
    
    return (lo**(slope+1) + x * (hi**(slope+1) - lo**(slope+1)))**(1/(slope+1))

In [None]:
lo = 0
hi = 1
x = np.linspace(lo-.5, hi+.5, 100)

for slope in np.linspace(0, 5, 10):
    plt.plot(x, powerlaw(x, slope, lo, hi))

In [None]:
n_slopes = 10
n_samples = 1000
lo = 0
hi = 1

slopes = np.linspace(0, 5, n_slopes)
data = sample_powerlaw(n_samples, slopes[:, None], lo, hi)
data.shape

In [None]:
bins = np.linspace(lo, hi, 20)
for d in data:
    plt.hist(d, bins=bins, density=True, histtype='step')

In [None]:
slopes = np.repeat(slopes[:, None], n_samples, axis=1)
slopes.shape

In [None]:
data = np.concatenate(data)
slopes = np.concatenate(slopes)
data.shape, slopes.shape

In [None]:
# Let's try adding a useless dimension so we can permute
# This means sampling will still work, but density evaluation won't
# We need some way to marginalize out the fake dimension
data = np.array([data, np.zeros(data.size)]).T
data.shape

In [None]:
# https://github.com/tensorflow/probability/issues/1410
# https://github.com/tensorflow/probability/issues/1006#issuecomment-663141106

import re

def make_bijector_kwargs(bijector, name_to_kwargs):
    if hasattr(bijector, 'bijectors'):
        return {b.name: make_bijector_kwargs(b, name_to_kwargs) for b in bijector.bijectors}
    else:
        for name_regex, kwargs in name_to_kwargs.items():
            if re.match(name_regex, bijector.name):
                return kwargs
    return {}

In [None]:
n_flows = 3
n_layers = 1
n_units = 128
activation = 'relu'

bijectors = []

#bijectors.append(tfb.BatchNormalization())

# We transform at the end with a logistic function
# This ensures all samples are in [0, 1]
bs = [
    tfb.Chain([tfb.Scale(scale=.5), tfb.Shift(shift=1.), tfb.Tanh()]),
    tfb.Identity(),
    ]
blockwise = tfb.Blockwise(bs, block_sizes=[1, 1])
bijectors.append(blockwise)

for i in range(n_flows):

    made = tfb.AutoregressiveNetwork(
        params=2,
        event_shape=(2,),
        conditional=True,
        conditional_event_shape=(1,),
        #conditional_input_layers='all_layers',
        hidden_units=[n_units]*n_layers,
        #input_order='left-to-right',
        #hidden_degrees='equal',
        activation=activation,
        #use_bias=True,
        #kernel_initializer='glorot_uniform',
        #bias_initializer='zeros',
        # kernel_regularizer=tf.keras.regularizers.L2(l2=1e-6),
        # bias_regularizer=tf.keras.regularizers.L2(l2=1e-6),
        #kernel_constraint=None,
        #bias_constraint=None,
        #validate_args=False,
        )
    maf = tfb.MaskedAutoregressiveFlow(made, name=f'maf{i}')
    bijectors.append(maf)
    
    #bijectors.append(tfb.BatchNormalization())
    # bn = tfb.BatchNormalization(
    #     batchnorm_layer=tfk.layers.BatchNormalization(
    #         momentum=0.,
    #         epsilon=1e-5,
    #         center=True,
    #         scale=True,
    #         gamma_constraint=tf.math.exp,
    #         ),
    #     training=True,
    #     )
    # bijectors.append(bn)

    bijectors.append(tfb.Permute([1, 0]))

bijector = tfb.Chain(bijectors)
distribution = tfd.Sample(tfd.Normal(loc=0., scale=1.), sample_shape=[2])                           
nf = tfd.TransformedDistribution(distribution=distribution, bijector=bijector)

In [None]:
x = tf.keras.Input(shape=(2,), dtype=tf.float32)
c = tf.keras.Input(shape=(1,), dtype=tf.float32)

log_prob = nf.log_prob(
    x,
    bijector_kwargs=make_bijector_kwargs(
        nf.bijector, {'maf.': {'conditional_input': c}},
        ),
    )

model = tf.keras.Model([x, c], log_prob)

In [None]:
epochs = 20
batch_size = 100
learning_rate = 1e-3

model.compile(
    optimizer=tf.optimizers.Adam(learning_rate=learning_rate),
    loss=lambda _, log_prob: -log_prob,
    )

result = model.fit(
    x=[data, slopes],
    y=np.zeros(n_samples*n_slopes, dtype=np.float32),
    epochs=epochs,
    batch_size=batch_size,
    shuffle=True,
    verbose=1,
    )

In [None]:
plt.plot(result.history['loss']);

In [None]:
slope = 0
n_samples = 10000

condition = slope * np.ones((n_samples, 1))
bijector_kwargs = {
    f'maf{i}': {'conditional_input': condition} 
    for i in range(n_flows)
    }

corner(
    nf.sample(
        n_samples,
        bijector_kwargs=bijector_kwargs,
        ).numpy(),
    );

In [None]:
slope = 0
n_samples = 10000

condition = slope * np.ones((n_samples, 1))
bijector_kwargs = {
    f'maf{i}': {'conditional_input': condition} 
    for i in range(n_flows)
    }

corner(
    nf.bijector.forward(
        distribution.sample(n_samples),
        **bijector_kwargs,
        ).numpy(),
    );

In [None]:
slope = 0
n_samples = 10000

condition = slope * np.ones((n_samples, 1))
samples = np.array([
    sample_powerlaw(n_samples, slope, lo, hi),
    np.zeros(n_samples),
    ]).T

bijector_kwargs = {
    f'maf{i}': {'conditional_input': condition} 
    for i in range(n_flows)
    }

corner(
    nf.bijector.inverse(
        samples,
        **bijector_kwargs,
        ).numpy(),
    );