From 2d96e3b4e60d38af4f76387d367691a7151ae8eb Mon Sep 17 00:00:00 2001 From: BubbleCal Date: Wed, 26 Mar 2025 13:16:34 +0800 Subject: [PATCH 1/6] fix: schema isn't expected for IVF_PQ Signed-off-by: BubbleCal --- rust/lance-index/src/vector/pq/storage.rs | 14 ++++++++++++++ rust/lance-index/src/vector/storage.rs | 6 +++++- 2 files changed, 19 insertions(+), 1 deletion(-) diff --git a/rust/lance-index/src/vector/pq/storage.rs b/rust/lance-index/src/vector/pq/storage.rs index 1e12d2ab4b3..9a15f1f090d 100644 --- a/rust/lance-index/src/vector/pq/storage.rs +++ b/rust/lance-index/src/vector/pq/storage.rs @@ -152,6 +152,20 @@ impl ProductQuantizationStorage { distance_type: DistanceType, transposed: bool, ) -> Result { + debug_assert_eq!(batch.num_columns(), 2); + + if batch.num_columns() != 2 { + log::warn!( + "PQ storage should have 2 columns, but got {} columns: {}", + batch.num_columns(), + batch.schema(), + ); + batch = batch.project(&[ + batch.schema().index_of(ROW_ID)?, + batch.schema().index_of(PQ_CODE_COLUMN)?, + ])?; + } + let Some(row_ids) = batch.column_by_name(ROW_ID) else { return Err(Error::Index { message: "Row ID column not found from PQ storage".to_string(), diff --git a/rust/lance-index/src/vector/storage.rs b/rust/lance-index/src/vector/storage.rs index e08a1079ad0..f5c3ae26247 100644 --- a/rust/lance-index/src/vector/storage.rs +++ b/rust/lance-index/src/vector/storage.rs @@ -14,7 +14,7 @@ use arrow_schema::SchemaRef; use deepsize::DeepSizeOf; use futures::prelude::stream::TryStreamExt; use lance_arrow::RecordBatchExt; -use lance_core::{Error, Result}; +use lance_core::{Error, Result, ROW_ID}; use lance_encoding::decoder::FilterExpression; use lance_file::v2::reader::FileReader; use lance_io::ReadBatchParams; @@ -180,6 +180,10 @@ impl StorageBuilder { )?; } + debug_assert_eq!(batch.num_columns(), 2); + debug_assert!(batch.column_by_name(ROW_ID).is_some()); + debug_assert!(batch.column_by_name(self.quantizer.column()).is_some()); + let batch = batch.add_metadata( STORAGE_METADATA_KEY.to_owned(), self.quantizer.metadata(None)?.to_string(), From f0808b215cbb167f41f94bfd63a835158575a777 Mon Sep 17 00:00:00 2001 From: BubbleCal Date: Wed, 26 Mar 2025 13:35:22 +0800 Subject: [PATCH 2/6] add test Signed-off-by: BubbleCal --- rust/lance-index/src/vector/pq/storage.rs | 59 +++++++++++++++++++++-- rust/lance-index/src/vector/storage.rs | 14 +++++- 2 files changed, 69 insertions(+), 4 deletions(-) diff --git a/rust/lance-index/src/vector/pq/storage.rs b/rust/lance-index/src/vector/pq/storage.rs index 9a15f1f090d..67be90b5519 100644 --- a/rust/lance-index/src/vector/pq/storage.rs +++ b/rust/lance-index/src/vector/pq/storage.rs @@ -152,8 +152,6 @@ impl ProductQuantizationStorage { distance_type: DistanceType, transposed: bool, ) -> Result { - debug_assert_eq!(batch.num_columns(), 2); - if batch.num_columns() != 2 { log::warn!( "PQ storage should have 2 columns, but got {} columns: {}", @@ -980,7 +978,7 @@ mod tests { use super::*; - use arrow_array::Float32Array; + use arrow_array::{Float32Array, UInt32Array}; use arrow_schema::{DataType, Field, Schema as ArrowSchema}; use lance_arrow::FixedSizeListArrayExt; use lance_core::datatypes::Schema; @@ -1019,6 +1017,40 @@ mod tests { .unwrap() } + async fn create_pq_storage_with_extra_column() -> ProductQuantizationStorage { + let codebook = Float32Array::from_iter_values((0..256 * DIM).map(|_| rand::random())); + let codebook = FixedSizeListArray::try_new_from_values(codebook, DIM as i32).unwrap(); + let pq = ProductQuantizer::new(NUM_SUB_VECTORS, 8, DIM, codebook, DistanceType::Dot); + + let schema = ArrowSchema::new(vec![ + Field::new( + "vec", + DataType::FixedSizeList( + Field::new_list_field(DataType::Float32, true).into(), + DIM as i32, + ), + true, + ), + ROW_ID_FIELD.clone(), + Field::new("extra", DataType::UInt32, true), + ]); + let vectors = Float32Array::from_iter_values((0..TOTAL * DIM).map(|_| rand::random())); + let row_ids = UInt64Array::from_iter_values((0..TOTAL).map(|v| v as u64)); + let extra_column = UInt32Array::from_iter_values((0..TOTAL).map(|v| v as u32)); + let fsl = FixedSizeListArray::try_new_from_values(vectors, DIM as i32).unwrap(); + let batch = RecordBatch::try_new( + schema.into(), + vec![Arc::new(fsl), Arc::new(row_ids), Arc::new(extra_column)], + ) + .unwrap(); + + StorageBuilder::new("vec".to_owned(), pq.distance_type, pq) + .unwrap() + .assert_num_columns(false) + .build(vec![batch]) + .unwrap() + } + #[tokio::test] async fn test_build_pq_storage() { let storage = create_pq_storage().await; @@ -1076,4 +1108,25 @@ mod tests { let dist2 = storage.dist_between(v, u); assert_eq!(dist1, dist2); } + + #[tokio::test] + async fn test_remap_with_extra_column() { + let storage = create_pq_storage_with_extra_column().await; + let mut mapping = HashMap::new(); + for i in 0..TOTAL / 2 { + mapping.insert(i as u64, Some((TOTAL + i) as u64)); + } + for i in TOTAL / 2..TOTAL { + mapping.insert(i as u64, None); + } + let new_storage = storage.remap(&mapping).unwrap(); + assert_eq!(new_storage.len(), TOTAL / 2); + assert_eq!(new_storage.row_ids.len(), TOTAL / 2); + for (i, row_id) in new_storage.row_ids().enumerate() { + assert_eq!(*row_id, (TOTAL + i) as u64); + } + assert_eq!(new_storage.batch.num_columns(), 2); + assert!(new_storage.batch.column_by_name(ROW_ID).is_some()); + assert!(new_storage.batch.column_by_name(PQ_CODE_COLUMN).is_some()); + } } diff --git a/rust/lance-index/src/vector/storage.rs b/rust/lance-index/src/vector/storage.rs index f5c3ae26247..88015dbeffb 100644 --- a/rust/lance-index/src/vector/storage.rs +++ b/rust/lance-index/src/vector/storage.rs @@ -152,6 +152,9 @@ pub struct StorageBuilder { vector_column: String, distance_type: DistanceType, quantizer: Q, + + // this is for testing purpose + assert_num_columns: bool, } impl StorageBuilder { @@ -160,9 +163,16 @@ impl StorageBuilder { vector_column, distance_type, quantizer, + assert_num_columns: true, }) } + // this is for testing purpose + pub fn assert_num_columns(mut self, assert_num_columns: bool) -> Self { + self.assert_num_columns = assert_num_columns; + self + } + pub fn build(&self, batches: Vec) -> Result { let mut batch = concat_batches(batches[0].schema_ref(), batches.iter())?; @@ -180,7 +190,9 @@ impl StorageBuilder { )?; } - debug_assert_eq!(batch.num_columns(), 2); + if self.assert_num_columns { + debug_assert_eq!(batch.num_columns(), 2); + } debug_assert!(batch.column_by_name(ROW_ID).is_some()); debug_assert!(batch.column_by_name(self.quantizer.column()).is_some()); From 045f661a67d6570174f264e77ce6ea6d14bb2d60 Mon Sep 17 00:00:00 2001 From: BubbleCal Date: Wed, 26 Mar 2025 14:16:55 +0800 Subject: [PATCH 3/6] fix Signed-off-by: BubbleCal --- rust/lance-index/src/vector/ivf/transform.rs | 9 +++++++++ rust/lance-index/src/vector/storage.rs | 3 ++- rust/lance/src/index/vector/builder.rs | 8 +++++--- rust/lance/src/index/vector/ivf/v2.rs | 3 +++ 4 files changed, 19 insertions(+), 4 deletions(-) diff --git a/rust/lance-index/src/vector/ivf/transform.rs b/rust/lance-index/src/vector/ivf/transform.rs index 7f80b188263..b8239027049 100644 --- a/rust/lance-index/src/vector/ivf/transform.rs +++ b/rust/lance-index/src/vector/ivf/transform.rs @@ -72,6 +72,15 @@ impl Transformer for PartitionTransformer { // If the partition ID column is already present, we don't need to compute it again. return Ok(batch.clone()); } + + if self.centroids.len() == 1 { + // If there is only one centroid, we can skip the computation. + // Just add a column with all zeros. + let part_ids = UInt32Array::from(vec![0; batch.num_rows()]); + let field = Field::new(PART_ID_COLUMN, part_ids.data_type().clone(), true); + return Ok(batch.try_with_column(field, Arc::new(part_ids))?); + } + let arr = batch .column_by_name(&self.input_column) diff --git a/rust/lance-index/src/vector/storage.rs b/rust/lance-index/src/vector/storage.rs index 88015dbeffb..43db6d7aff1 100644 --- a/rust/lance-index/src/vector/storage.rs +++ b/rust/lance-index/src/vector/storage.rs @@ -22,6 +22,7 @@ use lance_linalg::distance::DistanceType; use prost::Message; use snafu::location; +use crate::vector::PART_ID_COLUMN; use crate::{ pb, vector::{ @@ -191,7 +192,7 @@ impl StorageBuilder { } if self.assert_num_columns { - debug_assert_eq!(batch.num_columns(), 2); + debug_assert_eq!(batch.num_columns(), 2, "{}", batch.schema()); } debug_assert!(batch.column_by_name(ROW_ID).is_some()); debug_assert!(batch.column_by_name(self.quantizer.column()).is_some()); diff --git a/rust/lance/src/index/vector/builder.rs b/rust/lance/src/index/vector/builder.rs index 64bc2c127a4..cc163f124ea 100644 --- a/rust/lance/src/index/vector/builder.rs +++ b/rust/lance/src/index/vector/builder.rs @@ -26,7 +26,7 @@ use lance_index::vector::quantizer::{ use lance_index::vector::storage::STORAGE_METADATA_KEY; use lance_index::vector::v3::shuffler::IvfShufflerReader; use lance_index::vector::v3::subindex::SubIndexType; -use lance_index::vector::{VectorIndex, LOSS_METADATA_KEY, PQ_CODE_COLUMN}; +use lance_index::vector::{VectorIndex, LOSS_METADATA_KEY, PART_ID_COLUMN, PQ_CODE_COLUMN}; use lance_index::{ pb, vector::{ @@ -653,8 +653,9 @@ impl IvfIndexBuilder original_codes, codes_num_bytes as i32, )?; - *batch = - batch.replace_column_by_name(PQ_CODE_COLUMN, Arc::new(original_codes))?; + *batch = batch + .replace_column_by_name(PQ_CODE_COLUMN, Arc::new(original_codes))? + .drop_column(PART_ID_COLUMN)?; } } batches.extend(part_batches); @@ -672,6 +673,7 @@ impl IvfIndexBuilder .get(LOSS_METADATA_KEY) .map(|s| s.parse::().unwrap_or(0.0)) .unwrap_or(0.0); + let batch = batch.drop_column(PART_ID_COLUMN)?; batches.push(batch); } } diff --git a/rust/lance/src/index/vector/ivf/v2.rs b/rust/lance/src/index/vector/ivf/v2.rs index df5b9a1a7df..e109c9d4b60 100644 --- a/rust/lance/src/index/vector/ivf/v2.rs +++ b/rust/lance/src/index/vector/ivf/v2.rs @@ -1211,6 +1211,9 @@ mod tests { } #[rstest] + #[case(1, DistanceType::L2, 0.9)] + #[case(1, DistanceType::Cosine, 0.9)] + #[case(1, DistanceType::Dot, 0.85)] #[case(4, DistanceType::L2, 0.9)] #[case(4, DistanceType::Cosine, 0.9)] #[case(4, DistanceType::Dot, 0.85)] From b7a27ed59f0c549f32642f03c8576f8c6ba2e4c3 Mon Sep 17 00:00:00 2001 From: BubbleCal Date: Wed, 26 Mar 2025 14:20:19 +0800 Subject: [PATCH 4/6] fmt Signed-off-by: BubbleCal --- rust/lance-index/src/vector/storage.rs | 1 - 1 file changed, 1 deletion(-) diff --git a/rust/lance-index/src/vector/storage.rs b/rust/lance-index/src/vector/storage.rs index 43db6d7aff1..285994ac756 100644 --- a/rust/lance-index/src/vector/storage.rs +++ b/rust/lance-index/src/vector/storage.rs @@ -22,7 +22,6 @@ use lance_linalg::distance::DistanceType; use prost::Message; use snafu::location; -use crate::vector::PART_ID_COLUMN; use crate::{ pb, vector::{ From 52a9d7e46fad35085ff9060e153d87f0d65cd36e Mon Sep 17 00:00:00 2001 From: BubbleCal Date: Wed, 26 Mar 2025 14:51:00 +0800 Subject: [PATCH 5/6] fix Signed-off-by: BubbleCal --- rust/lance-index/src/vector/ivf/transform.rs | 8 -------- rust/lance/src/index/vector/ivf/v2.rs | 3 ++- 2 files changed, 2 insertions(+), 9 deletions(-) diff --git a/rust/lance-index/src/vector/ivf/transform.rs b/rust/lance-index/src/vector/ivf/transform.rs index b8239027049..d2f877cee15 100644 --- a/rust/lance-index/src/vector/ivf/transform.rs +++ b/rust/lance-index/src/vector/ivf/transform.rs @@ -73,14 +73,6 @@ impl Transformer for PartitionTransformer { return Ok(batch.clone()); } - if self.centroids.len() == 1 { - // If there is only one centroid, we can skip the computation. - // Just add a column with all zeros. - let part_ids = UInt32Array::from(vec![0; batch.num_rows()]); - let field = Field::new(PART_ID_COLUMN, part_ids.data_type().clone(), true); - return Ok(batch.try_with_column(field, Arc::new(part_ids))?); - } - let arr = batch .column_by_name(&self.input_column) diff --git a/rust/lance/src/index/vector/ivf/v2.rs b/rust/lance/src/index/vector/ivf/v2.rs index e109c9d4b60..3befff7bd52 100644 --- a/rust/lance/src/index/vector/ivf/v2.rs +++ b/rust/lance/src/index/vector/ivf/v2.rs @@ -1131,7 +1131,8 @@ mod tests { } if count >= 10 { panic!( - "failed to hit the retrain threshold {}", + "failed to hit the retrain threshold {} < {}", + last_avg_loss / original_avg_loss, AVG_LOSS_RETRAIN_THRESHOLD ); } From 7d3556e0a1ccc8e7a8a782567261bd08274342de Mon Sep 17 00:00:00 2001 From: BubbleCal Date: Wed, 26 Mar 2025 15:11:36 +0800 Subject: [PATCH 6/6] fix Signed-off-by: BubbleCal --- rust/lance/src/index/vector/ivf/v2.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/rust/lance/src/index/vector/ivf/v2.rs b/rust/lance/src/index/vector/ivf/v2.rs index 3befff7bd52..ea7c616bf16 100644 --- a/rust/lance/src/index/vector/ivf/v2.rs +++ b/rust/lance/src/index/vector/ivf/v2.rs @@ -1157,7 +1157,7 @@ mod tests { let ivf_models = get_ivf_models(&dataset).await; let ivf = &ivf_models[0]; assert_ne!(original_ivf.centroids, ivf.centroids); - if params.metric_type != DistanceType::Hamming { + if ivf.num_partitions() > 1 && params.metric_type != DistanceType::Hamming { assert_lt!(get_avg_loss(&dataset).await, last_avg_loss); } }