In [2]:
import tensorflow as tf
import tensorflow_probability as tfp
tfd = tfp.distributions
tfb = tfp.bijectors

In [3]:
scale_tril = tfb.FillScaleTriL()([-0.5, 1.25, 1.])

In [7]:
scale_tril_chain = tfb.Chain([
    tfb.TransformDiagonal(tfb.Chain([
        tfb.Shift(1e-5), 
        tfb.Softplus()
    ])),
    tfb.FillTriangular()
])

In [9]:
p = tfd.MultivariateNormalTriL(loc=0, scale_tril=scale_tril)
# p2 = tfd.MultivariateNormalTriL(loc=0, scale_tril=scale_tril_chain)

In [10]:
# Symmetric and positive definite

# isotropic gaussian with zero mean
q = tfd.MultivariateNormalDiag(loc=[0., 0.])

In [11]:
tfd.kl_divergence(q, p)

<tf.Tensor: shape=(), dtype=float32, numpy=3.0560925>

### KL Divergence as an objective function

In [12]:
q = tfd.MultivariateNormalDiag(
    loc=tf.Variable(tf.random.normal([2])), # isotropic gaussian
    scale_diag=tfp.util.TransformedVariable(
        tf.random.uniform([2]), bijector=tfb.Exp()
    )
)

In [13]:
tfd.kl_divergence(q, p)

<tf.Tensor: shape=(), dtype=float32, numpy=4.878487>

In [14]:
@tf.function
def loss_and_grads(q_dist):
    with tf.GradientTape() as tape:
        loss = tfd.kl_divergence(q_dist, p)
    
    return loss, tape.gradient(loss, q_dist.trainable_variables)

In [16]:
opt = tf.keras.optimizers.Adam()
for i in range(10):
    loss, grads = loss_and_grads(q)
    print(i, loss)
    opt.apply_gradients(zip(grads, q.trainable_variables))

0 tf.Tensor(4.878487, shape=(), dtype=float32)
1 tf.Tensor(4.8662114, shape=(), dtype=float32)
2 tf.Tensor(4.8539586, shape=(), dtype=float32)
3 tf.Tensor(4.8417306, shape=(), dtype=float32)
4 tf.Tensor(4.8295264, shape=(), dtype=float32)
5 tf.Tensor(4.817347, shape=(), dtype=float32)
6 tf.Tensor(4.805195, shape=(), dtype=float32)
7 tf.Tensor(4.7930675, shape=(), dtype=float32)
8 tf.Tensor(4.780965, shape=(), dtype=float32)
9 tf.Tensor(4.7688894, shape=(), dtype=float32)


In [17]:
# target density is a full covariance gausian dist
# diagonal distribution

### ELBO

In [20]:
latent_size = 2
event_shape = (28, 28, 1)

encoder = tf.keras.models.Sequential([
    tf.keras.layers.Conv2D(8, (5, 5), strides=2, activation="tanh", input_shape=event_shape),
    tf.keras.layers.Conv2D(8, (5, 5), strides=2, activation="tanh"),
    tf.keras.layers.Flatten(),
    tf.keras.layers.Dense(64, activation="tanh"),
    tf.keras.layers.Dense(2 * latent_size),
    tfp.layers.DistributionLambda(
        lambda t: tfd.MultivariateNormalDiag(
            loc=t[..., :latent_size],
            scale_diag=tf.math.exp(t[..., latent_size:])
        )
    )
], name="encoder")

Instructions for updating:
Do not pass `graph_parents`.  They will  no longer be used.


In [24]:
decoder = tf.keras.models.Sequential([
    tf.keras.layers.Dense(64, activation="tanh", input_shape=(latent_size,)),
    tf.keras.layers.Dense(128, activation="tanh"),
    tf.keras.layers.Reshape((4, 4, 8)),
    tf.keras.layers.Conv2DTranspose(8, (5, 5), strides=2, output_padding=1, activation="tanh"),
    tf.keras.layers.Conv2DTranspose(8, (5, 5), strides=2, output_padding=1, activation="tanh"),
    tf.keras.layers.Conv2D(1, (3, 3), padding="SAME"),
    tf.keras.layers.Flatten(),
    tfp.layers.IndependentBernoulli(
        event_shape
    )
], name="encoder")

In [25]:
decoder(tf.random.normal([16, latent_size]))

<tfp.distributions.Independent 'encoder_independent_bernoulli_1_IndependentBernoulli_Independentencoder_independent_bernoulli_1_IndependentBernoulli_Bernoulli' batch_shape=[16] event_shape=[28, 28, 1] dtype=float32>

In [26]:
prior = tfd.MultivariateNormalDiag(loc=tf.zeros(latent_size))
prior

<tfp.distributions.MultivariateNormalDiag 'MultivariateNormalDiag' batch_shape=[] event_shape=[2] dtype=float32>

In [27]:
def loss_fn(x_true, approx_posterior, x_pred, prior_dist):
    return tf.reduce_mean(
        tfd.kl_divergence(
            approx_posterior, prior_dist
        ) - x_pred.log_prob(x_true)
    )

In [28]:
def loss_fn(x_true, approx_posterior, x_pred, prior_dist):
    reconstruction_loss = -x_pred.log_prob(x_true)
    approx_posterior_sample = approx_posterior.sample()
    kl_approx = (approx_posterior.log_prob(approx_posterior_sample) 
        - prior_dist.log_prob(approx_posterior_sample))
    return tf.reduce_mean(
        kl_approx + reconstruction_loss
    )

In [29]:
@tf.function
def get_loss_and_grads(x):
    with tf.GradientTape() as tape:
        approx_posterior = encoder(x)
        approx_posterior_sample = approx_posterior.sample()
        x_pred = decoder(approx_posterior_sample)
        current_loss = loss_fn(x, approx_posterior, x_pred, prior)
        
    grads = tape.gradient(current_loss, encoder.trainable_variables + decoder.trainable_variables)
    
    return current_loss, grads

### KL Divergence Layers

In [37]:
latent_Size = 4
prior = tfd.MultivariateNormalDiag(loc=tf.zeros(latent_size))

encoder = tf.keras.models.Sequential([
    tf.keras.layers.Dense(64, activation='relu', input_shape=(12,)),
    tf.keras.layers.Dense(tfp.layers.MultivariateNormalTriL.params_size(latent_size)),
    tfp.layers.MultivariateNormalTriL(latent_size),
    tfp.layers.KLDivergenceAddLoss(prior, weight=10, use_exact_kl=False)
])

In [38]:
decoder = tf.keras.models.Sequential([
    tf.keras.layers.Dense(64, activation='relu', input_shape=(latent_size,)),
    tf.keras.layers.Dense(tfp.layers.IndependentNormal.params_size(12)),
    tfp.layers.IndependentNormal(12)
])

vae = tf.keras.models.Model(inputs=encoder.input, outputs=decoder(encoder.input))



ValueError: Input 0 of layer dense_21 is incompatible with the layer: expected axis -1 of input shape to have value 2 but received input with shape (None, 12)