# Time benchmark KDE

In [None]:
from abc import ABC, abstractmethod
from dataclasses import dataclass

import numpy as np
import pandas as pd

import time
from itertools import cycle

from scipy.stats import norm
from scipy.stats import multivariate_normal

from scipy.stats import gaussian_kde
from sklearn.neighbors import KernelDensity
from statsmodels.nonparametric.kde import KDEUnivariate
from statsmodels.nonparametric.kernel_density import KDEMultivariate
from KDEpy.FFTKDE import FFTKDE
from xentropy.kde import Kde

import parallelkdepy as pkde

from tqdm.notebook import trange, tqdm

import matplotlib
import matplotlib.pyplot as plt

In [None]:
plt.rcParams.update({
    # Figure
    "figure.figsize": (3.5, 2.5),          # Matches ~1-column width in most journals
    "figure.dpi": 150,                     # High resolution for print
    "figure.autolayout": True,             # Avoid clipped labels

    # Fonts
    "font.size": 8,                         # 8-10 pt works well for most journals
    "font.family": "sans-serif",                 # Matches LaTeX text if used
    "mathtext.fontset": "cm",               # Computer Modern for math

    # Axes
    "axes.linewidth": 0.8,
    "axes.labelsize": 8,
    "axes.titlesize": 8,
    "axes.labelpad": 2,

    # Ticks
    "xtick.direction": "in",
    "ytick.direction": "in",
    "xtick.top": True,
    "ytick.right": True,
    "xtick.major.size": 3,
    "xtick.minor.size": 1.5,
    "ytick.major.size": 3,
    "ytick.minor.size": 1.5,
    "xtick.labelsize": 7,
    "ytick.labelsize": 7,

    # Lines
    "lines.linewidth": 1.0,
    "lines.markersize": 4,

    # Legend
    "legend.fontsize": 7,
    "legend.frameon": False,

    # LaTex
    "text.usetex": True
})

## Benchmark

In [None]:
def get_distro(n_dims):
    if n_dims == 1:
        return norm(loc=0.0, scale=1.0)
    elif n_dims == 2:
        return multivariate_normal(mean=np.full(2, 0.0), cov=np.eye(2))
    else:
        raise ValueError("Only 1D and 2D distro allowed")

In [None]:
def sample_distro(distro, n_samples):
    return distro.rvs(size=n_samples)

In [None]:
def create_grid(n_points, dim, lb=-8, hb=8, device="cpu"):
    grid = pkde.Grid(dim*[(lb, hb, n_points)], device=device)

    return grid

In [None]:
@dataclass
class EstimatorInfo:
    package: str
    method: str
    dim: int
    device: str = "cpu"

In [None]:
class DensityEstimator(ABC):
    def __init__(self, info: EstimatorInfo):
        self._info = info
        self._is_fit = False

    @property
    def info(self):
        return self._info

    @abstractmethod
    def fit(self, x: np.ndarray):
        ...

    @abstractmethod
    def evaluate(self, grid) -> np.ndarray:
        ...

    def name(self) -> str:
        return f"{self._info.package}:{self._info.method}:{self._info.device}:{self._info.dim}D"

In [None]:
class ScipyKDE(DensityEstimator):
    def __init__(self, dim: int, bw_method: str):
        super().__init__(EstimatorInfo("scipy", bw_method, dim))

    def fit(self, data: np.ndarray, **kwargs):
        if data.ndim > 1:
            assert data.shape[1] == self.info.dim
        else:
            assert self.info.dim == 1
        self._data = data
        self._kde = gaussian_kde(data.T, bw_method=self.info.method)
        self._is_fit = True
        
        return self

    def evaluate(self, grid: pkde.Grid) -> np.ndarray:
        grid_mesh = grid.to_meshgrid()
        if len(grid_mesh) > 1:
            grid_points = np.vstack([x.ravel() for x in grid_mesh])
        else:
            grid_points = grid_mesh[0]
            
        density_estimated = self._kde(grid_points).reshape(grid.shape)

        return density_estimated

In [None]:
class SklearnKDE(DensityEstimator):
    def __init__(self, dim: int, bw_method: str):
        super().__init__(EstimatorInfo("sklearn", bw_method, dim))

    def fit(self, data: np.ndarray):
        if data.ndim > 1:
            assert data.shape[1] == self.info.dim
            self._data = data
        else:
            assert self.info.dim == 1
            self._data = data[:, np.newaxis]
        self._kde = KernelDensity(bandwidth=self.info.method).fit(self._data)
        self._is_fit = True

        return self

    def evaluate(self, grid: pkde.Grid) -> np.ndarray:
        grid_mesh = grid.to_meshgrid()
        if len(grid_mesh) > 1:
            grid_points = np.vstack([x.ravel() for x in grid_mesh])
            grid_points = grid_points.T
        else:
            grid_points = grid_mesh[0]
            grid_points = grid_points[:, np.newaxis]

        density_estiated = np.exp(self._kde.score_samples(grid_points)).reshape(grid.shape)

        return density_estiated

In [None]:
class StatsmodelsKDE1D(DensityEstimator):
    def __init__(self, bw_method: str):
        super().__init__(EstimatorInfo("statsmodels", bw_method, 1))

    def fit(self, data: np.ndarray):
        if data.ndim > 1:
            assert data.shape[1] == self.info.dim
        else:
            assert self.info.dim == 1
        self._data = data
        self._kde = KDEUnivariate(data)
        self._is_fit = True

        return self

    def evaluate(self, grid: pkde.Grid) -> np.ndarray:
        grid_size = grid.shape[0]
        self._kde.fit(bw=self.info.method, fft=True, gridsize=grid_size)

        grid_points = grid.to_meshgrid()[0]
        density_estimated = self._kde.evaluate(grid_points)

        return density_estimated

In [None]:
class StatsmodelsKDE2D(DensityEstimator):
    def __init__(self, bw_method: str):
        super().__init__(EstimatorInfo("statsmodels", bw_method, 2))

    def fit(self, data: np.ndarray):
        if data.ndim > 1:
            assert data.shape[1] == self.info.dim
        else:
            assert self.info.dim == 1
        self._data = data
        self._kde = KDEMultivariate(data=data, var_type="cc", bw=self.info.method)
        self._is_fit = True

    def evaluate(self, grid: pkde.Grid) -> np.ndarray:
        grid_mesh = grid.to_meshgrid()
        grid_points = np.vstack([x.ravel() for x in grid_mesh])
        
        density_estimated = self._kde.pdf(grid_points).reshape(grid.shape)

        return density_estimated

In [None]:
class KDEpyKDE(DensityEstimator):
    def __init__(self, dim: int, bw_method: str):
        super().__init__(EstimatorInfo("kdepy", bw_method, dim))

    def fit(self, data: np.ndarray):
        if data.ndim > 1:
            assert data.shape[1] == self.info.dim
        else:
            assert self.info.dim == 1
        self._data = data
        self._kde = FFTKDE(bw=self.info.method).fit(data)
        self._is_fit = True

        return self

    def evaluate(self, grid: pkde.Grid) -> np.ndarray:
        grid_mesh = grid.to_meshgrid()
        if len(grid_mesh) > 1:
            grid_points = np.vstack([x.ravel() for x in grid_mesh])
        else:
            grid_points = grid_mesh[0]

        density_estimated = self._kde.evaluate(grid_points).reshape(grid.shape)

        return density_estimated

In [None]:
class ParallelKDE(DensityEstimator):
    def __init__(self, dim:int, bw_method: str, device: str):
        super().__init__(EstimatorInfo("parallelkdepy", bw_method, dim, device))

    def fit(self, data: np.ndarray):
        if data.ndim > 1:
            assert data.shape[1] == self.info.dim
            self._data = data
        else:
            assert self.info.dim == 1
            self._data = data[:, np.newaxis]
        self._is_fit = True

        return self

    def evaluate(self, grid: pkde.Grid) -> np.ndarray:
        self._kde = pkde.DensityEstimation(self._data, grid=grid, device=self.info.device)
        self._kde.estimate_density(self.info.method)
        density_estimated = self._kde.get_density()

        return density_estimated

In [None]:
class XentropyKDE(DensityEstimator):
    def __init__(self, dim:int):
        super().__init__(EstimatorInfo("xentropy", "botev", dim))

    def fit(self, data: np.ndarray):
        if data.ndim > 1:
            assert data.shape[1] == self.info.dim
        else:
            assert self.info.dim == 1
        self._data = data
        self._is_fit = True

    def evaluate(self, grid: pkde.Grid) -> np.ndarray:
        self._kde = Kde(data=self._data, resolution=grid.shape[0])
        density_estimated = self._kde.pdf

        return density_estimated

In [None]:
@dataclass
class GridSpec:
    dim: int
    n_points: int

In [None]:
@dataclass
class DataSpec:
    dim: int
    n_samples: int

In [None]:
@dataclass
class RunSpec:
    n_warmup: int = 1
    n_runs: int = 10

In [None]:
def time_block(fn):
    t0 = time.perf_counter()
    out = fn()
    t1 = time.perf_counter()

    return out, (t1 - t0)

In [None]:
def available_estimators():
    ests = []
    
    # Scipy estimators
    for n in [1, 2]:
        for m in ["scott"]:
            ests.append(ScipyKDE(n, m))
            
    # Sklearn estimators
    for n in [1, 2]:
        for m in ["scott"]:
            ests.append(SklearnKDE(n, m))

    # Statsmodels estimators
    for n in [1, 2]:
        if n == 1:
            ms = ["scott"]
            for m in ms:
                ests.append(StatsmodelsKDE1D(m))
        else:
            ms = ["normal_reference"] # we can add cross-validation later because it's really slow
            for m in ms:
                ests.append(StatsmodelsKDE2D(m))

    # KDEpy estimators
    for n in [1]: # KDEpy only has bandwidth detection for 1D
        for m in ["scott", "ISJ"]:
            ests.append(KDEpyKDE(n, m))

    # ParallelKDE estimators
    for n in [1, 2]:
        for m in ["parallelEstimator", "rotEstimator"]:
            for d in ["cpu", "cuda"]:
                ests.append(ParallelKDE(n, m, d))

    for n in [1]:
        ests.append(XentropyKDE(n))

    return ests

In [None]:
def run_benchmark(estimators, datasets, grids, run_spec: RunSpec = None, verbose=False):
    if run_spec is None:
        run_spec = RunSpec()

    rows = []
    for gs in tqdm(grids, desc="Iterating grids..."):
        for ds in tqdm(datasets, desc="Iterating sample sizes..."):
            if ds.dim != gs.dim:
                continue

            dim = ds.dim

            distro = get_distro(dim)
            samples = sample_distro(distro, ds.n_samples)

            for est in estimators:
                if est.info.dim != dim:
                    continue
                if verbose:
                    print(f"Running {est.name()} with: dimensions: {dim}; grid size: {gs.n_points}; sample size:{ds.n_samples}")
                
                grid = create_grid(gs.n_points, dim, device=est.info.device)
                
                # Warmups (check if estimations work correctly and compile if needed)
                failed = False
                for _ in range(run_spec.n_warmup):
                    try:
                        est.fit(samples)
                        _ = est.evaluate(grid)
                    except Exception as e:
                        print(f"ERROR: {est.name()}")
                        rows.append({
                            "package": est.info.package,
                            "method": est.info.method,
                            "device": est.info.device,
                            "dim": est.info.dim,
                            "sample_size": ds.n_samples,
                            "grid_points": gs.n_points,
                            "ok": False,
                            "error": str(e),
                        })
                        failed = True
                        break

                if failed:
                    continue

                # Timed runs
                for _ in range(run_spec.n_runs):
                    _, fit_t = time_block(lambda: est.fit(samples))
                    _, eval_t = time_block(lambda: est.evaluate(grid))
                    rows.append({
                        "package": est.info.package,
                        "method": est.info.method,
                        "device": est.info.device,
                        "dim": est.info.dim,
                        "sample_size": ds.n_samples,
                        "grid_points": gs.n_points,
                        "fit_time_s": fit_t,
                        "eval_time_s": eval_t,
                        "total_time_s": fit_t + eval_t,
                        "ok": True,
                    })

    df = pd.DataFrame(rows)
    prefer = ["package", "method", "device", "dim", "sample_size", "grid_points", "fit_time_s", "eval_time_s", "total_time_s", "ok", "error"]
    cols = [c for c in prefer if c in df.columns] + [c for c in df.columns if c not in prefer]

    return df[cols]

### Quick check

In [None]:
datasets = [
    DataSpec(1, 1000),
    # DataSpec(2, 10000),
]
grids = [
    GridSpec(1, 500),
    # GridSpec(2, 125),
]
ests = available_estimators()

In [None]:
results_quick = run_benchmark(ests, datasets, grids, verbose=True)

In [None]:
results_quick

### Run benchmark

In [None]:
specs_samples = (
    [DataSpec(1, m) for m in (100, 1000, 10000, 100000)] + [DataSpec(2, m) for m in (1000, 10000, 100000, 1000000)], [GridSpec(1, 500), GridSpec(2, 100)]
)
specs_grids = (
    [DataSpec(1, 10000), DataSpec(2, 100000)], [GridSpec(1, m) for m in (100, 500, 2500)] + [GridSpec(2, m) for m in (33, 100, 300)]
)
ests = available_estimators()

#### Samples benchmark

In [None]:
results_samples = run_benchmark(ests, *specs_samples)

In [None]:
results_samples.to_csv("benchmark_samples_1d.csv", index=False)

#### Grids benchmark

In [None]:
results_grids = run_benchmark(ests, *specs_grids)

In [None]:
results_grids.to_csv("benchmark_grids_1d.csv", index=False)

## Plots (box plots)

### Sample benchmark

In [None]:
sbenchmark_results = pd.read_csv("benchmark_samples.csv")

In [None]:
mapping = {
    "package": {"scipy": "SciPy", "sklearn": "scikit-learn", "kdepy": "KDEpy", "parallelkdepy": "ParallelKDE"},
    "device": {"cpu": "CPU", "cuda": "CUDA"},
    "method": {"rotEstimator": "ROT (scott)", "parallelEstimator": "GradePro"},
}
sbenchmark_results = sbenchmark_results.replace(mapping)

In [None]:
sbenchmark_results

In [None]:
g = (
    sbenchmark_results.groupby(["dim", "sample_size", "package", "method", "device"], as_index=False)["total_time_s"].mean().rename(columns={"total_time_s": "mean_total_time"})
)
g["fastest_for_ss"] = g.groupby(["dim", "sample_size"])["mean_total_time"].transform("min")
g["rel_mean"] = g["mean_total_time"] / g["fastest_for_ss"]

dims = sorted(sbenchmark_results["dim"].dropna().unique())
packages = sorted(sbenchmark_results["package"].unique())
devices = sorted(sbenchmark_results["device"].unique())

base_colors = matplotlib.colormaps["tab10"]
pkg_colors = {p: base_colors(i) for i, p in enumerate(packages)}

ls_options = ["-", "--", ":", "-."]
dev_ls = {d: ls_options[i % len(ls_options)] for i, d in enumerate(devices)}

alpha_box = 0.7
alpha_points = 0.85
point_size = 16
jitter_sigma = 0.0

fig, axes = plt.subplots(1, 2, figsize=(5.6, 2.6), sharey=True)

for ax, d in zip(axes, dims):
    gd = g[g["dim"] == d]
    med = gd.groupby(["package", "method", "device"])["rel_mean"].median().sort_values()
    order = list(med.index)

    data = [
        gd[(gd["package"] == p) & (gd["method"] == m) & (gd["device"] == dev)]["rel_mean"].values
        for (p, m, dev) in order
    ]
    labels = [f"{p} | {m} | {dev}" for (p, m, dev) in order]

    bp = ax.boxplot(data, patch_artist=True, showfliers=False)

    for i, ((p, m, dev), box) in enumerate(zip(order, bp["boxes"])):
        color = pkg_colors[p]
        ls = dev_ls[dev]
        box.set(facecolor=color, alpha=alpha_box, edgecolor="black", linewidth=1.2, linestyle=ls)
        for w in (bp["whiskers"][2*i], bp["whiskers"][2*i+1]):
            w.set(color="black", alpha=alpha_box, linestyle=ls, linewidth=1.2)
        for c in (bp["caps"][2*i], bp["caps"][2*i+1]):
            c.set(color="black", alpha=alpha_box, linestyle=ls, linewidth=1.0)
        bp["medians"][i].set(color="black", linewidth=1.2, linestyle=ls)

    for i, vals in enumerate(data, start=1):
        p, m, dev = order[i-1]
        ax.scatter(
            [i]*len(vals), vals,
            s=point_size,
            alpha=alpha_points,
            color=pkg_colors[p],
            edgecolors="black",
            linestyle=dev_ls[dev],
            zorder=3,
        )

    ax.set_xticks(range(1, len(order)+1))
    ax.set_xticklabels([m for (_, m, _) in order], rotation=75, ha="right")
    
    ax.axhline(1.0, linestyle="--", linewidth=1)

    ax.set_yscale("log")
    
    ax.set_title(f"{d}D benchmark")
    if ax is axes[0]:
        ax.set_ylabel("Relative mean time")
    ax.tick_params(axis="x", rotation=75)
    ax.tick_params(which="both", top=False, right=False)

    ax.grid(alpha=0.6)

pkg_handles = [matplotlib.patches.Patch(facecolor=pkg_colors[p], alpha=alpha_box, edgecolor="black", label=p) for p in packages]
dev_handles = [matplotlib.lines.Line2D([0], [0], color="black", linestyle=dev_ls[d], label=d) for d in devices]
fig.legend(handles=pkg_handles, title="Package", loc="upper right", bbox_to_anchor=(1.17, 0.95), frameon=True)
fig.legend(handles=dev_handles, title="Device", loc="lower right", bbox_to_anchor=(1.145, 0.37), frameon=True)

fig.tight_layout()

# fig.savefig("benchmark_samples.pdf", dpi=500, bbox_inches="tight", pad_inches=0.02)

### Grid benchmark

In [None]:
gbenchmark_results = pd.read_csv("benchmark_grids.csv")

In [None]:
mapping = {
    "package": {"scipy": "SciPy", "sklearn": "scikit-learn", "kdepy": "KDEpy", "parallelkdepy": "ParallelKDE"},
    "device": {"cpu": "CPU", "cuda": "CUDA"},
    "method": {"rotEstimator": "ROT (scott)", "parallelEstimator": "GradePro"},
}
gbenchmark_results = gbenchmark_results.replace(mapping)

In [None]:
gbenchmark_results

In [None]:
g = (
    gbenchmark_results.groupby(["dim", "grid_points", "package", "method", "device"], as_index=False)["total_time_s"].mean().rename(columns={"total_time_s": "mean_total_time"})
)
g["fastest_for_ss"] = g.groupby(["dim", "grid_points"])["mean_total_time"].transform("min")
g["rel_mean"] = g["mean_total_time"] / g["fastest_for_ss"]

dims = sorted(gbenchmark_results["dim"].dropna().unique())
packages = sorted(gbenchmark_results["package"].unique())
devices = sorted(gbenchmark_results["device"].unique())

base_colors = matplotlib.colormaps["tab10"]
pkg_colors = {p: base_colors(i) for i, p in enumerate(packages)}

ls_options = ["-", "--", ":", "-."]
dev_ls = {d: ls_options[i % len(ls_options)] for i, d in enumerate(devices)}

alpha_box = 0.7
alpha_points = 0.85
point_size = 16

fig, axes = plt.subplots(1, 2, figsize=(5.6, 2.6), sharey=True)

for ax, d in zip(axes, dims):
    gd = g[g["dim"] == d]
    med = gd.groupby(["package", "method", "device"])["rel_mean"].median().sort_values()
    order = list(med.index)

    data = [
        gd[(gd["package"] == p) & (gd["method"] == m) & (gd["device"] == dev)]["rel_mean"].values
        for (p, m, dev) in order
    ]
    labels = [f"{p} | {m} | {dev}" for (p, m, dev) in order]

    bp = ax.boxplot(data, patch_artist=True, showfliers=False)

    for i, ((p, m, dev), box) in enumerate(zip(order, bp["boxes"])):
        color = pkg_colors[p]
        ls = dev_ls[dev]
        box.set(facecolor=color, alpha=alpha_box, edgecolor="black", linewidth=1.2, linestyle=ls)
        for w in (bp["whiskers"][2*i], bp["whiskers"][2*i+1]):
            w.set(color="black", alpha=alpha_box, linestyle=ls, linewidth=1.2)
        for c in (bp["caps"][2*i], bp["caps"][2*i+1]):
            c.set(color="black", alpha=alpha_box, linestyle=ls, linewidth=1.0)
        bp["medians"][i].set(color="black", linewidth=1.2, linestyle=ls)

    for i, vals in enumerate(data, start=1):
        p, m, dev = order[i-1]
        ax.scatter(
            [i]*len(vals), vals,
            s=point_size,
            alpha=alpha_points,
            color=pkg_colors[p],
            edgecolors="black",
            linestyle=dev_ls[dev],
            zorder=3,
        )

    ax.set_xticks(range(1, len(order)+1))
    ax.set_xticklabels([m for (_, m, _) in order], rotation=75, ha="right")
    
    ax.axhline(1.0, linestyle="--", linewidth=1)

    ax.set_yscale("log")
    
    ax.set_title(f"{d}D benchmark")
    if ax is axes[0]:
        ax.set_ylabel("Relative mean time")
    ax.tick_params(axis="x", rotation=75)
    ax.tick_params(which="both", top=False, right=False)

    ax.grid(alpha=0.6)

pkg_handles = [matplotlib.patches.Patch(facecolor=pkg_colors[p], alpha=alpha_box, edgecolor="black", label=p) for p in packages]
dev_handles = [matplotlib.lines.Line2D([0], [0], color="black", linestyle=dev_ls[d], label=d) for d in devices]
fig.legend(handles=pkg_handles, title="Package", loc="upper right", bbox_to_anchor=(1.17, 0.95), frameon=True)
fig.legend(handles=dev_handles, title="Device", loc="lower right", bbox_to_anchor=(1.145, 0.37), frameon=True)

fig.tight_layout()

# fig.savefig("benchmark_grid.pdf", dpi=500, bbox_inches="tight", pad_inches=0.02)

## Plots (bar plots)

### Sample benchmark

In [None]:
sbenchmark_results = pd.read_csv("benchmark_samples.csv")

In [None]:
mapping = {
    "package": {"scipy": "SciPy", "sklearn": "scikit-learn", "kdepy": "KDEpy", "parallelkdepy": "ParallelKDE", "xentropy": "X-Entropy"},
    "device": {"cpu": "CPU", "cuda": "CUDA"},
    "method": {"rotEstimator": "ROT", "parallelEstimator": "GradePro", "scott": "ROT", "ISJ": "Plug-In", "normal_reference": "ROT", "botev": "Plug-In"},
}
sbenchmark_results = sbenchmark_results.replace(mapping)

In [None]:
sbenchmark_results

In [None]:
cmap = matplotlib.colormaps["tab20"]
hatches = ["++", "xx"]

packages = sbenchmark_results["package"].sort_values().unique()
devices = sbenchmark_results["device"].unique()
dims = sbenchmark_results["dim"].unique()
pkg_colors = {p: [cmap(2*i), cmap(2*i+1)] for i, p in enumerate(packages)}
dev_hatches = {d: hatches[i] for i, d in enumerate(devices)}

fig, axes = plt.subplots(1, 2, figsize=(5.6, 2.6))

for ax, dim in zip(axes, dims):
    dsub = sbenchmark_results[sbenchmark_results["dim"] == dim].copy()
    dsub_agg = dsub.groupby(["package", "method", "device", "sample_size"])["total_time_s"].mean().reset_index()

    min_size = dsub_agg["sample_size"].min()
    df_min = dsub_agg[dsub_agg["sample_size"] == min_size]
    max_size = dsub_agg["sample_size"].max()
    df_max = dsub_agg[dsub_agg["sample_size"] == max_size]

    sel = pd.concat([df_min, df_max], ignore_index=True)

    mean_by_pkg_meth_dev = sel.groupby(["package", "method", "device"])["total_time_s"].mean().reset_index()
    methods_by_pkg_dev = mean_by_pkg_meth_dev.sort_values(["package", "total_time_s"]).groupby(["package", "device"])["method"].apply(list).to_dict()

    bar_width = 0.38
    group_gap = 0.45
    package_gap = 0.0
    left_offset = -bar_width/2
    right_offset = bar_width/2

    x_positions, x_small, x_large = [], [], []
    heights_small, heights_large = [], []
    pkg_for_bar, xtick_labels = [], []
    dev_for_bar = []
    pkg_bands = []

    cursor = 0.0
    for pkg, dev in methods_by_pkg_dev:
        meths = methods_by_pkg_dev[pkg, dev]
        start_band = None
        for m in meths:
            x_center = cursor
            x_positions.append(x_center)
            x_small.append(x_center + left_offset)
            x_large.append(x_center + right_offset)

            v_small = sel[(sel["package"]==pkg) & (sel["method"]==m) & (sel["device"]==dev) & (sel["sample_size"]==min_size)]["total_time_s"]
            v_large = sel[(sel["package"]==pkg) & (sel["method"]==m) & (sel["device"]==dev) & (sel["sample_size"]==max_size)]["total_time_s"]
            heights_small.append(v_small.iloc[0])
            heights_large.append(v_large.iloc[0])

            pkg_for_bar.append(pkg)
            dev_for_bar.append(dev)
            xtick_labels.append(m)

            cursor += (2*bar_width + group_gap)

            if start_band is None:
                start_band = x_center - (bar_width + group_gap/2)
            end_band = x_center + (bar_width + group_gap/2)

        pkg_bands.append((start_band, end_band, pkg))
        cursor += package_gap

        bar_colors_small = [pkg_colors[pb][1] for pb in pkg_for_bar]
        bar_colors_large = [pkg_colors[pb][0] for pb in pkg_for_bar]
        bar_hatches = [dev_hatches[d] for d in dev_for_bar]

        ax.bar(x_small, heights_small, width=bar_width, edgecolor=(0, 0, 0, 0.5), color=bar_colors_small, hatch=bar_hatches)
        ax.bar(x_large, heights_large, width=bar_width, edgecolor=(0, 0, 0, 0.5), color=bar_colors_large, hatch=bar_hatches)

    ax.set_yscale("log")

    ax.set_title(f"{dim}D benchmark")

    ax.set_xticks(x_positions)
    ax.set_xticklabels(xtick_labels, rotation=45)

    labels = ax.get_xticklabels()
    if dim == 1:
        idx_to_color = {2: "indianred", 3: "indianred", 4: "indianred", 5: "indianred"}
    else:
        idx_to_color = {0: "indianred", 1: "indianred", 2: "indianred", 3: "indianred"}

    for i, t in enumerate(labels):
        if i in idx_to_color:
            t.set_color(idx_to_color[i])
            

    # ax.set_xlabel("Estimator")
    if dim == 1:
        ax.set_ylabel("Runtime [s]")

legend_colors = [pc[0] for pc in pkg_colors.values()]
color_labels = [pkg for pkg in pkg_colors]
color_handles = [matplotlib.patches.Patch(facecolor=legend_colors[i], label=color_labels[i]) for i in range(len(color_labels))]
pkg_leg = fig.legend(handles=color_handles, title="Packages", loc="upper right", frameon=True, bbox_to_anchor=(1.17, 1.0))
for txt_idx, txt in enumerate(pkg_leg.get_texts()):
    if txt_idx == 1:
        txt.set_color("indianred")

grey_handles = [matplotlib.patches.Patch(facecolor=fc, label=ls) for fc, ls in zip(["0.75", "0.25"], ["small", "large"])]
sample_leg = fig.legend(handles=grey_handles, title="Sample size", loc="center right", frameon=True, bbox_to_anchor=(1.145, 0.48))

hatch_handles = [matplotlib.patches.Patch(facecolor="white", edgecolor="black", hatch=h, label=ld) for h, ld in zip(["++", "xx"], ["CPU", "CUDA"])]
dev_leg = fig.legend(handles=hatch_handles, title="Device", loc="lower right", frameon=True, bbox_to_anchor=(1.145, 0.17))

# fig.savefig("benchmark_samples.pdf", dpi=500, bbox_inches="tight", pad_inches=0.02)

### Grid benchmark

In [None]:
gbenchmark_results = pd.read_csv("benchmark_grids.csv")

In [None]:
mapping = {
    "package": {"scipy": "SciPy", "sklearn": "scikit-learn", "kdepy": "KDEpy", "parallelkdepy": "ParallelKDE", "xentropy": "X-Entropy"},
    "device": {"cpu": "CPU", "cuda": "CUDA"},
    "method": {"rotEstimator": "ROT", "parallelEstimator": "GradePro", "scott": "ROT", "normal_reference": "ROT", "ISJ": "Plug-In", "botev": "Plug-In"},
}
gbenchmark_results = gbenchmark_results.replace(mapping)

In [None]:
gbenchmark_results

In [None]:
cmap = matplotlib.colormaps["tab20"]
hatches = ["++", "xx"]

packages = gbenchmark_results["package"].sort_values().unique()
devices = gbenchmark_results["device"].unique()
dims = gbenchmark_results["dim"].unique()
pkg_colors = {p: [cmap(2*i), cmap(2*i+1)] for i, p in enumerate(packages)}
dev_hatches = {d: hatches[i] for i, d in enumerate(devices)}

fig, axes = plt.subplots(1, 2, figsize=(5.6, 2.6))

for ax, dim in zip(axes, dims):
    dsub = gbenchmark_results[gbenchmark_results["dim"] == dim].copy()
    dsub_agg = dsub.groupby(["package", "method", "device", "grid_points"])["total_time_s"].mean().reset_index()

    min_size = dsub_agg["grid_points"].min()
    df_min = dsub_agg[dsub_agg["grid_points"] == min_size]
    max_size = dsub_agg["grid_points"].max()
    df_max = dsub_agg[dsub_agg["grid_points"] == max_size]

    sel = pd.concat([df_min, df_max], ignore_index=True)

    mean_by_pkg_meth_dev = sel.groupby(["package", "method", "device"])["total_time_s"].mean().reset_index()
    methods_by_pkg_dev = mean_by_pkg_meth_dev.sort_values(["package", "total_time_s"]).groupby(["package", "device"])["method"].apply(list).to_dict()

    bar_width = 0.38
    group_gap = 0.45
    package_gap = 0.0
    left_offset = -bar_width/2
    right_offset = bar_width/2

    x_positions, x_small, x_large = [], [], []
    heights_small, heights_large = [], []
    pkg_for_bar, xtick_labels = [], []
    dev_for_bar = []
    pkg_bands = []

    cursor = 0.0
    for pkg, dev in methods_by_pkg_dev:
        meths = methods_by_pkg_dev[pkg, dev]
        start_band = None
        for m in meths:
            x_center = cursor
            x_positions.append(x_center)
            x_small.append(x_center + left_offset)
            x_large.append(x_center + right_offset)

            v_small = sel[(sel["package"]==pkg) & (sel["method"]==m) & (sel["device"]==dev) & (sel["grid_points"]==min_size)]["total_time_s"]
            v_large = sel[(sel["package"]==pkg) & (sel["method"]==m) & (sel["device"]==dev) & (sel["grid_points"]==max_size)]["total_time_s"]
            heights_small.append(v_small.iloc[0])
            heights_large.append(v_large.iloc[0])

            pkg_for_bar.append(pkg)
            dev_for_bar.append(dev)
            xtick_labels.append(m)

            cursor += (2*bar_width + group_gap)

            if start_band is None:
                start_band = x_center - (bar_width + group_gap/2)
            end_band = x_center + (bar_width + group_gap/2)

        pkg_bands.append((start_band, end_band, pkg))
        cursor += package_gap

        bar_colors_small = [pkg_colors[pb][1] for pb in pkg_for_bar]
        bar_colors_large = [pkg_colors[pb][0] for pb in pkg_for_bar]
        bar_hatches = [dev_hatches[d] for d in dev_for_bar]

        ax.bar(x_small, heights_small, width=bar_width, edgecolor=(0, 0, 0, 0.5), color=bar_colors_small, hatch=bar_hatches)
        ax.bar(x_large, heights_large, width=bar_width, edgecolor=(0, 0, 0, 0.5), color=bar_colors_large, hatch=bar_hatches)

    ax.set_yscale("log")

    ax.set_title(f"{dim}D benchmark")

    ax.set_xticks(x_positions)
    ax.set_xticklabels(xtick_labels, rotation=45)

    labels = ax.get_xticklabels()
    if dim == 1:
        idx_to_color = {2: "indianred", 3: "indianred", 4: "indianred", 5: "indianred"}
    else:
        idx_to_color = {0: "indianred", 1: "indianred", 2: "indianred", 3: "indianred"}

    for i, t in enumerate(labels):
        if i in idx_to_color:
            t.set_color(idx_to_color[i])
            

    # ax.set_xlabel("Estimator")
    if dim == 1:
        ax.set_ylabel("Runtime [s]")

legend_colors = [pc[0] for pc in pkg_colors.values()]
color_labels = [pkg for pkg in pkg_colors]
color_handles = [matplotlib.patches.Patch(facecolor=legend_colors[i], label=color_labels[i]) for i in range(len(color_labels))]
pkg_leg = fig.legend(handles=color_handles, title="Packages", loc="upper right", frameon=True, bbox_to_anchor=(1.17, 1.0))
for txt_idx, txt in enumerate(pkg_leg.get_texts()):
    if txt_idx == 1:
        txt.set_color("indianred")

grey_handles = [matplotlib.patches.Patch(facecolor=fc, label=ls) for fc, ls in zip(["0.75", "0.25"], ["small", "large"])]
sample_leg = fig.legend(handles=grey_handles, title="Grid size", loc="center right", frameon=True, bbox_to_anchor=(1.14, 0.48))

hatch_handles = [matplotlib.patches.Patch(facecolor="white", edgecolor="black", hatch=h, label=ld) for h, ld in zip(["++", "xx"], ["CPU", "CUDA"])]
dev_leg = fig.legend(handles=hatch_handles, title="Device", loc="lower right", frameon=True, bbox_to_anchor=(1.145, 0.17))

# fig.savefig("benchmark_grids.pdf", dpi=500, bbox_inches="tight", pad_inches=0.02)