Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions rust/lance-index/src/vector/ivf/transform.rs
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ impl Transformer for PartitionTransformer {
// If the partition ID column is already present, we don't need to compute it again.
return Ok(batch.clone());
}

let arr =
batch
.column_by_name(&self.input_column)
Expand Down
69 changes: 68 additions & 1 deletion rust/lance-index/src/vector/pq/storage.rs
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,18 @@ impl ProductQuantizationStorage {
distance_type: DistanceType,
transposed: bool,
) -> Result<Self> {
if batch.num_columns() != 2 {
log::warn!(
"PQ storage should have 2 columns, but got {} columns: {}",
batch.num_columns(),
batch.schema(),
);
batch = batch.project(&[
batch.schema().index_of(ROW_ID)?,
batch.schema().index_of(PQ_CODE_COLUMN)?,
])?;
}

let Some(row_ids) = batch.column_by_name(ROW_ID) else {
return Err(Error::Index {
message: "Row ID column not found from PQ storage".to_string(),
Expand Down Expand Up @@ -966,7 +978,7 @@ mod tests {

use super::*;

use arrow_array::Float32Array;
use arrow_array::{Float32Array, UInt32Array};
use arrow_schema::{DataType, Field, Schema as ArrowSchema};
use lance_arrow::FixedSizeListArrayExt;
use lance_core::datatypes::Schema;
Expand Down Expand Up @@ -1005,6 +1017,40 @@ mod tests {
.unwrap()
}

async fn create_pq_storage_with_extra_column() -> ProductQuantizationStorage {
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we still get PQ storage with an extra column from a real workflow? Or is this just generating some kind of invalid input for testing?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it's just for testing, we shouldn't see any extra column in real workflow

let codebook = Float32Array::from_iter_values((0..256 * DIM).map(|_| rand::random()));
let codebook = FixedSizeListArray::try_new_from_values(codebook, DIM as i32).unwrap();
let pq = ProductQuantizer::new(NUM_SUB_VECTORS, 8, DIM, codebook, DistanceType::Dot);

let schema = ArrowSchema::new(vec![
Field::new(
"vec",
DataType::FixedSizeList(
Field::new_list_field(DataType::Float32, true).into(),
DIM as i32,
),
true,
),
ROW_ID_FIELD.clone(),
Field::new("extra", DataType::UInt32, true),
]);
let vectors = Float32Array::from_iter_values((0..TOTAL * DIM).map(|_| rand::random()));
let row_ids = UInt64Array::from_iter_values((0..TOTAL).map(|v| v as u64));
let extra_column = UInt32Array::from_iter_values((0..TOTAL).map(|v| v as u32));
let fsl = FixedSizeListArray::try_new_from_values(vectors, DIM as i32).unwrap();
let batch = RecordBatch::try_new(
schema.into(),
vec![Arc::new(fsl), Arc::new(row_ids), Arc::new(extra_column)],
)
.unwrap();

StorageBuilder::new("vec".to_owned(), pq.distance_type, pq)
.unwrap()
.assert_num_columns(false)
.build(vec![batch])
.unwrap()
}

#[tokio::test]
async fn test_build_pq_storage() {
let storage = create_pq_storage().await;
Expand Down Expand Up @@ -1062,4 +1108,25 @@ mod tests {
let dist2 = storage.dist_between(v, u);
assert_eq!(dist1, dist2);
}

#[tokio::test]
async fn test_remap_with_extra_column() {
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this because some old indices will have this extra column and we need to make sure they are supported?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

right, we saw some feedbacks about this, so add this test to make sure the old indices could work with this fix

let storage = create_pq_storage_with_extra_column().await;
let mut mapping = HashMap::new();
for i in 0..TOTAL / 2 {
mapping.insert(i as u64, Some((TOTAL + i) as u64));
}
for i in TOTAL / 2..TOTAL {
mapping.insert(i as u64, None);
}
let new_storage = storage.remap(&mapping).unwrap();
assert_eq!(new_storage.len(), TOTAL / 2);
assert_eq!(new_storage.row_ids.len(), TOTAL / 2);
for (i, row_id) in new_storage.row_ids().enumerate() {
assert_eq!(*row_id, (TOTAL + i) as u64);
}
assert_eq!(new_storage.batch.num_columns(), 2);
assert!(new_storage.batch.column_by_name(ROW_ID).is_some());
assert!(new_storage.batch.column_by_name(PQ_CODE_COLUMN).is_some());
}
}
18 changes: 17 additions & 1 deletion rust/lance-index/src/vector/storage.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ use arrow_schema::SchemaRef;
use deepsize::DeepSizeOf;
use futures::prelude::stream::TryStreamExt;
use lance_arrow::RecordBatchExt;
use lance_core::{Error, Result};
use lance_core::{Error, Result, ROW_ID};
use lance_encoding::decoder::FilterExpression;
use lance_file::v2::reader::FileReader;
use lance_io::ReadBatchParams;
Expand Down Expand Up @@ -152,6 +152,9 @@ pub struct StorageBuilder<Q: Quantization> {
vector_column: String,
distance_type: DistanceType,
quantizer: Q,

// this is for testing purpose
assert_num_columns: bool,
}

impl<Q: Quantization> StorageBuilder<Q> {
Expand All @@ -160,9 +163,16 @@ impl<Q: Quantization> StorageBuilder<Q> {
vector_column,
distance_type,
quantizer,
assert_num_columns: true,
})
}

// this is for testing purpose
pub fn assert_num_columns(mut self, assert_num_columns: bool) -> Self {
self.assert_num_columns = assert_num_columns;
self
}

pub fn build(&self, batches: Vec<RecordBatch>) -> Result<Q::Storage> {
let mut batch = concat_batches(batches[0].schema_ref(), batches.iter())?;

Expand All @@ -180,6 +190,12 @@ impl<Q: Quantization> StorageBuilder<Q> {
)?;
}

if self.assert_num_columns {
debug_assert_eq!(batch.num_columns(), 2, "{}", batch.schema());
}
debug_assert!(batch.column_by_name(ROW_ID).is_some());
debug_assert!(batch.column_by_name(self.quantizer.column()).is_some());

let batch = batch.add_metadata(
STORAGE_METADATA_KEY.to_owned(),
self.quantizer.metadata(None)?.to_string(),
Expand Down
8 changes: 5 additions & 3 deletions rust/lance/src/index/vector/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ use lance_index::vector::quantizer::{
use lance_index::vector::storage::STORAGE_METADATA_KEY;
use lance_index::vector::v3::shuffler::IvfShufflerReader;
use lance_index::vector::v3::subindex::SubIndexType;
use lance_index::vector::{VectorIndex, LOSS_METADATA_KEY, PQ_CODE_COLUMN};
use lance_index::vector::{VectorIndex, LOSS_METADATA_KEY, PART_ID_COLUMN, PQ_CODE_COLUMN};
use lance_index::{
pb,
vector::{
Expand Down Expand Up @@ -653,8 +653,9 @@ impl<S: IvfSubIndex + 'static, Q: Quantization + 'static> IvfIndexBuilder<S, Q>
original_codes,
codes_num_bytes as i32,
)?;
*batch =
batch.replace_column_by_name(PQ_CODE_COLUMN, Arc::new(original_codes))?;
*batch = batch
.replace_column_by_name(PQ_CODE_COLUMN, Arc::new(original_codes))?
.drop_column(PART_ID_COLUMN)?;
}
}
batches.extend(part_batches);
Expand All @@ -672,6 +673,7 @@ impl<S: IvfSubIndex + 'static, Q: Quantization + 'static> IvfIndexBuilder<S, Q>
.get(LOSS_METADATA_KEY)
.map(|s| s.parse::<f64>().unwrap_or(0.0))
.unwrap_or(0.0);
let batch = batch.drop_column(PART_ID_COLUMN)?;
batches.push(batch);
}
}
Expand Down
8 changes: 6 additions & 2 deletions rust/lance/src/index/vector/ivf/v2.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1131,7 +1131,8 @@ mod tests {
}
if count >= 10 {
panic!(
"failed to hit the retrain threshold {}",
"failed to hit the retrain threshold {} < {}",
last_avg_loss / original_avg_loss,
AVG_LOSS_RETRAIN_THRESHOLD
);
}
Expand All @@ -1156,7 +1157,7 @@ mod tests {
let ivf_models = get_ivf_models(&dataset).await;
let ivf = &ivf_models[0];
assert_ne!(original_ivf.centroids, ivf.centroids);
if params.metric_type != DistanceType::Hamming {
if ivf.num_partitions() > 1 && params.metric_type != DistanceType::Hamming {
assert_lt!(get_avg_loss(&dataset).await, last_avg_loss);
}
}
Expand Down Expand Up @@ -1211,6 +1212,9 @@ mod tests {
}

#[rstest]
#[case(1, DistanceType::L2, 0.9)]
#[case(1, DistanceType::Cosine, 0.9)]
#[case(1, DistanceType::Dot, 0.85)]
#[case(4, DistanceType::L2, 0.9)]
#[case(4, DistanceType::Cosine, 0.9)]
#[case(4, DistanceType::Dot, 0.85)]
Expand Down