Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions buckaroo/customizations/pd_stats_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,7 @@ def vc_nth(pos):
HistogramSeriesResult = TypedDict('HistogramSeriesResult', {'histogram_args': dict, 'histogram_bins': list})


@stat()
@stat(cost="aggregate")
def histogram_series(ser: RawSeries) -> HistogramSeriesResult:
"""Compute histogram args from raw series (numeric path)."""
if not pd.api.types.is_numeric_dtype(ser):
Expand Down Expand Up @@ -210,7 +210,7 @@ def histogram_series(ser: RawSeries) -> HistogramSeriesResult:
}


@stat()
@stat(cost="aggregate")
def histogram(value_counts: pd.Series, nan_per: float, is_numeric: bool, length: int, min: Any, max: Any,
histogram_args: dict) -> list:
"""Compute histogram from summary stats and histogram args."""
Expand Down
2 changes: 1 addition & 1 deletion buckaroo/customizations/pl_stats_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ def pl_numeric_stats(ser: RawSeries) -> NumericStatsResult:
# Histogram Series (polars series API)
# ============================================================

@stat()
@stat(cost="aggregate")
def pl_histogram_series(ser: RawSeries) -> HistogramSeriesResult:
"""Compute histogram args from raw polars series (numeric path)."""
if not ser.dtype.is_numeric():
Expand Down
2 changes: 1 addition & 1 deletion buckaroo/customizations/xorq_stats_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,7 +255,7 @@ def _categorical_histogram(execute: Callable[[Any], pd.DataFrame], expr: Any, co
return out


@stat(default=[])
@stat(default=[], cost="aggregate")
def histogram(expr: XorqExpr, execute: XorqExecute, orig_col_name: str, is_numeric: bool, is_bool: bool, length: int,
distinct_count: int, min: float, max: float) -> list:
"""10-bucket numeric histogram or top-10 categorical histogram.
Expand Down
28 changes: 26 additions & 2 deletions buckaroo/pluggable_analysis_framework/stat_func.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,9 @@ def __repr__(self):
# StatFunc — a registered stat computation
# ---------------------------------------------------------------------------

VALID_COSTS = ("scalar", "aggregate")


@dataclass
class StatFunc:
"""A registered stat computation.
Expand All @@ -145,6 +148,11 @@ class StatFunc:
column_filter: optional predicate on column dtype
quiet: suppress error reporting
default: fallback value on failure (MISSING = no fallback)
cost: cost class — ``"scalar"`` (cheap, ships in the initial
state_change response) or ``"aggregate"`` (slow path that the
JS orchestrator fetches via a separate round-trip after a
debounce). Histograms, value_counts and other per-column
queries belong in ``"aggregate"``.
"""
name: str
func: Callable
Expand All @@ -154,6 +162,7 @@ class StatFunc:
column_filter: Optional[Callable] = None
quiet: bool = False
default: Any = field(default_factory=lambda: MISSING)
cost: str = "scalar"
spread_dict_result: bool = False # v1 compat: spread all dict keys into accumulator
v1_computed: bool = False # v1 compat: pass full accumulator as single dict arg

Expand Down Expand Up @@ -250,7 +259,7 @@ def _get_requires_from_params(sig: inspect.Signature, hints: dict) -> tuple:
# @stat decorator
# ---------------------------------------------------------------------------

def stat(column_filter=None, quiet=False, default=MISSING):
def stat(column_filter=None, quiet=False, default=MISSING, cost="scalar"):
"""Decorator that converts a function into a StatFunc.

The function signature IS the contract:
Expand All @@ -263,6 +272,12 @@ def stat(column_filter=None, quiet=False, default=MISSING):
key the rest of the DAG expects. Use ``MultipleProvides`` (a TypedDict
alias) when one function should write several keys.

``cost=`` declares the compute-cost class. ``"scalar"`` (default)
means cheap — ships in the initial state_change response. ``"aggregate"``
means slow (histograms, value_counts, per-column queries) — the JS
orchestrator fetches these via a separate round-trip after a debounce.
See plans/js-driven-stat-debounce.md.

Usage::

@stat()
Expand All @@ -277,6 +292,10 @@ def mean(ser: RawSeries) -> float:
def safe_ratio(a: int, b: int) -> float:
return a / b

@stat(cost="aggregate")
def histogram(...) -> list:
...

class TypingResult(MultipleProvides):
is_numeric: bool
is_integer: bool
Expand All @@ -285,6 +304,10 @@ class TypingResult(MultipleProvides):
def typing_stats(dtype: str) -> TypingResult:
...
"""
if cost not in VALID_COSTS:
raise ValueError(
f"@stat(cost={cost!r}): invalid cost class. "
f"Must be one of {VALID_COSTS}.")
def decorator(func):
sig = inspect.signature(func)
try:
Expand All @@ -298,7 +321,8 @@ def decorator(func):
provides_keys = _get_provides_from_return_type(func.__name__, return_type)

stat_func = StatFunc(name=func.__name__, func=func, requires=requires, provides=provides_keys,
needs_raw=needs_raw, column_filter=column_filter, quiet=quiet, default=default)
needs_raw=needs_raw, column_filter=column_filter, quiet=quiet, default=default,
cost=cost)

# Attach metadata to the function so pipeline can find it
func._stat_func = stat_func
Expand Down
43 changes: 43 additions & 0 deletions tests/unit/test_paf_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,49 @@ def test_no_default(self):
sf = distinct_per._stat_func
assert sf.default is MISSING

def test_stat_default_cost_is_scalar(self):
# Default cost class is "scalar" — the bulk of stats are cheap.
# Only known-expensive ones opt in to "aggregate".
sf = length._stat_func
assert sf.cost == "scalar"

def test_stat_explicit_cost_aggregate(self):
@stat(cost="aggregate")
def big_compute(ser: RawSeries) -> float:
return float(ser.mean())

assert big_compute._stat_func.cost == "aggregate"

def test_stat_invalid_cost_rejected(self):
# Only "scalar" and "aggregate" are recognised. A typo should
# fail loud at decoration time — silently dropping an invalid
# cost would leak into the cost-class router as an unscheduled
# stat group.
import pytest as _pt
with _pt.raises(ValueError, match="cost"):
@stat(cost="bigly")
def bad(ser: RawSeries) -> float:
return float(ser.mean())

def test_known_expensive_stats_marked_aggregate(self):
"""The expensive built-in stat funcs (histogram producers across
all three engines) are tagged ``cost="aggregate"`` so a
downstream router can schedule them on the slow path."""
from buckaroo.customizations.pd_stats_v2 import histogram as pd_histogram
from buckaroo.customizations.pd_stats_v2 import histogram_series as pd_hs
from buckaroo.customizations.pl_stats_v2 import pl_histogram_series

assert pd_histogram._stat_func.cost == "aggregate"
assert pd_hs._stat_func.cost == "aggregate"
assert pl_histogram_series._stat_func.cost == "aggregate"

# xorq histogram (optional dep — skip if not installed)
try:
from buckaroo.customizations.xorq_stats_v2 import histogram as xq_histogram
except ImportError:
return
assert xq_histogram._stat_func.cost == "aggregate"


class _MultiSizeStats(MultipleProvides):
row_count: int
Expand Down
Loading