diff --git a/diskann-disk/src/data_model/cache.rs b/diskann-disk/src/data_model/cache.rs index f51ef452c..e0b18d00d 100644 --- a/diskann-disk/src/data_model/cache.rs +++ b/diskann-disk/src/data_model/cache.rs @@ -5,17 +5,14 @@ use crate::data_model::GraphDataType; use diskann::{graph::AdjacencyList, ANNError, ANNResult}; -use diskann_providers::common::AlignedBoxWithSlice; use hashbrown::{hash_map::Entry::Occupied, HashMap}; -use super::FP_VECTOR_MEM_ALIGN; - pub struct Cache> { // Maintains the mapping of vector_id to index in the global cached nodes list. mapping: HashMap, - // Aligned buffer to store the vectors of cached nodes. - vectors: AlignedBoxWithSlice, + // Flat buffer holding `capacity * dimension` vector elements, laid out row-major. + vectors: Vec, // The cached adjacency lists. adjacency_lists: Vec>, @@ -38,7 +35,7 @@ where pub fn new(dimension: usize, capacity: usize) -> ANNResult { Ok(Self { mapping: HashMap::new(), - vectors: AlignedBoxWithSlice::new(capacity * dimension, FP_VECTOR_MEM_ALIGN)?, + vectors: vec![Data::VectorDataType::default(); capacity * dimension], adjacency_lists: Vec::with_capacity(capacity), associated_data: Vec::with_capacity(capacity), dimension, diff --git a/diskann-disk/src/data_model/mod.rs b/diskann-disk/src/data_model/mod.rs index e5da97585..b47f89832 100644 --- a/diskann-disk/src/data_model/mod.rs +++ b/diskann-disk/src/data_model/mod.rs @@ -17,5 +17,3 @@ pub use cache::{Cache, CachingStrategy}; pub mod graph_data_types; pub use graph_data_types::{AdHoc, GraphDataType}; - -pub const FP_VECTOR_MEM_ALIGN: usize = 32; diff --git a/diskann-disk/src/search/pq/pq_scratch.rs b/diskann-disk/src/search/pq/pq_scratch.rs index 24d005af3..49a6fa2f9 100644 --- a/diskann-disk/src/search/pq/pq_scratch.rs +++ b/diskann-disk/src/search/pq/pq_scratch.rs @@ -4,9 +4,7 @@ */ //! Aligned allocator -use std::mem::size_of; - -use diskann::{error::IntoANNResult, utils::VectorRepr, ANNResult}; +use diskann::{error::IntoANNResult, utils::VectorRepr, ANNError, ANNResult}; use diskann_providers::common::AlignedBoxWithSlice; @@ -25,11 +23,9 @@ pub struct PQScratch { /// This is used to store the pq coordinates of the candidate vectors. pub aligned_pq_coord_scratch: AlignedBoxWithSlice, - /// Rotated query. It is initialized as the normalized query vector. Use PQTable.PreprocessQuery to rotate it. - pub rotated_query: AlignedBoxWithSlice, - - /// Aligned query float. The query vector is normalized with "norm" and stored here. - pub aligned_query_float: AlignedBoxWithSlice, + /// Query scratch buffer stored as `f32`. `set` initializes it by copying/converting the + /// raw query values; `PQTable.PreprocessQuery` can then rotate or otherwise preprocess it. + pub rotated_query: Vec, } impl PQScratch { @@ -39,7 +35,7 @@ impl PQScratch { /// Create a new pq scratch pub fn new( graph_degree: usize, - aligned_dim: usize, + dim: usize, num_pq_chunks: usize, num_centers: usize, ) -> ANNResult { @@ -49,36 +45,41 @@ impl PQScratch { AlignedBoxWithSlice::new(num_centers * num_pq_chunks, PQScratch::ALIGNED_ALLOC_128)?; let aligned_dist_scratch = AlignedBoxWithSlice::new(graph_degree, PQScratch::ALIGNED_ALLOC_128)?; - let aligned_query_float = AlignedBoxWithSlice::new(aligned_dim, 8 * size_of::())?; - let rotated_query = AlignedBoxWithSlice::new(aligned_dim, 8 * size_of::())?; + let rotated_query = vec![0.0f32; dim]; Ok(Self { aligned_pqtable_dist_scratch, aligned_dist_scratch, aligned_pq_coord_scratch, rotated_query, - aligned_query_float, }) } - /// Set rotated_query and aligned_query_float values - pub fn set(&mut self, dim: usize, query: &[T], norm: f32) -> ANNResult<()> - where - T: VectorRepr + Copy, - { - let query = &T::as_f32(&query[..dim]).into_ann_result()?; - - for (d, item) in query.iter().enumerate() { - let query_val = *item; - if (norm - 1.0).abs() > f32::EPSILON { - self.rotated_query[d] = query_val / norm; - self.aligned_query_float[d] = query_val / norm; - } else { - self.rotated_query[d] = query_val; - self.aligned_query_float[d] = query_val; - } + /// Copy `query` into `rotated_query`, converting to `f32`. + /// + /// `dim` is the element count in the `T` representation. The decompressed + /// `f32` length returned by `T::as_f32` may differ (e.g. `MinMaxElement` + /// expands to more `f32`s than its raw element count), so the destination + /// slice is sized by that actual length. + /// + /// Returns `DimensionMismatchError` if `dim > query.len()` or the + /// decompressed vector does not fit in `rotated_query`. + pub fn set(&mut self, dim: usize, query: &[T]) -> ANNResult<()> { + if dim > query.len() { + return Err(ANNError::log_dimension_mismatch_error(format!( + "PQScratch::set: expected query of length >= {dim}, got {}", + query.len() + ))); } - + let query = T::as_f32(&query[..dim]).into_ann_result()?; + if query.len() > self.rotated_query.len() { + return Err(ANNError::log_dimension_mismatch_error(format!( + "PQScratch::set: decompressed query of length {} does not fit rotated_query buffer of length {}", + query.len(), + self.rotated_query.len() + ))); + } + self.rotated_query[..query.len()].copy_from_slice(&query); Ok(()) } } @@ -94,14 +95,14 @@ mod tests { #[case(59, 16, 37, 41)] // not multiple of 256 fn test_pq_scratch( #[case] graph_degree: usize, - #[case] aligned_dim: usize, + #[case] dim: usize, #[case] num_pq_chunks: usize, #[case] num_centers: usize, ) { let mut pq_scratch: PQScratch = - PQScratch::new(graph_degree, aligned_dim, num_pq_chunks, num_centers).unwrap(); + PQScratch::new(graph_degree, dim, num_pq_chunks, num_centers).unwrap(); - // Check alignment + // Check alignment of the remaining AlignedBoxWithSlice buffers. assert_eq!( (pq_scratch.aligned_pqtable_dist_scratch.as_ptr() as usize) % PQScratch::ALIGNED_ALLOC_128, @@ -115,17 +116,13 @@ mod tests { (pq_scratch.aligned_pq_coord_scratch.as_ptr() as usize) % PQScratch::ALIGNED_ALLOC_128, 0 ); - assert_eq!((pq_scratch.rotated_query.as_ptr() as usize) % 32, 0); - assert_eq!((pq_scratch.aligned_query_float.as_ptr() as usize) % 32, 0); // Test set() method - let query: Vec = (1..=aligned_dim).map(|i| i as u8).collect(); - let norm = 2.0f32; - pq_scratch.set::(query.len(), &query, norm).unwrap(); + let query: Vec = (1..=dim).map(|i| i as u8).collect(); + pq_scratch.set::(query.len(), &query).unwrap(); (0..query.len()).for_each(|i| { - assert_eq!(pq_scratch.rotated_query[i], query[i] as f32 / norm); - assert_eq!(pq_scratch.aligned_query_float[i], query[i] as f32 / norm); + assert_eq!(pq_scratch.rotated_query[i], query[i] as f32); }); } } diff --git a/diskann-disk/src/search/provider/disk_provider.rs b/diskann-disk/src/search/provider/disk_provider.rs index 7bef9e6e4..076ffbcb6 100644 --- a/diskann-disk/src/search/provider/disk_provider.rs +++ b/diskann-disk/src/search/provider/disk_provider.rs @@ -632,11 +632,9 @@ where }, )?; - scratch.pq_scratch.set( - provider.graph_header.metadata().dims, - query, - 1.0_f32, // Normalization factor - )?; + scratch + .pq_scratch + .set(provider.graph_header.metadata().dims, query)?; let start_vertex_id = provider.graph_header.metadata().medoid as u32; let timer = Instant::now(); diff --git a/diskann-disk/src/search/provider/disk_vertex_provider.rs b/diskann-disk/src/search/provider/disk_vertex_provider.rs index b7096393e..f83759533 100644 --- a/diskann-disk/src/search/provider/disk_vertex_provider.rs +++ b/diskann-disk/src/search/provider/disk_vertex_provider.rs @@ -8,11 +8,10 @@ use std::ptr; use crate::data_model::GraphDataType; use byteorder::{ByteOrder, LittleEndian}; use diskann::{ANNError, ANNResult}; -use diskann_providers::common::AlignedBoxWithSlice; use hashbrown::HashMap; use crate::{ - data_model::{GraphHeader, FP_VECTOR_MEM_ALIGN}, + data_model::GraphHeader, search::{provider::disk_sector_graph::DiskSectorGraph, traits::VertexProvider}, utils::aligned_file_reader::traits::AlignedFileReader, }; @@ -30,8 +29,8 @@ where /// Centroid vertex id. pub centroid_vertex_id: u64, - /// Memory-aligned dimension. In-memory the vectors should be this size. - memory_aligned_dimension: usize, + /// Dimensionality of each fp vector. + dim: usize, /// the len of fp vector fp_vector_len: u64, @@ -39,8 +38,9 @@ where // sector graph sector_graph: DiskSectorGraph, - // Aligned fp vector cache - aligned_vector_buf: AlignedBoxWithSlice, + // Flat buffer holding the fp vectors for up to `max_batch_size` loaded nodes, + // laid out as `max_batch_size * dim` elements. + vector_buf: Vec, // The cached adjacency list. cached_adjacency_list: Vec>, @@ -74,10 +74,8 @@ where vertex_id: &Data::VectorIdType, ) -> ANNResult<&[::VectorDataType]> { match self.loaded_nodes.get(vertex_id) { - Some(local_offset) => Ok(&self.aligned_vector_buf[local_offset.idx - * self.memory_aligned_dimension - ..(local_offset.idx * self.memory_aligned_dimension) - + self.memory_aligned_dimension]), + Some(local_offset) => Ok(&self.vector_buf + [local_offset.idx * self.dim..(local_offset.idx * self.dim) + self.dim]), None => Err(ANNError::log_get_vertex_data_error( vertex_id.to_string(), "Vector".to_string(), @@ -126,15 +124,15 @@ where let fp_vector_buf = &self.sector_graph.node_disk_buf(idx, *vertex_id)[..self.fp_vector_len as usize]; - // memcpy from fp_vector_buf to the aligned buffer.. - // The safe condition is met here since the dimension of the vector in fp_vector_buffer is the same with aligned_vector_buffer. - // fp_vector_buf and aligned_vector_buffer.as_mut_ptr() are guaranteed to be non-overlapping. + // memcpy from fp_vector_buf to the vector buffer. + // The safe condition is met here since the dimension of the vector in fp_vector_buffer is + // the same with vector_buf. fp_vector_buf and vector_buf.as_mut_ptr() are guaranteed to be + // non-overlapping. unsafe { ptr::copy_nonoverlapping( fp_vector_buf.as_ptr(), - self.aligned_vector_buf[idx * self.memory_aligned_dimension - ..(idx * self.memory_aligned_dimension) + self.memory_aligned_dimension] - .as_mut_ptr() as *mut u8, + self.vector_buf[idx * self.dim..(idx * self.dim) + self.dim].as_mut_ptr() + as *mut u8, fp_vector_buf.len(), ); } @@ -210,17 +208,14 @@ where sector_reader: AlignedReaderType, ) -> ANNResult { let metadata = header.metadata(); - let memory_aligned_dimension = metadata.dims.next_multiple_of(8); + let dim = metadata.dims; Ok(Self { centroid_vertex_id: metadata.medoid, - memory_aligned_dimension, - fp_vector_len: (metadata.dims * std::mem::size_of::()) as u64, + dim, + fp_vector_len: (dim * std::mem::size_of::()) as u64, sector_graph: DiskSectorGraph::new(sector_reader, header, max_batch_size)?, - aligned_vector_buf: AlignedBoxWithSlice::new( - max_batch_size * memory_aligned_dimension, - FP_VECTOR_MEM_ALIGN, - )?, + vector_buf: vec![Data::VectorDataType::default(); max_batch_size * dim], cached_adjacency_list: Vec::with_capacity(max_batch_size), cached_associated_data: Vec::with_capacity(max_batch_size), loaded_nodes: HashMap::with_capacity(max_batch_size), @@ -238,10 +233,7 @@ where self.cached_adjacency_list.reserve(max_batch_size); self.cached_associated_data.reserve(max_batch_size); self.loaded_nodes.reserve(max_batch_size); - self.aligned_vector_buf = AlignedBoxWithSlice::new( - max_batch_size * self.memory_aligned_dimension, - FP_VECTOR_MEM_ALIGN, - )?; + self.vector_buf = vec![Data::VectorDataType::default(); max_batch_size * self.dim]; self.max_batch_size = max_batch_size; } Ok(()) @@ -267,10 +259,6 @@ where self.cached_adjacency_list.clear(); self.cached_associated_data.clear(); } - - pub fn memory_aligned_dimension(&self) -> usize { - self.memory_aligned_dimension - } } #[cfg(test)] diff --git a/diskann-disk/src/search/provider/disk_vertex_provider_factory.rs b/diskann-disk/src/search/provider/disk_vertex_provider_factory.rs index bc05e4b8e..116608610 100644 --- a/diskann-disk/src/search/provider/disk_vertex_provider_factory.rs +++ b/diskann-disk/src/search/provider/disk_vertex_provider_factory.rs @@ -132,7 +132,6 @@ impl, ReaderFactory: AlignedReaderFactor let graph_metadata = self.get_header()?; let graph_metadata = graph_metadata.metadata(); - let memory_aligned_dimension = graph_metadata.dims.next_multiple_of(8); if num_nodes_to_cache > graph_metadata.num_pts as usize { info!( @@ -146,7 +145,7 @@ impl, ReaderFactory: AlignedReaderFactor self.cache = Some(Arc::new(self.build_cache_via_bfs( start_node, num_nodes_to_cache, - memory_aligned_dimension, + graph_metadata.dims, )?)); } CachingStrategy::None => {}