Defining KL divergence analytically and numerically and checking that results are approximately equal


In [69]:
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' 
import numpy as np
import tensorflow as tf
import tensorflow_probability as tfp
from scipy.stats import multivariate_normal

def build_covariance_matrix(cov_params, dim, dtype=tf.float32):
    chol = tfp.math.fill_triangular(cov_params)
    cov_matrix = tf.matmul(chol, chol, transpose_b=True)  # Cholesky factor * Cholesky factor transpose
    return cov_matrix

def kl_divergence(mean1, cov1, mean2, cov2):
    inv_cov2 = tf.linalg.inv(cov2)
    trace_term = tf.linalg.trace(tf.matmul(inv_cov2, cov1))
    diff_mean = mean2 - mean1
    mahalanobis_term = tf.reduce_sum(tf.matmul(tf.transpose(diff_mean[:, tf.newaxis]), inv_cov2) * diff_mean[:, tf.newaxis], axis=-1)
    log_det_cov1 = tf.linalg.logdet(cov1)
    log_det_cov2 = tf.linalg.logdet(cov2)
    kl = 0.5 * (trace_term + mahalanobis_term - tf.cast(tf.shape(mean1)[-1], dtype) + log_det_cov2 - log_det_cov1)
    return tf.reduce_mean(kl)

dtype = tf.float32
dim = 2
mean1 = mean2 = tf.constant([0.0, 0.0], dtype=dtype)
cov_params1 = tf.constant([1, 0.5, 1], dtype=dtype)
cov_params2 = tf.constant([2, 0.5, 2], dtype=dtype)
cov1 = build_covariance_matrix(cov_params1, dim)
cov2 = build_covariance_matrix(cov_params2, dim)


In [70]:
%time kl_divergence(mean1, cov1, mean2, cov2)


CPU times: user 6.57 ms, sys: 437 μs, total: 7 ms
Wall time: 4.83 ms


<tf.Tensor: shape=(), dtype=float32, numpy=0.64410686>

In [85]:
def kl_divergence_numerical(mu1, sigma1, mu2, sigma2, num_samples=10000):

    # define grid for integrating densities
    grid_size = 1000
    grid_x = np.linspace(-10, 10, grid_size)
    grid_y = np.linspace(-10, 10, grid_size)
    grid_X, grid_Y = np.meshgrid(grid_x, grid_y)
    positions = np.vstack([grid_X.ravel(), grid_Y.ravel()]).T
    
    # evaluate the densities of both Gaussians at the sampled points
    p_x = multivariate_normal.pdf(positions, mean=mu1, cov=sigma1)
    q_x = multivariate_normal.pdf(positions, mean=mu2, cov=sigma2)
        
    # compute the KL divergence
    kl_div = np.sum(p_x * np.log(p_x / q_x)) * (grid_x[1] - grid_x[0]) * (grid_y[1] - grid_y[0])
    
    return kl_div

kl_divergence_numerical(mean1, cov1, mean2, cov2)


0.6441068611198906

In [2]:
import numpy as np

# Example target for batch of size 32
batch_size = 32
mean_dim = 2
cholesky_dim = 2

# Random example data
true_mean = np.random.rand(batch_size, mean_dim)
true_cholesky = np.random.rand(batch_size, cholesky_dim, cholesky_dim)

# Flatten the Cholesky matrices and concatenate with means
true_cholesky_flat = np.array([cholesky.flatten() for cholesky in true_cholesky])
y_true = np.concatenate([true_mean, true_cholesky_flat], axis=-1)
