Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
73b860b
feat: change the KSD metric to use the stein kernel in the blobs benc…
gw265981 Mar 18, 2025
6fc91ca
feat: use gaussian kde as score function and use regularisation for S…
gw265981 Mar 18, 2025
22d9143
fix: improve typing for score function
gw265981 Mar 18, 2025
4c0d806
feat: add KDE bandwidth method as a parameter to SteinThinning
gw265981 Mar 18, 2025
a8ee69c
feat: use scipy gaussian_kde for SteinThinning score function in benc…
gw265981 Mar 18, 2025
41fcd48
Merge branch 'main' into bugfix/investigate-stein-thinning
gw265981 Mar 18, 2025
ed405a3
fix: references to external method
gw265981 Mar 18, 2025
16b610e
feat: update benchmarks save paths
gw265981 Mar 19, 2025
254fb63
feat: save frame plots for pounce benchmark
gw265981 Mar 19, 2025
70dc93e
feat: update benchmarking results
gw265981 Mar 19, 2025
6c78ac9
feat: fix UMAP random_state and reduce UMAP n_components for pounce_b…
gw265981 Mar 19, 2025
822df59
feat: update benchmarking results
gw265981 Mar 19, 2025
10000f9
feat: add time results for blobs benchmark
gw265981 Mar 19, 2025
9b4f97f
test: add test for SteinThinning score function in benchmark_util.py
gw265981 Mar 19, 2025
4954ba6
test: update score function test
gw265981 Mar 19, 2025
21faf7e
docs: add clarification about unsupervised benchmarking
gw265981 Mar 20, 2025
a90dfc9
docs: improve alt text
gw265981 Mar 20, 2025
a4b560f
fix: type in benchmark alt text
gw265981 Mar 20, 2025
0a34877
docs: add full stops
gw265981 Mar 20, 2025
8350ff7
docs: save path clarification for david_benchmark.py
gw265981 Mar 20, 2025
c893b94
docs: add changes to `CHANGELOG.md`
gw265981 Mar 20, 2025
12614e0
Merge remote-tracking branch 'origin/bugfix/investigate-stein-thinnin…
gw265981 Mar 20, 2025
31bfe2e
fix: typo in benchmark docs page
gw265981 Mar 20, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 7 additions & 3 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

- Added Compress++ coreset reduction algorithm.
(https://github.com/gchq/coreax/issues/934)
- Added `reduce_iterative()` method to Kernel Herding. (https://github.
com/gchq/coreax/pull/983)
- Added `reduce_iterative()` method to Kernel Herding. (https://github.com/gchq/coreax/pull/983)
- Added probabilistic iterative Kernel Herding benchmarking results. (https://github.com/gchq/coreax/pull/983)

### Fixed
Expand All @@ -21,7 +20,12 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Changed

-
- Changed the score function used by Stein Thinning in benchmarking.
(https://github.com/gchq/coreax/pull/1000)
- Fixed the random state for UMAP in benchmarking for reproducibility.
(https://github.com/gchq/coreax/pull/1000)
- Reduced the number of dimensions when applying UMAP in `pounce_benchmark.py`.
(https://github.com/gchq/coreax/pull/1000)

### Removed

Expand Down
42 changes: 30 additions & 12 deletions benchmark/blobs_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,13 +31,16 @@
import json
import os
import time
from typing import Union

import jax
import jax.numpy as jnp
import jax.scipy as jsp
import numpy as np
from jaxtyping import Array, Shaped
from sklearn.datasets import make_blobs

from coreax import Data, SlicedScoreMatching
from coreax import Data
from coreax.benchmark_util import IterativeKernelHerding
from coreax.kernels import (
SquaredExponentialKernel,
Expand Down Expand Up @@ -84,17 +87,32 @@ def setup_stein_kernel(
:param random_seed: An integer seed for the random number generator.
:return: A SteinKernel object.
"""
sliced_score_matcher = SlicedScoreMatching(
jax.random.PRNGKey(random_seed),
jax.random.rademacher,
use_analytic=True,
num_random_vectors=100,
learning_rate=0.001,
num_epochs=50,
)
# Fit a Gaussian kernel density estimator on a subset of points for efficiency
num_data_points = len(dataset)
num_samples_length_scale = min(num_data_points, 1000)
generator = np.random.default_rng(random_seed)
idx = generator.choice(num_data_points, num_samples_length_scale, replace=False)
kde = jsp.stats.gaussian_kde(dataset.data[idx].T)

# Define the score function as the gradient of log density given by the KDE
def score_function(
x: Union[Shaped[Array, " n d"], Shaped[Array, ""], float, int],
) -> Union[Shaped[Array, " n d"], Shaped[Array, " 1 1"]]:
"""
Compute the score function (gradient of log density) for a single point.

:param x: Input point represented as array
:return: Gradient of log probability density at the given point
"""

def logpdf_single(x: Shaped[Array, " d"]) -> Shaped[Array, ""]:
return kde.logpdf(x.reshape(1, -1))[0]

return jax.grad(logpdf_single)(x)

return SteinKernel(
base_kernel=sq_exp_kernel,
score_function=sliced_score_matcher.match(jnp.asarray(dataset.data)),
score_function=score_function,
)


Expand Down Expand Up @@ -142,7 +160,7 @@ def setup_solvers(
SteinThinning(
coreset_size=coreset_size,
kernel=stein_kernel,
regularise=False,
regularise=True,
),
),
(
Expand Down Expand Up @@ -308,7 +326,7 @@ def main() -> None: # pylint: disable=too-many-locals

# Set up metrics
mmd_metric = MMD(kernel=sq_exp_kernel)
ksd_metric = KSD(kernel=sq_exp_kernel)
ksd_metric = KSD(kernel=stein_kernel) # KSD needs a Stein kernel

# Set up weights optimiser
weights_optimiser = MMDWeightsOptimiser(kernel=sq_exp_kernel)
Expand Down
Loading
Loading