diff --git a/rust/lance-index/src/vector/ivf/transform.rs b/rust/lance-index/src/vector/ivf/transform.rs index 7f80b188263..d2f877cee15 100644 --- a/rust/lance-index/src/vector/ivf/transform.rs +++ b/rust/lance-index/src/vector/ivf/transform.rs @@ -72,6 +72,7 @@ impl Transformer for PartitionTransformer { // If the partition ID column is already present, we don't need to compute it again. return Ok(batch.clone()); } + let arr = batch .column_by_name(&self.input_column) diff --git a/rust/lance-index/src/vector/pq/storage.rs b/rust/lance-index/src/vector/pq/storage.rs index 1e12d2ab4b3..67be90b5519 100644 --- a/rust/lance-index/src/vector/pq/storage.rs +++ b/rust/lance-index/src/vector/pq/storage.rs @@ -152,6 +152,18 @@ impl ProductQuantizationStorage { distance_type: DistanceType, transposed: bool, ) -> Result { + if batch.num_columns() != 2 { + log::warn!( + "PQ storage should have 2 columns, but got {} columns: {}", + batch.num_columns(), + batch.schema(), + ); + batch = batch.project(&[ + batch.schema().index_of(ROW_ID)?, + batch.schema().index_of(PQ_CODE_COLUMN)?, + ])?; + } + let Some(row_ids) = batch.column_by_name(ROW_ID) else { return Err(Error::Index { message: "Row ID column not found from PQ storage".to_string(), @@ -966,7 +978,7 @@ mod tests { use super::*; - use arrow_array::Float32Array; + use arrow_array::{Float32Array, UInt32Array}; use arrow_schema::{DataType, Field, Schema as ArrowSchema}; use lance_arrow::FixedSizeListArrayExt; use lance_core::datatypes::Schema; @@ -1005,6 +1017,40 @@ mod tests { .unwrap() } + async fn create_pq_storage_with_extra_column() -> ProductQuantizationStorage { + let codebook = Float32Array::from_iter_values((0..256 * DIM).map(|_| rand::random())); + let codebook = FixedSizeListArray::try_new_from_values(codebook, DIM as i32).unwrap(); + let pq = ProductQuantizer::new(NUM_SUB_VECTORS, 8, DIM, codebook, DistanceType::Dot); + + let schema = ArrowSchema::new(vec![ + Field::new( + "vec", + DataType::FixedSizeList( + Field::new_list_field(DataType::Float32, true).into(), + DIM as i32, + ), + true, + ), + ROW_ID_FIELD.clone(), + Field::new("extra", DataType::UInt32, true), + ]); + let vectors = Float32Array::from_iter_values((0..TOTAL * DIM).map(|_| rand::random())); + let row_ids = UInt64Array::from_iter_values((0..TOTAL).map(|v| v as u64)); + let extra_column = UInt32Array::from_iter_values((0..TOTAL).map(|v| v as u32)); + let fsl = FixedSizeListArray::try_new_from_values(vectors, DIM as i32).unwrap(); + let batch = RecordBatch::try_new( + schema.into(), + vec![Arc::new(fsl), Arc::new(row_ids), Arc::new(extra_column)], + ) + .unwrap(); + + StorageBuilder::new("vec".to_owned(), pq.distance_type, pq) + .unwrap() + .assert_num_columns(false) + .build(vec![batch]) + .unwrap() + } + #[tokio::test] async fn test_build_pq_storage() { let storage = create_pq_storage().await; @@ -1062,4 +1108,25 @@ mod tests { let dist2 = storage.dist_between(v, u); assert_eq!(dist1, dist2); } + + #[tokio::test] + async fn test_remap_with_extra_column() { + let storage = create_pq_storage_with_extra_column().await; + let mut mapping = HashMap::new(); + for i in 0..TOTAL / 2 { + mapping.insert(i as u64, Some((TOTAL + i) as u64)); + } + for i in TOTAL / 2..TOTAL { + mapping.insert(i as u64, None); + } + let new_storage = storage.remap(&mapping).unwrap(); + assert_eq!(new_storage.len(), TOTAL / 2); + assert_eq!(new_storage.row_ids.len(), TOTAL / 2); + for (i, row_id) in new_storage.row_ids().enumerate() { + assert_eq!(*row_id, (TOTAL + i) as u64); + } + assert_eq!(new_storage.batch.num_columns(), 2); + assert!(new_storage.batch.column_by_name(ROW_ID).is_some()); + assert!(new_storage.batch.column_by_name(PQ_CODE_COLUMN).is_some()); + } } diff --git a/rust/lance-index/src/vector/storage.rs b/rust/lance-index/src/vector/storage.rs index e08a1079ad0..285994ac756 100644 --- a/rust/lance-index/src/vector/storage.rs +++ b/rust/lance-index/src/vector/storage.rs @@ -14,7 +14,7 @@ use arrow_schema::SchemaRef; use deepsize::DeepSizeOf; use futures::prelude::stream::TryStreamExt; use lance_arrow::RecordBatchExt; -use lance_core::{Error, Result}; +use lance_core::{Error, Result, ROW_ID}; use lance_encoding::decoder::FilterExpression; use lance_file::v2::reader::FileReader; use lance_io::ReadBatchParams; @@ -152,6 +152,9 @@ pub struct StorageBuilder { vector_column: String, distance_type: DistanceType, quantizer: Q, + + // this is for testing purpose + assert_num_columns: bool, } impl StorageBuilder { @@ -160,9 +163,16 @@ impl StorageBuilder { vector_column, distance_type, quantizer, + assert_num_columns: true, }) } + // this is for testing purpose + pub fn assert_num_columns(mut self, assert_num_columns: bool) -> Self { + self.assert_num_columns = assert_num_columns; + self + } + pub fn build(&self, batches: Vec) -> Result { let mut batch = concat_batches(batches[0].schema_ref(), batches.iter())?; @@ -180,6 +190,12 @@ impl StorageBuilder { )?; } + if self.assert_num_columns { + debug_assert_eq!(batch.num_columns(), 2, "{}", batch.schema()); + } + debug_assert!(batch.column_by_name(ROW_ID).is_some()); + debug_assert!(batch.column_by_name(self.quantizer.column()).is_some()); + let batch = batch.add_metadata( STORAGE_METADATA_KEY.to_owned(), self.quantizer.metadata(None)?.to_string(), diff --git a/rust/lance/src/index/vector/builder.rs b/rust/lance/src/index/vector/builder.rs index 64bc2c127a4..cc163f124ea 100644 --- a/rust/lance/src/index/vector/builder.rs +++ b/rust/lance/src/index/vector/builder.rs @@ -26,7 +26,7 @@ use lance_index::vector::quantizer::{ use lance_index::vector::storage::STORAGE_METADATA_KEY; use lance_index::vector::v3::shuffler::IvfShufflerReader; use lance_index::vector::v3::subindex::SubIndexType; -use lance_index::vector::{VectorIndex, LOSS_METADATA_KEY, PQ_CODE_COLUMN}; +use lance_index::vector::{VectorIndex, LOSS_METADATA_KEY, PART_ID_COLUMN, PQ_CODE_COLUMN}; use lance_index::{ pb, vector::{ @@ -653,8 +653,9 @@ impl IvfIndexBuilder original_codes, codes_num_bytes as i32, )?; - *batch = - batch.replace_column_by_name(PQ_CODE_COLUMN, Arc::new(original_codes))?; + *batch = batch + .replace_column_by_name(PQ_CODE_COLUMN, Arc::new(original_codes))? + .drop_column(PART_ID_COLUMN)?; } } batches.extend(part_batches); @@ -672,6 +673,7 @@ impl IvfIndexBuilder .get(LOSS_METADATA_KEY) .map(|s| s.parse::().unwrap_or(0.0)) .unwrap_or(0.0); + let batch = batch.drop_column(PART_ID_COLUMN)?; batches.push(batch); } } diff --git a/rust/lance/src/index/vector/ivf/v2.rs b/rust/lance/src/index/vector/ivf/v2.rs index df5b9a1a7df..ea7c616bf16 100644 --- a/rust/lance/src/index/vector/ivf/v2.rs +++ b/rust/lance/src/index/vector/ivf/v2.rs @@ -1131,7 +1131,8 @@ mod tests { } if count >= 10 { panic!( - "failed to hit the retrain threshold {}", + "failed to hit the retrain threshold {} < {}", + last_avg_loss / original_avg_loss, AVG_LOSS_RETRAIN_THRESHOLD ); } @@ -1156,7 +1157,7 @@ mod tests { let ivf_models = get_ivf_models(&dataset).await; let ivf = &ivf_models[0]; assert_ne!(original_ivf.centroids, ivf.centroids); - if params.metric_type != DistanceType::Hamming { + if ivf.num_partitions() > 1 && params.metric_type != DistanceType::Hamming { assert_lt!(get_avg_loss(&dataset).await, last_avg_loss); } } @@ -1211,6 +1212,9 @@ mod tests { } #[rstest] + #[case(1, DistanceType::L2, 0.9)] + #[case(1, DistanceType::Cosine, 0.9)] + #[case(1, DistanceType::Dot, 0.85)] #[case(4, DistanceType::L2, 0.9)] #[case(4, DistanceType::Cosine, 0.9)] #[case(4, DistanceType::Dot, 0.85)]