Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 3 additions & 6 deletions diskann-disk/src/data_model/cache.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,17 +5,14 @@

use crate::data_model::GraphDataType;
use diskann::{graph::AdjacencyList, ANNError, ANNResult};
use diskann_providers::common::AlignedBoxWithSlice;
use hashbrown::{hash_map::Entry::Occupied, HashMap};

use super::FP_VECTOR_MEM_ALIGN;

pub struct Cache<Data: GraphDataType<VectorIdType = u32>> {
// Maintains the mapping of vector_id to index in the global cached nodes list.
mapping: HashMap<Data::VectorIdType, usize>,

// Aligned buffer to store the vectors of cached nodes.
vectors: AlignedBoxWithSlice<Data::VectorDataType>,
// Flat buffer holding `capacity * dimension` vector elements, laid out row-major.
vectors: Vec<Data::VectorDataType>,

// The cached adjacency lists.
adjacency_lists: Vec<AdjacencyList<Data::VectorIdType>>,
Expand All @@ -38,7 +35,7 @@ where
pub fn new(dimension: usize, capacity: usize) -> ANNResult<Self> {
Ok(Self {
mapping: HashMap::new(),
vectors: AlignedBoxWithSlice::new(capacity * dimension, FP_VECTOR_MEM_ALIGN)?,
vectors: vec![Data::VectorDataType::default(); capacity * dimension],
adjacency_lists: Vec::with_capacity(capacity),
associated_data: Vec::with_capacity(capacity),
dimension,
Expand Down
2 changes: 0 additions & 2 deletions diskann-disk/src/data_model/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,5 +17,3 @@ pub use cache::{Cache, CachingStrategy};

pub mod graph_data_types;
pub use graph_data_types::{AdHoc, GraphDataType};

pub const FP_VECTOR_MEM_ALIGN: usize = 32;
75 changes: 36 additions & 39 deletions diskann-disk/src/search/pq/pq_scratch.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,7 @@
*/
//! Aligned allocator

use std::mem::size_of;

use diskann::{error::IntoANNResult, utils::VectorRepr, ANNResult};
use diskann::{error::IntoANNResult, utils::VectorRepr, ANNError, ANNResult};

use diskann_providers::common::AlignedBoxWithSlice;

Expand All @@ -25,11 +23,9 @@ pub struct PQScratch {
/// This is used to store the pq coordinates of the candidate vectors.
pub aligned_pq_coord_scratch: AlignedBoxWithSlice<u8>,

/// Rotated query. It is initialized as the normalized query vector. Use PQTable.PreprocessQuery to rotate it.
pub rotated_query: AlignedBoxWithSlice<f32>,

/// Aligned query float. The query vector is normalized with "norm" and stored here.
pub aligned_query_float: AlignedBoxWithSlice<f32>,
/// Query scratch buffer stored as `f32`. `set` initializes it by copying/converting the
/// raw query values; `PQTable.PreprocessQuery` can then rotate or otherwise preprocess it.
pub rotated_query: Vec<f32>,
}

impl PQScratch {
Expand All @@ -39,7 +35,7 @@ impl PQScratch {
/// Create a new pq scratch
pub fn new(
graph_degree: usize,
aligned_dim: usize,
dim: usize,
num_pq_chunks: usize,
num_centers: usize,
) -> ANNResult<Self> {
Expand All @@ -49,36 +45,41 @@ impl PQScratch {
AlignedBoxWithSlice::new(num_centers * num_pq_chunks, PQScratch::ALIGNED_ALLOC_128)?;
let aligned_dist_scratch =
AlignedBoxWithSlice::new(graph_degree, PQScratch::ALIGNED_ALLOC_128)?;
let aligned_query_float = AlignedBoxWithSlice::new(aligned_dim, 8 * size_of::<f32>())?;
let rotated_query = AlignedBoxWithSlice::new(aligned_dim, 8 * size_of::<f32>())?;
let rotated_query = vec![0.0f32; dim];

Ok(Self {
aligned_pqtable_dist_scratch,
aligned_dist_scratch,
aligned_pq_coord_scratch,
rotated_query,
aligned_query_float,
})
}

/// Set rotated_query and aligned_query_float values
pub fn set<T>(&mut self, dim: usize, query: &[T], norm: f32) -> ANNResult<()>
where
T: VectorRepr + Copy,
{
let query = &T::as_f32(&query[..dim]).into_ann_result()?;

for (d, item) in query.iter().enumerate() {
let query_val = *item;
if (norm - 1.0).abs() > f32::EPSILON {
self.rotated_query[d] = query_val / norm;
self.aligned_query_float[d] = query_val / norm;
} else {
self.rotated_query[d] = query_val;
self.aligned_query_float[d] = query_val;
}
/// Copy `query` into `rotated_query`, converting to `f32`.
///
/// `dim` is the element count in the `T` representation. The decompressed
/// `f32` length returned by `T::as_f32` may differ (e.g. `MinMaxElement`
/// expands to more `f32`s than its raw element count), so the destination
/// slice is sized by that actual length.
///
/// Returns `DimensionMismatchError` if `dim > query.len()` or the
/// decompressed vector does not fit in `rotated_query`.
pub fn set<T: VectorRepr>(&mut self, dim: usize, query: &[T]) -> ANNResult<()> {
if dim > query.len() {
return Err(ANNError::log_dimension_mismatch_error(format!(
"PQScratch::set: expected query of length >= {dim}, got {}",
query.len()
)));
}

let query = T::as_f32(&query[..dim]).into_ann_result()?;
if query.len() > self.rotated_query.len() {
return Err(ANNError::log_dimension_mismatch_error(format!(
"PQScratch::set: decompressed query of length {} does not fit rotated_query buffer of length {}",
query.len(),
self.rotated_query.len()
)));
}
self.rotated_query[..query.len()].copy_from_slice(&query);
Ok(())
}
}
Expand All @@ -94,14 +95,14 @@ mod tests {
#[case(59, 16, 37, 41)] // not multiple of 256
fn test_pq_scratch(
#[case] graph_degree: usize,
#[case] aligned_dim: usize,
#[case] dim: usize,
#[case] num_pq_chunks: usize,
#[case] num_centers: usize,
) {
let mut pq_scratch: PQScratch =
PQScratch::new(graph_degree, aligned_dim, num_pq_chunks, num_centers).unwrap();
PQScratch::new(graph_degree, dim, num_pq_chunks, num_centers).unwrap();

// Check alignment
// Check alignment of the remaining AlignedBoxWithSlice buffers.
assert_eq!(
(pq_scratch.aligned_pqtable_dist_scratch.as_ptr() as usize)
% PQScratch::ALIGNED_ALLOC_128,
Expand All @@ -115,17 +116,13 @@ mod tests {
(pq_scratch.aligned_pq_coord_scratch.as_ptr() as usize) % PQScratch::ALIGNED_ALLOC_128,
0
);
assert_eq!((pq_scratch.rotated_query.as_ptr() as usize) % 32, 0);
assert_eq!((pq_scratch.aligned_query_float.as_ptr() as usize) % 32, 0);

// Test set() method
let query: Vec<u8> = (1..=aligned_dim).map(|i| i as u8).collect();
let norm = 2.0f32;
pq_scratch.set::<u8>(query.len(), &query, norm).unwrap();
let query: Vec<u8> = (1..=dim).map(|i| i as u8).collect();
pq_scratch.set::<u8>(query.len(), &query).unwrap();

(0..query.len()).for_each(|i| {
assert_eq!(pq_scratch.rotated_query[i], query[i] as f32 / norm);
assert_eq!(pq_scratch.aligned_query_float[i], query[i] as f32 / norm);
assert_eq!(pq_scratch.rotated_query[i], query[i] as f32);
});
}
}
8 changes: 3 additions & 5 deletions diskann-disk/src/search/provider/disk_provider.rs
Original file line number Diff line number Diff line change
Expand Up @@ -632,11 +632,9 @@ where
},
)?;

scratch.pq_scratch.set(
provider.graph_header.metadata().dims,
query,
1.0_f32, // Normalization factor
)?;
scratch
.pq_scratch
.set(provider.graph_header.metadata().dims, query)?;
let start_vertex_id = provider.graph_header.metadata().medoid as u32;

let timer = Instant::now();
Expand Down
50 changes: 19 additions & 31 deletions diskann-disk/src/search/provider/disk_vertex_provider.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,10 @@ use std::ptr;
use crate::data_model::GraphDataType;
use byteorder::{ByteOrder, LittleEndian};
use diskann::{ANNError, ANNResult};
use diskann_providers::common::AlignedBoxWithSlice;
use hashbrown::HashMap;

use crate::{
data_model::{GraphHeader, FP_VECTOR_MEM_ALIGN},
data_model::GraphHeader,
search::{provider::disk_sector_graph::DiskSectorGraph, traits::VertexProvider},
utils::aligned_file_reader::traits::AlignedFileReader,
};
Expand All @@ -30,17 +29,18 @@ where
/// Centroid vertex id.
pub centroid_vertex_id: u64,

/// Memory-aligned dimension. In-memory the vectors should be this size.
memory_aligned_dimension: usize,
/// Dimensionality of each fp vector.
dim: usize,

/// the len of fp vector
fp_vector_len: u64,

// sector graph
sector_graph: DiskSectorGraph<AlignedReaderType>,

// Aligned fp vector cache
aligned_vector_buf: AlignedBoxWithSlice<Data::VectorDataType>,
// Flat buffer holding the fp vectors for up to `max_batch_size` loaded nodes,
// laid out as `max_batch_size * dim` elements.
vector_buf: Vec<Data::VectorDataType>,

// The cached adjacency list.
cached_adjacency_list: Vec<Vec<Data::VectorIdType>>,
Expand Down Expand Up @@ -74,10 +74,8 @@ where
vertex_id: &Data::VectorIdType,
) -> ANNResult<&[<Data as GraphDataType>::VectorDataType]> {
match self.loaded_nodes.get(vertex_id) {
Some(local_offset) => Ok(&self.aligned_vector_buf[local_offset.idx
* self.memory_aligned_dimension
..(local_offset.idx * self.memory_aligned_dimension)
+ self.memory_aligned_dimension]),
Some(local_offset) => Ok(&self.vector_buf
[local_offset.idx * self.dim..(local_offset.idx * self.dim) + self.dim]),
None => Err(ANNError::log_get_vertex_data_error(
vertex_id.to_string(),
"Vector".to_string(),
Expand Down Expand Up @@ -126,15 +124,15 @@ where
let fp_vector_buf =
&self.sector_graph.node_disk_buf(idx, *vertex_id)[..self.fp_vector_len as usize];

// memcpy from fp_vector_buf to the aligned buffer..
// The safe condition is met here since the dimension of the vector in fp_vector_buffer is the same with aligned_vector_buffer.
// fp_vector_buf and aligned_vector_buffer.as_mut_ptr() are guaranteed to be non-overlapping.
// memcpy from fp_vector_buf to the vector buffer.
// The safe condition is met here since the dimension of the vector in fp_vector_buffer is
// the same with vector_buf. fp_vector_buf and vector_buf.as_mut_ptr() are guaranteed to be
// non-overlapping.
unsafe {
ptr::copy_nonoverlapping(
fp_vector_buf.as_ptr(),
self.aligned_vector_buf[idx * self.memory_aligned_dimension
..(idx * self.memory_aligned_dimension) + self.memory_aligned_dimension]
.as_mut_ptr() as *mut u8,
self.vector_buf[idx * self.dim..(idx * self.dim) + self.dim].as_mut_ptr()
as *mut u8,
fp_vector_buf.len(),
);
}
Expand Down Expand Up @@ -210,17 +208,14 @@ where
sector_reader: AlignedReaderType,
) -> ANNResult<Self> {
let metadata = header.metadata();
let memory_aligned_dimension = metadata.dims.next_multiple_of(8);
let dim = metadata.dims;
Ok(Self {
centroid_vertex_id: metadata.medoid,
memory_aligned_dimension,
fp_vector_len: (metadata.dims * std::mem::size_of::<Data::VectorDataType>()) as u64,
dim,
fp_vector_len: (dim * std::mem::size_of::<Data::VectorDataType>()) as u64,
sector_graph: DiskSectorGraph::new(sector_reader, header, max_batch_size)?,

aligned_vector_buf: AlignedBoxWithSlice::new(
max_batch_size * memory_aligned_dimension,
FP_VECTOR_MEM_ALIGN,
)?,
vector_buf: vec![Data::VectorDataType::default(); max_batch_size * dim],
cached_adjacency_list: Vec::with_capacity(max_batch_size),
cached_associated_data: Vec::with_capacity(max_batch_size),
loaded_nodes: HashMap::with_capacity(max_batch_size),
Expand All @@ -238,10 +233,7 @@ where
self.cached_adjacency_list.reserve(max_batch_size);
self.cached_associated_data.reserve(max_batch_size);
self.loaded_nodes.reserve(max_batch_size);
self.aligned_vector_buf = AlignedBoxWithSlice::new(
max_batch_size * self.memory_aligned_dimension,
FP_VECTOR_MEM_ALIGN,
)?;
self.vector_buf = vec![Data::VectorDataType::default(); max_batch_size * self.dim];
self.max_batch_size = max_batch_size;
}
Ok(())
Expand All @@ -267,10 +259,6 @@ where
self.cached_adjacency_list.clear();
self.cached_associated_data.clear();
}

pub fn memory_aligned_dimension(&self) -> usize {
self.memory_aligned_dimension
}
}

#[cfg(test)]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,6 @@ impl<Data: GraphDataType<VectorIdType = u32>, ReaderFactory: AlignedReaderFactor

let graph_metadata = self.get_header()?;
let graph_metadata = graph_metadata.metadata();
let memory_aligned_dimension = graph_metadata.dims.next_multiple_of(8);

if num_nodes_to_cache > graph_metadata.num_pts as usize {
info!(
Expand All @@ -146,7 +145,7 @@ impl<Data: GraphDataType<VectorIdType = u32>, ReaderFactory: AlignedReaderFactor
self.cache = Some(Arc::new(self.build_cache_via_bfs(
start_node,
num_nodes_to_cache,
memory_aligned_dimension,
graph_metadata.dims,
)?));
}
CachingStrategy::None => {}
Expand Down
Loading