From f2164228d7c338301811a39fc0d452c09cfb39ff Mon Sep 17 00:00:00 2001 From: Weston Pace Date: Thu, 14 May 2026 13:00:54 +0000 Subject: [PATCH 1/4] =?UTF-8?q?feat(index):=20aggregate=20pushdown=20?= =?UTF-8?q?=E2=80=94=20parser,=20exec,=20prefilter=20wiring,=20tests?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Lands the plan-time and execute-time halves of aggregate pushdown. Not yet wired into the scanner. Plan side (rust/lance-index/src/expression/): - Moves scalar/expression.rs to expression/scalar.rs, paralleling the new expression/aggregate.rs. - AnyAggregateQuery and AggregateQueryParser traits. - AggregateIndexSearch leaf with optional index_name, parsed query, optional per-aggregate filter, and the original SELECT expression. - CountQuery (basic / approx / distinct / approx_distinct) and CountQueryParser. Scalar index trait (rust/lance-index/src/scalar.rs): - ScalarIndex::calculate_aggregate returning a partial-state ArrowScalar; default-error stubs added to btree, bitmap, bloomfilter, inverted, json, label_list, ngram, rtree, zonemap, and the LogicalScalarIndex wrapper. - Re-exports lance_arrow_scalar::ArrowScalar through scalar::. Execute side (rust/lance/src/io/exec/aggregate_index.rs): - AggregateIndexSearchExec emits one partial-state RecordBatch whose schema is the concatenation of state_fields() for each paired AggregateFunctionExpr, so a downstream AggregateExec(Final) consumes it unchanged. - One optional child input — a ScalarIndexExec — supplies a prefilter RowAddrMask. The prefilter load and per-aggregate index loads run in parallel. - Intersects fragment bitmaps across indexed aggregates, materializes the allow list as concrete [0..physical_rows) ranges (avoids the RoaringBitmap::full() inflation in RowAddrTreeMap::Sub), then composes prefilter ∩ fragments_allow − deletion_mask into a single AllowList. - Calls calculate_aggregate per indexed aggregate; falls back to counting the combined mask directly when an aggregate is a non-distinct COUNT without an associated index. - Unit tests cover try_new validation, the Full+Partial count helper, and end-to-end execution with no prefilter, an AllowList prefilter, a BlockList prefilter, and deletions. Also includes aggregate-pushdown-research.md surveying how mature query engines structure aggregate pushdown. Co-Authored-By: Claude Opus 4.7 (1M context) --- Cargo.lock | 1 + aggregate-pushdown-research.md | 225 ++ rust/lance-index/Cargo.toml | 1 + rust/lance-index/src/expression.rs | 9 + rust/lance-index/src/expression/aggregate.rs | 205 ++ rust/lance-index/src/expression/scalar.rs | 3106 ++++++++++++++++ rust/lance-index/src/lib.rs | 1 + rust/lance-index/src/scalar.rs | 15 + rust/lance-index/src/scalar/bitmap.rs | 17 +- rust/lance-index/src/scalar/bloomfilter.rs | 14 + rust/lance-index/src/scalar/btree.rs | 15 +- rust/lance-index/src/scalar/expression.rs | 3107 +---------------- rust/lance-index/src/scalar/inverted/index.rs | 14 + rust/lance-index/src/scalar/json.rs | 14 + rust/lance-index/src/scalar/label_list.rs | 15 +- rust/lance-index/src/scalar/ngram.rs | 14 + rust/lance-index/src/scalar/rtree.rs | 14 + rust/lance-index/src/scalar/zonemap.rs | 14 + rust/lance/src/index/scalar_logical.rs | 18 +- rust/lance/src/io/exec.rs | 1 + rust/lance/src/io/exec/aggregate_index.rs | 776 ++++ 21 files changed, 4487 insertions(+), 3109 deletions(-) create mode 100644 aggregate-pushdown-research.md create mode 100644 rust/lance-index/src/expression.rs create mode 100644 rust/lance-index/src/expression/aggregate.rs create mode 100644 rust/lance-index/src/expression/scalar.rs create mode 100644 rust/lance/src/io/exec/aggregate_index.rs diff --git a/Cargo.lock b/Cargo.lock index be629f78146..92b9bc2b919 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -4729,6 +4729,7 @@ dependencies = [ "jieba-rs", "jsonb", "lance-arrow", + "lance-arrow-scalar", "lance-core", "lance-datafusion", "lance-datagen", diff --git a/aggregate-pushdown-research.md b/aggregate-pushdown-research.md new file mode 100644 index 00000000000..e4084423fce --- /dev/null +++ b/aggregate-pushdown-research.md @@ -0,0 +1,225 @@ +# Aggregate Pushdown in Mature Query Engines + +Background research for the Lance `feat-aggregate-pushdown` work. The motivating use case is `COUNT(DISTINCT col)` directly from a bitmap index, but this report maps the broader design space. + +## 1. Executive Summary + +- **Three distinct families of "aggregate pushdown"** are conflated in vendor docs. Keep them separate when designing Lance's APIs: (a) *metadata-only execution* — answer the aggregate from per-fragment statistics with zero data IO (Snowflake, Iceberg PR #6622, SQL Server segment metadata for `MIN/MAX/COUNT`); (b) *scan-local aggregation* — run the aggregate inside the scan operator over compressed/encoded data, eliminating a separate Aggregate node (SQL Server "Aggregate Pushdown" since 2016, ClickHouse `optimize_aggregation_in_order`); (c) *materialized/pre-aggregated structures* — separate physical artifact that answers many GROUP BYs (ClickHouse projections, Pinot star-tree, SQL Server indexed views, AggregatingMergeTree). +- **`MIN/MAX` is the universally-supported case.** Every engine has a `MinMaxAggPath`-equivalent that either reads endpoints from a sorted index (Postgres) or reads per-segment min/max statistics (everyone else). Lance has min/max page statistics already — turning these into an `Aggregate` rewrite is the lowest-hanging fruit and matches `preprocess_minmax_aggregates` in Postgres almost exactly. +- **`COUNT(*)` from metadata is universally supported but with caveats.** Without a predicate, every engine answers from row counts in fragment/manifest metadata. *With* a predicate, only fragments whose stats *prove* full inclusion or full exclusion can be skipped — partial fragments must still be scanned. DataFusion's "Fully Matching / Partially Matching / Not Matching" trichotomy (the limit-pruning blog post, March 2026) is the cleanest articulation. +- **`COUNT(DISTINCT)` from a bitmap index is unusual but legitimate.** Druid's `cardinality` aggregator returns approximate distinct counts directly from per-value bitmaps. *Exact* `COUNT(DISTINCT)` from a bitmap is also trivial — it is the dictionary size after applying the predicate's row mask. Lance's bitmap index already has per-value posting lists, so exact distinct count is the natural fit, not HLL. +- **Partial vs. full pushdown matters at the planner level.** Spark's `SupportsPushDownAggregates.supportCompletePushDown()` is the canonical API: per-fragment partial aggregates with a final reduction step in the engine. This is also how Postgres's partition-wise aggregate and postgres_fdw work (`combinefunc`/`serialfunc`/`deserialfunc`). Lance will likely need the same split because indexes are per-fragment. +- **NULL semantics differ between aggregates and become an issue.** `COUNT(*)` counts rows; `COUNT(col)` skips nulls; `MIN/MAX` skip nulls. Iceberg's PR #6622 distinguishes "stat is null because column is all-null" (legal answer) from "stat is missing" (abort pushdown) via a `hasValue` flag. Lance needs the same distinction. +- **Predicate compatibility is the gating constraint.** A pushed aggregate is only legal if the predicate is *also* fully evaluable from the same metadata — otherwise the count/min/max applies to an over-set of rows. This is the source of most correctness bugs in this area (cf. the Iceberg "Fix aggregate pushdown" thread). +- **GROUP BY pushdown is the hard mode.** SQL Server's "grouped aggregate pushdown" only fires when the grouping key bit-packs into ≤10 bits *and* a runtime "benefit measure" stays above a threshold. Pinot's star-tree solves it with a precomputed index. ClickHouse's projections do too. There is no cheap implementation — Lance should defer until non-grouped pushdown is solid. +- **MVCC/visibility is an issue only for transactional engines.** Postgres's index-only-scan has to consult the visibility map; SQL Server's pushdown only applies to "compressed rowgroups" not the delta store. Lance's append-only/versioned model sidesteps this — but the analogue is *deletion vectors / row-level deletes*. Iceberg PR #6622 explicitly disables aggregate pushdown when row-level deletes exist. Lance must do the same when deletion vectors apply. +- **The optimizer integration is consistently a dedicated planner pass, not a generic rule.** Postgres's `preprocess_minmax_aggregates` runs in `grouping_planner` just before `query_planner`. DataFusion's `AggregateStatistics` is a `PhysicalOptimizerRule`. Spark uses a V2 datasource interface (`SupportsPushDownAggregates`). The pattern is consistent: detect the shape, build an alternative path, let the cost model choose. + +--- + +## 2. Taxonomy of Techniques + +``` + ┌─────────────────────────────────────────────┐ + │ Aggregate Pushdown │ + └─────────────────────────────────────────────┘ + │ + ┌──────────────────────────────┼──────────────────────────────┐ + ▼ ▼ ▼ +┌───────────────┐ ┌───────────────┐ ┌───────────────┐ +│ Metadata- │ │ Scan-local │ │ Materialized/ │ +│ only │ │ aggregation │ │ Pre-aggregated│ +│ (no IO) │ │ (closer to │ │ artifacts │ +│ │ │ data) │ │ │ +└───────────────┘ └───────────────┘ └───────────────┘ + │ │ │ + │ MIN/MAX from index endpoints │ Aggregate inside scan │ Indexed views (MSSQL) + │ (Postgres MinMaxAggPath) │ over compressed data │ Projections (ClickHouse) + │ │ (MSSQL agg pushdown) │ AggregatingMergeTree + │ MIN/MAX/COUNT from zone-maps │ │ Star-tree (Pinot) + │ (Iceberg, Snowflake, MSSQL │ SIMD-vectorized agg over │ Materialized views (PG, Snowflake) + │ segments, DuckDB zonemap) │ bit-packed encoded data │ + │ │ │ Roll-up tables (Druid) + │ COUNT(*) from row counts │ Grouped agg pushdown │ + │ │ (MSSQL, 2019+) │ + │ COUNT DISTINCT from bitmap │ │ + │ dictionary (Druid) │ │ + │ │ │ + │ HLL distinct from sketches │ │ + │ (Druid hyperUnique, BQ, Snow) │ │ + └──────────────────────────────────┴───────────────────────────────┘ + + Orthogonal axis: partial vs. complete + ┌─────────────────────────────────────────────────────────────────┐ + │ Complete: source returns final answer (single fragment, or │ + │ commutative aggregate over independent fragments │ + │ where source does the reduction itself). │ + │ Partial: source returns per-fragment partial aggregate state; │ + │ engine reduces with combinefunc. │ + │ - Spark: SupportsPushDownAggregates.supportCompletePushDown() │ + │ - Postgres: partial aggregates (combine/serial/deserialfunc) │ + │ - postgres_fdw: per-foreign-server partial aggregation │ + └─────────────────────────────────────────────────────────────────┘ +``` + +--- + +## 3. Per-Engine Sections + +### 3.1 PostgreSQL + +**MIN/MAX via `MinMaxAggPath` — `src/backend/optimizer/plan/planagg.c`.** The function `preprocess_minmax_aggregates(PlannerInfo *root)` is called by `grouping_planner` just before `query_planner`. It checks: only aggregates in target list, single base relation, no `GROUP BY`/window/CTE, single-argument MIN/MAX (recognized via its sort operator from `pg_aggregate` — `fetch_agg_sort_op`), no `DISTINCT`/`ORDER BY`/`FILTER` on the aggregate, no mutable functions, no row-type args. For each matching aggregate, `build_minmax_path` builds an effective `SELECT col FROM t WHERE col IS NOT NULL ORDER BY col [DESC] LIMIT 1` subquery and registers a `MinMaxAggPath` against the `UPPERREL_GROUP_AGG` upper rel. Cost model decides between this and a scan-based `AggPath`. ([planagg.c](https://doxygen.postgresql.org/planagg_8c_source.html), [Cybertec write-up](https://www.cybertec-postgresql.com/en/speeding-up-min-and-max/)) + +**`COUNT(*)` and index-only scans.** Postgres has no `count(*)`-from-index optimization analogous to MIN/MAX. The closest is *index-only scan*: a btree scan that skips heap access when the visibility map's all-visible bit is set for the heap page. `EXPLAIN` reports `Heap Fetches: N` for cases where the VM bit was not set. Index-only scans require: index type supports it (btree always; GiST/SP-GiST for some opclasses; GIN never); query references only indexed columns; relevant heap pages are all-visible (requires `VACUUM`). With predicates, btree can do `LooseIndexScan`-like skips, but a full `count(*)` still walks every index entry. ([Index-Only Scans docs](https://www.postgresql.org/docs/current/indexes-index-only-scans.html)) + +**Partition-wise aggregate and FDW pushdown.** Postgres 10 added remote aggregation in `postgres_fdw`; subsequent commits added partition-wise aggregate, which decomposes the top-level Agg into per-partition Aggs that can each be pushed to a foreign server. The plan ends with a final `Aggregate` whose `combinefunc`/`serialfunc`/`deserialfunc` (declared in `CREATE AGGREGATE`) merge the partials. Enabled by `enable_partitionwise_aggregate` GUC. Restrictions: no `DISTINCT`/`ORDER BY` in aggregate, no `HAVING`, not `array_agg`. ([EDB Aggregate Push-down post](https://www.enterprisedb.com/blog/postgresql-aggregate-push-down-postgresfdw), [commit message](https://www.postgresql.org/message-id/E1f30tV-0003rh-27@gemulon.postgresql.org)) + +### 3.2 DuckDB + +DuckDB auto-builds **zonemaps** (per-row-group min/max) for all general-purpose types and uses them for both predicate pushdown and "computing aggregations" (Indexing docs). Row groups are ~122,880 rows. The optimizer pipeline (Filter Pushdown, Join Order, TopN, Expression Rewriter, Filter Pull-up, IN Rewriter, Statistics Propagation, Reorder Filters, Join Filter Pushdown) does not document a dedicated metadata-only-aggregate rule, but Statistics Propagation does fold known constants (e.g., `MIN/MAX` of a column with known range) at plan time. The **ART index** is documented as not affecting aggregation/join/sort performance — it is only for point lookups and PK enforcement. ([Indexing](https://duckdb.org/docs/current/guides/performance/indexing), [Optimizers blog](https://duckdb.org/2024/11/14/optimizers)) + +### 3.3 SQL Server (Columnstore) + +**Segment elimination** drops rowgroups whose per-segment min/max prove a predicate cannot match. Numeric/date types since 2014; string/binary/guid since 2022. Each rowgroup also stores row count for instant `COUNT(*)`. ([SQLpassion segment elimination](https://www.sqlpassion.at/archive/2017/01/30/columnstore-segment-elimination/)) + +**Aggregate Pushdown (2016+).** The Aggregate operator is fused into the Columnstore Scan; aggregation runs on compressed/bit-packed data with SIMD. Supports `MIN`, `MAX`, `SUM`, `COUNT`, `COUNT(*)` when input+output fit in 64 bits (int family, money, decimal/numeric with precision ≤18, date/time types). **Not supported**: `DISTINCT`, string columns, virtual columns, delta store rows (which still flow up to the Aggregate node). EXPLAIN exposes `ActualLocallyAggregatedRows`. ([Microsoft post](https://learn.microsoft.com/en-us/archive/blogs/sql_server_team/columnstore-index-performance-sql-server-2016-aggregate-pushdown)) + +**Grouped Aggregate Pushdown (2019+).** Extends to `GROUP BY`. Each output batch (~900 rows) makes a *runtime* choice between fast (pushdown) and slow paths based on a "benefit measure" starting at 100 and decremented when batches contain few rows per key (22% penalty for <8/key, 11% for 8–16/key). Disables entirely when bit-packed grouping key exceeds 10 bits. Pure RLE keys always fast-path. ([Paul White, SQLPerformance](https://sqlperformance.com/2019/04/sql-plan/grouped-aggregate-pushdown)) + +**Indexed Views.** Materialized `SELECT ... GROUP BY` results with synchronous maintenance. Optimizer can use them transparently if `EXPAND VIEWS` is off — purely planner-side pattern match against `SELECT` shape. + +### 3.4 ClickHouse + +**Granule-level min/max + skip indexes.** Default granule is 8192 rows; the primary key (sparse) gives row-range pruning, and explicit `minmax`/`set`/`bloom_filter` skip indexes augment it. The `optimize_use_implicit_projections` and `optimize_use_projections` flags drive the optimizer to consider projections. + +**Projections** (transparent materialized aggregates). When a projection defines `GROUP BY`, the underlying engine becomes `AggregatingMergeTree` and aggregate columns become `AggregateFunction(...)` states. The optimizer "automatically samples the primary keys and chooses a table that can generate the same correct result, but requires the least amount of data to be read." Since 25.5, projections can store only sorting keys + `_part_offset` to act as a pure index. ([Projections docs](https://clickhouse.com/docs/data-modeling/projections)) + +**AggregatingMergeTree.** Stores partial states for aggregations; `min`/`max` need no extra merge cost ("require no extra steps to calculate the final result from the intermediate steps"). The `SimpleAggregateFunction` combinator is an optimized form for aggregates whose state is just the result (`min`, `max`, `sum`, `any`, `anyLast`). ([Altinity KB](https://kb.altinity.com/altinity-kb-queries-and-syntax/simplestateif-or-ifstate-for-simple-aggregate-functions/)) + +### 3.5 Apache Druid + +**Bitmap indexes per dictionary entry.** For each distinct value in a (string) column, Druid stores one Roaring-compressed bitmap of matching rows. Combined with a dictionary mapping string→int. ([Segments doc](https://druid.apache.org/docs/latest/design/segments/)) + +**`cardinality` and `hyperUnique` aggregators.** `COUNT(DISTINCT)` in SQL is translated to `cardinality`, which returns an *approximate* count via HyperLogLog over the dimension values; `hyperUnique` is the recommended path when you only need the count, not the values — it's stored as an HLL sketch in the segment, so the count is computed by merging sketches across segments, no per-row work. Druid recommends DataSketches (theta/HLL) for new use cases. ([HLL old docs](https://druid.apache.org/docs/latest/querying/hll-old.html), [CALCITE-1670](https://issues.apache.org/jira/browse/CALCITE-1670)) + +For *exact* distinct count, Druid does not push down — it runs a groupBy and counts. The bitmap-per-value structure means exact distinct count *could* be answered as "number of bitmaps in the dictionary whose intersection with the predicate mask is non-empty" — this is exactly the Lance opportunity. + +### 3.6 Apache Pinot — Star-Tree Index + +Pre-aggregated multi-dimensional tree. Each level splits on a dimension; each internal node has a "star" child holding the aggregate with that dimension dropped. The planner pattern-matches a query's `GROUP BY` dimensions and aggregate functions against an available star-tree's schema. Aggregations are *materialized* at build time. Reported gains: "99.76% reduction in latency vs. no Star-Tree Index (6.3 seconds to 15 ms)" and "99.99999% reduction in amount of data scanned." Supports COUNT/SUM/MIN/MAX/etc.; approximate distinct via DataSketches theta/HLL stored as the aggregate value at the node. ([Pinot docs](https://docs.pinot.apache.org/basics/indexing/star-tree-index), [Part 3 blog](https://startree.ai/resources/star-tree-index-in-apache-pinot-part-3-understanding-the-impact-in-real-customer/)) + +### 3.7 Snowflake + +**Micro-partition metadata** stored per partition: column value ranges, distinct counts, and "additional properties." Metadata is in the cloud-services layer, queried before any data IO. `count(*)`, `MIN(col)`, `MAX(col)` on a partition-aligned column with no predicate (or with a predicate that aligns with metadata) can return from metadata alone, hence the well-known "instant `COUNT(*)`" on Snowflake. ([Micro-partitions docs](https://docs.snowflake.com/en/user-guide/tables-clustering-micropartitions)) + +**Snowflake Optima (2024-2025).** Dynamically generates *additional* lightweight per-micro-partition metadata for high-frequency "hot" expressions seen in workloads — extending min/max-style pruning to expressions like `LOWER(col) = ...`. ([Optima blog](https://www.snowflake.com/en/engineering-blog/snowflake-optima-metadata-query-pruning/)) + +### 3.8 Parquet / Iceberg + +**Parquet** stores per-row-group and per-page min/max, null count, distinct count (optional, often unset by writers). These drive predicate pushdown but are also enough material for aggregate pushdown. + +**Iceberg PR #6622** (`huaxingao`, merged) implemented `MIN/MAX/COUNT` pushdown through Spark's `SupportsPushDownAggregates`. Key classes: `AggregateEvaluator`, `BoundAggregate` (with `hasValue` to distinguish "all-null column" from "stats missing"), `MaxAggregate`, `MinAggregate`, `CountNonNull`. `SparkScanBuilder` orchestrates. Restrictions explicitly enumerated: +- **No GROUP BY** ("Group by aggregation push down is not supported") +- **No row-level deletes** ("Skipped aggregate pushdown: detected row level deletes") +- **No complex types lacking stats** +- **No truncated string metrics** (default mode truncates strings; can't reason about MIN/MAX) + +Toggle: `spark.sql.iceberg.aggregate-push-down-enabled`. ([PR #6622](https://github.com/apache/iceberg/pull/6622)) + +### 3.9 Spark V2 — `SupportsPushDownAggregates` + +The data-source-side contract used by Iceberg, JDBC, file sources. `pushAggregation(Aggregation): boolean` to attempt pushdown; `supportCompletePushDown(Aggregation): boolean` to declare whether the source returns final or partial. If partial, Spark inserts a final Aggregate above the V2 scan with the combine semantics. Filter pushdown happens *first*, then aggregate pushdown — so the data source sees already-filtered fragments. ([Spark JavaDoc](https://spark.apache.org/docs/3.4.3/api/java/org/apache/spark/sql/connector/read/SupportsPushDownAggregates.html)) + +### 3.10 DataFusion + +Has an `AggregateStatistics` physical optimizer rule that converts `MIN/MAX/COUNT(*)` over a scan with exact statistics into a constant `ProjectionExec` — pure metadata-only execution. Issue [#19938](https://github.com/apache/datafusion/issues/19938) proposes extending min/max statistics to drive group-by *layout* (use a `Vec` indexed by `value - min` when the range is small). The "Limit Pruning" blog (March 2026) describes a clean three-tier model: *Not Matching* / *Partially Matching* / *Fully Matching* row groups, where Fully-Matching groups can satisfy `LIMIT` without row-level filtering — directly applicable to aggregate pushdown: Fully-Matching row groups can contribute exact counts from their row-count statistic. ([Query Optimizer docs](https://datafusion.apache.org/library-user-guide/query-optimizer.html), [Limit Pruning blog](https://datafusion.apache.org/blog/2026/03/20/limit-pruning/)) + +--- + +## 4. Index-Type → Aggregate-Type Matrix + +| Index / Metadata | `COUNT(*)` | `MIN/MAX` | `SUM` | `COUNT(col)` (non-null) | `COUNT(DISTINCT col)` | `GROUP BY` cardinality | +|--- |--- |--- |--- |--- |--- |--- | +| Row count per fragment | Yes (no pred) | No | No | Need null count | No | No | +| Zone map (min/max) | No* | **Yes** | No | No | No | No | +| Null count per fragment | Yes (with above) | No | No | **Yes** (no pred) | No | No | +| Distinct count per frag. | No | No | No | No | Approx (upper bound)† | No | +| Btree (ordered) | Walk index | **Yes** O(log n) | Walk index | Walk index | Loose-index scan | Stream-grouped scan | +| Bitmap (one-per-value) | Sum of all bitmaps | **Yes** (first/last value with non-empty bitmap) | No | Bitmap union cardinality | **Yes** (count of values with non-empty bitmap intersected with predicate mask) | **Yes** (cardinality of each bitmap, partition by value) | +| HLL/Theta sketch | No | No | No | No | **Yes** (approximate) | Per-group sketch merge | +| Materialized view / projection / star-tree | Yes | Yes | Yes | Yes | Yes (if pre-aggregated) | **Yes** | + +*`COUNT(*)` from a zone map alone needs row count too — but every engine stores both per fragment, so in practice this is a single lookup. +†Per-fragment distinct counts cannot be summed (overlap); they bound the answer above. + +The bitmap row is the strongest case for Lance. Bitmap-cardinality identities: +``` +COUNT(col) = popcount( OR_v posting[v] ) over predicate-masked rows +COUNT(DISTINCT col)= |{ v : posting[v] AND mask != ∅ }| +COUNT(*) WHERE col=v = popcount( posting[v] AND mask ) +GROUP BY col, COUNT(*) = for v in dict: emit (v, popcount(posting[v] AND mask)) +``` + +--- + +## 5. Planner Integration Patterns + +Three recurring shapes, in order of complexity: + +**(a) Pre-planner rewrite (Postgres pattern).** A dedicated function — `preprocess_minmax_aggregates` — runs *before* the main path enumeration. It builds an alternate path (`MinMaxAggPath`) parallel to the normal Aggregate-over-Scan path. The cost model picks the winner. Pros: keeps the special case out of the general optimizer. Cons: each new shape is a new bespoke function. + +**(b) Physical-optimizer rule (DataFusion `AggregateStatistics`).** A late physical-plan rewrite that inspects the plan tree for `AggregateExec { mode: Final, expr: [Min|Max|Count], input: ScanExec }` and, if the scan can produce exact statistics for those columns, replaces the whole subtree with a `ProjectionExec` of constants. Pros: composes with existing rules. Cons: must reason about partial-vs-final aggregate modes; needs exact (not estimated) statistics. + +**(c) Data-source interface (Spark V2 `SupportsPushDownAggregates`).** The optimizer hands the data source an `Aggregation` description; the source returns whether (and how completely) it can satisfy it. If partial, optimizer inserts a final-stage Aggregate above. Pros: clean separation; the source owns correctness. Cons: API surface is large; partial-aggregate plumbing must be wired. + +**Recommendation for Lance.** Mirror Spark V2's contract at the `Scan` level, but execute the dispatch in DataFusion's physical optimizer (since Lance plans through DataFusion already). The `Scan` would expose `try_pushdown_aggregate(agg, filter) -> Option`. The optimizer rule walks `AggregateExec(final) → AggregateExec(partial) → Scan` patterns and asks the scan whether it can satisfy. Index access lives inside the scan (or its `MetricsProvider`), not in the optimizer. + +--- + +## 6. Correctness Gotchas + +1. **Predicate-must-be-fully-evaluable-by-index.** If the index can evaluate `col = 5` but not `f(col) = 5`, the predicate must be either rejected by the index entirely or split. A pushed aggregate over a partially-filtered set is silently wrong. Iceberg's PR thread had multiple iterations on this. + +2. **NULL handling per aggregate.** `COUNT(*)` counts rows including nulls; `COUNT(col)` and `MIN/MAX` skip nulls. Need both row count and null count per fragment. Iceberg's `BoundAggregate.hasValue` distinguishes "stat exists and column is all-null (legal answer for MIN/MAX = NULL)" from "stat missing → abort." + +3. **Row-level deletes / deletion vectors / MVCC.** Stale statistics. Postgres: visibility map. SQL Server: delta rowgroups bypass pushdown. Iceberg: aggregate pushdown disabled if row-level deletes exist on touched files. **Lance equivalent: deletion vectors.** Pushdown must either consult deletion vector population (row count − deleted count) or abort. + +4. **Empty input vs zero.** `COUNT` on zero rows is `0`; `MIN/MAX/SUM` on zero rows is `NULL`. The fast path must return the right type, not silently coerce. + +5. **`COUNT(DISTINCT)` overlap across fragments.** Per-fragment distinct counts cannot be summed. Two paths: (a) merge an exact structure (sorted dictionary or bitmap union) across fragments; (b) merge HLL/theta sketches for approximate answer. Lance bitmap indexes naturally support (a) via posting-list union. + +6. **Truncated/lossy statistics.** Parquet writers commonly truncate string min/max. Iceberg refuses pushdown in this case. Lance should mark such stats as inexact and refuse. + +7. **`MIN/MAX` sort operator vs. aggregate sort order.** Postgres's `fetch_agg_sort_op` looks up the agg's sort operator from `pg_aggregate`. A user-defined min-like aggregate is not eligible unless registered correctly. Lance's analogue: only well-known `MIN`/`MAX` over orderable types qualify; do not try to be clever with user-defined aggregates. + +8. **GROUP BY combined with aggregate pushdown is partial by definition.** Each fragment emits `(group_key, partial_agg)`, and the engine reduces across fragments. The fragment-side dedup is *not* a complete `GROUP BY` — duplicates across fragments are normal and required for correctness. SQL Server's docs: "the data source can still output data with duplicated keys, which is OK as Spark will do GROUP BY key again." + +9. **Aggregate-over-filter ordering.** Spark V2 explicitly pushes filters *before* aggregates. Lance's scan API should follow: aggregate pushdown receives the post-filter view. + +10. **Approximate vs exact must be explicit in the API.** Calcite Druid translation of `COUNT(DISTINCT)` to `cardinality` was filed as a bug (CALCITE-1670) because users didn't expect approximate semantics. Lance should never silently approximate. + +--- + +## 7. Open Questions / Things I Couldn't Pin Down Authoritatively + +- **DuckDB's exact metadata-only path.** Multiple sources say zonemaps drive "computing aggregations" but I could not find a named optimizer rule (e.g., a `count_star_metadata` rule) in either the optimizer blog or the indexing docs. Need to read `src/optimizer/` in the DuckDB tree directly — start at [`optimizer.cpp`](https://github.com/duckdb/duckdb/blob/main/src/optimizer/optimizer.cpp) and look for statistics-propagation paths that fold to constants. +- **ClickHouse projection selection cost model.** Docs say "the optimizer automatically samples the primary keys" but I did not find a description of the tie-breaking when multiple projections could serve. Likely in `Processors/QueryPlan/Optimizations/optimizeUseAggregateProjection.cpp` in source. +- **Snowflake metadata-only execution rules.** Marketing-level confirmation that COUNT/MIN/MAX from metadata works, but no published planner doc. The Optima blog is the closest thing and is high-level. +- **Pinot star-tree planner matching.** Docs describe the structure but not the matcher. The pattern from the description is "exact match on dimension subset + supported aggregate"; needs source-code confirmation (see `pinot-segment-spi`). +- **Druid exact COUNT(DISTINCT) status.** There is a community "Exact Cardinality Count" extension PR but it is not in core. Mainline path is HLL-approximate. Worth a follow-up: does Druid's bitmap structure make exact distinct count "free enough" that someone proposed a core impl? (The PR exists; review comments would tell us why it didn't merge.) +- **Postgres `count(*)` from index.** I expected a planner rewrite analogous to MinMaxAggPath. I couldn't find one — it appears `count(*)` always goes through an actual scan (possibly index-only), never a metadata read. Worth confirming on `pgsql-hackers`; multiple threads have proposed it and been declined for MVCC reasons. +- **Iceberg manifest-only `MIN/MAX` correctness with column nullability.** PR #6622 introduces `hasValue` but I didn't trace whether mixed-null + non-null fragments are merged correctly when *some* fragments have stats and *others* don't. Worth reading the test cases before mirroring the design. + +--- + +### Sources + +- PostgreSQL: [planagg.c source](https://doxygen.postgresql.org/planagg_8c_source.html) · [Cybertec MIN/MAX speedup](https://www.cybertec-postgresql.com/en/speeding-up-min-and-max/) · [Index-Only Scans](https://www.postgresql.org/docs/current/indexes-index-only-scans.html) · [Wiki: Index-only scans](https://wiki.postgresql.org/wiki/Index-only_scans) · [EDB Aggregate Push-down](https://www.enterprisedb.com/blog/postgresql-aggregate-push-down-postgresfdw) · [Partition-wise aggregation commit](https://www.postgresql.org/message-id/E1f30tV-0003rh-27@gemulon.postgresql.org) +- DuckDB: [Indexing](https://duckdb.org/docs/current/guides/performance/indexing) · [Indexes](https://duckdb.org/docs/current/sql/indexes) · [Optimizers blog](https://duckdb.org/2024/11/14/optimizers) · [Row Groups (DeepWiki)](https://deepwiki.com/duckdb/duckdb/7.2-column-storage) +- SQL Server: [Aggregate Pushdown 2016](https://learn.microsoft.com/en-us/archive/blogs/sql_server_team/columnstore-index-performance-sql-server-2016-aggregate-pushdown) · [Grouped Aggregate Pushdown (Paul White)](https://sqlperformance.com/2019/04/sql-plan/grouped-aggregate-pushdown) · [Columnstore Query Performance](https://learn.microsoft.com/en-us/sql/relational-databases/indexes/columnstore-indexes-query-performance) · [ColumnStore Segment Elimination](https://www.sqlpassion.at/archive/2017/01/30/columnstore-segment-elimination/) +- ClickHouse: [Projections docs](https://clickhouse.com/docs/data-modeling/projections) · [AggregatingMergeTree (Altinity)](https://kb.altinity.com/engines/mergetree-table-engine-family/aggregatingmergetree/) · [SimpleState combinator](https://kb.altinity.com/altinity-kb-queries-and-syntax/simplestateif-or-ifstate-for-simple-aggregate-functions/) +- Druid: [Segments design](https://druid.apache.org/docs/latest/design/segments/) · [HLL old aggregator](https://druid.apache.org/docs/latest/querying/hll-old.html) · [Aggregations reference](https://druid.apache.org/docs/latest/querying/aggregations/) · [CALCITE-1670](https://issues.apache.org/jira/browse/CALCITE-1670) +- Pinot: [Star-Tree Index docs](https://docs.pinot.apache.org/basics/indexing/star-tree-index) · [Star-Tree Part 3](https://startree.ai/resources/star-tree-index-in-apache-pinot-part-3-understanding-the-impact-in-real-customer/) +- Snowflake: [Micro-partitions and clustering](https://docs.snowflake.com/en/user-guide/tables-clustering-micropartitions) · [Snowflake Optima](https://www.snowflake.com/en/engineering-blog/snowflake-optima-metadata-query-pruning/) · [Pruning paper (arXiv)](https://arxiv.org/html/2504.11540v1) +- Iceberg/Spark: [Iceberg PR #6622 (aggregate pushdown)](https://github.com/apache/iceberg/pull/6622) · [Iceberg statistics (Ryft)](https://www.ryft.io/blog/making-sense-of-apache-iceberg-statistics) · [Spark SupportsPushDownAggregates JavaDoc](https://spark.apache.org/docs/3.4.3/api/java/org/apache/spark/sql/connector/read/SupportsPushDownAggregates.html) +- DataFusion: [Query Optimizer](https://datafusion.apache.org/library-user-guide/query-optimizer.html) · [Issue #19938 (min/max in grouped aggs)](https://github.com/apache/datafusion/issues/19938) · [Limit Pruning blog (Mar 2026)](https://datafusion.apache.org/blog/2026/03/20/limit-pruning/) · [Optimizing SQL Part 2](https://datafusion.apache.org/blog/2025/06/15/optimizing-sql-dataframes-part-two/) diff --git a/rust/lance-index/Cargo.toml b/rust/lance-index/Cargo.toml index af5e0204320..4ecd142e8dc 100644 --- a/rust/lance-index/Cargo.toml +++ b/rust/lance-index/Cargo.toml @@ -38,6 +38,7 @@ itertools.workspace = true jieba-rs = { workspace = true, optional = true } jsonb.workspace = true lance-arrow.workspace = true +lance-arrow-scalar.workspace = true lance-core.workspace = true lance-datafusion.workspace = true lance-encoding.workspace = true diff --git a/rust/lance-index/src/expression.rs b/rust/lance-index/src/expression.rs new file mode 100644 index 00000000000..1b15f77d219 --- /dev/null +++ b/rust/lance-index/src/expression.rs @@ -0,0 +1,9 @@ +//! Plan-time expression parsing for scalar and aggregate index pushdown. +//! +//! Both halves split a user expression into an index-evaluable leaf plus the +//! residual computation: [`scalar`] parses `WHERE` clauses, [`aggregate`] +//! parses `SELECT`-list aggregates. The execute-time consumers live under +//! `lance::io::exec::{scalar_index, aggregate_index}`. + +pub mod aggregate; +pub mod scalar; diff --git a/rust/lance-index/src/expression/aggregate.rs b/rust/lance-index/src/expression/aggregate.rs new file mode 100644 index 00000000000..f8ccf1e7e80 --- /dev/null +++ b/rust/lance-index/src/expression/aggregate.rs @@ -0,0 +1,205 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright The Lance Authors + +//! Aggregate query parsing — the plan-time half of aggregate pushdown. +//! +//! Parallels [`crate::scalar::expression`]: where the scalar side splits a +//! `WHERE` clause into an index search and a post-filter refine, the aggregate +//! side splits a `SELECT` element (typically containing an aggregate) into an +//! [`AggregateIndexSearch`] and a post-projection [`Expr`] that runs on top of +//! the search's output. + +use std::any::Any; +use std::sync::Arc; + +use arrow_schema::Field; +use datafusion_expr::Expr; + +use datafusion_expr::expr::AggregateFunction; +use lance_core::Result; + +/// A parsed aggregate query, ready to be evaluated against a scalar index. +/// +/// Implementations carry whatever state the parser produced (e.g. the +/// distinct/approximate flags on [`CountQuery`]); the execute side downcasts +/// via [`AnyAggregateQuery::as_any`] to drive the right computation. The +/// counterpart for `WHERE` filters is [`crate::scalar::AnyQuery`]. +pub trait AnyAggregateQuery: std::fmt::Debug + std::fmt::Display + Any + Send + Sync { + /// Cast the query as `Any` to allow for downcasting to the concrete query type. + fn as_any(&self) -> &dyn Any; + /// The expected schema of the output of this aggregate query. + /// + /// This should be the "partial aggregate" representation. For example, if the query + /// is an AVG aggregate then it should be a struct field with two fields: `sum` and `count`. + fn output(&self) -> &Field; +} + +/// A parser that decides whether an expression (containing one +/// aggregate) can be served by indices +/// +/// For example, in the query SELECT MAX(score), MIN(error) FROM t WHERE category = '7' +/// this would be called twice, once for MAX(score) and once for MIN(error). The query +/// should indicate whether it can handle a filter or not. If so, and the filter can +/// be satisfied by scalar index search, then the filter will be provided as a bitmap +/// to the aggregate search. If not, then the aggregate query will not be executed +/// if a filter is present. +/// +/// In most cases this will only return when the aggregate function is an approximate +/// function as exact acceleration of aggregates is difficult. +pub trait AggregateQueryParser: std::fmt::Debug + Send + Sync { + /// Parse the given aggregate, returning a query that can be evaluated. + /// + /// This method should not load or search the index. It is expected we can + /// do the parsing purely from the dataset metadata and the index details. + /// + /// Returns `Some(query)` if the parser recognizes the aggregate and can + /// produce a query for it. Returns `None` if not — the caller is expected + /// to fall back to a normal scan in that case. + fn parse_aggregate( + &self, + aggregate: &AggregateFunction, + ) -> Result>>; +} + +/// A single aggregate-index search — leaf carrying the parsed query. +/// +/// Parallels [`crate::scalar::expression::ScalarIndexSearch`]. There is no +/// tree variant (no `AggregateIndexExpr`) because v1 only emits a single +/// search per parsed aggregate. If we later need to combine multiple +/// aggregate searches logically we'll introduce one. +#[derive(Debug, Clone)] +pub struct AggregateIndexSearch { + /// The index accelerating the aggregate search + /// + /// Will be None for nilary aggregates such as COUNT(*) + pub index_name: Option, + /// The query that the exec node will evaluate + pub query: Arc, + /// Filter to be applied only to this aggregate + pub filter: Option, + /// The original expression that was parsed into this aggregate search + pub original_expr: Expr, +} + +impl std::fmt::Display for AggregateIndexSearch { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!( + f, + "{}@{:?}", + self.query, + self.index_name.as_deref().unwrap_or("*") + ) + } +} + +/// A query for `COUNT`-shaped aggregates — covers `COUNT(*)`, `COUNT(col)`, +/// `COUNT(DISTINCT col)`, and their approximate variants. +/// +/// The combination of `is_approximate` and `is_distinct` selects between the +/// four standard SQL shapes, with the matching constructors below. +#[derive(Debug, Clone, PartialEq)] +pub struct CountQuery { + is_approximate: bool, + is_distinct: bool, + output_field: Field, +} + +impl CountQuery { + /// Exact non-distinct count — `COUNT(*)` or `COUNT(col)`. + pub fn basic() -> Self { + Self { + is_approximate: false, + is_distinct: false, + output_field: Field::new("count", arrow_schema::DataType::UInt64, false), + } + } + + /// Approximate non-distinct count — used when the underlying index can + /// only produce an estimate (e.g. via a sketch). + pub fn approx() -> Self { + Self { + is_approximate: true, + is_distinct: false, + output_field: Field::new("count", arrow_schema::DataType::UInt64, false), + } + } + + /// Exact distinct count — `COUNT(DISTINCT col)`. + pub fn distinct() -> Self { + Self { + is_approximate: false, + is_distinct: true, + output_field: Field::new("count_distinct", arrow_schema::DataType::UInt64, false), + } + } + + /// Approximate distinct count — `APPROX_COUNT_DISTINCT(col)` / HLL-style. + pub fn approx_distinct() -> Self { + Self { + is_approximate: true, + is_distinct: true, + output_field: Field::new( + "approx_count_distinct", + arrow_schema::DataType::UInt64, + false, + ), + } + } + + /// `true` if the result is an approximation rather than an exact count. + pub fn is_approximate(&self) -> bool { + self.is_approximate + } + + /// `true` if the count is over distinct values rather than rows. + pub fn is_distinct(&self) -> bool { + self.is_distinct + } +} + +impl std::fmt::Display for CountQuery { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + self.output_field.name().fmt(f) + } +} + +impl AnyAggregateQuery for CountQuery { + fn as_any(&self) -> &dyn Any { + self + } + fn output(&self) -> &Field { + &self.output_field + } +} + +/// Parser for [`CountQuery`] +#[derive(Debug, Default)] +pub struct CountQueryParser { + #[allow(dead_code)] + index_name: Option, +} + +impl CountQueryParser { + /// Create a parser. `index_name`, when set, identifies a count-supporting + /// scalar index that the produced [`AggregateIndexSearch`] should bind to. + pub fn new(index_name: Option) -> Self { + Self { index_name } + } +} + +impl AggregateQueryParser for CountQueryParser { + fn parse_aggregate( + &self, + agg: &AggregateFunction, + ) -> Result>> { + if agg.func.name() != "count" { + return Ok(None); + } + let query = if agg.params.distinct { + CountQuery::distinct() + } else { + CountQuery::basic() + }; + Ok(Some(Arc::new(query))) + } +} diff --git a/rust/lance-index/src/expression/scalar.rs b/rust/lance-index/src/expression/scalar.rs new file mode 100644 index 00000000000..01d80eaf669 --- /dev/null +++ b/rust/lance-index/src/expression/scalar.rs @@ -0,0 +1,3106 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright The Lance Authors + +use std::{ + ops::Bound, + sync::{Arc, LazyLock}, +}; + +use arrow::array::BinaryBuilder; +use arrow_array::{Array, RecordBatch, UInt32Array}; +use arrow_schema::{DataType, Field, Schema, SchemaRef}; +use async_recursion::async_recursion; +use async_trait::async_trait; +use datafusion_common::ScalarValue; +use datafusion_expr::{ + Between, BinaryExpr, Expr, Operator, ReturnFieldArgs, ScalarUDF, + expr::{InList, Like, ScalarFunction}, +}; +use tokio::try_join; + +use crate::metrics::MetricsCollector; +use crate::scalar::{ + AnyQuery, BloomFilterQuery, LabelListQuery, SargableQuery, ScalarIndex, SearchResult, + TextQuery, TokenQuery, +}; +#[cfg(feature = "geo")] +use crate::scalar::{GeoQuery, RelationQuery}; +use lance_core::{ + Error, Result, + utils::mask::{NullableRowAddrMask, RowAddrMask}, +}; +use lance_datafusion::{expr::safe_coerce_scalar, planner::Planner}; +use roaring::RoaringBitmap; +use tracing::instrument; + +const MAX_DEPTH: usize = 500; + +/// An indexed expression consists of a scalar index query with a post-scan filter +/// +/// When a user wants to filter the data returned by a scan we may be able to use +/// one or more scalar indices to reduce the amount of data we load from the disk. +/// +/// For example, if a user provides the filter "x = 7", and we have a scalar index +/// on x, then we can possibly identify the exact row that the user desires with our +/// index. A full-table scan can then turn into a take operation fetching the rows +/// desired. This would create an IndexedExpression with a scalar_query but no +/// refine. +/// +/// If the user asked for "type = 'dog' && z = 3" and we had a scalar index on the +/// "type" column then we could convert this to an indexed scan for "type='dog'" +/// followed by an in-memory filter for z=3. This would create an IndexedExpression +/// with both a scalar_query AND a refine. +/// +/// Finally, if the user asked for "z = 3" and we do not have a scalar index on the +/// "z" column then we must fallback to an IndexedExpression with no scalar_query and +/// only a refine. +/// +/// Two IndexedExpressions can be AND'd together. Each part is AND'd together. +/// Two IndexedExpressions cannot be OR'd together unless both are scalar_query only +/// or both are refine only +/// An IndexedExpression cannot be negated if it has both a refine and a scalar_query +/// +/// When an operation cannot be performed we fallback to the original expression-only +/// representation +#[derive(Debug, PartialEq)] +pub struct IndexedExpression { + /// The portion of the query that can be satisfied by scalar indices + pub scalar_query: Option, + /// The portion of the query that cannot be satisfied by scalar indices + pub refine_expr: Option, +} + +pub trait ScalarQueryParser: std::fmt::Debug + Send + Sync { + /// Visit a between expression + /// + /// Returns an IndexedExpression if the index can accelerate between expressions + fn visit_between( + &self, + column: &str, + low: &Bound, + high: &Bound, + ) -> Option; + /// Visit an in list expression + /// + /// Returns an IndexedExpression if the index can accelerate in list expressions + fn visit_in_list(&self, column: &str, in_list: &[ScalarValue]) -> Option; + /// Visit an is bool expression + /// + /// Returns an IndexedExpression if the index can accelerate is bool expressions + fn visit_is_bool(&self, column: &str, value: bool) -> Option; + /// Visit an is null expression + /// + /// Returns an IndexedExpression if the index can accelerate is null expressions + fn visit_is_null(&self, column: &str) -> Option; + /// Visit a comparison expression + /// + /// Returns an IndexedExpression if the index can accelerate comparison expressions + fn visit_comparison( + &self, + column: &str, + value: &ScalarValue, + op: &Operator, + ) -> Option; + /// Visit a scalar function expression + /// + /// Returns an IndexedExpression if the index can accelerate the given scalar function. + /// For example, an ngram index can accelerate the contains function. + fn visit_scalar_function( + &self, + column: &str, + data_type: &DataType, + func: &ScalarUDF, + args: &[Expr], + ) -> Option; + + /// Visit a LIKE expression + /// + /// Returns an IndexedExpression if the index can accelerate LIKE expressions. + /// For prefix patterns (e.g., "foo%"): + /// - ZoneMaps prune zones based on min/max statistics + /// - BTrees use range query conversion `[prefix, next_prefix)` + /// + /// For patterns with wildcards in the middle (e.g., "foo%bar%"), the leading prefix + /// can still be used for pruning, with the full pattern as a refine expression. + /// + /// # Arguments + /// * `column` - The column name + /// * `like` - The full LIKE expression (for constructing refine_expr if needed) + /// * `pattern` - The LIKE pattern as ScalarValue (e.g., "foo%") + fn visit_like( + &self, + _column: &str, + _like: &Like, + _pattern: &ScalarValue, + ) -> Option { + None + } + + /// Visits a potential reference to a column + /// + /// This function is a little different from the other visitors. It is used to test if a potential + /// column reference is a reference the index handles. + /// + /// Most indexes are designed to run on references to the indexed column. For example, if a query + /// is "x = 7" and we have a scalar index on "x" then we apply the index to the "x" column reference. + /// + /// However, some indexes are designed to run on projections of the indexed column. For example, + /// if a query is "json_extract(json, '$.name') = 'books'" and we have a JSON index on the "json" column + /// then we apply the index to the projection of the "json" column. + /// + /// This function is used to test if a potential column reference is a reference the index handles. + /// The default implementation matches column references but this can be overridden by indexes that + /// handle projections. + /// + /// The function is also passed in the data type of the column and should return the data type of the + /// reference. Normally this is the same as the input for a direct column reference and possibly something + /// different for a projection. E.g. a JSON column (LargeBinary) might be projected to a string or float + /// + /// Note: higher logic in the expression parser already limits references to either Expr::Column or Expr::ScalarFunction + /// where the first argument is an Expr::Column. If your projection doesn't fit that mold then the + /// expression parser will need to be modified. + fn is_valid_reference(&self, func: &Expr, data_type: &DataType) -> Option { + match func { + Expr::Column(_) => Some(data_type.clone()), + _ => None, + } + } +} + +/// A generic parser that wraps multiple scalar query parsers +/// +/// It will search each parser in order and return the first non-None result +#[derive(Debug)] +pub struct MultiQueryParser { + parsers: Vec>, +} + +impl MultiQueryParser { + /// Create a new MultiQueryParser with a single parser + pub fn single(parser: Box) -> Self { + Self { + parsers: vec![parser], + } + } + + /// Add a new parser to the MultiQueryParser + pub fn add(&mut self, other: Box) { + self.parsers.push(other); + } +} + +impl ScalarQueryParser for MultiQueryParser { + fn visit_between( + &self, + column: &str, + low: &Bound, + high: &Bound, + ) -> Option { + self.parsers + .iter() + .find_map(|parser| parser.visit_between(column, low, high)) + } + fn visit_in_list(&self, column: &str, in_list: &[ScalarValue]) -> Option { + self.parsers + .iter() + .find_map(|parser| parser.visit_in_list(column, in_list)) + } + fn visit_is_bool(&self, column: &str, value: bool) -> Option { + self.parsers + .iter() + .find_map(|parser| parser.visit_is_bool(column, value)) + } + fn visit_is_null(&self, column: &str) -> Option { + self.parsers + .iter() + .find_map(|parser| parser.visit_is_null(column)) + } + fn visit_comparison( + &self, + column: &str, + value: &ScalarValue, + op: &Operator, + ) -> Option { + self.parsers + .iter() + .find_map(|parser| parser.visit_comparison(column, value, op)) + } + fn visit_scalar_function( + &self, + column: &str, + data_type: &DataType, + func: &ScalarUDF, + args: &[Expr], + ) -> Option { + self.parsers + .iter() + .find_map(|parser| parser.visit_scalar_function(column, data_type, func, args)) + } + fn visit_like( + &self, + column: &str, + like: &Like, + pattern: &ScalarValue, + ) -> Option { + self.parsers + .iter() + .find_map(|parser| parser.visit_like(column, like, pattern)) + } + /// TODO(low-priority): This is maybe not quite right. We should filter down the list of parsers based + /// on those that consider the reference valid. Instead what we are doing is checking all parsers if any one + /// parser considers the reference valid. + /// + /// This will be a problem if the user creates two indexes (e.g. btree and json) on the same column and those two + /// indexes have different reference schemes. + fn is_valid_reference(&self, func: &Expr, data_type: &DataType) -> Option { + self.parsers + .iter() + .find_map(|parser| parser.is_valid_reference(func, data_type)) + } +} + +/// A parser for indices that handle SARGable queries +#[derive(Debug)] +pub struct SargableQueryParser { + index_name: String, + index_type: String, + needs_recheck: bool, +} + +impl SargableQueryParser { + pub fn new(index_name: String, index_type: String, needs_recheck: bool) -> Self { + Self { + index_name, + index_type, + needs_recheck, + } + } +} + +impl ScalarQueryParser for SargableQueryParser { + fn is_valid_reference(&self, func: &Expr, data_type: &DataType) -> Option { + match func { + Expr::Column(_) => Some(data_type.clone()), + // Also accept get_field expressions for nested field access + Expr::ScalarFunction(udf) if udf.name() == "get_field" => Some(data_type.clone()), + _ => None, + } + } + + fn visit_between( + &self, + column: &str, + low: &Bound, + high: &Bound, + ) -> Option { + if let Bound::Included(val) | Bound::Excluded(val) = low + && val.is_null() + { + return None; + } + if let Bound::Included(val) | Bound::Excluded(val) = high + && val.is_null() + { + return None; + } + let query = SargableQuery::Range(low.clone(), high.clone()); + Some(IndexedExpression::index_query_with_recheck( + column.to_string(), + self.index_name.clone(), + self.index_type.clone(), + Arc::new(query), + self.needs_recheck, + )) + } + + fn visit_in_list(&self, column: &str, in_list: &[ScalarValue]) -> Option { + if in_list.iter().any(|val| val.is_null()) { + return None; + } + let query = SargableQuery::IsIn(in_list.to_vec()); + Some(IndexedExpression::index_query_with_recheck( + column.to_string(), + self.index_name.clone(), + self.index_type.clone(), + Arc::new(query), + self.needs_recheck, + )) + } + + fn visit_is_bool(&self, column: &str, value: bool) -> Option { + Some(IndexedExpression::index_query_with_recheck( + column.to_string(), + self.index_name.clone(), + self.index_type.clone(), + Arc::new(SargableQuery::Equals(ScalarValue::Boolean(Some(value)))), + self.needs_recheck, + )) + } + + fn visit_is_null(&self, column: &str) -> Option { + Some(IndexedExpression::index_query_with_recheck( + column.to_string(), + self.index_name.clone(), + self.index_type.clone(), + Arc::new(SargableQuery::IsNull()), + self.needs_recheck, + )) + } + + fn visit_comparison( + &self, + column: &str, + value: &ScalarValue, + op: &Operator, + ) -> Option { + if value.is_null() { + return None; + } + let query = match op { + Operator::Lt => SargableQuery::Range(Bound::Unbounded, Bound::Excluded(value.clone())), + Operator::LtEq => { + SargableQuery::Range(Bound::Unbounded, Bound::Included(value.clone())) + } + Operator::Gt => SargableQuery::Range(Bound::Excluded(value.clone()), Bound::Unbounded), + Operator::GtEq => { + SargableQuery::Range(Bound::Included(value.clone()), Bound::Unbounded) + } + Operator::Eq => SargableQuery::Equals(value.clone()), + // This will be negated by the caller + Operator::NotEq => SargableQuery::Equals(value.clone()), + _ => unreachable!(), + }; + Some(IndexedExpression::index_query_with_recheck( + column.to_string(), + self.index_name.clone(), + self.index_type.clone(), + Arc::new(query), + self.needs_recheck, + )) + } + + fn visit_scalar_function( + &self, + column: &str, + _data_type: &DataType, + func: &ScalarUDF, + args: &[Expr], + ) -> Option { + // Handle starts_with(col, 'prefix') -> convert to LikePrefix query + if func.name() == "starts_with" && args.len() == 2 { + // Extract the prefix from the second argument + let prefix = match &args[1] { + Expr::Literal(ScalarValue::Utf8(Some(s)), _) => ScalarValue::Utf8(Some(s.clone())), + Expr::Literal(ScalarValue::LargeUtf8(Some(s)), _) => { + ScalarValue::LargeUtf8(Some(s.clone())) + } + _ => return None, + }; + + let query = SargableQuery::LikePrefix(prefix); + return Some(IndexedExpression::index_query_with_recheck( + column.to_string(), + self.index_name.clone(), + self.index_type.clone(), + Arc::new(query), + self.needs_recheck, + )); + } + + None + } + + fn visit_like( + &self, + column: &str, + like: &Like, + pattern: &ScalarValue, + ) -> Option { + // Case-insensitive LIKE (ILIKE) cannot be efficiently pruned with zone maps + if like.case_insensitive { + return None; + } + + // Extract the pattern string + let pattern_str = match pattern { + ScalarValue::Utf8(Some(s)) => s.as_str(), + ScalarValue::LargeUtf8(Some(s)) => s.as_str(), + _ => return None, + }; + + // Try to extract a prefix from the LIKE pattern + let (prefix, needs_refine) = extract_like_leading_prefix(pattern_str, like.escape_char)?; + + // Create the prefix ScalarValue with the same type as the pattern + let prefix_value = match pattern { + ScalarValue::Utf8(_) => ScalarValue::Utf8(Some(prefix)), + ScalarValue::LargeUtf8(_) => ScalarValue::LargeUtf8(Some(prefix)), + _ => return None, + }; + + let query = SargableQuery::LikePrefix(prefix_value); + let scalar_query = Some(ScalarIndexExpr::Query(ScalarIndexSearch { + column: column.to_string(), + index_name: self.index_name.clone(), + index_type: self.index_type.clone(), + query: Arc::new(query), + needs_recheck: self.needs_recheck, + })); + + // If the pattern has wildcards beyond simple prefix, add refine expression + let refine_expr = if needs_refine { + Some(Expr::Like(like.clone())) + } else { + None + }; + + Some(IndexedExpression { + scalar_query, + refine_expr, + }) + } +} + +/// Extract the leading literal prefix from a LIKE pattern. +/// +/// Returns `Some((prefix, needs_refine))` where: +/// - `prefix` is the leading literal portion before any wildcards +/// - `needs_refine` is true if the pattern has wildcards beyond a simple trailing `%` +/// +/// Returns `None` if the pattern starts with a wildcard (no leading literal). +/// +/// Examples: +/// - "foo%" -> Some(("foo", false)) - pure prefix, no recheck needed +/// - "foo%bar%" -> Some(("foo", true)) - can use prefix for pruning, needs recheck +/// - "foo_bar%" -> Some(("foo", true)) - _ is a wildcard, needs recheck +/// - "foo\%bar%" with escape '\' -> Some(("foo%bar", false)) - escaped %, pure prefix +/// - "%foo" -> None - starts with wildcard, cannot prune +/// - "foo" -> None - no wildcard at all, use equality instead +fn extract_like_leading_prefix(pattern: &str, escape_char: Option) -> Option<(String, bool)> { + let chars: Vec = pattern.chars().collect(); + let len = chars.len(); + + if len == 0 { + return None; + } + + // DataFusion's starts_with simplification escapes special characters with backslash + // but doesn't set escape_char. Use backslash as default escape character. + // Pattern: starts_with(col, 'test_ns$') -> col LIKE 'test\_ns$%' (escape_char: None) + // See: https://github.com/apache/datafusion/issues/XXXX + let effective_escape_char = escape_char.or(Some('\\')); + + // Helper to check if a character at position i is escaped + let is_escaped = |i: usize| -> bool { + if let Some(esc) = effective_escape_char { + if i > 0 && chars[i - 1] == esc { + // Check if the escape char itself is escaped + if i >= 2 && chars[i - 2] == esc { + false // Escape was escaped, so this char is NOT escaped + } else { + true // This char is escaped + } + } else { + false + } + } else { + // No escape character defined - nothing can be escaped + false + } + }; + + // Pattern must contain at least one unescaped wildcard + let has_wildcard = chars.iter().enumerate().any(|(i, &c)| { + if c != '%' && c != '_' { + return false; + } + !is_escaped(i) + }); + + if !has_wildcard { + return None; // No wildcards, should use equality + } + + // Check if pattern starts with an unescaped wildcard + if chars[0] == '%' || chars[0] == '_' { + return None; // Starts with wildcard, cannot prune + } + + // Extract the leading literal prefix (everything before first unescaped wildcard) + let mut prefix = String::new(); + let mut i = 0; + let mut found_wildcard = false; + + while i < len { + let c = chars[i]; + + // Check for escape character (using effective escape char which may be inferred) + if let Some(esc) = effective_escape_char + && c == esc + && i + 1 < len + { + let next = chars[i + 1]; + if next == '%' || next == '_' || next == esc { + // Escaped character - add the literal character + prefix.push(next); + i += 2; + continue; + } + } + + // Check for unescaped wildcard + if c == '%' || c == '_' { + found_wildcard = true; + break; + } + + prefix.push(c); + i += 1; + } + + if prefix.is_empty() { + return None; + } + + // Check if pattern is just a simple prefix (ends with single % and nothing after) + let needs_refine = if found_wildcard && i < len { + // Check if we're at a % wildcard + if chars[i] == '%' && i + 1 == len { + // Pattern is "prefix%" - pure prefix match, no refine needed + false + } else { + // Pattern has more after first wildcard, or has _ wildcard + true + } + } else { + // No wildcard found (shouldn't happen due to earlier check) + false + }; + + Some((prefix, needs_refine)) +} + +/// A parser for bloom filter indices that only support equals, is_null, and is_in operations +#[derive(Debug)] +pub struct BloomFilterQueryParser { + index_name: String, + index_type: String, + needs_recheck: bool, +} + +impl BloomFilterQueryParser { + pub fn new(index_name: String, index_type: String, needs_recheck: bool) -> Self { + Self { + index_name, + index_type, + needs_recheck, + } + } +} + +impl ScalarQueryParser for BloomFilterQueryParser { + fn visit_between( + &self, + _: &str, + _: &Bound, + _: &Bound, + ) -> Option { + // Bloom filters don't support range queries + None + } + + fn visit_in_list(&self, column: &str, in_list: &[ScalarValue]) -> Option { + let query = BloomFilterQuery::IsIn(in_list.to_vec()); + Some(IndexedExpression::index_query_with_recheck( + column.to_string(), + self.index_name.clone(), + self.index_type.clone(), + Arc::new(query), + self.needs_recheck, + )) + } + + fn visit_is_bool(&self, column: &str, value: bool) -> Option { + Some(IndexedExpression::index_query_with_recheck( + column.to_string(), + self.index_name.clone(), + self.index_type.clone(), + Arc::new(BloomFilterQuery::Equals(ScalarValue::Boolean(Some(value)))), + self.needs_recheck, + )) + } + + fn visit_is_null(&self, column: &str) -> Option { + Some(IndexedExpression::index_query_with_recheck( + column.to_string(), + self.index_name.clone(), + self.index_type.clone(), + Arc::new(BloomFilterQuery::IsNull()), + self.needs_recheck, + )) + } + + fn visit_comparison( + &self, + column: &str, + value: &ScalarValue, + op: &Operator, + ) -> Option { + let query = match op { + // Bloom filters only support equality comparisons + Operator::Eq => BloomFilterQuery::Equals(value.clone()), + // This will be negated by the caller + Operator::NotEq => BloomFilterQuery::Equals(value.clone()), + // Bloom filters don't support range operations + _ => return None, + }; + Some(IndexedExpression::index_query_with_recheck( + column.to_string(), + self.index_name.clone(), + self.index_type.clone(), + Arc::new(query), + self.needs_recheck, + )) + } + + fn visit_scalar_function( + &self, + _: &str, + _: &DataType, + _: &ScalarUDF, + _: &[Expr], + ) -> Option { + // Bloom filters don't support scalar functions + None + } +} + +/// A parser for indices that handle label list queries +#[derive(Debug)] +pub struct LabelListQueryParser { + index_name: String, + index_type: String, +} + +impl LabelListQueryParser { + pub fn new(index_name: String, index_type: String) -> Self { + Self { + index_name, + index_type, + } + } +} + +impl ScalarQueryParser for LabelListQueryParser { + fn visit_between( + &self, + _: &str, + _: &Bound, + _: &Bound, + ) -> Option { + None + } + + fn visit_in_list(&self, _: &str, _: &[ScalarValue]) -> Option { + None + } + + fn visit_is_bool(&self, _: &str, _: bool) -> Option { + None + } + + fn visit_is_null(&self, _: &str) -> Option { + None + } + + fn visit_comparison( + &self, + _: &str, + _: &ScalarValue, + _: &Operator, + ) -> Option { + None + } + + fn visit_scalar_function( + &self, + column: &str, + data_type: &DataType, + func: &ScalarUDF, + args: &[Expr], + ) -> Option { + if args.len() != 2 { + return None; + } + // DataFusion normalizes array_contains to array_has + if func.name() == "array_has" { + let inner_type = match data_type { + DataType::List(field) | DataType::LargeList(field) => field.data_type(), + _ => return None, + }; + let scalar = maybe_scalar(&args[1], inner_type)?; + // array_has(..., NULL) returns no matches in datafusion, but the index would + // match rows containing NULL. Fallback to match datafusion behavior. + if scalar.is_null() { + return None; + } + let query = LabelListQuery::HasAnyLabel(vec![scalar]); + return Some(IndexedExpression::index_query( + column.to_string(), + self.index_name.clone(), + self.index_type.clone(), + Arc::new(query), + )); + } + + let label_list = maybe_scalar(&args[1], data_type)?; + if let ScalarValue::List(list_arr) = label_list { + let list_values = list_arr.values(); + if list_values.is_empty() { + return None; + } + let mut scalars = Vec::with_capacity(list_values.len()); + for idx in 0..list_values.len() { + scalars.push(ScalarValue::try_from_array(list_values.as_ref(), idx).ok()?); + } + if func.name() == "array_has_all" { + let query = LabelListQuery::HasAllLabels(scalars); + Some(IndexedExpression::index_query( + column.to_string(), + self.index_name.clone(), + self.index_type.clone(), + Arc::new(query), + )) + } else if func.name() == "array_has_any" { + let query = LabelListQuery::HasAnyLabel(scalars); + Some(IndexedExpression::index_query( + column.to_string(), + self.index_name.clone(), + self.index_type.clone(), + Arc::new(query), + )) + } else { + None + } + } else { + None + } + } +} + +/// A parser for indices that handle string contains queries +#[derive(Debug, Clone)] +pub struct TextQueryParser { + index_name: String, + index_type: String, + needs_recheck: bool, +} + +impl TextQueryParser { + pub fn new(index_name: String, index_type: String, needs_recheck: bool) -> Self { + Self { + index_name, + index_type, + needs_recheck, + } + } +} + +impl ScalarQueryParser for TextQueryParser { + fn visit_between( + &self, + _: &str, + _: &Bound, + _: &Bound, + ) -> Option { + None + } + + fn visit_in_list(&self, _: &str, _: &[ScalarValue]) -> Option { + None + } + + fn visit_is_bool(&self, _: &str, _: bool) -> Option { + None + } + + fn visit_is_null(&self, _: &str) -> Option { + None + } + + fn visit_comparison( + &self, + _: &str, + _: &ScalarValue, + _: &Operator, + ) -> Option { + None + } + + fn visit_scalar_function( + &self, + column: &str, + data_type: &DataType, + func: &ScalarUDF, + args: &[Expr], + ) -> Option { + if args.len() != 2 { + return None; + } + let scalar = maybe_scalar(&args[1], data_type)?; + match scalar { + ScalarValue::Utf8(Some(scalar_str)) | ScalarValue::LargeUtf8(Some(scalar_str)) => { + if func.name() == "contains" { + let query = TextQuery::StringContains(scalar_str); + Some(IndexedExpression::index_query_with_recheck( + column.to_string(), + self.index_name.clone(), + self.index_type.clone(), + Arc::new(query), + self.needs_recheck, + )) + } else { + None + } + } + _ => { + // If the scalar is not a string, we cannot handle it + None + } + } + } +} + +/// A parser for indices that handle queries with the contains_tokens function +#[derive(Debug, Clone)] +pub struct FtsQueryParser { + index_name: String, + index_type: String, +} + +impl FtsQueryParser { + pub fn new(name: String, index_type: String) -> Self { + Self { + index_name: name, + index_type, + } + } +} + +impl ScalarQueryParser for FtsQueryParser { + fn visit_between( + &self, + _: &str, + _: &Bound, + _: &Bound, + ) -> Option { + None + } + + fn visit_in_list(&self, _: &str, _: &[ScalarValue]) -> Option { + None + } + + fn visit_is_bool(&self, _: &str, _: bool) -> Option { + None + } + + fn visit_is_null(&self, _: &str) -> Option { + None + } + + fn visit_comparison( + &self, + _: &str, + _: &ScalarValue, + _: &Operator, + ) -> Option { + None + } + + fn visit_scalar_function( + &self, + column: &str, + data_type: &DataType, + func: &ScalarUDF, + args: &[Expr], + ) -> Option { + if args.len() != 2 { + return None; + } + let scalar = maybe_scalar(&args[1], data_type)?; + if let ScalarValue::Utf8(Some(scalar_str)) = scalar + && func.name() == "contains_tokens" + { + let query = TokenQuery::TokensContains(scalar_str); + return Some(IndexedExpression::index_query( + column.to_string(), + self.index_name.clone(), + self.index_type.clone(), + Arc::new(query), + )); + } + None + } +} + +/// A parser for geo indices that handles spatial queries +#[cfg(feature = "geo")] +#[derive(Debug, Clone)] +pub struct GeoQueryParser { + index_name: String, + index_type: String, +} + +#[cfg(feature = "geo")] +impl GeoQueryParser { + pub fn new(index_name: String, index_type: String) -> Self { + Self { + index_name, + index_type, + } + } +} + +#[cfg(feature = "geo")] +impl ScalarQueryParser for GeoQueryParser { + fn visit_between( + &self, + _: &str, + _: &Bound, + _: &Bound, + ) -> Option { + None + } + + fn visit_in_list(&self, _: &str, _: &[ScalarValue]) -> Option { + None + } + + fn visit_is_bool(&self, _: &str, _: bool) -> Option { + None + } + + fn visit_is_null(&self, column: &str) -> Option { + Some(IndexedExpression::index_query_with_recheck( + column.to_string(), + self.index_name.clone(), + self.index_type.clone(), + Arc::new(GeoQuery::IsNull), + true, + )) + } + + fn visit_comparison( + &self, + _: &str, + _: &ScalarValue, + _: &Operator, + ) -> Option { + None + } + + fn visit_scalar_function( + &self, + column: &str, + _data_type: &DataType, + func: &ScalarUDF, + args: &[Expr], + ) -> Option { + if (func.name() == "st_intersects" + || func.name() == "st_contains" + || func.name() == "st_within" + || func.name() == "st_touches" + || func.name() == "st_crosses" + || func.name() == "st_overlaps" + || func.name() == "st_covers" + || func.name() == "st_coveredby") + && args.len() == 2 + { + let left_arg = &args[0]; + let right_arg = &args[1]; + return match (left_arg, right_arg) { + (Expr::Literal(left_value, metadata), Expr::Column(_)) => { + let mut field = Field::new("_geo", left_value.data_type(), false); + if let Some(metadata) = metadata { + field = field.with_metadata(metadata.to_hashmap()); + } + let query = GeoQuery::IntersectQuery(RelationQuery { + value: left_value.clone(), + field, + }); + Some(IndexedExpression::index_query_with_recheck( + column.to_string(), + self.index_name.clone(), + self.index_type.clone(), + Arc::new(query), + true, + )) + } + (Expr::Column(_), Expr::Literal(right_value, metadata)) => { + let mut field = Field::new("_geo", right_value.data_type(), false); + if let Some(metadata) = metadata { + field = field.with_metadata(metadata.to_hashmap()); + } + let query = GeoQuery::IntersectQuery(RelationQuery { + value: right_value.clone(), + field, + }); + Some(IndexedExpression::index_query_with_recheck( + column.to_string(), + self.index_name.clone(), + self.index_type.clone(), + Arc::new(query), + true, + )) + } + _ => None, + }; + } + None + } +} + +impl IndexedExpression { + /// Create an expression that only does refine + fn refine_only(refine_expr: Expr) -> Self { + Self { + scalar_query: None, + refine_expr: Some(refine_expr), + } + } + + /// Create an expression that is only an index query + fn index_query( + column: String, + index_name: String, + index_type: String, + query: Arc, + ) -> Self { + Self { + scalar_query: Some(ScalarIndexExpr::Query(ScalarIndexSearch { + column, + index_name, + index_type, + query, + needs_recheck: false, // Default to false, will be set by parser + })), + refine_expr: None, + } + } + + /// Create an expression that is only an index query with explicit needs_recheck + fn index_query_with_recheck( + column: String, + index_name: String, + index_type: String, + query: Arc, + needs_recheck: bool, + ) -> Self { + Self { + scalar_query: Some(ScalarIndexExpr::Query(ScalarIndexSearch { + column, + index_name, + index_type, + query, + needs_recheck, + })), + refine_expr: None, + } + } + + /// Try and negate the expression + /// + /// If the expression contains both an index query and a refine expression then it + /// cannot be negated today and None will be returned (we give up trying to use indices) + fn maybe_not(self) -> Option { + match (self.scalar_query, self.refine_expr) { + (Some(_), Some(_)) => None, + (Some(scalar_query), None) => { + if scalar_query.needs_recheck() { + return None; + } + Some(Self { + scalar_query: Some(ScalarIndexExpr::Not(Box::new(scalar_query))), + refine_expr: None, + }) + } + (None, Some(refine_expr)) => Some(Self { + scalar_query: None, + refine_expr: Some(Expr::Not(Box::new(refine_expr))), + }), + (None, None) => panic!("Empty node should not occur"), + } + } + + /// Perform a logical AND of two indexed expressions + /// + /// This is straightforward because we can just AND the individual parts + /// because (A && B) && (C && D) == (A && C) && (B && D) + fn and(self, other: Self) -> Self { + let scalar_query = match (self.scalar_query, other.scalar_query) { + (Some(scalar_query), Some(other_scalar_query)) => Some(ScalarIndexExpr::And( + Box::new(scalar_query), + Box::new(other_scalar_query), + )), + (Some(scalar_query), None) => Some(scalar_query), + (None, Some(scalar_query)) => Some(scalar_query), + (None, None) => None, + }; + let refine_expr = match (self.refine_expr, other.refine_expr) { + (Some(refine_expr), Some(other_refine_expr)) => { + Some(refine_expr.and(other_refine_expr)) + } + (Some(refine_expr), None) => Some(refine_expr), + (None, Some(refine_expr)) => Some(refine_expr), + (None, None) => None, + }; + Self { + scalar_query, + refine_expr, + } + } + + /// Try and perform a logical OR of two indexed expressions + /// + /// This is a bit tricky because something like: + /// (color == 'blue' AND size < 20) OR (color == 'green' AND size < 50) + /// is not equivalent to: + /// (color == 'blue' OR color == 'green') AND (size < 20 OR size < 50) + fn maybe_or(self, other: Self) -> Option { + // If either expression is missing a scalar_query then we need to load all rows from + // the database and so we short-circuit and return None + let scalar_query = self.scalar_query?; + let other_scalar_query = other.scalar_query?; + let scalar_query = Some(ScalarIndexExpr::Or( + Box::new(scalar_query), + Box::new(other_scalar_query), + )); + + let refine_expr = match (self.refine_expr, other.refine_expr) { + // TODO + // + // To handle these cases we need a way of going back from a scalar expression query to a logical DF expression (perhaps + // we can store the expression that led to the creation of the query) + // + // For example, imagine we have something like "(color == 'blue' AND size < 20) OR (color == 'green' AND size < 50)" + // + // We can do an indexed load of all rows matching "color == 'blue' OR color == 'green'" but then we need to + // refine that load with the full original expression which, at the moment, we no longer have. + (Some(_), Some(_)) => { + return None; + } + (Some(_), None) => { + return None; + } + (None, Some(_)) => { + return None; + } + (None, None) => None, + }; + Some(Self { + scalar_query, + refine_expr, + }) + } + + fn refine(self, expr: Expr) -> Self { + match self.refine_expr { + Some(refine_expr) => Self { + scalar_query: self.scalar_query, + refine_expr: Some(refine_expr.and(expr)), + }, + None => Self { + scalar_query: self.scalar_query, + refine_expr: Some(expr), + }, + } + } +} + +/// A trait implemented by anything that can load indices by name +/// +/// This is used during the evaluation of an index expression +#[async_trait] +pub trait ScalarIndexLoader: Send + Sync { + /// Load the index with the given name + async fn load_index( + &self, + column: &str, + index_name: &str, + metrics: &dyn MetricsCollector, + ) -> Result>; +} + +/// This represents a search into a scalar index +#[derive(Debug, Clone)] +pub struct ScalarIndexSearch { + /// The column to search (redundant, used for debugging messages) + pub column: String, + /// The name of the index to search + pub index_name: String, + /// The type of the index being searched (e.g. "BTree", "Bitmap"), used for display purposes + pub index_type: String, + /// The query to search for + pub query: Arc, + /// If true, the query results are inexact and will need a recheck + pub needs_recheck: bool, +} + +impl PartialEq for ScalarIndexSearch { + fn eq(&self, other: &Self) -> bool { + self.column == other.column + && self.index_name == other.index_name + && self.query.as_ref().eq(other.query.as_ref()) + } +} + +/// This represents a lookup into one or more scalar indices +/// +/// This is a tree of operations because we may need to logically combine or +/// modify the results of scalar lookups +#[derive(Debug, Clone)] +pub enum ScalarIndexExpr { + Not(Box), + And(Box, Box), + Or(Box, Box), + Query(ScalarIndexSearch), +} + +impl PartialEq for ScalarIndexExpr { + fn eq(&self, other: &Self) -> bool { + match (self, other) { + (Self::Not(l0), Self::Not(r0)) => l0 == r0, + (Self::And(l0, l1), Self::And(r0, r1)) => l0 == r0 && l1 == r1, + (Self::Or(l0, l1), Self::Or(r0, r1)) => l0 == r0 && l1 == r1, + (Self::Query(l_search), Self::Query(r_search)) => l_search == r_search, + _ => false, + } + } +} + +impl std::fmt::Display for ScalarIndexExpr { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Self::Not(inner) => write!(f, "NOT({})", inner), + Self::And(lhs, rhs) => write!(f, "AND({},{})", lhs, rhs), + Self::Or(lhs, rhs) => write!(f, "OR({},{})", lhs, rhs), + Self::Query(search) => write!( + f, + "[{}]@{}({})", + search.query.format(&search.column), + search.index_name, + search.index_type + ), + } + } +} + +/// When we evaluate a scalar index query we return a batch with three columns and two rows +/// +/// The first column has the block list and allow list +/// The second column tells if the result is least/exact/more (we repeat the discriminant twice) +/// The third column has the fragments covered bitmap in the first row and null in the second row +pub static INDEX_EXPR_RESULT_SCHEMA: LazyLock = LazyLock::new(|| { + Arc::new(Schema::new(vec![ + Field::new("result".to_string(), DataType::Binary, true), + Field::new("discriminant".to_string(), DataType::UInt32, true), + Field::new("fragments_covered".to_string(), DataType::Binary, true), + ])) +}); + +#[derive(Debug)] +enum NullableIndexExprResult { + Exact(NullableRowAddrMask), + AtMost(NullableRowAddrMask), + AtLeast(NullableRowAddrMask), +} + +impl From for NullableIndexExprResult { + fn from(result: SearchResult) -> Self { + match result { + SearchResult::Exact(mask) => Self::Exact(NullableRowAddrMask::AllowList(mask)), + SearchResult::AtMost(mask) => Self::AtMost(NullableRowAddrMask::AllowList(mask)), + SearchResult::AtLeast(mask) => Self::AtLeast(NullableRowAddrMask::AllowList(mask)), + } + } +} + +impl std::ops::BitAnd for NullableIndexExprResult { + type Output = Self; + + fn bitand(self, rhs: Self) -> Self { + match (self, rhs) { + (Self::Exact(lhs), Self::Exact(rhs)) => Self::Exact(lhs & rhs), + (Self::Exact(lhs), Self::AtMost(rhs)) | (Self::AtMost(lhs), Self::Exact(rhs)) => { + Self::AtMost(lhs & rhs) + } + (Self::Exact(exact), Self::AtLeast(_)) | (Self::AtLeast(_), Self::Exact(exact)) => { + // We could do better here, elements in both lhs and rhs are known + // to be true and don't require a recheck. We only need to recheck + // elements in lhs that are not in rhs + Self::AtMost(exact) + } + (Self::AtMost(lhs), Self::AtMost(rhs)) => Self::AtMost(lhs & rhs), + (Self::AtLeast(lhs), Self::AtLeast(rhs)) => Self::AtLeast(lhs & rhs), + (Self::AtMost(most), Self::AtLeast(_)) | (Self::AtLeast(_), Self::AtMost(most)) => { + Self::AtMost(most) + } + } + } +} + +impl std::ops::BitOr for NullableIndexExprResult { + type Output = Self; + + fn bitor(self, rhs: Self) -> Self { + match (self, rhs) { + (Self::Exact(lhs), Self::Exact(rhs)) => Self::Exact(lhs | rhs), + (Self::Exact(lhs), Self::AtMost(rhs)) | (Self::AtMost(rhs), Self::Exact(lhs)) => { + // We could do better here, elements in lhs are known to be true + // and don't require a recheck. We only need to recheck elements + // in rhs that are not in lhs + Self::AtMost(lhs | rhs) + } + (Self::Exact(lhs), Self::AtLeast(rhs)) | (Self::AtLeast(rhs), Self::Exact(lhs)) => { + Self::AtLeast(lhs | rhs) + } + (Self::AtMost(lhs), Self::AtMost(rhs)) => Self::AtMost(lhs | rhs), + (Self::AtLeast(lhs), Self::AtLeast(rhs)) => Self::AtLeast(lhs | rhs), + (Self::AtMost(_), Self::AtLeast(least)) | (Self::AtLeast(least), Self::AtMost(_)) => { + Self::AtLeast(least) + } + } + } +} + +impl NullableIndexExprResult { + pub fn drop_nulls(self) -> IndexExprResult { + match self { + Self::Exact(mask) => IndexExprResult::Exact(mask.drop_nulls()), + Self::AtMost(mask) => IndexExprResult::AtMost(mask.drop_nulls()), + Self::AtLeast(mask) => IndexExprResult::AtLeast(mask.drop_nulls()), + } + } +} + +#[derive(Debug)] +pub enum IndexExprResult { + // The answer is exactly the rows in the allow list minus the rows in the block list + Exact(RowAddrMask), + // The answer is at most the rows in the allow list minus the rows in the block list + // Some of the rows in the allow list may not be in the result and will need to be filtered + // by a recheck. Every row in the block list is definitely not in the result. + AtMost(RowAddrMask), + // The answer is at least the rows in the allow list minus the rows in the block list + // Some of the rows in the block list might be in the result. Every row in the allow list is + // definitely in the result. + AtLeast(RowAddrMask), +} + +impl IndexExprResult { + pub fn row_addr_mask(&self) -> &RowAddrMask { + match self { + Self::Exact(mask) => mask, + Self::AtMost(mask) => mask, + Self::AtLeast(mask) => mask, + } + } + + pub fn discriminant(&self) -> u32 { + match self { + Self::Exact(_) => 0, + Self::AtMost(_) => 1, + Self::AtLeast(_) => 2, + } + } + + pub fn from_parts(mask: RowAddrMask, discriminant: u32) -> Result { + match discriminant { + 0 => Ok(Self::Exact(mask)), + 1 => Ok(Self::AtMost(mask)), + 2 => Ok(Self::AtLeast(mask)), + _ => Err(Error::invalid_input_source( + format!("Invalid IndexExprResult discriminant: {}", discriminant).into(), + )), + } + } + + #[instrument(skip_all)] + pub fn serialize_to_arrow( + &self, + fragments_covered_by_result: &RoaringBitmap, + ) -> Result { + let row_addr_mask = self.row_addr_mask(); + let row_addr_mask_arr = row_addr_mask.into_arrow()?; + let discriminant = self.discriminant(); + let discriminant_arr = + Arc::new(UInt32Array::from(vec![discriminant, discriminant])) as Arc; + let mut fragments_covered_builder = BinaryBuilder::new(); + let fragments_covered_bytes_len = fragments_covered_by_result.serialized_size(); + let mut fragments_covered_bytes = Vec::with_capacity(fragments_covered_bytes_len); + fragments_covered_by_result.serialize_into(&mut fragments_covered_bytes)?; + fragments_covered_builder.append_value(fragments_covered_bytes); + fragments_covered_builder.append_null(); + let fragments_covered_arr = Arc::new(fragments_covered_builder.finish()) as Arc; + Ok(RecordBatch::try_new( + INDEX_EXPR_RESULT_SCHEMA.clone(), + vec![ + Arc::new(row_addr_mask_arr), + Arc::new(discriminant_arr), + Arc::new(fragments_covered_arr), + ], + )?) + } +} + +impl ScalarIndexExpr { + /// Evaluates the scalar index expression + /// + /// This will result in loading one or more scalar indices and searching them + /// + /// TODO: We could potentially try and be smarter about reusing loaded indices for + /// any situations where the session cache has been disabled. + #[async_recursion] + async fn evaluate_impl( + &self, + index_loader: &dyn ScalarIndexLoader, + metrics: &dyn MetricsCollector, + ) -> Result { + match self { + Self::Not(inner) => { + let result = inner.evaluate_impl(index_loader, metrics).await?; + // Flip certainty: NOT(AtMost) → AtLeast, NOT(AtLeast) → AtMost + Ok(match result { + NullableIndexExprResult::Exact(mask) => NullableIndexExprResult::Exact(!mask), + NullableIndexExprResult::AtMost(mask) => { + NullableIndexExprResult::AtLeast(!mask) + } + NullableIndexExprResult::AtLeast(mask) => { + NullableIndexExprResult::AtMost(!mask) + } + }) + } + Self::And(lhs, rhs) => { + let lhs_result = lhs.evaluate_impl(index_loader, metrics); + let rhs_result = rhs.evaluate_impl(index_loader, metrics); + let (lhs_result, rhs_result) = try_join!(lhs_result, rhs_result)?; + Ok(lhs_result & rhs_result) + } + Self::Or(lhs, rhs) => { + let lhs_result = lhs.evaluate_impl(index_loader, metrics); + let rhs_result = rhs.evaluate_impl(index_loader, metrics); + let (lhs_result, rhs_result) = try_join!(lhs_result, rhs_result)?; + Ok(lhs_result | rhs_result) + } + Self::Query(search) => { + let index = index_loader + .load_index(&search.column, &search.index_name, metrics) + .await?; + let search_result = index.search(search.query.as_ref(), metrics).await?; + Ok(search_result.into()) + } + } + } + + #[instrument(level = "debug", skip_all)] + pub async fn evaluate( + &self, + index_loader: &dyn ScalarIndexLoader, + metrics: &dyn MetricsCollector, + ) -> Result { + Ok(self + .evaluate_impl(index_loader, metrics) + .await? + .drop_nulls()) + } + + pub fn to_expr(&self) -> Expr { + match self { + Self::Not(inner) => Expr::Not(inner.to_expr().into()), + Self::And(lhs, rhs) => { + let lhs = lhs.to_expr(); + let rhs = rhs.to_expr(); + lhs.and(rhs) + } + Self::Or(lhs, rhs) => { + let lhs = lhs.to_expr(); + let rhs = rhs.to_expr(); + lhs.or(rhs) + } + Self::Query(search) => search.query.to_expr(search.column.clone()), + } + } + + pub fn needs_recheck(&self) -> bool { + match self { + Self::Not(inner) => inner.needs_recheck(), + Self::And(lhs, rhs) | Self::Or(lhs, rhs) => lhs.needs_recheck() || rhs.needs_recheck(), + Self::Query(search) => search.needs_recheck, + } + } +} + +// Extract a column from the expression, if it is a column, or None +fn maybe_column(expr: &Expr) -> Option<&str> { + match expr { + Expr::Column(col) => Some(&col.name), + _ => None, + } +} + +// Extract the full nested column path from a get_field expression chain +// For example: get_field(get_field(metadata, "status"), "code") -> "metadata.status.code" +fn extract_nested_column_path(expr: &Expr) -> Option { + let mut current_expr = expr; + let mut parts = Vec::new(); + + // Walk up the get_field chain + loop { + match current_expr { + Expr::ScalarFunction(udf) if udf.name() == "get_field" => { + if udf.args.len() != 2 { + return None; + } + // Extract the field name from the second argument + // The Literal now has two fields: ScalarValue and Option + if let Expr::Literal(ScalarValue::Utf8(Some(field_name)), _) = &udf.args[1] { + parts.push(field_name.clone()); + } else { + return None; + } + // Move up to the parent expression + current_expr = &udf.args[0]; + } + Expr::Column(col) => { + // We've reached the base column + parts.push(col.name.clone()); + break; + } + _ => { + return None; + } + } + } + + // Reverse to get the correct order (parent.child.grandchild) + parts.reverse(); + + // Format the path correctly + let field_refs: Vec<&str> = parts.iter().map(|s| s.as_str()).collect(); + Some(lance_core::datatypes::format_field_path(&field_refs)) +} + +// Extract a column from the expression, if it is a column, and we have an index for that column, or None +// +// There's two ways to get a column. First, the obvious way, is a +// simple column reference (e.g. x = 7). Second, a more complex way, +// is some kind of projection into a column (e.g. json_extract(json, '$.name')). +// Third way is nested field access (e.g. get_field(metadata, "status.code")) +fn maybe_indexed_column<'b>( + expr: &Expr, + index_info: &'b dyn IndexInformationProvider, +) -> Option<(String, DataType, &'b dyn ScalarQueryParser)> { + // First try to extract the full nested column path for get_field expressions + if let Some(nested_path) = extract_nested_column_path(expr) + && let Some((data_type, parser)) = index_info.get_index(&nested_path) + && let Some(data_type) = parser.is_valid_reference(expr, data_type) + { + return Some((nested_path, data_type, parser)); + } + + match expr { + Expr::Column(col) => { + let col = col.name.as_str(); + let (data_type, parser) = index_info.get_index(col)?; + if let Some(data_type) = parser.is_valid_reference(expr, data_type) { + Some((col.to_string(), data_type, parser)) + } else { + None + } + } + Expr::ScalarFunction(udf) => { + if udf.args.is_empty() { + return None; + } + // For non-get_field functions, fall back to old behavior + let col = maybe_column(&udf.args[0])?; + let (data_type, parser) = index_info.get_index(col)?; + if let Some(data_type) = parser.is_valid_reference(expr, data_type) { + Some((col.to_string(), data_type, parser)) + } else { + None + } + } + _ => None, + } +} + +// Extract a literal scalar value from an expression, if it is a literal, or None +fn maybe_scalar(expr: &Expr, expected_type: &DataType) -> Option { + match expr { + Expr::Literal(value, _) => safe_coerce_scalar(value, expected_type), + // Some literals can't be expressed in datafusion's SQL and can only be expressed with + // a cast. For example, there is no way to express a fixed-size-binary literal (which is + // commonly used for UUID). As a result the expression could look like... + // + // col = arrow_cast(value, 'fixed_size_binary(16)') + // + // In this case we need to extract the value, apply the cast, and then test the casted value + Expr::Cast(cast) => match cast.expr.as_ref() { + Expr::Literal(value, _) => { + let casted = value.cast_to(&cast.data_type).ok()?; + safe_coerce_scalar(&casted, expected_type) + } + _ => None, + }, + Expr::ScalarFunction(scalar_function) => { + if scalar_function.name() == "arrow_cast" { + if scalar_function.args.len() != 2 { + return None; + } + match (&scalar_function.args[0], &scalar_function.args[1]) { + (Expr::Literal(value, _), Expr::Literal(cast_type, _)) => { + let target_type = scalar_function + .func + .return_field_from_args(ReturnFieldArgs { + arg_fields: &[ + Arc::new(Field::new("expression", value.data_type(), false)), + Arc::new(Field::new("datatype", cast_type.data_type(), false)), + ], + scalar_arguments: &[Some(value), Some(cast_type)], + }) + .ok()?; + let casted = value.cast_to(target_type.data_type()).ok()?; + safe_coerce_scalar(&casted, expected_type) + } + _ => None, + } + } else { + None + } + } + _ => None, + } +} + +// Extract a list of scalar values from an expression, if it is a list of scalar values, or None +fn maybe_scalar_list(exprs: &Vec, expected_type: &DataType) -> Option> { + let mut scalar_values = Vec::with_capacity(exprs.len()); + for expr in exprs { + match maybe_scalar(expr, expected_type) { + Some(scalar_val) => { + scalar_values.push(scalar_val); + } + None => { + return None; + } + } + } + Some(scalar_values) +} + +fn visit_between( + between: &Between, + index_info: &dyn IndexInformationProvider, +) -> Option { + let (column, col_type, query_parser) = maybe_indexed_column(&between.expr, index_info)?; + let low = maybe_scalar(&between.low, &col_type)?; + let high = maybe_scalar(&between.high, &col_type)?; + + let indexed_expr = + query_parser.visit_between(&column, &Bound::Included(low), &Bound::Included(high))?; + + if between.negated { + indexed_expr.maybe_not() + } else { + Some(indexed_expr) + } +} + +fn visit_in_list( + in_list: &InList, + index_info: &dyn IndexInformationProvider, +) -> Option { + let (column, col_type, query_parser) = maybe_indexed_column(&in_list.expr, index_info)?; + let values = maybe_scalar_list(&in_list.list, &col_type)?; + + let indexed_expr = query_parser.visit_in_list(&column, &values)?; + + if in_list.negated { + indexed_expr.maybe_not() + } else { + Some(indexed_expr) + } +} + +fn visit_is_bool( + expr: &Expr, + index_info: &dyn IndexInformationProvider, + value: bool, +) -> Option { + let (column, col_type, query_parser) = maybe_indexed_column(expr, index_info)?; + if col_type != DataType::Boolean { + None + } else { + query_parser.visit_is_bool(&column, value) + } +} + +// A column can be a valid indexed expression if the column is boolean (e.g. 'WHERE on_sale') +fn visit_column( + col: &Expr, + index_info: &dyn IndexInformationProvider, +) -> Option { + let (column, col_type, query_parser) = maybe_indexed_column(col, index_info)?; + if col_type != DataType::Boolean { + None + } else { + query_parser.visit_is_bool(&column, true) + } +} + +fn visit_is_null( + expr: &Expr, + index_info: &dyn IndexInformationProvider, + negated: bool, +) -> Option { + let (column, _, query_parser) = maybe_indexed_column(expr, index_info)?; + let indexed_expr = query_parser.visit_is_null(&column)?; + if negated { + indexed_expr.maybe_not() + } else { + Some(indexed_expr) + } +} + +fn visit_not( + expr: &Expr, + index_info: &dyn IndexInformationProvider, + depth: usize, +) -> Result> { + let node = visit_node(expr, index_info, depth + 1)?; + Ok(node.and_then(|node| node.maybe_not())) +} + +fn visit_comparison( + expr: &BinaryExpr, + index_info: &dyn IndexInformationProvider, +) -> Option { + let left_col = maybe_indexed_column(&expr.left, index_info); + if let Some((column, col_type, query_parser)) = left_col { + let scalar = maybe_scalar(&expr.right, &col_type)?; + query_parser.visit_comparison(&column, &scalar, &expr.op) + } else { + // Datafusion's query simplifier will canonicalize expressions and so we shouldn't reach this case. If, for some reason, we + // do reach this case we can handle it in the future by inverting expr.op and swapping the left and right sides + None + } +} + +fn maybe_range( + expr: &BinaryExpr, + index_info: &dyn IndexInformationProvider, +) -> Option { + let left_expr = match expr.left.as_ref() { + Expr::BinaryExpr(binary_expr) => Some(binary_expr), + _ => None, + }?; + let right_expr = match expr.right.as_ref() { + Expr::BinaryExpr(binary_expr) => Some(binary_expr), + _ => None, + }?; + + let (left_col, dt, parser) = maybe_indexed_column(&left_expr.left, index_info)?; + let right_col = maybe_column(&right_expr.left)?; + + if left_col != right_col { + return None; + } + + let left_value = maybe_scalar(&left_expr.right, &dt)?; + let right_value = maybe_scalar(&right_expr.right, &dt)?; + + let (low, high) = match (left_expr.op, right_expr.op) { + // x >= a && x <= b + (Operator::GtEq, Operator::LtEq) => { + (Bound::Included(left_value), Bound::Included(right_value)) + } + // x >= a && x < b + (Operator::GtEq, Operator::Lt) => { + (Bound::Included(left_value), Bound::Excluded(right_value)) + } + // x > a && x <= b + (Operator::Gt, Operator::LtEq) => { + (Bound::Excluded(left_value), Bound::Included(right_value)) + } + // x > a && x < b + (Operator::Gt, Operator::Lt) => (Bound::Excluded(left_value), Bound::Excluded(right_value)), + // x <= a && x >= b + (Operator::LtEq, Operator::GtEq) => { + (Bound::Included(right_value), Bound::Included(left_value)) + } + // x <= a && x > b + (Operator::LtEq, Operator::Gt) => { + (Bound::Excluded(right_value), Bound::Included(left_value)) + } + // x < a && x >= b + (Operator::Lt, Operator::GtEq) => { + (Bound::Included(right_value), Bound::Excluded(left_value)) + } + // x < a && x > b + (Operator::Lt, Operator::Gt) => (Bound::Excluded(right_value), Bound::Excluded(left_value)), + _ => return None, + }; + + parser.visit_between(&left_col, &low, &high) +} + +fn visit_and( + expr: &BinaryExpr, + index_info: &dyn IndexInformationProvider, + depth: usize, +) -> Result> { + // Many scalar indices can efficiently handle a BETWEEN query as a single search and this + // can be much more efficient than two separate range queries. As an optimization we check + // to see if this is a between query and, if so, we handle it as a single query + // + // Note: We can't rely on users writing the SQL BETWEEN operator because: + // * Some users won't realize it's an option or a good idea + // * Datafusion's simplifier will rewrite the BETWEEN operator into two separate range queries + if let Some(range_expr) = maybe_range(expr, index_info) { + return Ok(Some(range_expr)); + } + + let left = visit_node(&expr.left, index_info, depth + 1)?; + let right = visit_node(&expr.right, index_info, depth + 1)?; + Ok(match (left, right) { + (Some(left), Some(right)) => Some(left.and(right)), + (Some(left), None) => Some(left.refine((*expr.right).clone())), + (None, Some(right)) => Some(right.refine((*expr.left).clone())), + (None, None) => None, + }) +} + +fn visit_or( + expr: &BinaryExpr, + index_info: &dyn IndexInformationProvider, + depth: usize, +) -> Result> { + let left = visit_node(&expr.left, index_info, depth + 1)?; + let right = visit_node(&expr.right, index_info, depth + 1)?; + Ok(match (left, right) { + (Some(left), Some(right)) => left.maybe_or(right), + // If one side can use an index and the other side cannot then + // we must abandon the entire thing. For example, consider the + // query "color == 'blue' or size > 10" where color is indexed but + // size is not. It's entirely possible that size > 10 matches every + // row in our database. There is nothing we can do except a full scan + (Some(_), None) => None, + (None, Some(_)) => None, + (None, None) => None, + }) +} + +fn visit_binary_expr( + expr: &BinaryExpr, + index_info: &dyn IndexInformationProvider, + depth: usize, +) -> Result> { + match &expr.op { + Operator::Lt | Operator::LtEq | Operator::Gt | Operator::GtEq | Operator::Eq => { + Ok(visit_comparison(expr, index_info)) + } + // visit_comparison will maybe create an Eq query which we negate + Operator::NotEq => Ok(visit_comparison(expr, index_info).and_then(|node| node.maybe_not())), + Operator::And => visit_and(expr, index_info, depth), + Operator::Or => visit_or(expr, index_info, depth), + _ => Ok(None), + } +} + +fn visit_scalar_fn( + scalar_fn: &ScalarFunction, + index_info: &dyn IndexInformationProvider, +) -> Option { + if scalar_fn.args.is_empty() { + return None; + } + let (col, data_type, query_parser) = maybe_indexed_column(&scalar_fn.args[0], index_info)?; + query_parser.visit_scalar_function(&col, &data_type, &scalar_fn.func, &scalar_fn.args) +} + +fn visit_like_expr( + like: &Like, + index_info: &dyn IndexInformationProvider, +) -> Option { + let (column, _, query_parser) = maybe_indexed_column(&like.expr, index_info)?; + + // Extract the pattern as a ScalarValue + let pattern = match like.pattern.as_ref() { + Expr::Literal(scalar, _) => scalar.clone(), + _ => return None, + }; + + query_parser.visit_like(&column, like, &pattern) +} + +fn visit_node( + expr: &Expr, + index_info: &dyn IndexInformationProvider, + depth: usize, +) -> Result> { + if depth >= MAX_DEPTH { + return Err(Error::invalid_input(format!( + "the filter expression is too long, lance limit the max number of conditions to {}", + MAX_DEPTH + ))); + } + match expr { + Expr::Between(between) => Ok(visit_between(between, index_info)), + Expr::Alias(alias) => visit_node(alias.expr.as_ref(), index_info, depth), + Expr::Column(_) => Ok(visit_column(expr, index_info)), + Expr::InList(in_list) => Ok(visit_in_list(in_list, index_info)), + Expr::IsFalse(expr) => Ok(visit_is_bool(expr.as_ref(), index_info, false)), + Expr::IsTrue(expr) => Ok(visit_is_bool(expr.as_ref(), index_info, true)), + Expr::IsNull(expr) => Ok(visit_is_null(expr.as_ref(), index_info, false)), + Expr::IsNotNull(expr) => Ok(visit_is_null(expr.as_ref(), index_info, true)), + Expr::Not(expr) => visit_not(expr.as_ref(), index_info, depth), + Expr::BinaryExpr(binary_expr) => visit_binary_expr(binary_expr, index_info, depth), + Expr::ScalarFunction(scalar_fn) => Ok(visit_scalar_fn(scalar_fn, index_info)), + Expr::Like(like) => { + if like.negated { + // NOT LIKE cannot be efficiently pruned with zone maps + Ok(None) + } else { + Ok(visit_like_expr(like, index_info)) + } + } + _ => Ok(None), + } +} + +/// A trait to be used in `apply_scalar_indices` to inform the function which columns are indexeds +pub trait IndexInformationProvider { + /// Check if an index exists for `col` and, if so, return the data type of col + /// as well as a query parser that can parse queries for that column + fn get_index(&self, col: &str) -> Option<(&DataType, &dyn ScalarQueryParser)>; +} + +/// Attempt to split a filter expression into a search of scalar indexes and an +/// optional post-search refinement query +pub fn apply_scalar_indices( + expr: Expr, + index_info: &dyn IndexInformationProvider, +) -> Result { + Ok(visit_node(&expr, index_info, 0)?.unwrap_or(IndexedExpression::refine_only(expr))) +} + +#[derive(Clone, Default, Debug)] +pub struct FilterPlan { + pub index_query: Option, + /// True if the index query is guaranteed to return exact results + pub skip_recheck: bool, + pub refine_expr: Option, + pub full_expr: Option, +} + +impl FilterPlan { + pub fn empty() -> Self { + Self { + index_query: None, + skip_recheck: true, + refine_expr: None, + full_expr: None, + } + } + + pub fn new_refine_only(expr: Expr) -> Self { + Self { + index_query: None, + skip_recheck: true, + refine_expr: Some(expr.clone()), + full_expr: Some(expr), + } + } + + pub fn is_empty(&self) -> bool { + self.refine_expr.is_none() && self.index_query.is_none() + } + + pub fn all_columns(&self) -> Vec { + self.full_expr + .as_ref() + .map(Planner::column_names_in_expr) + .unwrap_or_default() + } + + pub fn refine_columns(&self) -> Vec { + self.refine_expr + .as_ref() + .map(Planner::column_names_in_expr) + .unwrap_or_default() + } + + /// Return true if this has a refine step, regardless of the status of prefilter + pub fn has_refine(&self) -> bool { + self.refine_expr.is_some() + } + + /// Return true if this has a scalar index query + pub fn has_index_query(&self) -> bool { + self.index_query.is_some() + } + + pub fn has_any_filter(&self) -> bool { + self.refine_expr.is_some() || self.index_query.is_some() + } + + pub fn make_refine_only(&mut self) { + self.index_query = None; + self.refine_expr = self.full_expr.clone(); + } + + /// Return true if there is no refine or recheck of any kind and there is an index query + pub fn is_exact_index_search(&self) -> bool { + self.index_query.is_some() && self.refine_expr.is_none() && self.skip_recheck + } +} + +pub trait PlannerIndexExt { + /// Determine how to apply a provided filter + /// + /// We parse the filter into a logical expression. We then + /// split the logical expression into a portion that can be + /// satisfied by an index search (of one or more indices) and + /// a refine portion that must be applied after the index search + fn create_filter_plan( + &self, + filter: Expr, + index_info: &dyn IndexInformationProvider, + use_scalar_index: bool, + ) -> Result; +} + +impl PlannerIndexExt for Planner { + fn create_filter_plan( + &self, + filter: Expr, + index_info: &dyn IndexInformationProvider, + use_scalar_index: bool, + ) -> Result { + let logical_expr = self.optimize_expr(filter)?; + if use_scalar_index { + let indexed_expr = apply_scalar_indices(logical_expr.clone(), index_info)?; + let mut skip_recheck = false; + if let Some(scalar_query) = indexed_expr.scalar_query.as_ref() { + skip_recheck = !scalar_query.needs_recheck(); + } + Ok(FilterPlan { + index_query: indexed_expr.scalar_query, + refine_expr: indexed_expr.refine_expr, + full_expr: Some(logical_expr), + skip_recheck, + }) + } else { + Ok(FilterPlan { + index_query: None, + skip_recheck: true, + refine_expr: Some(logical_expr.clone()), + full_expr: Some(logical_expr), + }) + } + } +} + +#[cfg(test)] +mod tests { + use std::collections::HashMap; + + use arrow_schema::{Field, Schema}; + use chrono::Utc; + use datafusion_common::{Column, DFSchema}; + use datafusion_expr::simplify::SimplifyContext; + use lance_datafusion::exec::{LanceExecutionOptions, get_session_context}; + + use crate::scalar::json::{JsonQuery, JsonQueryParser}; + + use super::*; + + struct ColInfo { + data_type: DataType, + parser: Box, + } + + impl ColInfo { + fn new(data_type: DataType, parser: Box) -> Self { + Self { data_type, parser } + } + } + + struct MockIndexInfoProvider { + indexed_columns: HashMap, + } + + impl MockIndexInfoProvider { + fn new(indexed_columns: Vec<(&str, ColInfo)>) -> Self { + Self { + indexed_columns: HashMap::from_iter( + indexed_columns + .into_iter() + .map(|(s, ty)| (s.to_string(), ty)), + ), + } + } + } + + impl IndexInformationProvider for MockIndexInfoProvider { + fn get_index(&self, col: &str) -> Option<(&DataType, &dyn ScalarQueryParser)> { + self.indexed_columns + .get(col) + .map(|col_info| (&col_info.data_type, col_info.parser.as_ref())) + } + } + + fn check( + index_info: &dyn IndexInformationProvider, + expr: &str, + expected: Option, + optimize: bool, + ) { + let schema = Schema::new(vec![ + Field::new("color", DataType::Utf8, false), + Field::new("size", DataType::Float32, false), + Field::new("aisle", DataType::UInt32, false), + Field::new("on_sale", DataType::Boolean, false), + Field::new("price", DataType::Float32, false), + Field::new("json", DataType::LargeBinary, false), + ]); + let df_schema: DFSchema = schema.try_into().unwrap(); + + let ctx = get_session_context(&LanceExecutionOptions::default()); + let state = ctx.state(); + let mut expr = state.create_logical_expr(expr, &df_schema).unwrap(); + if optimize { + let simplify_context = SimplifyContext::default() + .with_schema(Arc::new(df_schema)) + .with_query_execution_start_time(Some(Utc::now())); + let simplifier = + datafusion::optimizer::simplify_expressions::ExprSimplifier::new(simplify_context); + expr = simplifier.simplify(expr).unwrap(); + } + + let actual = apply_scalar_indices(expr.clone(), index_info).unwrap(); + if let Some(expected) = expected { + assert_eq!(actual, expected); + } else { + assert!(actual.scalar_query.is_none()); + assert_eq!(actual.refine_expr.unwrap(), expr); + } + } + + fn check_no_index(index_info: &dyn IndexInformationProvider, expr: &str) { + check(index_info, expr, None, false) + } + + fn check_simple( + index_info: &dyn IndexInformationProvider, + expr: &str, + col: &str, + query: impl AnyQuery, + ) { + check( + index_info, + expr, + Some(IndexedExpression::index_query( + col.to_string(), + format!("{}_idx", col), + "BTree".to_string(), + Arc::new(query), + )), + false, + ) + } + + fn check_range( + index_info: &dyn IndexInformationProvider, + expr: &str, + col: &str, + query: SargableQuery, + ) { + check( + index_info, + expr, + Some(IndexedExpression::index_query( + col.to_string(), + format!("{}_idx", col), + "BTree".to_string(), + Arc::new(query), + )), + true, + ) + } + + fn check_simple_negated( + index_info: &dyn IndexInformationProvider, + expr: &str, + col: &str, + query: SargableQuery, + ) { + check( + index_info, + expr, + Some( + IndexedExpression::index_query( + col.to_string(), + format!("{}_idx", col), + "BTree".to_string(), + Arc::new(query), + ) + .maybe_not() + .unwrap(), + ), + false, + ) + } + + #[test] + fn test_expressions() { + let index_info = MockIndexInfoProvider::new(vec![ + ( + "color", + ColInfo::new( + DataType::Utf8, + Box::new(SargableQueryParser::new( + "color_idx".to_string(), + "BTree".to_string(), + false, + )), + ), + ), + ( + "aisle", + ColInfo::new( + DataType::UInt32, + Box::new(SargableQueryParser::new( + "aisle_idx".to_string(), + "BTree".to_string(), + false, + )), + ), + ), + ( + "on_sale", + ColInfo::new( + DataType::Boolean, + Box::new(SargableQueryParser::new( + "on_sale_idx".to_string(), + "BTree".to_string(), + false, + )), + ), + ), + ( + "price", + ColInfo::new( + DataType::Float32, + Box::new(SargableQueryParser::new( + "price_idx".to_string(), + "BTree".to_string(), + false, + )), + ), + ), + ( + "json", + ColInfo::new( + DataType::LargeBinary, + Box::new(JsonQueryParser::new( + "$.name".to_string(), + Box::new(SargableQueryParser::new( + "json_idx".to_string(), + "BTree".to_string(), + false, + )), + )), + ), + ), + ]); + + check_simple( + &index_info, + "json_extract(json, '$.name') = 'foo'", + "json", + JsonQuery::new( + Arc::new(SargableQuery::Equals(ScalarValue::Utf8(Some( + "foo".to_string(), + )))), + "$.name".to_string(), + ), + ); + + check_no_index(&index_info, "size BETWEEN 5 AND 10"); + // Cast case. We will cast 5 (an int64) to Int16 and then coerce to UInt32 + check_simple( + &index_info, + "aisle = arrow_cast(5, 'Int16')", + "aisle", + SargableQuery::Equals(ScalarValue::UInt32(Some(5))), + ); + // 5 different ways of writing BETWEEN (all should be recognized) + check_range( + &index_info, + "aisle BETWEEN 5 AND 10", + "aisle", + SargableQuery::Range( + Bound::Included(ScalarValue::UInt32(Some(5))), + Bound::Included(ScalarValue::UInt32(Some(10))), + ), + ); + check_range( + &index_info, + "aisle >= 5 AND aisle <= 10", + "aisle", + SargableQuery::Range( + Bound::Included(ScalarValue::UInt32(Some(5))), + Bound::Included(ScalarValue::UInt32(Some(10))), + ), + ); + + check_range( + &index_info, + "aisle <= 10 AND aisle >= 5", + "aisle", + SargableQuery::Range( + Bound::Included(ScalarValue::UInt32(Some(5))), + Bound::Included(ScalarValue::UInt32(Some(10))), + ), + ); + + check_range( + &index_info, + "5 <= aisle AND 10 >= aisle", + "aisle", + SargableQuery::Range( + Bound::Included(ScalarValue::UInt32(Some(5))), + Bound::Included(ScalarValue::UInt32(Some(10))), + ), + ); + + check_range( + &index_info, + "10 >= aisle AND 5 <= aisle", + "aisle", + SargableQuery::Range( + Bound::Included(ScalarValue::UInt32(Some(5))), + Bound::Included(ScalarValue::UInt32(Some(10))), + ), + ); + check_range( + &index_info, + "aisle <= 10 AND aisle > 5", + "aisle", + SargableQuery::Range( + Bound::Excluded(ScalarValue::UInt32(Some(5))), + Bound::Included(ScalarValue::UInt32(Some(10))), + ), + ); + check_range( + &index_info, + "aisle < 10 AND aisle >= 5", + "aisle", + SargableQuery::Range( + Bound::Included(ScalarValue::UInt32(Some(5))), + Bound::Excluded(ScalarValue::UInt32(Some(10))), + ), + ); + check_simple( + &index_info, + "on_sale IS TRUE", + "on_sale", + SargableQuery::Equals(ScalarValue::Boolean(Some(true))), + ); + check_simple( + &index_info, + "on_sale", + "on_sale", + SargableQuery::Equals(ScalarValue::Boolean(Some(true))), + ); + check_simple_negated( + &index_info, + "NOT on_sale", + "on_sale", + SargableQuery::Equals(ScalarValue::Boolean(Some(true))), + ); + check_simple( + &index_info, + "on_sale IS FALSE", + "on_sale", + SargableQuery::Equals(ScalarValue::Boolean(Some(false))), + ); + check_simple_negated( + &index_info, + "aisle NOT BETWEEN 5 AND 10", + "aisle", + SargableQuery::Range( + Bound::Included(ScalarValue::UInt32(Some(5))), + Bound::Included(ScalarValue::UInt32(Some(10))), + ), + ); + // Small in-list (in-list with 3 or fewer items optimizes into or-chain) + check_simple( + &index_info, + "aisle IN (5, 6, 7)", + "aisle", + SargableQuery::IsIn(vec![ + ScalarValue::UInt32(Some(5)), + ScalarValue::UInt32(Some(6)), + ScalarValue::UInt32(Some(7)), + ]), + ); + check_simple_negated( + &index_info, + "NOT aisle IN (5, 6, 7)", + "aisle", + SargableQuery::IsIn(vec![ + ScalarValue::UInt32(Some(5)), + ScalarValue::UInt32(Some(6)), + ScalarValue::UInt32(Some(7)), + ]), + ); + check_simple_negated( + &index_info, + "aisle NOT IN (5, 6, 7)", + "aisle", + SargableQuery::IsIn(vec![ + ScalarValue::UInt32(Some(5)), + ScalarValue::UInt32(Some(6)), + ScalarValue::UInt32(Some(7)), + ]), + ); + check_simple( + &index_info, + "aisle IN (5, 6, 7, 8, 9)", + "aisle", + SargableQuery::IsIn(vec![ + ScalarValue::UInt32(Some(5)), + ScalarValue::UInt32(Some(6)), + ScalarValue::UInt32(Some(7)), + ScalarValue::UInt32(Some(8)), + ScalarValue::UInt32(Some(9)), + ]), + ); + check_simple_negated( + &index_info, + "NOT aisle IN (5, 6, 7, 8, 9)", + "aisle", + SargableQuery::IsIn(vec![ + ScalarValue::UInt32(Some(5)), + ScalarValue::UInt32(Some(6)), + ScalarValue::UInt32(Some(7)), + ScalarValue::UInt32(Some(8)), + ScalarValue::UInt32(Some(9)), + ]), + ); + check_simple_negated( + &index_info, + "aisle NOT IN (5, 6, 7, 8, 9)", + "aisle", + SargableQuery::IsIn(vec![ + ScalarValue::UInt32(Some(5)), + ScalarValue::UInt32(Some(6)), + ScalarValue::UInt32(Some(7)), + ScalarValue::UInt32(Some(8)), + ScalarValue::UInt32(Some(9)), + ]), + ); + check_simple( + &index_info, + "on_sale is false", + "on_sale", + SargableQuery::Equals(ScalarValue::Boolean(Some(false))), + ); + check_simple( + &index_info, + "on_sale is true", + "on_sale", + SargableQuery::Equals(ScalarValue::Boolean(Some(true))), + ); + check_simple( + &index_info, + "aisle < 10", + "aisle", + SargableQuery::Range( + Bound::Unbounded, + Bound::Excluded(ScalarValue::UInt32(Some(10))), + ), + ); + check_simple( + &index_info, + "aisle <= 10", + "aisle", + SargableQuery::Range( + Bound::Unbounded, + Bound::Included(ScalarValue::UInt32(Some(10))), + ), + ); + check_simple( + &index_info, + "aisle > 10", + "aisle", + SargableQuery::Range( + Bound::Excluded(ScalarValue::UInt32(Some(10))), + Bound::Unbounded, + ), + ); + // In the future we can handle this case if we need to. For + // now let's make sure we don't accidentally do the wrong thing + // (we were getting this backwards in the past) + check_no_index(&index_info, "10 > aisle"); + check_simple( + &index_info, + "aisle >= 10", + "aisle", + SargableQuery::Range( + Bound::Included(ScalarValue::UInt32(Some(10))), + Bound::Unbounded, + ), + ); + check_simple( + &index_info, + "aisle = 10", + "aisle", + SargableQuery::Equals(ScalarValue::UInt32(Some(10))), + ); + check_simple_negated( + &index_info, + "aisle <> 10", + "aisle", + SargableQuery::Equals(ScalarValue::UInt32(Some(10))), + ); + // // Common compound case, AND'd clauses + let left = Box::new(ScalarIndexExpr::Query(ScalarIndexSearch { + column: "aisle".to_string(), + index_name: "aisle_idx".to_string(), + index_type: "BTree".to_string(), + query: Arc::new(SargableQuery::Equals(ScalarValue::UInt32(Some(10)))), + needs_recheck: false, + })); + let right = Box::new(ScalarIndexExpr::Query(ScalarIndexSearch { + column: "color".to_string(), + index_name: "color_idx".to_string(), + index_type: "BTree".to_string(), + query: Arc::new(SargableQuery::Equals(ScalarValue::Utf8(Some( + "blue".to_string(), + )))), + needs_recheck: false, + })); + check( + &index_info, + "aisle = 10 AND color = 'blue'", + Some(IndexedExpression { + scalar_query: Some(ScalarIndexExpr::And(left.clone(), right.clone())), + refine_expr: None, + }), + false, + ); + // Compound AND's and not all of them are indexed columns + let refine = Expr::Column(Column::new_unqualified("size")).gt(datafusion_expr::lit(30_i64)); + check( + &index_info, + "aisle = 10 AND color = 'blue' AND size > 30", + Some(IndexedExpression { + scalar_query: Some(ScalarIndexExpr::And(left.clone(), right.clone())), + refine_expr: Some(refine.clone()), + }), + false, + ); + // Compounded OR's where ALL columns are indexed + check( + &index_info, + "aisle = 10 OR color = 'blue'", + Some(IndexedExpression { + scalar_query: Some(ScalarIndexExpr::Or(left.clone(), right.clone())), + refine_expr: None, + }), + false, + ); + // Compounded OR's with one or more unindexed columns + check_no_index(&index_info, "aisle = 10 OR color = 'blue' OR size > 30"); + // AND'd group of OR + check( + &index_info, + "(aisle = 10 OR color = 'blue') AND size > 30", + Some(IndexedExpression { + scalar_query: Some(ScalarIndexExpr::Or(left, right)), + refine_expr: Some(refine), + }), + false, + ); + // Examples of things that are not yet supported but should be supportable someday + + // OR'd group of refined index searches (see IndexedExpression::or for details) + check_no_index( + &index_info, + "(aisle = 10 AND size > 30) OR (color = 'blue' AND size > 20)", + ); + + // Non-normalized arithmetic (can use expression simplification) + check_no_index(&index_info, "aisle + 3 < 10"); + + // Currently we assume that the return of an index search tells us which rows are + // TRUE and all other rows are FALSE. This will need to change but for now it is + // safer to not support the following cases because the return value of non-matched + // rows is NULL and not FALSE. + check_no_index(&index_info, "aisle IN (5, 6, NULL)"); + // OR-list with NULL (in future DF version this will be optimized repr of + // small in-list with NULL so let's get ready for it) + check_no_index(&index_info, "aisle = 5 OR aisle = 6 OR NULL"); + check_no_index(&index_info, "aisle IN (5, 6, 7, 8, NULL)"); + check_no_index(&index_info, "aisle = NULL"); + check_no_index(&index_info, "aisle BETWEEN 5 AND NULL"); + check_no_index(&index_info, "aisle BETWEEN NULL AND 10"); + } + + #[tokio::test] + async fn test_not_flips_certainty() { + use lance_core::utils::mask::{NullableRowAddrSet, RowAddrTreeMap}; + + // Test that NOT flips certainty for inexact index results + // This tests the implementation in evaluate_impl for Self::Not + + // Helper function that mimics the NOT logic we just fixed + fn apply_not(result: NullableIndexExprResult) -> NullableIndexExprResult { + match result { + NullableIndexExprResult::Exact(mask) => NullableIndexExprResult::Exact(!mask), + NullableIndexExprResult::AtMost(mask) => NullableIndexExprResult::AtLeast(!mask), + NullableIndexExprResult::AtLeast(mask) => NullableIndexExprResult::AtMost(!mask), + } + } + + // AtMost: superset of matches (e.g., bloom filter says "might be in [1,2]") + let at_most = NullableIndexExprResult::AtMost(NullableRowAddrMask::AllowList( + NullableRowAddrSet::new(RowAddrTreeMap::from_iter(&[1, 2]), RowAddrTreeMap::new()), + )); + // NOT(AtMost) should be AtLeast (definitely NOT in [1,2], might be elsewhere) + assert!(matches!( + apply_not(at_most), + NullableIndexExprResult::AtLeast(_) + )); + + // AtLeast: subset of matches (e.g., definitely in [1,2], might be more) + let at_least = NullableIndexExprResult::AtLeast(NullableRowAddrMask::AllowList( + NullableRowAddrSet::new(RowAddrTreeMap::from_iter(&[1, 2]), RowAddrTreeMap::new()), + )); + // NOT(AtLeast) should be AtMost (might NOT be in [1,2], definitely elsewhere) + assert!(matches!( + apply_not(at_least), + NullableIndexExprResult::AtMost(_) + )); + + // Exact should stay Exact + let exact = NullableIndexExprResult::Exact(NullableRowAddrMask::AllowList( + NullableRowAddrSet::new(RowAddrTreeMap::from_iter(&[1, 2]), RowAddrTreeMap::new()), + )); + assert!(matches!( + apply_not(exact), + NullableIndexExprResult::Exact(_) + )); + } + + #[tokio::test] + async fn test_and_or_preserve_certainty() { + use lance_core::utils::mask::{NullableRowAddrSet, RowAddrTreeMap}; + + // Test that AND/OR correctly propagate certainty + let make_at_most = || { + NullableIndexExprResult::AtMost(NullableRowAddrMask::AllowList( + NullableRowAddrSet::new( + RowAddrTreeMap::from_iter(&[1, 2, 3]), + RowAddrTreeMap::new(), + ), + )) + }; + + let make_at_least = || { + NullableIndexExprResult::AtLeast(NullableRowAddrMask::AllowList( + NullableRowAddrSet::new( + RowAddrTreeMap::from_iter(&[2, 3, 4]), + RowAddrTreeMap::new(), + ), + )) + }; + + let make_exact = || { + NullableIndexExprResult::Exact(NullableRowAddrMask::AllowList(NullableRowAddrSet::new( + RowAddrTreeMap::from_iter(&[1, 2]), + RowAddrTreeMap::new(), + ))) + }; + + // AtMost & AtMost → AtMost + assert!(matches!( + make_at_most() & make_at_most(), + NullableIndexExprResult::AtMost(_) + )); + + // AtLeast & AtLeast → AtLeast + assert!(matches!( + make_at_least() & make_at_least(), + NullableIndexExprResult::AtLeast(_) + )); + + // AtMost & AtLeast → AtMost (superset remains superset) + assert!(matches!( + make_at_most() & make_at_least(), + NullableIndexExprResult::AtMost(_) + )); + + // AtMost | AtMost → AtMost + assert!(matches!( + make_at_most() | make_at_most(), + NullableIndexExprResult::AtMost(_) + )); + + // AtLeast | AtLeast → AtLeast + assert!(matches!( + make_at_least() | make_at_least(), + NullableIndexExprResult::AtLeast(_) + )); + + // AtMost | AtLeast → AtLeast (subset coverage guaranteed) + assert!(matches!( + make_at_most() | make_at_least(), + NullableIndexExprResult::AtLeast(_) + )); + + // Exact & AtMost → AtMost + assert!(matches!( + make_exact() & make_at_most(), + NullableIndexExprResult::AtMost(_) + )); + + // Exact | AtLeast → AtLeast + assert!(matches!( + make_exact() | make_at_least(), + NullableIndexExprResult::AtLeast(_) + )); + } + + #[test] + fn test_extract_like_leading_prefix() { + // Simple prefix patterns (no recheck needed) + assert_eq!( + extract_like_leading_prefix("foo%", None), + Some(("foo".to_string(), false)) + ); + assert_eq!( + extract_like_leading_prefix("abc%", None), + Some(("abc".to_string(), false)) + ); + + // Patterns with wildcards in the middle (need recheck) + assert_eq!( + extract_like_leading_prefix("foo%bar%", None), + Some(("foo".to_string(), true)) + ); + assert_eq!( + extract_like_leading_prefix("foo_bar%", None), + Some(("foo".to_string(), true)) + ); + assert_eq!( + extract_like_leading_prefix("foo%bar", None), + Some(("foo".to_string(), true)) + ); + assert_eq!( + extract_like_leading_prefix("foo_", None), + Some(("foo".to_string(), true)) + ); + + // Not prefix patterns (starts with wildcard) + assert_eq!(extract_like_leading_prefix("%foo", None), None); + assert_eq!(extract_like_leading_prefix("_foo%", None), None); + assert_eq!(extract_like_leading_prefix("%", None), None); + + // No wildcard at all (should use equality) + assert_eq!(extract_like_leading_prefix("foo", None), None); + + // With escape character + assert_eq!( + extract_like_leading_prefix(r"foo\%bar%", Some('\\')), + Some(("foo%bar".to_string(), false)) + ); + assert_eq!( + extract_like_leading_prefix(r"foo\_bar%", Some('\\')), + Some(("foo_bar".to_string(), false)) + ); + assert_eq!( + extract_like_leading_prefix(r"foo\\bar%", Some('\\')), + Some(("foo\\bar".to_string(), false)) + ); + + // Escaped trailing % is not a wildcard (no wildcards) + assert_eq!(extract_like_leading_prefix(r"foo\%", Some('\\')), None); + + // With backslash as default escape (for DataFusion starts_with compatibility): + // "foo\%" means escaped %, no wildcard -> None (should use equality) + assert_eq!(extract_like_leading_prefix(r"foo\%", None), None); + // "foo\bar%" - \b is not a valid escape sequence, so \ and b are literals, % is wildcard + assert_eq!( + extract_like_leading_prefix(r"foo\bar%", None), + Some(("foo\\bar".to_string(), false)) + ); + + // Empty pattern + assert_eq!(extract_like_leading_prefix("", None), None); + + // Mixed escaped and unescaped + assert_eq!( + extract_like_leading_prefix(r"foo\%bar%baz%", Some('\\')), + Some(("foo%bar".to_string(), true)) + ); + } + + #[test] + fn test_like_expression_parsing() { + // Test that LIKE expressions are parsed correctly with refine_expr for complex patterns + + let index_info = MockIndexInfoProvider::new(vec![( + "color", + ColInfo::new( + DataType::Utf8, + Box::new(SargableQueryParser::new( + "color_idx".to_string(), + "BTree".to_string(), + false, + )), + ), + )]); + + // Simple prefix pattern: LIKE 'foo%' -> LikePrefix("foo"), no refine_expr + let schema = Schema::new(vec![Field::new("color", DataType::Utf8, false)]); + let df_schema: DFSchema = schema.try_into().unwrap(); + let ctx = get_session_context(&LanceExecutionOptions::default()); + let state = ctx.state(); + + let expr = state + .create_logical_expr("color LIKE 'foo%'", &df_schema) + .unwrap(); + let result = apply_scalar_indices(expr, &index_info).unwrap(); + + assert!(result.scalar_query.is_some(), "Should have scalar_query"); + assert!( + result.refine_expr.is_none(), + "Simple prefix should not need refine_expr" + ); + + // Extract the query and verify it's LikePrefix + if let Some(ScalarIndexExpr::Query(search)) = &result.scalar_query { + let query = search.query.as_any().downcast_ref::(); + assert!(query.is_some(), "Query should be SargableQuery"); + match query.unwrap() { + SargableQuery::LikePrefix(prefix) => { + assert_eq!(prefix, &ScalarValue::Utf8(Some("foo".to_string()))); + } + _ => panic!("Expected LikePrefix query"), + } + } else { + panic!("Expected Query variant"); + } + + // Complex pattern: LIKE 'foo%bar%' -> LikePrefix("foo"), with refine_expr + let expr = state + .create_logical_expr("color LIKE 'foo%bar%'", &df_schema) + .unwrap(); + let result = apply_scalar_indices(expr, &index_info).unwrap(); + + assert!(result.scalar_query.is_some(), "Should have scalar_query"); + assert!( + result.refine_expr.is_some(), + "Complex pattern should have refine_expr" + ); + + // Verify the query is still LikePrefix("foo") + if let Some(ScalarIndexExpr::Query(search)) = &result.scalar_query { + let query = search.query.as_any().downcast_ref::(); + assert!(query.is_some(), "Query should be SargableQuery"); + match query.unwrap() { + SargableQuery::LikePrefix(prefix) => { + assert_eq!(prefix, &ScalarValue::Utf8(Some("foo".to_string()))); + } + _ => panic!("Expected LikePrefix query"), + } + } + + // Verify the refine_expr is the original LIKE expression + let refine = result.refine_expr.unwrap(); + match refine { + Expr::Like(like) => { + assert!(!like.negated); + assert!(!like.case_insensitive); + if let Expr::Literal(ScalarValue::Utf8(Some(pattern)), _) = like.pattern.as_ref() { + assert_eq!(pattern, "foo%bar%"); + } else { + panic!("Expected Utf8 literal pattern"); + } + } + _ => panic!("Expected Like expression in refine_expr"), + } + + // Pattern starting with wildcard: LIKE '%foo' -> no index, only refine + let expr = state + .create_logical_expr("color LIKE '%foo'", &df_schema) + .unwrap(); + let result = apply_scalar_indices(expr, &index_info).unwrap(); + + assert!( + result.scalar_query.is_none(), + "Pattern starting with wildcard should not use index" + ); + assert!(result.refine_expr.is_some(), "Should fall back to refine"); + } + + #[test] + fn test_starts_with_with_underscore_after_optimization() { + // Test that starts_with with underscore in prefix works correctly after DataFusion optimization + // DataFusion simplifies starts_with(col, 'test_ns$') to col LIKE 'test_ns$%' + // The underscore in the prefix should NOT be treated as a wildcard! + let index_info = MockIndexInfoProvider::new(vec![( + "object_id", + ColInfo::new( + DataType::Utf8, + Box::new(SargableQueryParser::new( + "object_id_idx".to_string(), + "BTree".to_string(), + false, + )), + ), + )]); + + let schema = Schema::new(vec![Field::new("object_id", DataType::Utf8, false)]); + let df_schema: DFSchema = schema.try_into().unwrap(); + let ctx = get_session_context(&LanceExecutionOptions::default()); + let state = ctx.state(); + + // Create the expression with starts_with containing underscore + let expr = state + .create_logical_expr("starts_with(object_id, 'test_ns$')", &df_schema) + .unwrap(); + + // Apply DataFusion simplification (this may convert starts_with to LIKE) + let simplify_context = SimplifyContext::default() + .with_schema(Arc::new(df_schema)) + .with_query_execution_start_time(Some(Utc::now())); + let simplifier = + datafusion::optimizer::simplify_expressions::ExprSimplifier::new(simplify_context); + let simplified_expr = simplifier.simplify(expr).unwrap(); + + // Apply scalar indices + let result = apply_scalar_indices(simplified_expr, &index_info).unwrap(); + + // The prefix should be "test_ns$", NOT "test" + // This test documents the current (potentially broken) behavior + if let Some(ScalarIndexExpr::Query(search)) = &result.scalar_query { + let query = search + .query + .as_any() + .downcast_ref::() + .unwrap(); + match query { + SargableQuery::LikePrefix(prefix) => { + let prefix_str = match prefix { + ScalarValue::Utf8(Some(s)) => s.clone(), + _ => panic!("Expected Utf8 prefix"), + }; + // Verify the prefix is correctly extracted with underscore as literal + assert_eq!( + prefix_str, "test_ns$", + "Prefix should be 'test_ns$', not 'test' (underscore should not be a wildcard)" + ); + } + _ => panic!("Expected LikePrefix query"), + } + } else { + // If no scalar query, it means the pattern was not recognized + panic!("Expected scalar_query to be present"); + } + } + + #[test] + fn test_starts_with_to_like_conversion() { + // Test that starts_with(col, 'prefix') is converted to LikePrefix query + let index_info = MockIndexInfoProvider::new(vec![( + "color", + ColInfo::new( + DataType::Utf8, + Box::new(SargableQueryParser::new( + "color_idx".to_string(), + "BTree".to_string(), + false, + )), + ), + )]); + + let schema = Schema::new(vec![Field::new("color", DataType::Utf8, false)]); + let df_schema: DFSchema = schema.try_into().unwrap(); + let ctx = get_session_context(&LanceExecutionOptions::default()); + let state = ctx.state(); + + // starts_with(color, 'foo') should be converted to LikePrefix("foo") + let expr = state + .create_logical_expr("starts_with(color, 'foo')", &df_schema) + .unwrap(); + let result = apply_scalar_indices(expr, &index_info).unwrap(); + + assert!( + result.scalar_query.is_some(), + "starts_with should use index" + ); + assert!( + result.refine_expr.is_none(), + "Pure prefix starts_with should not need refine_expr" + ); + + // Extract the query and verify it's LikePrefix + if let Some(ScalarIndexExpr::Query(search)) = &result.scalar_query { + let query = search.query.as_any().downcast_ref::(); + assert!(query.is_some(), "Query should be SargableQuery"); + match query.unwrap() { + SargableQuery::LikePrefix(prefix) => { + assert_eq!(prefix, &ScalarValue::Utf8(Some("foo".to_string()))); + } + _ => panic!("Expected LikePrefix query"), + } + } else { + panic!("Expected Query variant"); + } + + // Both starts_with and LIKE 'prefix%' should produce the same LikePrefix query + let like_expr = state + .create_logical_expr("color LIKE 'foo%'", &df_schema) + .unwrap(); + let like_result = apply_scalar_indices(like_expr, &index_info).unwrap(); + + // Compare the queries - both should be LikePrefix("foo") + if let ( + Some(ScalarIndexExpr::Query(starts_with_search)), + Some(ScalarIndexExpr::Query(like_search)), + ) = (&result.scalar_query, &like_result.scalar_query) + { + let sw_query = starts_with_search + .query + .as_any() + .downcast_ref::() + .unwrap(); + let like_query = like_search + .query + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!( + sw_query, like_query, + "starts_with and LIKE 'prefix%' should produce identical queries" + ); + } + } +} diff --git a/rust/lance-index/src/lib.rs b/rust/lance-index/src/lib.rs index 0ed6ddd4e2d..c2a92a2d300 100644 --- a/rust/lance-index/src/lib.rs +++ b/rust/lance-index/src/lib.rs @@ -20,6 +20,7 @@ use roaring::RoaringBitmap; use serde::{Deserialize, Serialize}; use std::convert::TryFrom; +pub mod expression; pub mod frag_reuse; pub mod mem_wal; pub mod metrics; diff --git a/rust/lance-index/src/scalar.rs b/rust/lance-index/src/scalar.rs index c89d75c11db..c3803db7aba 100644 --- a/rust/lance-index/src/scalar.rs +++ b/rust/lance-index/src/scalar.rs @@ -12,6 +12,7 @@ use datafusion::functions::string::contains::ContainsFunc; use datafusion::functions_nested::array_has; use datafusion::physical_plan::SendableRecordBatchStream; use datafusion_common::{Column, scalar::ScalarValue}; +pub use lance_arrow_scalar::ArrowScalar; use std::collections::{HashMap, HashSet}; use std::fmt::Debug; use std::pin::Pin; @@ -27,6 +28,7 @@ use lance_io::stream::{RecordBatchStream, RecordBatchStreamAdapter}; use roaring::RoaringBitmap; use serde::Serialize; +use crate::expression::aggregate::AnyAggregateQuery; use crate::metrics::MetricsCollector; use crate::scalar::registry::TrainingCriteria; use crate::{Index, IndexParams, IndexType}; @@ -984,6 +986,19 @@ pub trait ScalarIndex: Send + Sync + std::fmt::Debug + Index + DeepSizeOf { metrics: &dyn MetricsCollector, ) -> Result; + /// Calculates an aggregate value using the index and an optional `filter` + /// + /// The returned value should be a partial aggregate. For example, if calculating + /// the average, the returned value should be the sum of all values and the count of values, + /// returned as a struct scalar. + async fn calculate_aggregate( + &self, + query: &dyn AnyAggregateQuery, + filter: Option, + total_rows: u64, + metrics: &dyn MetricsCollector, + ) -> Result; + /// Returns true if the remap operation is supported fn can_remap(&self) -> bool; diff --git a/rust/lance-index/src/scalar/bitmap.rs b/rust/lance-index/src/scalar/bitmap.rs index 45027cc7b63..a5804b05e29 100644 --- a/rust/lance-index/src/scalar/bitmap.rs +++ b/rust/lance-index/src/scalar/bitmap.rs @@ -19,6 +19,7 @@ use datafusion::physical_plan::SendableRecordBatchStream; use datafusion_common::ScalarValue; use deepsize::DeepSizeOf; use futures::{StreamExt, TryStreamExt, stream}; +use lance_arrow_scalar::ArrowScalar; use lance_core::utils::mask::RowSetOps; use lance_core::{ Error, ROW_ID, Result, @@ -40,7 +41,9 @@ use super::{ BuiltinIndexType, SargableQuery, ScalarIndexParams, SearchResult, btree::OrderableScalarValue, }; use crate::pbold; -use crate::{Index, IndexType, metrics::MetricsCollector}; +use crate::{ + Index, IndexType, expression::aggregate::AnyAggregateQuery, metrics::MetricsCollector, +}; use crate::{ frag_reuse::FragReuseIndex, progress::IndexBuildProgress, @@ -607,6 +610,18 @@ impl ScalarIndex for BitmapIndex { Ok(SearchResult::Exact(selection)) } + async fn calculate_aggregate( + &self, + query: &dyn AnyAggregateQuery, + _filter: Option, + _total_rows: u64, + _metrics: &dyn MetricsCollector, + ) -> Result { + Err(Error::invalid_input(format!( + "this index cannot accelerate the aggregate {query:?}" + ))) + } + fn can_remap(&self) -> bool { true } diff --git a/rust/lance-index/src/scalar/bloomfilter.rs b/rust/lance-index/src/scalar/bloomfilter.rs index b861340b9ee..d491ddb80be 100644 --- a/rust/lance-index/src/scalar/bloomfilter.rs +++ b/rust/lance-index/src/scalar/bloomfilter.rs @@ -7,6 +7,7 @@ //! It is a space-efficient data structure that can be used to test whether an element is a member of a set. //! It's an inexact filter - they may include false positives that require rechecking. +use crate::expression::aggregate::AnyAggregateQuery; use crate::scalar::bloomfilter::sbbf::{Sbbf, SbbfBuilder}; use crate::scalar::expression::{BloomFilterQueryParser, ScalarQueryParser}; use crate::scalar::registry::{ @@ -20,6 +21,7 @@ use arrow_array::{Array, UInt64Array}; mod as_bytes; pub mod sbbf; use arrow_schema::{DataType, Field}; +use lance_arrow_scalar::ArrowScalar; use serde::{Deserialize, Serialize}; use std::sync::LazyLock; @@ -439,6 +441,18 @@ impl ScalarIndex for BloomFilterIndex { )) } + async fn calculate_aggregate( + &self, + query: &dyn AnyAggregateQuery, + _filter: Option, + _total_rows: u64, + _metrics: &dyn MetricsCollector, + ) -> Result { + Err(Error::invalid_input(format!( + "this index cannot accelerate the aggregate {query:?}" + ))) + } + async fn update( &self, new_data: SendableRecordBatchStream, diff --git a/rust/lance-index/src/scalar/btree.rs b/rust/lance-index/src/scalar/btree.rs index f2be3241b85..f83a62b3211 100644 --- a/rust/lance-index/src/scalar/btree.rs +++ b/rust/lance-index/src/scalar/btree.rs @@ -15,7 +15,7 @@ use super::{ OldIndexDataFilter, SargableQuery, ScalarIndex, ScalarIndexParams, SearchResult, compute_next_prefix, }; -use crate::{Index, IndexType}; +use crate::{Index, IndexType, expression::aggregate::AnyAggregateQuery}; use crate::{ frag_reuse::FragReuseIndex, progress::{IndexBuildProgress, noop_progress}, @@ -44,6 +44,7 @@ use futures::{ future::BoxFuture, stream::{self}, }; +use lance_arrow_scalar::ArrowScalar; use lance_core::{ Error, ROW_ID, Result, cache::{CacheKey, LanceCache, WeakLanceCache}, @@ -1618,6 +1619,18 @@ impl ScalarIndex for BTreeIndex { Ok(SearchResult::Exact(selection)) } + async fn calculate_aggregate( + &self, + query: &dyn AnyAggregateQuery, + _filter: Option, + _total_rows: u64, + _metrics: &dyn MetricsCollector, + ) -> Result { + Err(Error::invalid_input(format!( + "this index cannot accelerate the aggregate {query:?}" + ))) + } + fn can_remap(&self) -> bool { true } diff --git a/rust/lance-index/src/scalar/expression.rs b/rust/lance-index/src/scalar/expression.rs index d8f61d7dabc..e73e898ff76 100644 --- a/rust/lance-index/src/scalar/expression.rs +++ b/rust/lance-index/src/scalar/expression.rs @@ -1,3105 +1,2 @@ -// SPDX-License-Identifier: Apache-2.0 -// SPDX-FileCopyrightText: Copyright The Lance Authors - -use std::{ - ops::Bound, - sync::{Arc, LazyLock}, -}; - -use arrow::array::BinaryBuilder; -use arrow_array::{Array, RecordBatch, UInt32Array}; -use arrow_schema::{DataType, Field, Schema, SchemaRef}; -use async_recursion::async_recursion; -use async_trait::async_trait; -use datafusion_common::ScalarValue; -use datafusion_expr::{ - Between, BinaryExpr, Expr, Operator, ReturnFieldArgs, ScalarUDF, - expr::{InList, Like, ScalarFunction}, -}; -use tokio::try_join; - -use super::{ - AnyQuery, BloomFilterQuery, LabelListQuery, MetricsCollector, SargableQuery, ScalarIndex, - SearchResult, TextQuery, TokenQuery, -}; -#[cfg(feature = "geo")] -use super::{GeoQuery, RelationQuery}; -use lance_core::{ - Error, Result, - utils::mask::{NullableRowAddrMask, RowAddrMask}, -}; -use lance_datafusion::{expr::safe_coerce_scalar, planner::Planner}; -use roaring::RoaringBitmap; -use tracing::instrument; - -const MAX_DEPTH: usize = 500; - -/// An indexed expression consists of a scalar index query with a post-scan filter -/// -/// When a user wants to filter the data returned by a scan we may be able to use -/// one or more scalar indices to reduce the amount of data we load from the disk. -/// -/// For example, if a user provides the filter "x = 7", and we have a scalar index -/// on x, then we can possibly identify the exact row that the user desires with our -/// index. A full-table scan can then turn into a take operation fetching the rows -/// desired. This would create an IndexedExpression with a scalar_query but no -/// refine. -/// -/// If the user asked for "type = 'dog' && z = 3" and we had a scalar index on the -/// "type" column then we could convert this to an indexed scan for "type='dog'" -/// followed by an in-memory filter for z=3. This would create an IndexedExpression -/// with both a scalar_query AND a refine. -/// -/// Finally, if the user asked for "z = 3" and we do not have a scalar index on the -/// "z" column then we must fallback to an IndexedExpression with no scalar_query and -/// only a refine. -/// -/// Two IndexedExpressions can be AND'd together. Each part is AND'd together. -/// Two IndexedExpressions cannot be OR'd together unless both are scalar_query only -/// or both are refine only -/// An IndexedExpression cannot be negated if it has both a refine and a scalar_query -/// -/// When an operation cannot be performed we fallback to the original expression-only -/// representation -#[derive(Debug, PartialEq)] -pub struct IndexedExpression { - /// The portion of the query that can be satisfied by scalar indices - pub scalar_query: Option, - /// The portion of the query that cannot be satisfied by scalar indices - pub refine_expr: Option, -} - -pub trait ScalarQueryParser: std::fmt::Debug + Send + Sync { - /// Visit a between expression - /// - /// Returns an IndexedExpression if the index can accelerate between expressions - fn visit_between( - &self, - column: &str, - low: &Bound, - high: &Bound, - ) -> Option; - /// Visit an in list expression - /// - /// Returns an IndexedExpression if the index can accelerate in list expressions - fn visit_in_list(&self, column: &str, in_list: &[ScalarValue]) -> Option; - /// Visit an is bool expression - /// - /// Returns an IndexedExpression if the index can accelerate is bool expressions - fn visit_is_bool(&self, column: &str, value: bool) -> Option; - /// Visit an is null expression - /// - /// Returns an IndexedExpression if the index can accelerate is null expressions - fn visit_is_null(&self, column: &str) -> Option; - /// Visit a comparison expression - /// - /// Returns an IndexedExpression if the index can accelerate comparison expressions - fn visit_comparison( - &self, - column: &str, - value: &ScalarValue, - op: &Operator, - ) -> Option; - /// Visit a scalar function expression - /// - /// Returns an IndexedExpression if the index can accelerate the given scalar function. - /// For example, an ngram index can accelerate the contains function. - fn visit_scalar_function( - &self, - column: &str, - data_type: &DataType, - func: &ScalarUDF, - args: &[Expr], - ) -> Option; - - /// Visit a LIKE expression - /// - /// Returns an IndexedExpression if the index can accelerate LIKE expressions. - /// For prefix patterns (e.g., "foo%"): - /// - ZoneMaps prune zones based on min/max statistics - /// - BTrees use range query conversion `[prefix, next_prefix)` - /// - /// For patterns with wildcards in the middle (e.g., "foo%bar%"), the leading prefix - /// can still be used for pruning, with the full pattern as a refine expression. - /// - /// # Arguments - /// * `column` - The column name - /// * `like` - The full LIKE expression (for constructing refine_expr if needed) - /// * `pattern` - The LIKE pattern as ScalarValue (e.g., "foo%") - fn visit_like( - &self, - _column: &str, - _like: &Like, - _pattern: &ScalarValue, - ) -> Option { - None - } - - /// Visits a potential reference to a column - /// - /// This function is a little different from the other visitors. It is used to test if a potential - /// column reference is a reference the index handles. - /// - /// Most indexes are designed to run on references to the indexed column. For example, if a query - /// is "x = 7" and we have a scalar index on "x" then we apply the index to the "x" column reference. - /// - /// However, some indexes are designed to run on projections of the indexed column. For example, - /// if a query is "json_extract(json, '$.name') = 'books'" and we have a JSON index on the "json" column - /// then we apply the index to the projection of the "json" column. - /// - /// This function is used to test if a potential column reference is a reference the index handles. - /// The default implementation matches column references but this can be overridden by indexes that - /// handle projections. - /// - /// The function is also passed in the data type of the column and should return the data type of the - /// reference. Normally this is the same as the input for a direct column reference and possibly something - /// different for a projection. E.g. a JSON column (LargeBinary) might be projected to a string or float - /// - /// Note: higher logic in the expression parser already limits references to either Expr::Column or Expr::ScalarFunction - /// where the first argument is an Expr::Column. If your projection doesn't fit that mold then the - /// expression parser will need to be modified. - fn is_valid_reference(&self, func: &Expr, data_type: &DataType) -> Option { - match func { - Expr::Column(_) => Some(data_type.clone()), - _ => None, - } - } -} - -/// A generic parser that wraps multiple scalar query parsers -/// -/// It will search each parser in order and return the first non-None result -#[derive(Debug)] -pub struct MultiQueryParser { - parsers: Vec>, -} - -impl MultiQueryParser { - /// Create a new MultiQueryParser with a single parser - pub fn single(parser: Box) -> Self { - Self { - parsers: vec![parser], - } - } - - /// Add a new parser to the MultiQueryParser - pub fn add(&mut self, other: Box) { - self.parsers.push(other); - } -} - -impl ScalarQueryParser for MultiQueryParser { - fn visit_between( - &self, - column: &str, - low: &Bound, - high: &Bound, - ) -> Option { - self.parsers - .iter() - .find_map(|parser| parser.visit_between(column, low, high)) - } - fn visit_in_list(&self, column: &str, in_list: &[ScalarValue]) -> Option { - self.parsers - .iter() - .find_map(|parser| parser.visit_in_list(column, in_list)) - } - fn visit_is_bool(&self, column: &str, value: bool) -> Option { - self.parsers - .iter() - .find_map(|parser| parser.visit_is_bool(column, value)) - } - fn visit_is_null(&self, column: &str) -> Option { - self.parsers - .iter() - .find_map(|parser| parser.visit_is_null(column)) - } - fn visit_comparison( - &self, - column: &str, - value: &ScalarValue, - op: &Operator, - ) -> Option { - self.parsers - .iter() - .find_map(|parser| parser.visit_comparison(column, value, op)) - } - fn visit_scalar_function( - &self, - column: &str, - data_type: &DataType, - func: &ScalarUDF, - args: &[Expr], - ) -> Option { - self.parsers - .iter() - .find_map(|parser| parser.visit_scalar_function(column, data_type, func, args)) - } - fn visit_like( - &self, - column: &str, - like: &Like, - pattern: &ScalarValue, - ) -> Option { - self.parsers - .iter() - .find_map(|parser| parser.visit_like(column, like, pattern)) - } - /// TODO(low-priority): This is maybe not quite right. We should filter down the list of parsers based - /// on those that consider the reference valid. Instead what we are doing is checking all parsers if any one - /// parser considers the reference valid. - /// - /// This will be a problem if the user creates two indexes (e.g. btree and json) on the same column and those two - /// indexes have different reference schemes. - fn is_valid_reference(&self, func: &Expr, data_type: &DataType) -> Option { - self.parsers - .iter() - .find_map(|parser| parser.is_valid_reference(func, data_type)) - } -} - -/// A parser for indices that handle SARGable queries -#[derive(Debug)] -pub struct SargableQueryParser { - index_name: String, - index_type: String, - needs_recheck: bool, -} - -impl SargableQueryParser { - pub fn new(index_name: String, index_type: String, needs_recheck: bool) -> Self { - Self { - index_name, - index_type, - needs_recheck, - } - } -} - -impl ScalarQueryParser for SargableQueryParser { - fn is_valid_reference(&self, func: &Expr, data_type: &DataType) -> Option { - match func { - Expr::Column(_) => Some(data_type.clone()), - // Also accept get_field expressions for nested field access - Expr::ScalarFunction(udf) if udf.name() == "get_field" => Some(data_type.clone()), - _ => None, - } - } - - fn visit_between( - &self, - column: &str, - low: &Bound, - high: &Bound, - ) -> Option { - if let Bound::Included(val) | Bound::Excluded(val) = low - && val.is_null() - { - return None; - } - if let Bound::Included(val) | Bound::Excluded(val) = high - && val.is_null() - { - return None; - } - let query = SargableQuery::Range(low.clone(), high.clone()); - Some(IndexedExpression::index_query_with_recheck( - column.to_string(), - self.index_name.clone(), - self.index_type.clone(), - Arc::new(query), - self.needs_recheck, - )) - } - - fn visit_in_list(&self, column: &str, in_list: &[ScalarValue]) -> Option { - if in_list.iter().any(|val| val.is_null()) { - return None; - } - let query = SargableQuery::IsIn(in_list.to_vec()); - Some(IndexedExpression::index_query_with_recheck( - column.to_string(), - self.index_name.clone(), - self.index_type.clone(), - Arc::new(query), - self.needs_recheck, - )) - } - - fn visit_is_bool(&self, column: &str, value: bool) -> Option { - Some(IndexedExpression::index_query_with_recheck( - column.to_string(), - self.index_name.clone(), - self.index_type.clone(), - Arc::new(SargableQuery::Equals(ScalarValue::Boolean(Some(value)))), - self.needs_recheck, - )) - } - - fn visit_is_null(&self, column: &str) -> Option { - Some(IndexedExpression::index_query_with_recheck( - column.to_string(), - self.index_name.clone(), - self.index_type.clone(), - Arc::new(SargableQuery::IsNull()), - self.needs_recheck, - )) - } - - fn visit_comparison( - &self, - column: &str, - value: &ScalarValue, - op: &Operator, - ) -> Option { - if value.is_null() { - return None; - } - let query = match op { - Operator::Lt => SargableQuery::Range(Bound::Unbounded, Bound::Excluded(value.clone())), - Operator::LtEq => { - SargableQuery::Range(Bound::Unbounded, Bound::Included(value.clone())) - } - Operator::Gt => SargableQuery::Range(Bound::Excluded(value.clone()), Bound::Unbounded), - Operator::GtEq => { - SargableQuery::Range(Bound::Included(value.clone()), Bound::Unbounded) - } - Operator::Eq => SargableQuery::Equals(value.clone()), - // This will be negated by the caller - Operator::NotEq => SargableQuery::Equals(value.clone()), - _ => unreachable!(), - }; - Some(IndexedExpression::index_query_with_recheck( - column.to_string(), - self.index_name.clone(), - self.index_type.clone(), - Arc::new(query), - self.needs_recheck, - )) - } - - fn visit_scalar_function( - &self, - column: &str, - _data_type: &DataType, - func: &ScalarUDF, - args: &[Expr], - ) -> Option { - // Handle starts_with(col, 'prefix') -> convert to LikePrefix query - if func.name() == "starts_with" && args.len() == 2 { - // Extract the prefix from the second argument - let prefix = match &args[1] { - Expr::Literal(ScalarValue::Utf8(Some(s)), _) => ScalarValue::Utf8(Some(s.clone())), - Expr::Literal(ScalarValue::LargeUtf8(Some(s)), _) => { - ScalarValue::LargeUtf8(Some(s.clone())) - } - _ => return None, - }; - - let query = SargableQuery::LikePrefix(prefix); - return Some(IndexedExpression::index_query_with_recheck( - column.to_string(), - self.index_name.clone(), - self.index_type.clone(), - Arc::new(query), - self.needs_recheck, - )); - } - - None - } - - fn visit_like( - &self, - column: &str, - like: &Like, - pattern: &ScalarValue, - ) -> Option { - // Case-insensitive LIKE (ILIKE) cannot be efficiently pruned with zone maps - if like.case_insensitive { - return None; - } - - // Extract the pattern string - let pattern_str = match pattern { - ScalarValue::Utf8(Some(s)) => s.as_str(), - ScalarValue::LargeUtf8(Some(s)) => s.as_str(), - _ => return None, - }; - - // Try to extract a prefix from the LIKE pattern - let (prefix, needs_refine) = extract_like_leading_prefix(pattern_str, like.escape_char)?; - - // Create the prefix ScalarValue with the same type as the pattern - let prefix_value = match pattern { - ScalarValue::Utf8(_) => ScalarValue::Utf8(Some(prefix)), - ScalarValue::LargeUtf8(_) => ScalarValue::LargeUtf8(Some(prefix)), - _ => return None, - }; - - let query = SargableQuery::LikePrefix(prefix_value); - let scalar_query = Some(ScalarIndexExpr::Query(ScalarIndexSearch { - column: column.to_string(), - index_name: self.index_name.clone(), - index_type: self.index_type.clone(), - query: Arc::new(query), - needs_recheck: self.needs_recheck, - })); - - // If the pattern has wildcards beyond simple prefix, add refine expression - let refine_expr = if needs_refine { - Some(Expr::Like(like.clone())) - } else { - None - }; - - Some(IndexedExpression { - scalar_query, - refine_expr, - }) - } -} - -/// Extract the leading literal prefix from a LIKE pattern. -/// -/// Returns `Some((prefix, needs_refine))` where: -/// - `prefix` is the leading literal portion before any wildcards -/// - `needs_refine` is true if the pattern has wildcards beyond a simple trailing `%` -/// -/// Returns `None` if the pattern starts with a wildcard (no leading literal). -/// -/// Examples: -/// - "foo%" -> Some(("foo", false)) - pure prefix, no recheck needed -/// - "foo%bar%" -> Some(("foo", true)) - can use prefix for pruning, needs recheck -/// - "foo_bar%" -> Some(("foo", true)) - _ is a wildcard, needs recheck -/// - "foo\%bar%" with escape '\' -> Some(("foo%bar", false)) - escaped %, pure prefix -/// - "%foo" -> None - starts with wildcard, cannot prune -/// - "foo" -> None - no wildcard at all, use equality instead -fn extract_like_leading_prefix(pattern: &str, escape_char: Option) -> Option<(String, bool)> { - let chars: Vec = pattern.chars().collect(); - let len = chars.len(); - - if len == 0 { - return None; - } - - // DataFusion's starts_with simplification escapes special characters with backslash - // but doesn't set escape_char. Use backslash as default escape character. - // Pattern: starts_with(col, 'test_ns$') -> col LIKE 'test\_ns$%' (escape_char: None) - // See: https://github.com/apache/datafusion/issues/XXXX - let effective_escape_char = escape_char.or(Some('\\')); - - // Helper to check if a character at position i is escaped - let is_escaped = |i: usize| -> bool { - if let Some(esc) = effective_escape_char { - if i > 0 && chars[i - 1] == esc { - // Check if the escape char itself is escaped - if i >= 2 && chars[i - 2] == esc { - false // Escape was escaped, so this char is NOT escaped - } else { - true // This char is escaped - } - } else { - false - } - } else { - // No escape character defined - nothing can be escaped - false - } - }; - - // Pattern must contain at least one unescaped wildcard - let has_wildcard = chars.iter().enumerate().any(|(i, &c)| { - if c != '%' && c != '_' { - return false; - } - !is_escaped(i) - }); - - if !has_wildcard { - return None; // No wildcards, should use equality - } - - // Check if pattern starts with an unescaped wildcard - if chars[0] == '%' || chars[0] == '_' { - return None; // Starts with wildcard, cannot prune - } - - // Extract the leading literal prefix (everything before first unescaped wildcard) - let mut prefix = String::new(); - let mut i = 0; - let mut found_wildcard = false; - - while i < len { - let c = chars[i]; - - // Check for escape character (using effective escape char which may be inferred) - if let Some(esc) = effective_escape_char - && c == esc - && i + 1 < len - { - let next = chars[i + 1]; - if next == '%' || next == '_' || next == esc { - // Escaped character - add the literal character - prefix.push(next); - i += 2; - continue; - } - } - - // Check for unescaped wildcard - if c == '%' || c == '_' { - found_wildcard = true; - break; - } - - prefix.push(c); - i += 1; - } - - if prefix.is_empty() { - return None; - } - - // Check if pattern is just a simple prefix (ends with single % and nothing after) - let needs_refine = if found_wildcard && i < len { - // Check if we're at a % wildcard - if chars[i] == '%' && i + 1 == len { - // Pattern is "prefix%" - pure prefix match, no refine needed - false - } else { - // Pattern has more after first wildcard, or has _ wildcard - true - } - } else { - // No wildcard found (shouldn't happen due to earlier check) - false - }; - - Some((prefix, needs_refine)) -} - -/// A parser for bloom filter indices that only support equals, is_null, and is_in operations -#[derive(Debug)] -pub struct BloomFilterQueryParser { - index_name: String, - index_type: String, - needs_recheck: bool, -} - -impl BloomFilterQueryParser { - pub fn new(index_name: String, index_type: String, needs_recheck: bool) -> Self { - Self { - index_name, - index_type, - needs_recheck, - } - } -} - -impl ScalarQueryParser for BloomFilterQueryParser { - fn visit_between( - &self, - _: &str, - _: &Bound, - _: &Bound, - ) -> Option { - // Bloom filters don't support range queries - None - } - - fn visit_in_list(&self, column: &str, in_list: &[ScalarValue]) -> Option { - let query = BloomFilterQuery::IsIn(in_list.to_vec()); - Some(IndexedExpression::index_query_with_recheck( - column.to_string(), - self.index_name.clone(), - self.index_type.clone(), - Arc::new(query), - self.needs_recheck, - )) - } - - fn visit_is_bool(&self, column: &str, value: bool) -> Option { - Some(IndexedExpression::index_query_with_recheck( - column.to_string(), - self.index_name.clone(), - self.index_type.clone(), - Arc::new(BloomFilterQuery::Equals(ScalarValue::Boolean(Some(value)))), - self.needs_recheck, - )) - } - - fn visit_is_null(&self, column: &str) -> Option { - Some(IndexedExpression::index_query_with_recheck( - column.to_string(), - self.index_name.clone(), - self.index_type.clone(), - Arc::new(BloomFilterQuery::IsNull()), - self.needs_recheck, - )) - } - - fn visit_comparison( - &self, - column: &str, - value: &ScalarValue, - op: &Operator, - ) -> Option { - let query = match op { - // Bloom filters only support equality comparisons - Operator::Eq => BloomFilterQuery::Equals(value.clone()), - // This will be negated by the caller - Operator::NotEq => BloomFilterQuery::Equals(value.clone()), - // Bloom filters don't support range operations - _ => return None, - }; - Some(IndexedExpression::index_query_with_recheck( - column.to_string(), - self.index_name.clone(), - self.index_type.clone(), - Arc::new(query), - self.needs_recheck, - )) - } - - fn visit_scalar_function( - &self, - _: &str, - _: &DataType, - _: &ScalarUDF, - _: &[Expr], - ) -> Option { - // Bloom filters don't support scalar functions - None - } -} - -/// A parser for indices that handle label list queries -#[derive(Debug)] -pub struct LabelListQueryParser { - index_name: String, - index_type: String, -} - -impl LabelListQueryParser { - pub fn new(index_name: String, index_type: String) -> Self { - Self { - index_name, - index_type, - } - } -} - -impl ScalarQueryParser for LabelListQueryParser { - fn visit_between( - &self, - _: &str, - _: &Bound, - _: &Bound, - ) -> Option { - None - } - - fn visit_in_list(&self, _: &str, _: &[ScalarValue]) -> Option { - None - } - - fn visit_is_bool(&self, _: &str, _: bool) -> Option { - None - } - - fn visit_is_null(&self, _: &str) -> Option { - None - } - - fn visit_comparison( - &self, - _: &str, - _: &ScalarValue, - _: &Operator, - ) -> Option { - None - } - - fn visit_scalar_function( - &self, - column: &str, - data_type: &DataType, - func: &ScalarUDF, - args: &[Expr], - ) -> Option { - if args.len() != 2 { - return None; - } - // DataFusion normalizes array_contains to array_has - if func.name() == "array_has" { - let inner_type = match data_type { - DataType::List(field) | DataType::LargeList(field) => field.data_type(), - _ => return None, - }; - let scalar = maybe_scalar(&args[1], inner_type)?; - // array_has(..., NULL) returns no matches in datafusion, but the index would - // match rows containing NULL. Fallback to match datafusion behavior. - if scalar.is_null() { - return None; - } - let query = LabelListQuery::HasAnyLabel(vec![scalar]); - return Some(IndexedExpression::index_query( - column.to_string(), - self.index_name.clone(), - self.index_type.clone(), - Arc::new(query), - )); - } - - let label_list = maybe_scalar(&args[1], data_type)?; - if let ScalarValue::List(list_arr) = label_list { - let list_values = list_arr.values(); - if list_values.is_empty() { - return None; - } - let mut scalars = Vec::with_capacity(list_values.len()); - for idx in 0..list_values.len() { - scalars.push(ScalarValue::try_from_array(list_values.as_ref(), idx).ok()?); - } - if func.name() == "array_has_all" { - let query = LabelListQuery::HasAllLabels(scalars); - Some(IndexedExpression::index_query( - column.to_string(), - self.index_name.clone(), - self.index_type.clone(), - Arc::new(query), - )) - } else if func.name() == "array_has_any" { - let query = LabelListQuery::HasAnyLabel(scalars); - Some(IndexedExpression::index_query( - column.to_string(), - self.index_name.clone(), - self.index_type.clone(), - Arc::new(query), - )) - } else { - None - } - } else { - None - } - } -} - -/// A parser for indices that handle string contains queries -#[derive(Debug, Clone)] -pub struct TextQueryParser { - index_name: String, - index_type: String, - needs_recheck: bool, -} - -impl TextQueryParser { - pub fn new(index_name: String, index_type: String, needs_recheck: bool) -> Self { - Self { - index_name, - index_type, - needs_recheck, - } - } -} - -impl ScalarQueryParser for TextQueryParser { - fn visit_between( - &self, - _: &str, - _: &Bound, - _: &Bound, - ) -> Option { - None - } - - fn visit_in_list(&self, _: &str, _: &[ScalarValue]) -> Option { - None - } - - fn visit_is_bool(&self, _: &str, _: bool) -> Option { - None - } - - fn visit_is_null(&self, _: &str) -> Option { - None - } - - fn visit_comparison( - &self, - _: &str, - _: &ScalarValue, - _: &Operator, - ) -> Option { - None - } - - fn visit_scalar_function( - &self, - column: &str, - data_type: &DataType, - func: &ScalarUDF, - args: &[Expr], - ) -> Option { - if args.len() != 2 { - return None; - } - let scalar = maybe_scalar(&args[1], data_type)?; - match scalar { - ScalarValue::Utf8(Some(scalar_str)) | ScalarValue::LargeUtf8(Some(scalar_str)) => { - if func.name() == "contains" { - let query = TextQuery::StringContains(scalar_str); - Some(IndexedExpression::index_query_with_recheck( - column.to_string(), - self.index_name.clone(), - self.index_type.clone(), - Arc::new(query), - self.needs_recheck, - )) - } else { - None - } - } - _ => { - // If the scalar is not a string, we cannot handle it - None - } - } - } -} - -/// A parser for indices that handle queries with the contains_tokens function -#[derive(Debug, Clone)] -pub struct FtsQueryParser { - index_name: String, - index_type: String, -} - -impl FtsQueryParser { - pub fn new(name: String, index_type: String) -> Self { - Self { - index_name: name, - index_type, - } - } -} - -impl ScalarQueryParser for FtsQueryParser { - fn visit_between( - &self, - _: &str, - _: &Bound, - _: &Bound, - ) -> Option { - None - } - - fn visit_in_list(&self, _: &str, _: &[ScalarValue]) -> Option { - None - } - - fn visit_is_bool(&self, _: &str, _: bool) -> Option { - None - } - - fn visit_is_null(&self, _: &str) -> Option { - None - } - - fn visit_comparison( - &self, - _: &str, - _: &ScalarValue, - _: &Operator, - ) -> Option { - None - } - - fn visit_scalar_function( - &self, - column: &str, - data_type: &DataType, - func: &ScalarUDF, - args: &[Expr], - ) -> Option { - if args.len() != 2 { - return None; - } - let scalar = maybe_scalar(&args[1], data_type)?; - if let ScalarValue::Utf8(Some(scalar_str)) = scalar - && func.name() == "contains_tokens" - { - let query = TokenQuery::TokensContains(scalar_str); - return Some(IndexedExpression::index_query( - column.to_string(), - self.index_name.clone(), - self.index_type.clone(), - Arc::new(query), - )); - } - None - } -} - -/// A parser for geo indices that handles spatial queries -#[cfg(feature = "geo")] -#[derive(Debug, Clone)] -pub struct GeoQueryParser { - index_name: String, - index_type: String, -} - -#[cfg(feature = "geo")] -impl GeoQueryParser { - pub fn new(index_name: String, index_type: String) -> Self { - Self { - index_name, - index_type, - } - } -} - -#[cfg(feature = "geo")] -impl ScalarQueryParser for GeoQueryParser { - fn visit_between( - &self, - _: &str, - _: &Bound, - _: &Bound, - ) -> Option { - None - } - - fn visit_in_list(&self, _: &str, _: &[ScalarValue]) -> Option { - None - } - - fn visit_is_bool(&self, _: &str, _: bool) -> Option { - None - } - - fn visit_is_null(&self, column: &str) -> Option { - Some(IndexedExpression::index_query_with_recheck( - column.to_string(), - self.index_name.clone(), - self.index_type.clone(), - Arc::new(GeoQuery::IsNull), - true, - )) - } - - fn visit_comparison( - &self, - _: &str, - _: &ScalarValue, - _: &Operator, - ) -> Option { - None - } - - fn visit_scalar_function( - &self, - column: &str, - _data_type: &DataType, - func: &ScalarUDF, - args: &[Expr], - ) -> Option { - if (func.name() == "st_intersects" - || func.name() == "st_contains" - || func.name() == "st_within" - || func.name() == "st_touches" - || func.name() == "st_crosses" - || func.name() == "st_overlaps" - || func.name() == "st_covers" - || func.name() == "st_coveredby") - && args.len() == 2 - { - let left_arg = &args[0]; - let right_arg = &args[1]; - return match (left_arg, right_arg) { - (Expr::Literal(left_value, metadata), Expr::Column(_)) => { - let mut field = Field::new("_geo", left_value.data_type(), false); - if let Some(metadata) = metadata { - field = field.with_metadata(metadata.to_hashmap()); - } - let query = GeoQuery::IntersectQuery(RelationQuery { - value: left_value.clone(), - field, - }); - Some(IndexedExpression::index_query_with_recheck( - column.to_string(), - self.index_name.clone(), - self.index_type.clone(), - Arc::new(query), - true, - )) - } - (Expr::Column(_), Expr::Literal(right_value, metadata)) => { - let mut field = Field::new("_geo", right_value.data_type(), false); - if let Some(metadata) = metadata { - field = field.with_metadata(metadata.to_hashmap()); - } - let query = GeoQuery::IntersectQuery(RelationQuery { - value: right_value.clone(), - field, - }); - Some(IndexedExpression::index_query_with_recheck( - column.to_string(), - self.index_name.clone(), - self.index_type.clone(), - Arc::new(query), - true, - )) - } - _ => None, - }; - } - None - } -} - -impl IndexedExpression { - /// Create an expression that only does refine - fn refine_only(refine_expr: Expr) -> Self { - Self { - scalar_query: None, - refine_expr: Some(refine_expr), - } - } - - /// Create an expression that is only an index query - fn index_query( - column: String, - index_name: String, - index_type: String, - query: Arc, - ) -> Self { - Self { - scalar_query: Some(ScalarIndexExpr::Query(ScalarIndexSearch { - column, - index_name, - index_type, - query, - needs_recheck: false, // Default to false, will be set by parser - })), - refine_expr: None, - } - } - - /// Create an expression that is only an index query with explicit needs_recheck - fn index_query_with_recheck( - column: String, - index_name: String, - index_type: String, - query: Arc, - needs_recheck: bool, - ) -> Self { - Self { - scalar_query: Some(ScalarIndexExpr::Query(ScalarIndexSearch { - column, - index_name, - index_type, - query, - needs_recheck, - })), - refine_expr: None, - } - } - - /// Try and negate the expression - /// - /// If the expression contains both an index query and a refine expression then it - /// cannot be negated today and None will be returned (we give up trying to use indices) - fn maybe_not(self) -> Option { - match (self.scalar_query, self.refine_expr) { - (Some(_), Some(_)) => None, - (Some(scalar_query), None) => { - if scalar_query.needs_recheck() { - return None; - } - Some(Self { - scalar_query: Some(ScalarIndexExpr::Not(Box::new(scalar_query))), - refine_expr: None, - }) - } - (None, Some(refine_expr)) => Some(Self { - scalar_query: None, - refine_expr: Some(Expr::Not(Box::new(refine_expr))), - }), - (None, None) => panic!("Empty node should not occur"), - } - } - - /// Perform a logical AND of two indexed expressions - /// - /// This is straightforward because we can just AND the individual parts - /// because (A && B) && (C && D) == (A && C) && (B && D) - fn and(self, other: Self) -> Self { - let scalar_query = match (self.scalar_query, other.scalar_query) { - (Some(scalar_query), Some(other_scalar_query)) => Some(ScalarIndexExpr::And( - Box::new(scalar_query), - Box::new(other_scalar_query), - )), - (Some(scalar_query), None) => Some(scalar_query), - (None, Some(scalar_query)) => Some(scalar_query), - (None, None) => None, - }; - let refine_expr = match (self.refine_expr, other.refine_expr) { - (Some(refine_expr), Some(other_refine_expr)) => { - Some(refine_expr.and(other_refine_expr)) - } - (Some(refine_expr), None) => Some(refine_expr), - (None, Some(refine_expr)) => Some(refine_expr), - (None, None) => None, - }; - Self { - scalar_query, - refine_expr, - } - } - - /// Try and perform a logical OR of two indexed expressions - /// - /// This is a bit tricky because something like: - /// (color == 'blue' AND size < 20) OR (color == 'green' AND size < 50) - /// is not equivalent to: - /// (color == 'blue' OR color == 'green') AND (size < 20 OR size < 50) - fn maybe_or(self, other: Self) -> Option { - // If either expression is missing a scalar_query then we need to load all rows from - // the database and so we short-circuit and return None - let scalar_query = self.scalar_query?; - let other_scalar_query = other.scalar_query?; - let scalar_query = Some(ScalarIndexExpr::Or( - Box::new(scalar_query), - Box::new(other_scalar_query), - )); - - let refine_expr = match (self.refine_expr, other.refine_expr) { - // TODO - // - // To handle these cases we need a way of going back from a scalar expression query to a logical DF expression (perhaps - // we can store the expression that led to the creation of the query) - // - // For example, imagine we have something like "(color == 'blue' AND size < 20) OR (color == 'green' AND size < 50)" - // - // We can do an indexed load of all rows matching "color == 'blue' OR color == 'green'" but then we need to - // refine that load with the full original expression which, at the moment, we no longer have. - (Some(_), Some(_)) => { - return None; - } - (Some(_), None) => { - return None; - } - (None, Some(_)) => { - return None; - } - (None, None) => None, - }; - Some(Self { - scalar_query, - refine_expr, - }) - } - - fn refine(self, expr: Expr) -> Self { - match self.refine_expr { - Some(refine_expr) => Self { - scalar_query: self.scalar_query, - refine_expr: Some(refine_expr.and(expr)), - }, - None => Self { - scalar_query: self.scalar_query, - refine_expr: Some(expr), - }, - } - } -} - -/// A trait implemented by anything that can load indices by name -/// -/// This is used during the evaluation of an index expression -#[async_trait] -pub trait ScalarIndexLoader: Send + Sync { - /// Load the index with the given name - async fn load_index( - &self, - column: &str, - index_name: &str, - metrics: &dyn MetricsCollector, - ) -> Result>; -} - -/// This represents a search into a scalar index -#[derive(Debug, Clone)] -pub struct ScalarIndexSearch { - /// The column to search (redundant, used for debugging messages) - pub column: String, - /// The name of the index to search - pub index_name: String, - /// The type of the index being searched (e.g. "BTree", "Bitmap"), used for display purposes - pub index_type: String, - /// The query to search for - pub query: Arc, - /// If true, the query results are inexact and will need a recheck - pub needs_recheck: bool, -} - -impl PartialEq for ScalarIndexSearch { - fn eq(&self, other: &Self) -> bool { - self.column == other.column - && self.index_name == other.index_name - && self.query.as_ref().eq(other.query.as_ref()) - } -} - -/// This represents a lookup into one or more scalar indices -/// -/// This is a tree of operations because we may need to logically combine or -/// modify the results of scalar lookups -#[derive(Debug, Clone)] -pub enum ScalarIndexExpr { - Not(Box), - And(Box, Box), - Or(Box, Box), - Query(ScalarIndexSearch), -} - -impl PartialEq for ScalarIndexExpr { - fn eq(&self, other: &Self) -> bool { - match (self, other) { - (Self::Not(l0), Self::Not(r0)) => l0 == r0, - (Self::And(l0, l1), Self::And(r0, r1)) => l0 == r0 && l1 == r1, - (Self::Or(l0, l1), Self::Or(r0, r1)) => l0 == r0 && l1 == r1, - (Self::Query(l_search), Self::Query(r_search)) => l_search == r_search, - _ => false, - } - } -} - -impl std::fmt::Display for ScalarIndexExpr { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - match self { - Self::Not(inner) => write!(f, "NOT({})", inner), - Self::And(lhs, rhs) => write!(f, "AND({},{})", lhs, rhs), - Self::Or(lhs, rhs) => write!(f, "OR({},{})", lhs, rhs), - Self::Query(search) => write!( - f, - "[{}]@{}({})", - search.query.format(&search.column), - search.index_name, - search.index_type - ), - } - } -} - -/// When we evaluate a scalar index query we return a batch with three columns and two rows -/// -/// The first column has the block list and allow list -/// The second column tells if the result is least/exact/more (we repeat the discriminant twice) -/// The third column has the fragments covered bitmap in the first row and null in the second row -pub static INDEX_EXPR_RESULT_SCHEMA: LazyLock = LazyLock::new(|| { - Arc::new(Schema::new(vec![ - Field::new("result".to_string(), DataType::Binary, true), - Field::new("discriminant".to_string(), DataType::UInt32, true), - Field::new("fragments_covered".to_string(), DataType::Binary, true), - ])) -}); - -#[derive(Debug)] -enum NullableIndexExprResult { - Exact(NullableRowAddrMask), - AtMost(NullableRowAddrMask), - AtLeast(NullableRowAddrMask), -} - -impl From for NullableIndexExprResult { - fn from(result: SearchResult) -> Self { - match result { - SearchResult::Exact(mask) => Self::Exact(NullableRowAddrMask::AllowList(mask)), - SearchResult::AtMost(mask) => Self::AtMost(NullableRowAddrMask::AllowList(mask)), - SearchResult::AtLeast(mask) => Self::AtLeast(NullableRowAddrMask::AllowList(mask)), - } - } -} - -impl std::ops::BitAnd for NullableIndexExprResult { - type Output = Self; - - fn bitand(self, rhs: Self) -> Self { - match (self, rhs) { - (Self::Exact(lhs), Self::Exact(rhs)) => Self::Exact(lhs & rhs), - (Self::Exact(lhs), Self::AtMost(rhs)) | (Self::AtMost(lhs), Self::Exact(rhs)) => { - Self::AtMost(lhs & rhs) - } - (Self::Exact(exact), Self::AtLeast(_)) | (Self::AtLeast(_), Self::Exact(exact)) => { - // We could do better here, elements in both lhs and rhs are known - // to be true and don't require a recheck. We only need to recheck - // elements in lhs that are not in rhs - Self::AtMost(exact) - } - (Self::AtMost(lhs), Self::AtMost(rhs)) => Self::AtMost(lhs & rhs), - (Self::AtLeast(lhs), Self::AtLeast(rhs)) => Self::AtLeast(lhs & rhs), - (Self::AtMost(most), Self::AtLeast(_)) | (Self::AtLeast(_), Self::AtMost(most)) => { - Self::AtMost(most) - } - } - } -} - -impl std::ops::BitOr for NullableIndexExprResult { - type Output = Self; - - fn bitor(self, rhs: Self) -> Self { - match (self, rhs) { - (Self::Exact(lhs), Self::Exact(rhs)) => Self::Exact(lhs | rhs), - (Self::Exact(lhs), Self::AtMost(rhs)) | (Self::AtMost(rhs), Self::Exact(lhs)) => { - // We could do better here, elements in lhs are known to be true - // and don't require a recheck. We only need to recheck elements - // in rhs that are not in lhs - Self::AtMost(lhs | rhs) - } - (Self::Exact(lhs), Self::AtLeast(rhs)) | (Self::AtLeast(rhs), Self::Exact(lhs)) => { - Self::AtLeast(lhs | rhs) - } - (Self::AtMost(lhs), Self::AtMost(rhs)) => Self::AtMost(lhs | rhs), - (Self::AtLeast(lhs), Self::AtLeast(rhs)) => Self::AtLeast(lhs | rhs), - (Self::AtMost(_), Self::AtLeast(least)) | (Self::AtLeast(least), Self::AtMost(_)) => { - Self::AtLeast(least) - } - } - } -} - -impl NullableIndexExprResult { - pub fn drop_nulls(self) -> IndexExprResult { - match self { - Self::Exact(mask) => IndexExprResult::Exact(mask.drop_nulls()), - Self::AtMost(mask) => IndexExprResult::AtMost(mask.drop_nulls()), - Self::AtLeast(mask) => IndexExprResult::AtLeast(mask.drop_nulls()), - } - } -} - -#[derive(Debug)] -pub enum IndexExprResult { - // The answer is exactly the rows in the allow list minus the rows in the block list - Exact(RowAddrMask), - // The answer is at most the rows in the allow list minus the rows in the block list - // Some of the rows in the allow list may not be in the result and will need to be filtered - // by a recheck. Every row in the block list is definitely not in the result. - AtMost(RowAddrMask), - // The answer is at least the rows in the allow list minus the rows in the block list - // Some of the rows in the block list might be in the result. Every row in the allow list is - // definitely in the result. - AtLeast(RowAddrMask), -} - -impl IndexExprResult { - pub fn row_addr_mask(&self) -> &RowAddrMask { - match self { - Self::Exact(mask) => mask, - Self::AtMost(mask) => mask, - Self::AtLeast(mask) => mask, - } - } - - pub fn discriminant(&self) -> u32 { - match self { - Self::Exact(_) => 0, - Self::AtMost(_) => 1, - Self::AtLeast(_) => 2, - } - } - - pub fn from_parts(mask: RowAddrMask, discriminant: u32) -> Result { - match discriminant { - 0 => Ok(Self::Exact(mask)), - 1 => Ok(Self::AtMost(mask)), - 2 => Ok(Self::AtLeast(mask)), - _ => Err(Error::invalid_input_source( - format!("Invalid IndexExprResult discriminant: {}", discriminant).into(), - )), - } - } - - #[instrument(skip_all)] - pub fn serialize_to_arrow( - &self, - fragments_covered_by_result: &RoaringBitmap, - ) -> Result { - let row_addr_mask = self.row_addr_mask(); - let row_addr_mask_arr = row_addr_mask.into_arrow()?; - let discriminant = self.discriminant(); - let discriminant_arr = - Arc::new(UInt32Array::from(vec![discriminant, discriminant])) as Arc; - let mut fragments_covered_builder = BinaryBuilder::new(); - let fragments_covered_bytes_len = fragments_covered_by_result.serialized_size(); - let mut fragments_covered_bytes = Vec::with_capacity(fragments_covered_bytes_len); - fragments_covered_by_result.serialize_into(&mut fragments_covered_bytes)?; - fragments_covered_builder.append_value(fragments_covered_bytes); - fragments_covered_builder.append_null(); - let fragments_covered_arr = Arc::new(fragments_covered_builder.finish()) as Arc; - Ok(RecordBatch::try_new( - INDEX_EXPR_RESULT_SCHEMA.clone(), - vec![ - Arc::new(row_addr_mask_arr), - Arc::new(discriminant_arr), - Arc::new(fragments_covered_arr), - ], - )?) - } -} - -impl ScalarIndexExpr { - /// Evaluates the scalar index expression - /// - /// This will result in loading one or more scalar indices and searching them - /// - /// TODO: We could potentially try and be smarter about reusing loaded indices for - /// any situations where the session cache has been disabled. - #[async_recursion] - async fn evaluate_impl( - &self, - index_loader: &dyn ScalarIndexLoader, - metrics: &dyn MetricsCollector, - ) -> Result { - match self { - Self::Not(inner) => { - let result = inner.evaluate_impl(index_loader, metrics).await?; - // Flip certainty: NOT(AtMost) → AtLeast, NOT(AtLeast) → AtMost - Ok(match result { - NullableIndexExprResult::Exact(mask) => NullableIndexExprResult::Exact(!mask), - NullableIndexExprResult::AtMost(mask) => { - NullableIndexExprResult::AtLeast(!mask) - } - NullableIndexExprResult::AtLeast(mask) => { - NullableIndexExprResult::AtMost(!mask) - } - }) - } - Self::And(lhs, rhs) => { - let lhs_result = lhs.evaluate_impl(index_loader, metrics); - let rhs_result = rhs.evaluate_impl(index_loader, metrics); - let (lhs_result, rhs_result) = try_join!(lhs_result, rhs_result)?; - Ok(lhs_result & rhs_result) - } - Self::Or(lhs, rhs) => { - let lhs_result = lhs.evaluate_impl(index_loader, metrics); - let rhs_result = rhs.evaluate_impl(index_loader, metrics); - let (lhs_result, rhs_result) = try_join!(lhs_result, rhs_result)?; - Ok(lhs_result | rhs_result) - } - Self::Query(search) => { - let index = index_loader - .load_index(&search.column, &search.index_name, metrics) - .await?; - let search_result = index.search(search.query.as_ref(), metrics).await?; - Ok(search_result.into()) - } - } - } - - #[instrument(level = "debug", skip_all)] - pub async fn evaluate( - &self, - index_loader: &dyn ScalarIndexLoader, - metrics: &dyn MetricsCollector, - ) -> Result { - Ok(self - .evaluate_impl(index_loader, metrics) - .await? - .drop_nulls()) - } - - pub fn to_expr(&self) -> Expr { - match self { - Self::Not(inner) => Expr::Not(inner.to_expr().into()), - Self::And(lhs, rhs) => { - let lhs = lhs.to_expr(); - let rhs = rhs.to_expr(); - lhs.and(rhs) - } - Self::Or(lhs, rhs) => { - let lhs = lhs.to_expr(); - let rhs = rhs.to_expr(); - lhs.or(rhs) - } - Self::Query(search) => search.query.to_expr(search.column.clone()), - } - } - - pub fn needs_recheck(&self) -> bool { - match self { - Self::Not(inner) => inner.needs_recheck(), - Self::And(lhs, rhs) | Self::Or(lhs, rhs) => lhs.needs_recheck() || rhs.needs_recheck(), - Self::Query(search) => search.needs_recheck, - } - } -} - -// Extract a column from the expression, if it is a column, or None -fn maybe_column(expr: &Expr) -> Option<&str> { - match expr { - Expr::Column(col) => Some(&col.name), - _ => None, - } -} - -// Extract the full nested column path from a get_field expression chain -// For example: get_field(get_field(metadata, "status"), "code") -> "metadata.status.code" -fn extract_nested_column_path(expr: &Expr) -> Option { - let mut current_expr = expr; - let mut parts = Vec::new(); - - // Walk up the get_field chain - loop { - match current_expr { - Expr::ScalarFunction(udf) if udf.name() == "get_field" => { - if udf.args.len() != 2 { - return None; - } - // Extract the field name from the second argument - // The Literal now has two fields: ScalarValue and Option - if let Expr::Literal(ScalarValue::Utf8(Some(field_name)), _) = &udf.args[1] { - parts.push(field_name.clone()); - } else { - return None; - } - // Move up to the parent expression - current_expr = &udf.args[0]; - } - Expr::Column(col) => { - // We've reached the base column - parts.push(col.name.clone()); - break; - } - _ => { - return None; - } - } - } - - // Reverse to get the correct order (parent.child.grandchild) - parts.reverse(); - - // Format the path correctly - let field_refs: Vec<&str> = parts.iter().map(|s| s.as_str()).collect(); - Some(lance_core::datatypes::format_field_path(&field_refs)) -} - -// Extract a column from the expression, if it is a column, and we have an index for that column, or None -// -// There's two ways to get a column. First, the obvious way, is a -// simple column reference (e.g. x = 7). Second, a more complex way, -// is some kind of projection into a column (e.g. json_extract(json, '$.name')). -// Third way is nested field access (e.g. get_field(metadata, "status.code")) -fn maybe_indexed_column<'b>( - expr: &Expr, - index_info: &'b dyn IndexInformationProvider, -) -> Option<(String, DataType, &'b dyn ScalarQueryParser)> { - // First try to extract the full nested column path for get_field expressions - if let Some(nested_path) = extract_nested_column_path(expr) - && let Some((data_type, parser)) = index_info.get_index(&nested_path) - && let Some(data_type) = parser.is_valid_reference(expr, data_type) - { - return Some((nested_path, data_type, parser)); - } - - match expr { - Expr::Column(col) => { - let col = col.name.as_str(); - let (data_type, parser) = index_info.get_index(col)?; - if let Some(data_type) = parser.is_valid_reference(expr, data_type) { - Some((col.to_string(), data_type, parser)) - } else { - None - } - } - Expr::ScalarFunction(udf) => { - if udf.args.is_empty() { - return None; - } - // For non-get_field functions, fall back to old behavior - let col = maybe_column(&udf.args[0])?; - let (data_type, parser) = index_info.get_index(col)?; - if let Some(data_type) = parser.is_valid_reference(expr, data_type) { - Some((col.to_string(), data_type, parser)) - } else { - None - } - } - _ => None, - } -} - -// Extract a literal scalar value from an expression, if it is a literal, or None -fn maybe_scalar(expr: &Expr, expected_type: &DataType) -> Option { - match expr { - Expr::Literal(value, _) => safe_coerce_scalar(value, expected_type), - // Some literals can't be expressed in datafusion's SQL and can only be expressed with - // a cast. For example, there is no way to express a fixed-size-binary literal (which is - // commonly used for UUID). As a result the expression could look like... - // - // col = arrow_cast(value, 'fixed_size_binary(16)') - // - // In this case we need to extract the value, apply the cast, and then test the casted value - Expr::Cast(cast) => match cast.expr.as_ref() { - Expr::Literal(value, _) => { - let casted = value.cast_to(&cast.data_type).ok()?; - safe_coerce_scalar(&casted, expected_type) - } - _ => None, - }, - Expr::ScalarFunction(scalar_function) => { - if scalar_function.name() == "arrow_cast" { - if scalar_function.args.len() != 2 { - return None; - } - match (&scalar_function.args[0], &scalar_function.args[1]) { - (Expr::Literal(value, _), Expr::Literal(cast_type, _)) => { - let target_type = scalar_function - .func - .return_field_from_args(ReturnFieldArgs { - arg_fields: &[ - Arc::new(Field::new("expression", value.data_type(), false)), - Arc::new(Field::new("datatype", cast_type.data_type(), false)), - ], - scalar_arguments: &[Some(value), Some(cast_type)], - }) - .ok()?; - let casted = value.cast_to(target_type.data_type()).ok()?; - safe_coerce_scalar(&casted, expected_type) - } - _ => None, - } - } else { - None - } - } - _ => None, - } -} - -// Extract a list of scalar values from an expression, if it is a list of scalar values, or None -fn maybe_scalar_list(exprs: &Vec, expected_type: &DataType) -> Option> { - let mut scalar_values = Vec::with_capacity(exprs.len()); - for expr in exprs { - match maybe_scalar(expr, expected_type) { - Some(scalar_val) => { - scalar_values.push(scalar_val); - } - None => { - return None; - } - } - } - Some(scalar_values) -} - -fn visit_between( - between: &Between, - index_info: &dyn IndexInformationProvider, -) -> Option { - let (column, col_type, query_parser) = maybe_indexed_column(&between.expr, index_info)?; - let low = maybe_scalar(&between.low, &col_type)?; - let high = maybe_scalar(&between.high, &col_type)?; - - let indexed_expr = - query_parser.visit_between(&column, &Bound::Included(low), &Bound::Included(high))?; - - if between.negated { - indexed_expr.maybe_not() - } else { - Some(indexed_expr) - } -} - -fn visit_in_list( - in_list: &InList, - index_info: &dyn IndexInformationProvider, -) -> Option { - let (column, col_type, query_parser) = maybe_indexed_column(&in_list.expr, index_info)?; - let values = maybe_scalar_list(&in_list.list, &col_type)?; - - let indexed_expr = query_parser.visit_in_list(&column, &values)?; - - if in_list.negated { - indexed_expr.maybe_not() - } else { - Some(indexed_expr) - } -} - -fn visit_is_bool( - expr: &Expr, - index_info: &dyn IndexInformationProvider, - value: bool, -) -> Option { - let (column, col_type, query_parser) = maybe_indexed_column(expr, index_info)?; - if col_type != DataType::Boolean { - None - } else { - query_parser.visit_is_bool(&column, value) - } -} - -// A column can be a valid indexed expression if the column is boolean (e.g. 'WHERE on_sale') -fn visit_column( - col: &Expr, - index_info: &dyn IndexInformationProvider, -) -> Option { - let (column, col_type, query_parser) = maybe_indexed_column(col, index_info)?; - if col_type != DataType::Boolean { - None - } else { - query_parser.visit_is_bool(&column, true) - } -} - -fn visit_is_null( - expr: &Expr, - index_info: &dyn IndexInformationProvider, - negated: bool, -) -> Option { - let (column, _, query_parser) = maybe_indexed_column(expr, index_info)?; - let indexed_expr = query_parser.visit_is_null(&column)?; - if negated { - indexed_expr.maybe_not() - } else { - Some(indexed_expr) - } -} - -fn visit_not( - expr: &Expr, - index_info: &dyn IndexInformationProvider, - depth: usize, -) -> Result> { - let node = visit_node(expr, index_info, depth + 1)?; - Ok(node.and_then(|node| node.maybe_not())) -} - -fn visit_comparison( - expr: &BinaryExpr, - index_info: &dyn IndexInformationProvider, -) -> Option { - let left_col = maybe_indexed_column(&expr.left, index_info); - if let Some((column, col_type, query_parser)) = left_col { - let scalar = maybe_scalar(&expr.right, &col_type)?; - query_parser.visit_comparison(&column, &scalar, &expr.op) - } else { - // Datafusion's query simplifier will canonicalize expressions and so we shouldn't reach this case. If, for some reason, we - // do reach this case we can handle it in the future by inverting expr.op and swapping the left and right sides - None - } -} - -fn maybe_range( - expr: &BinaryExpr, - index_info: &dyn IndexInformationProvider, -) -> Option { - let left_expr = match expr.left.as_ref() { - Expr::BinaryExpr(binary_expr) => Some(binary_expr), - _ => None, - }?; - let right_expr = match expr.right.as_ref() { - Expr::BinaryExpr(binary_expr) => Some(binary_expr), - _ => None, - }?; - - let (left_col, dt, parser) = maybe_indexed_column(&left_expr.left, index_info)?; - let right_col = maybe_column(&right_expr.left)?; - - if left_col != right_col { - return None; - } - - let left_value = maybe_scalar(&left_expr.right, &dt)?; - let right_value = maybe_scalar(&right_expr.right, &dt)?; - - let (low, high) = match (left_expr.op, right_expr.op) { - // x >= a && x <= b - (Operator::GtEq, Operator::LtEq) => { - (Bound::Included(left_value), Bound::Included(right_value)) - } - // x >= a && x < b - (Operator::GtEq, Operator::Lt) => { - (Bound::Included(left_value), Bound::Excluded(right_value)) - } - // x > a && x <= b - (Operator::Gt, Operator::LtEq) => { - (Bound::Excluded(left_value), Bound::Included(right_value)) - } - // x > a && x < b - (Operator::Gt, Operator::Lt) => (Bound::Excluded(left_value), Bound::Excluded(right_value)), - // x <= a && x >= b - (Operator::LtEq, Operator::GtEq) => { - (Bound::Included(right_value), Bound::Included(left_value)) - } - // x <= a && x > b - (Operator::LtEq, Operator::Gt) => { - (Bound::Excluded(right_value), Bound::Included(left_value)) - } - // x < a && x >= b - (Operator::Lt, Operator::GtEq) => { - (Bound::Included(right_value), Bound::Excluded(left_value)) - } - // x < a && x > b - (Operator::Lt, Operator::Gt) => (Bound::Excluded(right_value), Bound::Excluded(left_value)), - _ => return None, - }; - - parser.visit_between(&left_col, &low, &high) -} - -fn visit_and( - expr: &BinaryExpr, - index_info: &dyn IndexInformationProvider, - depth: usize, -) -> Result> { - // Many scalar indices can efficiently handle a BETWEEN query as a single search and this - // can be much more efficient than two separate range queries. As an optimization we check - // to see if this is a between query and, if so, we handle it as a single query - // - // Note: We can't rely on users writing the SQL BETWEEN operator because: - // * Some users won't realize it's an option or a good idea - // * Datafusion's simplifier will rewrite the BETWEEN operator into two separate range queries - if let Some(range_expr) = maybe_range(expr, index_info) { - return Ok(Some(range_expr)); - } - - let left = visit_node(&expr.left, index_info, depth + 1)?; - let right = visit_node(&expr.right, index_info, depth + 1)?; - Ok(match (left, right) { - (Some(left), Some(right)) => Some(left.and(right)), - (Some(left), None) => Some(left.refine((*expr.right).clone())), - (None, Some(right)) => Some(right.refine((*expr.left).clone())), - (None, None) => None, - }) -} - -fn visit_or( - expr: &BinaryExpr, - index_info: &dyn IndexInformationProvider, - depth: usize, -) -> Result> { - let left = visit_node(&expr.left, index_info, depth + 1)?; - let right = visit_node(&expr.right, index_info, depth + 1)?; - Ok(match (left, right) { - (Some(left), Some(right)) => left.maybe_or(right), - // If one side can use an index and the other side cannot then - // we must abandon the entire thing. For example, consider the - // query "color == 'blue' or size > 10" where color is indexed but - // size is not. It's entirely possible that size > 10 matches every - // row in our database. There is nothing we can do except a full scan - (Some(_), None) => None, - (None, Some(_)) => None, - (None, None) => None, - }) -} - -fn visit_binary_expr( - expr: &BinaryExpr, - index_info: &dyn IndexInformationProvider, - depth: usize, -) -> Result> { - match &expr.op { - Operator::Lt | Operator::LtEq | Operator::Gt | Operator::GtEq | Operator::Eq => { - Ok(visit_comparison(expr, index_info)) - } - // visit_comparison will maybe create an Eq query which we negate - Operator::NotEq => Ok(visit_comparison(expr, index_info).and_then(|node| node.maybe_not())), - Operator::And => visit_and(expr, index_info, depth), - Operator::Or => visit_or(expr, index_info, depth), - _ => Ok(None), - } -} - -fn visit_scalar_fn( - scalar_fn: &ScalarFunction, - index_info: &dyn IndexInformationProvider, -) -> Option { - if scalar_fn.args.is_empty() { - return None; - } - let (col, data_type, query_parser) = maybe_indexed_column(&scalar_fn.args[0], index_info)?; - query_parser.visit_scalar_function(&col, &data_type, &scalar_fn.func, &scalar_fn.args) -} - -fn visit_like_expr( - like: &Like, - index_info: &dyn IndexInformationProvider, -) -> Option { - let (column, _, query_parser) = maybe_indexed_column(&like.expr, index_info)?; - - // Extract the pattern as a ScalarValue - let pattern = match like.pattern.as_ref() { - Expr::Literal(scalar, _) => scalar.clone(), - _ => return None, - }; - - query_parser.visit_like(&column, like, &pattern) -} - -fn visit_node( - expr: &Expr, - index_info: &dyn IndexInformationProvider, - depth: usize, -) -> Result> { - if depth >= MAX_DEPTH { - return Err(Error::invalid_input(format!( - "the filter expression is too long, lance limit the max number of conditions to {}", - MAX_DEPTH - ))); - } - match expr { - Expr::Between(between) => Ok(visit_between(between, index_info)), - Expr::Alias(alias) => visit_node(alias.expr.as_ref(), index_info, depth), - Expr::Column(_) => Ok(visit_column(expr, index_info)), - Expr::InList(in_list) => Ok(visit_in_list(in_list, index_info)), - Expr::IsFalse(expr) => Ok(visit_is_bool(expr.as_ref(), index_info, false)), - Expr::IsTrue(expr) => Ok(visit_is_bool(expr.as_ref(), index_info, true)), - Expr::IsNull(expr) => Ok(visit_is_null(expr.as_ref(), index_info, false)), - Expr::IsNotNull(expr) => Ok(visit_is_null(expr.as_ref(), index_info, true)), - Expr::Not(expr) => visit_not(expr.as_ref(), index_info, depth), - Expr::BinaryExpr(binary_expr) => visit_binary_expr(binary_expr, index_info, depth), - Expr::ScalarFunction(scalar_fn) => Ok(visit_scalar_fn(scalar_fn, index_info)), - Expr::Like(like) => { - if like.negated { - // NOT LIKE cannot be efficiently pruned with zone maps - Ok(None) - } else { - Ok(visit_like_expr(like, index_info)) - } - } - _ => Ok(None), - } -} - -/// A trait to be used in `apply_scalar_indices` to inform the function which columns are indexeds -pub trait IndexInformationProvider { - /// Check if an index exists for `col` and, if so, return the data type of col - /// as well as a query parser that can parse queries for that column - fn get_index(&self, col: &str) -> Option<(&DataType, &dyn ScalarQueryParser)>; -} - -/// Attempt to split a filter expression into a search of scalar indexes and an -/// optional post-search refinement query -pub fn apply_scalar_indices( - expr: Expr, - index_info: &dyn IndexInformationProvider, -) -> Result { - Ok(visit_node(&expr, index_info, 0)?.unwrap_or(IndexedExpression::refine_only(expr))) -} - -#[derive(Clone, Default, Debug)] -pub struct FilterPlan { - pub index_query: Option, - /// True if the index query is guaranteed to return exact results - pub skip_recheck: bool, - pub refine_expr: Option, - pub full_expr: Option, -} - -impl FilterPlan { - pub fn empty() -> Self { - Self { - index_query: None, - skip_recheck: true, - refine_expr: None, - full_expr: None, - } - } - - pub fn new_refine_only(expr: Expr) -> Self { - Self { - index_query: None, - skip_recheck: true, - refine_expr: Some(expr.clone()), - full_expr: Some(expr), - } - } - - pub fn is_empty(&self) -> bool { - self.refine_expr.is_none() && self.index_query.is_none() - } - - pub fn all_columns(&self) -> Vec { - self.full_expr - .as_ref() - .map(Planner::column_names_in_expr) - .unwrap_or_default() - } - - pub fn refine_columns(&self) -> Vec { - self.refine_expr - .as_ref() - .map(Planner::column_names_in_expr) - .unwrap_or_default() - } - - /// Return true if this has a refine step, regardless of the status of prefilter - pub fn has_refine(&self) -> bool { - self.refine_expr.is_some() - } - - /// Return true if this has a scalar index query - pub fn has_index_query(&self) -> bool { - self.index_query.is_some() - } - - pub fn has_any_filter(&self) -> bool { - self.refine_expr.is_some() || self.index_query.is_some() - } - - pub fn make_refine_only(&mut self) { - self.index_query = None; - self.refine_expr = self.full_expr.clone(); - } - - /// Return true if there is no refine or recheck of any kind and there is an index query - pub fn is_exact_index_search(&self) -> bool { - self.index_query.is_some() && self.refine_expr.is_none() && self.skip_recheck - } -} - -pub trait PlannerIndexExt { - /// Determine how to apply a provided filter - /// - /// We parse the filter into a logical expression. We then - /// split the logical expression into a portion that can be - /// satisfied by an index search (of one or more indices) and - /// a refine portion that must be applied after the index search - fn create_filter_plan( - &self, - filter: Expr, - index_info: &dyn IndexInformationProvider, - use_scalar_index: bool, - ) -> Result; -} - -impl PlannerIndexExt for Planner { - fn create_filter_plan( - &self, - filter: Expr, - index_info: &dyn IndexInformationProvider, - use_scalar_index: bool, - ) -> Result { - let logical_expr = self.optimize_expr(filter)?; - if use_scalar_index { - let indexed_expr = apply_scalar_indices(logical_expr.clone(), index_info)?; - let mut skip_recheck = false; - if let Some(scalar_query) = indexed_expr.scalar_query.as_ref() { - skip_recheck = !scalar_query.needs_recheck(); - } - Ok(FilterPlan { - index_query: indexed_expr.scalar_query, - refine_expr: indexed_expr.refine_expr, - full_expr: Some(logical_expr), - skip_recheck, - }) - } else { - Ok(FilterPlan { - index_query: None, - skip_recheck: true, - refine_expr: Some(logical_expr.clone()), - full_expr: Some(logical_expr), - }) - } - } -} - -#[cfg(test)] -mod tests { - use std::collections::HashMap; - - use arrow_schema::{Field, Schema}; - use chrono::Utc; - use datafusion_common::{Column, DFSchema}; - use datafusion_expr::simplify::SimplifyContext; - use lance_datafusion::exec::{LanceExecutionOptions, get_session_context}; - - use crate::scalar::json::{JsonQuery, JsonQueryParser}; - - use super::*; - - struct ColInfo { - data_type: DataType, - parser: Box, - } - - impl ColInfo { - fn new(data_type: DataType, parser: Box) -> Self { - Self { data_type, parser } - } - } - - struct MockIndexInfoProvider { - indexed_columns: HashMap, - } - - impl MockIndexInfoProvider { - fn new(indexed_columns: Vec<(&str, ColInfo)>) -> Self { - Self { - indexed_columns: HashMap::from_iter( - indexed_columns - .into_iter() - .map(|(s, ty)| (s.to_string(), ty)), - ), - } - } - } - - impl IndexInformationProvider for MockIndexInfoProvider { - fn get_index(&self, col: &str) -> Option<(&DataType, &dyn ScalarQueryParser)> { - self.indexed_columns - .get(col) - .map(|col_info| (&col_info.data_type, col_info.parser.as_ref())) - } - } - - fn check( - index_info: &dyn IndexInformationProvider, - expr: &str, - expected: Option, - optimize: bool, - ) { - let schema = Schema::new(vec![ - Field::new("color", DataType::Utf8, false), - Field::new("size", DataType::Float32, false), - Field::new("aisle", DataType::UInt32, false), - Field::new("on_sale", DataType::Boolean, false), - Field::new("price", DataType::Float32, false), - Field::new("json", DataType::LargeBinary, false), - ]); - let df_schema: DFSchema = schema.try_into().unwrap(); - - let ctx = get_session_context(&LanceExecutionOptions::default()); - let state = ctx.state(); - let mut expr = state.create_logical_expr(expr, &df_schema).unwrap(); - if optimize { - let simplify_context = SimplifyContext::default() - .with_schema(Arc::new(df_schema)) - .with_query_execution_start_time(Some(Utc::now())); - let simplifier = - datafusion::optimizer::simplify_expressions::ExprSimplifier::new(simplify_context); - expr = simplifier.simplify(expr).unwrap(); - } - - let actual = apply_scalar_indices(expr.clone(), index_info).unwrap(); - if let Some(expected) = expected { - assert_eq!(actual, expected); - } else { - assert!(actual.scalar_query.is_none()); - assert_eq!(actual.refine_expr.unwrap(), expr); - } - } - - fn check_no_index(index_info: &dyn IndexInformationProvider, expr: &str) { - check(index_info, expr, None, false) - } - - fn check_simple( - index_info: &dyn IndexInformationProvider, - expr: &str, - col: &str, - query: impl AnyQuery, - ) { - check( - index_info, - expr, - Some(IndexedExpression::index_query( - col.to_string(), - format!("{}_idx", col), - "BTree".to_string(), - Arc::new(query), - )), - false, - ) - } - - fn check_range( - index_info: &dyn IndexInformationProvider, - expr: &str, - col: &str, - query: SargableQuery, - ) { - check( - index_info, - expr, - Some(IndexedExpression::index_query( - col.to_string(), - format!("{}_idx", col), - "BTree".to_string(), - Arc::new(query), - )), - true, - ) - } - - fn check_simple_negated( - index_info: &dyn IndexInformationProvider, - expr: &str, - col: &str, - query: SargableQuery, - ) { - check( - index_info, - expr, - Some( - IndexedExpression::index_query( - col.to_string(), - format!("{}_idx", col), - "BTree".to_string(), - Arc::new(query), - ) - .maybe_not() - .unwrap(), - ), - false, - ) - } - - #[test] - fn test_expressions() { - let index_info = MockIndexInfoProvider::new(vec![ - ( - "color", - ColInfo::new( - DataType::Utf8, - Box::new(SargableQueryParser::new( - "color_idx".to_string(), - "BTree".to_string(), - false, - )), - ), - ), - ( - "aisle", - ColInfo::new( - DataType::UInt32, - Box::new(SargableQueryParser::new( - "aisle_idx".to_string(), - "BTree".to_string(), - false, - )), - ), - ), - ( - "on_sale", - ColInfo::new( - DataType::Boolean, - Box::new(SargableQueryParser::new( - "on_sale_idx".to_string(), - "BTree".to_string(), - false, - )), - ), - ), - ( - "price", - ColInfo::new( - DataType::Float32, - Box::new(SargableQueryParser::new( - "price_idx".to_string(), - "BTree".to_string(), - false, - )), - ), - ), - ( - "json", - ColInfo::new( - DataType::LargeBinary, - Box::new(JsonQueryParser::new( - "$.name".to_string(), - Box::new(SargableQueryParser::new( - "json_idx".to_string(), - "BTree".to_string(), - false, - )), - )), - ), - ), - ]); - - check_simple( - &index_info, - "json_extract(json, '$.name') = 'foo'", - "json", - JsonQuery::new( - Arc::new(SargableQuery::Equals(ScalarValue::Utf8(Some( - "foo".to_string(), - )))), - "$.name".to_string(), - ), - ); - - check_no_index(&index_info, "size BETWEEN 5 AND 10"); - // Cast case. We will cast 5 (an int64) to Int16 and then coerce to UInt32 - check_simple( - &index_info, - "aisle = arrow_cast(5, 'Int16')", - "aisle", - SargableQuery::Equals(ScalarValue::UInt32(Some(5))), - ); - // 5 different ways of writing BETWEEN (all should be recognized) - check_range( - &index_info, - "aisle BETWEEN 5 AND 10", - "aisle", - SargableQuery::Range( - Bound::Included(ScalarValue::UInt32(Some(5))), - Bound::Included(ScalarValue::UInt32(Some(10))), - ), - ); - check_range( - &index_info, - "aisle >= 5 AND aisle <= 10", - "aisle", - SargableQuery::Range( - Bound::Included(ScalarValue::UInt32(Some(5))), - Bound::Included(ScalarValue::UInt32(Some(10))), - ), - ); - - check_range( - &index_info, - "aisle <= 10 AND aisle >= 5", - "aisle", - SargableQuery::Range( - Bound::Included(ScalarValue::UInt32(Some(5))), - Bound::Included(ScalarValue::UInt32(Some(10))), - ), - ); - - check_range( - &index_info, - "5 <= aisle AND 10 >= aisle", - "aisle", - SargableQuery::Range( - Bound::Included(ScalarValue::UInt32(Some(5))), - Bound::Included(ScalarValue::UInt32(Some(10))), - ), - ); - - check_range( - &index_info, - "10 >= aisle AND 5 <= aisle", - "aisle", - SargableQuery::Range( - Bound::Included(ScalarValue::UInt32(Some(5))), - Bound::Included(ScalarValue::UInt32(Some(10))), - ), - ); - check_range( - &index_info, - "aisle <= 10 AND aisle > 5", - "aisle", - SargableQuery::Range( - Bound::Excluded(ScalarValue::UInt32(Some(5))), - Bound::Included(ScalarValue::UInt32(Some(10))), - ), - ); - check_range( - &index_info, - "aisle < 10 AND aisle >= 5", - "aisle", - SargableQuery::Range( - Bound::Included(ScalarValue::UInt32(Some(5))), - Bound::Excluded(ScalarValue::UInt32(Some(10))), - ), - ); - check_simple( - &index_info, - "on_sale IS TRUE", - "on_sale", - SargableQuery::Equals(ScalarValue::Boolean(Some(true))), - ); - check_simple( - &index_info, - "on_sale", - "on_sale", - SargableQuery::Equals(ScalarValue::Boolean(Some(true))), - ); - check_simple_negated( - &index_info, - "NOT on_sale", - "on_sale", - SargableQuery::Equals(ScalarValue::Boolean(Some(true))), - ); - check_simple( - &index_info, - "on_sale IS FALSE", - "on_sale", - SargableQuery::Equals(ScalarValue::Boolean(Some(false))), - ); - check_simple_negated( - &index_info, - "aisle NOT BETWEEN 5 AND 10", - "aisle", - SargableQuery::Range( - Bound::Included(ScalarValue::UInt32(Some(5))), - Bound::Included(ScalarValue::UInt32(Some(10))), - ), - ); - // Small in-list (in-list with 3 or fewer items optimizes into or-chain) - check_simple( - &index_info, - "aisle IN (5, 6, 7)", - "aisle", - SargableQuery::IsIn(vec![ - ScalarValue::UInt32(Some(5)), - ScalarValue::UInt32(Some(6)), - ScalarValue::UInt32(Some(7)), - ]), - ); - check_simple_negated( - &index_info, - "NOT aisle IN (5, 6, 7)", - "aisle", - SargableQuery::IsIn(vec![ - ScalarValue::UInt32(Some(5)), - ScalarValue::UInt32(Some(6)), - ScalarValue::UInt32(Some(7)), - ]), - ); - check_simple_negated( - &index_info, - "aisle NOT IN (5, 6, 7)", - "aisle", - SargableQuery::IsIn(vec![ - ScalarValue::UInt32(Some(5)), - ScalarValue::UInt32(Some(6)), - ScalarValue::UInt32(Some(7)), - ]), - ); - check_simple( - &index_info, - "aisle IN (5, 6, 7, 8, 9)", - "aisle", - SargableQuery::IsIn(vec![ - ScalarValue::UInt32(Some(5)), - ScalarValue::UInt32(Some(6)), - ScalarValue::UInt32(Some(7)), - ScalarValue::UInt32(Some(8)), - ScalarValue::UInt32(Some(9)), - ]), - ); - check_simple_negated( - &index_info, - "NOT aisle IN (5, 6, 7, 8, 9)", - "aisle", - SargableQuery::IsIn(vec![ - ScalarValue::UInt32(Some(5)), - ScalarValue::UInt32(Some(6)), - ScalarValue::UInt32(Some(7)), - ScalarValue::UInt32(Some(8)), - ScalarValue::UInt32(Some(9)), - ]), - ); - check_simple_negated( - &index_info, - "aisle NOT IN (5, 6, 7, 8, 9)", - "aisle", - SargableQuery::IsIn(vec![ - ScalarValue::UInt32(Some(5)), - ScalarValue::UInt32(Some(6)), - ScalarValue::UInt32(Some(7)), - ScalarValue::UInt32(Some(8)), - ScalarValue::UInt32(Some(9)), - ]), - ); - check_simple( - &index_info, - "on_sale is false", - "on_sale", - SargableQuery::Equals(ScalarValue::Boolean(Some(false))), - ); - check_simple( - &index_info, - "on_sale is true", - "on_sale", - SargableQuery::Equals(ScalarValue::Boolean(Some(true))), - ); - check_simple( - &index_info, - "aisle < 10", - "aisle", - SargableQuery::Range( - Bound::Unbounded, - Bound::Excluded(ScalarValue::UInt32(Some(10))), - ), - ); - check_simple( - &index_info, - "aisle <= 10", - "aisle", - SargableQuery::Range( - Bound::Unbounded, - Bound::Included(ScalarValue::UInt32(Some(10))), - ), - ); - check_simple( - &index_info, - "aisle > 10", - "aisle", - SargableQuery::Range( - Bound::Excluded(ScalarValue::UInt32(Some(10))), - Bound::Unbounded, - ), - ); - // In the future we can handle this case if we need to. For - // now let's make sure we don't accidentally do the wrong thing - // (we were getting this backwards in the past) - check_no_index(&index_info, "10 > aisle"); - check_simple( - &index_info, - "aisle >= 10", - "aisle", - SargableQuery::Range( - Bound::Included(ScalarValue::UInt32(Some(10))), - Bound::Unbounded, - ), - ); - check_simple( - &index_info, - "aisle = 10", - "aisle", - SargableQuery::Equals(ScalarValue::UInt32(Some(10))), - ); - check_simple_negated( - &index_info, - "aisle <> 10", - "aisle", - SargableQuery::Equals(ScalarValue::UInt32(Some(10))), - ); - // // Common compound case, AND'd clauses - let left = Box::new(ScalarIndexExpr::Query(ScalarIndexSearch { - column: "aisle".to_string(), - index_name: "aisle_idx".to_string(), - index_type: "BTree".to_string(), - query: Arc::new(SargableQuery::Equals(ScalarValue::UInt32(Some(10)))), - needs_recheck: false, - })); - let right = Box::new(ScalarIndexExpr::Query(ScalarIndexSearch { - column: "color".to_string(), - index_name: "color_idx".to_string(), - index_type: "BTree".to_string(), - query: Arc::new(SargableQuery::Equals(ScalarValue::Utf8(Some( - "blue".to_string(), - )))), - needs_recheck: false, - })); - check( - &index_info, - "aisle = 10 AND color = 'blue'", - Some(IndexedExpression { - scalar_query: Some(ScalarIndexExpr::And(left.clone(), right.clone())), - refine_expr: None, - }), - false, - ); - // Compound AND's and not all of them are indexed columns - let refine = Expr::Column(Column::new_unqualified("size")).gt(datafusion_expr::lit(30_i64)); - check( - &index_info, - "aisle = 10 AND color = 'blue' AND size > 30", - Some(IndexedExpression { - scalar_query: Some(ScalarIndexExpr::And(left.clone(), right.clone())), - refine_expr: Some(refine.clone()), - }), - false, - ); - // Compounded OR's where ALL columns are indexed - check( - &index_info, - "aisle = 10 OR color = 'blue'", - Some(IndexedExpression { - scalar_query: Some(ScalarIndexExpr::Or(left.clone(), right.clone())), - refine_expr: None, - }), - false, - ); - // Compounded OR's with one or more unindexed columns - check_no_index(&index_info, "aisle = 10 OR color = 'blue' OR size > 30"); - // AND'd group of OR - check( - &index_info, - "(aisle = 10 OR color = 'blue') AND size > 30", - Some(IndexedExpression { - scalar_query: Some(ScalarIndexExpr::Or(left, right)), - refine_expr: Some(refine), - }), - false, - ); - // Examples of things that are not yet supported but should be supportable someday - - // OR'd group of refined index searches (see IndexedExpression::or for details) - check_no_index( - &index_info, - "(aisle = 10 AND size > 30) OR (color = 'blue' AND size > 20)", - ); - - // Non-normalized arithmetic (can use expression simplification) - check_no_index(&index_info, "aisle + 3 < 10"); - - // Currently we assume that the return of an index search tells us which rows are - // TRUE and all other rows are FALSE. This will need to change but for now it is - // safer to not support the following cases because the return value of non-matched - // rows is NULL and not FALSE. - check_no_index(&index_info, "aisle IN (5, 6, NULL)"); - // OR-list with NULL (in future DF version this will be optimized repr of - // small in-list with NULL so let's get ready for it) - check_no_index(&index_info, "aisle = 5 OR aisle = 6 OR NULL"); - check_no_index(&index_info, "aisle IN (5, 6, 7, 8, NULL)"); - check_no_index(&index_info, "aisle = NULL"); - check_no_index(&index_info, "aisle BETWEEN 5 AND NULL"); - check_no_index(&index_info, "aisle BETWEEN NULL AND 10"); - } - - #[tokio::test] - async fn test_not_flips_certainty() { - use lance_core::utils::mask::{NullableRowAddrSet, RowAddrTreeMap}; - - // Test that NOT flips certainty for inexact index results - // This tests the implementation in evaluate_impl for Self::Not - - // Helper function that mimics the NOT logic we just fixed - fn apply_not(result: NullableIndexExprResult) -> NullableIndexExprResult { - match result { - NullableIndexExprResult::Exact(mask) => NullableIndexExprResult::Exact(!mask), - NullableIndexExprResult::AtMost(mask) => NullableIndexExprResult::AtLeast(!mask), - NullableIndexExprResult::AtLeast(mask) => NullableIndexExprResult::AtMost(!mask), - } - } - - // AtMost: superset of matches (e.g., bloom filter says "might be in [1,2]") - let at_most = NullableIndexExprResult::AtMost(NullableRowAddrMask::AllowList( - NullableRowAddrSet::new(RowAddrTreeMap::from_iter(&[1, 2]), RowAddrTreeMap::new()), - )); - // NOT(AtMost) should be AtLeast (definitely NOT in [1,2], might be elsewhere) - assert!(matches!( - apply_not(at_most), - NullableIndexExprResult::AtLeast(_) - )); - - // AtLeast: subset of matches (e.g., definitely in [1,2], might be more) - let at_least = NullableIndexExprResult::AtLeast(NullableRowAddrMask::AllowList( - NullableRowAddrSet::new(RowAddrTreeMap::from_iter(&[1, 2]), RowAddrTreeMap::new()), - )); - // NOT(AtLeast) should be AtMost (might NOT be in [1,2], definitely elsewhere) - assert!(matches!( - apply_not(at_least), - NullableIndexExprResult::AtMost(_) - )); - - // Exact should stay Exact - let exact = NullableIndexExprResult::Exact(NullableRowAddrMask::AllowList( - NullableRowAddrSet::new(RowAddrTreeMap::from_iter(&[1, 2]), RowAddrTreeMap::new()), - )); - assert!(matches!( - apply_not(exact), - NullableIndexExprResult::Exact(_) - )); - } - - #[tokio::test] - async fn test_and_or_preserve_certainty() { - use lance_core::utils::mask::{NullableRowAddrSet, RowAddrTreeMap}; - - // Test that AND/OR correctly propagate certainty - let make_at_most = || { - NullableIndexExprResult::AtMost(NullableRowAddrMask::AllowList( - NullableRowAddrSet::new( - RowAddrTreeMap::from_iter(&[1, 2, 3]), - RowAddrTreeMap::new(), - ), - )) - }; - - let make_at_least = || { - NullableIndexExprResult::AtLeast(NullableRowAddrMask::AllowList( - NullableRowAddrSet::new( - RowAddrTreeMap::from_iter(&[2, 3, 4]), - RowAddrTreeMap::new(), - ), - )) - }; - - let make_exact = || { - NullableIndexExprResult::Exact(NullableRowAddrMask::AllowList(NullableRowAddrSet::new( - RowAddrTreeMap::from_iter(&[1, 2]), - RowAddrTreeMap::new(), - ))) - }; - - // AtMost & AtMost → AtMost - assert!(matches!( - make_at_most() & make_at_most(), - NullableIndexExprResult::AtMost(_) - )); - - // AtLeast & AtLeast → AtLeast - assert!(matches!( - make_at_least() & make_at_least(), - NullableIndexExprResult::AtLeast(_) - )); - - // AtMost & AtLeast → AtMost (superset remains superset) - assert!(matches!( - make_at_most() & make_at_least(), - NullableIndexExprResult::AtMost(_) - )); - - // AtMost | AtMost → AtMost - assert!(matches!( - make_at_most() | make_at_most(), - NullableIndexExprResult::AtMost(_) - )); - - // AtLeast | AtLeast → AtLeast - assert!(matches!( - make_at_least() | make_at_least(), - NullableIndexExprResult::AtLeast(_) - )); - - // AtMost | AtLeast → AtLeast (subset coverage guaranteed) - assert!(matches!( - make_at_most() | make_at_least(), - NullableIndexExprResult::AtLeast(_) - )); - - // Exact & AtMost → AtMost - assert!(matches!( - make_exact() & make_at_most(), - NullableIndexExprResult::AtMost(_) - )); - - // Exact | AtLeast → AtLeast - assert!(matches!( - make_exact() | make_at_least(), - NullableIndexExprResult::AtLeast(_) - )); - } - - #[test] - fn test_extract_like_leading_prefix() { - // Simple prefix patterns (no recheck needed) - assert_eq!( - extract_like_leading_prefix("foo%", None), - Some(("foo".to_string(), false)) - ); - assert_eq!( - extract_like_leading_prefix("abc%", None), - Some(("abc".to_string(), false)) - ); - - // Patterns with wildcards in the middle (need recheck) - assert_eq!( - extract_like_leading_prefix("foo%bar%", None), - Some(("foo".to_string(), true)) - ); - assert_eq!( - extract_like_leading_prefix("foo_bar%", None), - Some(("foo".to_string(), true)) - ); - assert_eq!( - extract_like_leading_prefix("foo%bar", None), - Some(("foo".to_string(), true)) - ); - assert_eq!( - extract_like_leading_prefix("foo_", None), - Some(("foo".to_string(), true)) - ); - - // Not prefix patterns (starts with wildcard) - assert_eq!(extract_like_leading_prefix("%foo", None), None); - assert_eq!(extract_like_leading_prefix("_foo%", None), None); - assert_eq!(extract_like_leading_prefix("%", None), None); - - // No wildcard at all (should use equality) - assert_eq!(extract_like_leading_prefix("foo", None), None); - - // With escape character - assert_eq!( - extract_like_leading_prefix(r"foo\%bar%", Some('\\')), - Some(("foo%bar".to_string(), false)) - ); - assert_eq!( - extract_like_leading_prefix(r"foo\_bar%", Some('\\')), - Some(("foo_bar".to_string(), false)) - ); - assert_eq!( - extract_like_leading_prefix(r"foo\\bar%", Some('\\')), - Some(("foo\\bar".to_string(), false)) - ); - - // Escaped trailing % is not a wildcard (no wildcards) - assert_eq!(extract_like_leading_prefix(r"foo\%", Some('\\')), None); - - // With backslash as default escape (for DataFusion starts_with compatibility): - // "foo\%" means escaped %, no wildcard -> None (should use equality) - assert_eq!(extract_like_leading_prefix(r"foo\%", None), None); - // "foo\bar%" - \b is not a valid escape sequence, so \ and b are literals, % is wildcard - assert_eq!( - extract_like_leading_prefix(r"foo\bar%", None), - Some(("foo\\bar".to_string(), false)) - ); - - // Empty pattern - assert_eq!(extract_like_leading_prefix("", None), None); - - // Mixed escaped and unescaped - assert_eq!( - extract_like_leading_prefix(r"foo\%bar%baz%", Some('\\')), - Some(("foo%bar".to_string(), true)) - ); - } - - #[test] - fn test_like_expression_parsing() { - // Test that LIKE expressions are parsed correctly with refine_expr for complex patterns - - let index_info = MockIndexInfoProvider::new(vec![( - "color", - ColInfo::new( - DataType::Utf8, - Box::new(SargableQueryParser::new( - "color_idx".to_string(), - "BTree".to_string(), - false, - )), - ), - )]); - - // Simple prefix pattern: LIKE 'foo%' -> LikePrefix("foo"), no refine_expr - let schema = Schema::new(vec![Field::new("color", DataType::Utf8, false)]); - let df_schema: DFSchema = schema.try_into().unwrap(); - let ctx = get_session_context(&LanceExecutionOptions::default()); - let state = ctx.state(); - - let expr = state - .create_logical_expr("color LIKE 'foo%'", &df_schema) - .unwrap(); - let result = apply_scalar_indices(expr, &index_info).unwrap(); - - assert!(result.scalar_query.is_some(), "Should have scalar_query"); - assert!( - result.refine_expr.is_none(), - "Simple prefix should not need refine_expr" - ); - - // Extract the query and verify it's LikePrefix - if let Some(ScalarIndexExpr::Query(search)) = &result.scalar_query { - let query = search.query.as_any().downcast_ref::(); - assert!(query.is_some(), "Query should be SargableQuery"); - match query.unwrap() { - SargableQuery::LikePrefix(prefix) => { - assert_eq!(prefix, &ScalarValue::Utf8(Some("foo".to_string()))); - } - _ => panic!("Expected LikePrefix query"), - } - } else { - panic!("Expected Query variant"); - } - - // Complex pattern: LIKE 'foo%bar%' -> LikePrefix("foo"), with refine_expr - let expr = state - .create_logical_expr("color LIKE 'foo%bar%'", &df_schema) - .unwrap(); - let result = apply_scalar_indices(expr, &index_info).unwrap(); - - assert!(result.scalar_query.is_some(), "Should have scalar_query"); - assert!( - result.refine_expr.is_some(), - "Complex pattern should have refine_expr" - ); - - // Verify the query is still LikePrefix("foo") - if let Some(ScalarIndexExpr::Query(search)) = &result.scalar_query { - let query = search.query.as_any().downcast_ref::(); - assert!(query.is_some(), "Query should be SargableQuery"); - match query.unwrap() { - SargableQuery::LikePrefix(prefix) => { - assert_eq!(prefix, &ScalarValue::Utf8(Some("foo".to_string()))); - } - _ => panic!("Expected LikePrefix query"), - } - } - - // Verify the refine_expr is the original LIKE expression - let refine = result.refine_expr.unwrap(); - match refine { - Expr::Like(like) => { - assert!(!like.negated); - assert!(!like.case_insensitive); - if let Expr::Literal(ScalarValue::Utf8(Some(pattern)), _) = like.pattern.as_ref() { - assert_eq!(pattern, "foo%bar%"); - } else { - panic!("Expected Utf8 literal pattern"); - } - } - _ => panic!("Expected Like expression in refine_expr"), - } - - // Pattern starting with wildcard: LIKE '%foo' -> no index, only refine - let expr = state - .create_logical_expr("color LIKE '%foo'", &df_schema) - .unwrap(); - let result = apply_scalar_indices(expr, &index_info).unwrap(); - - assert!( - result.scalar_query.is_none(), - "Pattern starting with wildcard should not use index" - ); - assert!(result.refine_expr.is_some(), "Should fall back to refine"); - } - - #[test] - fn test_starts_with_with_underscore_after_optimization() { - // Test that starts_with with underscore in prefix works correctly after DataFusion optimization - // DataFusion simplifies starts_with(col, 'test_ns$') to col LIKE 'test_ns$%' - // The underscore in the prefix should NOT be treated as a wildcard! - let index_info = MockIndexInfoProvider::new(vec![( - "object_id", - ColInfo::new( - DataType::Utf8, - Box::new(SargableQueryParser::new( - "object_id_idx".to_string(), - "BTree".to_string(), - false, - )), - ), - )]); - - let schema = Schema::new(vec![Field::new("object_id", DataType::Utf8, false)]); - let df_schema: DFSchema = schema.try_into().unwrap(); - let ctx = get_session_context(&LanceExecutionOptions::default()); - let state = ctx.state(); - - // Create the expression with starts_with containing underscore - let expr = state - .create_logical_expr("starts_with(object_id, 'test_ns$')", &df_schema) - .unwrap(); - - // Apply DataFusion simplification (this may convert starts_with to LIKE) - let simplify_context = SimplifyContext::default() - .with_schema(Arc::new(df_schema)) - .with_query_execution_start_time(Some(Utc::now())); - let simplifier = - datafusion::optimizer::simplify_expressions::ExprSimplifier::new(simplify_context); - let simplified_expr = simplifier.simplify(expr).unwrap(); - - // Apply scalar indices - let result = apply_scalar_indices(simplified_expr, &index_info).unwrap(); - - // The prefix should be "test_ns$", NOT "test" - // This test documents the current (potentially broken) behavior - if let Some(ScalarIndexExpr::Query(search)) = &result.scalar_query { - let query = search - .query - .as_any() - .downcast_ref::() - .unwrap(); - match query { - SargableQuery::LikePrefix(prefix) => { - let prefix_str = match prefix { - ScalarValue::Utf8(Some(s)) => s.clone(), - _ => panic!("Expected Utf8 prefix"), - }; - // Verify the prefix is correctly extracted with underscore as literal - assert_eq!( - prefix_str, "test_ns$", - "Prefix should be 'test_ns$', not 'test' (underscore should not be a wildcard)" - ); - } - _ => panic!("Expected LikePrefix query"), - } - } else { - // If no scalar query, it means the pattern was not recognized - panic!("Expected scalar_query to be present"); - } - } - - #[test] - fn test_starts_with_to_like_conversion() { - // Test that starts_with(col, 'prefix') is converted to LikePrefix query - let index_info = MockIndexInfoProvider::new(vec![( - "color", - ColInfo::new( - DataType::Utf8, - Box::new(SargableQueryParser::new( - "color_idx".to_string(), - "BTree".to_string(), - false, - )), - ), - )]); - - let schema = Schema::new(vec![Field::new("color", DataType::Utf8, false)]); - let df_schema: DFSchema = schema.try_into().unwrap(); - let ctx = get_session_context(&LanceExecutionOptions::default()); - let state = ctx.state(); - - // starts_with(color, 'foo') should be converted to LikePrefix("foo") - let expr = state - .create_logical_expr("starts_with(color, 'foo')", &df_schema) - .unwrap(); - let result = apply_scalar_indices(expr, &index_info).unwrap(); - - assert!( - result.scalar_query.is_some(), - "starts_with should use index" - ); - assert!( - result.refine_expr.is_none(), - "Pure prefix starts_with should not need refine_expr" - ); - - // Extract the query and verify it's LikePrefix - if let Some(ScalarIndexExpr::Query(search)) = &result.scalar_query { - let query = search.query.as_any().downcast_ref::(); - assert!(query.is_some(), "Query should be SargableQuery"); - match query.unwrap() { - SargableQuery::LikePrefix(prefix) => { - assert_eq!(prefix, &ScalarValue::Utf8(Some("foo".to_string()))); - } - _ => panic!("Expected LikePrefix query"), - } - } else { - panic!("Expected Query variant"); - } - - // Both starts_with and LIKE 'prefix%' should produce the same LikePrefix query - let like_expr = state - .create_logical_expr("color LIKE 'foo%'", &df_schema) - .unwrap(); - let like_result = apply_scalar_indices(like_expr, &index_info).unwrap(); - - // Compare the queries - both should be LikePrefix("foo") - if let ( - Some(ScalarIndexExpr::Query(starts_with_search)), - Some(ScalarIndexExpr::Query(like_search)), - ) = (&result.scalar_query, &like_result.scalar_query) - { - let sw_query = starts_with_search - .query - .as_any() - .downcast_ref::() - .unwrap(); - let like_query = like_search - .query - .as_any() - .downcast_ref::() - .unwrap(); - assert_eq!( - sw_query, like_query, - "starts_with and LIKE 'prefix%' should produce identical queries" - ); - } - } -} +/// This file was moved from `src/scalar/expression.rs` to `src/expression/scalar.rs` +pub use crate::expression::scalar::*; diff --git a/rust/lance-index/src/scalar/inverted/index.rs b/rust/lance-index/src/scalar/inverted/index.rs index 52e09864c14..b7b209354e2 100644 --- a/rust/lance-index/src/scalar/inverted/index.rs +++ b/rust/lance-index/src/scalar/inverted/index.rs @@ -14,6 +14,7 @@ use std::{ time::Instant, }; +use crate::expression::aggregate::AnyAggregateQuery; use crate::metrics::NoOpMetricsCollector; use crate::prefilter::NoFilter; use crate::scalar::registry::{TrainingCriteria, TrainingOrdering}; @@ -39,6 +40,7 @@ use fst::{Automaton, IntoStreamer, Streamer}; use futures::{FutureExt, Stream, StreamExt, TryStreamExt, stream}; use itertools::Itertools; use lance_arrow::{RecordBatchExt, iter_str_array}; +use lance_arrow_scalar::ArrowScalar; use lance_core::cache::{CacheCodec, CacheKey, LanceCache, WeakLanceCache}; use lance_core::error::{DataFusionResult, LanceOptionExt}; use lance_core::utils::mask::{RowAddrMask, RowAddrTreeMap}; @@ -951,6 +953,18 @@ impl ScalarIndex for InvertedIndex { } } + async fn calculate_aggregate( + &self, + query: &dyn AnyAggregateQuery, + _filter: Option, + _total_rows: u64, + _metrics: &dyn MetricsCollector, + ) -> Result { + Err(Error::invalid_input(format!( + "this index cannot accelerate the aggregate {query:?}" + ))) + } + fn can_remap(&self) -> bool { true } diff --git a/rust/lance-index/src/scalar/json.rs b/rust/lance-index/src/scalar/json.rs index 86c1204e174..944cb5e5d57 100644 --- a/rust/lance-index/src/scalar/json.rs +++ b/rust/lance-index/src/scalar/json.rs @@ -23,6 +23,7 @@ use datafusion_physical_expr::{ }; use deepsize::DeepSizeOf; use futures::StreamExt; +use lance_arrow_scalar::ArrowScalar; use lance_datafusion::exec::{LanceExecutionOptions, OneShotExec, get_session_context}; use lance_datafusion::udf::json::JsonbType; use prost::Message; @@ -33,6 +34,7 @@ use lance_core::{Error, ROW_ID, Result, cache::LanceCache, error::LanceOptionExt use crate::{ Index, IndexType, + expression::aggregate::AnyAggregateQuery, frag_reuse::FragReuseIndex, metrics::MetricsCollector, registry::IndexPluginRegistry, @@ -112,6 +114,18 @@ impl ScalarIndex for JsonIndex { .await } + async fn calculate_aggregate( + &self, + query: &dyn AnyAggregateQuery, + _filter: Option, + _total_rows: u64, + _metrics: &dyn MetricsCollector, + ) -> Result { + Err(Error::invalid_input(format!( + "this index cannot accelerate the aggregate {query:?}" + ))) + } + fn can_remap(&self) -> bool { self.target_index.can_remap() } diff --git a/rust/lance-index/src/scalar/label_list.rs b/rust/lance-index/src/scalar/label_list.rs index eae0f8d6054..40d8a514bf7 100644 --- a/rust/lance-index/src/scalar/label_list.rs +++ b/rust/lance-index/src/scalar/label_list.rs @@ -19,6 +19,7 @@ use datafusion::physical_plan::{SendableRecordBatchStream, stream::RecordBatchSt use datafusion_common::ScalarValue; use deepsize::DeepSizeOf; use futures::{StreamExt, TryStream, TryStreamExt, stream::BoxStream}; +use lance_arrow_scalar::ArrowScalar; use lance_core::cache::LanceCache; use lance_core::error::LanceOptionExt; use lance_core::utils::mask::{NullableRowAddrSet, RowAddrTreeMap, RowSetOps}; @@ -29,7 +30,6 @@ use tracing::instrument; use super::{AnyQuery, IndexStore, LabelListQuery, ScalarIndex, bitmap::BitmapIndex}; use super::{BuiltinIndexType, SargableQuery, ScalarIndexParams}; use super::{MetricsCollector, SearchResult}; -use crate::frag_reuse::FragReuseIndex; use crate::pbold; use crate::scalar::bitmap::BitmapIndexPlugin; use crate::scalar::expression::{LabelListQueryParser, ScalarQueryParser}; @@ -39,6 +39,7 @@ use crate::scalar::registry::{ }; use crate::scalar::{CreatedIndex, UpdateCriteria}; use crate::{Index, IndexType}; +use crate::{expression::aggregate::AnyAggregateQuery, frag_reuse::FragReuseIndex}; pub const BITMAP_LOOKUP_NAME: &str = "bitmap_page_lookup.lance"; pub const LABEL_LIST_NULLS_METADATA_KEY: &str = "lance:label_list_nulls"; @@ -207,6 +208,18 @@ impl ScalarIndex for LabelListIndex { Ok(SearchResult::Exact(row_ids)) } + async fn calculate_aggregate( + &self, + query: &dyn AnyAggregateQuery, + _filter: Option, + _total_rows: u64, + _metrics: &dyn MetricsCollector, + ) -> Result { + Err(Error::invalid_input(format!( + "this index cannot accelerate the aggregate {query:?}" + ))) + } + fn can_remap(&self) -> bool { true } diff --git a/rust/lance-index/src/scalar/ngram.rs b/rust/lance-index/src/scalar/ngram.rs index df526bebcc6..461b2a9900a 100644 --- a/rust/lance-index/src/scalar/ngram.rs +++ b/rust/lance-index/src/scalar/ngram.rs @@ -12,6 +12,7 @@ use super::{ AnyQuery, BuiltinIndexType, IndexReader, IndexStore, IndexWriter, MetricsCollector, ScalarIndex, ScalarIndexParams, SearchResult, TextQuery, }; +use crate::expression::aggregate::AnyAggregateQuery; use crate::frag_reuse::FragReuseIndex; use crate::metrics::NoOpMetricsCollector; use crate::pbold; @@ -32,6 +33,7 @@ use datafusion::execution::SendableRecordBatchStream; use deepsize::DeepSizeOf; use futures::{FutureExt, Stream, StreamExt, TryStreamExt, stream}; use lance_arrow::iter_str_array; +use lance_arrow_scalar::ArrowScalar; use lance_core::cache::{CacheKey, LanceCache, WeakLanceCache}; use lance_core::error::LanceOptionExt; use lance_core::utils::address::RowAddress; @@ -479,6 +481,18 @@ impl ScalarIndex for NGramIndex { } } + async fn calculate_aggregate( + &self, + query: &dyn AnyAggregateQuery, + _filter: Option, + _total_rows: u64, + _metrics: &dyn MetricsCollector, + ) -> Result { + Err(Error::invalid_input(format!( + "this index cannot accelerate the aggregate {query:?}" + ))) + } + fn can_remap(&self) -> bool { true } diff --git a/rust/lance-index/src/scalar/rtree.rs b/rust/lance-index/src/scalar/rtree.rs index adc365e53d2..7fdbf1c039f 100644 --- a/rust/lance-index/src/scalar/rtree.rs +++ b/rust/lance-index/src/scalar/rtree.rs @@ -1,6 +1,7 @@ // SPDX-License-Identifier: Apache-2.0 // SPDX-FileCopyrightText: Copyright The Lance Authors +use crate::expression::aggregate::AnyAggregateQuery; use crate::frag_reuse::FragReuseIndex; use crate::metrics::{MetricsCollector, NoOpMetricsCollector}; use crate::scalar::expression::{GeoQueryParser, ScalarQueryParser}; @@ -32,6 +33,7 @@ use geoarrow_array::builder::RectBuilder; use geoarrow_array::{GeoArrowArray, GeoArrowArrayAccessor, IntoArrow}; use geoarrow_schema::{Dimension, RectType}; use lance_arrow::RecordBatchExt; +use lance_arrow_scalar::ArrowScalar; use lance_core::cache::{CacheKey, LanceCache, WeakLanceCache}; use lance_core::utils::address::RowAddress; use lance_core::utils::mask::{NullableRowAddrSet, RowAddrTreeMap, RowSetOps}; @@ -545,6 +547,18 @@ impl ScalarIndex for RTreeIndex { } } + async fn calculate_aggregate( + &self, + query: &dyn AnyAggregateQuery, + _filter: Option, + _total_rows: u64, + _metrics: &dyn MetricsCollector, + ) -> Result { + Err(Error::invalid_input(format!( + "this index cannot accelerate the aggregate {query:?}" + ))) + } + fn can_remap(&self) -> bool { false } diff --git a/rust/lance-index/src/scalar/zonemap.rs b/rust/lance-index/src/scalar/zonemap.rs index 45ded3b0db5..df678268a56 100644 --- a/rust/lance-index/src/scalar/zonemap.rs +++ b/rust/lance-index/src/scalar/zonemap.rs @@ -13,6 +13,7 @@ //! //! use crate::Any; +use crate::expression::aggregate::AnyAggregateQuery; use crate::pbold; use crate::scalar::expression::{SargableQueryParser, ScalarQueryParser}; use crate::scalar::registry::{ @@ -24,6 +25,7 @@ use crate::scalar::{ }; use datafusion::functions_aggregate::min_max::{MaxAccumulator, MinAccumulator}; use datafusion_expr::Accumulator; +use lance_arrow_scalar::ArrowScalar; use lance_core::cache::{LanceCache, WeakLanceCache}; use serde::{Deserialize, Serialize}; use std::sync::LazyLock; @@ -564,6 +566,18 @@ impl ScalarIndex for ZoneMapIndex { }) } + async fn calculate_aggregate( + &self, + query: &dyn AnyAggregateQuery, + _filter: Option, + _total_rows: u64, + _metrics: &dyn MetricsCollector, + ) -> Result { + Err(Error::invalid_input(format!( + "this index cannot accelerate the aggregate {query:?}" + ))) + } + fn can_remap(&self) -> bool { false } diff --git a/rust/lance/src/index/scalar_logical.rs b/rust/lance/src/index/scalar_logical.rs index 2cc9dd5cbe4..e1f512f53c2 100644 --- a/rust/lance/src/index/scalar_logical.rs +++ b/rust/lance/src/index/scalar_logical.rs @@ -11,8 +11,11 @@ use deepsize::{Context, DeepSizeOf}; use futures::future::try_join_all; use lance_core::utils::mask::NullableRowAddrSet; use lance_core::{Error, Result}; +use lance_index::expression::aggregate::AnyAggregateQuery; use lance_index::metrics::MetricsCollector; -use lance_index::scalar::{AnyQuery, CreatedIndex, ScalarIndex, SearchResult, UpdateCriteria}; +use lance_index::scalar::{ + AnyQuery, ArrowScalar, CreatedIndex, ScalarIndex, SearchResult, UpdateCriteria, +}; use lance_index::{Index, IndexType}; use lance_table::format::IndexMetadata; use roaring::RoaringBitmap; @@ -132,6 +135,19 @@ impl ScalarIndex for LogicalScalarIndex { combine_search_results(results) } + async fn calculate_aggregate( + &self, + query: &dyn AnyAggregateQuery, + _filter: Option, + _total_rows: u64, + _metrics: &dyn MetricsCollector, + ) -> Result { + Err(Error::invalid_input(format!( + "LogicalScalarIndex '{}' cannot accelerate the aggregate {query:?}", + self.name + ))) + } + fn can_remap(&self) -> bool { false } diff --git a/rust/lance/src/io/exec.rs b/rust/lance/src/io/exec.rs index 0b93e3c2834..badc301ce59 100644 --- a/rust/lance/src/io/exec.rs +++ b/rust/lance/src/io/exec.rs @@ -5,6 +5,7 @@ //! //! WARNING: Internal API with no stability guarantees. +pub mod aggregate_index; #[cfg(feature = "substrait")] pub mod ann_proto; mod filter; diff --git a/rust/lance/src/io/exec/aggregate_index.rs b/rust/lance/src/io/exec/aggregate_index.rs new file mode 100644 index 00000000000..7da9ff42c7b --- /dev/null +++ b/rust/lance/src/io/exec/aggregate_index.rs @@ -0,0 +1,776 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright The Lance Authors + +//! Execute-time half of aggregate pushdown. +//! +//! [`AggregateIndexSearchExec`] computes partial aggregate state for one or +//! more aggregates by probing scalar indices, without scanning column data. +//! Its output schema matches what `AggregateExec(AggregateMode::Partial)` +//! would produce for the same aggregates, so a downstream `AggregateExec` +//! in `Final`/`FinalPartitioned` mode can combine us unchanged. + +use std::collections::HashMap; +use std::sync::Arc; + +use arrow_array::{Array, BinaryArray, Int64Array, RecordBatch}; +use arrow_schema::{Schema, SchemaRef}; +use datafusion::physical_plan::{ + DisplayAs, DisplayFormatType, ExecutionPlan, Partitioning, PlanProperties, + execution_plan::{Boundedness, EmissionType}, + metrics::{ExecutionPlanMetricsSet, MetricsSet}, +}; +use datafusion_physical_expr::EquivalenceProperties; +use datafusion_physical_expr::aggregate::AggregateFunctionExpr; +use futures::{StreamExt, TryStreamExt}; +use lance_core::utils::mask::{NullableRowAddrSet, RowAddrMask, RowAddrSelection, RowAddrTreeMap}; +use lance_core::{Error, Result}; +use lance_datafusion::utils::{ExecutionPlanMetricsSetExt, SCALAR_INDEX_SEARCH_TIME_METRIC}; +use lance_index::expression::aggregate::{AggregateIndexSearch, CountQuery}; +use lance_index::scalar::{ScalarIndex, SearchResult}; +use lance_table::format::Fragment; +use roaring::RoaringBitmap; +use tracing::instrument; + +use super::utils::{IndexMetrics, InstrumentedRecordBatchStreamAdapter}; +use crate::Dataset; +use crate::index::DatasetIndexExt; +use crate::index::prefilter::DatasetPreFilter; +use crate::index::scalar_logical::{open_named_scalar_index, scalar_index_fragment_bitmap}; + +/// An execution node that answers a set of aggregates from scalar indices. +/// +/// The node returns a single record batch whose schema is the concatenation +/// of `state_fields()` for each aggregate in `aggregate_funcs`. +/// +/// It optionally has a single child [`super::scalar_index::ScalarIndexExec`] +/// whose output is used as a prefilter for each aggregate. +#[derive(Debug)] +pub struct AggregateIndexSearchExec { + dataset: Arc, + aggregates: Vec>, + aggregate_funcs: Vec>, + prefilter_input: Option>, + schema: SchemaRef, + properties: Arc, + metrics: ExecutionPlanMetricsSet, +} + +impl DisplayAs for AggregateIndexSearchExec { + fn fmt_as(&self, t: DisplayFormatType, f: &mut std::fmt::Formatter) -> std::fmt::Result { + let names = self + .aggregates + .iter() + .map(|agg| agg.to_string()) + .collect::>() + .join(","); + match t { + DisplayFormatType::Default | DisplayFormatType::Verbose => { + write!(f, "AggregateIndexSearch: aggs=[{}]", names) + } + DisplayFormatType::TreeRender => { + write!(f, "AggregateIndexSearch\naggs=[{}]", names) + } + } + } +} + +impl AggregateIndexSearchExec { + /// Build a new node. + /// + /// `aggregates` and `aggregate_funcs` must have the same length — each + /// aggregate query is paired with its DataFusion partial-state spec. + /// `prefilter_input`, if present, must produce a single batch in the + /// scalar-index result schema; that mask is intersected with the + /// aggregate's natural fragment coverage and the active deletion mask. + pub fn try_new( + dataset: Arc, + aggregates: Vec>, + aggregate_funcs: Vec>, + prefilter_input: Option>, + ) -> Result { + if aggregates.len() != aggregate_funcs.len() { + return Err(Error::invalid_input(format!( + "AggregateIndexSearchExec: aggregates ({}) and aggregate_funcs ({}) length mismatch", + aggregates.len(), + aggregate_funcs.len() + ))); + } + + for agg in &aggregates { + if agg.index_name.is_none() { + // The only aggregate we can answer without an associated index + // is a non-distinct COUNT. + let count = agg + .query + .as_any() + .downcast_ref::() + .ok_or_else(|| { + Error::invalid_input(format!( + "AggregateIndexSearchExec: aggregate {} has no associated index but is not a count", + agg + )) + })?; + if count.is_distinct() { + return Err(Error::invalid_input(format!( + "AggregateIndexSearchExec: aggregate {} has no associated index but is a distinct count", + agg + ))); + } + } + } + + let state_fields = aggregate_funcs + .iter() + .map(|agg| agg.state_fields()) + .collect::>>() + .map_err(|e| Error::invalid_input(e.to_string()))? + .into_iter() + .flatten() + .collect::>(); + let state_fields_owned: Vec = + state_fields.iter().map(|f| f.as_ref().clone()).collect(); + let schema: SchemaRef = Arc::new(Schema::new(state_fields_owned)); + + let properties = Arc::new(PlanProperties::new( + EquivalenceProperties::new(schema.clone()), + Partitioning::RoundRobinBatch(1), + EmissionType::Incremental, + Boundedness::Bounded, + )); + + Ok(Self { + dataset, + aggregates, + aggregate_funcs, + prefilter_input, + schema, + properties, + metrics: ExecutionPlanMetricsSet::new(), + }) + } + + /// Drain `prefilter_input` (a [`super::scalar_index::ScalarIndexExec`]) to + /// produce the row-address mask it serialized. + async fn load_prefilter( + prefilter_input: Arc, + context: Arc, + ) -> Result { + let mut stream = prefilter_input.execute(0, context).map_err(Error::from)?; + let batch = stream + .try_next() + .await + .map_err(Error::from)? + .ok_or_else(|| { + Error::internal( + "AggregateIndexSearchExec: prefilter input produced no batches".to_string(), + ) + })?; + // Drain any remaining batches so the upstream sees a clean shutdown. + while stream.try_next().await.map_err(Error::from)?.is_some() {} + + let result_col = batch + .column(0) + .as_any() + .downcast_ref::() + .ok_or_else(|| { + Error::internal(format!( + "AggregateIndexSearchExec: prefilter result column has type {:?}, expected Binary", + batch.column(0).data_type() + )) + })?; + RowAddrMask::from_arrow(result_col) + } + + /// Look up the column name an index lives on by inspecting manifest metadata. + async fn column_for_index(dataset: &Dataset, index_name: &str) -> Result { + let indices = dataset.load_indices_by_name(index_name).await?; + let index = indices.into_iter().next().ok_or_else(|| { + Error::internal(format!( + "AggregateIndexSearchExec: no index named '{}' found", + index_name + )) + })?; + let field_id = *index.fields.first().ok_or_else(|| { + Error::internal(format!( + "AggregateIndexSearchExec: index '{}' has no field bindings", + index_name + )) + })?; + let field = dataset.schema().field_by_id(field_id).ok_or_else(|| { + Error::internal(format!( + "AggregateIndexSearchExec: index '{}' references unknown field id {}", + index_name, field_id + )) + })?; + Ok(field.name.clone()) + } + + /// Load every backing index referenced by the aggregates and the fragment + /// bitmap each one covers. + /// + /// The returned vectors are aligned with `aggregates`: aggregates without + /// an `index_name` produce `None` in `indices` and contribute no fragment + /// bitmap to the intersection. + async fn load_indices( + dataset: Arc, + aggregates: Vec>, + index_metrics: IndexMetrics, + ) -> Result<(Vec>>, Option)> { + let mut indices = Vec::with_capacity(aggregates.len()); + let mut fragments_intersection: Option = None; + for agg in &aggregates { + match &agg.index_name { + None => indices.push(None), + Some(index_name) => { + let column = Self::column_for_index(&dataset, index_name).await?; + let bitmap = scalar_index_fragment_bitmap(&dataset, &column, index_name) + .await? + .ok_or_else(|| { + Error::internal(format!( + "AggregateIndexSearchExec: index '{}' has no fragment bitmap", + index_name + )) + })?; + fragments_intersection = Some(match fragments_intersection.take() { + None => bitmap, + Some(existing) => existing & bitmap, + }); + let index = + open_named_scalar_index(&dataset, &column, index_name, &index_metrics) + .await?; + indices.push(Some(index)); + } + } + } + Ok((indices, fragments_intersection)) + } + + /// Apply the user's algorithm to fold the prefilter, fragment allow list, + /// and deletion mask into a single [`RowAddrMask`]. + /// + /// The result is always an `AllowList` so it can be wrapped in a + /// [`SearchResult::Exact`] for [`ScalarIndex::calculate_aggregate`]. + fn combine_masks( + fragments_allow: RowAddrTreeMap, + prefilter: Option, + deletion_mask: Option>, + ) -> RowAddrMask { + let base = RowAddrMask::AllowList(fragments_allow); + let after_prefilter = match prefilter { + None => base, + Some(prefilter) => base & prefilter, + }; + match deletion_mask { + None => after_prefilter, + Some(deletion_mask) => after_prefilter & (*deletion_mask).clone(), + } + } + + /// Count the rows selected by `mask`, looking up `Full`-marker fragments + /// in the manifest so we never need to materialize a `RoaringBitmap::full()`. + fn count_from_mask(mask: &RowAddrMask, dataset: &Dataset) -> Result { + let allow = mask.allow_list().ok_or_else(|| { + Error::internal( + "AggregateIndexSearchExec: combined mask is not an AllowList".to_string(), + ) + })?; + let frag_map: HashMap = dataset + .fragments() + .iter() + .map(|f| (f.id as u32, f)) + .collect(); + let mut count = 0i64; + for (frag_id, sel) in allow.iter() { + match sel { + RowAddrSelection::Full => { + // The fragment is in the allow list with no deletions + // touching it — its row count is the physical row count. + let frag = frag_map.get(frag_id).ok_or_else(|| { + Error::internal(format!( + "AggregateIndexSearchExec: fragment {} not found in manifest", + frag_id + )) + })?; + let n = frag.physical_rows.ok_or_else(|| { + Error::internal(format!( + "AggregateIndexSearchExec: physical_rows missing for fragment {}", + frag_id + )) + })?; + count += n as i64; + } + RowAddrSelection::Partial(bitmap) => { + count += bitmap.len() as i64; + } + } + } + Ok(count) + } + + #[instrument(name = "aggregate_index_search", skip_all, level = "debug")] + async fn do_execute( + dataset: Arc, + aggregates: Vec>, + prefilter_input: Option>, + context: Arc, + plan_metrics: ExecutionPlanMetricsSet, + schema: SchemaRef, + ) -> Result { + let index_metrics = IndexMetrics::new(&plan_metrics, 0); + + // Kick off the prefilter load and index loads in parallel. + let prefilter_fut = async { + match prefilter_input { + None => Ok::, Error>(None), + Some(input) => Self::load_prefilter(input, context.clone()).await.map(Some), + } + }; + let indices_fut = async { + let timer = plan_metrics.new_time(SCALAR_INDEX_SEARCH_TIME_METRIC, 0); + let _guard = timer.timer(); + Self::load_indices(dataset.clone(), aggregates.clone(), index_metrics.clone()).await + }; + let (prefilter, (loaded_indices, fragments_intersection)) = + futures::try_join!(prefilter_fut, indices_fut)?; + + // Fall back to all dataset fragments when no aggregate has an index — + // we still need a set of fragments to anchor the deletion mask against. + let fragments_covered = fragments_intersection.unwrap_or_else(|| { + dataset + .fragments() + .iter() + .map(|f| f.id as u32) + .collect::() + }); + + // Build the fragments allow list as concrete `[0..physical_rows)` + // ranges rather than `Full` markers. `Full` interacts poorly with + // `BlockList` subtraction — `RowAddrTreeMap::Sub` materializes a + // `RoaringBitmap::full()` (2^32 rows) per fragment when a `Full` entry + // gets a partial block subtracted from it, which inflates counts and + // is expensive. Concrete ranges avoid that path entirely and keep + // `len()` exact at every combine step. + let frag_map: HashMap = dataset + .fragments() + .iter() + .map(|f| (f.id as u32, f)) + .collect(); + let mut fragments_allow = RowAddrTreeMap::new(); + for frag_id in fragments_covered.iter() { + let frag = frag_map.get(&frag_id).ok_or_else(|| { + Error::internal(format!( + "AggregateIndexSearchExec: fragment {} not in manifest", + frag_id + )) + })?; + let physical = frag.physical_rows.ok_or_else(|| { + Error::internal(format!( + "AggregateIndexSearchExec: physical_rows missing for fragment {}", + frag_id + )) + })?; + let mut bitmap = RoaringBitmap::new(); + bitmap.insert_range(0u32..(physical as u32)); + fragments_allow.insert_bitmap(frag_id, bitmap); + } + + // Load the deletion mask for the covered fragments. + let deletion_mask = + match DatasetPreFilter::create_deletion_mask(dataset.clone(), fragments_covered) { + Some(fut) => Some(fut.await?), + None => None, + }; + + // Combine prefilter ∩ fragment-allow − deletion into a single AllowList. + let combined = Self::combine_masks(fragments_allow, prefilter, deletion_mask); + + // Compute partial state, one aggregate at a time. + let total_rows = dataset.count_all_rows().await? as u64; + let mut arrays: Vec> = Vec::with_capacity(aggregates.len()); + for (agg, index) in aggregates.iter().zip(loaded_indices.iter()) { + match index { + Some(index) => { + let allow_list = combined.allow_list().cloned().unwrap_or_default(); + let search_result = SearchResult::Exact(NullableRowAddrSet::new( + allow_list, + RowAddrTreeMap::new(), + )); + let scalar = index + .calculate_aggregate( + agg.query.as_ref(), + Some(search_result), + total_rows, + &index_metrics, + ) + .await?; + arrays.push(scalar.as_array().clone()); + } + None => { + // Validated in `try_new`: this can only be non-distinct COUNT. + let count = Self::count_from_mask(&combined, dataset.as_ref())?; + let arr = Arc::new(Int64Array::from(vec![count])) as Arc; + arrays.push(arr); + } + } + } + + Ok(RecordBatch::try_new(schema, arrays)?) + } +} + +impl ExecutionPlan for AggregateIndexSearchExec { + fn name(&self) -> &str { + "AggregateIndexSearchExec" + } + + fn as_any(&self) -> &dyn std::any::Any { + self + } + + fn schema(&self) -> SchemaRef { + self.schema.clone() + } + + fn children(&self) -> Vec<&Arc> { + match &self.prefilter_input { + Some(input) => vec![input], + None => vec![], + } + } + + fn with_new_children( + self: Arc, + children: Vec>, + ) -> datafusion::error::Result> { + let prefilter_input = match children.len() { + 0 => None, + 1 => Some(children.into_iter().next().unwrap()), + n => { + return Err(datafusion::error::DataFusionError::Internal(format!( + "AggregateIndexSearchExec accepts 0 or 1 children, got {}", + n + ))); + } + }; + Ok(Arc::new(Self { + dataset: self.dataset.clone(), + aggregates: self.aggregates.clone(), + aggregate_funcs: self.aggregate_funcs.clone(), + prefilter_input, + schema: self.schema.clone(), + properties: self.properties.clone(), + metrics: self.metrics.clone(), + })) + } + + fn execute( + &self, + partition: usize, + context: Arc, + ) -> datafusion::error::Result { + let schema = self.schema.clone(); + let batch_fut = Self::do_execute( + self.dataset.clone(), + self.aggregates.clone(), + self.prefilter_input.clone(), + context, + self.metrics.clone(), + schema.clone(), + ); + let stream = futures::stream::iter(vec![batch_fut]) + .then(|fut| async move { fut.await.map_err(|err| err.into()) }) + .boxed(); + Ok(Box::pin(InstrumentedRecordBatchStreamAdapter::new( + schema, + stream, + partition, + &self.metrics, + ))) + } + + fn partition_statistics( + &self, + _partition: Option, + ) -> datafusion::error::Result { + Ok(datafusion::physical_plan::Statistics { + num_rows: datafusion::common::stats::Precision::Exact(1), + ..datafusion::physical_plan::Statistics::new_unknown(&self.schema) + }) + } + + fn metrics(&self) -> Option { + Some(self.metrics.clone_inner()) + } + + fn properties(&self) -> &Arc { + &self.properties + } + + fn supports_limit_pushdown(&self) -> bool { + false + } +} + +#[cfg(test)] +mod tests { + use std::{ops::Bound, sync::Arc}; + + use arrow::datatypes::{Int64Type, UInt64Type}; + use datafusion::common::DFSchema; + use datafusion::execution::TaskContext; + use datafusion::functions_aggregate; + use datafusion::logical_expr::lit; + use datafusion::physical_expr::execution_props::ExecutionProps; + use datafusion::physical_plan::ExecutionPlan; + use datafusion::physical_planner::create_aggregate_expr_and_maybe_filter; + use datafusion::scalar::ScalarValue; + use futures::TryStreamExt; + use lance_core::utils::mask::{RowAddrMask, RowAddrTreeMap}; + use lance_core::utils::tempfile::TempStrDir; + use lance_datagen::gen_batch; + use lance_index::IndexType; + use lance_index::expression::aggregate::{AggregateIndexSearch, CountQuery}; + use lance_index::scalar::{ + SargableQuery, ScalarIndexParams, + expression::{ScalarIndexExpr, ScalarIndexSearch}, + }; + + use super::*; + use crate::Dataset; + use crate::index::DatasetIndexExt; + use crate::io::exec::scalar_index::ScalarIndexExec; + use crate::utils::test::{DatagenExt, FragmentCount, FragmentRowCount}; + + /// Build an `AggregateFunctionExpr` matching `COUNT(*)`. + fn count_star_expr(input_schema: &SchemaRef) -> Arc { + let expr = functions_aggregate::count::count(lit(1)); + let df_schema = DFSchema::try_from(input_schema.as_ref().clone()).unwrap(); + let (agg_expr, _filter, _order_by) = create_aggregate_expr_and_maybe_filter( + &expr, + &df_schema, + input_schema.as_ref(), + &ExecutionProps::default(), + ) + .unwrap(); + agg_expr + } + + fn count_search(index_name: Option<&str>) -> Arc { + Arc::new(AggregateIndexSearch { + index_name: index_name.map(str::to_string), + query: Arc::new(CountQuery::basic()), + filter: None, + original_expr: lit(0i64), + }) + } + + struct Fixture { + dataset: Arc, + _tmp: TempStrDir, + } + + /// 4 fragments × 10 rows, ascending `ordered` column with a BTree index. + async fn make_fixture() -> Fixture { + let tmp = TempStrDir::default(); + let mut dataset = gen_batch() + .col("ordered", lance_datagen::array::step::()) + .into_dataset( + tmp.as_str(), + FragmentCount::from(4), + FragmentRowCount::from(10), + ) + .await + .unwrap(); + + dataset + .create_index( + &["ordered"], + IndexType::BTree, + None, + &ScalarIndexParams::default(), + true, + ) + .await + .unwrap(); + + Fixture { + dataset: Arc::new(dataset), + _tmp: tmp, + } + } + + fn input_schema() -> SchemaRef { + Arc::new(Schema::new(vec![arrow_schema::Field::new( + "ordered", + arrow_schema::DataType::UInt64, + false, + )])) + } + + async fn run(plan: AggregateIndexSearchExec) -> i64 { + let stream = plan.execute(0, Arc::new(TaskContext::default())).unwrap(); + let batches: Vec = stream.try_collect().await.unwrap(); + assert_eq!(batches.len(), 1); + assert_eq!(batches[0].num_rows(), 1); + batches[0] + .column(0) + .as_any() + .downcast_ref::>() + .expect("count partial state should be Int64") + .value(0) + } + + #[tokio::test] + async fn try_new_rejects_length_mismatch() { + let fixture = make_fixture().await; + let schema = input_schema(); + let err = AggregateIndexSearchExec::try_new( + fixture.dataset, + vec![count_search(None)], + vec![count_star_expr(&schema), count_star_expr(&schema)], + None, + ) + .unwrap_err(); + assert!(err.to_string().contains("length mismatch"), "{err}"); + } + + #[tokio::test] + async fn try_new_rejects_distinct_count_without_index() { + let fixture = make_fixture().await; + let schema = input_schema(); + let distinct = Arc::new(AggregateIndexSearch { + index_name: None, + query: Arc::new(CountQuery::distinct()), + filter: None, + original_expr: lit(0i64), + }); + let err = AggregateIndexSearchExec::try_new( + fixture.dataset, + vec![distinct], + vec![count_star_expr(&schema)], + None, + ) + .unwrap_err(); + assert!(err.to_string().contains("distinct count"), "{err}"); + } + + #[tokio::test] + async fn count_from_mask_mixes_full_and_partial() { + // Synthesize an AllowList containing one Full-marker fragment and one + // Partial bitmap; verify the Full fragment falls back to physical_rows + // from the manifest and Partial falls back to bitmap.len(). + let fixture = make_fixture().await; + let mut tm = RowAddrTreeMap::new(); + // Fragment 0: full (10 physical rows). + tm.insert_fragment(0); + // Fragment 1: partial with explicit row addrs. + let row_addr_for = |frag_id: u32, offset: u32| ((frag_id as u64) << 32) | offset as u64; + tm.insert(row_addr_for(1, 0)); + tm.insert(row_addr_for(1, 1)); + tm.insert(row_addr_for(1, 2)); + + let mask = RowAddrMask::AllowList(tm); + let count = + AggregateIndexSearchExec::count_from_mask(&mask, fixture.dataset.as_ref()).unwrap(); + assert_eq!(count, 10 + 3); + } + + #[tokio::test] + async fn execute_count_no_prefilter() { + let fixture = make_fixture().await; + let dataset = fixture.dataset.clone(); + let schema = input_schema(); + let plan = AggregateIndexSearchExec::try_new( + dataset.clone(), + vec![count_search(None)], + vec![count_star_expr(&schema)], + None, + ) + .unwrap(); + let count = run(plan).await; + assert_eq!(count, 40); // 4 fragments × 10 rows + } + + #[tokio::test] + async fn execute_count_with_allow_list_prefilter() { + let fixture = make_fixture().await; + let dataset = fixture.dataset.clone(); + let schema = input_schema(); + + // `ordered < 25` matches 25 rows across the four fragments. + let prefilter_expr = ScalarIndexExpr::Query(ScalarIndexSearch { + column: "ordered".to_string(), + index_name: "ordered_idx".to_string(), + index_type: "BTree".to_string(), + query: Arc::new(SargableQuery::Range( + Bound::Unbounded, + Bound::Excluded(ScalarValue::UInt64(Some(25))), + )), + needs_recheck: false, + }); + let prefilter: Arc = + Arc::new(ScalarIndexExec::new(dataset.clone(), prefilter_expr)); + + let plan = AggregateIndexSearchExec::try_new( + dataset.clone(), + vec![count_search(None)], + vec![count_star_expr(&schema)], + Some(prefilter), + ) + .unwrap(); + let count = run(plan).await; + assert_eq!(count, 25); + } + + #[tokio::test] + async fn execute_count_with_block_list_prefilter() { + let fixture = make_fixture().await; + let dataset = fixture.dataset.clone(); + let schema = input_schema(); + + // NOT(ordered < 25) is a block list of those 25 rows — 40 − 25 = 15. + let prefilter_expr = + ScalarIndexExpr::Not(Box::new(ScalarIndexExpr::Query(ScalarIndexSearch { + column: "ordered".to_string(), + index_name: "ordered_idx".to_string(), + index_type: "BTree".to_string(), + query: Arc::new(SargableQuery::Range( + Bound::Unbounded, + Bound::Excluded(ScalarValue::UInt64(Some(25))), + )), + needs_recheck: false, + }))); + let prefilter: Arc = + Arc::new(ScalarIndexExec::new(dataset.clone(), prefilter_expr)); + + let plan = AggregateIndexSearchExec::try_new( + dataset.clone(), + vec![count_search(None)], + vec![count_star_expr(&schema)], + Some(prefilter), + ) + .unwrap(); + let count = run(plan).await; + assert_eq!(count, 15); + } + + #[tokio::test] + async fn execute_count_respects_deletions() { + let fixture = make_fixture().await; + let mut dataset = (*fixture.dataset).clone(); + // Delete the first ten rows of the dataset (which live in fragment 0). + dataset.delete("ordered < 10").await.unwrap(); + let dataset = Arc::new(dataset); + + let schema = input_schema(); + let plan = AggregateIndexSearchExec::try_new( + dataset.clone(), + vec![count_search(None)], + vec![count_star_expr(&schema)], + None, + ) + .unwrap(); + let count = run(plan).await; + assert_eq!(count, 30); + } +} From 42593bd909eec926f5e4906728a3082587a023d3 Mon Sep 17 00:00:00 2001 From: Weston Pace Date: Mon, 18 May 2026 21:15:28 +0000 Subject: [PATCH 2/4] feat(index): wire aggregate pushdown into the physical optimizer MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Adds AggregateIndexPushdown — a PhysicalOptimizerRule that walks the plan top-down and rewrites COUNT-shaped aggregates into AggregateIndexSearchExec so they're answered from index metadata + the deletion mask + an optional scalar-index prefilter, without scanning column data. Recognized shape: AggregateExec(Single, aggs=[COUNT(*)], group_by=[]) └── FilteredReadExec { no refine_filter, full_filter only when index_input is present, no scan range, no with_deleted_rows, no fragment subset } Rewritten to: AggregateExec(Final, aggs=[COUNT(*)], group_by=[]) └── AggregateIndexSearchExec { prefilter_input = index_input } The outer AggregateExec is dropped to AggregateMode::Final because AggregateIndexSearchExec emits one row of partial state. is_count_star is intentionally conservative: function name == "count", not distinct, single non-null Literal argument. Anything else (COUNT(col) with a column ref, DISTINCT, FILTER (WHERE), GROUP BY, residual filter, scan range, with_deleted_rows, fragment subset) leaves the existing scan path untouched. Registered first in get_physical_optimizer so generic rules don't see the rewritten subtree. Tests (4, driving the rule end-to-end through Scanner::create_plan): - rule_fires_on_unfiltered_count_star - rule_fires_when_filter_fully_indexed (BTree filter pushdown) - rule_skips_when_filter_needs_refine (unindexed column residual) - rule_skips_count_with_group_by Existing count_rows tests (3) and aggregate_index exec tests (7) all continue to pass. Co-Authored-By: Claude Opus 4.7 (1M context) --- rust/lance/src/io/exec.rs | 1 + .../src/io/exec/aggregate_index_pushdown.rs | 381 ++++++++++++++++++ rust/lance/src/io/exec/optimizer.rs | 4 + 3 files changed, 386 insertions(+) create mode 100644 rust/lance/src/io/exec/aggregate_index_pushdown.rs diff --git a/rust/lance/src/io/exec.rs b/rust/lance/src/io/exec.rs index badc301ce59..a00a1b459b8 100644 --- a/rust/lance/src/io/exec.rs +++ b/rust/lance/src/io/exec.rs @@ -6,6 +6,7 @@ //! WARNING: Internal API with no stability guarantees. pub mod aggregate_index; +mod aggregate_index_pushdown; #[cfg(feature = "substrait")] pub mod ann_proto; mod filter; diff --git a/rust/lance/src/io/exec/aggregate_index_pushdown.rs b/rust/lance/src/io/exec/aggregate_index_pushdown.rs new file mode 100644 index 00000000000..cb0812ef87c --- /dev/null +++ b/rust/lance/src/io/exec/aggregate_index_pushdown.rs @@ -0,0 +1,381 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright The Lance Authors + +//! Physical optimizer rule that rewrites `COUNT`-shaped aggregates into +//! [`AggregateIndexSearchExec`]. +//! +//! Recognized shape: +//! +//! ```text +//! AggregateExec(Single, aggs=[COUNT(*)], group_by=[]) +//! └── FilteredReadExec { full_filter ⊆ index_input, no refine_filter, ... } +//! ``` +//! +//! Rewritten to: +//! +//! ```text +//! AggregateExec(Final, aggs=[COUNT(*)], group_by=[]) +//! └── AggregateIndexSearchExec { prefilter_input = index_input } +//! ``` +//! +//! [`AggregateIndexSearchExec`] emits partial-state, so the outer +//! `AggregateExec(Final)` performs the final combine. + +use std::sync::Arc; + +use datafusion::common::tree_node::{Transformed, TreeNode}; +use datafusion::config::ConfigOptions; +use datafusion::error::Result as DFResult; +use datafusion::logical_expr::lit; +use datafusion::physical_optimizer::PhysicalOptimizerRule; +use datafusion::physical_plan::{ + ExecutionPlan, + aggregates::{AggregateExec, AggregateMode, PhysicalGroupBy}, +}; +use datafusion_physical_expr::aggregate::AggregateFunctionExpr; +use datafusion_physical_expr::expressions::Literal; +use lance_index::expression::aggregate::{AggregateIndexSearch, CountQuery}; + +use super::aggregate_index::AggregateIndexSearchExec; +use super::filtered_read::FilteredReadExec; + +/// Physical optimizer rule that pushes `COUNT`-shaped aggregates into +/// [`AggregateIndexSearchExec`], answering them from index metadata + the +/// deletion mask + an optional scalar-index prefilter, without scanning column +/// data. +/// +/// Only fires when the shape is verifiably safe; everything outside that +/// envelope (GROUP BY, residual filters, scan ranges, etc.) is left alone for +/// the normal scan path. +#[derive(Debug)] +pub struct AggregateIndexPushdown; + +impl PhysicalOptimizerRule for AggregateIndexPushdown { + fn optimize( + &self, + plan: Arc, + _config: &ConfigOptions, + ) -> DFResult> { + Ok(plan + .transform_down(|plan| { + let Some(agg) = plan.as_any().downcast_ref::() else { + return Ok(Transformed::no(plan)); + }; + if let Some(rewritten) = try_rewrite(agg)? { + return Ok(Transformed::yes(rewritten)); + } + Ok(Transformed::no(plan)) + })? + .data) + } + + fn name(&self) -> &str { + "aggregate_index_pushdown" + } + + fn schema_check(&self) -> bool { + true + } +} + +fn try_rewrite(agg: &AggregateExec) -> DFResult>> { + // The Lance scanner emits AggregateMode::Single. Other modes mean + // somebody else is already wrapping us in a partial/final pair; leave them + // alone to avoid double-wrapping. + if !matches!(agg.mode(), AggregateMode::Single) { + return Ok(None); + } + if !agg.group_expr().is_empty() { + return Ok(None); + } + if agg.aggr_expr().is_empty() { + return Ok(None); + } + + // Every aggregate must be a `COUNT()` shape (i.e. COUNT(*) / + // COUNT(1) / etc.) with no per-aggregate FILTER. Anything that depends on + // a column value can't be answered without scanning that column. + for (af, filter) in agg.aggr_expr().iter().zip(agg.filter_expr().iter()) { + if !is_count_star(af) { + return Ok(None); + } + if filter.is_some() { + return Ok(None); + } + } + + // The input must be a FilteredReadExec whose filter is either absent or + // fully evaluable by a child scalar-index search. + let child = &agg.children()[0]; + let Some(filtered_read) = child.as_any().downcast_ref::() else { + return Ok(None); + }; + + let options = filtered_read.options(); + // A refine filter is a residual the index couldn't fully evaluate — we'd + // need to scan data to apply it, so bail. + if options.refine_filter.is_some() { + return Ok(None); + } + // A full_filter without an index_input means the filter is evaluated by + // re-reading every row; not pushdownable. + if options.full_filter.is_some() && filtered_read.index_input().is_none() { + return Ok(None); + } + // LIMIT/OFFSET would change the count. + if options.scan_range_before_filter.is_some() || options.scan_range_after_filter.is_some() { + return Ok(None); + } + // We rely on the deletion mask being applied; with_deleted_rows changes + // that contract. + if options.with_deleted_rows { + return Ok(None); + } + // We assume the natural fragment coverage of the dataset; a fragment + // subset would require routing it into the exec. + if options.fragments.is_some() { + return Ok(None); + } + + let dataset = filtered_read.dataset().clone(); + let prefilter_input = filtered_read.index_input().cloned(); + let aggregates: Vec> = agg + .aggr_expr() + .iter() + .map(|_| { + Arc::new(AggregateIndexSearch { + index_name: None, + query: Arc::new(CountQuery::basic()), + filter: None, + // `original_expr` is only used for `Display`; the physical + // plan no longer carries the source `Expr`. + original_expr: lit(0i64), + }) + }) + .collect(); + let aggregate_funcs: Vec> = agg.aggr_expr().to_vec(); + + let exec = AggregateIndexSearchExec::try_new( + dataset, + aggregates, + aggregate_funcs, + prefilter_input, + )?; + let exec_schema = exec.schema(); + let exec: Arc = Arc::new(exec); + + // Wrap with AggregateExec(Final) so a downstream consumer that expected + // the original AggregateExec output schema continues to see it. + let null_filters: Vec>> = + (0..agg.aggr_expr().len()).map(|_| None).collect(); + let final_agg = AggregateExec::try_new( + AggregateMode::Final, + PhysicalGroupBy::default(), + agg.aggr_expr().to_vec(), + null_filters, + exec, + exec_schema, + )?; + Ok(Some(Arc::new(final_agg))) +} + +/// Returns `true` if `af` is `COUNT()` with no DISTINCT. +fn is_count_star(af: &Arc) -> bool { + if af.fun().name() != "count" { + return false; + } + if af.is_distinct() { + return false; + } + let args = af.expressions(); + if args.len() != 1 { + return false; + } + let Some(lit) = args[0].as_any().downcast_ref::() else { + return false; + }; + // `COUNT(NULL)` would always return 0; rule it out so we don't accidentally + // produce a wrong answer if the planner ever lets it through. + !lit.value().is_null() +} + +#[cfg(test)] +mod tests { + use std::sync::Arc; + + use arrow::datatypes::{Int64Type, UInt64Type}; + use datafusion::common::tree_node::{TreeNode, TreeNodeRecursion}; + use datafusion::physical_plan::{ExecutionPlan, displayable}; + use futures::TryStreamExt; + use lance_core::utils::tempfile::TempStrDir; + use lance_datagen::gen_batch; + use lance_index::IndexType; + use lance_index::scalar::ScalarIndexParams; + + use super::*; + use crate::Dataset; + use crate::dataset::scanner::AggregateExpr; + use crate::index::DatasetIndexExt; + use crate::utils::test::{DatagenExt, FragmentCount, FragmentRowCount}; + + struct Fixture { + dataset: Arc, + _tmp: TempStrDir, + } + + /// 4 fragments × 10 rows, ascending `ordered` column with a BTree index. + async fn make_fixture() -> Fixture { + let tmp = TempStrDir::default(); + let mut dataset = gen_batch() + .col("ordered", lance_datagen::array::step::()) + .into_dataset( + tmp.as_str(), + FragmentCount::from(4), + FragmentRowCount::from(10), + ) + .await + .unwrap(); + dataset + .create_index( + &["ordered"], + IndexType::BTree, + None, + &ScalarIndexParams::default(), + true, + ) + .await + .unwrap(); + Fixture { + dataset: Arc::new(dataset), + _tmp: tmp, + } + } + + /// True if `plan` contains an `AggregateIndexSearchExec` anywhere in its tree. + fn plan_contains_pushdown(plan: &Arc) -> bool { + let mut found = false; + plan.apply(|node| { + if node.as_any().is::() { + found = true; + Ok(TreeNodeRecursion::Stop) + } else { + Ok(TreeNodeRecursion::Continue) + } + }) + .unwrap(); + found + } + + /// Drive the rule via `Scanner::create_plan` (which registers the rule + /// through `get_physical_optimizer`) and return both the plan and the + /// final count for inspection. + async fn run_count(scanner: &mut crate::dataset::scanner::Scanner) -> (Arc, i64) { + scanner + .aggregate(AggregateExpr::builder().count_star().build()) + .unwrap(); + let plan = scanner.create_plan().await.unwrap(); + let stream = datafusion::physical_plan::execute_stream( + plan.clone(), + Arc::new(datafusion::execution::TaskContext::default()), + ) + .unwrap(); + let batches: Vec<_> = stream.try_collect().await.unwrap(); + assert_eq!(batches.len(), 1, "count plan emitted {} batches", batches.len()); + let count = batches[0] + .column(0) + .as_any() + .downcast_ref::>() + .expect("count column should be Int64") + .value(0); + (plan, count) + } + + #[tokio::test] + async fn rule_fires_on_unfiltered_count_star() { + let fixture = make_fixture().await; + let mut scanner = fixture.dataset.scan(); + let (plan, count) = run_count(&mut scanner).await; + assert_eq!(count, 40); + assert!( + plan_contains_pushdown(&plan), + "expected AggregateIndexSearchExec in plan: {}", + displayable(plan.as_ref()).indent(true) + ); + } + + #[tokio::test] + async fn rule_fires_when_filter_fully_indexed() { + let fixture = make_fixture().await; + let mut scanner = fixture.dataset.scan(); + scanner.filter("ordered < 25").unwrap(); + let (plan, count) = run_count(&mut scanner).await; + assert_eq!(count, 25); + assert!( + plan_contains_pushdown(&plan), + "expected AggregateIndexSearchExec in plan: {}", + displayable(plan.as_ref()).indent(true) + ); + } + + #[tokio::test] + async fn rule_skips_when_filter_needs_refine() { + // No index on `unindexed`, so the filter must be applied during the + // scan; the rule must not fire. + let tmp = TempStrDir::default(); + let mut dataset = gen_batch() + .col("ordered", lance_datagen::array::step::()) + .col("unindexed", lance_datagen::array::step::()) + .into_dataset( + tmp.as_str(), + FragmentCount::from(4), + FragmentRowCount::from(10), + ) + .await + .unwrap(); + dataset + .create_index( + &["ordered"], + IndexType::BTree, + None, + &ScalarIndexParams::default(), + true, + ) + .await + .unwrap(); + let dataset = Arc::new(dataset); + + let mut scanner = dataset.scan(); + scanner.filter("unindexed > 5").unwrap(); + let (plan, count) = run_count(&mut scanner).await; + // 40 rows total, values are 0..40 across fragments; `> 5` drops 0..6. + // Right answer either way; the point is the rule didn't fire. + assert_eq!(count, 34); + assert!( + !plan_contains_pushdown(&plan), + "rule should not fire with non-indexed filter, got plan: {}", + displayable(plan.as_ref()).indent(true) + ); + } + + #[tokio::test] + async fn rule_skips_count_with_group_by() { + let fixture = make_fixture().await; + // GROUP BY isn't supported by the rule yet — make sure we leave it alone. + let mut scanner = fixture.dataset.scan(); + scanner + .aggregate( + AggregateExpr::builder() + .group_by("ordered") + .count_star() + .build(), + ) + .unwrap(); + let plan = scanner.create_plan().await.unwrap(); + assert!( + !plan_contains_pushdown(&plan), + "rule should not fire for GROUP BY: {}", + displayable(plan.as_ref()).indent(true) + ); + } +} + diff --git a/rust/lance/src/io/exec/optimizer.rs b/rust/lance/src/io/exec/optimizer.rs index f031e10ce19..28cac1f0381 100644 --- a/rust/lance/src/io/exec/optimizer.rs +++ b/rust/lance/src/io/exec/optimizer.rs @@ -171,6 +171,10 @@ impl PhysicalOptimizerRule for SimplifyProjection { pub fn get_physical_optimizer() -> PhysicalOptimizer { PhysicalOptimizer::with_rules(vec![ + // Rewrite COUNT-shaped aggregates into AggregateIndexSearchExec so + // they can be answered without scanning column data. Runs before the + // generic rules so they don't see the rewritten subtree. + Arc::new(crate::io::exec::aggregate_index_pushdown::AggregateIndexPushdown), Arc::new(crate::io::exec::optimizer::CoalesceTake), Arc::new(crate::io::exec::optimizer::SimplifyProjection), // Push down limit into FilteredReadExec and other Execs via with_fetch() From f187fcfcaffca477b157d0de7b691db732e0b59b Mon Sep 17 00:00:00 2001 From: Weston Pace Date: Tue, 19 May 2026 23:03:48 +0000 Subject: [PATCH 3/4] Address review feedback and CI failures on aggregate-pushdown PR MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit CI fixes: - Add SPDX license headers to expression.rs and scalar/expression.rs. - cargo fmt the rule file (3 spots). - Update test_count_star_single_fragment and test_scanner_count_rows in dataset_aggregate.rs to expect the new AggregateExec(Final) → AggregateIndexSearchExec shape now that the rule fires by default. Correctness fixes (both pointed out by automated review on #6831): - Stable row IDs: DatasetPreFilter::create_deletion_mask returns an AllowList in stable-id space when the dataset uses stable row IDs, but AggregateIndexSearchExec builds its fragments-allow list in row-address space. ANDing across mismatched id spaces undercounts silently. Gate the rule on !manifest.uses_stable_row_ids() until the exec can reconcile the two id spaces. - Partial index coverage: when an index is built and then a fragment is appended, the index's fragment bitmap no longer covers the whole dataset. The original rule fired anyway and silently dropped rows in the unindexed fragments. The proper fix needs an async coverage check that's not expressible in a sync PhysicalOptimizerRule; until we plumb that through, narrow the rule to only fire when there is no filter at all (no full_filter, no refine_filter, no index_input). Unfiltered counts remain correct and still benefit from the rewrite. Both narrowings are documented in the module-level doc and the inline `try_rewrite` comments so a follow-up can lift them once the underlying machinery is in place. Repo hygiene: - Drop aggregate-pushdown-research.md from the repo root. It was a one-off survey not referenced by any code or doc. New regression tests in aggregate_index_pushdown.rs: - rule_skips_with_stable_row_ids — toggles enable_stable_row_ids + delete, asserts the rule does not fire and the count is correct. - rule_skips_partial_index_coverage — builds index over 4 fragments, appends a 5th, runs COUNT(*) WHERE indexed_col < N, asserts the rule does not fire and the count includes the appended fragment. - rule_skips_when_filter_present_even_if_indexed — replaces the old rule_fires_when_filter_fully_indexed; documents that the indexed- filter case is deferred. All 13 aggregate_index* tests pass; cargo check --workspace and cargo clippy -p lance -p lance-index --tests --benches -- -D warnings are clean. Co-Authored-By: Claude Opus 4.7 (1M context) --- aggregate-pushdown-research.md | 225 ------------------ rust/lance-index/src/expression.rs | 3 + rust/lance-index/src/scalar/expression.rs | 7 +- .../src/dataset/tests/dataset_aggregate.rs | 15 +- .../src/io/exec/aggregate_index_pushdown.rs | 165 +++++++++++-- 5 files changed, 160 insertions(+), 255 deletions(-) delete mode 100644 aggregate-pushdown-research.md diff --git a/aggregate-pushdown-research.md b/aggregate-pushdown-research.md deleted file mode 100644 index e4084423fce..00000000000 --- a/aggregate-pushdown-research.md +++ /dev/null @@ -1,225 +0,0 @@ -# Aggregate Pushdown in Mature Query Engines - -Background research for the Lance `feat-aggregate-pushdown` work. The motivating use case is `COUNT(DISTINCT col)` directly from a bitmap index, but this report maps the broader design space. - -## 1. Executive Summary - -- **Three distinct families of "aggregate pushdown"** are conflated in vendor docs. Keep them separate when designing Lance's APIs: (a) *metadata-only execution* — answer the aggregate from per-fragment statistics with zero data IO (Snowflake, Iceberg PR #6622, SQL Server segment metadata for `MIN/MAX/COUNT`); (b) *scan-local aggregation* — run the aggregate inside the scan operator over compressed/encoded data, eliminating a separate Aggregate node (SQL Server "Aggregate Pushdown" since 2016, ClickHouse `optimize_aggregation_in_order`); (c) *materialized/pre-aggregated structures* — separate physical artifact that answers many GROUP BYs (ClickHouse projections, Pinot star-tree, SQL Server indexed views, AggregatingMergeTree). -- **`MIN/MAX` is the universally-supported case.** Every engine has a `MinMaxAggPath`-equivalent that either reads endpoints from a sorted index (Postgres) or reads per-segment min/max statistics (everyone else). Lance has min/max page statistics already — turning these into an `Aggregate` rewrite is the lowest-hanging fruit and matches `preprocess_minmax_aggregates` in Postgres almost exactly. -- **`COUNT(*)` from metadata is universally supported but with caveats.** Without a predicate, every engine answers from row counts in fragment/manifest metadata. *With* a predicate, only fragments whose stats *prove* full inclusion or full exclusion can be skipped — partial fragments must still be scanned. DataFusion's "Fully Matching / Partially Matching / Not Matching" trichotomy (the limit-pruning blog post, March 2026) is the cleanest articulation. -- **`COUNT(DISTINCT)` from a bitmap index is unusual but legitimate.** Druid's `cardinality` aggregator returns approximate distinct counts directly from per-value bitmaps. *Exact* `COUNT(DISTINCT)` from a bitmap is also trivial — it is the dictionary size after applying the predicate's row mask. Lance's bitmap index already has per-value posting lists, so exact distinct count is the natural fit, not HLL. -- **Partial vs. full pushdown matters at the planner level.** Spark's `SupportsPushDownAggregates.supportCompletePushDown()` is the canonical API: per-fragment partial aggregates with a final reduction step in the engine. This is also how Postgres's partition-wise aggregate and postgres_fdw work (`combinefunc`/`serialfunc`/`deserialfunc`). Lance will likely need the same split because indexes are per-fragment. -- **NULL semantics differ between aggregates and become an issue.** `COUNT(*)` counts rows; `COUNT(col)` skips nulls; `MIN/MAX` skip nulls. Iceberg's PR #6622 distinguishes "stat is null because column is all-null" (legal answer) from "stat is missing" (abort pushdown) via a `hasValue` flag. Lance needs the same distinction. -- **Predicate compatibility is the gating constraint.** A pushed aggregate is only legal if the predicate is *also* fully evaluable from the same metadata — otherwise the count/min/max applies to an over-set of rows. This is the source of most correctness bugs in this area (cf. the Iceberg "Fix aggregate pushdown" thread). -- **GROUP BY pushdown is the hard mode.** SQL Server's "grouped aggregate pushdown" only fires when the grouping key bit-packs into ≤10 bits *and* a runtime "benefit measure" stays above a threshold. Pinot's star-tree solves it with a precomputed index. ClickHouse's projections do too. There is no cheap implementation — Lance should defer until non-grouped pushdown is solid. -- **MVCC/visibility is an issue only for transactional engines.** Postgres's index-only-scan has to consult the visibility map; SQL Server's pushdown only applies to "compressed rowgroups" not the delta store. Lance's append-only/versioned model sidesteps this — but the analogue is *deletion vectors / row-level deletes*. Iceberg PR #6622 explicitly disables aggregate pushdown when row-level deletes exist. Lance must do the same when deletion vectors apply. -- **The optimizer integration is consistently a dedicated planner pass, not a generic rule.** Postgres's `preprocess_minmax_aggregates` runs in `grouping_planner` just before `query_planner`. DataFusion's `AggregateStatistics` is a `PhysicalOptimizerRule`. Spark uses a V2 datasource interface (`SupportsPushDownAggregates`). The pattern is consistent: detect the shape, build an alternative path, let the cost model choose. - ---- - -## 2. Taxonomy of Techniques - -``` - ┌─────────────────────────────────────────────┐ - │ Aggregate Pushdown │ - └─────────────────────────────────────────────┘ - │ - ┌──────────────────────────────┼──────────────────────────────┐ - ▼ ▼ ▼ -┌───────────────┐ ┌───────────────┐ ┌───────────────┐ -│ Metadata- │ │ Scan-local │ │ Materialized/ │ -│ only │ │ aggregation │ │ Pre-aggregated│ -│ (no IO) │ │ (closer to │ │ artifacts │ -│ │ │ data) │ │ │ -└───────────────┘ └───────────────┘ └───────────────┘ - │ │ │ - │ MIN/MAX from index endpoints │ Aggregate inside scan │ Indexed views (MSSQL) - │ (Postgres MinMaxAggPath) │ over compressed data │ Projections (ClickHouse) - │ │ (MSSQL agg pushdown) │ AggregatingMergeTree - │ MIN/MAX/COUNT from zone-maps │ │ Star-tree (Pinot) - │ (Iceberg, Snowflake, MSSQL │ SIMD-vectorized agg over │ Materialized views (PG, Snowflake) - │ segments, DuckDB zonemap) │ bit-packed encoded data │ - │ │ │ Roll-up tables (Druid) - │ COUNT(*) from row counts │ Grouped agg pushdown │ - │ │ (MSSQL, 2019+) │ - │ COUNT DISTINCT from bitmap │ │ - │ dictionary (Druid) │ │ - │ │ │ - │ HLL distinct from sketches │ │ - │ (Druid hyperUnique, BQ, Snow) │ │ - └──────────────────────────────────┴───────────────────────────────┘ - - Orthogonal axis: partial vs. complete - ┌─────────────────────────────────────────────────────────────────┐ - │ Complete: source returns final answer (single fragment, or │ - │ commutative aggregate over independent fragments │ - │ where source does the reduction itself). │ - │ Partial: source returns per-fragment partial aggregate state; │ - │ engine reduces with combinefunc. │ - │ - Spark: SupportsPushDownAggregates.supportCompletePushDown() │ - │ - Postgres: partial aggregates (combine/serial/deserialfunc) │ - │ - postgres_fdw: per-foreign-server partial aggregation │ - └─────────────────────────────────────────────────────────────────┘ -``` - ---- - -## 3. Per-Engine Sections - -### 3.1 PostgreSQL - -**MIN/MAX via `MinMaxAggPath` — `src/backend/optimizer/plan/planagg.c`.** The function `preprocess_minmax_aggregates(PlannerInfo *root)` is called by `grouping_planner` just before `query_planner`. It checks: only aggregates in target list, single base relation, no `GROUP BY`/window/CTE, single-argument MIN/MAX (recognized via its sort operator from `pg_aggregate` — `fetch_agg_sort_op`), no `DISTINCT`/`ORDER BY`/`FILTER` on the aggregate, no mutable functions, no row-type args. For each matching aggregate, `build_minmax_path` builds an effective `SELECT col FROM t WHERE col IS NOT NULL ORDER BY col [DESC] LIMIT 1` subquery and registers a `MinMaxAggPath` against the `UPPERREL_GROUP_AGG` upper rel. Cost model decides between this and a scan-based `AggPath`. ([planagg.c](https://doxygen.postgresql.org/planagg_8c_source.html), [Cybertec write-up](https://www.cybertec-postgresql.com/en/speeding-up-min-and-max/)) - -**`COUNT(*)` and index-only scans.** Postgres has no `count(*)`-from-index optimization analogous to MIN/MAX. The closest is *index-only scan*: a btree scan that skips heap access when the visibility map's all-visible bit is set for the heap page. `EXPLAIN` reports `Heap Fetches: N` for cases where the VM bit was not set. Index-only scans require: index type supports it (btree always; GiST/SP-GiST for some opclasses; GIN never); query references only indexed columns; relevant heap pages are all-visible (requires `VACUUM`). With predicates, btree can do `LooseIndexScan`-like skips, but a full `count(*)` still walks every index entry. ([Index-Only Scans docs](https://www.postgresql.org/docs/current/indexes-index-only-scans.html)) - -**Partition-wise aggregate and FDW pushdown.** Postgres 10 added remote aggregation in `postgres_fdw`; subsequent commits added partition-wise aggregate, which decomposes the top-level Agg into per-partition Aggs that can each be pushed to a foreign server. The plan ends with a final `Aggregate` whose `combinefunc`/`serialfunc`/`deserialfunc` (declared in `CREATE AGGREGATE`) merge the partials. Enabled by `enable_partitionwise_aggregate` GUC. Restrictions: no `DISTINCT`/`ORDER BY` in aggregate, no `HAVING`, not `array_agg`. ([EDB Aggregate Push-down post](https://www.enterprisedb.com/blog/postgresql-aggregate-push-down-postgresfdw), [commit message](https://www.postgresql.org/message-id/E1f30tV-0003rh-27@gemulon.postgresql.org)) - -### 3.2 DuckDB - -DuckDB auto-builds **zonemaps** (per-row-group min/max) for all general-purpose types and uses them for both predicate pushdown and "computing aggregations" (Indexing docs). Row groups are ~122,880 rows. The optimizer pipeline (Filter Pushdown, Join Order, TopN, Expression Rewriter, Filter Pull-up, IN Rewriter, Statistics Propagation, Reorder Filters, Join Filter Pushdown) does not document a dedicated metadata-only-aggregate rule, but Statistics Propagation does fold known constants (e.g., `MIN/MAX` of a column with known range) at plan time. The **ART index** is documented as not affecting aggregation/join/sort performance — it is only for point lookups and PK enforcement. ([Indexing](https://duckdb.org/docs/current/guides/performance/indexing), [Optimizers blog](https://duckdb.org/2024/11/14/optimizers)) - -### 3.3 SQL Server (Columnstore) - -**Segment elimination** drops rowgroups whose per-segment min/max prove a predicate cannot match. Numeric/date types since 2014; string/binary/guid since 2022. Each rowgroup also stores row count for instant `COUNT(*)`. ([SQLpassion segment elimination](https://www.sqlpassion.at/archive/2017/01/30/columnstore-segment-elimination/)) - -**Aggregate Pushdown (2016+).** The Aggregate operator is fused into the Columnstore Scan; aggregation runs on compressed/bit-packed data with SIMD. Supports `MIN`, `MAX`, `SUM`, `COUNT`, `COUNT(*)` when input+output fit in 64 bits (int family, money, decimal/numeric with precision ≤18, date/time types). **Not supported**: `DISTINCT`, string columns, virtual columns, delta store rows (which still flow up to the Aggregate node). EXPLAIN exposes `ActualLocallyAggregatedRows`. ([Microsoft post](https://learn.microsoft.com/en-us/archive/blogs/sql_server_team/columnstore-index-performance-sql-server-2016-aggregate-pushdown)) - -**Grouped Aggregate Pushdown (2019+).** Extends to `GROUP BY`. Each output batch (~900 rows) makes a *runtime* choice between fast (pushdown) and slow paths based on a "benefit measure" starting at 100 and decremented when batches contain few rows per key (22% penalty for <8/key, 11% for 8–16/key). Disables entirely when bit-packed grouping key exceeds 10 bits. Pure RLE keys always fast-path. ([Paul White, SQLPerformance](https://sqlperformance.com/2019/04/sql-plan/grouped-aggregate-pushdown)) - -**Indexed Views.** Materialized `SELECT ... GROUP BY` results with synchronous maintenance. Optimizer can use them transparently if `EXPAND VIEWS` is off — purely planner-side pattern match against `SELECT` shape. - -### 3.4 ClickHouse - -**Granule-level min/max + skip indexes.** Default granule is 8192 rows; the primary key (sparse) gives row-range pruning, and explicit `minmax`/`set`/`bloom_filter` skip indexes augment it. The `optimize_use_implicit_projections` and `optimize_use_projections` flags drive the optimizer to consider projections. - -**Projections** (transparent materialized aggregates). When a projection defines `GROUP BY`, the underlying engine becomes `AggregatingMergeTree` and aggregate columns become `AggregateFunction(...)` states. The optimizer "automatically samples the primary keys and chooses a table that can generate the same correct result, but requires the least amount of data to be read." Since 25.5, projections can store only sorting keys + `_part_offset` to act as a pure index. ([Projections docs](https://clickhouse.com/docs/data-modeling/projections)) - -**AggregatingMergeTree.** Stores partial states for aggregations; `min`/`max` need no extra merge cost ("require no extra steps to calculate the final result from the intermediate steps"). The `SimpleAggregateFunction` combinator is an optimized form for aggregates whose state is just the result (`min`, `max`, `sum`, `any`, `anyLast`). ([Altinity KB](https://kb.altinity.com/altinity-kb-queries-and-syntax/simplestateif-or-ifstate-for-simple-aggregate-functions/)) - -### 3.5 Apache Druid - -**Bitmap indexes per dictionary entry.** For each distinct value in a (string) column, Druid stores one Roaring-compressed bitmap of matching rows. Combined with a dictionary mapping string→int. ([Segments doc](https://druid.apache.org/docs/latest/design/segments/)) - -**`cardinality` and `hyperUnique` aggregators.** `COUNT(DISTINCT)` in SQL is translated to `cardinality`, which returns an *approximate* count via HyperLogLog over the dimension values; `hyperUnique` is the recommended path when you only need the count, not the values — it's stored as an HLL sketch in the segment, so the count is computed by merging sketches across segments, no per-row work. Druid recommends DataSketches (theta/HLL) for new use cases. ([HLL old docs](https://druid.apache.org/docs/latest/querying/hll-old.html), [CALCITE-1670](https://issues.apache.org/jira/browse/CALCITE-1670)) - -For *exact* distinct count, Druid does not push down — it runs a groupBy and counts. The bitmap-per-value structure means exact distinct count *could* be answered as "number of bitmaps in the dictionary whose intersection with the predicate mask is non-empty" — this is exactly the Lance opportunity. - -### 3.6 Apache Pinot — Star-Tree Index - -Pre-aggregated multi-dimensional tree. Each level splits on a dimension; each internal node has a "star" child holding the aggregate with that dimension dropped. The planner pattern-matches a query's `GROUP BY` dimensions and aggregate functions against an available star-tree's schema. Aggregations are *materialized* at build time. Reported gains: "99.76% reduction in latency vs. no Star-Tree Index (6.3 seconds to 15 ms)" and "99.99999% reduction in amount of data scanned." Supports COUNT/SUM/MIN/MAX/etc.; approximate distinct via DataSketches theta/HLL stored as the aggregate value at the node. ([Pinot docs](https://docs.pinot.apache.org/basics/indexing/star-tree-index), [Part 3 blog](https://startree.ai/resources/star-tree-index-in-apache-pinot-part-3-understanding-the-impact-in-real-customer/)) - -### 3.7 Snowflake - -**Micro-partition metadata** stored per partition: column value ranges, distinct counts, and "additional properties." Metadata is in the cloud-services layer, queried before any data IO. `count(*)`, `MIN(col)`, `MAX(col)` on a partition-aligned column with no predicate (or with a predicate that aligns with metadata) can return from metadata alone, hence the well-known "instant `COUNT(*)`" on Snowflake. ([Micro-partitions docs](https://docs.snowflake.com/en/user-guide/tables-clustering-micropartitions)) - -**Snowflake Optima (2024-2025).** Dynamically generates *additional* lightweight per-micro-partition metadata for high-frequency "hot" expressions seen in workloads — extending min/max-style pruning to expressions like `LOWER(col) = ...`. ([Optima blog](https://www.snowflake.com/en/engineering-blog/snowflake-optima-metadata-query-pruning/)) - -### 3.8 Parquet / Iceberg - -**Parquet** stores per-row-group and per-page min/max, null count, distinct count (optional, often unset by writers). These drive predicate pushdown but are also enough material for aggregate pushdown. - -**Iceberg PR #6622** (`huaxingao`, merged) implemented `MIN/MAX/COUNT` pushdown through Spark's `SupportsPushDownAggregates`. Key classes: `AggregateEvaluator`, `BoundAggregate` (with `hasValue` to distinguish "all-null column" from "stats missing"), `MaxAggregate`, `MinAggregate`, `CountNonNull`. `SparkScanBuilder` orchestrates. Restrictions explicitly enumerated: -- **No GROUP BY** ("Group by aggregation push down is not supported") -- **No row-level deletes** ("Skipped aggregate pushdown: detected row level deletes") -- **No complex types lacking stats** -- **No truncated string metrics** (default mode truncates strings; can't reason about MIN/MAX) - -Toggle: `spark.sql.iceberg.aggregate-push-down-enabled`. ([PR #6622](https://github.com/apache/iceberg/pull/6622)) - -### 3.9 Spark V2 — `SupportsPushDownAggregates` - -The data-source-side contract used by Iceberg, JDBC, file sources. `pushAggregation(Aggregation): boolean` to attempt pushdown; `supportCompletePushDown(Aggregation): boolean` to declare whether the source returns final or partial. If partial, Spark inserts a final Aggregate above the V2 scan with the combine semantics. Filter pushdown happens *first*, then aggregate pushdown — so the data source sees already-filtered fragments. ([Spark JavaDoc](https://spark.apache.org/docs/3.4.3/api/java/org/apache/spark/sql/connector/read/SupportsPushDownAggregates.html)) - -### 3.10 DataFusion - -Has an `AggregateStatistics` physical optimizer rule that converts `MIN/MAX/COUNT(*)` over a scan with exact statistics into a constant `ProjectionExec` — pure metadata-only execution. Issue [#19938](https://github.com/apache/datafusion/issues/19938) proposes extending min/max statistics to drive group-by *layout* (use a `Vec` indexed by `value - min` when the range is small). The "Limit Pruning" blog (March 2026) describes a clean three-tier model: *Not Matching* / *Partially Matching* / *Fully Matching* row groups, where Fully-Matching groups can satisfy `LIMIT` without row-level filtering — directly applicable to aggregate pushdown: Fully-Matching row groups can contribute exact counts from their row-count statistic. ([Query Optimizer docs](https://datafusion.apache.org/library-user-guide/query-optimizer.html), [Limit Pruning blog](https://datafusion.apache.org/blog/2026/03/20/limit-pruning/)) - ---- - -## 4. Index-Type → Aggregate-Type Matrix - -| Index / Metadata | `COUNT(*)` | `MIN/MAX` | `SUM` | `COUNT(col)` (non-null) | `COUNT(DISTINCT col)` | `GROUP BY` cardinality | -|--- |--- |--- |--- |--- |--- |--- | -| Row count per fragment | Yes (no pred) | No | No | Need null count | No | No | -| Zone map (min/max) | No* | **Yes** | No | No | No | No | -| Null count per fragment | Yes (with above) | No | No | **Yes** (no pred) | No | No | -| Distinct count per frag. | No | No | No | No | Approx (upper bound)† | No | -| Btree (ordered) | Walk index | **Yes** O(log n) | Walk index | Walk index | Loose-index scan | Stream-grouped scan | -| Bitmap (one-per-value) | Sum of all bitmaps | **Yes** (first/last value with non-empty bitmap) | No | Bitmap union cardinality | **Yes** (count of values with non-empty bitmap intersected with predicate mask) | **Yes** (cardinality of each bitmap, partition by value) | -| HLL/Theta sketch | No | No | No | No | **Yes** (approximate) | Per-group sketch merge | -| Materialized view / projection / star-tree | Yes | Yes | Yes | Yes | Yes (if pre-aggregated) | **Yes** | - -*`COUNT(*)` from a zone map alone needs row count too — but every engine stores both per fragment, so in practice this is a single lookup. -†Per-fragment distinct counts cannot be summed (overlap); they bound the answer above. - -The bitmap row is the strongest case for Lance. Bitmap-cardinality identities: -``` -COUNT(col) = popcount( OR_v posting[v] ) over predicate-masked rows -COUNT(DISTINCT col)= |{ v : posting[v] AND mask != ∅ }| -COUNT(*) WHERE col=v = popcount( posting[v] AND mask ) -GROUP BY col, COUNT(*) = for v in dict: emit (v, popcount(posting[v] AND mask)) -``` - ---- - -## 5. Planner Integration Patterns - -Three recurring shapes, in order of complexity: - -**(a) Pre-planner rewrite (Postgres pattern).** A dedicated function — `preprocess_minmax_aggregates` — runs *before* the main path enumeration. It builds an alternate path (`MinMaxAggPath`) parallel to the normal Aggregate-over-Scan path. The cost model picks the winner. Pros: keeps the special case out of the general optimizer. Cons: each new shape is a new bespoke function. - -**(b) Physical-optimizer rule (DataFusion `AggregateStatistics`).** A late physical-plan rewrite that inspects the plan tree for `AggregateExec { mode: Final, expr: [Min|Max|Count], input: ScanExec }` and, if the scan can produce exact statistics for those columns, replaces the whole subtree with a `ProjectionExec` of constants. Pros: composes with existing rules. Cons: must reason about partial-vs-final aggregate modes; needs exact (not estimated) statistics. - -**(c) Data-source interface (Spark V2 `SupportsPushDownAggregates`).** The optimizer hands the data source an `Aggregation` description; the source returns whether (and how completely) it can satisfy it. If partial, optimizer inserts a final-stage Aggregate above. Pros: clean separation; the source owns correctness. Cons: API surface is large; partial-aggregate plumbing must be wired. - -**Recommendation for Lance.** Mirror Spark V2's contract at the `Scan` level, but execute the dispatch in DataFusion's physical optimizer (since Lance plans through DataFusion already). The `Scan` would expose `try_pushdown_aggregate(agg, filter) -> Option`. The optimizer rule walks `AggregateExec(final) → AggregateExec(partial) → Scan` patterns and asks the scan whether it can satisfy. Index access lives inside the scan (or its `MetricsProvider`), not in the optimizer. - ---- - -## 6. Correctness Gotchas - -1. **Predicate-must-be-fully-evaluable-by-index.** If the index can evaluate `col = 5` but not `f(col) = 5`, the predicate must be either rejected by the index entirely or split. A pushed aggregate over a partially-filtered set is silently wrong. Iceberg's PR thread had multiple iterations on this. - -2. **NULL handling per aggregate.** `COUNT(*)` counts rows including nulls; `COUNT(col)` and `MIN/MAX` skip nulls. Need both row count and null count per fragment. Iceberg's `BoundAggregate.hasValue` distinguishes "stat exists and column is all-null (legal answer for MIN/MAX = NULL)" from "stat missing → abort." - -3. **Row-level deletes / deletion vectors / MVCC.** Stale statistics. Postgres: visibility map. SQL Server: delta rowgroups bypass pushdown. Iceberg: aggregate pushdown disabled if row-level deletes exist on touched files. **Lance equivalent: deletion vectors.** Pushdown must either consult deletion vector population (row count − deleted count) or abort. - -4. **Empty input vs zero.** `COUNT` on zero rows is `0`; `MIN/MAX/SUM` on zero rows is `NULL`. The fast path must return the right type, not silently coerce. - -5. **`COUNT(DISTINCT)` overlap across fragments.** Per-fragment distinct counts cannot be summed. Two paths: (a) merge an exact structure (sorted dictionary or bitmap union) across fragments; (b) merge HLL/theta sketches for approximate answer. Lance bitmap indexes naturally support (a) via posting-list union. - -6. **Truncated/lossy statistics.** Parquet writers commonly truncate string min/max. Iceberg refuses pushdown in this case. Lance should mark such stats as inexact and refuse. - -7. **`MIN/MAX` sort operator vs. aggregate sort order.** Postgres's `fetch_agg_sort_op` looks up the agg's sort operator from `pg_aggregate`. A user-defined min-like aggregate is not eligible unless registered correctly. Lance's analogue: only well-known `MIN`/`MAX` over orderable types qualify; do not try to be clever with user-defined aggregates. - -8. **GROUP BY combined with aggregate pushdown is partial by definition.** Each fragment emits `(group_key, partial_agg)`, and the engine reduces across fragments. The fragment-side dedup is *not* a complete `GROUP BY` — duplicates across fragments are normal and required for correctness. SQL Server's docs: "the data source can still output data with duplicated keys, which is OK as Spark will do GROUP BY key again." - -9. **Aggregate-over-filter ordering.** Spark V2 explicitly pushes filters *before* aggregates. Lance's scan API should follow: aggregate pushdown receives the post-filter view. - -10. **Approximate vs exact must be explicit in the API.** Calcite Druid translation of `COUNT(DISTINCT)` to `cardinality` was filed as a bug (CALCITE-1670) because users didn't expect approximate semantics. Lance should never silently approximate. - ---- - -## 7. Open Questions / Things I Couldn't Pin Down Authoritatively - -- **DuckDB's exact metadata-only path.** Multiple sources say zonemaps drive "computing aggregations" but I could not find a named optimizer rule (e.g., a `count_star_metadata` rule) in either the optimizer blog or the indexing docs. Need to read `src/optimizer/` in the DuckDB tree directly — start at [`optimizer.cpp`](https://github.com/duckdb/duckdb/blob/main/src/optimizer/optimizer.cpp) and look for statistics-propagation paths that fold to constants. -- **ClickHouse projection selection cost model.** Docs say "the optimizer automatically samples the primary keys" but I did not find a description of the tie-breaking when multiple projections could serve. Likely in `Processors/QueryPlan/Optimizations/optimizeUseAggregateProjection.cpp` in source. -- **Snowflake metadata-only execution rules.** Marketing-level confirmation that COUNT/MIN/MAX from metadata works, but no published planner doc. The Optima blog is the closest thing and is high-level. -- **Pinot star-tree planner matching.** Docs describe the structure but not the matcher. The pattern from the description is "exact match on dimension subset + supported aggregate"; needs source-code confirmation (see `pinot-segment-spi`). -- **Druid exact COUNT(DISTINCT) status.** There is a community "Exact Cardinality Count" extension PR but it is not in core. Mainline path is HLL-approximate. Worth a follow-up: does Druid's bitmap structure make exact distinct count "free enough" that someone proposed a core impl? (The PR exists; review comments would tell us why it didn't merge.) -- **Postgres `count(*)` from index.** I expected a planner rewrite analogous to MinMaxAggPath. I couldn't find one — it appears `count(*)` always goes through an actual scan (possibly index-only), never a metadata read. Worth confirming on `pgsql-hackers`; multiple threads have proposed it and been declined for MVCC reasons. -- **Iceberg manifest-only `MIN/MAX` correctness with column nullability.** PR #6622 introduces `hasValue` but I didn't trace whether mixed-null + non-null fragments are merged correctly when *some* fragments have stats and *others* don't. Worth reading the test cases before mirroring the design. - ---- - -### Sources - -- PostgreSQL: [planagg.c source](https://doxygen.postgresql.org/planagg_8c_source.html) · [Cybertec MIN/MAX speedup](https://www.cybertec-postgresql.com/en/speeding-up-min-and-max/) · [Index-Only Scans](https://www.postgresql.org/docs/current/indexes-index-only-scans.html) · [Wiki: Index-only scans](https://wiki.postgresql.org/wiki/Index-only_scans) · [EDB Aggregate Push-down](https://www.enterprisedb.com/blog/postgresql-aggregate-push-down-postgresfdw) · [Partition-wise aggregation commit](https://www.postgresql.org/message-id/E1f30tV-0003rh-27@gemulon.postgresql.org) -- DuckDB: [Indexing](https://duckdb.org/docs/current/guides/performance/indexing) · [Indexes](https://duckdb.org/docs/current/sql/indexes) · [Optimizers blog](https://duckdb.org/2024/11/14/optimizers) · [Row Groups (DeepWiki)](https://deepwiki.com/duckdb/duckdb/7.2-column-storage) -- SQL Server: [Aggregate Pushdown 2016](https://learn.microsoft.com/en-us/archive/blogs/sql_server_team/columnstore-index-performance-sql-server-2016-aggregate-pushdown) · [Grouped Aggregate Pushdown (Paul White)](https://sqlperformance.com/2019/04/sql-plan/grouped-aggregate-pushdown) · [Columnstore Query Performance](https://learn.microsoft.com/en-us/sql/relational-databases/indexes/columnstore-indexes-query-performance) · [ColumnStore Segment Elimination](https://www.sqlpassion.at/archive/2017/01/30/columnstore-segment-elimination/) -- ClickHouse: [Projections docs](https://clickhouse.com/docs/data-modeling/projections) · [AggregatingMergeTree (Altinity)](https://kb.altinity.com/engines/mergetree-table-engine-family/aggregatingmergetree/) · [SimpleState combinator](https://kb.altinity.com/altinity-kb-queries-and-syntax/simplestateif-or-ifstate-for-simple-aggregate-functions/) -- Druid: [Segments design](https://druid.apache.org/docs/latest/design/segments/) · [HLL old aggregator](https://druid.apache.org/docs/latest/querying/hll-old.html) · [Aggregations reference](https://druid.apache.org/docs/latest/querying/aggregations/) · [CALCITE-1670](https://issues.apache.org/jira/browse/CALCITE-1670) -- Pinot: [Star-Tree Index docs](https://docs.pinot.apache.org/basics/indexing/star-tree-index) · [Star-Tree Part 3](https://startree.ai/resources/star-tree-index-in-apache-pinot-part-3-understanding-the-impact-in-real-customer/) -- Snowflake: [Micro-partitions and clustering](https://docs.snowflake.com/en/user-guide/tables-clustering-micropartitions) · [Snowflake Optima](https://www.snowflake.com/en/engineering-blog/snowflake-optima-metadata-query-pruning/) · [Pruning paper (arXiv)](https://arxiv.org/html/2504.11540v1) -- Iceberg/Spark: [Iceberg PR #6622 (aggregate pushdown)](https://github.com/apache/iceberg/pull/6622) · [Iceberg statistics (Ryft)](https://www.ryft.io/blog/making-sense-of-apache-iceberg-statistics) · [Spark SupportsPushDownAggregates JavaDoc](https://spark.apache.org/docs/3.4.3/api/java/org/apache/spark/sql/connector/read/SupportsPushDownAggregates.html) -- DataFusion: [Query Optimizer](https://datafusion.apache.org/library-user-guide/query-optimizer.html) · [Issue #19938 (min/max in grouped aggs)](https://github.com/apache/datafusion/issues/19938) · [Limit Pruning blog (Mar 2026)](https://datafusion.apache.org/blog/2026/03/20/limit-pruning/) · [Optimizing SQL Part 2](https://datafusion.apache.org/blog/2025/06/15/optimizing-sql-dataframes-part-two/) diff --git a/rust/lance-index/src/expression.rs b/rust/lance-index/src/expression.rs index 1b15f77d219..3ff722cab11 100644 --- a/rust/lance-index/src/expression.rs +++ b/rust/lance-index/src/expression.rs @@ -1,3 +1,6 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright The Lance Authors + //! Plan-time expression parsing for scalar and aggregate index pushdown. //! //! Both halves split a user expression into an index-evaluable leaf plus the diff --git a/rust/lance-index/src/scalar/expression.rs b/rust/lance-index/src/scalar/expression.rs index e73e898ff76..aa962928e79 100644 --- a/rust/lance-index/src/scalar/expression.rs +++ b/rust/lance-index/src/scalar/expression.rs @@ -1,2 +1,7 @@ -/// This file was moved from `src/scalar/expression.rs` to `src/expression/scalar.rs` +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright The Lance Authors + +//! This module was moved to [`crate::expression::scalar`]. Kept as a stub +//! re-export to preserve the old `lance_index::scalar::expression::…` paths. + pub use crate::expression::scalar::*; diff --git a/rust/lance/src/dataset/tests/dataset_aggregate.rs b/rust/lance/src/dataset/tests/dataset_aggregate.rs index ef2a90e6315..148e3752f6f 100644 --- a/rust/lance/src/dataset/tests/dataset_aggregate.rs +++ b/rust/lance/src/dataset/tests/dataset_aggregate.rs @@ -268,7 +268,9 @@ async fn test_count_star_single_fragment() { vec![], ); - // Verify COUNT(*) has empty projection optimization + // COUNT(*) is rewritten by AggregateIndexPushdown into a Final aggregate + // over AggregateIndexSearchExec, which answers from manifest metadata + the + // deletion mask instead of scanning column data. let mut scanner = ds.scan(); scanner .aggregate(AggregateExpr::substrait(agg_bytes.clone())) @@ -276,8 +278,8 @@ async fn test_count_star_single_fragment() { let plan = scanner.create_plan().await.unwrap(); assert_plan_node_equals( plan, - "AggregateExec: mode=Single, gby=[], aggr=[count(...)] - LanceRead: uri=..., projection=[], num_fragments=1, range_before=None, range_after=None, row_id=false, row_addr=true, full_filter=--, refine_filter=--", + "AggregateExec: mode=Final, gby=[], aggr=[count(...)] + AggregateIndexSearch: aggs=[count@\"*\"]", ) .await .unwrap(); @@ -1204,11 +1206,12 @@ async fn test_scanner_count_rows() { .unwrap(); let plan = scanner.create_plan().await.unwrap(); - // COUNT(*) should have empty projection (optimized to not read any columns) + // COUNT(*) is rewritten by AggregateIndexPushdown into a Final aggregate + // over AggregateIndexSearchExec. assert_plan_node_equals( plan.clone(), - "AggregateExec: mode=Single, gby=[], aggr=[count(Int32(1))] - LanceRead: uri=..., projection=[], num_fragments=2, range_before=None, range_after=None, row_id=false, row_addr=true, full_filter=--, refine_filter=--", + "AggregateExec: mode=Final, gby=[], aggr=[count(Int32(1))] + AggregateIndexSearch: aggs=[count@\"*\"]", ) .await .unwrap(); diff --git a/rust/lance/src/io/exec/aggregate_index_pushdown.rs b/rust/lance/src/io/exec/aggregate_index_pushdown.rs index cb0812ef87c..aed2d795d72 100644 --- a/rust/lance/src/io/exec/aggregate_index_pushdown.rs +++ b/rust/lance/src/io/exec/aggregate_index_pushdown.rs @@ -4,18 +4,24 @@ //! Physical optimizer rule that rewrites `COUNT`-shaped aggregates into //! [`AggregateIndexSearchExec`]. //! +//! v1 only fires for fully unfiltered counts — the simplest provably-safe +//! envelope. Filtered counts are deferred to a follow-up that can validate +//! the index covers every dataset fragment. +//! //! Recognized shape: //! //! ```text //! AggregateExec(Single, aggs=[COUNT(*)], group_by=[]) -//! └── FilteredReadExec { full_filter ⊆ index_input, no refine_filter, ... } +//! └── FilteredReadExec { no full_filter, no refine_filter, no index_input, +//! no scan range, no with_deleted_rows, no fragment +//! subset, not stable-row-ids } //! ``` //! //! Rewritten to: //! //! ```text //! AggregateExec(Final, aggs=[COUNT(*)], group_by=[]) -//! └── AggregateIndexSearchExec { prefilter_input = index_input } +//! └── AggregateIndexSearchExec { prefilter_input = None } //! ``` //! //! [`AggregateIndexSearchExec`] emits partial-state, so the outer @@ -104,22 +110,32 @@ fn try_rewrite(agg: &AggregateExec) -> DFResult>> } } - // The input must be a FilteredReadExec whose filter is either absent or - // fully evaluable by a child scalar-index search. + // The input must be a FilteredReadExec we can prove is safe to skip. let child = &agg.children()[0]; let Some(filtered_read) = child.as_any().downcast_ref::() else { return Ok(None); }; - let options = filtered_read.options(); - // A refine filter is a residual the index couldn't fully evaluate — we'd - // need to scan data to apply it, so bail. - if options.refine_filter.is_some() { + // Stable-row-id mode: `DatasetPreFilter::create_deletion_mask` produces an + // AllowList in stable-id space, but `AggregateIndexSearchExec` builds its + // fragments-allow list in row-address space. ANDing across the two yields + // a silently wrong count (rows in fragments > 0 are dropped because their + // stable ids and row addresses share a fragment-id bucket only by accident). + // Until the exec can reconcile the two id spaces, refuse to fire. + if filtered_read.dataset().manifest().uses_stable_row_ids() { return Ok(None); } - // A full_filter without an index_input means the filter is evaluated by - // re-reading every row; not pushdownable. - if options.full_filter.is_some() && filtered_read.index_input().is_none() { + + let options = filtered_read.options(); + // No filter at all is the only case v1 can prove correct. With a filter we + // would also need to verify the scalar index covers every dataset fragment + // (otherwise rows in unindexed fragments are silently dropped). That check + // is async and not currently expressible in a sync PhysicalOptimizerRule; + // until we plumb it through, leave the filtered case on the scan path. + if options.full_filter.is_some() + || options.refine_filter.is_some() + || filtered_read.index_input().is_some() + { return Ok(None); } // LIMIT/OFFSET would change the count. @@ -155,12 +171,8 @@ fn try_rewrite(agg: &AggregateExec) -> DFResult>> .collect(); let aggregate_funcs: Vec> = agg.aggr_expr().to_vec(); - let exec = AggregateIndexSearchExec::try_new( - dataset, - aggregates, - aggregate_funcs, - prefilter_input, - )?; + let exec = + AggregateIndexSearchExec::try_new(dataset, aggregates, aggregate_funcs, prefilter_input)?; let exec_schema = exec.schema(); let exec: Arc = Arc::new(exec); @@ -269,7 +281,9 @@ mod tests { /// Drive the rule via `Scanner::create_plan` (which registers the rule /// through `get_physical_optimizer`) and return both the plan and the /// final count for inspection. - async fn run_count(scanner: &mut crate::dataset::scanner::Scanner) -> (Arc, i64) { + async fn run_count( + scanner: &mut crate::dataset::scanner::Scanner, + ) -> (Arc, i64) { scanner .aggregate(AggregateExpr::builder().count_star().build()) .unwrap(); @@ -280,7 +294,12 @@ mod tests { ) .unwrap(); let batches: Vec<_> = stream.try_collect().await.unwrap(); - assert_eq!(batches.len(), 1, "count plan emitted {} batches", batches.len()); + assert_eq!( + batches.len(), + 1, + "count plan emitted {} batches", + batches.len() + ); let count = batches[0] .column(0) .as_any() @@ -304,15 +323,116 @@ mod tests { } #[tokio::test] - async fn rule_fires_when_filter_fully_indexed() { + async fn rule_skips_when_filter_present_even_if_indexed() { + // Deferred until the rule can verify the index covers every dataset + // fragment — without that check, an index built before a fragment + // append silently drops rows. See `rule_skips_partial_index_coverage` + // below for the regression scenario this protects against. let fixture = make_fixture().await; let mut scanner = fixture.dataset.scan(); scanner.filter("ordered < 25").unwrap(); let (plan, count) = run_count(&mut scanner).await; assert_eq!(count, 25); assert!( - plan_contains_pushdown(&plan), - "expected AggregateIndexSearchExec in plan: {}", + !plan_contains_pushdown(&plan), + "rule should not fire with any filter in v1, got plan: {}", + displayable(plan.as_ref()).indent(true) + ); + } + + #[tokio::test] + async fn rule_skips_partial_index_coverage() { + // Regression: when an index doesn't cover every dataset fragment + // (here, by appending a fragment after the index was built), the rule + // must not fire — otherwise rows in unindexed fragments are silently + // dropped. Today this is enforced by the blanket "no filter" gate. + use crate::dataset::WriteParams; + let tmp = TempStrDir::default(); + // Build a 4×10 dataset with a BTree index covering all 4 fragments. + let mut dataset = gen_batch() + .col("ordered", lance_datagen::array::step::()) + .into_dataset( + tmp.as_str(), + FragmentCount::from(4), + FragmentRowCount::from(10), + ) + .await + .unwrap(); + dataset + .create_index( + &["ordered"], + IndexType::BTree, + None, + &ScalarIndexParams::default(), + true, + ) + .await + .unwrap(); + // Append a fragment after the index was built — it is unindexed. + let extra = gen_batch() + .col("ordered", lance_datagen::array::step::()) + .into_reader_rows( + lance_datagen::RowCount::from(10), + lance_datagen::BatchCount::from(1), + ); + let dataset = Dataset::write( + extra, + tmp.as_str(), + Some(WriteParams { + mode: crate::dataset::WriteMode::Append, + max_rows_per_file: 10, + ..Default::default() + }), + ) + .await + .unwrap(); + let dataset = Arc::new(dataset); + + let mut scanner = dataset.scan(); + scanner.filter("ordered < 100").unwrap(); + let (plan, count) = run_count(&mut scanner).await; + // 5 fragments × 10 rows, all match `< 100`. + assert_eq!(count, 50); + assert!( + !plan_contains_pushdown(&plan), + "rule must not fire when the index has partial coverage, got plan: {}", + displayable(plan.as_ref()).indent(true) + ); + } + + #[tokio::test] + async fn rule_skips_with_stable_row_ids() { + // Regression: with stable row IDs the deletion mask is built in + // stable-id space while fragments_allow is in row-address space. + // ANDing across the two undercounts; refuse to fire. + use crate::dataset::WriteParams; + let tmp = TempStrDir::default(); + let mut dataset = gen_batch() + .col("ordered", lance_datagen::array::step::()) + .into_dataset_with_params( + tmp.as_str(), + FragmentCount::from(2), + FragmentRowCount::from(10), + Some(WriteParams { + max_rows_per_file: 10, + enable_stable_row_ids: true, + ..Default::default() + }), + ) + .await + .unwrap(); + // Touch a deletion so we exercise the masks that would otherwise + // collide across id spaces. + dataset.delete("ordered = 0").await.unwrap(); + let dataset = Arc::new(dataset); + + let mut scanner = dataset.scan(); + let (plan, count) = run_count(&mut scanner).await; + // 2 × 10 rows, minus the one deletion. + assert_eq!(count, 19); + assert!( + !plan_contains_pushdown(&plan), + "rule must not fire under stable row IDs, got plan: {}", displayable(plan.as_ref()).indent(true) ); } @@ -378,4 +498,3 @@ mod tests { ); } } - From d73f2cd522cd1616199d643c4090ce557658aac3 Mon Sep 17 00:00:00 2001 From: Weston Pace Date: Wed, 20 May 2026 16:42:54 +0000 Subject: [PATCH 4/4] Refresh python/Cargo.lock for lance-arrow-scalar dep The foundation commit added `lance-arrow-scalar` as a workspace dependency of `lance-index`. `python/Cargo.lock` was not regenerated, so CI's `--locked` build failed: error: cannot update the lock file /home/runner/work/lance/lance/python/Cargo.lock because --locked was passed to prevent this Run `cargo update` from the python crate dir to add the missing entry. Co-Authored-By: Claude Opus 4.7 (1M context) --- python/Cargo.lock | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/python/Cargo.lock b/python/Cargo.lock index 3e3d41cc7c0..9778e9e6b01 100644 --- a/python/Cargo.lock +++ b/python/Cargo.lock @@ -4063,6 +4063,19 @@ dependencies = [ "rand 0.9.4", ] +[[package]] +name = "lance-arrow-scalar" +version = "58.0.0" +dependencies = [ + "arrow-array", + "arrow-buffer", + "arrow-cast", + "arrow-data", + "arrow-row", + "arrow-schema", + "half", +] + [[package]] name = "lance-bitpacking" version = "7.0.0-beta.12" @@ -4272,6 +4285,7 @@ dependencies = [ "jieba-rs", "jsonb", "lance-arrow", + "lance-arrow-scalar", "lance-core", "lance-datafusion", "lance-datagen",