In [16]:
import jax.numpy as jnp
from jax import grad, vmap
from jax.scipy.stats import multivariate_normal
from jax.scipy.special import logsumexp
import jax

from tqdm import tqdm
import numpy as np

# Your logpdf function
def logpdf(x, mean, cov):
    return multivariate_normal.logpdf(x, mean=mean, cov=cov)

# Your logpdf_grad_func function which computes gradient at a point given the point and additional parameters
def logpdf_grad_func(x, mean, cov):
    return grad(logpdf, argnums=0)(x, mean, cov)


def k0(x, y, grd, c2=1.0, beta=0.5):
    d = len(x)

    z = x - y
    r2 = jnp.sum(jnp.square(z))
    base = c2 + r2
    base_beta = base ** (-beta)
    base_beta1 = base_beta / base

    gradlogpx, gradlogpy = grd(x), grd(y)

    coeffk = jnp.dot(gradlogpx, gradlogpy)
    coeffgrad = -2.0 * beta * base_beta1

    kterm = coeffk * base_beta
    gradandgradgradterms = coeffgrad * (
        (jnp.dot(gradlogpy, z) - jnp.dot(gradlogpx, z)) +
        (-d + 2 * (beta + 1) * r2 / base)
    )

    return kterm + gradandgradgradterms

# Now, create batch versions of the k0 and ksd functions using vmap
batch_k0 = vmap(vmap(k0, in_axes=(0, None, None, None, None), out_axes=0), in_axes=(None, 0, None, None, None), out_axes=0)

# The ksd function can now use batch_k0 to compute the KSD in a vectorized manner
def ksd(samples, logpdf_grad_func, weights=None, c2=1.0, beta=0.5):
    N = samples.shape[0]
    if weights is None:
        weights = jnp.ones(N) / N

    ksd_matrix = batch_k0(samples, samples, logpdf_grad_func, c2, beta)
    ksd_value = jnp.sqrt(jnp.sum(ksd_matrix * weights[:, None] * weights[None, :]) / N)
    return ksd_value

# Define the target and proposal distribution parameters
target_mean = jnp.array([2.0, 2.0])
target_cov = 0.5 * jnp.eye(2)
proposal_mean = jnp.array([1., 1.])
proposal_cov = 0.3 * jnp.eye(2)

params_target = (target_mean, target_cov)
params_proposal = (proposal_mean, proposal_cov)

# Partially apply logpdf_grad_func to create a new function with fixed params
logpdf_grad_func_params_fixed = jax.tree_util.Partial(logpdf_grad_func, mean=target_mean, cov=target_cov)

logpdf_grad_func_params_fixed_proposal = jax.tree_util.Partial(logpdf_grad_func, mean=proposal_mean, cov=proposal_cov)


# Generate samples from the target and proposal distributions
np.random.seed(0)
num_samples = 10000
true_samples = np.random.multivariate_normal(target_mean, target_cov, num_samples)
proposal_samples = np.random.multivariate_normal(proposal_mean, proposal_cov, num_samples)

# Calculate the importance weights
log_weights = logpdf(proposal_samples, mean=target_mean, cov=target_cov) - logpdf(proposal_samples, mean=proposal_mean, cov=proposal_cov)

weights = np.exp(log_weights - logsumexp(log_weights))

# Calculate and print the KSD for true and importance weighted samples
ksd_true = ksd(true_samples, logpdf_grad_func_params_fixed)
print("KSD using true samples:", ksd_true)

ksd_importance = ksd(proposal_samples, logpdf_grad_func_params_fixed, weights=weights)
print("KSD using importance weighted samples:", ksd_importance)


KSD using true samples: 0.0002471425
KSD using importance weighted samples: 0.0039374777
