diff --git a/python/python/lance/dataset.py b/python/python/lance/dataset.py index 4cbd39fdebb..a3aea17c6b2 100644 --- a/python/python/lance/dataset.py +++ b/python/python/lance/dataset.py @@ -3304,6 +3304,7 @@ def _create_index_impl( *, target_partition_size: Optional[int] = None, skip_transpose: bool = False, + rq_rotation: Optional[str] = None, require_commit: bool = True, **kwargs, ) -> Index: @@ -3619,6 +3620,9 @@ def _create_index_impl( if skip_transpose: kwargs["skip_transpose"] = True + if rq_rotation is not None: + kwargs["rq_rotation"] = rq_rotation + # Add fragment_ids and index_uuid to kwargs if provided for # distributed indexing if fragment_ids is not None: @@ -3919,6 +3923,7 @@ def create_index_uncommitted( *, target_partition_size: Optional[int] = None, skip_transpose: bool = False, + rq_rotation: Optional[str] = None, **kwargs, ) -> Index: """ @@ -3945,6 +3950,12 @@ def create_index_uncommitted( requirement: - ``fragment_ids`` must be provided + - ``rq_rotation`` (``IVF_RQ`` only): a JSON string produced by + ``lance.lance.indices.build_rq_rotation``. It must be identical across all + workers for their segments to be mergeable, since it pins the RaBitQ + rotation so every segment rotates vectors the same way. If omitted, each + call generates its own random rotation, which is only safe for a single, + non-merged segment. Returns ------- @@ -3974,6 +3985,7 @@ def create_index_uncommitted( index_uuid=index_uuid, target_partition_size=target_partition_size, skip_transpose=skip_transpose, + rq_rotation=rq_rotation, require_commit=False, **kwargs, ) diff --git a/python/python/lance/lance/indices/__init__.pyi b/python/python/lance/lance/indices/__init__.pyi index fc5d03b80bd..8afb1761f86 100644 --- a/python/python/lance/lance/indices/__init__.pyi +++ b/python/python/lance/lance/indices/__init__.pyi @@ -67,6 +67,12 @@ def transform_vectors( pq_codebook: pa.Array, dst_uri: str, ): ... +def build_rq_rotation( + dimension: int, + num_bits: int = 1, + rotation_type: str = "fast", + dtype: str = "float32", +) -> str: ... class IndexSegmentDescription: uuid: str diff --git a/python/python/tests/test_vector_index.py b/python/python/tests/test_vector_index.py index 4cf1c4947e3..49ec6c28501 100644 --- a/python/python/tests/test_vector_index.py +++ b/python/python/tests/test_vector_index.py @@ -3023,6 +3023,51 @@ def test_commit_existing_index_segments_accepts_index_metadata(tmp_path): assert 0 < len(results) <= 5 +def test_distributed_ivf_rq_shared_rotation(tmp_path): + """Two IVF_RQ segments built on separate fragments with one shared RaBitQ rotation + merge into a single committed, queryable index. The shared ``rq_rotation`` (from + ``lance.lance.indices.build_rq_rotation``) is what makes the independently built + segments mergeable.""" + from lance.lance import indices + + dim = 32 + ds = _make_sample_dataset_base( + tmp_path, "dist_rq_merge", n_rows=512, dim=dim, max_rows_per_file=256 + ) + frags = ds.get_fragments() + assert len(frags) == 2 + + ivf_model = IndicesBuilder(ds, "vector").train_ivf( + num_partitions=2, + distance_type="l2", + sample_rate=8, + ) + rq_rotation = indices.build_rq_rotation(dimension=dim, num_bits=1) + base_kwargs = { + "column": "vector", + "index_type": "IVF_RQ", + "num_partitions": 2, + "num_bits": 1, + "ivf_centroids": ivf_model.centroids, + "rq_rotation": rq_rotation, + } + first = ds.create_index_uncommitted( + **base_kwargs, + fragment_ids=[frags[0].fragment_id], + ) + second = ds.create_index_uncommitted( + **base_kwargs, + fragment_ids=[frags[1].fragment_id], + ) + + merged = ds.merge_existing_index_segments([first, second]) + ds = ds.commit_existing_index_segments("vector_idx", "vector", [merged]) + + q = np.random.rand(dim).astype(np.float32) + results = ds.to_table(nearest={"column": "vector", "q": q, "k": 5}) + assert 0 < len(results) <= 5 + + def test_index_segment_builder_builds_vector_segments(tmp_path): ds = _make_sample_dataset_base(tmp_path, "segment_builder_ds", 2000, 128) frags = ds.get_fragments() diff --git a/python/src/dataset.rs b/python/src/dataset.rs index 0f834484f30..3143854d50a 100644 --- a/python/src/dataset.rs +++ b/python/src/dataset.rs @@ -20,6 +20,7 @@ use blob::LanceBlobFile; use chrono::{Duration, TimeDelta, Utc}; use futures::{StreamExt, TryFutureExt}; use lance_index::vector::bq::RQBuildParams; +use lance_index::vector::bq::storage::RabitQuantizationMetadata; use log::error; use object_store::path::Path; use pyo3::exceptions::{PyStopIteration, PyTypeError}; @@ -4429,6 +4430,13 @@ fn prepare_vector_index_params( pq_params.codebook = Some(codebook.values().clone()) }; + if let Some(r) = kwargs.get_item("rq_rotation")? { + let json: String = r.extract()?; + let meta: RabitQuantizationMetadata = serde_json::from_str(&json) + .map_err(|e| PyValueError::new_err(format!("Invalid rq_rotation JSON: {e}")))?; + rq_params.rotation = Some(meta); + }; + if let Some(version) = kwargs.get_item("index_file_version")? { let version: String = version.extract()?; index_file_version = IndexFileVersion::try_from(&version) diff --git a/python/src/indices.rs b/python/src/indices.rs index cf93579b867..7a7667fbc18 100644 --- a/python/src/indices.rs +++ b/python/src/indices.rs @@ -358,6 +358,79 @@ fn train_pq_model<'py>( codebook.to_pyarrow(py) } +/// Mint one RaBitQ rotation and return it as a JSON string. +/// +/// Distributed IVF_RQ builds must pin a single rotation across all workers so that +/// independently built per-fragment segments rotate vectors identically and their +/// binary codes remain comparable when merged. A driver calls this once and broadcasts +/// the resulting string to every `create_index_uncommitted(..., rq_rotation=...)` call. +/// +/// Only the "fast" rotation is supported since its sign vector is JSON-serializable, whereas +/// the "matrix" rotation stores a dense matrix in a binary buffer that is dropped by the +/// JSON wire format. `dtype` is accepted for API symmetry but does not affect the fast +/// rotation. +/// +/// # Example (Python) +/// +/// ```python +/// from lance.lance import indices +/// +/// # Mint one rotation and broadcast `rot` to every worker. +/// rot = indices.build_rq_rotation(dimension=128, num_bits=1) +/// seg = ds.create_index_uncommitted( +/// column="vector", +/// index_type="IVF_RQ", +/// num_partitions=256, +/// ivf_centroids=centroids, +/// rq_rotation=rot, +/// fragment_ids=my_fragments, +/// ) +/// ``` +#[pyfunction] +#[pyo3(signature = (dimension, num_bits=1, rotation_type="fast", dtype="float32"))] +pub fn build_rq_rotation( + dimension: usize, + num_bits: u8, + rotation_type: &str, + dtype: &str, +) -> PyResult { + use arrow::datatypes::{Float16Type, Float32Type, Float64Type}; + use lance_index::vector::bq::builder::RabitQuantizer; + use lance_index::vector::bq::RQRotationType; + use lance_index::vector::quantizer::Quantization; + + if !dimension.is_multiple_of(u8::BITS as usize) { + return Err(PyValueError::new_err( + "dimension must be divisible by 8 for IVF_RQ", + )); + } + let rotation = match rotation_type.to_lowercase().as_str() { + "fast" => RQRotationType::Fast, + "matrix" => { + return Err(PyValueError::new_err( + "matrix rotation cannot be serialized to JSON for distributed builds; \ + use rotation_type='fast'", + )); + } + other => { + return Err(PyValueError::new_err(format!( + "unknown rotation_type: {other}; expected 'fast'" + ))); + } + }; + let dim = dimension as i32; + let quantizer = match dtype.to_lowercase().as_str() { + "float16" => RabitQuantizer::new_with_rotation::(num_bits, dim, rotation), + "float32" => RabitQuantizer::new_with_rotation::(num_bits, dim, rotation), + "float64" => RabitQuantizer::new_with_rotation::(num_bits, dim, rotation), + other => { + return Err(PyValueError::new_err(format!("unsupported dtype: {other}"))); + } + }; + serde_json::to_string(&quantizer.metadata(None)) + .map_err(|e| PyValueError::new_err(format!("failed to serialize RQ rotation: {e}"))) +} + #[allow(clippy::too_many_arguments)] async fn do_transform_vectors( dataset: &Dataset, @@ -752,6 +825,7 @@ pub fn register_indices(py: Python, m: &Bound<'_, PyModule>) -> PyResult<()> { let indices = PyModule::new(py, "indices")?; indices.add_wrapped(wrap_pyfunction!(train_ivf_model))?; indices.add_wrapped(wrap_pyfunction!(train_pq_model))?; + indices.add_wrapped(wrap_pyfunction!(build_rq_rotation))?; indices.add_wrapped(wrap_pyfunction!(transform_vectors))?; indices.add_wrapped(wrap_pyfunction!(shuffle_transformed_vectors))?; indices.add_wrapped(wrap_pyfunction!(load_shuffled_vectors))?; diff --git a/rust/lance-index/src/vector/bq.rs b/rust/lance-index/src/vector/bq.rs index a0a16b22169..62de70f2bf3 100644 --- a/rust/lance-index/src/vector/bq.rs +++ b/rust/lance-index/src/vector/bq.rs @@ -14,6 +14,7 @@ use lance_core::{Error, Result}; use num_traits::Float; use serde::{Deserialize, Serialize}; +use crate::vector::bq::storage::RabitQuantizationMetadata; use crate::vector::quantizer::QuantizerBuildParams; pub mod builder; @@ -100,10 +101,16 @@ impl FromStr for RQRotationType { } } -#[derive(Clone, Debug, PartialEq, Eq)] +#[derive(Clone, Debug)] pub struct RQBuildParams { pub num_bits: u8, pub rotation_type: RQRotationType, + /// Optional pre-built rotation to reuse instead of generating a fresh random one. + /// + /// Distributed `IVF_RQ` builds mint one rotation and broadcast it so every segment + /// rotates vectors identically. This is transient build-time state and is never + /// persisted to the `RabitQuantization` params proto. + pub rotation: Option, } impl RQBuildParams { @@ -111,6 +118,7 @@ impl RQBuildParams { Self { num_bits, rotation_type: RQRotationType::default(), + rotation: None, } } @@ -118,6 +126,7 @@ impl RQBuildParams { Self { num_bits, rotation_type, + rotation: None, } } } @@ -146,6 +155,7 @@ impl Default for RQBuildParams { Self { num_bits: 1, rotation_type: RQRotationType::default(), + rotation: None, } } } diff --git a/rust/lance-index/src/vector/bq/builder.rs b/rust/lance-index/src/vector/bq/builder.rs index 491e14d3af9..c61c35374e1 100644 --- a/rust/lance-index/src/vector/bq/builder.rs +++ b/rust/lance-index/src/vector/bq/builder.rs @@ -22,7 +22,7 @@ use crate::vector::bq::storage::{ use crate::vector::bq::transform::{ADD_FACTORS_FIELD, SCALE_FACTORS_FIELD}; use crate::vector::bq::{ RQBuildParams, RQRotationType, - rotation::{apply_fast_rotation, random_fast_rotation_signs}, + rotation::{apply_fast_rotation, fast_rotation_signs_len, random_fast_rotation_signs}, }; use crate::vector::quantizer::{Quantization, Quantizer, QuantizerBuildParams}; @@ -324,6 +324,46 @@ impl Quantization for RabitQuantizer { )); } + // Reuse a supplied rotation instead of generating a fresh random one. + if let Some(meta) = ¶ms.rotation { + let expected_code_dim = dim * params.num_bits as usize; + if meta.num_bits != params.num_bits || meta.code_dim as usize != expected_code_dim { + return Err(Error::invalid_input(format!( + "supplied RaBitQ rotation does not match build params: rotation \ + num_bits={}, code_dim={}; expected num_bits={}, code_dim={}", + meta.num_bits, meta.code_dim, params.num_bits, expected_code_dim + ))); + } + + match meta.rotation_type { + RQRotationType::Fast => { + let signs = meta.fast_rotation_signs.as_ref().ok_or_else(|| { + Error::invalid_input("supplied fast RaBitQ rotation is missing signs") + })?; + let expected_len = fast_rotation_signs_len(meta.code_dim as usize); + if signs.len() != expected_len { + return Err(Error::invalid_input(format!( + "supplied fast RaBitQ rotation signs length {} does not match \ + expected {} for code_dim={}", + signs.len(), + expected_len, + meta.code_dim + ))); + } + } + RQRotationType::Matrix => { + if meta.rotate_mat.is_none() { + return Err(Error::invalid_input( + "use the fast rotation for distributed builds", + )); + } + } + } + return Ok(Self { + metadata: meta.clone(), + }); + } + let q = match data.as_fixed_size_list().value_type() { DataType::Float16 => Self::new_with_rotation::( params.num_bits, @@ -582,4 +622,122 @@ mod tests { err ); } + + fn sample_fsl(n: usize, dim: usize) -> FixedSizeListArray { + let values: Vec = (0..n * dim).map(|i| ((i * 31 % 17) as f32) - 8.0).collect(); + FixedSizeListArray::try_new_from_values(Float32Array::from(values), dim as i32).unwrap() + } + + fn quantized_codes(q: &RabitQuantizer, data: &FixedSizeListArray) -> Vec { + use arrow::datatypes::UInt8Type; + q.quantize(data) + .unwrap() + .as_fixed_size_list() + .values() + .as_primitive::() + .values() + .to_vec() + } + + #[test] + fn test_shared_fast_rotation_gives_identical_codes() { + let dim = 32; + let seed = RabitQuantizer::new_with_rotation::(1, dim, RQRotationType::Fast); + let json = serde_json::to_string(&seed.metadata(None)).unwrap(); + let meta: RabitQuantizationMetadata = serde_json::from_str(&json).unwrap(); + + let params = RQBuildParams { + num_bits: 1, + rotation_type: RQRotationType::Fast, + rotation: Some(meta), + }; + let data = sample_fsl(8, dim as usize); + let q_a = RabitQuantizer::build(&data, DistanceType::L2, ¶ms).unwrap(); + let q_b = RabitQuantizer::build(&data, DistanceType::L2, ¶ms).unwrap(); + + assert_eq!( + quantized_codes(&q_a, &data), + quantized_codes(&q_b, &data), + "shared rotation must yield identical codes" + ); + } + + #[test] + fn test_unpinned_rotation_gives_different_codes() { + let dim = 32; + let params = RQBuildParams::new(1); + let data = sample_fsl(8, dim as usize); + let q_a = RabitQuantizer::build(&data, DistanceType::L2, ¶ms).unwrap(); + let q_b = RabitQuantizer::build(&data, DistanceType::L2, ¶ms).unwrap(); + + assert_ne!( + quantized_codes(&q_a, &data), + quantized_codes(&q_b, &data), + "independent unpinned rotations must yield different codes" + ); + } + + #[test] + fn test_build_rejects_rotation_with_mismatched_code_dim() { + let seed = RabitQuantizer::new_with_rotation::(1, 16, RQRotationType::Fast); + let params = RQBuildParams { + num_bits: 1, + rotation_type: RQRotationType::Fast, + rotation: Some(seed.metadata(None)), + }; + let data = sample_fsl(4, 32); + let err = RabitQuantizer::build(&data, DistanceType::L2, ¶ms).unwrap_err(); + assert!( + err.to_string().contains("does not match build params"), + "{}", + err + ); + } + + #[test] + fn test_build_rejects_fast_rotation_with_bad_signs_length() { + let dim = 16; + let seed = RabitQuantizer::new_with_rotation::(1, dim, RQRotationType::Fast); + let mut meta = seed.metadata(None); + // Corrupt the signs to the wrong length (valid would be 4 * ceil(16/8) = 8). + meta.fast_rotation_signs = Some(vec![0u8; 7]); + let params = RQBuildParams { + num_bits: 1, + rotation_type: RQRotationType::Fast, + rotation: Some(meta), + }; + let data = sample_fsl(4, dim as usize); + let err = RabitQuantizer::build(&data, DistanceType::L2, ¶ms).unwrap_err(); + assert!(err.to_string().contains("signs length"), "{}", err); + } + + #[test] + fn test_matrix_rotation_lost_through_json_is_rejected() { + let dim = 16; + let seed = RabitQuantizer::new_with_rotation::(1, dim, RQRotationType::Matrix); + let meta = seed.metadata(None); + assert!(meta.rotate_mat.is_some()); + + let json = serde_json::to_string(&meta).unwrap(); + let parsed: RabitQuantizationMetadata = serde_json::from_str(&json).unwrap(); + assert_eq!(parsed.rotation_type, RQRotationType::Matrix); + assert!( + parsed.rotate_mat.is_none(), + "matrix is expected to be dropped by JSON serialization" + ); + + let params = RQBuildParams { + num_bits: 1, + rotation_type: RQRotationType::Matrix, + rotation: Some(parsed), + }; + let data = sample_fsl(4, dim as usize); + let err = RabitQuantizer::build(&data, DistanceType::L2, ¶ms).unwrap_err(); + assert!( + err.to_string() + .contains("fast rotation for distributed builds"), + "{}", + err + ); + } } diff --git a/rust/lance-index/src/vector/bq/rotation.rs b/rust/lance-index/src/vector/bq/rotation.rs index de4fbf549f1..2346c772cb3 100644 --- a/rust/lance-index/src/vector/bq/rotation.rs +++ b/rust/lance-index/src/vector/bq/rotation.rs @@ -138,9 +138,13 @@ fn sign_bytes_per_round(dim: usize) -> usize { dim.div_ceil(8) } +pub(crate) fn fast_rotation_signs_len(dim: usize) -> usize { + FAST_ROTATION_ROUNDS * sign_bytes_per_round(dim) +} + pub fn random_fast_rotation_signs(dim: usize) -> Vec { // Each round needs one random sign bit per dimension. - let mut signs = vec![0u8; FAST_ROTATION_ROUNDS * sign_bytes_per_round(dim)]; + let mut signs = vec![0u8; fast_rotation_signs_len(dim)]; rand::rng().fill_bytes(&mut signs); signs } diff --git a/rust/lance/src/index/vector.rs b/rust/lance/src/index/vector.rs index 87b32344ec6..526ba62be56 100644 --- a/rust/lance/src/index/vector.rs +++ b/rust/lance/src/index/vector.rs @@ -1877,6 +1877,7 @@ fn derive_rabit_params(rabit_quantizer: &RabitQuantizer) -> RQBuildParams { RQBuildParams { num_bits: rabit_quantizer.num_bits(), rotation_type: rabit_quantizer.rotation_type(), + rotation: None, } }