diff --git a/buckaroo/customizations/pd_stats_v2.py b/buckaroo/customizations/pd_stats_v2.py index be1961bb9..59b0feb52 100644 --- a/buckaroo/customizations/pd_stats_v2.py +++ b/buckaroo/customizations/pd_stats_v2.py @@ -176,7 +176,7 @@ def vc_nth(pos): HistogramSeriesResult = TypedDict('HistogramSeriesResult', {'histogram_args': dict, 'histogram_bins': list}) -@stat() +@stat(cost="aggregate") def histogram_series(ser: RawSeries) -> HistogramSeriesResult: """Compute histogram args from raw series (numeric path).""" if not pd.api.types.is_numeric_dtype(ser): @@ -210,7 +210,7 @@ def histogram_series(ser: RawSeries) -> HistogramSeriesResult: } -@stat() +@stat(cost="aggregate") def histogram(value_counts: pd.Series, nan_per: float, is_numeric: bool, length: int, min: Any, max: Any, histogram_args: dict) -> list: """Compute histogram from summary stats and histogram args.""" diff --git a/buckaroo/customizations/pl_stats_v2.py b/buckaroo/customizations/pl_stats_v2.py index 6a74c1525..15f8cf7c8 100644 --- a/buckaroo/customizations/pl_stats_v2.py +++ b/buckaroo/customizations/pl_stats_v2.py @@ -113,7 +113,7 @@ def pl_numeric_stats(ser: RawSeries) -> NumericStatsResult: # Histogram Series (polars series API) # ============================================================ -@stat() +@stat(cost="aggregate") def pl_histogram_series(ser: RawSeries) -> HistogramSeriesResult: """Compute histogram args from raw polars series (numeric path).""" if not ser.dtype.is_numeric(): diff --git a/buckaroo/customizations/xorq_stats_v2.py b/buckaroo/customizations/xorq_stats_v2.py index b267aea55..148c228fa 100644 --- a/buckaroo/customizations/xorq_stats_v2.py +++ b/buckaroo/customizations/xorq_stats_v2.py @@ -255,7 +255,7 @@ def _categorical_histogram(execute: Callable[[Any], pd.DataFrame], expr: Any, co return out -@stat(default=[]) +@stat(default=[], cost="aggregate") def histogram(expr: XorqExpr, execute: XorqExecute, orig_col_name: str, is_numeric: bool, is_bool: bool, length: int, distinct_count: int, min: float, max: float) -> list: """10-bucket numeric histogram or top-10 categorical histogram. diff --git a/buckaroo/pluggable_analysis_framework/stat_func.py b/buckaroo/pluggable_analysis_framework/stat_func.py index 8e97a3f28..cd6a01a40 100644 --- a/buckaroo/pluggable_analysis_framework/stat_func.py +++ b/buckaroo/pluggable_analysis_framework/stat_func.py @@ -132,6 +132,9 @@ def __repr__(self): # StatFunc — a registered stat computation # --------------------------------------------------------------------------- +VALID_COSTS = ("scalar", "aggregate") + + @dataclass class StatFunc: """A registered stat computation. @@ -145,6 +148,11 @@ class StatFunc: column_filter: optional predicate on column dtype quiet: suppress error reporting default: fallback value on failure (MISSING = no fallback) + cost: cost class — ``"scalar"`` (cheap, ships in the initial + state_change response) or ``"aggregate"`` (slow path that the + JS orchestrator fetches via a separate round-trip after a + debounce). Histograms, value_counts and other per-column + queries belong in ``"aggregate"``. """ name: str func: Callable @@ -154,6 +162,7 @@ class StatFunc: column_filter: Optional[Callable] = None quiet: bool = False default: Any = field(default_factory=lambda: MISSING) + cost: str = "scalar" spread_dict_result: bool = False # v1 compat: spread all dict keys into accumulator v1_computed: bool = False # v1 compat: pass full accumulator as single dict arg @@ -250,7 +259,7 @@ def _get_requires_from_params(sig: inspect.Signature, hints: dict) -> tuple: # @stat decorator # --------------------------------------------------------------------------- -def stat(column_filter=None, quiet=False, default=MISSING): +def stat(column_filter=None, quiet=False, default=MISSING, cost="scalar"): """Decorator that converts a function into a StatFunc. The function signature IS the contract: @@ -263,6 +272,12 @@ def stat(column_filter=None, quiet=False, default=MISSING): key the rest of the DAG expects. Use ``MultipleProvides`` (a TypedDict alias) when one function should write several keys. + ``cost=`` declares the compute-cost class. ``"scalar"`` (default) + means cheap — ships in the initial state_change response. ``"aggregate"`` + means slow (histograms, value_counts, per-column queries) — the JS + orchestrator fetches these via a separate round-trip after a debounce. + See plans/js-driven-stat-debounce.md. + Usage:: @stat() @@ -277,6 +292,10 @@ def mean(ser: RawSeries) -> float: def safe_ratio(a: int, b: int) -> float: return a / b + @stat(cost="aggregate") + def histogram(...) -> list: + ... + class TypingResult(MultipleProvides): is_numeric: bool is_integer: bool @@ -285,6 +304,10 @@ class TypingResult(MultipleProvides): def typing_stats(dtype: str) -> TypingResult: ... """ + if cost not in VALID_COSTS: + raise ValueError( + f"@stat(cost={cost!r}): invalid cost class. " + f"Must be one of {VALID_COSTS}.") def decorator(func): sig = inspect.signature(func) try: @@ -298,7 +321,8 @@ def decorator(func): provides_keys = _get_provides_from_return_type(func.__name__, return_type) stat_func = StatFunc(name=func.__name__, func=func, requires=requires, provides=provides_keys, - needs_raw=needs_raw, column_filter=column_filter, quiet=quiet, default=default) + needs_raw=needs_raw, column_filter=column_filter, quiet=quiet, default=default, + cost=cost) # Attach metadata to the function so pipeline can find it func._stat_func = stat_func diff --git a/tests/unit/test_paf_v2.py b/tests/unit/test_paf_v2.py index 7e8cbc1d6..c33a78469 100644 --- a/tests/unit/test_paf_v2.py +++ b/tests/unit/test_paf_v2.py @@ -176,6 +176,49 @@ def test_no_default(self): sf = distinct_per._stat_func assert sf.default is MISSING + def test_stat_default_cost_is_scalar(self): + # Default cost class is "scalar" — the bulk of stats are cheap. + # Only known-expensive ones opt in to "aggregate". + sf = length._stat_func + assert sf.cost == "scalar" + + def test_stat_explicit_cost_aggregate(self): + @stat(cost="aggregate") + def big_compute(ser: RawSeries) -> float: + return float(ser.mean()) + + assert big_compute._stat_func.cost == "aggregate" + + def test_stat_invalid_cost_rejected(self): + # Only "scalar" and "aggregate" are recognised. A typo should + # fail loud at decoration time — silently dropping an invalid + # cost would leak into the cost-class router as an unscheduled + # stat group. + import pytest as _pt + with _pt.raises(ValueError, match="cost"): + @stat(cost="bigly") + def bad(ser: RawSeries) -> float: + return float(ser.mean()) + + def test_known_expensive_stats_marked_aggregate(self): + """The expensive built-in stat funcs (histogram producers across + all three engines) are tagged ``cost="aggregate"`` so a + downstream router can schedule them on the slow path.""" + from buckaroo.customizations.pd_stats_v2 import histogram as pd_histogram + from buckaroo.customizations.pd_stats_v2 import histogram_series as pd_hs + from buckaroo.customizations.pl_stats_v2 import pl_histogram_series + + assert pd_histogram._stat_func.cost == "aggregate" + assert pd_hs._stat_func.cost == "aggregate" + assert pl_histogram_series._stat_func.cost == "aggregate" + + # xorq histogram (optional dep — skip if not installed) + try: + from buckaroo.customizations.xorq_stats_v2 import histogram as xq_histogram + except ImportError: + return + assert xq_histogram._stat_func.cost == "aggregate" + class _MultiSizeStats(MultipleProvides): row_count: int