Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

perf: tune scalar argmin & argmax #43

Merged
merged 3 commits into from
Mar 23, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
114 changes: 97 additions & 17 deletions src/scalar/generic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,9 @@ trait SCALARInit<ScalarDType: Copy + PartialOrd> {

fn _init_max(start_value: ScalarDType) -> ScalarDType;

/// Check if we should allow the initial double update
/// Check if we should allow the updating the value(s) with the first non-NaN value

fn _allow_initial_double_update(start_value: ScalarDType) -> bool;
fn _allow_first_non_nan_update(start_value: ScalarDType) -> bool;

/// Nan check

Expand Down Expand Up @@ -55,9 +55,7 @@ pub trait ScalarArgMinMax<ScalarDType: Copy + PartialOrd> {
/// # Returns
/// The index of the minimum value in the slice.
///
fn argmin(data: &[ScalarDType]) -> usize {
Self::argminmax(data).0 // TODO: seems already to be fairly optimized :exploding_head:
}
fn argmin(data: &[ScalarDType]) -> usize;

/// Get the index of the maximum value in the slice.
///
Expand All @@ -67,9 +65,7 @@ pub trait ScalarArgMinMax<ScalarDType: Copy + PartialOrd> {
/// # Returns
/// The index of the maximum value in the slice.
///
fn argmax(data: &[ScalarDType]) -> usize {
Self::argminmax(data).1 // TODO: is slower :/
}
fn argmax(data: &[ScalarDType]) -> usize;
}

/// Type that implements the [ScalarArgMinMax](crate::ScalarArgMinMax) trait.
Expand Down Expand Up @@ -100,7 +96,7 @@ where
}

#[inline(always)]
fn _allow_initial_double_update(_start_value: ScalarDType) -> bool {
fn _allow_first_non_nan_update(_start_value: ScalarDType) -> bool {
false
}

Expand Down Expand Up @@ -128,7 +124,7 @@ where
}

#[inline(always)]
fn _allow_initial_double_update(_start_value: ScalarDType) -> bool {
fn _allow_first_non_nan_update(_start_value: ScalarDType) -> bool {
false
}

Expand Down Expand Up @@ -164,7 +160,7 @@ where
}

#[inline(always)]
fn _allow_initial_double_update(start_value: ScalarDType) -> bool {
fn _allow_first_non_nan_update(start_value: ScalarDType) -> bool {
start_value.is_nan()
}

Expand All @@ -191,23 +187,23 @@ macro_rules! impl_scalar {
let start_value: $dtype = unsafe { *arr.get_unchecked(0) };
let mut low: $dtype = Self::_init_min(start_value);
let mut high: $dtype = Self::_init_max(start_value);
let mut allow_double_update: bool = Self::_allow_initial_double_update(start_value);
let mut first_non_nan_update: bool = Self::_allow_first_non_nan_update(start_value);
for i in 0..arr.len() {
let v: $dtype = unsafe { *arr.get_unchecked(i) };
if <Self as SCALARInit<$dtype>>::_RETURN_AT_NAN && Self::_nan_check(v) {
// When _RETURN_AT_NAN is true and we encounter a NaN
return (i, i); // -> return the index
}
if allow_double_update {
// If we allow the double update (only for FloatIgnoreNaN)
if !Self::_nan_check(v) { // If the value is not a NaN
if first_non_nan_update {
// If we allow the first non-nan update (only for FloatIgnoreNaN)
if !Self::_nan_check(v) {
// Update the low and high
low = v;
low_index = i;
high = v;
high_index = i;
// And disable the double update
allow_double_update = false;
// And disable the first_non_nan_update update
first_non_nan_update = false;
}
} else if v < low {
low = v;
Expand All @@ -219,6 +215,70 @@ macro_rules! impl_scalar {
}
(low_index, high_index)
}

#[inline(always)]
fn argmin(arr: &[$dtype]) -> usize {
assert!(!arr.is_empty());
let mut low_index: usize = 0;
// It is remarkably faster to iterate over the index and use get_unchecked
// than using .iter().enumerate() (with a fold).
let start_value: $dtype = unsafe { *arr.get_unchecked(0) };
let mut low: $dtype = Self::_init_min(start_value);
let mut first_non_nan_update: bool = Self::_allow_first_non_nan_update(start_value);
for i in 0..arr.len() {
let v: $dtype = unsafe { *arr.get_unchecked(i) };
if <Self as SCALARInit<$dtype>>::_RETURN_AT_NAN && Self::_nan_check(v) {
// When _RETURN_AT_NAN is true and we encounter a NaN
return i; // -> return the index
}
if first_non_nan_update {
// If we allow the first non-nan update (only for FloatIgnoreNaN)
if !Self::_nan_check(v) {
// Update the low
low = v;
low_index = i;
// And disable the first_non_nan_update update
first_non_nan_update = false;
}
} else if v < low {
low = v;
low_index = i;
}
}
low_index
}

#[inline(always)]
fn argmax(arr: &[$dtype]) -> usize {
assert!(!arr.is_empty());
let mut high_index: usize = 0;
// It is remarkably faster to iterate over the index and use get_unchecked
// than using .iter().enumerate() (with a fold).
let start_value: $dtype = unsafe { *arr.get_unchecked(0) };
let mut high: $dtype = Self::_init_max(start_value);
let mut first_non_nan_update: bool = Self::_allow_first_non_nan_update(start_value);
for i in 0..arr.len() {
let v: $dtype = unsafe { *arr.get_unchecked(i) };
if <Self as SCALARInit<$dtype>>::_RETURN_AT_NAN && Self::_nan_check(v) {
// When _RETURN_AT_NAN is true and we encounter a NaN
return i; // -> return the index
}
if first_non_nan_update {
// If we allow the first non-nan update (only for FloatIgnoreNaN)
if !Self::_nan_check(v) {
// Update the high
high = v;
high_index = i;
// And disable the first_non_nan_update update
first_non_nan_update = false;
}
} else if v > high {
high = v;
high_index = i;
}
}
high_index
}
}
)*
};
Expand All @@ -243,6 +303,16 @@ impl ScalarArgMinMax<f16> for SCALAR<FloatReturnNaN> {
fn argminmax(arr: &[f16]) -> (usize, usize) {
scalar_argminmax_f16_return_nan(arr)
}

#[inline(always)]
fn argmin(arr: &[f16]) -> usize {
scalar_argminmax_f16_return_nan(arr).0
}

#[inline(always)]
fn argmax(arr: &[f16]) -> usize {
scalar_argminmax_f16_return_nan(arr).1
}
}

#[cfg(feature = "half")]
Expand All @@ -252,4 +322,14 @@ impl ScalarArgMinMax<f16> for SCALAR<FloatIgnoreNaN> {
fn argminmax(arr: &[f16]) -> (usize, usize) {
scalar_argminmax_f16_return_nan(arr)
}

#[inline(always)]
fn argmin(arr: &[f16]) -> usize {
scalar_argminmax_f16_return_nan(arr).0
}

#[inline(always)]
fn argmax(arr: &[f16]) -> usize {
scalar_argminmax_f16_return_nan(arr).1
}
}