In [1]:
import os
import sys
os.chdir('/home/zongchen/mmd_flow_cubature/')
sys.path.append('/home/zongchen/mmd_flow_cubature/')
import numpy as np
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
jax.config.update("jax_platform_name", "cpu")
jax.config.update("jax_enable_x64", True)

from mmd_flow.distributions import Distribution
from mmd_flow.kernels import gaussian_kernel, laplace_kernel
from mmd_flow.mmd import mmd_fixed_target
from mmd_flow.gradient_flow import gradient_flow
from tqdm import tqdm

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
bandwidth = 1.0
kernel = gaussian_kernel(bandwidth)
covariances = jnp.load('data/mog_covs.npy')
means = jnp.load('data/mog_means.npy')
# means = jnp.zeros((20, 2))
# covariances = jnp.array([jnp.eye(2) for _ in range(20)])
k = 20
weights = jnp.ones(k) / k
distribution = Distribution(kernel=kernel, means=means, covariances=covariances, integrand_name='neg_exp', weights=weights)
mmd_func = mmd_fixed_target(None, kernel, distribution)


In [10]:
rng_key = jax.random.PRNGKey(0)
d = 10
mean = jax.random.normal(rng_key, shape=(d, ))
# mean = jnp.load('data/mog_means.npy')[0, :]

rng_key, _ = jax.random.split(rng_key)
A = jax.random.normal(rng_key, shape=(d, d))
cov = jnp.dot(A, A.T) + d * jnp.eye(d)  # Ensures positive definiteness
# cov = jnp.load('data/mog_covs.npy')[0, :, :]
kernel = gaussian_kernel(5.0)
rng_key, _ = jax.random.split(rng_key)
Y = jax.random.normal(rng_key, shape=(10, d))

closed_form_kme = kernel.mean_embedding(Y, mean, cov)
print(closed_form_kme)
rng_key, _ = jax.random.split(rng_key)
samples = jax.random.multivariate_normal(rng_key, mean, cov, shape=[100000,]) 
gram_matrix = kernel.make_distance_matrix(Y, samples)
print(gram_matrix.mean(1))

[0.05081233 0.04268435 0.05223575 0.05016228 0.04853699 0.05211426
 0.04489118 0.05417713 0.05106026 0.04841987]
[0.05046493 0.04264479 0.05197224 0.05010664 0.04831337 0.05190297
 0.04474058 0.05393293 0.05089311 0.04820461]


In [6]:
def kme_double_RBF_diff_Gaussian(mu_1, mu_2, Sigma_1, Sigma_2, l):
    """
    Computes the double integral a gaussian kernel with lengthscale l, with two different Gaussians.
    
    Args:
        mu_1, mu_2: (D,) 
        Sigma_1, Sigma_2: (D, D)
        l : scalar

    Returns:
        A scalar: the value of the integral.
    """
    D = mu_1.shape[0]
    l_ = l ** 2
    Lambda = jnp.eye(D) * l_
    sum_ = Sigma_1 + Sigma_2 + Lambda
    part_1 = jnp.sqrt(jnp.linalg.det(Lambda) / jnp.linalg.det(sum_))
    sum_inv = jnp.linalg.inv(sum_)
    # Compute exponent: - (1/2) * mu^T * (Σ1 + Σ2 + Lambda)⁻¹ * Γ⁻¹ * mu
    exp_term = -0.5 * ((mu_1 - mu_2).T @ sum_inv @ (mu_1 - mu_2))
    exp_value = jnp.exp(exp_term)
    result = part_1 * exp_value
    return result


D = 3  # Dimension
mu_1 = jnp.array([1.0, -0.5, 0.3])
mu_2 = jnp.array([0.5, 0.2, -0.1])
Sigma_1 = jnp.array([[1.0, 0.2, 0.1], [0.2, 1.5, 0.3], [0.1, 0.3, 2.0]])
Sigma_2 = jnp.array([[1.2, 0.1, 0.0], [0.1, 1.3, 0.2], [0.0, 0.2, 1.1]])

l = 0.1 # Kernel bandwidth
sample_size = 1000  # Monte Carlo sample sizes
# Compute closed-form solution
closed_form_value = kme_double_RBF_diff_Gaussian(mu_1, mu_2, Sigma_1, Sigma_2, l)
print(closed_form_value)
rng_key = jax.random.PRNGKey(0)
# Generate samples from N(mu, Sigma)
L_1 = jnp.linalg.cholesky(Sigma_1)  # Cholesky decomposition
L_2 = jnp.linalg.cholesky(Sigma_2)  # Cholesky decomposition
rng_key, subkey = jax.random.split(rng_key)
z_1 = jax.random.normal(rng_key, shape=(sample_size, D))  # Standard normal samples
rng_key, subkey = jax.random.split(rng_key)
z_2 = jax.random.normal(rng_key, shape=(sample_size, D))  # Standard normal samples
samples_1 = mu_1 + z_1 @ L_1.T  # Transform to N(mu, Sigma)
samples_2 = mu_2 + z_2 @ L_2.T  # Transform to N(mu, Sigma)
kernel = gaussian_kernel(l)
K = kernel.make_distance_matrix(samples_1, samples_2)
print(K.mean())  # Monte Carlo mean


0.0001892119908346623
0.00019394465995004295


In [7]:
from jax.scipy.special import erf, erfc
from jax.scipy.stats import norm

@jax.jit
def kme_Matern_12_Gaussian_1d(l, y):
    """
    The implementation of the kernel mean embedding of the Matern one half kernel with Gaussian distribution
    Only in one dimension, and the Gaussian distribution is N(0, 1)
    
    Args:
        y: (M, )
        l: scalar

    Returns:
        kernel mean embedding: (M, )
    """
    # part1 = jnp.exp((1 - 2 * l * y) / (2 * l ** 2)) * (1 + erf((-1 + l * y) / (jnp.sqrt(2) * l)))
    # part2 = jnp.exp((1 + 2 * l * y) / (2 * l ** 2)) * erfc((1 / l + y) / jnp.sqrt(2))

    # return (part1 + part2) / 2
    term1 = jnp.exp((1 + 2 * l * (y)) / (2 * l**2)) * norm.cdf((- 1 / l - y))
    term2 = jnp.exp((1 - 2 * l * (y)) / (2 * l**2)) * norm.cdf((y - 1 / l))
    return term1 + term2

@jax.jit
def kme_Matern_12_Gaussian(l, y):
    """
    The implementation of the kernel mean embedding of the Matern one half kernel with Gaussian distribution
    Only in one dimension, and the Gaussian distribution is N(0, 1)
    
    Args:
        y: (M, D)
        l: (D, )

    Returns:
        kernel mean embedding: (M, )
    """
    high_d_map = jax.vmap(kme_Matern_12_Gaussian_1d, in_axes=(0, 0))
    kme_all_d = high_d_map(l, y.T)
    return jnp.prod(kme_all_d, axis=0)

In [8]:
D = 2  # Dimension
# mu_1 = jnp.array([1.0, -0.5, 0.3])
# mu_2 = jnp.array([0.5, 0.2, -0.1])
# Sigma_1 = jnp.array([[1.0, 0.2, 0.1], [0.2, 1.5, 0.3], [0.1, 0.3, 2.0]])
# Sigma_2 = jnp.array([[1.2, 0.1, 0.0], [0.1, 1.3, 0.2], [0.0, 0.2, 1.1]])
mu = jnp.zeros(D)
Sigma = jnp.eye(D)
y = jnp.array([[1.0, -0.5], [0.5, 0.2]])
l = 0.1 # Kernel bandwidth
sample_size = 100000  # Monte Carlo sample sizes
# Compute closed-form solution
closed_form_value = kme_Matern_12_Gaussian(l * jnp.ones(D), y)
print(closed_form_value)
closed_form_value_1 = kme_Matern_12_Gaussian(l * jnp.ones(1), y[:, 0][:, None])
print(closed_form_value_1)
closed_form_value_2 = kme_Matern_12_Gaussian(l * jnp.ones(1), y[:, 1][:, None])
print(closed_form_value_2)


rng_key = jax.random.PRNGKey(0)
# Generate samples from N(mu, Sigma)
L = jnp.linalg.cholesky(Sigma)  # Cholesky decomposition
rng_key, subkey = jax.random.split(rng_key)
z = jax.random.normal(rng_key, shape=(sample_size, D))  # Standard normal samples
samples = mu + z @ L.T  # Transform to N(mu, Sigma)
kernel = laplace_kernel(l)
K = kernel.make_distance_matrix(samples, y)
print(K.mean(0))  # Monte Carlo mean

K1 = kernel.make_distance_matrix(samples[:, 0][:, None], y[:, 0][:, None])
print(K1.mean(0))  # Monte Carlo mean

K2 = kernel.make_distance_matrix(samples[:, 1][:, None], y[:, 1][:, None])
print(K2.mean(0))  # Monte Carlo mean

[0.00338191 0.0054154 ]


[0.04838518 0.06989565]
[0.06989565 0.07747836]
[0.00340957 0.00531105]
[0.04870412 0.06914738]
[0.06971382 0.07698251]


In [9]:
closed_form_value_1 * closed_form_value_2

Array([0.00338191, 0.0054154 ], dtype=float64)

In [10]:
print(kme_Matern_12_Gaussian_1d(l, y[:, 0]))
K1 = kernel.make_distance_matrix(samples[:, 0][:, None], y[:, 0][:, None])
print(K1.mean(0)) 

samples = jax.random.normal(rng_key, shape=(sample_size, 1)) 
K1 = kernel.make_distance_matrix(samples, y[:, 0][:, None])
print(K1.mean(0))

[0.04838518 0.06989565]
[0.04870412 0.06914738]
[0.04903114 0.06921551]


In [11]:
y = jnp.array([[0.0, 0.2]]).T
l = 1.0 # Kernel bandwidth
sample_size = 100000  # Monte Carlo sample sizes
# Compute closed-form solution
closed_form_value = kme_Matern_12_Gaussian_1d(l, y[:, 0])
print(closed_form_value)

z = jax.random.normal(rng_key, shape=(sample_size, 1))  # Standard normal samples
kernel = laplace_kernel(l)
K = kernel.make_distance_matrix(z, y)
print(K.mean(0))  # Monte Carlo mean


[0.52315658 0.51769674]
[0.52215263 0.51703836]


In [12]:
from scipy.spatial.distance import cdist
dists = cdist(z, y, metric='cityblock')  # |x - y| for Laplace kernel
kernel_vals = np.exp(-dists / l)
print(kernel_vals.mean(0))  # Monte Carlo mean

[0.52215263 0.51703836]
