In [None]:
import numpy as np
import h5py
import matplotlib.pyplot as plt
import corner
import tensorflow as tf
import tensorflow_addons as tfa
import tensorflow_probability as tfp

tfd = tfp.distributions
tfb = tfp.bijectors

In [None]:
n_samples = 1000

data = np.random.uniform(size=[n_samples, 1])

In [None]:
def define_nf():
    
    return tfd.TransformedDistribution(
        distribution=tfd.Sample(tfd.Normal(loc=0., scale=1.), sample_shape=[1]),
        bijector=tfb.MaskedAutoregressiveFlow(tfb.AutoregressiveNetwork(
            params=2,
            hidden_units=[4],
            activation='relu',
            )),
        )

In [None]:
nf = define_nf()

In [None]:
def define_model(nf):
    
    x = tf.keras.Input(shape=[1], dtype=tf.float32)

    return tf.keras.Model(inputs=x, outputs=nf.log_prob(x))

In [None]:
model = define_model(nf)
model.compile(
    optimizer=tf.optimizers.Adam(),
    loss=lambda _, log_prob: -log_prob,
    )

In [None]:
nf.variables, model.weights

In [None]:
xx = np.linspace(-.5, 1.5, 100)
px = np.exp(nf.log_prob(xx[:, None]))
plt.plot(xx, px);

In [None]:
callbacks = [
    tf.keras.callbacks.ModelCheckpoint(
        filepath='./saving.hdf5',
        monitor='loss',
        mode='min',
        save_weights_only=True,
        save_best_only=True,
        save_freq='epoch',
        verbose=1,
        ),
    ]

result = model.fit(
    x=data, y=np.zeros(n_samples), epochs=10, callbacks=callbacks,
    )

In [None]:
nf.variables, model.weights

In [None]:
plt.plot(xx, px)

px = np.exp(nf.log_prob(xx[:, None]))
plt.plot(xx, px)

px_ = np.exp(model.predict_on_batch(xx))
plt.plot(xx, px_, ls='--');

In [None]:
with h5py.File('./saving.hdf5', 'r') as h:
    print(h.keys())

In [None]:
nf = define_nf()
model = define_model(nf)
nf.variables, model.weights

In [None]:
model.load_weights('./saving.hdf5')
nf.variables, model.weights

In [None]:
plt.plot(xx, px, lw=10)

px = np.exp(nf.log_prob(xx[:, None]))
plt.plot(xx, px, lw=5)

px_ = np.exp(model.predict_on_batch(xx))
plt.plot(xx, px_, ls='--', lw=2);