@@ -107,23 +107,35 @@ pub unsafe fn vector_binarize_query_avx2(vec: &DVectorView<u8>, binary: &mut [u6
107107/// This function is marked unsafe because it requires the AVX intrinsics.
108108#[ cfg( any( target_arch = "x86_64" , target_arch = "x86" ) ) ]
109109#[ target_feature( enable = "avx" ) ]
110- pub unsafe fn min_max_avx ( vec : & DVectorView < f32 > ) -> ( f32 , f32 ) {
110+ pub unsafe fn min_max_residual_avx (
111+ res : & mut DVector < f32 > ,
112+ x : & DVectorView < f32 > ,
113+ y : & DVectorView < f32 > ,
114+ ) -> ( f32 , f32 ) {
111115 use std:: arch:: x86_64:: * ;
112116
113117 let mut min_32x8 = _mm256_set1_ps ( f32:: MAX ) ;
114118 let mut max_32x8 = _mm256_set1_ps ( f32:: MIN ) ;
115- let mut ptr = vec. as_ptr ( ) ;
119+ let mut x_ptr = x. as_ptr ( ) ;
120+ let mut y_ptr = y. as_ptr ( ) ;
121+ let mut res_ptr = res. as_mut_ptr ( ) ;
116122 let mut f32x8 = [ 0.0f32 ; 8 ] ;
117123 let mut min = f32:: MAX ;
118124 let mut max = f32:: MIN ;
119- let length = vec . len ( ) ;
125+ let length = res . len ( ) ;
120126 let rest = length & 0b111 ;
127+ let ( mut x256, mut y256, mut res256) ;
121128
122129 for _ in 0 ..( length / 8 ) {
123- let v = _mm256_loadu_ps ( ptr) ;
124- ptr = ptr. add ( 8 ) ;
125- min_32x8 = _mm256_min_ps ( min_32x8, v) ;
126- max_32x8 = _mm256_max_ps ( max_32x8, v) ;
130+ x256 = _mm256_loadu_ps ( x_ptr) ;
131+ y256 = _mm256_loadu_ps ( y_ptr) ;
132+ res256 = _mm256_sub_ps ( x256, y256) ;
133+ _mm256_storeu_ps ( res_ptr, res256) ;
134+ x_ptr = x_ptr. add ( 8 ) ;
135+ y_ptr = y_ptr. add ( 8 ) ;
136+ res_ptr = res_ptr. add ( 8 ) ;
137+ min_32x8 = _mm256_min_ps ( min_32x8, res256) ;
138+ max_32x8 = _mm256_max_ps ( max_32x8, res256) ;
127139 }
128140 _mm256_storeu_ps ( f32x8. as_mut_ptr ( ) , min_32x8) ;
129141 for & x in f32x8. iter ( ) {
@@ -139,23 +151,30 @@ pub unsafe fn min_max_avx(vec: &DVectorView<f32>) -> (f32, f32) {
139151 }
140152
141153 for _ in 0 ..rest {
142- if * ptr < min {
143- min = * ptr;
154+ * res_ptr = * x_ptr - * y_ptr;
155+ if * res_ptr < min {
156+ min = * res_ptr;
144157 }
145- if * ptr > max {
146- max = * ptr ;
158+ if * res_ptr > max {
159+ max = * res_ptr ;
147160 }
148- ptr = ptr. add ( 1 ) ;
161+ res_ptr = res_ptr. add ( 1 ) ;
162+ x_ptr = x_ptr. add ( 1 ) ;
163+ y_ptr = y_ptr. add ( 1 ) ;
149164 }
150165
151166 ( min, max)
152167}
153168
154169/// Compute the u8 scalar quantization of a f32 vector.
155170///
171+ /// This function doesn't need `bias` because it *round* the f32 to u32 instead of *floor*.
172+ ///
156173/// # Safety
157174///
158175/// This function is marked unsafe because it requires the AVX intrinsics.
176+ #[ cfg( any( target_arch = "x86_64" , target_arch = "x86" ) ) ]
177+ #[ target_feature( enable = "avx2" ) ]
159178pub unsafe fn scalar_quantize_avx2 (
160179 quantized : & mut DVector < u8 > ,
161180 vec : & DVectorView < f32 > ,
0 commit comments