In [1]:
import spectrograms as sg
import numpy as np
import time
from dataclasses import dataclass
from typing import Callable, Any
import pandas as pd
from numpy_impls import *
from datetime import datetime
import json
from pathlib import Path
from itertools import product
from scipy.signal import stft as scipy_stft
from scipy.fft import fft as scipy_fft

np.random.seed(0)

# High-resolution monotonic timer
now = time.perf_counter

In [2]:
def get_window(window_type: str, n_fft: int) -> np.ndarray:
    """Get window function array for numpy implementations."""
    if window_type == "rectangular":
        return np.ones(n_fft)
    elif window_type == "hanning":
        return hann_window(n_fft)
    elif window_type == "hamming":
        return np.hamming(n_fft)
    elif window_type == "blackman":
        return np.blackman(n_fft)
    elif window_type == "kaiser_5":
        return np.kaiser(n_fft, 5.0)
    elif window_type == "gaussian_std05":
        from scipy.signal.windows import gaussian

        return gaussian(n_fft, 0.5 * n_fft)
    else:
        raise ValueError(f"Unknown window type: {window_type}")


@dataclass(frozen=True)
class StftSpec:
    sample_rate: int = 16_000
    n_fft: int = 1024
    hop_length: int = 256
    centre: bool = True
    window_type: str = "hanning"

    def window(self) -> np.ndarray:
        return get_window(self.window_type, self.n_fft)


def make_fixtures(sample_rate: int):
    t = np.linspace(0, 1.0, sample_rate, endpoint=False)
    return {
        "sine_440": np.sin(2 * np.pi * 440 * t),
        "sine_3k": np.sin(2 * np.pi * 3_000 * t),
        "noise": np.random.randn(sample_rate),
        "chirp": np.sin(2 * np.pi * (100 + 3_000 * t**2) * t),
        "impulse": np.eye(1, sample_rate, 0).ravel(),
    }


SAMPLE_RATE = 16_000
FIXTURES = make_fixtures(SAMPLE_RATE)

In [3]:
@dataclass(frozen=True)
class Operator:
    name: str
    fn: Callable[[np.ndarray], Any]


def benchmark_fn(fn: Callable[[], Any], *, warmup: int = 10, runs: int = 500):
    """Benchmark a function with warmup and multiple runs."""
    for _ in range(warmup):
        fn()

    times = np.empty(runs, dtype=np.float64)
    for i in range(runs):
        t0 = now()
        fn()
        t1 = now()
        times[i] = t1 - t0

    return {
        "mean": times.mean(),
        "std": times.std(),
        "min": times.min(),
        "max": times.max(),
        "raw": times,
    }


def compare_outputs(ref, test):
    """Compare two outputs for numerical correctness."""
    ref = np.asarray(ref)
    test = np.asarray(test)

    if ref.shape != test.shape:
        raise ValueError(f"Shape mismatch: {ref.shape} vs {test.shape}")

    diff = test - ref

    abs_max = np.max(np.abs(diff))
    abs_mean = np.mean(np.abs(diff))

    denom = np.maximum(np.abs(ref), 1e-12)
    rel_max = np.max(np.abs(diff) / denom)
    rel_mean = np.mean(np.abs(diff) / denom)

    return {
        "abs_max": abs_max,
        "abs_mean": abs_mean,
        "rel_max": rel_max,
        "rel_mean": rel_mean,
    }

In [4]:
# ==========================================
# Parameter Sweep Configuration
# ==========================================

# Define parameter grids - EXTENSIVE SWEEP
STFT_SWEEP = {
    "n_fft": [128, 256, 512, 1024, 2048, 4096],
    "hop_ratio": [0.125, 0.25, 0.5, 0.75],  # hop_size = n_fft * ratio
    "window_type": [
        "rectangular",
        "hanning",
        "hamming",
        "blackman",
        "kaiser_5",
        "gaussian_std05",
    ],
    "centre": [True, False],
    "sample_rate": [16000],
}

# Note: ~576 combinations per operator (6 x 4 x 6 x 2)

MEL_SWEEP = {
    "n_mels": [40, 80, 128],
    "f_min": [0.0],
    "f_max": [8000.0],
}

ERB_SWEEP = {
    "n_filters": [32, 64],
    "f_min": [50.0],
    "f_max": [8000.0],
}

LOGHZ_SWEEP = {
    "n_bins": [64, 84, 128],
    "f_min": [55.0],
    "f_max": [7040.0],
}

CHROMA_SWEEP = {
    "tuning": [440.0],
    "norm": ["l2", "max"],
}

DB_SWEEP = {
    "floor_db": [-80.0, -100.0],
}

print("Parameter sweep configuration loaded")
print(
    f"STFT combinations per operator: {len(STFT_SWEEP['n_fft']) * len(STFT_SWEEP['hop_ratio']) * len(STFT_SWEEP['window_type']) * len(STFT_SWEEP['centre'])}"
)

Parameter sweep configuration loaded
STFT combinations per operator: 288


In [5]:
# ==========================================
# Parameter Iterator and Operator Factory
# ==========================================

# Window type mappings
WINDOW_TYPE_MAP = {
    "rectangular": sg.WindowType.rectangular,
    "hanning": sg.WindowType.hanning,
    "hamming": sg.WindowType.hamming,
    "blackman": sg.WindowType.blackman,
    "kaiser_5": sg.WindowType.kaiser(5.0),
    "gaussian_std05": sg.WindowType.gaussian(0.5),
}

# Scipy window mapping
SCIPY_WINDOW_MAP = {
    "rectangular": "boxcar",
    "hanning": "hann",
    "hamming": "hamming",
    "blackman": "blackman",
    "kaiser_5": ("kaiser", 5.0),
    "gaussian_std05": ("gaussian", 0.5),
}


def generate_param_combinations(operator_type: str):
    """Generate all parameter combinations for a given operator."""
    combinations = []

    # Generate all STFT combinations
    for n_fft, hop_ratio, window_type, centre, sample_rate in product(
        STFT_SWEEP["n_fft"],
        STFT_SWEEP["hop_ratio"],
        STFT_SWEEP["window_type"],
        STFT_SWEEP["centre"],
        STFT_SWEEP["sample_rate"],
    ):
        hop_size = int(n_fft * hop_ratio)

        base_params = {
            "n_fft": n_fft,
            "hop_size": hop_size,
            "hop_ratio": hop_ratio,  # Store for reference
            "window_type": window_type,
            "centre": centre,
            "sample_rate": sample_rate,
        }

        # Add scale-specific parameters
        if operator_type in ["power", "magnitude"]:
            # Linear operators - just STFT params
            combinations.append(base_params)

        elif operator_type == "db":
            for floor_db in DB_SWEEP["floor_db"]:
                params = base_params.copy()
                params["scale_params"] = {"floor_db": floor_db}
                combinations.append(params)

        elif operator_type == "mel":
            for n_mels, f_min, f_max in product(
                MEL_SWEEP["n_mels"],
                MEL_SWEEP["f_min"],
                MEL_SWEEP["f_max"],
            ):
                # Validate Nyquist limit
                if f_max > sample_rate / 2:
                    continue
                params = base_params.copy()
                params["scale_params"] = {
                    "n_mels": n_mels,
                    "f_min": f_min,
                    "f_max": f_max,
                }
                combinations.append(params)

        elif operator_type == "erb":
            for n_filters, f_min, f_max in product(
                ERB_SWEEP["n_filters"],
                ERB_SWEEP["f_min"],
                ERB_SWEEP["f_max"],
            ):
                if f_max > sample_rate / 2:
                    continue
                params = base_params.copy()
                params["scale_params"] = {
                    "n_filters": n_filters,
                    "f_min": f_min,
                    "f_max": f_max,
                }
                combinations.append(params)

        elif operator_type == "loghz":
            for n_bins, f_min, f_max in product(
                LOGHZ_SWEEP["n_bins"],
                LOGHZ_SWEEP["f_min"],
                LOGHZ_SWEEP["f_max"],
            ):
                if f_max > sample_rate / 2:
                    continue
                params = base_params.copy()
                params["scale_params"] = {
                    "n_bins": n_bins,
                    "f_min": f_min,
                    "f_max": f_max,
                }
                combinations.append(params)

        elif operator_type == "chroma":
            for tuning, norm in product(
                CHROMA_SWEEP["tuning"],
                CHROMA_SWEEP["norm"],
            ):
                params = base_params.copy()
                params["scale_params"] = {
                    "tuning": tuning,
                    "norm": norm,
                }
                combinations.append(params)

    return combinations


def create_operator(operator_type: str, params: dict, implementation: str) -> Operator:
    """Create Operator object from parameter dict."""
    if implementation == "rust":
        return create_rust_operator(operator_type, params)
    elif implementation == "numpy":
        return create_numpy_operator(operator_type, params)
    elif implementation == "scipy":
        return create_scipy_operator(operator_type, params)
    else:
        raise ValueError(f"Unknown implementation: {implementation}")


def create_rust_operator(operator_type: str, params: dict) -> Operator:
    """Create Rust operator from parameters."""
    # Build SpectrogramParams
    window = WINDOW_TYPE_MAP[params["window_type"]]
    stft_params = sg.StftParams(
        n_fft=params["n_fft"],
        hop_size=params["hop_size"],
        window=window,
        centre=params["centre"],
    )
    spec_params = sg.SpectrogramParams(stft_params, sample_rate=params["sample_rate"])

    scale_params = params.get("scale_params", {})

    if operator_type == "power":
        return Operator(
            "power",
            lambda x, sp=spec_params: sg.compute_linear_power_spectrogram(x, sp),
        )
    elif operator_type == "magnitude":
        return Operator(
            "magnitude",
            lambda x, sp=spec_params: sg.compute_linear_magnitude_spectrogram(x, sp),
        )
    elif operator_type == "db":
        db_params = sg.LogParams(floor_db=scale_params["floor_db"])
        return Operator(
            "db",
            lambda x, sp=spec_params, dp=db_params: sg.compute_linear_db_spectrogram(
                x, sp, dp
            ),
        )
    elif operator_type == "mel":
        mel_params = sg.MelParams(**scale_params)
        return Operator(
            "mel",
            lambda x, sp=spec_params, mp=mel_params: sg.compute_mel_power_spectrogram(
                x, sp, mp
            ),
        )
    elif operator_type == "erb":
        erb_params = sg.ErbParams(**scale_params)
        return Operator(
            "erb",
            lambda x, sp=spec_params, ep=erb_params: sg.compute_erb_power_spectrogram(
                x, sp, ep
            ),
        )
    elif operator_type == "loghz":
        loghz_params = sg.LogHzParams(**scale_params)
        return Operator(
            "loghz",
            lambda x,
            sp=spec_params,
            lp=loghz_params: sg.compute_loghz_power_spectrogram(x, sp, lp),
        )
    else:
        raise ValueError(f"Unknown operator type: {operator_type}")


def create_numpy_operator(operator_type: str, params: dict) -> Operator:
    """Create NumPy operator from parameters."""
    # Build StftSpec for NumPy
    spec = StftSpec(
        sample_rate=params["sample_rate"],
        n_fft=params["n_fft"],
        hop_length=params["hop_size"],
        centre=params["centre"],
        window_type=params["window_type"],
    )

    scale_params = params.get("scale_params", {})

    if operator_type == "power":
        return Operator("power", lambda x, s=spec: power_spec_from_signal(x, s))
    elif operator_type == "magnitude":
        return Operator("magnitude", lambda x, s=spec: mag_spec_from_signal(x, s))
    elif operator_type == "db":
        eps = 10 ** (scale_params["floor_db"] / 10.0)  # Convert dB to linear
        return Operator("db", lambda x, s=spec, e=eps: db_spec_from_signal(x, s, eps=e))
    elif operator_type == "mel":
        mel_spec = MelSpec(**scale_params)
        return Operator(
            "mel", lambda x, s=spec, ms=mel_spec: mel_spec_from_signal(x, s, ms)
        )
    elif operator_type == "erb":
        erb_spec = ErbSpec(
            centre_freqs=erb_centers(
                scale_params["f_min"],
                scale_params["f_max"],
                scale_params["n_filters"],
            )
        )
        return Operator(
            "erb", lambda x, s=spec, es=erb_spec: erb_spec_from_signal(x, s, es)
        )
    elif operator_type == "loghz":
        loghz_spec = LogHzSpec(**scale_params)
        return Operator(
            "loghz", lambda x, s=spec, ls=loghz_spec: loghz_spec_from_signal(x, s, ls)
        )
    elif operator_type == "chroma":
        tuning = scale_params["tuning"]
        return Operator(
            "chroma", lambda x, s=spec, t=tuning: chroma_from_signal(x, s, f_ref=t)
        )
    else:
        raise ValueError(f"Unknown operator type: {operator_type}")


def create_scipy_operator(operator_type: str, params: dict) -> Operator:
    """Create scipy operator from parameters."""
    from scipy.signal import get_window as scipy_get_window

    # Get scipy window
    window_spec = SCIPY_WINDOW_MAP[params["window_type"]]

    scale_params = params.get("scale_params", {})

    # For scipy, we'll use scipy.signal.stft as the base
    def scipy_stft_wrapper(x, n_fft, hop_size, window_spec, centre, sample_rate):
        if isinstance(window_spec, tuple):
            window = scipy_get_window(window_spec, n_fft)
        else:
            window = scipy_get_window(window_spec, n_fft)

        # scipy.signal.stft has different parameter names
        f, t, Zxx = scipy_stft(
            x,
            fs=sample_rate,
            window=window,
            nperseg=n_fft,
            noverlap=n_fft - hop_size,
            boundary="zeros" if centre else None,
            padded=centre,
        )
        return Zxx

    if operator_type == "power":
        return Operator(
            "power",
            lambda x, p=params, ws=window_spec: np.abs(
                scipy_stft_wrapper(
                    x, p["n_fft"], p["hop_size"], ws, p["centre"], p["sample_rate"]
                )
            )
            ** 2,
        )
    elif operator_type == "magnitude":
        return Operator(
            "magnitude",
            lambda x, p=params, ws=window_spec: np.abs(
                scipy_stft_wrapper(
                    x, p["n_fft"], p["hop_size"], ws, p["centre"], p["sample_rate"]
                )
            ),
        )
    elif operator_type == "db":
        eps = 10 ** (scale_params["floor_db"] / 10.0)
        return Operator(
            "db",
            lambda x, p=params, ws=window_spec, e=eps: 10.0
            * np.log10(
                np.maximum(
                    np.abs(
                        scipy_stft_wrapper(
                            x,
                            p["n_fft"],
                            p["hop_size"],
                            ws,
                            p["centre"],
                            p["sample_rate"],
                        )
                    )
                    ** 2,
                    e,
                )
            ),
        )
    elif operator_type == "mel":
        # scipy mel uses power spectrogram + filterbank
        spec = StftSpec(
            sample_rate=params["sample_rate"],
            n_fft=params["n_fft"],
            hop_length=params["hop_size"],
            centre=params["centre"],
            window_type=params["window_type"],
        )
        mel_spec = MelSpec(**scale_params)
        return Operator(
            "mel", lambda x, s=spec, ms=mel_spec: mel_spec_from_signal(x, s, ms)
        )
    elif operator_type == "erb":
        # ERB not in scipy, use numpy implementation
        spec = StftSpec(
            sample_rate=params["sample_rate"],
            n_fft=params["n_fft"],
            hop_length=params["hop_size"],
            centre=params["centre"],
            window_type=params["window_type"],
        )
        erb_spec = ErbSpec(
            centre_freqs=erb_centers(
                scale_params["f_min"],
                scale_params["f_max"],
                scale_params["n_filters"],
            )
        )
        return Operator(
            "erb", lambda x, s=spec, es=erb_spec: erb_spec_from_signal(x, s, es)
        )
    elif operator_type == "loghz":
        # LogHz not in scipy, use numpy implementation
        spec = StftSpec(
            sample_rate=params["sample_rate"],
            n_fft=params["n_fft"],
            hop_length=params["hop_size"],
            centre=params["centre"],
            window_type=params["window_type"],
        )
        loghz_spec = LogHzSpec(**scale_params)
        return Operator(
            "loghz", lambda x, s=spec, ls=loghz_spec: loghz_spec_from_signal(x, s, ls)
        )
    elif operator_type == "chroma":
        spec = StftSpec(
            sample_rate=params["sample_rate"],
            n_fft=params["n_fft"],
            hop_length=params["hop_size"],
            centre=params["centre"],
            window_type=params["window_type"],
        )
        tuning = scale_params["tuning"]
        return Operator(
            "chroma", lambda x, s=spec, t=tuning: chroma_from_signal(x, s, f_ref=t)
        )
    else:
        raise ValueError(f"Unknown operator type: {operator_type}")


# Helper functions for NumPy implementations
def stft_from_signal(x: np.ndarray, spec: StftSpec):
    X, freqs, times = stft(
        x,
        sample_rate=spec.sample_rate,
        n_fft=spec.n_fft,
        hop_length=spec.hop_length,
        window=spec.window(),
        centre=spec.centre,
    )
    return X, freqs, times


def power_spec_from_signal(x: np.ndarray, spec: StftSpec) -> np.ndarray:
    X, _, _ = stft_from_signal(x, spec)
    return power_spectrogram(X)


def mag_spec_from_signal(x: np.ndarray, spec: StftSpec) -> np.ndarray:
    X, _, _ = stft_from_signal(x, spec)
    return magnitude_spectrogram(X)


def db_spec_from_signal(
    x: np.ndarray, spec: StftSpec, eps: float = 1e-10
) -> np.ndarray:
    P = power_spec_from_signal(x, spec)
    return db_spectrogram(P, eps=eps)


@dataclass(frozen=True)
class MelSpec:
    n_mels: int = 80
    f_min: float = 0.0
    f_max: float | None = None

    def fmax(self, sample_rate: int) -> float:
        return (sample_rate / 2.0) if self.f_max is None else self.f_max


def mel_spec_from_signal(
    x: np.ndarray, spec: StftSpec, mel_spec: MelSpec
) -> np.ndarray:
    P = power_spec_from_signal(x, spec)
    fb = mel_filterbank(
        sample_rate=spec.sample_rate,
        n_fft=spec.n_fft,
        n_mels=mel_spec.n_mels,
        f_min=mel_spec.f_min,
        f_max=mel_spec.fmax(spec.sample_rate),
    )
    return mel_spectrogram(P, fb)


@dataclass(frozen=True)
class LogHzSpec:
    n_bins: int = 128
    f_min: float = 30.0
    f_max: float | None = None

    def fmax(self, sample_rate: int) -> float:
        return (sample_rate / 2.0) if self.f_max is None else self.f_max


def loghz_spec_from_signal(
    x: np.ndarray, spec: StftSpec, log_spec: LogHzSpec
) -> np.ndarray:
    P = power_spec_from_signal(x, spec)
    M = log_frequency_matrix(
        sample_rate=spec.sample_rate,
        n_fft=spec.n_fft,
        n_bins=log_spec.n_bins,
        f_min=log_spec.f_min,
        f_max=log_spec.fmax(spec.sample_rate),
    )
    return logfreq_spectrogram(P, M)


@dataclass(frozen=True)
class ErbSpec:
    centre_freqs: np.ndarray


def erb_spec_from_signal(
    x: np.ndarray, spec: StftSpec, erb_spec: ErbSpec
) -> np.ndarray:
    X, freqs, _ = stft_from_signal(x, spec)
    return erb_spectrogram(X, freqs=freqs, centre_freqs=erb_spec.centre_freqs)


def chroma_from_signal(
    x: np.ndarray, spec: StftSpec, f_ref: float = 440.0
) -> np.ndarray:
    P = power_spec_from_signal(x, spec)
    freqs = np.fft.rfftfreq(spec.n_fft, d=1.0 / spec.sample_rate)
    return chroma(P, freqs=freqs, f_ref=f_ref)


print("Parameter iterator and operator factory loaded")

Parameter iterator and operator factory loaded


In [6]:
# ==========================================
# Widget Controls
# ==========================================

import ipywidgets as widgets
from IPython.display import display

# Operator selection checkboxes (MFCC and CQT removed)
operator_widgets = {
    "power": widgets.Checkbox(value=True, description="Linear Power"),
    "magnitude": widgets.Checkbox(value=True, description="Linear Magnitude"),
    "db": widgets.Checkbox(value=True, description="Linear dB"),
    "mel": widgets.Checkbox(value=True, description="Mel Power"),
    "erb": widgets.Checkbox(value=False, description="ERB Power"),
    "loghz": widgets.Checkbox(value=True, description="LogHz Power"),
    "chroma": widgets.Checkbox(value=False, description="Chroma"),
}

# Fixture selection
fixture_widgets = widgets.SelectMultiple(
    options=list(FIXTURES.keys()),
    value=list(FIXTURES.keys()),
    description="Fixtures:",
)

# Implementation selection (added scipy)
impl_widgets = {
    "rust": widgets.Checkbox(value=True, description="Rust (spectrograms)"),
    "numpy": widgets.Checkbox(value=True, description="NumPy"),
    "scipy": widgets.Checkbox(value=True, description="SciPy"),
}

# Run controls
warmup_widget = widgets.IntSlider(value=10, min=1, max=50, description="Warmup:")
runs_widget = widgets.IntSlider(value=100, min=10, max=1000, description="Runs:")
run_button = widgets.Button(description="Run Benchmarks", button_style="success")
progress_widget = widgets.IntProgress(value=0, min=0, max=100, description="Progress:")
status_widget = widgets.Label(value="Ready")

# Layout
display(
    widgets.VBox(
        [
            widgets.HTML("<h3>Operator Selection</h3>"),
            widgets.HBox([*operator_widgets.values()]),
            widgets.HTML("<h3>Fixtures</h3>"),
            fixture_widgets,
            widgets.HTML("<h3>Implementation</h3>"),
            *impl_widgets.values(),
            widgets.HTML("<h3>Benchmark Settings</h3>"),
            warmup_widget,
            runs_widget,
            widgets.HBox([run_button, progress_widget]),
            status_widget,
        ]
    )
)

VBox(children=(HTML(value='<h3>Operator Selection</h3>'), HBox(children=(Checkbox(value=True, description='Lin…

In [10]:
# ==========================================
# Benchmark Runner
# ==========================================


def run_parameter_sweep(
    operators: list,
    fixtures: list,
    implementations: list,
    warmup: int,
    runs: int,
):
    """
    Run benchmarks across parameter sweep.

    Returns:
        Dictionary containing all results with timestamp and metadata.
    """
    timestamp = datetime.now().isoformat()
    all_results = []

    total_combinations = (
        sum(len(generate_param_combinations(op)) for op in operators)
        * len(fixtures)
        * len(implementations)
    )

    progress = 0
    progress_widget.max = total_combinations

    for operator_type in operators:
        param_combos = generate_param_combinations(operator_type)

        for params in param_combos:
            for fixture_name in fixtures:
                fixture_data = FIXTURES[fixture_name]

                for impl in implementations:
                    status_widget.value = f"{impl} {operator_type} on {fixture_name}"

                    try:
                        # Create operator with params
                        op = create_operator(operator_type, params, impl)

                        # Run benchmark
                        timing = benchmark_fn(
                            lambda: op.fn(fixture_data), warmup=warmup, runs=runs
                        )

                        # Store result
                        result = {
                            "operator": operator_type,
                            "implementation": impl,
                            "params": params,
                            "fixture": fixture_name,
                            "timing": {
                                "mean": float(timing["mean"]),
                                "std": float(timing["std"]),
                                "min": float(timing["min"]),
                                "max": float(timing["max"]),
                                "median": float(np.median(timing["raw"])),
                            },
                        }
                        all_results.append(result)

                    except Exception as e:
                        print(f"Error: {impl} {operator_type} {fixture_name}: {e}")

                    progress += 1
                    progress_widget.value = progress

    status_widget.value = "Complete!"

    return {
        "timestamp": timestamp,
        "metadata": {
            "warmup_runs": warmup,
            "benchmark_runs": runs,
        },
        "results": all_results,
    }


def save_results(results: dict, output_dir: str = "benchmark_results"):
    """Save results to JSON file with timestamp."""
    Path(output_dir).mkdir(exist_ok=True)

    timestamp_clean = results["timestamp"].replace(":", "-").replace(".", "-")
    filename = Path(output_dir) / f"results_{timestamp_clean}.json"

    with open(filename, "w") as f:
        json.dump(results, f, indent=2)

    print(f"Results saved to: {filename}")
    return filename


# Connect button to runner
def on_run_clicked(b):
    selected_operators = [k for k, w in operator_widgets.items() if w.value]
    selected_fixtures = list(fixture_widgets.value)
    selected_impls = [k for k, w in impl_widgets.items() if w.value]

    if not selected_operators:
        status_widget.value = "Error: Select at least one operator"
        return

    if not selected_fixtures:
        status_widget.value = "Error: Select at least one fixture"
        return

    if not selected_impls:
        status_widget.value = "Error: Select at least one implementation"
        return

    results = run_parameter_sweep(
        selected_operators,
        selected_fixtures,
        selected_impls,
        warmup_widget.value,
        runs_widget.value,
    )

    save_results(results)


run_button.on_click(on_run_clicked)
print("Benchmark runner ready")

Benchmark runner ready


In [13]:
# ==========================================
# Results Loading and Analysis
# ==========================================


def load_results(filename: str) -> dict:
    """Load results from JSON file."""
    print(f"Loading from {filename}")
    with open(filename, "r") as f:
        x = json.load(f)
        print(x)
        return x


def list_result_files(output_dir: str = "benchmark_results") -> list:
    """List all result files sorted by timestamp (newest first)."""
    path = Path(output_dir)
    if not path.exists():
        return []
    return sorted(path.glob("results_*.json"), reverse=True)


# Widget to select result file
result_files = list_result_files()
if result_files:
    result_selector = widgets.Dropdown(
        options=[str(f) for f in result_files],
        description="Results:",
    )
    load_button = widgets.Button(description="Load Results", button_style="info")

    display(widgets.HBox([result_selector, load_button]))

    LOADED_RESULTS = None

    def on_load_clicked(b):
        global LOADED_RESULTS
        LOADED_RESULTS = load_results(result_selector.value)
        print(
            f"Loaded {len(LOADED_RESULTS['results'])} results from {LOADED_RESULTS['timestamp']}"
        )

    load_button.on_click(on_load_clicked)

else:
    print("No result files found. Run benchmarks first.")
    LOADED_RESULTS = None

HBox(children=(Dropdown(description='Results:', options=('benchmark_results/results_2026-01-28T14-44-47-072778…

In [14]:
# ==========================================
# COMPREHENSIVE SUMMARY TABLE GENERATOR
# ==========================================


def generate_comprehensive_summary(results: dict) -> pd.DataFrame:
    """
    Generate a comprehensive summary table comparing all implementations.

    Returns a DataFrame with:
    - One row per (operator, fixture) combination
    - Columns for each implementation's mean time
    - Speedup columns (Rust vs NumPy, Rust vs SciPy)
    """
    df = results_to_dataframe(results)

    # Group by operator and fixture, averaging across all parameter combinations
    summary_rows = []

    for (operator, fixture), group in df.groupby(["operator", "fixture"]):
        row = {
            "Operator": operator,
            "Fixture": fixture,
        }

        # Get mean times for each implementation
        for impl in ["rust", "numpy", "scipy"]:
            impl_data = group[group["implementation"] == impl]
            if len(impl_data) > 0:
                row[f"{impl.capitalize()} (ms)"] = impl_data["mean_ms"].mean()
                row[f"{impl.capitalize()} Std"] = impl_data["mean_ms"].std()
            else:
                row[f"{impl.capitalize()} (ms)"] = np.nan
                row[f"{impl.capitalize()} Std"] = np.nan

        # Calculate speedups (higher is better for Rust)
        if not np.isnan(row.get("Rust (ms)", np.nan)):
            if not np.isnan(row.get("Numpy (ms)", np.nan)):
                row["Speedup vs NumPy"] = row["Numpy (ms)"] / row["Rust (ms)"]
            if not np.isnan(row.get("Scipy (ms)", np.nan)):
                row["Speedup vs SciPy"] = row["Scipy (ms)"] / row["Rust (ms)"]

        summary_rows.append(row)

    summary_df = pd.DataFrame(summary_rows)

    # Sort by operator then fixture
    summary_df = summary_df.sort_values(["Operator", "Fixture"])

    return summary_df


def generate_operator_summary(results: dict) -> pd.DataFrame:
    """
    Generate operator-level summary (averaged across all fixtures and parameters).
    """
    df = results_to_dataframe(results)

    summary_rows = []

    for operator, group in df.groupby("operator"):
        row = {"Operator": operator}

        for impl in ["rust", "numpy", "scipy"]:
            impl_data = group[group["implementation"] == impl]
            if len(impl_data) > 0:
                row[f"{impl.capitalize()} (ms)"] = impl_data["mean_ms"].mean()
                row[f"{impl.capitalize()} Std"] = impl_data["mean_ms"].std()

        # Calculate average speedups
        if "Rust (ms)" in row and "Numpy (ms)" in row:
            row["Avg Speedup vs NumPy"] = row["Numpy (ms)"] / row["Rust (ms)"]
        if "Rust (ms)" in row and "Scipy (ms)" in row:
            row["Avg Speedup vs SciPy"] = row["Scipy (ms)"] / row["Rust (ms)"]

        summary_rows.append(row)

    summary_df = pd.DataFrame(summary_rows)
    return summary_df.sort_values("Operator")


def results_to_dataframe(results: dict) -> pd.DataFrame:
    """Convert results to pandas DataFrame for analysis."""
    rows = []
    for r in results["results"]:
        row = {
            "operator": r["operator"],
            "implementation": r["implementation"],
            "fixture": r["fixture"],
            "mean_ms": r["timing"]["mean"] * 1000,
            "std_ms": r["timing"]["std"] * 1000,
            "min_ms": r["timing"]["min"] * 1000,
            "median_ms": r["timing"]["median"] * 1000,
        }
        # Flatten params
        for k, v in r["params"].items():
            if isinstance(v, dict):
                for k2, v2 in v.items():
                    row[f"param_{k2}"] = v2
            else:
                row[f"param_{k}"] = v
        rows.append(row)
    return pd.DataFrame(rows)


print("Summary table generators loaded")

Summary table generators loaded


In [54]:
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib as mpl
import seaborn as sns
from pathlib import Path

# Set the visual style globally
sns.set_theme(style="darkgrid", context="paper")
mpl.rcParams.update(
    {
        # Use LaTeX for all text rendering
        "text.usetex": True,
        "font.family": "serif",
        "font.serif": ["Computer Modern"],
        # General Font Sizes
        "font.size": 14,  # Global default font size
        "axes.titlesize": 18,  # Size of the plot title
        "axes.labelsize": 16,  # Size of the x and y labels
        "xtick.labelsize": 12,  # Size of the x-axis tick labels
        "ytick.labelsize": 12,  # Size of the y-axis tick labels
        "legend.fontsize": 12,  # Size of the legend text
        "figure.titlesize": 20,  # Size of the figure suptitle
        # Lines and Markers (Optional but recommended for LaTeX style)
        "lines.linewidth": 2.0,
        "lines.markersize": 7,
        # Ensure LaTeX uses the correct math font
        "mathtext.fontset": "cm",
    }
)


def results_to_dataframe(results: dict) -> pd.DataFrame:
    """Convert results to pandas DataFrame for analysis."""
    rows = []
    for r in results["results"]:
        row = {
            "operator": r["operator"],
            "implementation": r["implementation"],
            "fixture": r["fixture"],
            "mean_ms": r["timing"]["mean"] * 1000,
            "std_ms": r["timing"]["std"] * 1000,
            "min_ms": r["timing"]["min"] * 1000,
            "median_ms": r["timing"]["median"] * 1000,
        }
        for k, v in r["params"].items():
            if isinstance(v, dict):
                for k2, v2 in v.items():
                    row[f"param_{k2}"] = v2
            else:
                row[f"param_{k}"] = v
        rows.append(row)
    return pd.DataFrame(rows)


def _save_fig(save_path: str = None):
    """Internal helper to save figures."""
    if save_path:
        path = Path(save_path)
        path.parent.mkdir(parents=True, exist_ok=True)
        plt.savefig(path.with_suffix(".pdf"), bbox_inches="tight", dpi=300)
        plt.savefig(path.with_suffix(".png"), bbox_inches="tight", dpi=300)
        print(f"Figure saved to {path}")


def plot_speedup_summary(
    df: pd.DataFrame, baselines: list = ["numpy", "scipy"], save_path: str = None
):
    """
    Plot Rust speedup against multiple baselines across all operators.
    """
    speedup_data = []
    param_cols = [c for c in df.columns if c.startswith("param_")]
    merge_cols = ["operator", "fixture"] + param_cols

    rust_df = df[df["implementation"] == "rust"]

    for baseline in baselines:
        base_df = df[df["implementation"] == baseline]
        if base_df.empty:
            continue

        merged = rust_df.merge(
            base_df, on=merge_cols, suffixes=("_rust", f"_{baseline}")
        )
        merged["speedup"] = merged[f"mean_ms_{baseline}"] / merged["mean_ms_rust"]

        for operator in merged["operator"].unique():
            op_data = merged[merged["operator"] == operator]["speedup"]
            speedup_data.append(
                {
                    "operator": operator,
                    "baseline": baseline,
                    "mean_speedup": op_data.mean(),
                    "std_speedup": op_data.std(),
                }
            )

    plot_df = pd.DataFrame(speedup_data)

    plt.figure(figsize=(16, 9))
    ax = sns.barplot(
        data=plot_df, x="operator", y="mean_speedup", hue="baseline", edgecolor="black"
    )

    plt.axhline(1.0, color="black", linestyle="--", linewidth=1.5, label="Break-even")
    plt.ylabel(r"\textbf{Speedup Factor (Baseline / Rust)}")
    plt.xlabel(r"\textbf{Operation}")
    plt.title(r"\textbf{Performance: Rust vs. NumPy and SciPy}")
    plt.legend(title=None)
    _save_fig(save_path)
    plt.show()
    return plot_df


def plot_param_sensitivity(
    df: pd.DataFrame, operator: str, param: str, save_path: str = None
):
    """
    Plot runtime vs parameter value using Seaborn lineplot (handles confidence intervals).
    """
    subset = df[df["operator"] == operator].copy()
    param_col = f"param_{param}"

    if param_col not in subset.columns:
        print(f"Parameter {param} not found.")
        return

    plt.figure(figsize=(10, 6))
    sns.lineplot(
        data=subset,
        x=param_col,
        y="mean_ms",
        hue="implementation",
        marker="o",
        linewidth=2.5,
        markersize=8,
    )

    plt.yscale("log")  # Runtimes often vary by orders of magnitude
    plt.title(f"Sensitivity Analysis: {param} ({operator})", fontsize=14)
    plt.xlabel(f"Parameter: {param}")
    plt.ylabel("Mean Runtime (ms) - Log Scale")

    if param in ["n_fft", "hop_size", "n_mels", "n_filters", "n_bins"]:
        plt.xscale("log", base=2)

    _save_fig(save_path)
    plt.show()


def plot_speedup_by_fixture(
    df: pd.DataFrame,
    operator: str,
    baselines: list = ["numpy", "scipy"],
    save_path: str = None,
):
    """
    Compare speedup across different audio/signal fixtures.
    """
    subset = df[df["operator"] == operator]
    param_cols = [c for c in df.columns if c.startswith("param_")]
    merge_cols = ["fixture"] + param_cols

    rust_df = subset[subset["implementation"] == "rust"]
    comparison_data = []

    for baseline in baselines:
        base_df = subset[subset["implementation"] == baseline]
        if base_df.empty:
            continue

        merged = rust_df.merge(
            base_df, on=merge_cols, suffixes=("_rust", f"_{baseline}")
        )
        merged["speedup"] = merged[f"mean_ms_{baseline}"] / merged["mean_ms_rust"]
        merged["baseline"] = baseline
        comparison_data.append(merged)

    if not comparison_data:
        print("Insufficient data for comparison.")
        return

    plot_df = pd.concat(comparison_data)

    plt.figure(figsize=(16, 9))
    sns.boxplot(data=plot_df, x="fixture", y="speedup", hue="baseline", palette="Set2")
    plt.axhline(1.0, color="black", linestyle="--", alpha=0.6)

    plt.title(f"Speedup Distribution by Fixture Type: {operator}", fontsize=14)
    plt.ylabel("Speedup (x Times Faster)")
    plt.xlabel("Fixture / Signal Type")

    _save_fig(save_path)
    plt.show()

In [55]:
# ==========================================
# Display Summary Tables
# ==========================================

if LOADED_RESULTS is not None:
    print("=" * 80)
    print("OPERATOR-LEVEL SUMMARY (Averaged across all fixtures and parameters)")
    print("=" * 80)
    operator_summary = generate_operator_summary(LOADED_RESULTS)
    display(operator_summary.round(3))

    print("\n" + "=" * 80)
    print("DETAILED SUMMARY (By Operator and Fixture)")
    print("=" * 80)
    detailed_summary = generate_comprehensive_summary(LOADED_RESULTS)
    display(detailed_summary.round(3))

    # Save summaries to CSV
    timestamp_clean = LOADED_RESULTS["timestamp"].replace(":", "-").replace(".", "-")
    operator_summary.to_csv(
        f"benchmark_results/operator_summary_{timestamp_clean}.csv",
        index=False,
        float_format="{:.3f}",
    )
    detailed_summary.to_csv(
        f"benchmark_results/detailed_summary_{timestamp_clean}.csv",
        index=False,
        float_format="{:.3f}",
    )
    print("\nSummaries saved to CSV files")
else:
    print("No results loaded. Run benchmarks or load existing results first.")

OPERATOR-LEVEL SUMMARY (Averaged across all fixtures and parameters)


Unnamed: 0,Operator,Rust (ms),Rust Std,Numpy (ms),Numpy Std,Scipy (ms),Scipy Std,Avg Speedup vs NumPy,Avg Speedup vs SciPy
0,db,0.257,0.165,0.35,0.251,0.451,0.366,1.363,1.755
1,erb,0.601,0.437,3.713,2.703,3.714,2.723,6.178,6.181
2,loghz,0.178,0.149,0.547,0.998,0.534,0.965,3.068,2.996
3,magnitude,0.14,0.089,0.198,0.133,0.319,0.277,1.419,2.287
4,mel,0.18,0.139,0.63,0.851,0.612,0.801,3.506,3.406
5,power,0.126,0.082,0.205,0.141,0.327,0.288,1.63,2.603



DETAILED SUMMARY (By Operator and Fixture)


Unnamed: 0,Operator,Fixture,Rust (ms),Rust Std,Numpy (ms),Numpy Std,Scipy (ms),Scipy Std,Speedup vs NumPy,Speedup vs SciPy
0,db,chirp,0.26,0.167,0.353,0.254,0.454,0.369,1.357,1.743
1,db,impulse,0.251,0.16,0.337,0.243,0.437,0.357,1.345,1.745
2,db,noise,0.266,0.172,0.361,0.259,0.463,0.375,1.359,1.743
3,db,sine_3k,0.252,0.162,0.348,0.25,0.448,0.365,1.383,1.779
4,db,sine_440,0.256,0.164,0.35,0.251,0.451,0.366,1.369,1.764
5,erb,chirp,0.604,0.448,3.746,2.717,3.721,2.693,6.198,6.158
6,erb,impulse,0.6,0.438,3.593,2.673,3.581,2.681,5.986,5.967
7,erb,noise,0.599,0.432,3.747,2.721,3.746,2.709,6.251,6.248
8,erb,sine_3k,0.604,0.439,3.744,2.676,3.755,2.755,6.202,6.219
9,erb,sine_440,0.597,0.428,3.733,2.734,3.768,2.782,6.254,6.313



Summaries saved to CSV files


In [56]:
# ==========================================
# Interactive Analysis Widgets
# ==========================================

# Check if results are loaded
if LOADED_RESULTS is not None:
    df = results_to_dataframe(LOADED_RESULTS)

    # Print basic stats
    print(f"Loaded {len(df)} result rows")
    print(f"Operators: {sorted(df['operator'].unique())}")
    print(f"Fixtures: {sorted(df['fixture'].unique())}")
    print(f"Implementations: {sorted(df['implementation'].unique())}")

    # Operator selection for analysis
    analysis_operator_widget = widgets.Dropdown(
        options=sorted(df["operator"].unique()),
        description="Operator:",
    )

    # Parameter selection
    param_cols = [c.replace("param_", "") for c in df.columns if c.startswith("param_")]
    analysis_param_widget = widgets.SelectMultiple(
        options=sorted(param_cols),
        value=[param_cols[0]] if param_cols else [],
        description="Parameters:",
        rows=min(len(param_cols), 6),
    )

    # Fixture selection
    analysis_fixture_widget = widgets.Dropdown(
        options=["All"] + sorted(df["fixture"].unique()),
        value="All",
        description="Fixture:",
    )

    # Plot buttons
    plot_speedup_btn = widgets.Button(
        description="Plot Speedup Summary", button_style="info"
    )
    plot_sensitivity_btn = widgets.Button(
        description="Plot Parameter Sensitivity", button_style="info"
    )
    plot_by_fixture_btn = widgets.Button(
        description="Plot Speedup by Fixture", button_style="info"
    )

    import os

    os.makedirs("./imgs", exist_ok=True)

    # Event handlers
    def on_plot_speedup(b):
        plot_speedup_summary(df, save_path="./imgs/average_speedup")

    def on_plot_sensitivity(b):
        operator = analysis_operator_widget.value
        params = list(analysis_param_widget.value)
        fixture = (
            None
            if analysis_fixture_widget.value == "All"
            else analysis_fixture_widget.value
        )

        if len(params) == 0:
            print("Select at least one parameter")
            return

        if len(params) == 1:
            plot_param_sensitivity(df, operator, params[0], fixture=fixture)
        else:
            plot_multi_param_sensitivity(df, operator, params, fixture=fixture)

    def on_plot_by_fixture(b):
        operator = analysis_operator_widget.value
        plot_speedup_by_fixture(df, operator)

    plot_speedup_btn.on_click(on_plot_speedup)
    plot_sensitivity_btn.on_click(on_plot_sensitivity)
    plot_by_fixture_btn.on_click(on_plot_by_fixture)

    # Layout
    display(
        widgets.VBox(
            [
                widgets.HTML("<h3>Analysis Controls</h3>"),
                analysis_operator_widget,
                analysis_param_widget,
                analysis_fixture_widget,
                widgets.HBox(
                    [plot_speedup_btn, plot_sensitivity_btn, plot_by_fixture_btn]
                ),
            ]
        )
    )
else:
    print("No results loaded. Load results first using the widget above.")

Loaded 51840 result rows
Operators: ['db', 'erb', 'loghz', 'magnitude', 'mel', 'power']
Fixtures: ['chirp', 'impulse', 'noise', 'sine_3k', 'sine_440']
Implementations: ['numpy', 'rust', 'scipy']


VBox(children=(HTML(value='<h3>Analysis Controls</h3>'), Dropdown(description='Operator:', options=('db', 'erb…