diff --git a/CHANGELOG.md b/CHANGELOG.md index f120e7a82..1fe2b3902 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -11,8 +11,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Added Compress++ coreset reduction algorithm. (https://github.com/gchq/coreax/issues/934) -- Added `reduce_iterative()` method to Kernel Herding. (https://github. - com/gchq/coreax/pull/983) +- Added `reduce_iterative()` method to Kernel Herding. (https://github.com/gchq/coreax/pull/983) - Added probabilistic iterative Kernel Herding benchmarking results. (https://github.com/gchq/coreax/pull/983) ### Fixed @@ -21,7 +20,12 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Changed -- +- Changed the score function used by Stein Thinning in benchmarking. + (https://github.com/gchq/coreax/pull/1000) +- Fixed the random state for UMAP in benchmarking for reproducibility. + (https://github.com/gchq/coreax/pull/1000) +- Reduced the number of dimensions when applying UMAP in `pounce_benchmark.py`. + (https://github.com/gchq/coreax/pull/1000) ### Removed diff --git a/benchmark/blobs_benchmark.py b/benchmark/blobs_benchmark.py index 1566983b3..b317c1965 100644 --- a/benchmark/blobs_benchmark.py +++ b/benchmark/blobs_benchmark.py @@ -31,13 +31,16 @@ import json import os import time +from typing import Union import jax import jax.numpy as jnp +import jax.scipy as jsp import numpy as np +from jaxtyping import Array, Shaped from sklearn.datasets import make_blobs -from coreax import Data, SlicedScoreMatching +from coreax import Data from coreax.benchmark_util import IterativeKernelHerding from coreax.kernels import ( SquaredExponentialKernel, @@ -84,17 +87,32 @@ def setup_stein_kernel( :param random_seed: An integer seed for the random number generator. :return: A SteinKernel object. """ - sliced_score_matcher = SlicedScoreMatching( - jax.random.PRNGKey(random_seed), - jax.random.rademacher, - use_analytic=True, - num_random_vectors=100, - learning_rate=0.001, - num_epochs=50, - ) + # Fit a Gaussian kernel density estimator on a subset of points for efficiency + num_data_points = len(dataset) + num_samples_length_scale = min(num_data_points, 1000) + generator = np.random.default_rng(random_seed) + idx = generator.choice(num_data_points, num_samples_length_scale, replace=False) + kde = jsp.stats.gaussian_kde(dataset.data[idx].T) + + # Define the score function as the gradient of log density given by the KDE + def score_function( + x: Union[Shaped[Array, " n d"], Shaped[Array, ""], float, int], + ) -> Union[Shaped[Array, " n d"], Shaped[Array, " 1 1"]]: + """ + Compute the score function (gradient of log density) for a single point. + + :param x: Input point represented as array + :return: Gradient of log probability density at the given point + """ + + def logpdf_single(x: Shaped[Array, " d"]) -> Shaped[Array, ""]: + return kde.logpdf(x.reshape(1, -1))[0] + + return jax.grad(logpdf_single)(x) + return SteinKernel( base_kernel=sq_exp_kernel, - score_function=sliced_score_matcher.match(jnp.asarray(dataset.data)), + score_function=score_function, ) @@ -142,7 +160,7 @@ def setup_solvers( SteinThinning( coreset_size=coreset_size, kernel=stein_kernel, - regularise=False, + regularise=True, ), ), ( @@ -308,7 +326,7 @@ def main() -> None: # pylint: disable=too-many-locals # Set up metrics mmd_metric = MMD(kernel=sq_exp_kernel) - ksd_metric = KSD(kernel=sq_exp_kernel) + ksd_metric = KSD(kernel=stein_kernel) # KSD needs a Stein kernel # Set up weights optimiser weights_optimiser = MMDWeightsOptimiser(kernel=sq_exp_kernel) diff --git a/benchmark/blobs_benchmark_results.json b/benchmark/blobs_benchmark_results.json index 3e4e6ed95..fa7678624 100644 --- a/benchmark/blobs_benchmark_results.json +++ b/benchmark/blobs_benchmark_results.json @@ -2,261 +2,261 @@ "25": { "KernelHerding": { "Unweighted_MMD": 0.024272788688540457, - "Unweighted_KSD": 0.07254663035273552, + "Unweighted_KSD": 0.08634152710437774, "Weighted_MMD": 0.008470626920461655, - "Weighted_KSD": 0.07226677536964417, - "Time": 4.6005674901998646 + "Weighted_KSD": 0.07446694374084473, + "Time": 4.765064420400449 }, "RandomSample": { "Unweighted_MMD": 0.11142438650131226, - "Unweighted_KSD": 0.07730772346258163, + "Unweighted_KSD": 0.08814063891768456, "Weighted_MMD": 0.011223642434924842, - "Weighted_KSD": 0.07383343055844308, - "Time": 3.495483254200008 + "Weighted_KSD": 0.07585941627621651, + "Time": 3.3727504627993765 }, "RPCholesky": { "Unweighted_MMD": 0.14004717767238617, - "Unweighted_KSD": 0.059306026250123975, + "Unweighted_KSD": 0.07314706221222878, "Weighted_MMD": 0.003688266733661294, - "Weighted_KSD": 0.07196912616491317, - "Time": 4.230013548799798 + "Weighted_KSD": 0.06093888357281685, + "Time": 4.0264426012003245 }, "SteinThinning": { - "Unweighted_MMD": 0.14796174466609954, - "Unweighted_KSD": 0.07581270486116409, - "Weighted_MMD": 0.01757101807743311, - "Weighted_KSD": 0.07423881441354752, - "Time": 4.806701995600088 + "Unweighted_MMD": 0.14493760019540786, + "Unweighted_KSD": 0.08524694442749023, + "Weighted_MMD": 0.06338460296392441, + "Weighted_KSD": 0.08662225008010864, + "Time": 5.611508011200203 }, "KernelThinning": { "Unweighted_MMD": 0.014880230650305748, - "Unweighted_KSD": 0.07227124646306038, + "Unweighted_KSD": 0.07588417455554008, "Weighted_MMD": 0.005388019885867834, - "Weighted_KSD": 0.07246261909604072, - "Time": 27.173368372400184 + "Weighted_KSD": 0.06449438109993935, + "Time": 25.01412621220006 }, "CompressPlusPlus": { "Unweighted_MMD": 0.013212332502007484, - "Unweighted_KSD": 0.07247907817363738, + "Unweighted_KSD": 0.08404513746500016, "Weighted_MMD": 0.007080519571900368, - "Weighted_KSD": 0.07277720794081688, - "Time": 17.304506266200043 + "Weighted_KSD": 0.08123543560504913, + "Time": 16.713567591600075 }, "ProbabilisticIterativeHerding": { "Unweighted_MMD": 0.021128473989665508, - "Unweighted_KSD": 0.0732197754085064, + "Unweighted_KSD": 0.08938155919313431, "Weighted_MMD": 0.007852014992386103, - "Weighted_KSD": 0.07306945249438286, - "Time": 4.6694933262000635 + "Weighted_KSD": 0.08065845519304275, + "Time": 4.702327222399617 }, "IterativeHerding": { "Unweighted_MMD": 0.007051250245422125, - "Unweighted_KSD": 0.07203583419322968, + "Unweighted_KSD": 0.06839882656931877, "Weighted_MMD": 0.005125141562893986, - "Weighted_KSD": 0.07220595926046372, - "Time": 4.062583659599841 + "Weighted_KSD": 0.06586349084973335, + "Time": 3.8252486778001185 }, "CubicProbIterativeHerding": { "Unweighted_MMD": 0.004542805999517441, - "Unweighted_KSD": 0.07216479405760765, + "Unweighted_KSD": 0.08182733058929444, "Weighted_MMD": 0.003512424463406205, - "Weighted_KSD": 0.07236581966280937, - "Time": 4.687457689599796 + "Weighted_KSD": 0.07799033671617508, + "Time": 4.375145751399759 } }, "50": { "KernelHerding": { "Unweighted_MMD": 0.014010918885469436, - "Unweighted_KSD": 0.0722734160721302, + "Unweighted_KSD": 0.05761846974492073, "Weighted_MMD": 0.0031911543337628245, - "Weighted_KSD": 0.07209383249282837, - "Time": 4.1393956179999805 + "Weighted_KSD": 0.052470114827156064, + "Time": 4.036917862999689 }, "RandomSample": { "Unweighted_MMD": 0.10492457151412964, - "Unweighted_KSD": 0.07875456660985947, + "Unweighted_KSD": 0.0798761561512947, "Weighted_MMD": 0.004955455008894205, - "Weighted_KSD": 0.07259993627667427, - "Time": 3.580713712999932 + "Weighted_KSD": 0.06159702017903328, + "Time": 3.27908037680063 }, "RPCholesky": { "Unweighted_MMD": 0.1466503471136093, - "Unweighted_KSD": 0.056694062799215315, + "Unweighted_KSD": 0.06491677835583687, "Weighted_MMD": 0.0015391094610095024, - "Weighted_KSD": 0.0722087174654007, - "Time": 3.8200428860001923 + "Weighted_KSD": 0.054540709406137464, + "Time": 3.7208302791997996 }, "SteinThinning": { - "Unweighted_MMD": 0.13258629888296128, - "Unweighted_KSD": 0.07708697170019149, - "Weighted_MMD": 0.006761046499013901, - "Weighted_KSD": 0.07263452410697938, - "Time": 4.231214966799962 + "Unweighted_MMD": 0.0868241548538208, + "Unweighted_KSD": 0.05509435683488846, + "Weighted_MMD": 0.0135643620043993, + "Weighted_KSD": 0.061474745720624925, + "Time": 4.627325403000578 }, "KernelThinning": { "Unweighted_MMD": 0.006303768884390592, - "Unweighted_KSD": 0.07201230749487877, + "Unweighted_KSD": 0.061569680273532865, "Weighted_MMD": 0.0022462865337729452, - "Weighted_KSD": 0.07222185432910919, - "Time": 15.216021602399996 + "Weighted_KSD": 0.05851251482963562, + "Time": 14.038466937999692 }, "CompressPlusPlus": { "Unweighted_MMD": 0.007616249471902847, - "Unweighted_KSD": 0.07215439230203628, + "Unweighted_KSD": 0.06331060901284218, "Weighted_MMD": 0.0028188966680318117, - "Weighted_KSD": 0.07224903926253319, - "Time": 11.209934081999744 + "Weighted_KSD": 0.05671325549483299, + "Time": 10.396489782000208 }, "ProbabilisticIterativeHerding": { "Unweighted_MMD": 0.015107517503201962, - "Unweighted_KSD": 0.07347788587212563, + "Unweighted_KSD": 0.06883815973997116, "Weighted_MMD": 0.003151226742193103, - "Weighted_KSD": 0.07250117510557175, - "Time": 4.343779678600185 + "Weighted_KSD": 0.06300541535019874, + "Time": 4.108718106000015 }, "IterativeHerding": { "Unweighted_MMD": 0.003708381252363324, - "Unweighted_KSD": 0.07212337255477905, + "Unweighted_KSD": 0.05261607468128204, "Weighted_MMD": 0.002603885461576283, - "Weighted_KSD": 0.07219909951090812, - "Time": 3.6810207548000107 + "Weighted_KSD": 0.0531992956995964, + "Time": 3.5771397664000686 }, "CubicProbIterativeHerding": { "Unweighted_MMD": 0.001733466051518917, - "Unweighted_KSD": 0.07222620248794556, + "Unweighted_KSD": 0.058075586706399916, "Weighted_MMD": 0.001442490390036255, - "Weighted_KSD": 0.07229570895433426, - "Time": 4.199541498000144 + "Weighted_KSD": 0.059921151399612425, + "Time": 4.120307515600507 } }, "100": { "KernelHerding": { "Unweighted_MMD": 0.007909100409597159, - "Unweighted_KSD": 0.07176313027739525, + "Unweighted_KSD": 0.04663899913430214, "Weighted_MMD": 0.0018589411629363894, - "Weighted_KSD": 0.07220481112599372, - "Time": 4.31388007539972 + "Weighted_KSD": 0.051218368113040924, + "Time": 4.235976945400035 }, "RandomSample": { "Unweighted_MMD": 0.05501915663480759, - "Unweighted_KSD": 0.07520547062158585, + "Unweighted_KSD": 0.0618309035897255, "Weighted_MMD": 0.00180354667827487, - "Weighted_KSD": 0.07226956561207772, - "Time": 3.731109356599518 + "Weighted_KSD": 0.05710694566369057, + "Time": 3.1581929947999017 }, "RPCholesky": { "Unweighted_MMD": 0.09764691218733787, - "Unweighted_KSD": 0.062210434675216676, + "Unweighted_KSD": 0.039633375406265256, "Weighted_MMD": 0.0010440661339089275, - "Weighted_KSD": 0.07225104942917823, - "Time": 4.349850091400003 + "Weighted_KSD": 0.05533244013786316, + "Time": 3.8502491777991965 }, "SteinThinning": { - "Unweighted_MMD": 0.13784433156251907, - "Unweighted_KSD": 0.08129674047231675, - "Weighted_MMD": 0.0046910161152482035, - "Weighted_KSD": 0.07230838015675545, - "Time": 4.689982681799847 + "Unweighted_MMD": 0.09307270124554634, + "Unweighted_KSD": 0.03587676882743836, + "Weighted_MMD": 0.006268286239355802, + "Weighted_KSD": 0.05565239489078522, + "Time": 4.740899266999622 }, "KernelThinning": { "Unweighted_MMD": 0.002685086103156209, - "Unweighted_KSD": 0.07206880524754525, + "Unweighted_KSD": 0.05610363930463791, "Weighted_MMD": 0.001265210215933621, - "Weighted_KSD": 0.07226345017552376, - "Time": 10.10230621419978 + "Weighted_KSD": 0.0581892117857933, + "Time": 9.000170669599902 }, "CompressPlusPlus": { "Unweighted_MMD": 0.0029356910847127436, - "Unweighted_KSD": 0.07219576761126519, + "Unweighted_KSD": 0.05573995485901832, "Weighted_MMD": 0.0012260458199307323, - "Weighted_KSD": 0.07228517681360244, - "Time": 9.244769073800308 + "Weighted_KSD": 0.055947765707969666, + "Time": 8.099010541199823 }, "ProbabilisticIterativeHerding": { "Unweighted_MMD": 0.009710153844207526, - "Unweighted_KSD": 0.07278616279363632, + "Unweighted_KSD": 0.062317197769880296, "Weighted_MMD": 0.0018384325550869108, - "Weighted_KSD": 0.07236671000719071, - "Time": 4.425218307400064 + "Weighted_KSD": 0.059105978906154634, + "Time": 4.5184859611999855 }, "IterativeHerding": { "Unweighted_MMD": 0.0022563493344932794, - "Unweighted_KSD": 0.07212945297360421, + "Unweighted_KSD": 0.04880478382110596, "Weighted_MMD": 0.001406662119552493, - "Weighted_KSD": 0.07225525602698327, - "Time": 4.298705347399846 + "Weighted_KSD": 0.05116616040468216, + "Time": 4.1359605843997995 }, "CubicProbIterativeHerding": { "Unweighted_MMD": 0.0008045180700719356, - "Unweighted_KSD": 0.07221448868513107, + "Unweighted_KSD": 0.05193357020616531, "Weighted_MMD": 0.0009792268159799279, - "Weighted_KSD": 0.07225939556956291, - "Time": 4.68569216800006 + "Weighted_KSD": 0.054329156130552295, + "Time": 4.49999562500052 } }, "200": { "KernelHerding": { "Unweighted_MMD": 0.004258563183248043, - "Unweighted_KSD": 0.0720168687403202, + "Unweighted_KSD": 0.047414638102054596, "Weighted_MMD": 0.0011734690284356474, - "Weighted_KSD": 0.0722421571612358, - "Time": 4.809446495800148 + "Weighted_KSD": 0.05488279089331627, + "Time": 4.568870213000264 }, "RandomSample": { "Unweighted_MMD": 0.04152125939726829, - "Unweighted_KSD": 0.07231617346405983, + "Unweighted_KSD": 0.05796747878193855, "Weighted_MMD": 0.000913540180772543, - "Weighted_KSD": 0.0722603291273117, - "Time": 3.7448029847997533 + "Weighted_KSD": 0.055494559556245805, + "Time": 3.401281245000064 }, "RPCholesky": { "Unweighted_MMD": 0.05692300647497177, - "Unweighted_KSD": 0.0671866662800312, + "Unweighted_KSD": 0.04246571436524391, "Weighted_MMD": 0.0008295111590996384, - "Weighted_KSD": 0.07224812433123588, - "Time": 4.360847868199926 + "Weighted_KSD": 0.053956735879182816, + "Time": 4.136736245400243 }, "SteinThinning": { - "Unweighted_MMD": 0.14454428851604462, - "Unweighted_KSD": 0.08556406646966934, - "Weighted_MMD": 0.0028360273223370313, - "Weighted_KSD": 0.07215539738535881, - "Time": 4.83350045979987 + "Unweighted_MMD": 0.10421270728111268, + "Unweighted_KSD": 0.024421826750040055, + "Weighted_MMD": 0.003508183266967535, + "Weighted_KSD": 0.055823469161987306, + "Time": 5.040176883599997 }, "KernelThinning": { "Unweighted_MMD": 0.0015182187082245946, - "Unweighted_KSD": 0.07213710397481918, + "Unweighted_KSD": 0.054005082696676254, "Weighted_MMD": 0.000885988853406161, - "Weighted_KSD": 0.07226478308439255, - "Time": 6.940934421800193 + "Weighted_KSD": 0.05745472684502602, + "Time": 6.787894143599624 }, "CompressPlusPlus": { "Unweighted_MMD": 0.0014102120650932193, - "Unweighted_KSD": 0.07215408384799957, + "Unweighted_KSD": 0.05317859873175621, "Weighted_MMD": 0.0007552313501946629, - "Weighted_KSD": 0.07224038168787957, - "Time": 7.29123429639967 + "Weighted_KSD": 0.05463842824101448, + "Time": 7.406790399999591 }, "ProbabilisticIterativeHerding": { "Unweighted_MMD": 0.006357756908982992, - "Unweighted_KSD": 0.07269964888691902, + "Unweighted_KSD": 0.05834345370531082, "Weighted_MMD": 0.0008730732253752649, - "Weighted_KSD": 0.07227222323417663, - "Time": 4.814415276399814 + "Weighted_KSD": 0.05702040642499924, + "Time": 4.711837356400428 }, "IterativeHerding": { "Unweighted_MMD": 0.0013821582775563, - "Unweighted_KSD": 0.07216034978628158, + "Unweighted_KSD": 0.0500980406999588, "Weighted_MMD": 0.0009947988553903997, - "Weighted_KSD": 0.07224116325378419, - "Time": 4.238990427600038 + "Weighted_KSD": 0.05419353991746902, + "Time": 4.150569619399903 }, "CubicProbIterativeHerding": { "Unweighted_MMD": 0.0005821106024086475, - "Unweighted_KSD": 0.07220486029982567, + "Unweighted_KSD": 0.05276075378060341, "Weighted_MMD": 0.0007064452278427779, - "Weighted_KSD": 0.07225989773869515, - "Time": 4.936300538400064 + "Weighted_KSD": 0.05621158331632614, + "Time": 4.702852152600099 } } } diff --git a/benchmark/blobs_benchmark_visualiser.py b/benchmark/blobs_benchmark_visualiser.py index 3923c639c..b84671b9e 100644 --- a/benchmark/blobs_benchmark_visualiser.py +++ b/benchmark/blobs_benchmark_visualiser.py @@ -93,7 +93,8 @@ def plot_benchmarking_results(data): plt.grid(True, linestyle="--", alpha=0.7) plt.savefig( - f"../examples/benchmarking_images/blobs_{metric}.png", bbox_inches="tight" + f"../examples/benchmarking_images/blobs_{metric.lower()}.png", + bbox_inches="tight", ) diff --git a/benchmark/david_benchmark.py b/benchmark/david_benchmark.py index 7a5e338b1..779736328 100644 --- a/benchmark/david_benchmark.py +++ b/benchmark/david_benchmark.py @@ -28,8 +28,14 @@ Each coreset algorithm is timed to measure and report the time taken for each step. """ -import math import os +import sys + +sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) + +# Pylint and Ruff are fighting over the position of the local import +# pylint: disable=wrong-import-position +import math import time from pathlib import Path from typing import Optional @@ -44,15 +50,19 @@ from coreax.benchmark_util import initialise_solvers from examples.david_map_reduce_weighted import downsample_opencv +# pylint: enable=wrong-import-position + MAX_8BIT = 255 -# pylint: disable=too-many-locals, too-many-statements +# pylint: disable=too-many-locals # ruff: noqa: PLR0914, PLR0915 # (disable line too long and too many statements ruff) def benchmark_coreset_algorithms( in_path: Path = Path("../examples/data/david_orig.png"), - out_path: Optional[Path] = Path("david_benchmark_results.png"), + out_path: Optional[Path] = Path( + "../examples/benchmarking_images/david_benchmark_results.png" + ), downsampling_factor: int = 1, ): """ @@ -63,7 +73,8 @@ def benchmark_coreset_algorithms( time is printed. :param in_path: Path to the input image file. - :param out_path: Path to save the output benchmark plot image. + :param out_path: Path to save the output benchmark plot image, relative to the + script's location. :param downsampling_factor: Factor by which to downsample the image. """ # Base directory of the current script @@ -131,13 +142,13 @@ def benchmark_coreset_algorithms( # Save the combined benchmark plot if out_path: plt.figure(figsize=(15, 10)) - plt.subplot(3, 3, 1) + plt.subplot(4, 3, 1) plt.imshow(original_data, cmap="gray") plt.title("Original Image") plt.axis("off") for i, (solver_name, coreset_data) in enumerate(coresets.items(), start=2): - plt.subplot(3, 3, i) + plt.subplot(4, 3, i) plt.scatter( coreset_data[:, 1], -coreset_data[:, 0], diff --git a/benchmark/mnist_benchmark.py b/benchmark/mnist_benchmark.py index 083610e39..4f12ca134 100644 --- a/benchmark/mnist_benchmark.py +++ b/benchmark/mnist_benchmark.py @@ -381,7 +381,7 @@ def density_preserving_umap(x: jnp.ndarray, n_components: int = 16) -> jnp.ndarr x_np = np.array(x) # Initialize UMAP with density-preserving option - umap_model = umap.UMAP(densmap=True, n_components=n_components) + umap_model = umap.UMAP(densmap=True, n_components=n_components, random_state=0) # Fit and transform the data x_umap = umap_model.fit_transform(x_np) diff --git a/benchmark/mnist_benchmark_results.json b/benchmark/mnist_benchmark_results.json index c60184e44..e9c889bf0 100644 --- a/benchmark/mnist_benchmark_results.json +++ b/benchmark/mnist_benchmark_results.json @@ -2,1206 +2,1206 @@ "Random Sample": { "25": { "0": { - "accuracy": 0.49939975142478943, - "time_taken": 31.378794076000304 + "accuracy": 0.47499004006385803, + "time_taken": 22.919005189998643 }, "1": { - "accuracy": 0.4205680191516876, - "time_taken": 15.958629951000148 + "accuracy": 0.5094035863876343, + "time_taken": 24.038130666000143 }, "2": { - "accuracy": 0.4317728579044342, - "time_taken": 16.24605190999955 + "accuracy": 0.4429771900177002, + "time_taken": 25.78509461199974 }, "3": { - "accuracy": 0.4267708659172058, - "time_taken": 16.77121249100037 + "accuracy": 0.5266107320785522, + "time_taken": 21.765049056000862 }, "4": { - "accuracy": 0.41426557302474976, - "time_taken": 20.9866746280004 + "accuracy": 0.4708881676197052, + "time_taken": 25.985052893998727 } }, "50": { "0": { - "accuracy": 0.5915995836257935, - "time_taken": 13.46058772699962 + "accuracy": 0.6092996001243591, + "time_taken": 8.013713274998736 }, "1": { - "accuracy": 0.5806996822357178, - "time_taken": 7.307854431999658 + "accuracy": 0.6098998785018921, + "time_taken": 10.523420129999067 }, "2": { - "accuracy": 0.5749998092651367, - "time_taken": 14.328193377000389 + "accuracy": 0.5316997766494751, + "time_taken": 7.394414693000726 }, "3": { - "accuracy": 0.6124996542930603, - "time_taken": 10.097868826000195 + "accuracy": 0.6038997769355774, + "time_taken": 7.596240197000952 }, "4": { - "accuracy": 0.5182000994682312, - "time_taken": 13.699923365000359 + "accuracy": 0.6276999711990356, + "time_taken": 16.082055785000193 } }, "100": { "0": { - "accuracy": 0.6564000844955444, - "time_taken": 6.6002176759993745 + "accuracy": 0.704299807548523, + "time_taken": 6.024247718998595 }, "1": { - "accuracy": 0.7256999611854553, - "time_taken": 4.555681074999484 + "accuracy": 0.7321001887321472, + "time_taken": 5.7248346590004076 }, "2": { - "accuracy": 0.7220999598503113, - "time_taken": 5.915041250000286 + "accuracy": 0.7243001461029053, + "time_taken": 9.890619637999407 }, "3": { - "accuracy": 0.687000036239624, - "time_taken": 5.589812125000208 + "accuracy": 0.7279003262519836, + "time_taken": 3.8261312839986203 }, "4": { - "accuracy": 0.7198001146316528, - "time_taken": 9.042554675000247 + "accuracy": 0.6853998899459839, + "time_taken": 7.29285894900022 } }, "500": { "0": { - "accuracy": 0.8401442170143127, - "time_taken": 4.9041379580003195 + "accuracy": 0.849459171295166, + "time_taken": 4.445891253999434 }, "1": { - "accuracy": 0.8366386294364929, - "time_taken": 3.9820582559996183 + "accuracy": 0.8390424847602844, + "time_taken": 2.849187047000669 }, "2": { - "accuracy": 0.8532652258872986, - "time_taken": 4.150694207000015 + "accuracy": 0.8586738705635071, + "time_taken": 4.086037507999208 }, "3": { - "accuracy": 0.8519631624221802, - "time_taken": 3.2897759840006984 + "accuracy": 0.8433493971824646, + "time_taken": 3.1370794300000853 }, "4": { - "accuracy": 0.8568710088729858, - "time_taken": 3.2805830689994764 + "accuracy": 0.8508613705635071, + "time_taken": 4.695724795999922 } }, "1000": { "0": { - "accuracy": 0.8800080418586731, - "time_taken": 4.472438157999932 + "accuracy": 0.8806089758872986, + "time_taken": 4.090531890999046 }, "1": { - "accuracy": 0.8753004670143127, - "time_taken": 3.1494866820003153 + "accuracy": 0.8755007982254028, + "time_taken": 2.858250543000395 }, "2": { - "accuracy": 0.8787059187889099, - "time_taken": 3.3807121429999825 + "accuracy": 0.8828125, + "time_taken": 2.871943344000101 }, "3": { - "accuracy": 0.8870192170143127, - "time_taken": 3.851220937999642 + "accuracy": 0.8731971383094788, + "time_taken": 2.942463305000274 }, "4": { - "accuracy": 0.8785056471824646, - "time_taken": 3.1983475079996424 + "accuracy": 0.8818109035491943, + "time_taken": 3.2579264050000347 } }, "5000": { "0": { - "accuracy": 0.9187700152397156, - "time_taken": 4.877470247000019 + "accuracy": 0.9258814454078674, + "time_taken": 4.775947737000024 }, "1": { - "accuracy": 0.9220753312110901, - "time_taken": 4.563210882000021 + "accuracy": 0.9238781929016113, + "time_taken": 3.5404121189985744 }, "2": { - "accuracy": 0.9268830418586731, - "time_taken": 4.27181648900023 + "accuracy": 0.9277844429016113, + "time_taken": 3.586515529999815 }, "3": { - "accuracy": 0.9260817170143127, - "time_taken": 4.163932319000196 + "accuracy": 0.9291867017745972, + "time_taken": 5.097970101000101 }, "4": { - "accuracy": 0.9253806471824646, - "time_taken": 5.401279224999598 + "accuracy": 0.9296875, + "time_taken": 3.7926768680008536 } } }, "RP Cholesky": { "25": { "0": { - "accuracy": 0.5479191541671753, - "time_taken": 30.12485848600045 + "accuracy": 0.47208893299102783, + "time_taken": 14.432026581000173 }, "1": { - "accuracy": 0.5831333994865417, - "time_taken": 17.1290923939996 + "accuracy": 0.5369148850440979, + "time_taken": 16.501476474999436 }, "2": { - "accuracy": 0.506402313709259, - "time_taken": 28.668303671999638 + "accuracy": 0.6021402478218079, + "time_taken": 18.394893914999557 }, "3": { - "accuracy": 0.48569461703300476, - "time_taken": 22.65919453600054 + "accuracy": 0.4675869345664978, + "time_taken": 18.458187942998848 }, "4": { - "accuracy": 0.4908963739871979, - "time_taken": 25.84216402400034 + "accuracy": 0.5124050378799438, + "time_taken": 25.307517440000083 } }, "50": { "0": { - "accuracy": 0.5573997497558594, - "time_taken": 11.868133194999245 + "accuracy": 0.6644995808601379, + "time_taken": 17.061049896999975 }, "1": { - "accuracy": 0.6715995669364929, - "time_taken": 9.038644168000246 + "accuracy": 0.6897998452186584, + "time_taken": 8.493855818998782 }, "2": { - "accuracy": 0.5750996470451355, - "time_taken": 15.650777019999623 + "accuracy": 0.6449994444847107, + "time_taken": 18.410714375999305 }, "3": { - "accuracy": 0.615599513053894, - "time_taken": 17.46370884499993 + "accuracy": 0.5964997410774231, + "time_taken": 10.035908525000195 }, "4": { - "accuracy": 0.6460996270179749, - "time_taken": 12.382691423999859 + "accuracy": 0.641399621963501, + "time_taken": 17.728979279001578 } }, "100": { "0": { - "accuracy": 0.7053002119064331, - "time_taken": 13.038988242000414 + "accuracy": 0.73580002784729, + "time_taken": 12.027434380999694 }, "1": { - "accuracy": 0.6796000599861145, - "time_taken": 5.734187043999555 + "accuracy": 0.7024999856948853, + "time_taken": 6.223794345998613 }, "2": { - "accuracy": 0.6684001684188843, - "time_taken": 10.304006375000426 + "accuracy": 0.6677001118659973, + "time_taken": 7.77737258399975 }, "3": { - "accuracy": 0.7305001020431519, - "time_taken": 6.787468426000487 + "accuracy": 0.6699000597000122, + "time_taken": 7.581072974999188 }, "4": { - "accuracy": 0.6646999716758728, - "time_taken": 8.555311609000455 + "accuracy": 0.7248998284339905, + "time_taken": 8.663397701000576 } }, "500": { "0": { - "accuracy": 0.8393429517745972, - "time_taken": 5.700323695000407 + "accuracy": 0.8440504670143127, + "time_taken": 5.951178929000889 }, "1": { - "accuracy": 0.8234174847602844, - "time_taken": 4.577245696000318 + "accuracy": 0.8293269276618958, + "time_taken": 3.659458459000234 }, "2": { - "accuracy": 0.8246194124221802, - "time_taken": 4.333464565000213 + "accuracy": 0.8071914911270142, + "time_taken": 5.165111156999046 }, "3": { - "accuracy": 0.8459535241127014, - "time_taken": 4.002924071000052 + "accuracy": 0.8140023946762085, + "time_taken": 4.427434246999837 }, "4": { - "accuracy": 0.8322315812110901, - "time_taken": 4.44240554299995 + "accuracy": 0.8303285241127014, + "time_taken": 3.714743741998973 } }, "1000": { "0": { - "accuracy": 0.8704928159713745, - "time_taken": 6.094056519999867 + "accuracy": 0.8708934187889099, + "time_taken": 6.251749818000462 }, "1": { - "accuracy": 0.8677884936332703, - "time_taken": 4.771197477999522 + "accuracy": 0.8620793223381042, + "time_taken": 4.676454960999763 }, "2": { - "accuracy": 0.8700921535491943, - "time_taken": 5.191751323000062 + "accuracy": 0.8752003312110901, + "time_taken": 5.054982719999316 }, "3": { - "accuracy": 0.8765023946762085, - "time_taken": 5.251704577000055 + "accuracy": 0.850661039352417, + "time_taken": 5.026370794999821 }, "4": { - "accuracy": 0.857271671295166, - "time_taken": 5.254659973999878 + "accuracy": 0.8616787195205688, + "time_taken": 4.259451833999265 } }, "5000": { "0": { - "accuracy": 0.9259815812110901, - "time_taken": 31.580425969000316 + "accuracy": 0.9230769276618958, + "time_taken": 30.088407684999765 }, "1": { - "accuracy": 0.9291867017745972, - "time_taken": 30.337975878999714 + "accuracy": 0.931490421295166, + "time_taken": 28.840131775999907 }, "2": { - "accuracy": 0.9300881624221802, - "time_taken": 31.068755215000238 + "accuracy": 0.9295873641967773, + "time_taken": 29.57311406699955 }, "3": { - "accuracy": 0.9288862347602844, - "time_taken": 28.69117661600012 + "accuracy": 0.9299879670143127, + "time_taken": 28.48633675000019 }, "4": { - "accuracy": 0.9255809187889099, - "time_taken": 28.678187123000498 + "accuracy": 0.9237780570983887, + "time_taken": 28.90727943500133 } } }, "Kernel Herding": { "25": { "0": { - "accuracy": 0.36344558000564575, - "time_taken": 27.81852218200038 + "accuracy": 0.4873950481414795, + "time_taken": 15.986496007999449 }, "1": { - "accuracy": 0.40876394510269165, - "time_taken": 28.917832838000322 + "accuracy": 0.4643857777118683, + "time_taken": 28.306637690999196 }, "2": { - "accuracy": 0.37885135412216187, - "time_taken": 33.120471524000095 + "accuracy": 0.44697871804237366, + "time_taken": 18.1760879630001 }, "3": { - "accuracy": 0.3827531635761261, - "time_taken": 16.74550173100033 + "accuracy": 0.42897194623947144, + "time_taken": 12.775741837000169 }, "4": { - "accuracy": 0.3984594941139221, - "time_taken": 31.39536281100027 + "accuracy": 0.42096811532974243, + "time_taken": 13.941555733999849 } }, "50": { "0": { - "accuracy": 0.5153999328613281, - "time_taken": 8.394265891000032 + "accuracy": 0.5863996148109436, + "time_taken": 7.962145343999509 }, "1": { - "accuracy": 0.5643997192382812, - "time_taken": 8.034130296000512 + "accuracy": 0.5681999921798706, + "time_taken": 8.845074394999756 }, "2": { - "accuracy": 0.5117997527122498, - "time_taken": 7.021193917999881 + "accuracy": 0.5861996412277222, + "time_taken": 8.543767514000137 }, "3": { - "accuracy": 0.5069000720977783, - "time_taken": 8.373046021999471 + "accuracy": 0.5385997891426086, + "time_taken": 7.017327412000668 }, "4": { - "accuracy": 0.5413997173309326, - "time_taken": 8.041774192999583 + "accuracy": 0.5630999803543091, + "time_taken": 6.821517419000884 } }, "100": { "0": { - "accuracy": 0.6686000823974609, - "time_taken": 6.880627470000036 + "accuracy": 0.6545001268386841, + "time_taken": 6.125024203000066 }, "1": { - "accuracy": 0.6687001585960388, - "time_taken": 5.998970880000343 + "accuracy": 0.6799001097679138, + "time_taken": 4.6802674750015285 }, "2": { - "accuracy": 0.6764000654220581, - "time_taken": 6.708478034000109 + "accuracy": 0.6616002321243286, + "time_taken": 4.902932041000895 }, "3": { - "accuracy": 0.6682003140449524, - "time_taken": 7.342494532999808 + "accuracy": 0.6335000991821289, + "time_taken": 4.237600578000638 }, "4": { - "accuracy": 0.672999918460846, - "time_taken": 6.854923852999491 + "accuracy": 0.6631001830101013, + "time_taken": 5.592238436000116 } }, "500": { "0": { - "accuracy": 0.7989783883094788, - "time_taken": 5.009305618999861 + "accuracy": 0.8102964758872986, + "time_taken": 4.772365287000866 }, "1": { - "accuracy": 0.7846554517745972, - "time_taken": 3.426840018999428 + "accuracy": 0.8169070482254028, + "time_taken": 3.474500064001404 }, "2": { - "accuracy": 0.7994791865348816, - "time_taken": 3.6989639409994197 + "accuracy": 0.8141025900840759, + "time_taken": 4.07752587700088 }, "3": { - "accuracy": 0.7809495329856873, - "time_taken": 3.7089166180003303 + "accuracy": 0.8239182829856873, + "time_taken": 3.45547284299937 }, "4": { - "accuracy": 0.7839543223381042, - "time_taken": 3.630458699999508 + "accuracy": 0.809495210647583, + "time_taken": 4.055496584000139 } }, "1000": { "0": { - "accuracy": 0.8508613705635071, - "time_taken": 5.865284939999583 + "accuracy": 0.8573718070983887, + "time_taken": 4.763669799000127 }, "1": { - "accuracy": 0.8452523946762085, - "time_taken": 3.5790511420000257 + "accuracy": 0.8505609035491943, + "time_taken": 3.301596471001176 }, "2": { - "accuracy": 0.8441506624221802, - "time_taken": 3.766006290000405 + "accuracy": 0.8493589758872986, + "time_taken": 3.361836451000272 }, "3": { - "accuracy": 0.8497596383094788, - "time_taken": 3.753327657000227 + "accuracy": 0.8515625, + "time_taken": 3.086794232000102 }, "4": { - "accuracy": 0.8400440812110901, - "time_taken": 3.5520159130001048 + "accuracy": 0.8539663553237915, + "time_taken": 3.2005308640000294 } }, "5000": { "0": { - "accuracy": 0.9318910241127014, - "time_taken": 8.529647805000423 + "accuracy": 0.9243789911270142, + "time_taken": 9.635175270001128 }, "1": { - "accuracy": 0.9282852411270142, - "time_taken": 5.212250545999268 + "accuracy": 0.9276843070983887, + "time_taken": 6.37886958099989 }, "2": { - "accuracy": 0.9311898946762085, - "time_taken": 6.773341258000073 + "accuracy": 0.9252804517745972, + "time_taken": 5.1813290219997725 }, "3": { - "accuracy": 0.9261819124221802, - "time_taken": 5.1819647869997425 + "accuracy": 0.9282852411270142, + "time_taken": 5.52222976899975 }, "4": { - "accuracy": 0.9283854365348816, - "time_taken": 6.26639689700005 + "accuracy": 0.9260817170143127, + "time_taken": 5.381160470000395 } } }, "Stein Thinning": { "25": { "0": { - "accuracy": 0.36664658784866333, - "time_taken": 26.155145622999953 + "accuracy": 0.45598235726356506, + "time_taken": 20.06922941299854 }, "1": { - "accuracy": 0.35044023394584656, - "time_taken": 16.872243384000285 + "accuracy": 0.43107256293296814, + "time_taken": 20.058883295998385 }, "2": { - "accuracy": 0.33363354206085205, - "time_taken": 16.71151909799937 + "accuracy": 0.44067633152008057, + "time_taken": 24.775430047000555 }, "3": { - "accuracy": 0.3225291073322296, - "time_taken": 18.860941359000208 + "accuracy": 0.4304724931716919, + "time_taken": 26.971532478999507 }, "4": { - "accuracy": 0.342837393283844, - "time_taken": 15.351801773999796 + "accuracy": 0.44317731261253357, + "time_taken": 28.394798983999863 } }, "50": { "0": { - "accuracy": 0.3817000985145569, - "time_taken": 13.093195344000378 + "accuracy": 0.48150014877319336, + "time_taken": 14.085924653998518 }, "1": { - "accuracy": 0.3882002532482147, - "time_taken": 11.234985342999607 + "accuracy": 0.48899999260902405, + "time_taken": 12.579714716001035 }, "2": { - "accuracy": 0.4043000638484955, - "time_taken": 11.797061467000276 + "accuracy": 0.4916000962257385, + "time_taken": 15.621676943001148 }, "3": { - "accuracy": 0.3527999520301819, - "time_taken": 11.524091536000014 + "accuracy": 0.4845997095108032, + "time_taken": 17.71667785200043 }, "4": { - "accuracy": 0.38370028138160706, - "time_taken": 11.592887677000363 + "accuracy": 0.48520001769065857, + "time_taken": 15.180106365000029 } }, "100": { "0": { - "accuracy": 0.47479987144470215, - "time_taken": 9.806702724999923 + "accuracy": 0.5978000164031982, + "time_taken": 11.944835819000218 }, "1": { - "accuracy": 0.4605000913143158, - "time_taken": 10.19817711199994 + "accuracy": 0.6014001965522766, + "time_taken": 10.989858828999786 }, "2": { - "accuracy": 0.5033999085426331, - "time_taken": 9.12059926999973 + "accuracy": 0.6025000810623169, + "time_taken": 10.250695644999723 }, "3": { - "accuracy": 0.4161998927593231, - "time_taken": 8.590373965999788 + "accuracy": 0.5681000351905823, + "time_taken": 13.388944784999694 }, "4": { - "accuracy": 0.4584999680519104, - "time_taken": 9.745934776000468 + "accuracy": 0.5780000686645508, + "time_taken": 11.565425942999354 } }, "500": { "0": { - "accuracy": 0.5769230723381042, - "time_taken": 8.572717877999821 + "accuracy": 0.7495993971824646, + "time_taken": 15.627634819000377 }, "1": { - "accuracy": 0.549879789352417, - "time_taken": 8.685776891999922 + "accuracy": 0.7189503312110901, + "time_taken": 11.987972888000513 }, "2": { - "accuracy": 0.5654046535491943, - "time_taken": 8.738311274999432 + "accuracy": 0.743990421295166, + "time_taken": 12.925626980000743 }, "3": { - "accuracy": 0.5692107677459717, - "time_taken": 8.89052372200058 + "accuracy": 0.7327724695205688, + "time_taken": 14.741296112999407 }, "4": { - "accuracy": 0.5557892918586731, - "time_taken": 8.744390723999459 + "accuracy": 0.7393830418586731, + "time_taken": 14.22917673199845 } }, "1000": { "0": { - "accuracy": 0.5848357677459717, - "time_taken": 10.872244053999566 + "accuracy": 0.8052884936332703, + "time_taken": 16.97553089900066 }, "1": { - "accuracy": 0.579026460647583, - "time_taken": 10.691493848000391 + "accuracy": 0.7964743971824646, + "time_taken": 16.09774988100071 }, "2": { - "accuracy": 0.5827323794364929, - "time_taken": 10.416130721000627 + "accuracy": 0.8101963400840759, + "time_taken": 15.728706653999325 }, "3": { - "accuracy": 0.5981570482254028, - "time_taken": 10.580251977999978 + "accuracy": 0.7945713400840759, + "time_taken": 15.161316943000202 }, "4": { - "accuracy": 0.5770232677459717, - "time_taken": 9.95398137400025 + "accuracy": 0.798677921295166, + "time_taken": 15.857722766000734 } }, "5000": { "0": { - "accuracy": 0.6365184187889099, - "time_taken": 31.84610059700026 + "accuracy": 0.8912259936332703, + "time_taken": 56.13571467200018 }, "1": { - "accuracy": 0.6400240659713745, - "time_taken": 31.580490523000663 + "accuracy": 0.8883213400840759, + "time_taken": 55.521121544999914 }, "2": { - "accuracy": 0.6353164911270142, - "time_taken": 30.96478199500052 + "accuracy": 0.891526460647583, + "time_taken": 53.86374767200141 }, "3": { - "accuracy": 0.6447315812110901, - "time_taken": 32.943920733000596 + "accuracy": 0.8930288553237915, + "time_taken": 55.25788913999895 }, "4": { - "accuracy": 0.6456330418586731, - "time_taken": 30.633445284999652 + "accuracy": 0.8855168223381042, + "time_taken": 54.47995065100076 } } }, "Kernel Thinning": { "25": { "0": { - "accuracy": 0.5464185476303101, - "time_taken": 110.06621436300065 + "accuracy": 0.49289700388908386, + "time_taken": 93.46730104400012 }, "1": { - "accuracy": 0.46028411388397217, - "time_taken": 25.186751725999784 + "accuracy": 0.4674869775772095, + "time_taken": 15.905900599000233 }, "2": { - "accuracy": 0.43417370319366455, - "time_taken": 19.643712891000177 + "accuracy": 0.45948395133018494, + "time_taken": 24.767529401999127 }, "3": { - "accuracy": 0.3735492527484894, - "time_taken": 19.60971717100074 + "accuracy": 0.5210083723068237, + "time_taken": 19.328794359000312 }, "4": { - "accuracy": 0.4456782639026642, - "time_taken": 20.31466248600009 + "accuracy": 0.4576832354068756, + "time_taken": 22.94440306300021 } }, "50": { "0": { - "accuracy": 0.6653998494148254, - "time_taken": 59.023018455000056 + "accuracy": 0.6489993929862976, + "time_taken": 48.49551541999972 }, "1": { - "accuracy": 0.6325995326042175, - "time_taken": 11.286527774000206 + "accuracy": 0.6273996233940125, + "time_taken": 10.125411686000007 }, "2": { - "accuracy": 0.5777995586395264, - "time_taken": 10.881928441000127 + "accuracy": 0.6058997511863708, + "time_taken": 7.737919924999005 }, "3": { - "accuracy": 0.5715998411178589, - "time_taken": 10.913744210000004 + "accuracy": 0.5623999238014221, + "time_taken": 6.257317322000745 }, "4": { - "accuracy": 0.6640992760658264, - "time_taken": 12.46789212099975 + "accuracy": 0.5890994071960449, + "time_taken": 8.691867642000943 } }, "100": { "0": { - "accuracy": 0.7268001437187195, - "time_taken": 35.0293219160003 + "accuracy": 0.7095999717712402, + "time_taken": 23.469334597999477 }, "1": { - "accuracy": 0.7283000946044922, - "time_taken": 5.59897917499984 + "accuracy": 0.7138001322746277, + "time_taken": 5.669212320000952 }, "2": { - "accuracy": 0.7609001994132996, - "time_taken": 6.843098536999605 + "accuracy": 0.7202000021934509, + "time_taken": 6.4046101799995085 }, "3": { - "accuracy": 0.7039998173713684, - "time_taken": 5.921434724000392 + "accuracy": 0.7141001224517822, + "time_taken": 5.954546537999704 }, "4": { - "accuracy": 0.6804999113082886, - "time_taken": 5.160561370000323 + "accuracy": 0.7061000466346741, + "time_taken": 5.680743742999766 } }, "500": { "0": { - "accuracy": 0.8551682829856873, - "time_taken": 12.34510114800014 + "accuracy": 0.8561698794364929, + "time_taken": 14.068776127000092 }, "1": { - "accuracy": 0.8511618971824646, - "time_taken": 4.445879892999983 + "accuracy": 0.8535656929016113, + "time_taken": 4.094917693999378 }, "2": { - "accuracy": 0.8497596383094788, - "time_taken": 3.9442537409995566 + "accuracy": 0.8541666865348816, + "time_taken": 3.4944589299993822 }, "3": { - "accuracy": 0.8512620329856873, - "time_taken": 4.5083901919997516 + "accuracy": 0.8581730723381042, + "time_taken": 3.896774025999548 }, "4": { - "accuracy": 0.838942289352417, - "time_taken": 3.8187096999999994 + "accuracy": 0.8525640964508057, + "time_taken": 3.939488691999941 } }, "1000": { "0": { - "accuracy": 0.8794070482254028, - "time_taken": 11.011465002000477 + "accuracy": 0.8832131624221802, + "time_taken": 9.80166764599926 }, "1": { - "accuracy": 0.875901460647583, - "time_taken": 3.8713226500003657 + "accuracy": 0.8770031929016113, + "time_taken": 3.624613356998452 }, "2": { - "accuracy": 0.8819110989570618, - "time_taken": 8.961700078999456 + "accuracy": 0.8835136294364929, + "time_taken": 3.490673861000687 }, "3": { - "accuracy": 0.8772035241127014, - "time_taken": 3.7053063889998157 + "accuracy": 0.8834134936332703, + "time_taken": 4.077375505999953 }, "4": { - "accuracy": 0.8855168223381042, - "time_taken": 3.8933140440003626 + "accuracy": 0.8765023946762085, + "time_taken": 3.2147287230000074 } }, "5000": { "0": { - "accuracy": 0.9211738705635071, - "time_taken": 15.707877896000355 + "accuracy": 0.9253806471824646, + "time_taken": 15.89239508799983 }, "1": { - "accuracy": 0.9233773946762085, - "time_taken": 6.091783504999512 + "accuracy": 0.9262820482254028, + "time_taken": 5.632784289999108 }, "2": { - "accuracy": 0.9256811141967773, - "time_taken": 5.776395839000543 + "accuracy": 0.9228765964508057, + "time_taken": 6.14335931000096 }, "3": { - "accuracy": 0.926682710647583, - "time_taken": 6.179523726000298 + "accuracy": 0.927584171295166, + "time_taken": 5.6770064319989615 }, "4": { - "accuracy": 0.926682710647583, - "time_taken": 5.389667761000055 + "accuracy": 0.9281851053237915, + "time_taken": 7.158779569001126 } } }, "Compress++": { "25": { "0": { - "accuracy": 0.48279327154159546, - "time_taken": 63.73896679699919 + "accuracy": 0.4751904010772705, + "time_taken": 58.80561342200053 }, "1": { - "accuracy": 0.49789944291114807, - "time_taken": 19.31496702000004 + "accuracy": 0.5812321305274963, + "time_taken": 22.111687761998837 }, "2": { - "accuracy": 0.46188488602638245, - "time_taken": 19.98056298300071 + "accuracy": 0.4190676510334015, + "time_taken": 26.76624609800092 }, "3": { - "accuracy": 0.46188458800315857, - "time_taken": 24.72102610499951 + "accuracy": 0.4723888635635376, + "time_taken": 18.55126123700029 }, "4": { - "accuracy": 0.4839934706687927, - "time_taken": 38.3823201969999 + "accuracy": 0.4623848795890808, + "time_taken": 15.92773811799998 } }, "50": { "0": { - "accuracy": 0.5958995819091797, - "time_taken": 30.356294965000416 + "accuracy": 0.6083995699882507, + "time_taken": 23.726150198999676 }, "1": { - "accuracy": 0.6557992696762085, - "time_taken": 10.960626758999751 + "accuracy": 0.6278995275497437, + "time_taken": 9.184842091000974 }, "2": { - "accuracy": 0.6099992394447327, - "time_taken": 12.611458585000037 + "accuracy": 0.5989994406700134, + "time_taken": 7.2256520180017105 }, "3": { - "accuracy": 0.6473994255065918, - "time_taken": 10.239416935000008 + "accuracy": 0.6024994254112244, + "time_taken": 11.275976351000281 }, "4": { - "accuracy": 0.5785996317863464, - "time_taken": 9.679247836000286 + "accuracy": 0.6049997806549072, + "time_taken": 9.112975457001085 } }, "100": { "0": { - "accuracy": 0.7106999754905701, - "time_taken": 21.78798379 + "accuracy": 0.7419999241828918, + "time_taken": 13.535142405000443 }, "1": { - "accuracy": 0.7305997014045715, - "time_taken": 5.884108313999604 + "accuracy": 0.6955998539924622, + "time_taken": 5.146802350000144 }, "2": { - "accuracy": 0.7083998918533325, - "time_taken": 5.77458919899982 + "accuracy": 0.7277998924255371, + "time_taken": 7.010694806998799 }, "3": { - "accuracy": 0.6870997548103333, - "time_taken": 5.419124719000138 + "accuracy": 0.7425999045372009, + "time_taken": 5.005660921000526 }, "4": { - "accuracy": 0.7167999148368835, - "time_taken": 5.903611237000405 + "accuracy": 0.7139999270439148, + "time_taken": 5.747074933999102 } }, "500": { "0": { - "accuracy": 0.8560697436332703, - "time_taken": 8.24205664999954 + "accuracy": 0.8392428159713745, + "time_taken": 12.817026272001385 }, "1": { - "accuracy": 0.8601762652397156, - "time_taken": 3.8466831529995034 + "accuracy": 0.8415464758872986, + "time_taken": 3.799437917999967 }, "2": { - "accuracy": 0.846754789352417, - "time_taken": 4.29086993799956 + "accuracy": 0.8514623641967773, + "time_taken": 3.6323357679993933 }, "3": { - "accuracy": 0.8469551205635071, - "time_taken": 3.9742898100003003 + "accuracy": 0.8482572436332703, + "time_taken": 3.938341524999487 }, "4": { - "accuracy": 0.8459535241127014, - "time_taken": 3.465456570999777 + "accuracy": 0.8571714758872986, + "time_taken": 3.3295341939992795 } }, "1000": { "0": { - "accuracy": 0.8879206776618958, - "time_taken": 7.61852171999999 + "accuracy": 0.8842147588729858, + "time_taken": 6.926109034000547 }, "1": { - "accuracy": 0.876802921295166, - "time_taken": 3.580353852999906 + "accuracy": 0.8783053159713745, + "time_taken": 3.3443215120005334 }, "2": { - "accuracy": 0.8792067170143127, - "time_taken": 4.019800891000159 + "accuracy": 0.8814102411270142, + "time_taken": 3.149032872999669 }, "3": { - "accuracy": 0.8861178159713745, - "time_taken": 3.750217070999497 + "accuracy": 0.8711939454078674, + "time_taken": 3.198121670000546 }, "4": { - "accuracy": 0.8814102411270142, - "time_taken": 3.537440410000272 + "accuracy": 0.8781049847602844, + "time_taken": 3.529270744000314 } }, "5000": { "0": { - "accuracy": 0.9228765964508057, - "time_taken": 6.762459595999644 + "accuracy": 0.9259815812110901, + "time_taken": 7.221486595000897 }, "1": { - "accuracy": 0.9269831776618958, - "time_taken": 5.589771165999991 + "accuracy": 0.9256811141967773, + "time_taken": 4.537618502001351 }, "2": { - "accuracy": 0.9271835088729858, - "time_taken": 5.489222301999689 + "accuracy": 0.9207732677459717, + "time_taken": 3.9046958340004494 }, "3": { - "accuracy": 0.9243789911270142, - "time_taken": 5.061679284000093 + "accuracy": 0.9258814454078674, + "time_taken": 4.473094468001364 }, "4": { - "accuracy": 0.9262820482254028, - "time_taken": 5.058138197999142 + "accuracy": 0.9301882982254028, + "time_taken": 5.37894143199992 } } }, "Iterative Probabilistic Herding (constant)": { "25": { "0": { - "accuracy": 0.565226137638092, - "time_taken": 25.58306674800042 + "accuracy": 0.5338136553764343, + "time_taken": 24.73460551700009 }, "1": { - "accuracy": 0.5584231019020081, - "time_taken": 26.9211215360001 + "accuracy": 0.5995398163795471, + "time_taken": 26.695493264000106 }, "2": { - "accuracy": 0.5649256110191345, - "time_taken": 24.513512876999812 + "accuracy": 0.5077033638954163, + "time_taken": 17.516685377000613 }, "3": { - "accuracy": 0.5609245300292969, - "time_taken": 27.941308749999735 + "accuracy": 0.5578231811523438, + "time_taken": 29.453460516999257 }, "4": { - "accuracy": 0.5159065127372742, - "time_taken": 19.32287024500056 + "accuracy": 0.5562224984169006, + "time_taken": 28.26684361900152 } }, "50": { "0": { - "accuracy": 0.6795998811721802, - "time_taken": 12.993868425000073 + "accuracy": 0.6718997955322266, + "time_taken": 14.062168884998755 }, "1": { - "accuracy": 0.6618993282318115, - "time_taken": 13.741371570999945 + "accuracy": 0.6653995513916016, + "time_taken": 8.245754612999008 }, "2": { - "accuracy": 0.7035000324249268, - "time_taken": 12.33730979000029 + "accuracy": 0.6303992867469788, + "time_taken": 10.585270434999984 }, "3": { - "accuracy": 0.6759998798370361, - "time_taken": 10.023791417000211 + "accuracy": 0.6818996071815491, + "time_taken": 9.023693425000602 }, "4": { - "accuracy": 0.6457993388175964, - "time_taken": 11.025239813000553 + "accuracy": 0.7031996846199036, + "time_taken": 10.645060467000803 } }, "100": { "0": { - "accuracy": 0.7630999088287354, - "time_taken": 10.076966365999397 + "accuracy": 0.721299946308136, + "time_taken": 9.759340487000372 }, "1": { - "accuracy": 0.7409002780914307, - "time_taken": 7.087535218999619 + "accuracy": 0.7022000551223755, + "time_taken": 6.031035981000969 }, "2": { - "accuracy": 0.7367997765541077, - "time_taken": 6.279074885999762 + "accuracy": 0.7301000952720642, + "time_taken": 5.960567097999956 }, "3": { - "accuracy": 0.7450000047683716, - "time_taken": 6.497801151000203 + "accuracy": 0.7073997259140015, + "time_taken": 6.447358765999525 }, "4": { - "accuracy": 0.7276996970176697, - "time_taken": 6.045073967999997 + "accuracy": 0.738899827003479, + "time_taken": 6.314079733001563 } }, "500": { "0": { - "accuracy": 0.8546674847602844, - "time_taken": 7.87245468600031 + "accuracy": 0.8675881624221802, + "time_taken": 7.881801559999076 }, "1": { - "accuracy": 0.8464543223381042, - "time_taken": 4.942239801000142 + "accuracy": 0.8488581776618958, + "time_taken": 4.193020180000531 }, "2": { - "accuracy": 0.8433493971824646, - "time_taken": 5.234136874999422 + "accuracy": 0.8550681471824646, + "time_taken": 4.261552747999303 }, "3": { - "accuracy": 0.8550681471824646, - "time_taken": 4.858938104000117 + "accuracy": 0.8592748641967773, + "time_taken": 4.009508700000879 }, "4": { - "accuracy": 0.856370210647583, - "time_taken": 5.168977192999591 + "accuracy": 0.8515625, + "time_taken": 4.392168127998957 } }, "1000": { "0": { - "accuracy": 0.8842147588729858, - "time_taken": 8.376013503999275 + "accuracy": 0.8854166865348816, + "time_taken": 8.312590070001534 }, "1": { - "accuracy": 0.8757011294364929, - "time_taken": 5.392493352999736 + "accuracy": 0.8839142918586731, + "time_taken": 4.755574097000135 }, "2": { - "accuracy": 0.8839142918586731, - "time_taken": 4.918340146000446 + "accuracy": 0.8805088400840759, + "time_taken": 4.353428874999736 }, "3": { - "accuracy": 0.884615421295166, - "time_taken": 4.754172397999355 + "accuracy": 0.8853164911270142, + "time_taken": 4.671408873000473 }, "4": { - "accuracy": 0.8897235989570618, - "time_taken": 5.224267777000023 + "accuracy": 0.8811097741127014, + "time_taken": 4.730935759000204 } }, "5000": { "0": { - "accuracy": 0.9242788553237915, - "time_taken": 21.26693089299988 + "accuracy": 0.9200721383094788, + "time_taken": 18.878691423999044 }, "1": { - "accuracy": 0.9253806471824646, - "time_taken": 14.351450629999817 + "accuracy": 0.9269831776618958, + "time_taken": 14.102433917998496 }, "2": { - "accuracy": 0.9214743971824646, - "time_taken": 14.56304887999977 + "accuracy": 0.9232772588729858, + "time_taken": 14.372946815999967 }, "3": { - "accuracy": 0.9211738705635071, - "time_taken": 15.201094977000139 + "accuracy": 0.9251803159713745, + "time_taken": 14.534260874999745 }, "4": { - "accuracy": 0.9302884936332703, - "time_taken": 15.344466227999874 + "accuracy": 0.922776460647583, + "time_taken": 14.080622875999325 } } }, "Iterative Herding": { "25": { "0": { - "accuracy": 0.469487726688385, - "time_taken": 26.213904901000205 + "accuracy": 0.5183072090148926, + "time_taken": 20.482818079000936 }, "1": { - "accuracy": 0.4717889130115509, - "time_taken": 40.966588446999594 + "accuracy": 0.4997002184391022, + "time_taken": 33.79406960199958 }, "2": { - "accuracy": 0.4678872227668762, - "time_taken": 17.88325883499965 + "accuracy": 0.4737892746925354, + "time_taken": 20.2240928909996 }, "3": { - "accuracy": 0.4661862552165985, - "time_taken": 28.512475307000386 + "accuracy": 0.4853939116001129, + "time_taken": 22.00765271399905 }, "4": { - "accuracy": 0.4657862186431885, - "time_taken": 21.193546985000467 + "accuracy": 0.49249720573425293, + "time_taken": 27.128820004998488 } }, "50": { "0": { - "accuracy": 0.6057994365692139, - "time_taken": 16.804117133000545 + "accuracy": 0.646899402141571, + "time_taken": 10.587521859000844 }, "1": { - "accuracy": 0.5874000191688538, - "time_taken": 10.829441335999945 + "accuracy": 0.6253997087478638, + "time_taken": 11.966102681999473 }, "2": { - "accuracy": 0.5901995897293091, - "time_taken": 14.236112396999488 + "accuracy": 0.6364994645118713, + "time_taken": 10.352756050999233 }, "3": { - "accuracy": 0.5842994451522827, - "time_taken": 10.729420103999473 + "accuracy": 0.598800003528595, + "time_taken": 10.920208049999928 }, "4": { - "accuracy": 0.593899667263031, - "time_taken": 18.04358775899982 + "accuracy": 0.6159996390342712, + "time_taken": 9.585873128000458 } }, "100": { "0": { - "accuracy": 0.7108997106552124, - "time_taken": 6.8043059060000814 + "accuracy": 0.6893000602722168, + "time_taken": 7.230347970000366 }, "1": { - "accuracy": 0.7032998204231262, - "time_taken": 5.164458477000153 + "accuracy": 0.7013002038002014, + "time_taken": 5.632414011000947 }, "2": { - "accuracy": 0.7057000398635864, - "time_taken": 5.732392590000018 + "accuracy": 0.7097997069358826, + "time_taken": 4.813542661999236 }, "3": { - "accuracy": 0.689500093460083, - "time_taken": 5.783705160999489 + "accuracy": 0.7190999984741211, + "time_taken": 6.265559563998977 }, "4": { - "accuracy": 0.676800012588501, - "time_taken": 5.26132266500008 + "accuracy": 0.701799750328064, + "time_taken": 4.968709823000609 } }, "500": { "0": { - "accuracy": 0.8225160241127014, - "time_taken": 6.017370114999721 + "accuracy": 0.8287259936332703, + "time_taken": 6.028517775999717 }, "1": { - "accuracy": 0.8190104365348816, - "time_taken": 4.340642064999884 + "accuracy": 0.8367387652397156, + "time_taken": 4.297074895999685 }, "2": { - "accuracy": 0.8218148946762085, - "time_taken": 4.0769241510006395 + "accuracy": 0.8322315812110901, + "time_taken": 3.9299392890006857 }, "3": { - "accuracy": 0.8159054517745972, - "time_taken": 3.6375408430003517 + "accuracy": 0.8322315812110901, + "time_taken": 4.2436198350005725 }, "4": { - "accuracy": 0.8083934187889099, - "time_taken": 4.066575695999745 + "accuracy": 0.8248197436332703, + "time_taken": 3.769788012999925 } }, "1000": { "0": { - "accuracy": 0.8711939454078674, - "time_taken": 5.933206661999975 + "accuracy": 0.8669871687889099, + "time_taken": 5.508468142999845 }, "1": { - "accuracy": 0.8691906929016113, - "time_taken": 3.854292703000283 + "accuracy": 0.8685897588729858, + "time_taken": 3.7791977529996075 }, "2": { - "accuracy": 0.8723958730697632, - "time_taken": 4.351314984000055 + "accuracy": 0.8708934187889099, + "time_taken": 3.860622815998795 }, "3": { - "accuracy": 0.8621795177459717, - "time_taken": 3.9070940849996987 + "accuracy": 0.8730969429016113, + "time_taken": 3.898559622000903 }, "4": { - "accuracy": 0.8653846383094788, - "time_taken": 4.015361558999757 + "accuracy": 0.8717948794364929, + "time_taken": 4.052534259999447 } }, "5000": { "0": { - "accuracy": 0.927584171295166, - "time_taken": 12.062748391000241 + "accuracy": 0.9273838400840759, + "time_taken": 12.479124105000665 }, "1": { - "accuracy": 0.930588960647583, - "time_taken": 9.747504790999301 + "accuracy": 0.9300881624221802, + "time_taken": 10.870270003999394 }, "2": { - "accuracy": 0.9271835088729858, - "time_taken": 8.544544992999363 + "accuracy": 0.9292868971824646, + "time_taken": 8.873208498998792 }, "3": { - "accuracy": 0.9268830418586731, - "time_taken": 9.586745757999779 + "accuracy": 0.9269831776618958, + "time_taken": 8.184928790999038 }, "4": { - "accuracy": 0.926682710647583, - "time_taken": 8.666436417999648 + "accuracy": 0.9264823794364929, + "time_taken": 8.986154870999599 } } }, "Iterative Probabilistic Herding (cubic)": { "25": { "0": { - "accuracy": 0.5662263631820679, - "time_taken": 26.524820819000524 + "accuracy": 0.5639256834983826, + "time_taken": 19.656280438000977 }, "1": { - "accuracy": 0.5555218458175659, - "time_taken": 22.971091108999644 + "accuracy": 0.5903359651565552, + "time_taken": 23.445653445998687 }, "2": { - "accuracy": 0.5328129529953003, - "time_taken": 20.666025272000297 + "accuracy": 0.5108044147491455, + "time_taken": 21.126809224999306 }, "3": { - "accuracy": 0.5929372310638428, - "time_taken": 18.04689451000013 + "accuracy": 0.5239095091819763, + "time_taken": 17.322403440000926 }, "4": { - "accuracy": 0.5388152003288269, - "time_taken": 27.70253951199993 + "accuracy": 0.4981989562511444, + "time_taken": 13.911915508000675 } }, "50": { "0": { - "accuracy": 0.6779000163078308, - "time_taken": 18.710263177000343 + "accuracy": 0.7299996018409729, + "time_taken": 13.690802703998997 }, "1": { - "accuracy": 0.6948000192642212, - "time_taken": 8.480895946999226 + "accuracy": 0.6956995725631714, + "time_taken": 9.36956431499857 }, "2": { - "accuracy": 0.6437996029853821, - "time_taken": 10.354708187000142 + "accuracy": 0.6731998324394226, + "time_taken": 7.570936809999694 }, "3": { - "accuracy": 0.6868996024131775, - "time_taken": 12.53823441699933 + "accuracy": 0.6613995432853699, + "time_taken": 7.556182764999903 }, "4": { - "accuracy": 0.6839997172355652, - "time_taken": 14.371656973999961 + "accuracy": 0.6961996555328369, + "time_taken": 10.754087191000508 } }, "100": { "0": { - "accuracy": 0.779999852180481, - "time_taken": 9.07973868299996 + "accuracy": 0.7239999175071716, + "time_taken": 9.73435688599966 }, "1": { - "accuracy": 0.7574002742767334, - "time_taken": 6.124219661000097 + "accuracy": 0.7531000971794128, + "time_taken": 4.95020767799906 }, "2": { - "accuracy": 0.745400071144104, - "time_taken": 6.907441441000628 + "accuracy": 0.7430996298789978, + "time_taken": 5.362517735000438 }, "3": { - "accuracy": 0.7519000172615051, - "time_taken": 6.637311533000684 + "accuracy": 0.7399001717567444, + "time_taken": 5.40860261600028 }, "4": { - "accuracy": 0.7315998077392578, - "time_taken": 6.438722759000484 + "accuracy": 0.7192999124526978, + "time_taken": 6.977815176000149 } }, "500": { "0": { - "accuracy": 0.859375, - "time_taken": 8.543036389000008 + "accuracy": 0.850661039352417, + "time_taken": 7.863592741001412 }, "1": { - "accuracy": 0.8518629670143127, - "time_taken": 4.909884863000116 + "accuracy": 0.8597756624221802, + "time_taken": 4.861156242999641 }, "2": { - "accuracy": 0.8504607677459717, - "time_taken": 5.362902012000632 + "accuracy": 0.8591746687889099, + "time_taken": 4.396951866001473 }, "3": { - "accuracy": 0.8542668223381042, - "time_taken": 5.555552781999722 + "accuracy": 0.8482572436332703, + "time_taken": 4.890931624999212 }, "4": { - "accuracy": 0.8544671535491943, - "time_taken": 4.767543308999848 + "accuracy": 0.8622796535491943, + "time_taken": 4.996591860999615 } }, "1000": { "0": { - "accuracy": 0.8869190812110901, - "time_taken": 9.989208918000259 + "accuracy": 0.8874198794364929, + "time_taken": 9.331451460000608 }, "1": { - "accuracy": 0.8870192170143127, - "time_taken": 5.912931406999633 + "accuracy": 0.8836137652397156, + "time_taken": 5.63444943200011 }, "2": { - "accuracy": 0.8907251954078674, - "time_taken": 6.551261347999571 + "accuracy": 0.8883213400840759, + "time_taken": 5.859720279999237 }, "3": { - "accuracy": 0.8823117017745972, - "time_taken": 6.210722205999446 + "accuracy": 0.8873197436332703, + "time_taken": 6.163236370999584 }, "4": { - "accuracy": 0.8866186141967773, - "time_taken": 6.336088992000441 + "accuracy": 0.8865184187889099, + "time_taken": 5.887336122999841 } }, "5000": { "0": { - "accuracy": 0.9264823794364929, - "time_taken": 31.977671014000407 + "accuracy": 0.9237780570983887, + "time_taken": 29.019777364999754 }, "1": { - "accuracy": 0.9283854365348816, - "time_taken": 24.778856055999313 + "accuracy": 0.9285857677459717, + "time_taken": 24.14570309699957 }, "2": { - "accuracy": 0.9249799847602844, - "time_taken": 24.93020751200038 + "accuracy": 0.9238781929016113, + "time_taken": 24.03303279300053 }, "3": { - "accuracy": 0.9258814454078674, - "time_taken": 23.563160845999846 + "accuracy": 0.9291867017745972, + "time_taken": 24.484326515001158 }, "4": { - "accuracy": 0.9239783883094788, - "time_taken": 23.45357099100056 + "accuracy": 0.9248798489570618, + "time_taken": 23.91146829200079 } } } diff --git a/benchmark/mnist_time_results.json b/benchmark/mnist_time_results.json index 4f190d185..79ee6b04f 100644 --- a/benchmark/mnist_time_results.json +++ b/benchmark/mnist_time_results.json @@ -1,335 +1,335 @@ { "Random Sample": { "25": { - "0": 0.647310689000733, - "1": 0.001746535999700427, - "2": 0.0016490729976794682, - "3": 0.0017102239980886225, - "4": 0.0017371739995724056 + "0": 0.6397729159980372, + "1": 0.0017929950008692686, + "2": 0.001671838002948789, + "3": 0.0016534319984202739, + "4": 0.0018852100001822691 }, "50": { - "0": 0.6191554599972733, - "1": 0.0017478970003139693, - "2": 0.001827262000006158, - "3": 0.0016880760013009422, - "4": 0.0017338979996566195 + "0": 0.6058683080009359, + "1": 0.0017296629994234536, + "2": 0.0016193310002563521, + "3": 0.0017611430012038909, + "4": 0.001899507002235623 }, "100": { - "0": 0.6146928409980319, - "1": 0.0015965790007612668, - "2": 0.001910418999614194, - "3": 0.0017455419983889442, - "4": 0.0016654669998388272 + "0": 0.602887089997239, + "1": 0.0017442200005461928, + "2": 0.0017450609993829858, + "3": 0.001634698000998469, + "4": 0.001743682998494478 }, "500": { - "0": 0.6247970969998278, - "1": 0.001661883001361275, - "2": 0.0020321159972809255, - "3": 0.001670051999099087, - "4": 0.0016311489998770412 + "0": 0.610039818999212, + "1": 0.0016881660012586508, + "2": 0.002094398998451652, + "3": 0.0017250429991690908, + "4": 0.0020420569999259897 }, "1000": { - "0": 0.6269827109972539, - "1": 0.0016431939984613564, - "2": 0.0016867839985934552, - "3": 0.0017590919997019228, - "4": 0.0016369789991586003 + "0": 0.6131921340020199, + "1": 0.0015884900021774229, + "2": 0.0017381260004185606, + "3": 0.0016522740006621461, + "4": 0.001942560000316007 } }, "RP Cholesky": { "25": { - "0": 1.5792628960007278, - "1": 0.009678587997768773, - "2": 0.009231661999365315, - "3": 0.008769658001256175, - "4": 0.009058035000634845 + "0": 1.5261813510005595, + "1": 0.008889092001481913, + "2": 0.008908703999622958, + "3": 0.008831459999782965, + "4": 0.00892095800008974 }, "50": { - "0": 1.511850119997689, - "1": 0.017149059000075795, - "2": 0.01731069299785304, - "3": 0.016735487999540055, - "4": 0.017396160001226235 + "0": 1.4747764250023465, + "1": 0.01738219599792501, + "2": 0.017394433001754805, + "3": 0.01717097599976114, + "4": 0.017007919002935523 }, "100": { - "0": 1.5369271290001052, - "1": 0.03710502700050711, - "2": 0.036878171002172166, - "3": 0.03643120099877706, - "4": 0.0368910990000586 + "0": 1.5036276520004321, + "1": 0.03718040099920472, + "2": 0.03720370900191483, + "3": 0.03674045799925807, + "4": 0.03665894100049627 }, "500": { - "0": 2.0553133670000534, - "1": 0.38911631099836086, - "2": 0.3895299329997215, - "3": 0.3905363150006451, - "4": 0.3903145599979325 + "0": 2.0446486809996713, + "1": 0.3870472409980721, + "2": 0.3871140369992645, + "3": 0.38680070500049624, + "4": 0.3865599739983736 }, "1000": { - "0": 2.8241980749990034, - "1": 1.2761649700005364, - "2": 1.2760263800009852, - "3": 1.2775282250004238, - "4": 1.2765133399989281 + "0": 2.799040249999962, + "1": 1.2717059960014012, + "2": 1.271023598001193, + "3": 1.2724880509995273, + "4": 1.272995469000307 } }, "Kernel Herding": { "25": { - "0": 2.6031983829998353, - "1": 0.26574823800183367, - "2": 0.26442587800192996, - "3": 0.26058913100132486, - "4": 0.25983492900195415 + "0": 2.5674014530013665, + "1": 0.26696247899963055, + "2": 0.2704821110019111, + "3": 0.2681096330015862, + "4": 0.26634214899968356 }, "50": { - "0": 1.9571695459999319, - "1": 0.2680527449992951, - "2": 0.26516477200129884, - "3": 0.2671307109994814, - "4": 0.2625435150002886 + "0": 1.9552257959985582, + "1": 0.26597738900090917, + "2": 0.26720760699754464, + "3": 0.26862574299957487, + "4": 0.2677930559984816 }, "100": { - "0": 1.6518758749989502, - "1": 0.2736544399995182, - "2": 0.27683376699860673, - "3": 0.27000973599933786, - "4": 0.26825259999895934 + "0": 1.616994022999279, + "1": 0.27444080199711607, + "2": 0.2744868519985175, + "3": 0.27481753100073547, + "4": 0.27534870299859904 }, "500": { - "0": 1.8148316490005527, - "1": 0.3298745380016044, - "2": 0.3237015800004883, - "3": 0.3160861960022885, - "4": 0.32276979100061 + "0": 1.7750997059993097, + "1": 0.32438827000078163, + "2": 0.32237087300018175, + "3": 0.3228230339991569, + "4": 0.32342054699984146 }, "1000": { - "0": 1.8868193730013445, - "1": 0.38048298899957445, - "2": 0.3765919639990898, - "3": 0.3759127279990935, - "4": 0.37570304299879353 + "0": 1.841568111001834, + "1": 0.3689523279972491, + "2": 0.37112178600000334, + "3": 0.3702217220015882, + "4": 0.3714859100000467 } }, "Stein Thinning": { "25": { - "0": 4.076632080999843, - "1": 4.0155544619992725, - "2": 3.9595811019971734, - "3": 3.7776183349997154, - "4": 3.8290312989993254 + "0": 6.69975286800036, + "1": 5.731123834000755, + "2": 5.567138086000341, + "3": 5.570860274001461, + "4": 5.52378452099947 }, "50": { - "0": 4.454431094000029, - "1": 4.158743233001587, - "2": 4.219952594998176, - "3": 3.934264821000397, - "4": 3.981639203000668 + "0": 6.741240006998851, + "1": 5.822887542999524, + "2": 5.8281003029987914, + "3": 5.787020288000349, + "4": 5.821601327999815 }, "100": { - "0": 4.420032927999273, - "1": 4.288612154003204, - "2": 4.268337712001085, - "3": 4.1098778509986005, - "4": 4.155323740000313 + "0": 6.602837345999433, + "1": 6.141826622002554, + "2": 6.064489849999518, + "3": 6.074834995000856, + "4": 6.022488116999739 }, "500": { - "0": 5.637759639997967, - "1": 10.025119519999862, - "2": 5.700425971001096, - "3": 5.587590609000472, - "4": 5.601147779998428 + "0": 8.96147054299945, + "1": 12.831382200998632, + "2": 8.39390050899965, + "3": 8.426648578002641, + "4": 8.36499944499883 }, "1000": { - "0": 7.587774307998188, - "1": 7.324021808999532, - "2": 7.105633106002642, - "3": 7.04964620499959, - "4": 7.102810305997991 + "0": 11.843904524997924, + "1": 11.518897926998761, + "2": 11.35488407199955, + "3": 11.304080829002487, + "4": 11.301111371001753 } }, "Kernel Thinning": { "25": { - "0": 90.81835862099979, - "1": 0.4871493360005843, - "2": 0.47238080099850777, - "3": 0.4738097219997144, - "4": 0.4770787749985175 + "0": 85.42026628800159, + "1": 0.47016479000012623, + "2": 0.46691343000202323, + "3": 0.46379170400177827, + "4": 0.4644382999977097 }, "50": { - "0": 46.26972372800083, - "1": 0.4933236199976818, - "2": 0.47538955799973337, - "3": 0.479368959000567, - "4": 0.47749236800154904 + "0": 47.04927618200236, + "1": 0.48483248599950457, + "2": 0.4784954530005052, + "3": 0.47844771000018227, + "4": 0.4787960780013236 }, "100": { - "0": 28.12213393500133, - "1": 0.5136646639984974, - "2": 0.491243342999951, - "3": 0.496500807999837, - "4": 0.4881571599980816 + "0": 26.912127160001546, + "1": 0.49925257100039744, + "2": 0.49251512400223874, + "3": 0.4928879849976511, + "4": 0.49184409299778054 }, "500": { - "0": 8.31735664399821, - "1": 0.48968988500200794, - "2": 0.48004734400092275, - "3": 0.47180004799884045, - "4": 0.47870473399962066 + "0": 7.987691559999803, + "1": 0.485307916002057, + "2": 0.4772449490010331, + "3": 0.4768101949994161, + "4": 0.47599558100046124 }, "1000": { - "0": 7.5886145939985, - "1": 0.5649518889986211, - "2": 0.5526147460004722, - "3": 0.549001772000338, - "4": 0.5572626430002856 + "0": 7.304257896001218, + "1": 0.5494356599992898, + "2": 0.5409130939988245, + "3": 0.543283473001793, + "4": 0.5417007770010969 } }, "Compress++": { "25": { - "0": 37.30952391499886, - "1": 0.05918412000028184, - "2": 0.05895868099833024, - "3": 0.058486374000494834, - "4": 0.05868467999971472 + "0": 35.93357312899752, + "1": 0.05818122900018352, + "2": 0.05758214000161388, + "3": 0.057043042997975135, + "4": 0.0573447450005915 }, "50": { - "0": 22.428665295003157, - "1": 0.0645626670011552, - "2": 0.06443759800094995, - "3": 0.06283377699946868, - "4": 0.06370369399883202 + "0": 21.631699821999064, + "1": 0.06311780300165992, + "2": 0.0626806519976526, + "3": 0.0606953850001446, + "4": 0.061265020001883386 }, "100": { - "0": 10.690804505000415, - "1": 0.06473967900092248, - "2": 0.0656564070013701, - "3": 0.06673681599932024, - "4": 0.06645065200063982 + "0": 10.191384666002705, + "1": 0.06473088300117524, + "2": 0.06460130399864283, + "3": 0.06175167599940323, + "4": 0.06374652999875252 }, "500": { - "0": 4.838294804998441, - "1": 0.08471525499771815, - "2": 0.08456600299905404, - "3": 0.08648833999905037, - "4": 0.08701666200067848 + "0": 4.64133475899871, + "1": 0.0834090960015601, + "2": 0.08410113900026772, + "3": 0.0832646980015852, + "4": 0.0844591209970531 }, "1000": { - "0": 3.9391429269999207, - "1": 0.10608179200062295, - "2": 0.10270820200094022, - "3": 0.1046263509997516, - "4": 0.10333565999826533 + "0": 3.834021393999137, + "1": 0.10361743400062551, + "2": 0.1014489750014036, + "3": 0.10026232100062771, + "4": 0.1008677270001499 } }, "Iterative Probabilistic Herding (constant)": { "25": { - "0": 3.2858039469974756, - "1": 0.3433665390002716, - "2": 0.31859386100040865, - "3": 0.31910103900008835, - "4": 0.3209825199992338 + "0": 3.157573227999819, + "1": 0.323332105002919, + "2": 0.32025618300031056, + "3": 0.32019502100229147, + "4": 0.3172955840018403 }, "50": { - "0": 3.3452633079978114, - "1": 0.36137995299941394, - "2": 0.35409107799932826, - "3": 0.3516577080008574, - "4": 0.3545336329989368 + "0": 3.2260004470008425, + "1": 0.3637085570007912, + "2": 0.3580116890007048, + "3": 0.35794998799974564, + "4": 0.3592098299995996 }, "100": { - "0": 3.4772456729988335, - "1": 0.419977902998653, - "2": 0.4046461109974189, - "3": 0.4071592389991565, - "4": 0.40611362399795325 + "0": 3.3186774120003975, + "1": 0.4064853199997742, + "2": 0.4025723509985255, + "3": 0.4018633320010849, + "4": 0.4037423609988764 }, "500": { - "0": 4.121649624001293, - "1": 0.9177702929991938, - "2": 0.9034675580005569, - "3": 0.905480092002108, - "4": 0.9054613059997791 + "0": 4.008981801001937, + "1": 0.9036035079989233, + "2": 0.8982191540017084, + "3": 0.9001957710024726, + "4": 0.8998167860008834 }, "1000": { - "0": 5.015538179999567, - "1": 1.5666278080025222, - "2": 1.555398291999154, - "3": 1.5604459620008129, - "4": 1.5613997779983038 + "0": 4.8883837110006425, + "1": 1.5568918610006222, + "2": 1.5566408380000212, + "3": 1.553357890999905, + "4": 1.555068123001547 } }, "Iterative Herding": { "25": { - "0": 1.8262478720025683, - "1": 0.2904930989971035, - "2": 0.2827128229982918, - "3": 0.2810959639973589, - "4": 0.2824249270015571 + "0": 1.7455461079989618, + "1": 0.27892118800082244, + "2": 0.2753962290007621, + "3": 0.2753615389992774, + "4": 0.2754493289976381 }, "50": { - "0": 1.8156351859979623, - "1": 0.30463867400249, - "2": 0.29666196499965736, - "3": 0.296801904001768, - "4": 0.2964241440022306 + "0": 1.7579592790025345, + "1": 0.3039573330024723, + "2": 0.30241110200222465, + "3": 0.3015530980010226, + "4": 0.30180186400320963 }, "100": { - "0": 1.874870039999223, - "1": 0.334557213998778, - "2": 0.325431799999933, - "3": 0.3274396750020969, - "4": 0.32587794700157247 + "0": 1.8197639319987502, + "1": 0.32099666700014495, + "2": 0.31948224599909736, + "3": 0.32031299200025387, + "4": 0.31887398099934217 }, "500": { - "0": 2.275343524001073, - "1": 0.5682799839996733, - "2": 0.5668382160001784, - "3": 0.5603331589991285, - "4": 0.5630835640004079 + "0": 2.1735239589979756, + "1": 0.5581009879970225, + "2": 0.5562046619998, + "3": 0.5586796050010889, + "4": 0.5535225969979365 }, "1000": { - "0": 2.590097914002399, - "1": 0.8541516059995047, - "2": 0.8420595110001159, - "3": 0.8399910410007578, - "4": 0.8505241100028798 + "0": 2.408495853000204, + "1": 0.8383036519990128, + "2": 0.8373670760011009, + "3": 0.8363624660014466, + "4": 0.8261094190020231 } }, "Iterative Probabilistic Herding (cubic)": { "25": { - "0": 3.205438458997378, - "1": 0.3528836830009823, - "2": 0.34116740399986156, - "3": 0.34536475899949437, - "4": 0.34625313600190566 + "0": 3.0695364310013247, + "1": 0.3450686610012781, + "2": 0.34057602400207543, + "3": 0.3391848300016136, + "4": 0.3390651859990612 }, "50": { - "0": 3.3181408650016238, - "1": 0.4046067210001638, - "2": 0.39647927599799004, - "3": 0.39693685399834067, - "4": 0.3962372599999071 + "0": 3.240672612999333, + "1": 0.4003331630010507, + "2": 0.39608382600272307, + "3": 0.3945449530001497, + "4": 0.39500187399971765 }, "100": { - "0": 3.5156353169986687, - "1": 0.5167118350000237, - "2": 0.51060877800046, - "3": 0.5102541699998255, - "4": 0.509720803998789 + "0": 3.4326532929990208, + "1": 0.5085808479998377, + "2": 0.5066659709991654, + "3": 0.5071053790015867, + "4": 0.4975864910011296 }, "500": { - "0": 4.725280201000714, - "1": 1.5142447790021833, - "2": 1.500269970001682, - "3": 1.5034325390006416, - "4": 1.5038313849981932 + "0": 4.607725097001094, + "1": 1.4990024629987602, + "2": 1.4945151149986486, + "3": 1.4941014079995512, + "4": 1.4937070180021692 }, "1000": { - "0": 6.223259660000622, - "1": 2.8354778319990146, - "2": 2.8134299510020355, - "3": 2.811016760999337, - "4": 2.8105186829998274 + "0": 6.099510264000855, + "1": 2.8083413430031214, + "2": 2.8086630260004313, + "3": 2.807137742998748, + "4": 2.7990312059991993 } } } diff --git a/benchmark/pounce_benchmark.py b/benchmark/pounce_benchmark.py index 77fe8e5dd..04b8367d6 100644 --- a/benchmark/pounce_benchmark.py +++ b/benchmark/pounce_benchmark.py @@ -35,7 +35,9 @@ from coreax.data import Data -def plot_selected_frames(umap_data, selected_indices, action_frames, solver_name): +def plot_selected_frames( + umap_data, selected_indices, action_frames, solver_name, out_dir +): """ Plot the selected frames and action frames on a bar chart. @@ -43,6 +45,7 @@ def plot_selected_frames(umap_data, selected_indices, action_frames, solver_name :param selected_indices: Indices of the selected frames. :param action_frames: Indices of the action frames. :param solver_name: The name of the solver used. + :param out_dir: The output directory for saved plots. """ x = np.arange(len(umap_data)) y = np.zeros(len(umap_data)) @@ -59,12 +62,13 @@ def plot_selected_frames(umap_data, selected_indices, action_frames, solver_name plt.title(f"Selected Frames for {solver_name}", fontsize=24, fontweight="bold") plt.legend() plt.tight_layout() - plt.show() + output_frames_path = out_dir / f"frames_{solver_name.replace(' ', '_')}.png" + plt.savefig(output_frames_path) def benchmark_coreset_algorithms( in_path: Path = Path("../examples/data/pounce/pounce.gif"), - out_dir: Path = Path("pounce"), + out_dir: Path = Path("../examples/benchmarking_images/pounce"), coreset_size: int = 10, ): """ @@ -87,7 +91,7 @@ def benchmark_coreset_algorithms( raw_data = np.asarray(image_data) reshaped_data = raw_data.reshape(raw_data.shape[0], -1) - umap_model = umap.UMAP(densmap=True, n_components=25) + umap_model = umap.UMAP(densmap=True, n_components=10, random_state=0) umap_data = jnp.asarray(umap_model.fit_transform(reshaped_data)) print("umap_data_shape", umap_data.shape) @@ -105,7 +109,7 @@ def benchmark_coreset_algorithms( selected_indices = np.sort(np.asarray(coreset.unweighted_indices)) coreset_frames = raw_data[selected_indices] - output_gif_path = out_dir / f"{solver_name}_coreset.gif" + output_gif_path = out_dir / f"{solver_name.replace(' ', '_')}_coreset.gif" imageio.v3.imwrite(output_gif_path, coreset_frames, loop=0) print(f"Saved {solver_name} coreset GIF to {output_gif_path}") print(f"time taken: {solver_name:<25} {duration:<30.4f}") @@ -115,6 +119,7 @@ def benchmark_coreset_algorithms( selected_indices=selected_indices, action_frames=np.arange(63, 85), solver_name=solver_name, + out_dir=out_dir, ) diff --git a/coreax/benchmark_util.py b/coreax/benchmark_util.py index 0041088db..395e26f74 100644 --- a/coreax/benchmark_util.py +++ b/coreax/benchmark_util.py @@ -23,13 +23,14 @@ from collections.abc import Callable from typing import Optional, TypeVar, Union +import jax import jax.numpy as jnp +import jax.scipy as jsp import numpy as np -from jaxtyping import Array, Float +from jaxtyping import Array, Float, Shaped from coreax import Coresubset, Data, SupervisedData from coreax.kernels import SquaredExponentialKernel, SteinKernel, median_heuristic -from coreax.score_matching import KernelDensityMatching from coreax.solvers import ( CompressPlusPlus, HerdingState, @@ -195,14 +196,30 @@ def _get_stein_solver(_size: int) -> Union[SteinThinning, MapReduce]: :return: A `SteinThinning` solver if `leaf_size` is `None`, otherwise a `MapReduce` solver with `SteinThinning` as the base solver. """ - # Generate small dataset for ScoreMatching for Stein Kernel + kde = jsp.stats.gaussian_kde(train_data_umap.data[idx].T) - score_function = KernelDensityMatching(length_scale=length_scale.item()).match( - train_data_umap[idx] + # Define the score function as the gradient of log density given by the KDE + def score_function( + x: Union[Shaped[Array, " n d"], Shaped[Array, ""], float, int], + ) -> Union[Shaped[Array, " n d"], Shaped[Array, " 1 1"]]: + """ + Compute the score function (gradient of log density) for a single point. + + :param x: Input point represented as array. + :return: Gradient of log probability density at the given point. + """ + + def logpdf_single(x: Shaped[Array, " d"]) -> Shaped[Array, ""]: + return kde.logpdf(x.reshape(1, -1))[0] + + return jax.grad(logpdf_single)(x) + + stein_kernel = SteinKernel( + base_kernel=kernel, + score_function=score_function, ) - stein_kernel = SteinKernel(kernel, score_function) stein_solver = SteinThinning( - coreset_size=_size, kernel=stein_kernel, regularise=False + coreset_size=_size, kernel=stein_kernel, regularise=True ) if leaf_size is None: return stein_solver diff --git a/coreax/solvers/coresubset.py b/coreax/solvers/coresubset.py index 85dcccb4d..408392223 100644 --- a/coreax/solvers/coresubset.py +++ b/coreax/solvers/coresubset.py @@ -443,9 +443,10 @@ class SteinThinning( If :data:`None`, defaults to :math:`1/\text{coreset_size}` following :cite:`benard2023kernel`. :param block_size: Block size passed to - :meth:`~coreax.kernels.ScalarValuedKernel.compute_mean` + :meth:`~coreax.kernels.ScalarValuedKernel.compute_mean`. :param unroll: Unroll parameter passed to - :meth:`~coreax.kernels.ScalarValuedKernel.compute_mean` + :meth:`~coreax.kernels.ScalarValuedKernel.compute_mean`. + :param kde_bw_method: Bandwidth method passed to `jax.scipy.stats.gaussian_kde`. """ kernel: ScalarValuedKernel @@ -455,6 +456,7 @@ class SteinThinning( regulariser_lambda: Optional[float] = None block_size: Optional[Union[int, tuple[Optional[int], Optional[int]]]] = None unroll: Union[int, bool, tuple[Union[int, bool], Union[int, bool]]] = 1 + kde_bw_method: Optional[Union[str, int, Callable]] = None @override def reduce( @@ -477,20 +479,21 @@ def refine( Only the score function, :math:`\nabla \log p(x)`, is provided to the solver. Since the lambda regularisation term relies on the density, :math:`p(x)`, directly, it is estimated using a Gaussian kernel density - estimator. + estimator using `jax.scipy.stats.gaussian_kde`. The bandwidth + method for this can passed as kde_bw_method when initialising + :class:`SteinThinning`. :param coresubset: The coresubset to refine :param solver_state: Solution state information, primarily used to cache expensive intermediate solution step values. - :return: a refined coresubset and relevant intermediate solver state information + :return: A refined coresubset and relevant intermediate solver state + information. """ x, w_x = jtu.tree_leaves(coresubset.pre_coreset_data) kernel = convert_stein_kernel(x, self.kernel, self.score_matching) stein_kernel_diagonal = jax.vmap(self.kernel.compute_elementwise)(x, x) if self.regularise: - # Cannot guarantee that kernel.base_kernel has a 'length_scale' attribute - bandwidth_method = getattr(kernel.base_kernel, "length_scale", None) - kde = jsp.stats.gaussian_kde(x.T, weights=w_x, bw_method=bandwidth_method) + kde = jsp.stats.gaussian_kde(x.T, weights=w_x, bw_method=self.kde_bw_method) if self.regulariser_lambda is None: # Use regularisation parameter suggested in :cite:`benard2023kernel` @@ -814,7 +817,7 @@ def refine( # noqa: PLR0915 requested ``coreset_size``, the extra indices will not be optimised and will be clipped from the return ``coresubset``. - :param coresubset: The coresubset to refine + :param coresubset: The coresubset to refine. :param solver_state: Solution state information, primarily used to cache expensive intermediate solution step values. :return: A refined coresubset and relevant intermediate solver state information diff --git a/documentation/source/benchmark.rst b/documentation/source/benchmark.rst index f2a797c07..8388acbec 100644 --- a/documentation/source/benchmark.rst +++ b/documentation/source/benchmark.rst @@ -5,9 +5,12 @@ In this benchmark, we assess the performance of different coreset algorithms: :class:`~coreax.solvers.KernelHerding`, :class:`~coreax.solvers.SteinThinning`, :class:`~coreax.solvers.RandomSample`, :class:`~coreax.solvers.RPCholesky`, :class:`~coreax.solvers.KernelThinning`, and :class:`~coreax.solvers.CompressPlusPlus`. -Each of these algorithms is evaluated -across four different tests, providing a comparison of their performance and -applicability to various datasets. +Each of these algorithms is evaluated across four different tests, providing a +comparison of their performance and applicability to various datasets. + +This benchmark only evaluates unsupervised coreset algorithms. Hence, the tasks +involve selecting a representative subset of data points without any prior labels +provided. Test 1: Benchmarking Coreset Algorithms on the MNIST Dataset ------------------------------------------------------------ @@ -58,23 +61,31 @@ The benchmarking test showed that the accuracy remained similar regardless of the coreset method used, with only small differences, which could potentially be attributed to the use of these regularisation techniques. - **Results**: The accuracy of the MLP classifier when trained using the full MNIST dataset (60,000 training images) was 97.31%, serving as a baseline for evaluating the performance of the coreset algorithms. - - .. image:: ../../examples/benchmarking_images/mnist_benchmark_accuracy.png - :alt: Benchmark Results for MNIST Coreset Algorithms + :alt: A bar chart showing the accuracy of nine different algorithms + (the 6 mentioned above and 3 variants of iterative probabilistic Kernel Herding) across + six coreset sizes (25, 50, 100, 500, 1000, and 5000). The chart displays increasing + accuracy for all algorithms as coreset size increases, with performance ranging from + about 0.45-0.55 at size 25 to about 0.9 at size 5000. There is + little difference in performance between the algorithms at the 1000-5000 size. No + algorithm can be said to outperform the random sample consistently and Stein + Thinning lags slightly behind all other algorithms across all sizes. **Figure 1**: Accuracy of coreset algorithms on the MNIST dataset. Bar heights represent the average accuracy. Error bars represent the min-max range for accuracy for each coreset size across 5 runs. .. image:: ../../examples/benchmarking_images/mnist_benchmark_time_taken.png - :alt: Time Taken Benchmark Results for MNIST Coreset Algorithms + :alt: A bar chart showing the time taken on the logarithmic scale. + Random sample shows negligible time to run across the sizes, while Stein Thinning + tends to take an order of magnitude more than other algorithms. Among others, + Compress++ tends to be consistently faster than other algorithms while RP + Cholesky is the fastest at small coreset sizes. **Figure 2**: Time taken to generate coreset for each coreset algorithm. Bar heights represent the average time taken. Error bars represent the min-max range for each @@ -117,66 +128,66 @@ For each metric and coreset size, the best performance score is highlighted in b - Time * - KernelHerding - 0.024273 - - 0.072547 + - 0.086342 - 0.008471 - - 0.072267 - - 4.600567 + - 0.074467 + - 4.765064 * - RandomSample - 0.111424 - - 0.077308 + - 0.088141 - 0.011224 - - 0.073833 - - **3.495483** + - 0.075859 + - **3.372750** * - RPCholesky - 0.140047 - - **0.059306** + - 0.073147 - 0.003688 - - **0.071969** - - 4.230014 + - **0.060939** + - 4.026443 * - SteinThinning - - 0.147962 - - 0.075813 - - 0.017571 - - 0.074239 - - 4.806702 + - 0.144938 + - 0.085247 + - 0.063385 + - 0.086622 + - 5.611508 * - KernelThinning - 0.014880 - - 0.072271 + - 0.075884 - 0.005388 - - 0.072463 - - 27.173368 + - 0.064494 + - 25.014126 * - CompressPlusPlus - 0.013212 - - 0.072479 + - 0.084045 - 0.007081 - - 0.072777 - - 17.304506 + - 0.081235 + - 16.713568 * - ProbabilisticIterativeHerding - 0.021128 - - 0.073220 + - 0.089382 - 0.007852 - - 0.073069 - - 4.669493 + - 0.080658 + - 4.702327 * - IterativeHerding - 0.007051 - - 0.072036 + - **0.068399** - 0.005125 - - 0.072206 - - 4.062584 + - 0.065863 + - 3.825249 * - CubicProbIterativeHerding - **0.004543** - - 0.072165 + - 0.081827 - **0.003512** - - 0.072366 - - 4.687458 + - 0.077990 + - 4.375146 .. list-table:: Coreset Size 50 (Original Sample Size 1,024) :header-rows: 1 @@ -190,66 +201,66 @@ For each metric and coreset size, the best performance score is highlighted in b - Time * - KernelHerding - 0.014011 - - 0.072273 + - 0.057618 - 0.003191 - - **0.072094** - - 4.139396 + - **0.052470** + - 4.036918 * - RandomSample - 0.104925 - - 0.078755 + - 0.079876 - 0.004955 - - 0.072600 - - **3.580714** + - 0.061597 + - **3.279080** * - RPCholesky - 0.146650 - - **0.056694** + - 0.064917 - 0.001539 - - 0.072209 - - 3.820043 + - 0.054541 + - 3.720830 * - SteinThinning - - 0.132586 - - 0.077087 - - 0.006761 - - 0.072635 - - 4.231215 + - 0.086824 + - 0.055094 + - 0.013564 + - 0.061475 + - 4.627325 * - KernelThinning - 0.006304 - - 0.072012 + - 0.061570 - 0.002246 - - 0.072222 - - 15.216022 + - 0.058513 + - 14.038467 * - CompressPlusPlus - 0.007616 - - 0.072154 + - 0.063311 - 0.002819 - - 0.072249 - - 11.209934 + - 0.056713 + - 10.396490 * - ProbabilisticIterativeHerding - 0.015108 - - 0.073478 + - 0.068838 - 0.003151 - - 0.072501 - - 4.343780 + - 0.063005 + - 4.108718 * - IterativeHerding - 0.003708 - - 0.072123 + - **0.052616** - 0.002604 - - 0.072199 - - 3.681021 + - 0.053199 + - 3.577140 * - CubicProbIterativeHerding - **0.001733** - - 0.072226 + - 0.058076 - **0.001442** - - 0.072296 - - 4.199541 + - 0.059921 + - 4.120308 .. list-table:: Coreset Size 100 (Original Sample Size 1,024) :header-rows: 1 @@ -263,66 +274,66 @@ For each metric and coreset size, the best performance score is highlighted in b - Time * - KernelHerding - 0.007909 - - 0.071763 + - 0.046639 - 0.001859 - - **0.072205** - - 4.313880 + - 0.051218 + - 4.235977 * - RandomSample - 0.055019 - - 0.075205 + - 0.061831 - 0.001804 - - 0.072270 - - **3.731109** + - 0.057107 + - **3.158193** * - RPCholesky - 0.097647 - - **0.062210** + - 0.039633 - 0.001044 - - 0.072251 - - 4.349850 + - 0.055332 + - 3.850249 * - SteinThinning - - 0.137844 - - 0.081297 - - 0.004691 - - 0.072308 - - 4.689983 + - 0.093073 + - **0.035877** + - 0.006268 + - 0.055652 + - 4.740899 * - KernelThinning - 0.002685 - - 0.072069 + - 0.056104 - 0.001265 - - 0.072263 - - 10.102306 + - 0.058189 + - 9.000171 * - CompressPlusPlus - 0.002936 - - 0.072196 + - 0.055740 - 0.001226 - - 0.072285 - - 9.244769 + - 0.055948 + - 8.099011 * - ProbabilisticIterativeHerding - 0.009710 - - 0.072786 + - 0.062317 - 0.001838 - - 0.072367 - - 4.425218 + - 0.059106 + - 4.518486 * - IterativeHerding - 0.002256 - - 0.072129 + - 0.048805 - 0.001407 - - 0.072255 - - 4.298705 + - **0.051166** + - 4.135961 * - CubicProbIterativeHerding - **0.000805** - - 0.072214 + - 0.051934 - **0.000979** - - 0.072259 - - 4.685692 + - 0.054329 + - 4.499996 .. list-table:: Coreset Size 200 (Original Sample Size 1,024) :header-rows: 1 @@ -336,97 +347,97 @@ For each metric and coreset size, the best performance score is highlighted in b - Time * - KernelHerding - 0.004259 - - 0.072017 + - 0.047415 - 0.001173 - - 0.072242 - - 4.809446 + - 0.054883 + - 4.568870 * - RandomSample - 0.041521 - - 0.072316 + - 0.057967 - 0.000914 - - 0.072260 - - **3.744803** + - 0.055495 + - **3.401281** * - RPCholesky - 0.056923 - - **0.067187** + - 0.042466 - 0.000830 - - 0.072248 - - 4.360848 + - **0.053957** + - 4.136736 * - SteinThinning - - 0.144544 - - 0.085564 - - 0.002836 - - **0.072155** - - 4.833500 + - 0.104213 + - **0.024422** + - 0.003508 + - 0.055823 + - 5.040177 * - KernelThinning - 0.001518 - - 0.072137 + - 0.054005 - 0.000886 - - 0.072265 - - 6.940934 + - 0.057455 + - 6.787894 * - CompressPlusPlus - 0.001410 - - 0.072154 + - 0.053179 - 0.000755 - - 0.072240 - - 7.291234 + - 0.054638 + - 7.406790 * - ProbabilisticIterativeHerding - 0.006358 - - 0.072700 + - 0.058343 - 0.000873 - - 0.072272 - - 4.814415 + - 0.057020 + - 4.711837 * - IterativeHerding - 0.001382 - - 0.072160 + - 0.050098 - 0.000995 - - 0.072241 - - 4.238990 + - 0.054194 + - 4.150570 * - CubicProbIterativeHerding - **0.000582** - - 0.072205 + - 0.052761 - **0.000706** - - 0.072260 - - 4.936301 + - 0.056212 + - 4.702852 **Visualisation**: The results in this table can be visualised as follows: .. image:: ../../examples/benchmarking_images/blobs_unweighted_mmd.png :alt: Line graph visualising the data tables above, plotting unweighted MMD against - coreset size for each of the coreset methods + coreset size for each of the coreset methods. **Figure 3**: Unweighted MMD plotted against coreset size for each coreset method. .. image:: ../../examples/benchmarking_images/blobs_unweighted_ksd.png :alt: Line graph visualising the data tables above, plotting unweighted KSD against - coreset size for each of the coreset methods + coreset size for each of the coreset methods. **Figure 4**: Unweighted KSD plotted against coreset size for each coreset method. .. image:: ../../examples/benchmarking_images/blobs_weighted_mmd.png :alt: Line graph visualising the data tables above, plotting weighted MMD against - coreset size for each of the coreset methods + coreset size for each of the coreset methods. **Figure 5**: Weighted MMD plotted against coreset size for each coreset method. .. image:: ../../examples/benchmarking_images/blobs_weighted_ksd.png :alt: Line graph visualising the data tables above, plotting weighted KSD against - coreset size for each of the coreset methods + coreset size for each of the coreset methods. **Figure 6**: Weighted KSD plotted against coreset size for each coreset method. - .. image:: ../../examples/benchmarking_images/blobs_Time.png + .. image:: ../../examples/benchmarking_images/blobs_time.png :alt: Line graph visualising the data tables above, plotting time taken against - coreset size for each of the coreset methods + coreset size for each of the coreset methods. **Figure 7**: Time taken plotted against coreset size for each coreset method. @@ -449,7 +460,10 @@ from an input image. The process follows these steps: **Results**: The following plot visualises the pixels chosen by each coreset algorithm. .. image:: ../../examples/benchmarking_images/david_benchmark_results.png - :alt: Plot showing pixels chosen from an image by each coreset algorithm + :alt: Plot showing pixels chosen from an image by each coreset algorithm. All + algorithms tend to perform similarly, resulting in a blurred version of the + original image. The only exception is Stein Thinning, which reconstructs only a + few features of the image. **Figure 8**: The original image and pixels selected by each coreset algorithm plotted side-by-side for visual comparison. @@ -523,28 +537,29 @@ The following plots show the frames chosen by each coreset algorithm with action in orange. .. image:: ../../examples/benchmarking_images/pounce/frames_Random_Sample.png - :alt: Plot showing the frames selected by Random Sample + :alt: Plot shows Random Sample selecting 3 action frames. .. image:: ../../examples/benchmarking_images/pounce/frames_RP_Cholesky.png - :alt: Plot showing the frames selected by RP Cholesky + :alt: Plot shows RP Cholesky selecting 3 action frames. .. image:: ../../examples/benchmarking_images/pounce/frames_Stein_Thinning.png - :alt: Plot showing the frames selected by Stein Thinning + :alt: Plot shows Stein Thinning selecting 5 action frames. .. image:: ../../examples/benchmarking_images/pounce/frames_Kernel_Herding.png - :alt: Plot showing the frames selected by Kernel Herding + :alt: Plot shows Kernel Herding selecting 1 action frame. .. image:: ../../examples/benchmarking_images/pounce/frames_Kernel_Thinning.png - :alt: Plot showing the frames selected by Kernel Thinning + :alt: Plot shows Kernel Thinning selecting 2 action frames. .. image:: ../../examples/benchmarking_images/pounce/frames_Compress++.png - :alt: Plot showing the frames selected by Compress++ + :alt: Plot shows Compress++ selecting 2 action frames. .. image:: ../../examples/benchmarking_images/pounce/frames_Iterative_Probabilistic_Herding_(constant).png - :alt: Plot showing the frames selected by Probabilistic Iterative Kernel Herding + :alt: Plot shows Probabilistic Iterative Kernel Herding selecting 2 action frames. .. image:: ../../examples/benchmarking_images/pounce/frames_Iterative_Probabilistic_Herding_(cubic).png - :alt: Plot showing the frames selected by Probabilistic Iterative Kernel Herding with a decaying temperature parameter + :alt: Plot shows Probabilistic Iterative Kernel Herding with a decaying temperature + parameter selecting 2 action frames. Conclusion ---------- diff --git a/examples/benchmarking_images/blobs_Time.png b/examples/benchmarking_images/blobs_Time.png deleted file mode 100644 index b94369536..000000000 --- a/examples/benchmarking_images/blobs_Time.png +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:d586b7f62480499050c0b88546d0bda00bb1582bb86fc7fc5c0879c7d5fc4352 -size 125982 diff --git a/examples/benchmarking_images/blobs_time.png b/examples/benchmarking_images/blobs_time.png new file mode 100644 index 000000000..f02bbfb0f --- /dev/null +++ b/examples/benchmarking_images/blobs_time.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:29b7a6efddb7770747ebc5a0e39fcc37c21b35f91bb345a0e03505695f3f8357 +size 123687 diff --git a/examples/benchmarking_images/blobs_time_taken.png b/examples/benchmarking_images/blobs_time_taken.png deleted file mode 100644 index 5ef1a51f3..000000000 --- a/examples/benchmarking_images/blobs_time_taken.png +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:0103f52cc6d9e02de5f9b402e57257e0456cb5411d07825630af844cc6ee94ba -size 115777 diff --git a/examples/benchmarking_images/blobs_unweighted_ksd.png b/examples/benchmarking_images/blobs_unweighted_ksd.png index 0b7353c0c..9ab9af696 100644 --- a/examples/benchmarking_images/blobs_unweighted_ksd.png +++ b/examples/benchmarking_images/blobs_unweighted_ksd.png @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:64443b574edb3876b485132e2608b474bd35ca238fcea6dafb238f07ade7b3f4 -size 116863 +oid sha256:61dde49cdc23e1e2a9d0e752ef3dcde6e2ed1c78a850d57f8dd8f13e1db0bc36 +size 160110 diff --git a/examples/benchmarking_images/blobs_unweighted_mmd.png b/examples/benchmarking_images/blobs_unweighted_mmd.png index d91322d89..19020294d 100644 --- a/examples/benchmarking_images/blobs_unweighted_mmd.png +++ b/examples/benchmarking_images/blobs_unweighted_mmd.png @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:6272a5508bb48f16295b4bc50951ed132a9b69e947be1cc6ae00c2c0b7094f40 -size 154623 +oid sha256:1bc52eb8c47785225ec6a1e928ba8306839a8b0e508f724502afd3da0d89d979 +size 154976 diff --git a/examples/benchmarking_images/blobs_weighted_ksd.png b/examples/benchmarking_images/blobs_weighted_ksd.png index 357c8615b..daadf56eb 100644 --- a/examples/benchmarking_images/blobs_weighted_ksd.png +++ b/examples/benchmarking_images/blobs_weighted_ksd.png @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:a55b4a9c8032480c51cf753c1118ce05fff3fb51af0bd5bc0039401d82509fe0 -size 133190 +oid sha256:2dafecd60222cba204ac76347ac20fe53476477babc9149d125fb5af95f09adc +size 163906 diff --git a/examples/benchmarking_images/blobs_weighted_mmd.png b/examples/benchmarking_images/blobs_weighted_mmd.png index 6bd09e12a..99388d702 100644 --- a/examples/benchmarking_images/blobs_weighted_mmd.png +++ b/examples/benchmarking_images/blobs_weighted_mmd.png @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:021fea28fbd1f54f33379ebb6d5fecf971418a65fcb91992b538ec3775f4b257 -size 182407 +oid sha256:754ae8f70e45e1783e4fca09ee5ccb2577117d0b6e63ab027c265d119a197cab +size 162225 diff --git a/examples/benchmarking_images/david_benchmark_results.png b/examples/benchmarking_images/david_benchmark_results.png index a9f1af388..de1232eb4 100644 --- a/examples/benchmarking_images/david_benchmark_results.png +++ b/examples/benchmarking_images/david_benchmark_results.png @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:c2b008415af23f3d56ae94f34e6e23646aa614886c157e7638a6d891cac06009 -size 768408 +oid sha256:6950c9d7ae31f39b297f2b705c1f4222a1664c6ca8b3bccc61f6ddfa4315cba9 +size 365289 diff --git a/examples/benchmarking_images/mnist_benchmark_accuracy.png b/examples/benchmarking_images/mnist_benchmark_accuracy.png index 745eef54d..5b0a0b0ef 100644 --- a/examples/benchmarking_images/mnist_benchmark_accuracy.png +++ b/examples/benchmarking_images/mnist_benchmark_accuracy.png @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:07c3f4ec5294b5fda6282f0ba2c96edc898a6c6fae688db1e606350d8962b5be -size 149345 +oid sha256:bf30781b651f677d2d897f1ab670fade7fb39ba9f80fae9e8cf0ad3531c82bac +size 147804 diff --git a/examples/benchmarking_images/mnist_benchmark_time_taken.png b/examples/benchmarking_images/mnist_benchmark_time_taken.png index 4c946a7d4..2b3f8045c 100644 --- a/examples/benchmarking_images/mnist_benchmark_time_taken.png +++ b/examples/benchmarking_images/mnist_benchmark_time_taken.png @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:62804eeaec8d2aa350a2b12176ee571967f817ffa88fa426948542ba547bcb59 -size 112821 +oid sha256:d0f1155843d609140beef7558ed6902e3f719f34e1dd9330aa79d23707804957 +size 110784 diff --git a/examples/benchmarking_images/pounce/Compress++_coreset.gif b/examples/benchmarking_images/pounce/Compress++_coreset.gif index 730d1982a..db779774e 100644 Binary files a/examples/benchmarking_images/pounce/Compress++_coreset.gif and b/examples/benchmarking_images/pounce/Compress++_coreset.gif differ diff --git a/examples/benchmarking_images/pounce/Iterative_Herding_coreset.gif b/examples/benchmarking_images/pounce/Iterative_Herding_coreset.gif index b75f84fcf..1fbbbccb1 100644 Binary files a/examples/benchmarking_images/pounce/Iterative_Herding_coreset.gif and b/examples/benchmarking_images/pounce/Iterative_Herding_coreset.gif differ diff --git a/examples/benchmarking_images/pounce/Iterative_Probabilistic_Herding_(constant)_coreset.gif b/examples/benchmarking_images/pounce/Iterative_Probabilistic_Herding_(constant)_coreset.gif index 35c5c0bf6..8ce6d7199 100644 Binary files a/examples/benchmarking_images/pounce/Iterative_Probabilistic_Herding_(constant)_coreset.gif and b/examples/benchmarking_images/pounce/Iterative_Probabilistic_Herding_(constant)_coreset.gif differ diff --git a/examples/benchmarking_images/pounce/Iterative_Probabilistic_Herding_(cubic)_coreset.gif b/examples/benchmarking_images/pounce/Iterative_Probabilistic_Herding_(cubic)_coreset.gif index 7c3feb3d2..c1bc4d682 100644 Binary files a/examples/benchmarking_images/pounce/Iterative_Probabilistic_Herding_(cubic)_coreset.gif and b/examples/benchmarking_images/pounce/Iterative_Probabilistic_Herding_(cubic)_coreset.gif differ diff --git a/examples/benchmarking_images/pounce/Kernel_Herding_coreset.gif b/examples/benchmarking_images/pounce/Kernel_Herding_coreset.gif index 631040a53..562bd4661 100644 Binary files a/examples/benchmarking_images/pounce/Kernel_Herding_coreset.gif and b/examples/benchmarking_images/pounce/Kernel_Herding_coreset.gif differ diff --git a/examples/benchmarking_images/pounce/Kernel_Thinning_coreset.gif b/examples/benchmarking_images/pounce/Kernel_Thinning_coreset.gif index cb47451ed..88584a1b9 100644 Binary files a/examples/benchmarking_images/pounce/Kernel_Thinning_coreset.gif and b/examples/benchmarking_images/pounce/Kernel_Thinning_coreset.gif differ diff --git a/examples/benchmarking_images/pounce/RP_Cholesky_coreset.gif b/examples/benchmarking_images/pounce/RP_Cholesky_coreset.gif index 1bfdfc28d..ce7bd6cf4 100644 Binary files a/examples/benchmarking_images/pounce/RP_Cholesky_coreset.gif and b/examples/benchmarking_images/pounce/RP_Cholesky_coreset.gif differ diff --git a/examples/benchmarking_images/pounce/Random_Sample_coreset.gif b/examples/benchmarking_images/pounce/Random_Sample_coreset.gif index bdb507ca4..3b47c1e7e 100644 Binary files a/examples/benchmarking_images/pounce/Random_Sample_coreset.gif and b/examples/benchmarking_images/pounce/Random_Sample_coreset.gif differ diff --git a/examples/benchmarking_images/pounce/Stein_Thinning_coreset.gif b/examples/benchmarking_images/pounce/Stein_Thinning_coreset.gif index e482f7502..7b89ec72a 100644 Binary files a/examples/benchmarking_images/pounce/Stein_Thinning_coreset.gif and b/examples/benchmarking_images/pounce/Stein_Thinning_coreset.gif differ diff --git a/examples/benchmarking_images/pounce/frames_Compress++.png b/examples/benchmarking_images/pounce/frames_Compress++.png index ad9255ac2..97496f500 100644 Binary files a/examples/benchmarking_images/pounce/frames_Compress++.png and b/examples/benchmarking_images/pounce/frames_Compress++.png differ diff --git a/examples/benchmarking_images/pounce/frames_Iterative_Herding.png b/examples/benchmarking_images/pounce/frames_Iterative_Herding.png index 69f66bcf2..59b717a71 100644 Binary files a/examples/benchmarking_images/pounce/frames_Iterative_Herding.png and b/examples/benchmarking_images/pounce/frames_Iterative_Herding.png differ diff --git a/examples/benchmarking_images/pounce/frames_Iterative_Probabilistic_Herding_(constant).png b/examples/benchmarking_images/pounce/frames_Iterative_Probabilistic_Herding_(constant).png index 5ad0463d7..2480ad883 100644 Binary files a/examples/benchmarking_images/pounce/frames_Iterative_Probabilistic_Herding_(constant).png and b/examples/benchmarking_images/pounce/frames_Iterative_Probabilistic_Herding_(constant).png differ diff --git a/examples/benchmarking_images/pounce/frames_Iterative_Probabilistic_Herding_(cubic).png b/examples/benchmarking_images/pounce/frames_Iterative_Probabilistic_Herding_(cubic).png index 7a90c9e6c..16afa6d65 100644 Binary files a/examples/benchmarking_images/pounce/frames_Iterative_Probabilistic_Herding_(cubic).png and b/examples/benchmarking_images/pounce/frames_Iterative_Probabilistic_Herding_(cubic).png differ diff --git a/examples/benchmarking_images/pounce/frames_Kernel_Herding.png b/examples/benchmarking_images/pounce/frames_Kernel_Herding.png index cf69f333d..4f58e8e01 100644 Binary files a/examples/benchmarking_images/pounce/frames_Kernel_Herding.png and b/examples/benchmarking_images/pounce/frames_Kernel_Herding.png differ diff --git a/examples/benchmarking_images/pounce/frames_Kernel_Thinning.png b/examples/benchmarking_images/pounce/frames_Kernel_Thinning.png index 0a2dd3479..f57c362ae 100644 Binary files a/examples/benchmarking_images/pounce/frames_Kernel_Thinning.png and b/examples/benchmarking_images/pounce/frames_Kernel_Thinning.png differ diff --git a/examples/benchmarking_images/pounce/frames_RP_Cholesky.png b/examples/benchmarking_images/pounce/frames_RP_Cholesky.png index dea885166..0b04779e4 100644 Binary files a/examples/benchmarking_images/pounce/frames_RP_Cholesky.png and b/examples/benchmarking_images/pounce/frames_RP_Cholesky.png differ diff --git a/examples/benchmarking_images/pounce/frames_Random_Sample.png b/examples/benchmarking_images/pounce/frames_Random_Sample.png index 08ec87407..1c21af8c9 100644 Binary files a/examples/benchmarking_images/pounce/frames_Random_Sample.png and b/examples/benchmarking_images/pounce/frames_Random_Sample.png differ diff --git a/examples/benchmarking_images/pounce/frames_Stein_Thinning.png b/examples/benchmarking_images/pounce/frames_Stein_Thinning.png index 1f75830bf..72c8c0f13 100644 Binary files a/examples/benchmarking_images/pounce/frames_Stein_Thinning.png and b/examples/benchmarking_images/pounce/frames_Stein_Thinning.png differ diff --git a/tests/unit/test_benchmark.py b/tests/unit/test_benchmark.py index 89e411906..192b679b5 100644 --- a/tests/unit/test_benchmark.py +++ b/tests/unit/test_benchmark.py @@ -38,6 +38,7 @@ calculate_delta, initialise_solvers, ) +from coreax.coreset import Coresubset from coreax.solvers import ( CompressPlusPlus, KernelHerding, @@ -215,6 +216,11 @@ def test_solver_instances() -> None: solver_instance = solver_function(1) assert isinstance(solver_instance, expected_solver_types_with_leaf[solver_name]) + # For SteinThinning, run reduce to make sure the score function works + stein_solver = solvers_no_leaf["Stein Thinning"](1) + coreset, _ = stein_solver.reduce(mock_data) + assert isinstance(coreset, Coresubset) + @pytest.mark.parametrize("n", [1, 2, 100]) def test_calculate_delta(n):