diff --git a/python/python/lance/dataset.py b/python/python/lance/dataset.py index a2b62f422be..4e4cae37042 100644 --- a/python/python/lance/dataset.py +++ b/python/python/lance/dataset.py @@ -2068,7 +2068,13 @@ def create_index( kwargs["metric_type"] = metric index_type = index_type.upper() - valid_index_types = ["IVF_FLAT", "IVF_PQ", "IVF_HNSW_PQ", "IVF_HNSW_SQ"] + valid_index_types = [ + "IVF_FLAT", + "IVF_PQ", + "IVF_HNSW_FLAT", + "IVF_HNSW_PQ", + "IVF_HNSW_SQ", + ] if index_type not in valid_index_types: raise NotImplementedError( f"Only {valid_index_types} index types supported. Got {index_type}" diff --git a/python/python/tests/test_vector_index.py b/python/python/tests/test_vector_index.py index a1b0d5247e6..4dbbe9beadb 100644 --- a/python/python/tests/test_vector_index.py +++ b/python/python/tests/test_vector_index.py @@ -561,6 +561,18 @@ def test_create_ivf_hnsw_sq_index(dataset, tmp_path): assert ann_ds.list_indices()[0]["fields"] == ["vector"] +def test_create_ivf_hnsw_flat_index(dataset, tmp_path): + assert not dataset.has_index + ann_ds = lance.write_dataset(dataset.to_table(), tmp_path / "indexed.lance") + ann_ds = ann_ds.create_index( + "vector", + index_type="IVF_HNSW_FLAT", + num_partitions=4, + num_sub_vectors=16, + ) + assert ann_ds.list_indices()[0]["fields"] == ["vector"] + + def test_multivec_ann(indexed_multivec_dataset: lance.LanceDataset): query = np.random.rand(5, 128) results = indexed_multivec_dataset.scanner( diff --git a/python/src/dataset.rs b/python/src/dataset.rs index 20220c53943..8fb0c56a9e8 100644 --- a/python/src/dataset.rs +++ b/python/src/dataset.rs @@ -1344,7 +1344,9 @@ impl Dataset { "NGRAM" => IndexType::NGram, "LABEL_LIST" => IndexType::LabelList, "INVERTED" | "FTS" => IndexType::Inverted, - "IVF_FLAT" | "IVF_PQ" | "IVF_HNSW_PQ" | "IVF_HNSW_SQ" => IndexType::Vector, + "IVF_FLAT" | "IVF_PQ" | "IVF_HNSW_FLAT" | "IVF_HNSW_PQ" | "IVF_HNSW_SQ" => { + IndexType::Vector + } _ => { return Err(PyValueError::new_err(format!( "Index type '{index_type}' is not supported." @@ -2153,6 +2155,12 @@ fn prepare_vector_index_params( m_type, ivf_params, pq_params, ))), + "IVF_HNSW_FLAT" => Ok(Box::new(VectorIndexParams::ivf_hnsw( + m_type, + ivf_params, + hnsw_params, + ))), + "IVF_HNSW_PQ" => Ok(Box::new(VectorIndexParams::with_ivf_hnsw_pq_params( m_type, ivf_params, diff --git a/rust/lance-index/src/lib.rs b/rust/lance-index/src/lib.rs index 3ed696cc193..4f5b3e75bea 100644 --- a/rust/lance-index/src/lib.rs +++ b/rust/lance-index/src/lib.rs @@ -99,6 +99,7 @@ pub enum IndexType { IvfPq = 103, IvfHnswSq = 104, IvfHnswPq = 105, + IvfHnswFlat = 106, } impl std::fmt::Display for IndexType { @@ -115,6 +116,7 @@ impl std::fmt::Display for IndexType { Self::IvfSq => write!(f, "IVF_SQ"), Self::IvfHnswSq => write!(f, "IVF_HNSW_SQ"), Self::IvfHnswPq => write!(f, "IVF_HNSW_PQ"), + Self::IvfHnswFlat => write!(f, "IVF_HNSW_FLAT"), } } } @@ -136,6 +138,7 @@ impl TryFrom for IndexType { v if v == Self::IvfPq as i32 => Ok(Self::IvfPq), v if v == Self::IvfHnswSq as i32 => Ok(Self::IvfHnswSq), v if v == Self::IvfHnswPq as i32 => Ok(Self::IvfHnswPq), + v if v == Self::IvfHnswFlat as i32 => Ok(Self::IvfHnswFlat), _ => Err(Error::InvalidInput { source: format!("the input value {} is not a valid IndexType", value).into(), location: location!(), @@ -164,6 +167,7 @@ impl IndexType { | Self::IvfPq | Self::IvfHnswSq | Self::IvfHnswPq + | Self::IvfHnswFlat | Self::IvfFlat | Self::IvfSq ) @@ -191,7 +195,8 @@ impl IndexType { | Self::IvfSq | Self::IvfPq | Self::IvfHnswSq - | Self::IvfHnswPq => 1, + | Self::IvfHnswPq + | Self::IvfHnswFlat => 1, } } } diff --git a/rust/lance/src/index.rs b/rust/lance/src/index.rs index bcc81f091d0..39eae0cd9f0 100644 --- a/rust/lance/src/index.rs +++ b/rust/lance/src/index.rs @@ -1091,6 +1091,18 @@ impl DatasetIndexInternalExt for Dataset { Ok(Arc::new(ivf) as Arc) } + "IVF_HNSW_FLAT" => { + let ivf = IVFIndex::::try_new( + self.object_store.clone(), + self.indices_dir(), + uuid.to_owned(), + Arc::downgrade(&self.session), + fri, + ) + .await?; + Ok(Arc::new(ivf) as Arc) + } + "IVF_HNSW_SQ" => { let ivf = IVFIndex::::try_new( self.object_store.clone(), diff --git a/rust/lance/src/index/vector.rs b/rust/lance/src/index/vector.rs index 5b5f79d50c1..8748d3fe5d3 100644 --- a/rust/lance/src/index/vector.rs +++ b/rust/lance/src/index/vector.rs @@ -145,6 +145,19 @@ impl VectorIndexParams { } } + pub fn ivf_hnsw( + distance_type: DistanceType, + ivf: IvfBuildParams, + hnsw: HnswBuildParams, + ) -> Self { + let stages = vec![StageParams::Ivf(ivf), StageParams::Hnsw(hnsw)]; + Self { + stages, + metric_type: distance_type, + version: IndexFileVersion::V3, + } + } + /// Create index parameters with `IVF`, `PQ` and `HNSW` parameters, respectively. /// This is used for `IVF_HNSW_PQ` index. pub fn with_ivf_hnsw_pq_params( @@ -392,6 +405,21 @@ pub(crate) async fn build_vector_index( }); } } + } else { + // without quantization + IvfIndexBuilder::::new( + dataset.clone(), + column.to_owned(), + dataset.indices_dir().child(uuid), + params.metric_type, + Box::new(shuffler), + Some(ivf_params.clone()), + Some(()), + hnsw_params.clone(), + fri, + )? + .build() + .await?; } } else { return Err(Error::Index { diff --git a/rust/lance/src/index/vector/ivf.rs b/rust/lance/src/index/vector/ivf.rs index 2654eb059f9..7c11795186d 100644 --- a/rust/lance/src/index/vector/ivf.rs +++ b/rust/lance/src/index/vector/ivf.rs @@ -447,6 +447,29 @@ pub(crate) async fn optimize_vector_indices_v2( .build() .await?; } + // IVF_HNSW_FLAT + (SubIndexType::Hnsw, QuantizationType::Flat) => { + IvfIndexBuilder::::new( + dataset.clone(), + vector_column.to_owned(), + index_dir, + distance_type, + shuffler, + None, + None, + // TODO: get the HNSW parameters from the existing indices + HnswBuildParams::default(), + fri, + )? + .with_ivf(ivf_model.clone()) + .with_quantizer(quantizer.try_into()?) + .with_existing_indices(indices_to_merge) + .retrain(options.retrain) + .shuffle_data(unindexed) + .await? + .build() + .await?; + } // IVF_HNSW_SQ (SubIndexType::Hnsw, QuantizationType::Scalar) => { IvfIndexBuilder::::new( diff --git a/rust/lance/src/index/vector/ivf/v2.rs b/rust/lance/src/index/vector/ivf/v2.rs index d89b6cdcbd7..9140d5056f3 100644 --- a/rust/lance/src/index/vector/ivf/v2.rs +++ b/rust/lance/src/index/vector/ivf/v2.rs @@ -348,7 +348,7 @@ impl Index for IVFIndex IndexType::IvfSq, (SubIndexType::Hnsw, QuantizationType::Product) => IndexType::IvfHnswPq, (SubIndexType::Hnsw, QuantizationType::Scalar) => IndexType::IvfHnswSq, - _ => IndexType::Vector, + (SubIndexType::Hnsw, QuantizationType::Flat) => IndexType::IvfHnswFlat, } } @@ -1244,6 +1244,26 @@ mod tests { test_optimize_strategy(params).await; } + #[rstest] + #[case(4, DistanceType::L2, 0.9)] + #[case(4, DistanceType::Cosine, 0.9)] + #[case(4, DistanceType::Dot, 0.85)] + #[tokio::test] + async fn test_create_ivf_hnsw_flat( + #[case] nlist: usize, + #[case] distance_type: DistanceType, + #[case] recall_requirement: f32, + ) { + let ivf_params = IvfBuildParams::new(nlist); + let hnsw_params = HnswBuildParams::default(); + let params = VectorIndexParams::ivf_hnsw(distance_type, ivf_params, hnsw_params); + test_index(params.clone(), nlist, recall_requirement, None).await; + if distance_type == DistanceType::Cosine { + test_index_multivec(params.clone(), nlist, recall_requirement).await; + } + test_optimize_strategy(params).await; + } + #[rstest] #[case(4, DistanceType::L2, 0.9)] #[case(4, DistanceType::Cosine, 0.9)]