In [3]:
import jax.numpy as jnp
from jax.scipy.stats import multivariate_normal
from jax.scipy.special import logsumexp
from jax import grad
from utils import old_ksd

import jax

from tqdm import tqdm
import numpy as np

# Your logpdf function
def logpdf(x, mean, cov):
    return multivariate_normal.logpdf(x, mean=mean, cov=cov)

# Your logpdf_grad_func function which computes gradient at a point given the point and additional parameters
def logpdf_grad_func(x, mean, cov):
    return grad(logpdf, argnums=0)(x, mean, cov)


# Define the target and proposal distribution parameters
target_mean = jnp.array([2.0, 2.0])
target_cov = 0.5 * jnp.eye(2)
proposal_mean = jnp.array([1., 1.])
proposal_cov = 0.3 * jnp.eye(2)

params_target = (target_mean, target_cov)
params_proposal = (proposal_mean, proposal_cov)

# Partially apply logpdf_grad_func to create a new function with fixed params
logpdf_grad_func_params_fixed = jax.tree_util.Partial(logpdf_grad_func, mean=target_mean, cov=target_cov)

logpdf_grad_func_params_fixed_proposal = jax.tree_util.Partial(logpdf_grad_func, mean=proposal_mean, cov=proposal_cov)


# Generate samples from the target and proposal distributions
np.random.seed(0)
num_samples = 10
true_samples = np.random.multivariate_normal(target_mean, target_cov, num_samples)
proposal_samples = np.random.multivariate_normal(proposal_mean, proposal_cov, num_samples)

# Calculate the importance weights
log_weights = logpdf(proposal_samples, mean=target_mean, cov=target_cov) - logpdf(proposal_samples, mean=proposal_mean, cov=proposal_cov)

weights = np.exp(log_weights - logsumexp(log_weights))

# Calculate and print the KSD for true and importance weighted samples
ksd_true = ksd(true_samples, logpdf_grad_func_params_fixed)
print("KSD using true samples:", ksd_true)

ksd_importance = ksd(proposal_samples, logpdf_grad_func_params_fixed, weights=weights)
print("KSD using importance weighted samples:", ksd_importance)

ksd_true_old = old_ksd(true_samples, logpdf_grad_func_params_fixed)
print("KSD using true samples:", ksd_true_old)

ksd_importance_old = old_ksd(proposal_samples, logpdf_grad_func_params_fixed, weights=weights)
print("KSD using importance weighted samples:", ksd_importance_old)




KSD using true samples: 0.3673918216793771
KSD using importance weighted samples: 0.43224111555653943
KSD using true samples: 0.3673918216793771
KSD using importance weighted samples: 0.43224111555653943
