diff --git a/buckaroo/pluggable_analysis_framework/stat_pipeline.py b/buckaroo/pluggable_analysis_framework/stat_pipeline.py index 5f4bec8f4..ff9b6297a 100644 --- a/buckaroo/pluggable_analysis_framework/stat_pipeline.py +++ b/buckaroo/pluggable_analysis_framework/stat_pipeline.py @@ -234,13 +234,18 @@ def __init__(self, stat_funcs: list, unit_test: bool = True, record_timings: boo self._unit_test_result = self.unit_test() def process_column(self, column_name: str, column_dtype, raw_series=None, sampled_series=None, raw_dataframe=None, - initial_stats: Optional[Dict[str, Any]] = None) -> Tuple[Dict[str, Any], List[StatError]]: + initial_stats: Optional[Dict[str, Any]] = None, + cost_classes=None) -> Tuple[Dict[str, Any], List[StatError]]: """Process a single column through the stat DAG. 1. Filters stat functions by column dtype - 2. Executes in topological order with Ok/Err accumulator - 3. Returns (plain_dict, errors) + 2. Filters by ``cost_classes`` (default: all costs) — used by + the JS-driven progressive-stats router to run scalars only + 3. Executes in topological order with Ok/Err accumulator + 4. Returns (plain_dict, errors) """ + if cost_classes is None: + cost_classes = {"scalar", "aggregate"} # Build column-specific DAG (filters by dtype) external = set(self.EXTERNAL_KEYS) if initial_stats: @@ -255,6 +260,8 @@ def process_column(self, column_name: str, column_dtype, raw_series=None, sample accumulator[k] = Ok(v) record_timings = self.record_timings for sf in column_funcs: + if sf.cost not in cost_classes: + continue if record_timings: t0 = time.perf_counter() _execute_stat_func(sf, accumulator, column_name, raw_series=raw_series, sampled_series=sampled_series, @@ -270,9 +277,12 @@ def process_column(self, column_name: str, column_dtype, raw_series=None, sample return resolve_accumulator(accumulator, column_name, col_key_to_func) - def process_df(self, df: pd.DataFrame, debug: bool = False) -> Tuple[SDType, List[StatError]]: + def process_df(self, df: pd.DataFrame, debug: bool = False, cost_classes=None) -> Tuple[SDType, List[StatError]]: """Process all columns of a DataFrame. + ``cost_classes`` (default: all) restricts which stat funcs run. + Used by the JS-driven progressive-stats router. + Returns: (summary_dict, all_errors) where summary_dict is SDType-compatible (column_name -> {stat_name -> value}). @@ -292,13 +302,39 @@ def process_df(self, df: pd.DataFrame, debug: bool = False) -> Tuple[SDType, Lis col_result, col_errors = self.process_column(column_name=rewritten_col_name, column_dtype=col_dtype, raw_series=ser, sampled_series=ser, raw_dataframe=df, - initial_stats={'orig_col_name': orig_col_name, 'rewritten_col_name': rewritten_col_name}) + initial_stats={'orig_col_name': orig_col_name, 'rewritten_col_name': rewritten_col_name}, + cost_classes=cost_classes) summary[rewritten_col_name] = col_result all_errors.extend(col_errors) return summary, all_errors + def process_df_scalars(self, df: pd.DataFrame, debug: bool = False) -> Tuple[SDType, List[StatError]]: + """Run only cost=scalar stats — the fast path used by the + JS-driven progressive-stats router. Histograms etc. are + skipped.""" + return self.process_df(df, debug=debug, cost_classes={"scalar"}) + + def process_df_aggregates(self, df: pd.DataFrame, debug: bool = False) -> Tuple[SDType, List[StatError]]: + """Run the full pipeline, return only aggregate-cost stats. + + Aggregates depend on scalar inputs, so the full pipeline runs; + the response is filtered to ship just the aggregate provides + (typically histograms). Caching scalars across this and + ``process_df_scalars`` is a future PR. + """ + summary, errs = self.process_df(df, debug=debug) + agg_provides = {sk.name for sf in self.ordered_stat_funcs + if sf.cost == "aggregate" for sk in sf.provides} + filtered = {col: {k: v for k, v in stats.items() if k in agg_provides} + for col, stats in summary.items()} + agg_func_names = {sf.name for sf in self.ordered_stat_funcs + if sf.cost == "aggregate"} + filtered_errs = [e for e in errs + if e.stat_func is not None and e.stat_func.name in agg_func_names] + return filtered, filtered_errs + def process_df_v1_compat(self, df: pd.DataFrame, debug: bool = False) -> Tuple[SDType, ErrDict]: """Process DataFrame with v1-compatible error format. diff --git a/buckaroo/pluggable_analysis_framework/xorq_stat_pipeline.py b/buckaroo/pluggable_analysis_framework/xorq_stat_pipeline.py index 54720f9af..b38f05600 100644 --- a/buckaroo/pluggable_analysis_framework/xorq_stat_pipeline.py +++ b/buckaroo/pluggable_analysis_framework/xorq_stat_pipeline.py @@ -165,7 +165,48 @@ def unit_test(self) -> Tuple[bool, List[StatError]]: finally: self.backend = saved_backend - def process_table(self, table) -> Tuple[SDType, List[StatError]]: + def process_table_scalars(self, table) -> Tuple[SDType, List[StatError]]: + """Run only cost=scalar stats — the fast path. + + Filters out cost=aggregate funcs (histograms, per-column-query + stats) before running the pipeline. Used by the JS-driven + progressive-stats router to ship cheap stats immediately on + state_change. See plans/js-driven-stat-debounce.md. + + Aggregate stats that depend only on scalars (e.g. a downstream + compute that consumes ``histogram``) get their inputs missing + and produce ``Err`` upstream — that's the expected shape; the + consumer should ignore those when asking for scalars only. + """ + return self.process_table(table, cost_classes={"scalar"}) + + def process_table_aggregates(self, table) -> Tuple[SDType, List[StatError]]: + """Run the full pipeline, return only aggregate-cost stats. + + Aggregates typically depend on scalar inputs (e.g. ``histogram`` + consumes ``value_counts``, ``length``), so the full pipeline + has to run. Output is filtered to just aggregate-cost provides + so the response payload is small. Caching of the scalar half + across this and ``process_table_scalars`` is a future PR; here + the compute is whole, only the shipping is filtered. + """ + summary, errs = self.process_table(table) + agg_provides = {sk.name for sf in self.ordered_stat_funcs + if sf.cost == "aggregate" for sk in sf.provides} + filtered = {col: {k: v for k, v in stats.items() if k in agg_provides} + for col, stats in summary.items()} + # Errors are per-stat-func; keep only those from aggregate funcs. + agg_func_names = {sf.name for sf in self.ordered_stat_funcs + if sf.cost == "aggregate"} + filtered_errs = [e for e in errs + if e.stat_func is not None and e.stat_func.name in agg_func_names] + return filtered, filtered_errs + + def process_table(self, table, cost_classes=None) -> Tuple[SDType, List[StatError]]: + """Run the pipeline; ``cost_classes`` filter restricts which + stat funcs execute (default: all costs).""" + if cost_classes is None: + cost_classes = {"scalar", "aggregate"} schema = table.schema() columns = list(table.columns) @@ -188,6 +229,8 @@ def process_table(self, table) -> Tuple[SDType, List[StatError]]: for sf in self.ordered_stat_funcs: if not _is_batch_func(sf): continue + if sf.cost not in cost_classes: + continue xorq_col_param = next(r.name for r in sf.requires if r.type is XorqColumn) for col in columns: col_dtype = schema[col] @@ -253,6 +296,10 @@ def process_table(self, table) -> Tuple[SDType, List[StatError]]: # (typically the batch-phase stats). if sf.provides and all(sk.name in col_accum for sk in sf.provides): continue + # Cost-class filter — the JS-driven router runs scalars + # first, aggregates after a debounce. + if sf.cost not in cost_classes: + continue _execute_stat_func(sf, col_accum, col, raw_series=None, sampled_series=None, raw_dataframe=None, xorq_expr=table, xorq_execute=self._execute) diff --git a/tests/unit/test_paf_v2.py b/tests/unit/test_paf_v2.py index c33a78469..5722b3713 100644 --- a/tests/unit/test_paf_v2.py +++ b/tests/unit/test_paf_v2.py @@ -600,6 +600,54 @@ def test_basic_pipeline(self): assert 'distinct_per' in pipeline.provided_summary_facts_set assert 'length' in pipeline.provided_summary_facts_set + def test_process_df_cost_class_filter_scalars(self): + """Scalars-only path: ``process_df_scalars`` runs only cost=scalar + stats. Aggregate funcs (and stats that depend on them) are + skipped. The fast path the JS router uses for the initial + state_change response.""" + @stat(cost="aggregate") + def expensive_stat(ser: RawSeries) -> int: + return 999 # would be slow in real life + + df = pd.DataFrame({'a': [1, 2, 3, 1, 2]}) + pipeline = StatPipeline([length, expensive_stat], unit_test=False) + summary, errs = pipeline.process_df_scalars(df) + + assert summary['a']['length'] == 5 + # The aggregate stat was filtered out — it did not run. + assert 'expensive_stat' not in summary['a'] + + def test_process_df_cost_class_filter_aggregates(self): + """Aggregates-only path: ``process_df_aggregates`` runs the + pipeline but ships only aggregate provides. Used by the JS + router for the slow follow-up after a debounce.""" + @stat(cost="aggregate") + def expensive_stat(ser: RawSeries) -> int: + return 999 + + df = pd.DataFrame({'a': [1, 2, 3, 1, 2]}) + pipeline = StatPipeline([length, expensive_stat], unit_test=False) + summary, errs = pipeline.process_df_aggregates(df) + + # Only aggregate-cost stat is shipped. + assert summary['a']['expensive_stat'] == 999 + # Scalar 'length' is computed but filtered out of the response. + assert 'length' not in summary['a'] + + def test_process_df_default_runs_all_costs(self): + """Default ``process_df()`` (no ``cost_classes`` arg) runs all + cost classes — back-compat for every existing caller.""" + @stat(cost="aggregate") + def expensive_stat(ser: RawSeries) -> int: + return 999 + + df = pd.DataFrame({'a': [1, 2, 3]}) + pipeline = StatPipeline([length, expensive_stat], unit_test=False) + summary, errs = pipeline.process_df(df) + + assert summary['a']['length'] == 3 + assert summary['a']['expensive_stat'] == 999 + def test_process_column(self): pipeline = StatPipeline([length, distinct_count, distinct_per], unit_test=False) ser = pd.Series([1, 2, 3, 1, 2]) diff --git a/tests/unit/test_xorq_buckaroo_widget.py b/tests/unit/test_xorq_buckaroo_widget.py index e71fd1e86..7fcdab629 100644 --- a/tests/unit/test_xorq_buckaroo_widget.py +++ b/tests/unit/test_xorq_buckaroo_widget.py @@ -60,6 +60,82 @@ def counting(self, q): ) +class TestCostClassFilter: + """Phase 2 of the JS-driven progressive-stats router. The + ``process_table_scalars`` / ``process_table_aggregates`` entry + points filter by ``StatFunc.cost`` so the JS orchestrator can + fetch cheap stats immediately on state_change and slow stats + after a debounce. + """ + + def test_scalars_only_skips_histogram_queries(self): + """``process_table_scalars`` skips the expensive per-column + histogram path entirely. Spy on ``_execute`` and assert the + query count drops vs the full pipeline.""" + from buckaroo.pluggable_analysis_framework.xorq_stat_pipeline import XorqStatPipeline + from buckaroo.customizations.xorq_stats_v2 import XORQ_STATS_V2 + + full_queries: list = [] + scalar_queries: list = [] + orig = XorqStatPipeline._execute + + def counting(self, q, sink): + sink.append(q) + return orig(self, q) + + expr = _expr() + + # Full pipeline reference count + XorqStatPipeline._execute = lambda self, q: counting(self, q, full_queries) + try: + p1 = XorqStatPipeline(list(XORQ_STATS_V2), unit_test=False) + p1.process_table(expr) + finally: + XorqStatPipeline._execute = orig + + # Scalars-only count — must be strictly fewer + XorqStatPipeline._execute = lambda self, q: counting(self, q, scalar_queries) + try: + p2 = XorqStatPipeline(list(XORQ_STATS_V2), unit_test=False) + p2.process_table_scalars(expr) + finally: + XorqStatPipeline._execute = orig + + assert len(scalar_queries) < len(full_queries), ( + f"scalars-only must issue fewer queries than full pipeline; " + f"got scalars={len(scalar_queries)} full={len(full_queries)}") + + def test_scalars_only_omits_histogram_key(self): + """``process_table_scalars`` output has no ``histogram`` key for + any column — that stat is cost=aggregate.""" + from buckaroo.pluggable_analysis_framework.xorq_stat_pipeline import XorqStatPipeline + from buckaroo.customizations.xorq_stats_v2 import XORQ_STATS_V2 + + pipeline = XorqStatPipeline(list(XORQ_STATS_V2), unit_test=False) + summary, errs = pipeline.process_table_scalars(_expr()) + for col, stats in summary.items(): + assert "histogram" not in stats, ( + f"col {col} has histogram in scalars-only output: " + f"{list(stats.keys())}") + + def test_aggregates_only_ships_only_aggregate_keys(self): + """``process_table_aggregates`` returns only aggregate-cost + stat provides. Scalars are computed (as dependencies) but + filtered out of the response.""" + from buckaroo.pluggable_analysis_framework.xorq_stat_pipeline import XorqStatPipeline + from buckaroo.customizations.xorq_stats_v2 import XORQ_STATS_V2 + + pipeline = XorqStatPipeline(list(XORQ_STATS_V2), unit_test=False) + summary, errs = pipeline.process_table_aggregates(_expr()) + for col, stats in summary.items(): + # The boston-style histogram stat is aggregate. + if "histogram" in stats: + # Other obvious scalar keys must NOT be there. + assert "length" not in stats, ( + f"col {col} aggregate response leaks scalar 'length'") + assert "min" not in stats + + class TestInstantiation: def test_smoke(self): XorqBuckarooWidget(_expr())