Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 0 additions & 3 deletions .github/workflows/rust.yml
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,6 @@ jobs:
sudo apt install -y protobuf-compiler libssl-dev
- name: Run cargo fmt
run: cargo fmt --check
- name: Run clippy
#run: cargo clippy -- --deny "warnings"
run: cargo clippy
- name: Run tests
run: |
cargo build --all-features
Expand Down
262 changes: 234 additions & 28 deletions rust/src/dataset/scanner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,9 @@ use datafusion::execution::{
context::SessionState,
runtime_env::{RuntimeConfig, RuntimeEnv},
};
use datafusion::physical_plan::filter::FilterExec;
use datafusion::physical_plan::{
limit::GlobalLimitExec, ExecutionPlan, PhysicalExpr, SendableRecordBatchStream,
filter::FilterExec, limit::GlobalLimitExec, ExecutionPlan, PhysicalExpr,
SendableRecordBatchStream,
};
use datafusion::prelude::*;
use futures::stream::{Stream, StreamExt};
Expand All @@ -36,9 +36,7 @@ use crate::datafusion::physical_expr::column_names_in_expr;
use crate::datatypes::Schema;
use crate::format::Index;
use crate::index::vector::{MetricType, Query};
use crate::io::exec::{
KNNFlatExec, KNNIndexExec, LanceScanExec, LocalTakeExec, ProjectionExec, TakeExec,
};
use crate::io::exec::{KNNFlatExec, KNNIndexExec, LanceScanExec, ProjectionExec, TakeExec};
use crate::utils::sql::parse_sql_filter;
use crate::{Error, Result};

Expand Down Expand Up @@ -273,11 +271,11 @@ impl Scanner {
}
}

let knn_node = self.ann(q, &index); // score, _rowid
let knn_node = self.ann(q, &index)?; // score, _rowid
let with_vector = self.dataset.schema().project(&[&q.column])?;
let knn_node_with_vector = self.take(knn_node, &with_vector)?;
let knn_node = if q.refine_factor.is_some() {
self.flat_knn(knn_node_with_vector, q)
self.flat_knn(knn_node_with_vector, q)?
} else {
knn_node_with_vector
}; // vector, score, _rowid
Expand All @@ -290,7 +288,7 @@ impl Scanner {
let vector_scan_projection =
Arc::new(self.dataset.schema().project(&[&q.column]).unwrap());
let scan_node = self.scan(true, vector_scan_projection);
let knn_node = self.flat_knn(scan_node, q);
let knn_node = self.flat_knn(scan_node, q)?;

let knn_node = filter_expr
.map(|f| self.filter_knn(knn_node.clone(), f))
Expand All @@ -308,7 +306,7 @@ impl Scanner {
)?,
);
let scan = self.scan(true, filter_schema);
self.filter_node(filter, scan, true, None)?
self.filter_node(filter, scan)?
} else {
self.scan(with_row_id, Arc::new(self.projections.clone()))
};
Expand Down Expand Up @@ -346,12 +344,7 @@ impl Scanner {
knn_node,
Arc::new(filter_projection),
)?);
self.filter_node(
filter_expression,
take_node,
false,
Some(Arc::new(self.vector_search_schema()?)),
)
self.filter_node(filter_expression, take_node)
}

/// Create an Execution plan with a scan node
Expand All @@ -367,17 +360,17 @@ impl Scanner {
}

/// Add a knn search node to the input plan
fn flat_knn(&self, input: Arc<dyn ExecutionPlan>, q: &Query) -> Arc<dyn ExecutionPlan> {
Arc::new(KNNFlatExec::new(input, q.clone()))
fn flat_knn(&self, input: Arc<dyn ExecutionPlan>, q: &Query) -> Result<Arc<dyn ExecutionPlan>> {
Ok(Arc::new(KNNFlatExec::try_new(input, q.clone())?))
}

/// Create an Execution plan to do indexed ANN search
fn ann(&self, q: &Query, index: &Index) -> Arc<dyn ExecutionPlan> {
Arc::new(KNNIndexExec::new(
fn ann(&self, q: &Query, index: &Index) -> Result<Arc<dyn ExecutionPlan>> {
Ok(Arc::new(KNNIndexExec::try_new(
self.dataset.clone(),
&index.uuid.to_string(),
q,
))
)?))
}

/// Take row indices produced by input plan from the dataset (with projection)
Expand Down Expand Up @@ -406,18 +399,14 @@ impl Scanner {
&self,
filter: Arc<dyn PhysicalExpr>,
input: Arc<dyn ExecutionPlan>,
drop_row_id: bool,
ann_schema: Option<Arc<Schema>>,
) -> Result<Arc<dyn ExecutionPlan>> {
let filter_node = Arc::new(FilterExec::try_new(filter, input)?);
let output_schema = self.scanner_output_schema()?;
Ok(Arc::new(LocalTakeExec::new(
filter_node,
Ok(Arc::new(TakeExec::try_new(
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why did we need a separate LocalTake before?

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Was attempted to make it similar to what C++ does, that it takes the same batch from the input BatchRecords, but later found out it is too difficult to implement that in Rust. In the end, LocalTake becomes almost identical to GlobalTake.

self.dataset.clone(),
filter_node,
output_schema,
ann_schema,
drop_row_id,
)))
)?))
}
}

Expand Down Expand Up @@ -451,16 +440,23 @@ impl Stream for RecordBatchStream {
#[cfg(test)]
mod test {

use std::collections::BTreeSet;
use std::path::PathBuf;

use super::*;

use arrow::array::as_primitive_array;
use arrow::compute::concat_batches;
use arrow_array::{ArrayRef, Int32Array, Int64Array, RecordBatchReader, StringArray};
use arrow::datatypes::Int32Type;
use arrow_array::{
ArrayRef, FixedSizeListArray, Int32Array, Int64Array, RecordBatchReader, StringArray,
};
use arrow_schema::DataType;
use futures::TryStreamExt;
use tempfile::tempdir;

use crate::index::vector::VectorIndexParams;
use crate::index::IndexType;
use crate::{arrow::RecordBatchBuffer, dataset::WriteParams};

#[tokio::test]
Expand Down Expand Up @@ -620,4 +616,214 @@ mod test {
.unwrap();
expected_batches
}

async fn create_vector_dataset(path: &str, build_index: bool) -> Dataset {
let schema = Arc::new(ArrowSchema::new(vec![
ArrowField::new("i", DataType::Int32, true),
ArrowField::new("s", DataType::Utf8, true),
ArrowField::new(
"vec",
DataType::FixedSizeList(
Box::new(ArrowField::new("item", DataType::Float32, true)),
32,
),
true,
),
]));

let batches = RecordBatchBuffer::new(
(0..5)
.map(|i| {
let vector_values: Float32Array = (0..32 * 80).map(|v| v as f32).collect();
let vectors = FixedSizeListArray::try_new(&vector_values, 32).unwrap();
RecordBatch::try_new(
schema.clone(),
vec![
Arc::new(Int32Array::from_iter_values(i * 80..(i + 1) * 80)),
Arc::new(StringArray::from_iter_values(
(i * 80..(i + 1) * 80).map(|v| format!("s-{}", v)),
)),
Arc::new(vectors),
],
)
.unwrap()
})
.collect(),
);

let mut params = WriteParams::default();
params.max_rows_per_group = 10;
let mut reader: Box<dyn RecordBatchReader> = Box::new(batches);

let dataset = Dataset::write(&mut reader, path, Some(params))
.await
.unwrap();

if build_index {
let params = VectorIndexParams::ivf_pq(2, 8, 2, false, MetricType::L2, 2);
dataset
.create_index(
&["vec"],
IndexType::Vector,
Some("idx".to_string()),
&params,
true,
)
.await
.unwrap();
}

Dataset::open(path).await.unwrap()
}

#[tokio::test]
async fn test_knn_nodes() {
for build_index in &[true, false] {
let test_dir = tempdir().unwrap();
let test_uri = test_dir.path().to_str().unwrap();
let dataset = create_vector_dataset(test_uri, *build_index).await;
let mut scan = dataset.scan();
let key: Float32Array = (32..64).map(|v| v as f32).collect();
scan.nearest("vec", &key, 5).unwrap();
scan.refine(5);

let results = scan
.try_into_stream()
.await
.unwrap()
.try_collect::<Vec<_>>()
.await
.unwrap();

assert_eq!(results.len(), 1);
let batch = &results[0];

assert_eq!(batch.num_rows(), 5);
assert_eq!(
batch.schema().as_ref(),
&ArrowSchema::new(vec![
ArrowField::new("i", DataType::Int32, true),
ArrowField::new("s", DataType::Utf8, true),
ArrowField::new(
"vec",
DataType::FixedSizeList(
Box::new(ArrowField::new("item", DataType::Float32, true)),
32,
),
true,
),
ArrowField::new("score", DataType::Float32, false),
])
);

let expected_i = BTreeSet::from_iter(vec![1, 81, 161, 241, 321]);
let column_i = batch.column_by_name("i").unwrap();
let actual_i: BTreeSet<i32> = as_primitive_array::<Int32Type>(column_i.as_ref())
.values()
.iter()
.copied()
.collect();
assert_eq!(expected_i, actual_i);
}
}

#[tokio::test]
async fn test_knn_with_filter() {
let test_dir = tempdir().unwrap();
let test_uri = test_dir.path().to_str().unwrap();
let dataset = create_vector_dataset(test_uri, true).await;
let mut scan = dataset.scan();
let key: Float32Array = (32..64).map(|v| v as f32).collect();
scan.nearest("vec", &key, 5).unwrap();
scan.filter("i > 100").unwrap();
scan.project(&["i"]).unwrap();
scan.refine(5);

let results = scan
.try_into_stream()
.await
.unwrap()
.try_collect::<Vec<_>>()
.await
.unwrap();

assert_eq!(results.len(), 1);
let batch = &results[0];

assert_eq!(batch.num_rows(), 3);
assert_eq!(
batch.schema().as_ref(),
&ArrowSchema::new(vec![
ArrowField::new("i", DataType::Int32, true),
ArrowField::new(
"vec",
DataType::FixedSizeList(
Box::new(ArrowField::new("item", DataType::Float32, true)),
32,
),
true,
),
ArrowField::new("score", DataType::Float32, false),
])
);

let expected_i = BTreeSet::from_iter(vec![161, 241, 321]);
let column_i = batch.column_by_name("i").unwrap();
let actual_i: BTreeSet<i32> = as_primitive_array::<Int32Type>(column_i.as_ref())
.values()
.iter()
.copied()
.collect();
assert_eq!(expected_i, actual_i);
}

#[tokio::test]
async fn test_refine_factor() {
let test_dir = tempdir().unwrap();
let test_uri = test_dir.path().to_str().unwrap();

let dataset = create_vector_dataset(test_uri, true).await;
let mut scan = dataset.scan();
let key: Float32Array = (32..64).map(|v| v as f32).collect();
scan.nearest("vec", &key, 5).unwrap();
scan.refine(5);

let results = scan
.try_into_stream()
.await
.unwrap()
.try_collect::<Vec<_>>()
.await
.unwrap();

assert_eq!(results.len(), 1);
let batch = &results[0];

assert_eq!(batch.num_rows(), 5);
assert_eq!(
batch.schema().as_ref(),
&ArrowSchema::new(vec![
ArrowField::new("i", DataType::Int32, true),
ArrowField::new("s", DataType::Utf8, true),
ArrowField::new(
"vec",
DataType::FixedSizeList(
Box::new(ArrowField::new("item", DataType::Float32, true)),
32,
),
true,
),
ArrowField::new("score", DataType::Float32, false),
])
);

let expected_i = BTreeSet::from_iter(vec![1, 81, 161, 241, 321]);
let column_i = batch.column_by_name("i").unwrap();
let actual_i: BTreeSet<i32> = as_primitive_array::<Int32Type>(column_i.as_ref())
.values()
.iter()
.copied()
.collect();
assert_eq!(expected_i, actual_i);
}
}
2 changes: 1 addition & 1 deletion rust/src/index/vector.rs
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ const MAX_ITERATIONS: usize = 50;
/// Maximum number of iterations for OPQ.
/// See OPQ paper for details.
const MAX_OPQ_ITERATIONS: usize = 100;
const SCORE_COL: &str = "score";
pub(crate) const SCORE_COL: &str = "score";
const INDEX_FILE_NAME: &str = "index.idx";

/// Query parameters for the vector indices
Expand Down
Loading