In [3]:
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' 
import tensorflow as tf
import pickle
os.sys.path.append('../../evaluation/')
import metrics

# load data
input_path = '../../data/simulated_2d_demo/'
dtype = tf.float32
y_val = pickle.load(open(input_path+'/val/target_data.pkl', 'rb'))
y_val = [metrics.output_to_stats_2d(i) for i in y_val]
y_mean = tf.convert_to_tensor([i[0] for i in y_val], dtype=dtype)
y_cov = tf.convert_to_tensor([i[1] for i in y_val], dtype=dtype)

def kl_divergence(mean1, cov1, mean2, cov2):
    # Ensure shapes are consistent
    batch_size = tf.cast(tf.shape(mean1)[0], tf.float32)
    num_features = tf.cast(tf.shape(mean1)[-1], tf.float32)

    # Compute the inverse of cov2
    inv_cov2 = tf.linalg.inv(cov2)

    # Compute the trace term: trace(inv_cov2 @ cov1)
    trace_term = tf.linalg.trace(tf.linalg.matmul(inv_cov2, cov1, transpose_a=False, transpose_b=True))

    # Compute the Mahalanobis term
    diff_mean = mean2 - mean1
    diff_mean_expanded = tf.expand_dims(diff_mean, axis=-1)
    mahalanobis_term = tf.reduce_sum(tf.linalg.matmul(inv_cov2, diff_mean_expanded) * diff_mean_expanded, axis=-2)

    # Compute the log-determinants
    log_det_cov1 = tf.linalg.logdet(cov1)
    log_det_cov2 = tf.linalg.logdet(cov2)
    
    print(trace_term.shape)
    print(mahalanobis_term.shape)
    print(log_det_cov2.shape)
    print(log_det_cov1.shape)

    # KL divergence computation
    kl = 0.5 * (trace_term + tf.squeeze(mahalanobis_term) - num_features + log_det_cov2 - log_det_cov1)
    
    # Return the average KL divergence over the batch dimension
    return kl

kl_loss = kl_divergence(y_mean, y_cov, y_mean+1, y_cov)


(5,)
(5, 1)
(5,)
(5,)


In [4]:
kl_loss


<tf.Tensor: shape=(5,), dtype=float32, numpy=
array([2.1675751, 4.9288616, 1.5267398, 2.1787164, 1.3317399],
      dtype=float32)>

In [8]:
import numpy as np
from scipy.stats import multivariate_normal
def kl_divergence_numerical(mu1, sigma1, mu2, sigma2, num_samples=10000):

    # define grid for integrating densities
    grid_size = 1000
    grid_x = np.linspace(-10, 10, grid_size)
    grid_y = np.linspace(-10, 10, grid_size)
    grid_X, grid_Y = np.meshgrid(grid_x, grid_y)
    positions = np.vstack([grid_X.ravel(), grid_Y.ravel()]).T
    
    # evaluate the densities of both Gaussians at the sampled points
    p_x = multivariate_normal.pdf(positions, mean=mu1, cov=sigma1)
    q_x = multivariate_normal.pdf(positions, mean=mu2, cov=sigma2)
        
    # compute the KL divergence
    kl_div = np.sum(p_x * np.log(p_x / q_x)) * (grid_x[1] - grid_x[0]) * (grid_y[1] - grid_y[0])
    
    return kl_div

for idx in range(5):
    print(kl_divergence_numerical(y_mean[idx], y_cov[idx], y_mean[idx]+1, y_cov[idx]))


2.1675750647196237
4.9288617090384115
1.5267398630510578
2.1787163020121043
1.3317399449767944
