Skip to content

Commit d2d51b0

Browse files
committed
reuse y_binary_vec with reset, replace to_u8 to as u8
Signed-off-by: Keming <kemingy94@gmail.com>
1 parent af39c1c commit d2d51b0

File tree

2 files changed

+20
-24
lines changed

2 files changed

+20
-24
lines changed

src/rabitq.rs

Lines changed: 19 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -4,16 +4,15 @@ use core::f32;
44
use std::path::Path;
55

66
use log::debug;
7-
use nalgebra::{DMatrix, DVector, DVectorView};
8-
use num_traits::ToPrimitive;
7+
use nalgebra::{DMatrix, DMatrixView, DVector, DVectorView};
98
use serde::{Deserialize, Serialize};
109

1110
use crate::consts::{DEFAULT_X_DOT_PRODUCT, EPSILON, THETA_LOG_DIM, WINDOWS_SIZE};
1211
use crate::metrics::METRICS;
1312
use crate::utils::{gen_random_bias, gen_random_qr_orthogonal, matrix_from_fvecs};
1413

1514
/// Convert the vector to binary format and store in a u64 vector.
16-
fn vector_binarize_u64(vec: &DVector<f32>) -> Vec<u64> {
15+
fn vector_binarize_u64(vec: &DVectorView<f32>) -> Vec<u64> {
1716
let mut binary = vec![0u64; (vec.len() + 63) / 64];
1817
for (i, &v) in vec.iter().enumerate() {
1918
if v > 0.0 {
@@ -25,37 +24,37 @@ fn vector_binarize_u64(vec: &DVector<f32>) -> Vec<u64> {
2524

2625
/// Convert the vector to +1/-1 format.
2726
#[inline]
28-
fn vector_binarize_one(vec: &DVector<f32>) -> DVector<f32> {
27+
fn vector_binarize_one(vec: &DVectorView<f32>) -> DVector<f32> {
2928
DVector::from_fn(vec.len(), |i, _| if vec[i] > 0.0 { 1.0 } else { -1.0 })
3029
}
3130

3231
/// Interface of `vector_binarize_query`
33-
fn vector_binarize_query(vec: &DVector<u8>) -> Vec<u64> {
32+
fn vector_binarize_query(vec: &DVectorView<u8>, binary: &mut [u64]) {
3433
#[cfg(any(target_arch = "x86_64", target_arch = "x86"))]
3534
{
3635
if is_x86_feature_detected!("avx2") {
37-
unsafe { crate::simd::vector_binarize_query_avx2(&vec.as_view()) }
36+
unsafe {
37+
crate::simd::vector_binarize_query_avx2(&vec.as_view(), binary);
38+
}
3839
} else {
39-
vector_binarize_query_raw(vec)
40+
vector_binarize_query_raw(vec, binary);
4041
}
4142
}
4243
#[cfg(not(any(target_arch = "x86_64", target_arch = "x86")))]
4344
{
44-
vector_binarize_query_raw(vec)
45+
vector_binarize_query_raw(vec, binary);
4546
}
4647
}
4748

4849
/// Convert the vector to binary format (one value to multiple bits) and store in a u64 vector.
4950
#[inline]
50-
fn vector_binarize_query_raw(vec: &DVector<u8>) -> Vec<u64> {
51+
fn vector_binarize_query_raw(vec: &DVectorView<u8>, binary: &mut [u64]) {
5152
let length = vec.len();
52-
let mut binary = vec![0u64; length * THETA_LOG_DIM as usize / 64];
5353
for j in 0..THETA_LOG_DIM as usize {
5454
for i in 0..length {
5555
binary[(i + j * length) / 64] |= (((vec[i] >> j) & 1) as u64) << (i % 64);
5656
}
5757
}
58-
binary
5958
}
6059

6160
/// Calculate the dot product of two binary vectors.
@@ -129,17 +128,15 @@ fn quantize_query_vector(
129128
) -> u32 {
130129
let mut sum = 0u32;
131130
for i in 0..vec.len() {
132-
let q = ((vec[i] - lower_bound) * multiplier + bias[i])
133-
.to_u8()
134-
.expect("convert to u8 error");
131+
let q = ((vec[i] - lower_bound) * multiplier + bias[i]) as u8;
135132
quantized[i] = q;
136133
sum += q as u32;
137134
}
138135
sum
139136
}
140137

141138
/// Find the nearest cluster for the given vector.
142-
fn kmeans_nearest_cluster(centroids: &DMatrix<f32>, vec: &DVectorView<f32>) -> usize {
139+
fn kmeans_nearest_cluster(centroids: &DMatrixView<f32>, vec: &DVectorView<f32>) -> usize {
143140
let mut min_dist = f32::MAX;
144141
let mut min_label = 0;
145142
let mut residual = DVector::<f32>::zeros(vec.len());
@@ -198,13 +195,13 @@ impl RaBitQ {
198195
if i % 5000 == 0 {
199196
debug!("\t> preprocessing {}...", i);
200197
}
201-
let min_label = kmeans_nearest_cluster(&centroids, &xp);
198+
let min_label = kmeans_nearest_cluster(&centroids.as_view(), &xp);
202199
labels[min_label].push(i as u32);
203200
let x_c_quantized = xp - centroids.column(min_label);
204201
x_c_distance[i] = x_c_quantized.norm();
205202
x_c_distance_square[i] = x_c_distance[i].powi(2);
206-
x_binary_vec.push(vector_binarize_u64(&x_c_quantized));
207-
x_signed_vec.push(vector_binarize_one(&x_c_quantized));
203+
x_binary_vec.push(vector_binarize_u64(&x_c_quantized.as_view()));
204+
x_signed_vec.push(vector_binarize_one(&x_c_quantized.as_view()));
208205
let norm = x_c_distance[i] * dim_sqrt;
209206
x_dot_product[i] = if norm.is_normal() {
210207
x_c_quantized.dot(&x_signed_vec[i]) / norm
@@ -284,6 +281,7 @@ impl RaBitQ {
284281

285282
let mut rough_distances = Vec::new();
286283
let mut quantized = DVector::<u8>::zeros(self.dim as usize);
284+
let mut binary_vec = vec![0u64; query.len() * THETA_LOG_DIM as usize / 64];
287285
for &(dist, i) in lists[..length].iter() {
288286
y_projected.sub_to(&self.centroids.column(i), &mut residual);
289287
let (lower_bound, upper_bound) = min_max(&residual.as_view());
@@ -296,7 +294,8 @@ impl RaBitQ {
296294
lower_bound,
297295
one_over_delta,
298296
);
299-
let y_binary_vec = vector_binarize_query(&quantized);
297+
binary_vec.iter_mut().for_each(|element| *element = 0);
298+
vector_binarize_query(&quantized.as_view(), &mut binary_vec);
300299
let dist_sqrt = dist.sqrt();
301300
for j in self.offsets[i]..self.offsets[i + 1] {
302301
let ju = j as usize;
@@ -305,7 +304,7 @@ impl RaBitQ {
305304
+ dist
306305
+ lower_bound * self.factor_ppc[ju]
307306
+ (2.0
308-
* asymmetric_binary_dot_product(&self.x_binary_vec[ju], &y_binary_vec)
307+
* asymmetric_binary_dot_product(&self.x_binary_vec[ju], &binary_vec)
309308
as f32
310309
- scalar_sum as f32)
311310
* self.factor_ip[ju]

src/simd.rs

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -79,12 +79,11 @@ pub unsafe fn l2_squared_distance_avx2(lhs: &DVectorView<f32>, rhs: &DVectorView
7979
/// This function is marked unsafe because it requires the AVX intrinsics.
8080
#[cfg(any(target_arch = "x86_64", target_arch = "x86"))]
8181
#[target_feature(enable = "avx2")]
82-
pub unsafe fn vector_binarize_query_avx2(vec: &DVectorView<u8>) -> Vec<u64> {
82+
pub unsafe fn vector_binarize_query_avx2(vec: &DVectorView<u8>, binary: &mut [u64]) {
8383
use std::arch::x86_64::*;
8484

8585
let length = vec.len();
8686
let mut ptr = vec.as_ptr() as *const __m256i;
87-
let mut binary = vec![0u64; length * THETA_LOG_DIM as usize / 64];
8887

8988
for i in (0..length).step_by(32) {
9089
// since it's not guaranteed that the vec is fully-aligned
@@ -99,6 +98,4 @@ pub unsafe fn vector_binarize_query_avx2(vec: &DVectorView<u8>) -> Vec<u64> {
9998
v = _mm256_slli_epi32(v, 1);
10099
}
101100
}
102-
103-
binary
104101
}

0 commit comments

Comments
 (0)