diff --git a/crates/lance-graph-python/src/graph.rs b/crates/lance-graph-python/src/graph.rs index ad6d7531..b7a4337c 100644 --- a/crates/lance-graph-python/src/graph.rs +++ b/crates/lance-graph-python/src/graph.rs @@ -24,9 +24,10 @@ use arrow_schema::Schema; use datafusion::datasource::{DefaultTableSource, MemTable}; use datafusion::execution::context::SessionContext; use lance_graph::{ - ast::DistanceMetric as RustDistanceMetric, CypherQuery as RustCypherQuery, - ExecutionStrategy as RustExecutionStrategy, GraphConfig as RustGraphConfig, - GraphError as RustGraphError, VectorSearch as RustVectorSearch, InMemoryCatalog, + ast::{DistanceMetric as RustDistanceMetric, GraphPattern, ReadingClause}, + CypherQuery as RustCypherQuery, ExecutionStrategy as RustExecutionStrategy, + GraphConfig as RustGraphConfig, GraphError as RustGraphError, InMemoryCatalog, + VectorSearch as RustVectorSearch, }; use pyo3::{ exceptions::{PyNotImplementedError, PyRuntimeError, PyValueError}, @@ -106,6 +107,9 @@ impl From for RustDistanceMetric { #[derive(Clone)] pub struct VectorSearch { inner: RustVectorSearch, + /// Flag to enable vector-first Lance ANN execution path. + /// This is stored separately because RustVectorSearch doesn't have this concept. + use_lance_index: bool, } #[pymethods] @@ -125,6 +129,7 @@ impl VectorSearch { fn new(column: &str) -> Self { Self { inner: RustVectorSearch::new(column), + use_lance_index: false, } } @@ -142,6 +147,7 @@ impl VectorSearch { fn query_vector(&self, vector: Vec) -> Self { Self { inner: self.inner.clone().query_vector(vector), + use_lance_index: self.use_lance_index, } } @@ -159,6 +165,7 @@ impl VectorSearch { fn metric(&self, metric: DistanceMetric) -> Self { Self { inner: self.inner.clone().metric(metric.into()), + use_lance_index: self.use_lance_index, } } @@ -176,6 +183,7 @@ impl VectorSearch { fn top_k(&self, k: usize) -> Self { Self { inner: self.inner.clone().top_k(k), + use_lance_index: self.use_lance_index, } } @@ -193,6 +201,7 @@ impl VectorSearch { fn include_distance(&self, include: bool) -> Self { Self { inner: self.inner.clone().include_distance(include), + use_lance_index: self.use_lance_index, } } @@ -210,6 +219,30 @@ impl VectorSearch { fn distance_column_name(&self, name: &str) -> Self { Self { inner: self.inner.clone().distance_column_name(name), + use_lance_index: self.use_lance_index, + } + } + + /// Use Lance ANN index when datasets are Lance datasets. + /// + /// This enables a vector-first execution path that queries the Lance index + /// and then runs the Cypher query on the top-k results. This can be much faster + /// for large datasets but may change semantics when the Cypher query includes + /// filters or additional constraints. + /// + /// Parameters + /// ---------- + /// enabled : bool + /// If True, use Lance ANN index for vector search when possible. + /// + /// Returns + /// ------- + /// VectorSearch + /// A new builder with the setting applied + fn use_lance_index(&self, enabled: bool) -> Self { + Self { + inner: self.inner.clone(), + use_lance_index: enabled, } } @@ -640,6 +673,8 @@ impl CypherQuery { /// Dictionary mapping table names to Lance datasets or PyArrow tables /// vector_search : VectorSearch /// VectorSearch configuration for reranking + /// (Use VectorSearch.use_lance_index(True) to enable a vector-first + /// execution path when datasets are Lance datasets.) /// /// Returns /// ------- @@ -674,6 +709,12 @@ impl CypherQuery { datasets: &Bound<'_, PyDict>, vector_search: &VectorSearch, ) -> PyResult { + if vector_search.use_lance_index { + if let Some(result) = try_execute_with_lance_index(py, &self.inner, datasets, vector_search)? { + return record_batch_to_python_table(py, &result); + } + } + // Convert datasets to Arrow batches let arrow_datasets = python_datasets_to_batches(datasets)?; @@ -742,27 +783,191 @@ fn json_to_python(py: Python, value: &JsonValue) -> PyResult { } // Helper functions for Arrow conversion + +/// Convert a single Python dataset value to a RecordBatch +fn python_dataset_to_batch(value: &Bound<'_, PyAny>) -> PyResult { + let batch = if is_lance_dataset(value)? { + lance_dataset_to_record_batch(value)? + } else if value.hasattr("to_table")? { + let table = value.call_method0("to_table")?; + python_any_to_record_batch(&table)? + } else { + python_any_to_record_batch(value)? + }; + normalize_record_batch(batch) +} + fn python_datasets_to_batches( datasets: &Bound<'_, PyDict>, +) -> PyResult> { + python_datasets_to_batches_impl(datasets, None) +} + +fn python_datasets_to_batches_with_override( + datasets: &Bound<'_, PyDict>, + override_label: &str, + override_batch: &RecordBatch, +) -> PyResult> { + python_datasets_to_batches_impl(datasets, Some((override_label, override_batch))) +} + +fn python_datasets_to_batches_impl( + datasets: &Bound<'_, PyDict>, + override_entry: Option<(&str, &RecordBatch)>, ) -> PyResult> { let mut arrow_datasets = HashMap::new(); for (key, value) in datasets.iter() { let table_name: String = key.extract()?; - let batch = if is_lance_dataset(&value)? { - // Handle Lance datasets using scan() -> to_pyarrow() pattern that works elsewhere - lance_dataset_to_record_batch(&value)? - } else if value.hasattr("to_table")? { - let table = value.call_method0("to_table")?; - python_any_to_record_batch(&table)? - } else { - python_any_to_record_batch(&value)? - }; - let batch = normalize_record_batch(batch)?; + + // Check if this table should use the override batch + if let Some((override_label, override_batch)) = override_entry { + if table_name == override_label { + arrow_datasets.insert(table_name, override_batch.clone()); + continue; + } + } + + let batch = python_dataset_to_batch(&value)?; arrow_datasets.insert(table_name, batch); } Ok(arrow_datasets) } +fn try_execute_with_lance_index( + py: Python, + query: &RustCypherQuery, + datasets: &Bound<'_, PyDict>, + vector_search: &VectorSearch, +) -> PyResult> { + // Only use vector-first path for simple queries without filters. + // Queries with WITH/WHERE clauses need the standard rerank path to ensure correct semantics. + let ast = query.ast(); + if ast.with_clause.is_some() + || ast.where_clause.is_some() + || ast.post_with_where_clause.is_some() + { + return Ok(None); + } + + let query_vector = match vector_search.inner.get_query_vector() { + Some(vec) => vec.to_vec(), + None => { + return Err(PyValueError::new_err( + "VectorSearch.query_vector is required when use_lance_index is enabled", + )) + } + }; + + let (alias, column) = split_vector_column(vector_search.inner.column()); + let label = resolve_vector_label(query, alias.as_deref())?; + let label = match label { + Some(label) => label, + None => return Ok(None), + }; + + let dataset_value = match datasets.get_item(&label)? { + Some(value) => value, + None => return Ok(None), + }; + + if !is_lance_dataset(&dataset_value)? { + return Ok(None); + } + + let metric_str = match vector_search.inner.get_metric() { + RustDistanceMetric::L2 => "l2", + RustDistanceMetric::Cosine => "cosine", + RustDistanceMetric::Dot => "dot", + }; + + // Build the `nearest` dict for Lance's to_table() ANN query. + // Setting use_index=true tells Lance to use the ANN index if available, + // otherwise it falls back to flat (brute-force) search. + let nearest = PyDict::new(py); + nearest.set_item("column", column)?; + nearest.set_item("k", vector_search.inner.get_top_k())?; + nearest.set_item("q", query_vector)?; + nearest.set_item("metric", metric_str)?; + nearest.set_item("use_index", true)?; + + let kwargs = PyDict::new(py); + kwargs.set_item("nearest", nearest)?; + + let table = dataset_value.call_method("to_table", (), Some(&kwargs))?; + let batch = python_any_to_record_batch(&table)?; + let batch = normalize_record_batch(batch)?; + + let arrow_datasets = python_datasets_to_batches_with_override(datasets, &label, &batch)?; + + let inner_query = query.clone(); + let result = RT + .block_on(Some(py), inner_query.execute(arrow_datasets, None))? + .map_err(graph_error_to_pyerr)?; + + Ok(Some(result)) +} + +fn split_vector_column(column: &str) -> (Option, &str) { + let mut parts = column.splitn(2, '.'); + let first = parts.next().unwrap_or(column); + if let Some(rest) = parts.next() { + (Some(first.to_string()), rest) + } else { + (None, column) + } +} + +fn resolve_vector_label( + query: &RustCypherQuery, + alias: Option<&str>, +) -> PyResult> { + let alias_map = alias_map_from_query(query); + if let Some(alias) = alias { + return Ok(alias_map.get(alias).cloned()); + } + if alias_map.len() == 1 { + return Ok(alias_map.values().next().cloned()); + } + Ok(None) +} + +fn alias_map_from_query(query: &RustCypherQuery) -> HashMap { + let mut map = HashMap::new(); + let ast = query.ast(); + for clause in ast + .reading_clauses + .iter() + .chain(ast.post_with_reading_clauses.iter()) + { + if let ReadingClause::Match(match_clause) = clause { + for pattern in &match_clause.patterns { + collect_aliases_from_pattern(pattern, &mut map); + } + } + } + map +} + +fn collect_aliases_from_pattern(pattern: &GraphPattern, map: &mut HashMap) { + match pattern { + GraphPattern::Node(node) => { + if let (Some(var), Some(label)) = (node.variable.as_ref(), node.labels.first()) { + map.entry(var.clone()).or_insert_with(|| label.clone()); + } + } + GraphPattern::Path(path) => { + if let (Some(var), Some(label)) = (path.start_node.variable.as_ref(), path.start_node.labels.first()) { + map.entry(var.clone()).or_insert_with(|| label.clone()); + } + for segment in &path.segments { + if let (Some(var), Some(label)) = (segment.end_node.variable.as_ref(), segment.end_node.labels.first()) { + map.entry(var.clone()).or_insert_with(|| label.clone()); + } + } + } + } +} + fn normalize_record_batch(batch: RecordBatch) -> PyResult { if batch.schema().metadata().is_empty() { return Ok(batch); @@ -1125,3 +1330,95 @@ pub fn register_graph_module(py: Python, parent_module: &Bound<'_, PyModule>) -> parent_module.add_submodule(&graph_module)?; Ok(()) } + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_split_vector_column_with_alias() { + let (alias, column) = split_vector_column("d.embedding"); + assert_eq!(alias, Some("d".to_string())); + assert_eq!(column, "embedding"); + } + + #[test] + fn test_split_vector_column_without_alias() { + let (alias, column) = split_vector_column("embedding"); + assert_eq!(alias, None); + assert_eq!(column, "embedding"); + } + + #[test] + fn test_split_vector_column_with_multiple_dots() { + // Should only split on the first dot + let (alias, column) = split_vector_column("d.nested.embedding"); + assert_eq!(alias, Some("d".to_string())); + assert_eq!(column, "nested.embedding"); + } + + #[test] + fn test_split_vector_column_empty_string() { + let (alias, column) = split_vector_column(""); + assert_eq!(alias, None); + assert_eq!(column, ""); + } + + #[test] + fn test_alias_map_from_simple_node_query() { + let query = RustCypherQuery::new("MATCH (d:Document) RETURN d.name").unwrap(); + let map = alias_map_from_query(&query); + assert_eq!(map.get("d"), Some(&"Document".to_string())); + } + + #[test] + fn test_alias_map_from_multiple_nodes() { + let query = + RustCypherQuery::new("MATCH (p:Person), (d:Document) RETURN p.name, d.title").unwrap(); + let map = alias_map_from_query(&query); + assert_eq!(map.get("p"), Some(&"Person".to_string())); + assert_eq!(map.get("d"), Some(&"Document".to_string())); + } + + #[test] + fn test_alias_map_from_path_query() { + let query = + RustCypherQuery::new("MATCH (p:Person)-[:KNOWS]->(f:Friend) RETURN p.name, f.name") + .unwrap(); + let map = alias_map_from_query(&query); + assert_eq!(map.get("p"), Some(&"Person".to_string())); + assert_eq!(map.get("f"), Some(&"Friend".to_string())); + } + + #[test] + fn test_resolve_vector_label_with_alias() { + let query = RustCypherQuery::new("MATCH (d:Document) RETURN d.name").unwrap(); + let result = resolve_vector_label(&query, Some("d")).unwrap(); + assert_eq!(result, Some("Document".to_string())); + } + + #[test] + fn test_resolve_vector_label_without_alias_single_node() { + let query = RustCypherQuery::new("MATCH (d:Document) RETURN d.name").unwrap(); + // When no alias is provided and there's only one node, should return that label + let result = resolve_vector_label(&query, None).unwrap(); + assert_eq!(result, Some("Document".to_string())); + } + + #[test] + fn test_resolve_vector_label_without_alias_multiple_nodes() { + let query = + RustCypherQuery::new("MATCH (p:Person), (d:Document) RETURN p.name, d.title").unwrap(); + // When no alias is provided and there are multiple nodes, should return None + let result = resolve_vector_label(&query, None).unwrap(); + assert_eq!(result, None); + } + + #[test] + fn test_resolve_vector_label_unknown_alias() { + let query = RustCypherQuery::new("MATCH (d:Document) RETURN d.name").unwrap(); + // When alias doesn't exist in the query, should return None + let result = resolve_vector_label(&query, Some("x")).unwrap(); + assert_eq!(result, None); + } +} diff --git a/crates/lance-graph/src/lance_vector_search.rs b/crates/lance-graph/src/lance_vector_search.rs index 78a0535e..a50ca9a5 100644 --- a/crates/lance-graph/src/lance_vector_search.rs +++ b/crates/lance-graph/src/lance_vector_search.rs @@ -125,6 +125,28 @@ impl VectorSearch { self } + // Getters for accessing internal state (used by Python bindings) + + /// Get the column name + pub fn column(&self) -> &str { + &self.column + } + + /// Get the query vector if set + pub fn get_query_vector(&self) -> Option<&[f32]> { + self.query_vector.as_deref() + } + + /// Get the distance metric + pub fn get_metric(&self) -> &DistanceMetric { + &self.metric + } + + /// Get the top_k value + pub fn get_top_k(&self) -> usize { + self.top_k + } + /// Perform brute-force vector search on a RecordBatch /// /// This method computes distances for all vectors in the batch and returns diff --git a/python/python/tests/test_vector_search.py b/python/python/tests/test_vector_search.py index 0675faa0..386466e4 100644 --- a/python/python/tests/test_vector_search.py +++ b/python/python/tests/test_vector_search.py @@ -181,6 +181,369 @@ def test_execute_with_vector_rerank_basic(vector_env): assert data["d.name"][1] == "Doc2" +@pytest.mark.requires_lance +def test_use_lance_index_missing_query_vector(vector_env, tmp_path): + """Test error when use_lance_index=True but query_vector is not set.""" + config, _, _ = vector_env + + import lance + import numpy as np + + embedding_values = np.array( + [[1.0, 0.0, 0.0], [0.9, 0.1, 0.0]], + dtype=np.float32, + ) + + documents_table = pa.table( + { + "id": [1, 2], + "name": ["Doc1", "Doc2"], + "embedding": pa.FixedSizeListArray.from_arrays( + embedding_values.flatten(), list_size=3 + ), + } + ) + + dataset_path = tmp_path / "Document.lance" + lance.write_dataset(documents_table, dataset_path) + lance_dataset = lance.dataset(str(dataset_path)) + + query = CypherQuery( + "MATCH (d:Document) RETURN d.id, d.name, d.embedding" + ).with_config(config) + + with pytest.raises(ValueError, match="query_vector is required"): + query.execute_with_vector_rerank( + {"Document": lance_dataset}, + VectorSearch("d.embedding") + .metric(DistanceMetric.L2) + .top_k(3) + .use_lance_index(True), # No query_vector set + ) + + +def test_use_lance_index_fallback_non_lance_dataset(vector_env): + """Test use_lance_index=True falls back for non-Lance datasets.""" + config, datasets, _ = vector_env + + query = CypherQuery( + "MATCH (d:Document) RETURN d.id, d.name, d.embedding" + ).with_config(config) + + # Should work fine - falls back to standard rerank for PyArrow table + results = query.execute_with_vector_rerank( + datasets, # PyArrow tables, not Lance datasets + VectorSearch("d.embedding") + .query_vector([1.0, 0.0, 0.0]) + .metric(DistanceMetric.L2) + .top_k(3) + .use_lance_index(True), # Should fallback silently + ) + + data = results.to_pydict() + assert len(data["d.name"]) == 3 + assert data["d.name"][0] == "Doc1" + # _distance column should be present (standard rerank path) + assert "_distance" in data + + +@pytest.mark.requires_lance +def test_use_lance_index_unqualified_column(vector_env, tmp_path): + """Test use_lance_index with unqualified column name (no alias prefix).""" + config, _, _ = vector_env + + import lance + import numpy as np + + embedding_values = np.array( + [ + [1.0, 0.0, 0.0], + [0.9, 0.1, 0.0], + [0.0, 1.0, 0.0], + ], + dtype=np.float32, + ) + + documents_table = pa.table( + { + "id": [1, 2, 3], + "name": ["Doc1", "Doc2", "Doc3"], + "embedding": pa.FixedSizeListArray.from_arrays( + embedding_values.flatten(), list_size=3 + ), + } + ) + + dataset_path = tmp_path / "Document.lance" + lance.write_dataset(documents_table, dataset_path) + lance_dataset = lance.dataset(str(dataset_path)) + + # Use unqualified column name "embedding" instead of "d.embedding" + # This should still work when there's only one node label in the query + query = CypherQuery( + "MATCH (d:Document) RETURN d.id, d.name, d.embedding" + ).with_config(config) + + results = query.execute_with_vector_rerank( + {"Document": lance_dataset}, + VectorSearch("embedding") # No alias prefix + .query_vector([1.0, 0.0, 0.0]) + .metric(DistanceMetric.L2) + .top_k(2) + .use_lance_index(True), + ) + + data = results.to_pydict() + assert len(data["d.name"]) == 2 + assert data["d.name"][0] == "Doc1" + + +def test_use_lance_index_builder_propagation(): + """Test that use_lance_index flag is properly propagated through builder methods.""" + vs = VectorSearch("embedding").use_lance_index(True) + + # Each builder method should preserve the use_lance_index flag + vs2 = vs.query_vector([1.0, 0.0, 0.0]) + vs3 = vs2.metric(DistanceMetric.L2) + vs4 = vs3.top_k(10) + vs5 = vs4.include_distance(True) + vs6 = vs5.distance_column_name("dist") + + # All should still have use_lance_index=True (we verify by using it) + # This is an indirect test - if propagation failed, the final object + # would have use_lance_index=False + # We can't directly inspect the flag, but we can verify the chain works + assert vs6 is not None # Chain completed successfully + + +@pytest.mark.requires_lance +def test_use_lance_index_cosine_metric(vector_env, tmp_path): + """Test use_lance_index with cosine distance metric.""" + config, _, _ = vector_env + + import lance + import numpy as np + + embedding_values = np.array( + [ + [1.0, 0.0, 0.0], + [0.9, 0.1, 0.0], + [0.0, 1.0, 0.0], + ], + dtype=np.float32, + ) + + documents_table = pa.table( + { + "id": [1, 2, 3], + "name": ["Doc1", "Doc2", "Doc3"], + "embedding": pa.FixedSizeListArray.from_arrays( + embedding_values.flatten(), list_size=3 + ), + } + ) + + dataset_path = tmp_path / "Document.lance" + lance.write_dataset(documents_table, dataset_path) + lance_dataset = lance.dataset(str(dataset_path)) + + query = CypherQuery( + "MATCH (d:Document) RETURN d.id, d.name, d.embedding" + ).with_config(config) + + results = query.execute_with_vector_rerank( + {"Document": lance_dataset}, + VectorSearch("d.embedding") + .query_vector([1.0, 0.0, 0.0]) + .metric(DistanceMetric.Cosine) # Using cosine metric + .top_k(2) + .use_lance_index(True), + ) + + data = results.to_pydict() + assert len(data["d.name"]) == 2 + assert data["d.name"][0] == "Doc1" + + +@pytest.mark.requires_lance +def test_use_lance_index_dot_metric(vector_env, tmp_path): + """Test use_lance_index with dot product metric.""" + config, _, _ = vector_env + + import lance + import numpy as np + + embedding_values = np.array( + [ + [1.0, 0.0, 0.0], + [0.9, 0.1, 0.0], + [0.0, 1.0, 0.0], + ], + dtype=np.float32, + ) + + documents_table = pa.table( + { + "id": [1, 2, 3], + "name": ["Doc1", "Doc2", "Doc3"], + "embedding": pa.FixedSizeListArray.from_arrays( + embedding_values.flatten(), list_size=3 + ), + } + ) + + dataset_path = tmp_path / "Document.lance" + lance.write_dataset(documents_table, dataset_path) + lance_dataset = lance.dataset(str(dataset_path)) + + query = CypherQuery( + "MATCH (d:Document) RETURN d.id, d.name, d.embedding" + ).with_config(config) + + results = query.execute_with_vector_rerank( + {"Document": lance_dataset}, + VectorSearch("d.embedding") + .query_vector([1.0, 0.0, 0.0]) + .metric(DistanceMetric.Dot) # Using dot product metric + .top_k(2) + .use_lance_index(True), + ) + + data = results.to_pydict() + assert len(data["d.name"]) == 2 + assert data["d.name"][0] == "Doc1" + + +@pytest.mark.requires_lance +def test_execute_with_vector_rerank_lance_index(vector_env, tmp_path): + """Test vector-first execution using Lance datasets. + + Note: This test does NOT create an actual vector index on the Lance dataset. + Lance will fall back to flat (brute-force) search when use_index=True is set + but no index exists. This test validates: + 1. The code path for the vector-first execution is exercised + 2. Results are correct (matching the standard rerank behavior) + 3. The Lance dataset integration works end-to-end + + To test actual ANN index behavior, create an index with: + lance_dataset.create_index("embedding", index_type="IVF_PQ", ...) + """ + config, _, _ = vector_env + + import lance + import numpy as np + + # Create embeddings with fixed-size list type (required for Lance vector search) + embedding_values = np.array( + [ + [1.0, 0.0, 0.0], + [0.9, 0.1, 0.0], + [0.0, 1.0, 0.0], + [0.0, 0.0, 1.0], + [0.5, 0.5, 0.0], + ], + dtype=np.float32, + ) + + documents_table = pa.table( + { + "id": [1, 2, 3, 4, 5], + "name": ["Doc1", "Doc2", "Doc3", "Doc4", "Doc5"], + "category": ["tech", "tech", "science", "tech", "science"], + "embedding": pa.FixedSizeListArray.from_arrays( + embedding_values.flatten(), list_size=3 + ), + } + ) + + dataset_path = tmp_path / "Document.lance" + lance.write_dataset(documents_table, dataset_path) + lance_dataset = lance.dataset(str(dataset_path)) + + query = CypherQuery( + "MATCH (d:Document) RETURN d.id, d.name, d.embedding" + ).with_config(config) + + results = query.execute_with_vector_rerank( + {"Document": lance_dataset}, + VectorSearch("d.embedding") + .query_vector([1.0, 0.0, 0.0]) + .metric(DistanceMetric.L2) + .top_k(3) + .use_lance_index(True), + ) + + data = results.to_pydict() + assert len(data["d.name"]) == 3 + assert data["d.name"][0] == "Doc1" + assert data["d.name"][1] == "Doc2" + + +@pytest.mark.requires_lance +def test_execute_with_vector_rerank_lance_index_fallback_on_where(vector_env, tmp_path): + """Test that use_lance_index falls back to standard rerank with WHERE clause. + + When a Cypher query includes filters (WHERE clause), the vector-first path would + change semantics: it would search ALL vectors first, then apply filters. This could + miss relevant results that match the filter but aren't in the top-k vectors. + + The implementation correctly detects this and falls back to the standard + candidate-then-rerank path. + """ + config, _, _ = vector_env + + import lance + import numpy as np + + # Create embeddings with fixed-size list type (required for Lance vector search) + embedding_values = np.array( + [ + [1.0, 0.0, 0.0], + [0.9, 0.1, 0.0], + [0.0, 1.0, 0.0], + [0.0, 0.0, 1.0], + [0.5, 0.5, 0.0], + ], + dtype=np.float32, + ) + + documents_table = pa.table( + { + "id": [1, 2, 3, 4, 5], + "name": ["Doc1", "Doc2", "Doc3", "Doc4", "Doc5"], + "category": ["tech", "tech", "science", "tech", "science"], + "embedding": pa.FixedSizeListArray.from_arrays( + embedding_values.flatten(), list_size=3 + ), + } + ) + + dataset_path = tmp_path / "Document.lance" + lance.write_dataset(documents_table, dataset_path) + lance_dataset = lance.dataset(str(dataset_path)) + + # Query WITH a WHERE clause - should fall back to standard rerank + query = CypherQuery( + "MATCH (d:Document) WHERE d.category = 'tech' RETURN d.id, d.name, d.embedding" + ).with_config(config) + + results = query.execute_with_vector_rerank( + {"Document": lance_dataset}, + VectorSearch("d.embedding") + .query_vector([1.0, 0.0, 0.0]) + .metric(DistanceMetric.L2) + .top_k(3) + .use_lance_index(True), # This will be ignored due to WHERE clause + ) + + data = results.to_pydict() + # Should only have tech documents (Doc1, Doc2, Doc4), not science docs + assert len(data["d.name"]) == 3 + assert all(name in ["Doc1", "Doc2", "Doc4"] for name in data["d.name"]) + # Doc1 should still be first (closest to [1,0,0]) + assert data["d.name"][0] == "Doc1" + + def test_execute_with_vector_rerank_filtered(vector_env): """Test Cypher filter + vector rerank.""" config, datasets, _ = vector_env