Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: migrate IVF_PQ indices when vector column is casted #2102

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 13 additions & 9 deletions protos/transaction.proto
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,17 @@ message Transaction {
repeated IndexMetadata removed_indices = 2;
}

// During a rewrite an index may be rewritten. We only serialize the UUID
// since a rewrite should not change the other index parameters.
message RewrittenIndex {
// The id of the index that will be replaced
UUID old_id = 1;
// the id of the new index
UUID new_id = 2;
// The new field ids it is based on. May be empty if unchanged.
repeated int32 new_field_ids = 3;
}

// An operation that rewrites but does not change the data in the table. These
// kinds of operations just rearrange data.
message Rewrite {
Expand All @@ -103,15 +114,6 @@ message Transaction {
// These fragments IDs are not yet assigned.
repeated DataFragment new_fragments = 2;

// During a rewrite an index may be rewritten. We only serialize the UUID
// since a rewrite should not change the other index parameters.
message RewrittenIndex {
// The id of the index that will be replaced
UUID old_id = 1;
// the id of the new index
UUID new_id = 2;
}

// A group of rewrite files that are all part of the same rewrite.
message RewriteGroup {
// The old fragment that is being replaced
Expand Down Expand Up @@ -139,6 +141,8 @@ message Transaction {
repeated DataFragment fragments = 1;
// The new schema
repeated lance.file.Field schema = 2;
// Migrated index ids
repeated RewrittenIndex rewritten_indices = 4;
// Schema metadata.
map<string, bytes> schema_metadata = 3;
}
Expand Down
6 changes: 5 additions & 1 deletion python/src/dataset.rs
Original file line number Diff line number Diff line change
Expand Up @@ -241,7 +241,11 @@ impl Operation {
fn merge(fragments: Vec<FragmentMetadata>, schema: PyArrowType<ArrowSchema>) -> PyResult<Self> {
let schema = convert_schema(&schema.0)?;
let fragments = into_fragments(fragments);
let op = LanceOperation::Merge { fragments, schema };
let op = LanceOperation::Merge {
fragments,
schema,
rewritten_indices: vec![],
};
Ok(Self(op))
}

Expand Down
3 changes: 3 additions & 0 deletions rust/lance-index/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,9 @@ pub trait Index: Send + Sync {
/// Cast to [Any].
fn as_any(&self) -> &dyn Any;

/// Cast to mut [Any]
fn as_mut_any(&mut self) -> &mut dyn Any;

/// Cast to [Index]
fn as_index(self: Arc<Self>) -> Arc<dyn Index>;

Expand Down
8 changes: 5 additions & 3 deletions rust/lance-index/src/scalar/btree.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,7 @@ use datafusion_physical_expr::{
PhysicalSortExpr,
};
use futures::{
future::BoxFuture,
stream::{self},
FutureExt, Stream, StreamExt, TryFutureExt, TryStreamExt,
future::BoxFuture, stream, FutureExt, Stream, StreamExt, TryFutureExt, TryStreamExt,
};
use lance_core::{Error, Result};
use lance_datafusion::{
Expand Down Expand Up @@ -786,6 +784,10 @@ impl Index for BTreeIndex {
self
}

fn as_mut_any(&mut self) -> &mut dyn Any {
self
}

fn as_index(self: Arc<Self>) -> Arc<dyn Index> {
self
}
Expand Down
4 changes: 4 additions & 0 deletions rust/lance-index/src/scalar/flat.rs
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,10 @@ impl Index for FlatIndex {
self
}

fn as_mut_any(&mut self) -> &mut dyn Any {
self
}

fn as_index(self: Arc<Self>) -> Arc<dyn Index> {
self
}
Expand Down
6 changes: 6 additions & 0 deletions rust/lance-index/src/vector/pq.rs
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,8 @@ pub trait ProductQuantizer: Send + Sync + std::fmt::Debug {

fn dimension(&self) -> usize;

fn metric_type(&self) -> MetricType;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

would it be better to use DistanceType here?
cc @westonpace


// TODO: move to pub(crate) once the refactor of lance::index to lance-index is done.
fn codebook_as_fsl(&self) -> FixedSizeListArray;

Expand Down Expand Up @@ -436,6 +438,10 @@ impl<T: ArrowFloatType + Dot + L2 + 'static> ProductQuantizer for ProductQuantiz
self.dimension
}

fn metric_type(&self) -> MetricType {
self.metric_type
}

fn codebook_as_fsl(&self) -> FixedSizeListArray {
FixedSizeListArray::try_new_from_values(
self.codebook.as_ref().clone(),
Expand Down
53 changes: 48 additions & 5 deletions rust/lance/src/dataset.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ use lance_arrow::SchemaExt;
use lance_core::datatypes::{Field, SchemaCompareOptions};
use lance_datafusion::utils::reader_to_stream;
use lance_file::datatypes::populate_schema_dictionary;
use lance_index::DatasetIndexExt;
use lance_io::object_store::{ObjectStore, ObjectStoreParams};
use lance_io::object_writer::ObjectWriter;
use lance_io::traits::WriteExt;
Expand Down Expand Up @@ -70,10 +71,11 @@ use self::cleanup::RemovalStats;
use self::feature_flags::{apply_feature_flags, can_read_dataset, can_write_dataset};
use self::fragment::FileFragment;
use self::scanner::{DatasetRecordBatchStream, Scanner};
use self::transaction::{Operation, Transaction};
use self::transaction::{Operation, RewrittenIndex, Transaction};
use self::write::write_fragments_internal;
use crate::datatypes::Schema;
use crate::error::box_error;
use crate::index::cast_index;
use crate::io::commit::{commit_new_dataset, commit_transaction};
use crate::io::exec::Planner;
use crate::session::Session;
Expand Down Expand Up @@ -866,6 +868,7 @@ impl Dataset {
Operation::Merge {
fragments: updated_fragments,
schema: new_schema,
rewritten_indices: vec![],
},
None,
);
Expand Down Expand Up @@ -1643,7 +1646,11 @@ impl Dataset {
let fragments = self
.add_columns_impl(read_columns, mapper, result_checkpoint, None)
.await?;
let operation = Operation::Merge { fragments, schema };
let operation = Operation::Merge {
fragments,
schema,
rewritten_indices: vec![],
};
let transaction = Transaction::new(self.manifest.version, operation, None);
let new_manifest = commit_transaction(
self,
Expand Down Expand Up @@ -1825,6 +1832,8 @@ impl Dataset {
// This schema contains the exact field ids we want to write the new fields with.
let new_col_schema = new_schema.project_by_ids(&new_ids);

let cast_fields_copy = cast_fields.clone();

let mapper = move |batch: &RecordBatch| {
let mut fields = Vec::with_capacity(cast_fields.len());
let mut columns = Vec::with_capacity(batch.num_columns());
Expand Down Expand Up @@ -1856,11 +1865,35 @@ impl Dataset {
)
.await?;

// Also need to cast the indices, if possible.
let vector_indices = self.load_indices().await?;
let mut rewritten_indices = Vec::new();
for (from_field, to_field) in cast_fields_copy.iter() {
let affected_indices = vector_indices
.iter()
.filter(|index| index.fields.len() == 1 && index.fields[0] == from_field.id)
.collect::<Vec<_>>();
for index in affected_indices {
let res = cast_index(self, &index.uuid, from_field, to_field).await;
match res {
Ok(new_id) => rewritten_indices.push(RewrittenIndex {
old_id: index.uuid,
new_id,
new_field_ids: vec![to_field.id],
}),
// If it's not yet supported, we just skip it.
Err(Error::NotSupported { .. }) => continue,
Err(e) => return Err(e),
}
}
}

Transaction::new(
self.manifest.version,
Operation::Merge {
schema: new_schema,
fragments,
rewritten_indices,
},
None,
)
Expand Down Expand Up @@ -4672,7 +4705,7 @@ mod tests {
assert_eq!(f.files.len(), 3);
});

// Cast vector column, should not keep index (TODO: keep it)
// Cast vector column, should keep it
dataset
.alter_columns(&[
ColumnAlteration::new("vec".into()).cast_to(DataType::FixedSizeList(
Expand Down Expand Up @@ -4702,9 +4735,9 @@ mod tests {
]);
assert_eq!(&ArrowSchema::from(dataset.schema()), &expected_schema);

// We currently lose the index when casting a column
// We keep vector indices when casting a column
let indices = dataset.load_indices().await?;
assert_eq!(indices.len(), 0);
assert_eq!(indices.len(), 1);

// Each fragment gains a file with the new columns
dataset.fragments().iter().for_each(|f| {
Expand Down Expand Up @@ -4734,6 +4767,16 @@ mod tests {
let actual_data = dataset.scan().try_into_batch().await?;
assert_eq!(actual_data, expected_data);

let query_data = dataset
.scan()
.nearest("vec", &vec![0.0f32; 128].into(), 10)
.unwrap()
.nprobs(10)
.try_into_batch()
.await
.unwrap();
assert_eq!(query_data.num_rows(), 10);

Ok(())
}
}
1 change: 1 addition & 0 deletions rust/lance/src/dataset/fragment.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1559,6 +1559,7 @@ mod tests {
Operation::Merge {
schema,
fragments: vec![frag],
rewritten_indices: vec![],
},
Some(dataset.manifest.version),
None,
Expand Down
1 change: 1 addition & 0 deletions rust/lance/src/dataset/optimize.rs
Original file line number Diff line number Diff line change
Expand Up @@ -893,6 +893,7 @@ pub async fn commit_compaction(
.map(|rewritten| RewrittenIndex {
old_id: rewritten.original,
new_id: rewritten.new,
new_field_ids: vec![],
})
.collect();

Expand Down
42 changes: 37 additions & 5 deletions rust/lance/src/dataset/transaction.rs
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,7 @@ pub enum Operation {
Merge {
fragments: Vec<Fragment>,
schema: Schema,
rewritten_indices: Vec<RewrittenIndex>,
},
/// Restore an old version of the database
Restore { version: u64 },
Expand Down Expand Up @@ -154,6 +155,7 @@ pub enum Operation {
pub struct RewrittenIndex {
pub old_id: Uuid,
pub new_id: Uuid,
pub new_field_ids: Vec<i32>,
}

#[derive(Debug, Clone)]
Expand Down Expand Up @@ -487,9 +489,22 @@ impl Transaction {
Operation::ReserveFragments { .. } => {
final_fragments.extend(maybe_existing_fragments?.clone());
}
Operation::Merge { ref fragments, .. } => {
Operation::Merge {
ref fragments,
rewritten_indices,
..
} => {
final_fragments.extend(fragments.clone());

for rewritten in rewritten_indices {
let index = final_indices
.iter_mut()
.find(|idx| idx.uuid == rewritten.old_id)
.unwrap();
index.uuid = rewritten.new_id;
index.fields = rewritten.new_field_ids.clone();
}

// Some fields that have indices may have been removed, so we should
// remove those indices as well.
Self::retain_relevant_indices(&mut final_indices, &schema)
Expand Down Expand Up @@ -736,11 +751,16 @@ impl TryFrom<&pb::Transaction> for Transaction {
},
Some(pb::transaction::Operation::Merge(pb::transaction::Merge {
fragments,
rewritten_indices,
schema,
schema_metadata: _schema_metadata, // TODO: handle metadata
})) => Operation::Merge {
fragments: fragments.iter().map(Fragment::from).collect(),
schema: Schema::from(&Fields(schema.clone())),
rewritten_indices: rewritten_indices
.iter()
.map(RewrittenIndex::try_from)
.collect::<Result<_>>()?,
},
Some(pb::transaction::Operation::Restore(pb::transaction::Restore { version })) => {
Operation::Restore { version: *version }
Expand Down Expand Up @@ -779,10 +799,10 @@ impl TryFrom<&pb::Transaction> for Transaction {
}
}

impl TryFrom<&pb::transaction::rewrite::RewrittenIndex> for RewrittenIndex {
impl TryFrom<&pb::transaction::RewrittenIndex> for RewrittenIndex {
type Error = Error;

fn try_from(message: &pb::transaction::rewrite::RewrittenIndex) -> Result<Self> {
fn try_from(message: &pb::transaction::RewrittenIndex) -> Result<Self> {
Ok(Self {
old_id: message
.old_id
Expand All @@ -800,6 +820,7 @@ impl TryFrom<&pb::transaction::rewrite::RewrittenIndex> for RewrittenIndex {
message: "required field (new_id) missing from message".to_string(),
location: location!(),
})??,
new_field_ids: message.new_field_ids.clone(),
})
}
}
Expand Down Expand Up @@ -868,9 +889,17 @@ impl From<&Transaction> for pb::Transaction {
new_indices: new_indices.iter().map(IndexMetadata::from).collect(),
removed_indices: removed_indices.iter().map(IndexMetadata::from).collect(),
}),
Operation::Merge { fragments, schema } => {
Operation::Merge {
fragments,
schema,
rewritten_indices,
} => {
pb::transaction::Operation::Merge(pb::transaction::Merge {
fragments: fragments.iter().map(pb::DataFragment::from).collect(),
rewritten_indices: rewritten_indices
.iter()
.map(|rewritten| rewritten.into())
.collect(),
schema: Fields::from(schema).0,
schema_metadata: Default::default(), // TODO: handle metadata
})
Expand Down Expand Up @@ -906,11 +935,12 @@ impl From<&Transaction> for pb::Transaction {
}
}

impl From<&RewrittenIndex> for pb::transaction::rewrite::RewrittenIndex {
impl From<&RewrittenIndex> for pb::transaction::RewrittenIndex {
fn from(value: &RewrittenIndex) -> Self {
Self {
old_id: Some((&value.old_id).into()),
new_id: Some((&value.new_id).into()),
new_field_ids: value.new_field_ids.clone(),
}
}
}
Expand Down Expand Up @@ -965,6 +995,7 @@ mod tests {
Operation::Merge {
fragments: vec![fragment0.clone(), fragment2.clone()],
schema: Schema::default(),
rewritten_indices: vec![],
},
Operation::Overwrite {
fragments: vec![fragment0.clone(), fragment2.clone()],
Expand Down Expand Up @@ -1059,6 +1090,7 @@ mod tests {
Operation::Merge {
fragments: vec![fragment0.clone(), fragment2.clone()],
schema: Schema::default(),
rewritten_indices: vec![],
},
// Merge conflicts with everything except CreateIndex and ReserveFragments.
[true, false, true, true, true, true, false, true],
Expand Down
Loading
Loading