Skip to content

Commit e5a4af0

Browse files
committed
rewrite sub_to with avx2
Signed-off-by: Keming <kemingy94@gmail.com>
1 parent 28efe09 commit e5a4af0

File tree

2 files changed

+47
-19
lines changed

2 files changed

+47
-19
lines changed

src/rabitq.rs

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -118,19 +118,25 @@ fn min_max_raw(vec: &DVectorView<f32>) -> (f32, f32) {
118118
(min, max)
119119
}
120120

121-
// Interface of `min_max`
122-
fn min_max(vec: &DVectorView<f32>) -> (f32, f32) {
121+
// Interface of `min_max_residual`
122+
fn min_max_residual(
123+
res: &mut DVector<f32>,
124+
x: &DVectorView<f32>,
125+
y: &DVectorView<f32>,
126+
) -> (f32, f32) {
123127
#[cfg(any(target_arch = "x86_64", target_arch = "x86"))]
124128
{
125129
if is_x86_feature_detected!("avx") {
126-
unsafe { crate::simd::min_max_avx(vec) }
130+
unsafe { crate::simd::min_max_residual_avx(res, x, y) }
127131
} else {
128-
min_max_raw(vec)
132+
x.sub_to(y, res);
133+
min_max_raw(&res.as_view())
129134
}
130135
}
131136
#[cfg(not(any(target_arch = "x86_64", target_arch = "x86")))]
132137
{
133-
min_max_raw(vec)
138+
x.sub_to(y, &mut res);
139+
min_max_raw(&res.as_view())
134140
}
135141
}
136142

@@ -333,8 +339,11 @@ impl RaBitQ {
333339
let mut quantized = DVector::<u8>::zeros(self.dim as usize);
334340
let mut binary_vec = vec![0u64; query.len() * THETA_LOG_DIM as usize / 64];
335341
for &(dist, i) in lists[..length].iter() {
336-
y_projected.sub_to(&self.centroids.column(i), &mut residual);
337-
let (lower_bound, upper_bound) = min_max(&residual.as_view());
342+
let (lower_bound, upper_bound) = min_max_residual(
343+
&mut residual,
344+
&y_projected.as_view(),
345+
&self.centroids.column(i),
346+
);
338347
let delta = (upper_bound - lower_bound) / ((1 << THETA_LOG_DIM) as f32 - 1.0);
339348
let one_over_delta = 1.0 / delta;
340349
let scalar_sum = scalar_quantize(

src/simd.rs

Lines changed: 31 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -107,23 +107,35 @@ pub unsafe fn vector_binarize_query_avx2(vec: &DVectorView<u8>, binary: &mut [u6
107107
/// This function is marked unsafe because it requires the AVX intrinsics.
108108
#[cfg(any(target_arch = "x86_64", target_arch = "x86"))]
109109
#[target_feature(enable = "avx")]
110-
pub unsafe fn min_max_avx(vec: &DVectorView<f32>) -> (f32, f32) {
110+
pub unsafe fn min_max_residual_avx(
111+
res: &mut DVector<f32>,
112+
x: &DVectorView<f32>,
113+
y: &DVectorView<f32>,
114+
) -> (f32, f32) {
111115
use std::arch::x86_64::*;
112116

113117
let mut min_32x8 = _mm256_set1_ps(f32::MAX);
114118
let mut max_32x8 = _mm256_set1_ps(f32::MIN);
115-
let mut ptr = vec.as_ptr();
119+
let mut x_ptr = x.as_ptr();
120+
let mut y_ptr = y.as_ptr();
121+
let mut res_ptr = res.as_mut_ptr();
116122
let mut f32x8 = [0.0f32; 8];
117123
let mut min = f32::MAX;
118124
let mut max = f32::MIN;
119-
let length = vec.len();
125+
let length = res.len();
120126
let rest = length & 0b111;
127+
let (mut x256, mut y256, mut res256);
121128

122129
for _ in 0..(length / 8) {
123-
let v = _mm256_loadu_ps(ptr);
124-
ptr = ptr.add(8);
125-
min_32x8 = _mm256_min_ps(min_32x8, v);
126-
max_32x8 = _mm256_max_ps(max_32x8, v);
130+
x256 = _mm256_loadu_ps(x_ptr);
131+
y256 = _mm256_loadu_ps(y_ptr);
132+
res256 = _mm256_sub_ps(x256, y256);
133+
_mm256_storeu_ps(res_ptr, res256);
134+
x_ptr = x_ptr.add(8);
135+
y_ptr = y_ptr.add(8);
136+
res_ptr = res_ptr.add(8);
137+
min_32x8 = _mm256_min_ps(min_32x8, res256);
138+
max_32x8 = _mm256_max_ps(max_32x8, res256);
127139
}
128140
_mm256_storeu_ps(f32x8.as_mut_ptr(), min_32x8);
129141
for &x in f32x8.iter() {
@@ -139,23 +151,30 @@ pub unsafe fn min_max_avx(vec: &DVectorView<f32>) -> (f32, f32) {
139151
}
140152

141153
for _ in 0..rest {
142-
if *ptr < min {
143-
min = *ptr;
154+
*res_ptr = *x_ptr - *y_ptr;
155+
if *res_ptr < min {
156+
min = *res_ptr;
144157
}
145-
if *ptr > max {
146-
max = *ptr;
158+
if *res_ptr > max {
159+
max = *res_ptr;
147160
}
148-
ptr = ptr.add(1);
161+
res_ptr = res_ptr.add(1);
162+
x_ptr = x_ptr.add(1);
163+
y_ptr = y_ptr.add(1);
149164
}
150165

151166
(min, max)
152167
}
153168

154169
/// Compute the u8 scalar quantization of a f32 vector.
155170
///
171+
/// This function doesn't need `bias` because it *round* the f32 to u32 instead of *floor*.
172+
///
156173
/// # Safety
157174
///
158175
/// This function is marked unsafe because it requires the AVX intrinsics.
176+
#[cfg(any(target_arch = "x86_64", target_arch = "x86"))]
177+
#[target_feature(enable = "avx2")]
159178
pub unsafe fn scalar_quantize_avx2(
160179
quantized: &mut DVector<u8>,
161180
vec: &DVectorView<f32>,

0 commit comments

Comments
 (0)