Skip to content

Commit

Permalink
Implementation of jax.scipy.stats.gaussian_kde
Browse files Browse the repository at this point in the history
  • Loading branch information
dfm committed Jun 28, 2022
1 parent 6835dc1 commit 0788d57
Show file tree
Hide file tree
Showing 5 changed files with 489 additions and 1 deletion.
1 change: 1 addition & 0 deletions CHANGELOG.md
Expand Up @@ -13,6 +13,7 @@ PLEASE REMEMBER TO CHANGE THE '..main' WITH AN ACTUAL TAG in GITHUB LINK.
* Changes
* `JaxTestCase` and `JaxTestLoader` have been removed from `jax.test_util`. These
classes have been deprecated since v0.3.1 ({jax-issue}`#11248`).
* Added {class}`jax.scipy.gaussian_kde` ({jax-issue}`#11237`).

## jaxlib 0.3.15 (Unreleased)

Expand Down
15 changes: 15 additions & 0 deletions docs/jax.scipy.rst
Expand Up @@ -319,3 +319,18 @@ jax.scipy.stats.uniform

logpdf
pdf

jax.scipy.stats.gaussian_kde
~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. currentmodule:: jax.scipy.stats
.. autosummary::
:toctree: _autosummary

gaussian_kde
gaussian_kde.evaluate
gaussian_kde.integrate_gaussian
gaussian_kde.integrate_box_1d
gaussian_kde.integrate_kde
gaussian_kde.resample
gaussian_kde.pdf
gaussian_kde.logpdf
270 changes: 270 additions & 0 deletions jax/_src/scipy/stats/kde.py
@@ -0,0 +1,270 @@
# Copyright 2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from dataclasses import dataclass
from functools import partial
from typing import Any

import numpy as np
import scipy.stats as osp_stats

import jax.numpy as jnp
from jax import jit, lax, random, vmap
from jax._src.numpy.lax_numpy import _check_arraylike, _promote_dtypes_inexact
from jax._src.numpy.util import _wraps
from jax._src.tree_util import register_pytree_node_class
from jax.scipy import linalg, special


@_wraps(osp_stats.gaussian_kde, update_doc=False)
@register_pytree_node_class
@dataclass(frozen=True, init=False)
class gaussian_kde:
neff: Any
dataset: Any
weights: Any
covariance: Any
inv_cov: Any

def __init__(self, dataset, bw_method=None, weights=None):
_check_arraylike("gaussian_kde", dataset)
dataset = jnp.atleast_2d(dataset)
if jnp.issubdtype(lax.dtype(dataset), jnp.complexfloating):
raise NotImplementedError("gaussian_kde does not support complex data")
if not dataset.size > 1:
raise ValueError("`dataset` input should have multiple elements.")

d, n = dataset.shape
if weights is not None:
_check_arraylike("gaussian_kde", weights)
dataset, weights = _promote_dtypes_inexact(dataset, weights)
weights = jnp.atleast_1d(weights)
weights /= jnp.sum(weights)
if weights.ndim != 1:
raise ValueError("`weights` input should be one-dimensional.")
if len(weights) != n:
raise ValueError("`weights` input should be of length n")
else:
dataset, = _promote_dtypes_inexact(dataset)
weights = jnp.full(n, 1.0 / n, dtype=dataset.dtype)

self._setattr("dataset", dataset)
self._setattr("weights", weights)
neff = self._setattr("neff", 1 / jnp.sum(weights**2))

bw_method = "scott" if bw_method is None else bw_method
if bw_method == "scott":
factor = jnp.power(neff, -1. / (d + 4))
elif bw_method == "silverman":
factor = jnp.power(neff * (d + 2) / 4.0, -1. / (d + 4))
elif jnp.isscalar(bw_method) and not isinstance(bw_method, str):
factor = bw_method
elif callable(bw_method):
factor = bw_method(self)
else:
raise ValueError(
"`bw_method` should be 'scott', 'silverman', a scalar, or a callable."
)

data_covariance = jnp.atleast_2d(
jnp.cov(dataset, rowvar=1, bias=False, aweights=weights))
data_inv_cov = jnp.linalg.inv(data_covariance)
covariance = data_covariance * factor**2
inv_cov = data_inv_cov / factor**2
self._setattr("covariance", covariance)
self._setattr("inv_cov", inv_cov)

def _setattr(self, name, value):
# Frozen dataclasses don't support setting attributes so we have to
# overload that operation here as they do in the dataclass implementation
object.__setattr__(self, name, value)
return value

def tree_flatten(self):
return ((self.neff, self.dataset, self.weights, self.covariance,
self.inv_cov), None)

@classmethod
def tree_unflatten(cls, aux_data, children):
del aux_data
kde = cls.__new__(cls)
kde._setattr("neff", children[0])
kde._setattr("dataset", children[1])
kde._setattr("weights", children[2])
kde._setattr("covariance", children[3])
kde._setattr("inv_cov", children[4])
return kde

@property
def d(self):
return self.dataset.shape[0]

@property
def n(self):
return self.dataset.shape[1]

@_wraps(osp_stats.gaussian_kde.evaluate, update_doc=False)
def evaluate(self, points):
_check_arraylike("evaluate", points)
points = self._reshape_points(points)
result = _gaussian_kernel_eval(False, self.dataset.T, self.weights[:, None],
points.T, self.inv_cov)
return result[:, 0]

@_wraps(osp_stats.gaussian_kde.__call__, update_doc=False)
def __call__(self, points):
return self.evaluate(points)

@_wraps(osp_stats.gaussian_kde.integrate_gaussian, update_doc=False)
def integrate_gaussian(self, mean, cov):
mean = jnp.atleast_1d(jnp.squeeze(mean))
cov = jnp.atleast_2d(cov)

if mean.shape != (self.d,):
raise ValueError("mean does not have dimension {}".format(self.d))
if cov.shape != (self.d, self.d):
raise ValueError("covariance does not have dimension {}".format(self.d))

chol = linalg.cho_factor(self.covariance + cov)
norm = jnp.sqrt(2 * np.pi)**self.d * jnp.prod(jnp.diag(chol[0]))
norm = 1.0 / norm
return _gaussian_kernel_convolve(chol, norm, self.dataset, self.weights,
mean)

@_wraps(osp_stats.gaussian_kde.integrate_box_1d, update_doc=False)
def integrate_box_1d(self, low, high):
if self.d != 1:
raise ValueError("integrate_box_1d() only handles 1D pdfs")
if jnp.ndim(low) != 0 or jnp.ndim(high) != 0:
raise ValueError(
"the limits of integration in integrate_box_1d must be scalars")
sigma = jnp.squeeze(jnp.sqrt(self.covariance))
low = jnp.squeeze((low - self.dataset) / sigma)
high = jnp.squeeze((high - self.dataset) / sigma)
return jnp.sum(self.weights * (special.ndtr(high) - special.ndtr(low)))

@_wraps(osp_stats.gaussian_kde.integrate_kde, update_doc=False)
def integrate_kde(self, other):
if other.d != self.d:
raise ValueError("KDEs are not the same dimensionality")

chol = linalg.cho_factor(self.covariance + other.covariance)
norm = jnp.sqrt(2 * np.pi)**self.d * jnp.prod(jnp.diag(chol[0]))
norm = 1.0 / norm

sm, lg = (self, other) if self.n < other.n else (other, self)
result = vmap(partial(_gaussian_kernel_convolve, chol, norm, lg.dataset,
lg.weights),
in_axes=1)(sm.dataset)
return jnp.sum(result * sm.weights)

def resample(self, key, shape=()):
r"""Randomly sample a dataset from the estimated pdf
Args:
key: a PRNG key used as the random key.
shape: optional, a tuple of nonnegative integers specifying the result
batch shape; that is, the prefix of the result shape excluding the last
axis.
Returns:
The resampled dataset as an array with shape `(d,) + shape`.
"""
ind_key, eps_key = random.split(key)
ind = random.choice(ind_key, self.n, shape=shape, p=self.weights)
eps = random.multivariate_normal(eps_key,
jnp.zeros(self.d, self.covariance.dtype),
self.covariance,
shape=shape,
dtype=self.dataset.dtype).T
return self.dataset[:, ind] + eps

@_wraps(osp_stats.gaussian_kde.pdf, update_doc=False)
def pdf(self, x):
return self.evaluate(x)

@_wraps(osp_stats.gaussian_kde.logpdf, update_doc=False)
def logpdf(self, x):
_check_arraylike("logpdf", x)
x = self._reshape_points(x)
result = _gaussian_kernel_eval(True, self.dataset.T, self.weights[:, None],
x.T, self.inv_cov)
return result[:, 0]

def integrate_box(self, low_bounds, high_bounds, maxpts=None):
"""This method is not implemented in the JAX interface."""
del low_bounds, high_bounds, maxpts
raise NotImplementedError(
"only 1D box integrations are supported; use `integrate_box_1d`")

def set_bandwidth(self, bw_method=None):
"""This method is not implemented in the JAX interface."""
del bw_method
raise NotImplementedError(
"dynamically changing the bandwidth method is not supported")

def _reshape_points(self, points):
if jnp.issubdtype(lax.dtype(points), jnp.complexfloating):
raise NotImplementedError(
"gaussian_kde does not support complex coordinates")
points = jnp.atleast_2d(points)
d, m = points.shape
if d != self.d:
if d == 1 and m == self.d:
points = jnp.reshape(points, (self.d, 1))
else:
raise ValueError(
"points have dimension {}, dataset has dimension {}".format(
d, self.d))
return points


def _gaussian_kernel_convolve(chol, norm, target, weights, mean):
diff = target - mean[:, None]
alpha = linalg.cho_solve(chol, diff)
arg = 0.5 * jnp.sum(diff * alpha, axis=0)
return norm * jnp.sum(jnp.exp(-arg) * weights)


@partial(jit, static_argnums=0)
def _gaussian_kernel_eval(in_log, points, values, xi, precision):
points, values, xi, precision = _promote_dtypes_inexact(
points, values, xi, precision)
d = points.shape[1]

if xi.shape[1] != d:
raise ValueError("points and xi must have same trailing dim")
if precision.shape != (d, d):
raise ValueError("precision matrix must match data dims")

whitening = linalg.cholesky(precision, lower=True)
points = jnp.dot(points, whitening)
xi = jnp.dot(xi, whitening)
log_norm = jnp.sum(jnp.log(
jnp.diag(whitening))) - 0.5 * d * jnp.log(2 * np.pi)

def kernel(x_test, x_train, y_train):
arg = log_norm - 0.5 * jnp.sum(jnp.square(x_train - x_test))
if in_log:
return jnp.log(y_train) + arg
else:
return y_train * jnp.exp(arg)

reduce = special.logsumexp if in_log else jnp.sum
reduced_kernel = lambda x: reduce(vmap(kernel, in_axes=(None, 0, 0))
(x, points, values),
axis=0)
mapped_kernel = vmap(reduced_kernel)

return mapped_kernel(xi)
1 change: 1 addition & 0 deletions jax/scipy/stats/__init__.py
Expand Up @@ -31,3 +31,4 @@
from jax.scipy.stats import chi2 as chi2
from jax.scipy.stats import betabinom as betabinom
from jax.scipy.stats import gennorm as gennorm
from jax._src.scipy.stats.kde import gaussian_kde as gaussian_kde

0 comments on commit 0788d57

Please sign in to comment.