@@ -4,16 +4,15 @@ use core::f32;
44use std:: path:: Path ;
55
66use log:: debug;
7- use nalgebra:: { DMatrix , DVector , DVectorView } ;
8- use num_traits:: ToPrimitive ;
7+ use nalgebra:: { DMatrix , DMatrixView , DVector , DVectorView } ;
98use serde:: { Deserialize , Serialize } ;
109
1110use crate :: consts:: { DEFAULT_X_DOT_PRODUCT , EPSILON , THETA_LOG_DIM , WINDOWS_SIZE } ;
1211use crate :: metrics:: METRICS ;
1312use crate :: utils:: { gen_random_bias, gen_random_qr_orthogonal, matrix_from_fvecs} ;
1413
1514/// Convert the vector to binary format and store in a u64 vector.
16- fn vector_binarize_u64 ( vec : & DVector < f32 > ) -> Vec < u64 > {
15+ fn vector_binarize_u64 ( vec : & DVectorView < f32 > ) -> Vec < u64 > {
1716 let mut binary = vec ! [ 0u64 ; ( vec. len( ) + 63 ) / 64 ] ;
1817 for ( i, & v) in vec. iter ( ) . enumerate ( ) {
1918 if v > 0.0 {
@@ -25,37 +24,37 @@ fn vector_binarize_u64(vec: &DVector<f32>) -> Vec<u64> {
2524
2625/// Convert the vector to +1/-1 format.
2726#[ inline]
28- fn vector_binarize_one ( vec : & DVector < f32 > ) -> DVector < f32 > {
27+ fn vector_binarize_one ( vec : & DVectorView < f32 > ) -> DVector < f32 > {
2928 DVector :: from_fn ( vec. len ( ) , |i, _| if vec[ i] > 0.0 { 1.0 } else { -1.0 } )
3029}
3130
3231/// Interface of `vector_binarize_query`
33- fn vector_binarize_query ( vec : & DVector < u8 > ) -> Vec < u64 > {
32+ fn vector_binarize_query ( vec : & DVectorView < u8 > , binary : & mut [ u64 ] ) {
3433 #[ cfg( any( target_arch = "x86_64" , target_arch = "x86" ) ) ]
3534 {
3635 if is_x86_feature_detected ! ( "avx2" ) {
37- unsafe { crate :: simd:: vector_binarize_query_avx2 ( & vec. as_view ( ) ) }
36+ unsafe {
37+ crate :: simd:: vector_binarize_query_avx2 ( & vec. as_view ( ) , binary) ;
38+ }
3839 } else {
39- vector_binarize_query_raw ( vec)
40+ vector_binarize_query_raw ( vec, binary ) ;
4041 }
4142 }
4243 #[ cfg( not( any( target_arch = "x86_64" , target_arch = "x86" ) ) ) ]
4344 {
44- vector_binarize_query_raw ( vec)
45+ vector_binarize_query_raw ( vec, binary ) ;
4546 }
4647}
4748
4849/// Convert the vector to binary format (one value to multiple bits) and store in a u64 vector.
4950#[ inline]
50- fn vector_binarize_query_raw ( vec : & DVector < u8 > ) -> Vec < u64 > {
51+ fn vector_binarize_query_raw ( vec : & DVectorView < u8 > , binary : & mut [ u64 ] ) {
5152 let length = vec. len ( ) ;
52- let mut binary = vec ! [ 0u64 ; length * THETA_LOG_DIM as usize / 64 ] ;
5353 for j in 0 ..THETA_LOG_DIM as usize {
5454 for i in 0 ..length {
5555 binary[ ( i + j * length) / 64 ] |= ( ( ( vec[ i] >> j) & 1 ) as u64 ) << ( i % 64 ) ;
5656 }
5757 }
58- binary
5958}
6059
6160/// Calculate the dot product of two binary vectors.
@@ -129,17 +128,15 @@ fn quantize_query_vector(
129128) -> u32 {
130129 let mut sum = 0u32 ;
131130 for i in 0 ..vec. len ( ) {
132- let q = ( ( vec[ i] - lower_bound) * multiplier + bias[ i] )
133- . to_u8 ( )
134- . expect ( "convert to u8 error" ) ;
131+ let q = ( ( vec[ i] - lower_bound) * multiplier + bias[ i] ) as u8 ;
135132 quantized[ i] = q;
136133 sum += q as u32 ;
137134 }
138135 sum
139136}
140137
141138/// Find the nearest cluster for the given vector.
142- fn kmeans_nearest_cluster ( centroids : & DMatrix < f32 > , vec : & DVectorView < f32 > ) -> usize {
139+ fn kmeans_nearest_cluster ( centroids : & DMatrixView < f32 > , vec : & DVectorView < f32 > ) -> usize {
143140 let mut min_dist = f32:: MAX ;
144141 let mut min_label = 0 ;
145142 let mut residual = DVector :: < f32 > :: zeros ( vec. len ( ) ) ;
@@ -198,13 +195,13 @@ impl RaBitQ {
198195 if i % 5000 == 0 {
199196 debug ! ( "\t > preprocessing {}..." , i) ;
200197 }
201- let min_label = kmeans_nearest_cluster ( & centroids, & xp) ;
198+ let min_label = kmeans_nearest_cluster ( & centroids. as_view ( ) , & xp) ;
202199 labels[ min_label] . push ( i as u32 ) ;
203200 let x_c_quantized = xp - centroids. column ( min_label) ;
204201 x_c_distance[ i] = x_c_quantized. norm ( ) ;
205202 x_c_distance_square[ i] = x_c_distance[ i] . powi ( 2 ) ;
206- x_binary_vec. push ( vector_binarize_u64 ( & x_c_quantized) ) ;
207- x_signed_vec. push ( vector_binarize_one ( & x_c_quantized) ) ;
203+ x_binary_vec. push ( vector_binarize_u64 ( & x_c_quantized. as_view ( ) ) ) ;
204+ x_signed_vec. push ( vector_binarize_one ( & x_c_quantized. as_view ( ) ) ) ;
208205 let norm = x_c_distance[ i] * dim_sqrt;
209206 x_dot_product[ i] = if norm. is_normal ( ) {
210207 x_c_quantized. dot ( & x_signed_vec[ i] ) / norm
@@ -284,6 +281,7 @@ impl RaBitQ {
284281
285282 let mut rough_distances = Vec :: new ( ) ;
286283 let mut quantized = DVector :: < u8 > :: zeros ( self . dim as usize ) ;
284+ let mut binary_vec = vec ! [ 0u64 ; query. len( ) * THETA_LOG_DIM as usize / 64 ] ;
287285 for & ( dist, i) in lists[ ..length] . iter ( ) {
288286 y_projected. sub_to ( & self . centroids . column ( i) , & mut residual) ;
289287 let ( lower_bound, upper_bound) = min_max ( & residual. as_view ( ) ) ;
@@ -296,7 +294,8 @@ impl RaBitQ {
296294 lower_bound,
297295 one_over_delta,
298296 ) ;
299- let y_binary_vec = vector_binarize_query ( & quantized) ;
297+ binary_vec. iter_mut ( ) . for_each ( |element| * element = 0 ) ;
298+ vector_binarize_query ( & quantized. as_view ( ) , & mut binary_vec) ;
300299 let dist_sqrt = dist. sqrt ( ) ;
301300 for j in self . offsets [ i] ..self . offsets [ i + 1 ] {
302301 let ju = j as usize ;
@@ -305,7 +304,7 @@ impl RaBitQ {
305304 + dist
306305 + lower_bound * self . factor_ppc [ ju]
307306 + ( 2.0
308- * asymmetric_binary_dot_product ( & self . x_binary_vec [ ju] , & y_binary_vec )
307+ * asymmetric_binary_dot_product ( & self . x_binary_vec [ ju] , & binary_vec )
309308 as f32
310309 - scalar_sum as f32 )
311310 * self . factor_ip [ ju]
0 commit comments