Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 41 additions & 5 deletions buckaroo/pluggable_analysis_framework/stat_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,13 +234,18 @@ def __init__(self, stat_funcs: list, unit_test: bool = True, record_timings: boo
self._unit_test_result = self.unit_test()

def process_column(self, column_name: str, column_dtype, raw_series=None, sampled_series=None, raw_dataframe=None,
initial_stats: Optional[Dict[str, Any]] = None) -> Tuple[Dict[str, Any], List[StatError]]:
initial_stats: Optional[Dict[str, Any]] = None,
cost_classes=None) -> Tuple[Dict[str, Any], List[StatError]]:
"""Process a single column through the stat DAG.

1. Filters stat functions by column dtype
2. Executes in topological order with Ok/Err accumulator
3. Returns (plain_dict, errors)
2. Filters by ``cost_classes`` (default: all costs) — used by
the JS-driven progressive-stats router to run scalars only
3. Executes in topological order with Ok/Err accumulator
4. Returns (plain_dict, errors)
"""
if cost_classes is None:
cost_classes = {"scalar", "aggregate"}
# Build column-specific DAG (filters by dtype)
external = set(self.EXTERNAL_KEYS)
if initial_stats:
Expand All @@ -255,6 +260,8 @@ def process_column(self, column_name: str, column_dtype, raw_series=None, sample
accumulator[k] = Ok(v)
record_timings = self.record_timings
for sf in column_funcs:
if sf.cost not in cost_classes:
continue
if record_timings:
t0 = time.perf_counter()
_execute_stat_func(sf, accumulator, column_name, raw_series=raw_series, sampled_series=sampled_series,
Expand All @@ -270,9 +277,12 @@ def process_column(self, column_name: str, column_dtype, raw_series=None, sample

return resolve_accumulator(accumulator, column_name, col_key_to_func)

def process_df(self, df: pd.DataFrame, debug: bool = False) -> Tuple[SDType, List[StatError]]:
def process_df(self, df: pd.DataFrame, debug: bool = False, cost_classes=None) -> Tuple[SDType, List[StatError]]:
"""Process all columns of a DataFrame.

``cost_classes`` (default: all) restricts which stat funcs run.
Used by the JS-driven progressive-stats router.

Returns:
(summary_dict, all_errors) where summary_dict is SDType-compatible
(column_name -> {stat_name -> value}).
Expand All @@ -292,13 +302,39 @@ def process_df(self, df: pd.DataFrame, debug: bool = False) -> Tuple[SDType, Lis

col_result, col_errors = self.process_column(column_name=rewritten_col_name, column_dtype=col_dtype,
raw_series=ser, sampled_series=ser, raw_dataframe=df,
initial_stats={'orig_col_name': orig_col_name, 'rewritten_col_name': rewritten_col_name})
initial_stats={'orig_col_name': orig_col_name, 'rewritten_col_name': rewritten_col_name},
cost_classes=cost_classes)

summary[rewritten_col_name] = col_result
all_errors.extend(col_errors)

return summary, all_errors

def process_df_scalars(self, df: pd.DataFrame, debug: bool = False) -> Tuple[SDType, List[StatError]]:
"""Run only cost=scalar stats — the fast path used by the
JS-driven progressive-stats router. Histograms etc. are
skipped."""
return self.process_df(df, debug=debug, cost_classes={"scalar"})

def process_df_aggregates(self, df: pd.DataFrame, debug: bool = False) -> Tuple[SDType, List[StatError]]:
"""Run the full pipeline, return only aggregate-cost stats.

Aggregates depend on scalar inputs, so the full pipeline runs;
the response is filtered to ship just the aggregate provides
(typically histograms). Caching scalars across this and
``process_df_scalars`` is a future PR.
"""
summary, errs = self.process_df(df, debug=debug)
agg_provides = {sk.name for sf in self.ordered_stat_funcs
if sf.cost == "aggregate" for sk in sf.provides}
filtered = {col: {k: v for k, v in stats.items() if k in agg_provides}
for col, stats in summary.items()}
agg_func_names = {sf.name for sf in self.ordered_stat_funcs
if sf.cost == "aggregate"}
filtered_errs = [e for e in errs
if e.stat_func is not None and e.stat_func.name in agg_func_names]
return filtered, filtered_errs

def process_df_v1_compat(self, df: pd.DataFrame, debug: bool = False) -> Tuple[SDType, ErrDict]:
"""Process DataFrame with v1-compatible error format.

Expand Down
49 changes: 48 additions & 1 deletion buckaroo/pluggable_analysis_framework/xorq_stat_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,48 @@ def unit_test(self) -> Tuple[bool, List[StatError]]:
finally:
self.backend = saved_backend

def process_table(self, table) -> Tuple[SDType, List[StatError]]:
def process_table_scalars(self, table) -> Tuple[SDType, List[StatError]]:
"""Run only cost=scalar stats — the fast path.

Filters out cost=aggregate funcs (histograms, per-column-query
stats) before running the pipeline. Used by the JS-driven
progressive-stats router to ship cheap stats immediately on
state_change. See plans/js-driven-stat-debounce.md.

Aggregate stats that depend only on scalars (e.g. a downstream
compute that consumes ``histogram``) get their inputs missing
and produce ``Err`` upstream — that's the expected shape; the
consumer should ignore those when asking for scalars only.
"""
return self.process_table(table, cost_classes={"scalar"})

def process_table_aggregates(self, table) -> Tuple[SDType, List[StatError]]:
"""Run the full pipeline, return only aggregate-cost stats.

Aggregates typically depend on scalar inputs (e.g. ``histogram``
consumes ``value_counts``, ``length``), so the full pipeline
has to run. Output is filtered to just aggregate-cost provides
so the response payload is small. Caching of the scalar half
across this and ``process_table_scalars`` is a future PR; here
the compute is whole, only the shipping is filtered.
"""
summary, errs = self.process_table(table)
agg_provides = {sk.name for sf in self.ordered_stat_funcs
if sf.cost == "aggregate" for sk in sf.provides}
filtered = {col: {k: v for k, v in stats.items() if k in agg_provides}
for col, stats in summary.items()}
# Errors are per-stat-func; keep only those from aggregate funcs.
agg_func_names = {sf.name for sf in self.ordered_stat_funcs
if sf.cost == "aggregate"}
filtered_errs = [e for e in errs
if e.stat_func is not None and e.stat_func.name in agg_func_names]
return filtered, filtered_errs

def process_table(self, table, cost_classes=None) -> Tuple[SDType, List[StatError]]:
"""Run the pipeline; ``cost_classes`` filter restricts which
stat funcs execute (default: all costs)."""
if cost_classes is None:
cost_classes = {"scalar", "aggregate"}
schema = table.schema()
columns = list(table.columns)

Expand All @@ -188,6 +229,8 @@ def process_table(self, table) -> Tuple[SDType, List[StatError]]:
for sf in self.ordered_stat_funcs:
if not _is_batch_func(sf):
continue
if sf.cost not in cost_classes:
continue
xorq_col_param = next(r.name for r in sf.requires if r.type is XorqColumn)
for col in columns:
col_dtype = schema[col]
Expand Down Expand Up @@ -253,6 +296,10 @@ def process_table(self, table) -> Tuple[SDType, List[StatError]]:
# (typically the batch-phase stats).
if sf.provides and all(sk.name in col_accum for sk in sf.provides):
continue
# Cost-class filter — the JS-driven router runs scalars
# first, aggregates after a debounce.
if sf.cost not in cost_classes:
continue
_execute_stat_func(sf, col_accum, col, raw_series=None, sampled_series=None, raw_dataframe=None,
xorq_expr=table, xorq_execute=self._execute)

Expand Down
48 changes: 48 additions & 0 deletions tests/unit/test_paf_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -600,6 +600,54 @@ def test_basic_pipeline(self):
assert 'distinct_per' in pipeline.provided_summary_facts_set
assert 'length' in pipeline.provided_summary_facts_set

def test_process_df_cost_class_filter_scalars(self):
"""Scalars-only path: ``process_df_scalars`` runs only cost=scalar
stats. Aggregate funcs (and stats that depend on them) are
skipped. The fast path the JS router uses for the initial
state_change response."""
@stat(cost="aggregate")
def expensive_stat(ser: RawSeries) -> int:
return 999 # would be slow in real life

df = pd.DataFrame({'a': [1, 2, 3, 1, 2]})
pipeline = StatPipeline([length, expensive_stat], unit_test=False)
summary, errs = pipeline.process_df_scalars(df)

assert summary['a']['length'] == 5
# The aggregate stat was filtered out — it did not run.
assert 'expensive_stat' not in summary['a']

def test_process_df_cost_class_filter_aggregates(self):
"""Aggregates-only path: ``process_df_aggregates`` runs the
pipeline but ships only aggregate provides. Used by the JS
router for the slow follow-up after a debounce."""
@stat(cost="aggregate")
def expensive_stat(ser: RawSeries) -> int:
return 999

df = pd.DataFrame({'a': [1, 2, 3, 1, 2]})
pipeline = StatPipeline([length, expensive_stat], unit_test=False)
summary, errs = pipeline.process_df_aggregates(df)

# Only aggregate-cost stat is shipped.
assert summary['a']['expensive_stat'] == 999
# Scalar 'length' is computed but filtered out of the response.
assert 'length' not in summary['a']

def test_process_df_default_runs_all_costs(self):
"""Default ``process_df()`` (no ``cost_classes`` arg) runs all
cost classes — back-compat for every existing caller."""
@stat(cost="aggregate")
def expensive_stat(ser: RawSeries) -> int:
return 999

df = pd.DataFrame({'a': [1, 2, 3]})
pipeline = StatPipeline([length, expensive_stat], unit_test=False)
summary, errs = pipeline.process_df(df)

assert summary['a']['length'] == 3
assert summary['a']['expensive_stat'] == 999

def test_process_column(self):
pipeline = StatPipeline([length, distinct_count, distinct_per], unit_test=False)
ser = pd.Series([1, 2, 3, 1, 2])
Expand Down
76 changes: 76 additions & 0 deletions tests/unit/test_xorq_buckaroo_widget.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,82 @@ def counting(self, q):
)


class TestCostClassFilter:
"""Phase 2 of the JS-driven progressive-stats router. The
``process_table_scalars`` / ``process_table_aggregates`` entry
points filter by ``StatFunc.cost`` so the JS orchestrator can
fetch cheap stats immediately on state_change and slow stats
after a debounce.
"""

def test_scalars_only_skips_histogram_queries(self):
"""``process_table_scalars`` skips the expensive per-column
histogram path entirely. Spy on ``_execute`` and assert the
query count drops vs the full pipeline."""
from buckaroo.pluggable_analysis_framework.xorq_stat_pipeline import XorqStatPipeline
from buckaroo.customizations.xorq_stats_v2 import XORQ_STATS_V2

full_queries: list = []
scalar_queries: list = []
orig = XorqStatPipeline._execute

def counting(self, q, sink):
sink.append(q)
return orig(self, q)

expr = _expr()

# Full pipeline reference count
XorqStatPipeline._execute = lambda self, q: counting(self, q, full_queries)
try:
p1 = XorqStatPipeline(list(XORQ_STATS_V2), unit_test=False)
p1.process_table(expr)
finally:
XorqStatPipeline._execute = orig

# Scalars-only count — must be strictly fewer
XorqStatPipeline._execute = lambda self, q: counting(self, q, scalar_queries)
try:
p2 = XorqStatPipeline(list(XORQ_STATS_V2), unit_test=False)
p2.process_table_scalars(expr)
finally:
XorqStatPipeline._execute = orig

assert len(scalar_queries) < len(full_queries), (
f"scalars-only must issue fewer queries than full pipeline; "
f"got scalars={len(scalar_queries)} full={len(full_queries)}")

def test_scalars_only_omits_histogram_key(self):
"""``process_table_scalars`` output has no ``histogram`` key for
any column — that stat is cost=aggregate."""
from buckaroo.pluggable_analysis_framework.xorq_stat_pipeline import XorqStatPipeline
from buckaroo.customizations.xorq_stats_v2 import XORQ_STATS_V2

pipeline = XorqStatPipeline(list(XORQ_STATS_V2), unit_test=False)
summary, errs = pipeline.process_table_scalars(_expr())
for col, stats in summary.items():
assert "histogram" not in stats, (
f"col {col} has histogram in scalars-only output: "
f"{list(stats.keys())}")

def test_aggregates_only_ships_only_aggregate_keys(self):
"""``process_table_aggregates`` returns only aggregate-cost
stat provides. Scalars are computed (as dependencies) but
filtered out of the response."""
from buckaroo.pluggable_analysis_framework.xorq_stat_pipeline import XorqStatPipeline
from buckaroo.customizations.xorq_stats_v2 import XORQ_STATS_V2

pipeline = XorqStatPipeline(list(XORQ_STATS_V2), unit_test=False)
summary, errs = pipeline.process_table_aggregates(_expr())
for col, stats in summary.items():
# The boston-style histogram stat is aggregate.
if "histogram" in stats:
# Other obvious scalar keys must NOT be there.
assert "length" not in stats, (
f"col {col} aggregate response leaks scalar 'length'")
assert "min" not in stats


class TestInstantiation:
def test_smoke(self):
XorqBuckarooWidget(_expr())
Expand Down