Skip to content
This repository has been archived by the owner on Feb 18, 2024. It is now read-only.

Improved performance of comparison with SIMD feature flag (2x-3.5x) #305

Merged
merged 3 commits into from
Aug 20, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion benches/comparison_kernels.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ where

fn bench_op_scalar<T>(arr_a: &PrimitiveArray<T>, value_b: T, op: Operator)
where
T: NativeType + std::cmp::PartialOrd,
T: NativeType + Simd8,
{
primitive_compare_scalar(
criterion::black_box(arr_a),
Expand Down
3 changes: 3 additions & 0 deletions src/compute/comparison/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,9 @@ mod boolean;
mod primitive;
mod utf8;

mod simd;
pub use simd::{Simd8, Simd8Lanes};

#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum Operator {
Lt,
Expand Down
122 changes: 52 additions & 70 deletions src/compute/comparison/primitive.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,45 +23,33 @@ use crate::{
error::{ArrowError, Result},
};

use super::simd::{Simd8, Simd8Lanes};
use super::{super::utils::combine_validities, Operator};

pub(crate) fn compare_values_op<T, F>(lhs: &[T], rhs: &[T], op: F) -> MutableBitmap
where
T: NativeType,
F: Fn(T, T) -> bool,
T: NativeType + Simd8,
F: Fn(T::Simd, T::Simd) -> u8,
{
assert_eq!(lhs.len(), rhs.len());
let mut values = MutableBuffer::from_len_zeroed((lhs.len() + 7) / 8);

let lhs_chunks_iter = lhs.chunks_exact(8);
let lhs_remainder = lhs_chunks_iter.remainder();
let rhs_chunks_iter = rhs.chunks_exact(8);
let rhs_remainder = rhs_chunks_iter.remainder();

let chunks = lhs.len() / 8;

values[..chunks]
.iter_mut()
.zip(lhs_chunks_iter)
.zip(rhs_chunks_iter)
.for_each(|((byte, lhs), rhs)| {
lhs.iter()
.zip(rhs.iter())
.enumerate()
.for_each(|(i, (&lhs, &rhs))| {
*byte |= if op(lhs, rhs) { 1 << i } else { 0 };
});
});
let mut values = MutableBuffer::with_capacity((lhs.len() + 7) / 8);
let iterator = lhs_chunks_iter.zip(rhs_chunks_iter).map(|(lhs, rhs)| {
let lhs = T::Simd::from_chunk(lhs);
let rhs = T::Simd::from_chunk(rhs);
op(lhs, rhs)
});
values.extend_from_trusted_len_iter(iterator);

if !lhs_remainder.is_empty() {
let last = &mut values[chunks];
lhs_remainder
.iter()
.zip(rhs_remainder.iter())
.enumerate()
.for_each(|(i, (&lhs, &rhs))| {
*last |= if op(lhs, rhs) { 1 << i } else { 0 };
});
let lhs = T::Simd::from_incomplete_chunk(lhs_remainder, T::default());
let rhs = T::Simd::from_incomplete_chunk(rhs_remainder, T::default());
values.push(op(lhs, rhs))
};
MutableBitmap::from_buffer(values, lhs.len())
}
Expand All @@ -70,8 +58,8 @@ where
/// comparison function.
fn compare_op<T, F>(lhs: &PrimitiveArray<T>, rhs: &PrimitiveArray<T>, op: F) -> Result<BooleanArray>
where
T: NativeType,
F: Fn(T, T) -> bool,
T: NativeType + Simd8,
F: Fn(T::Simd, T::Simd) -> u8,
{
if lhs.len() != rhs.len() {
return Err(ArrowError::InvalidArgumentError(
Expand All @@ -90,31 +78,25 @@ where
/// a specified comparison function.
pub fn compare_op_scalar<T, F>(lhs: &PrimitiveArray<T>, rhs: T, op: F) -> Result<BooleanArray>
where
T: NativeType,
F: Fn(T, T) -> bool,
T: NativeType + Simd8,
F: Fn(T::Simd, T::Simd) -> u8,
{
let validity = lhs.validity().clone();

let mut values = MutableBuffer::from_len_zeroed((lhs.len() + 7) / 8);
let rhs = T::Simd::from_chunk(&[rhs; 8]);

let lhs_chunks_iter = lhs.values().chunks_exact(8);
let lhs_remainder = lhs_chunks_iter.remainder();
let chunks = lhs.len() / 8;

values[..chunks]
.iter_mut()
.zip(lhs_chunks_iter)
.for_each(|(byte, chunk)| {
chunk.iter().enumerate().for_each(|(i, &c_i)| {
*byte |= if op(c_i, rhs) { 1 << i } else { 0 };
});
});
let mut values = MutableBuffer::with_capacity((lhs.len() + 7) / 8);
let iterator = lhs_chunks_iter.map(|lhs| {
let lhs = T::Simd::from_chunk(lhs);
op(lhs, rhs)
});
values.extend_from_trusted_len_iter(iterator);

if !lhs_remainder.is_empty() {
let last = &mut values[chunks];
lhs_remainder.iter().enumerate().for_each(|(i, &lhs)| {
*last |= if op(lhs, rhs) { 1 << i } else { 0 };
});
let lhs = T::Simd::from_incomplete_chunk(lhs_remainder, T::default());
values.push(op(lhs, rhs))
};

Ok(BooleanArray::from_data(
Expand All @@ -126,105 +108,105 @@ where
/// Perform `lhs == rhs` operation on two arrays.
pub fn eq<T>(lhs: &PrimitiveArray<T>, rhs: &PrimitiveArray<T>) -> Result<BooleanArray>
where
T: NativeType,
T: NativeType + Simd8,
{
compare_op(lhs, rhs, |a, b| a == b)
compare_op(lhs, rhs, |a, b| a.eq(b))
}

/// Perform `left == right` operation on an array and a scalar value.
pub fn eq_scalar<T>(lhs: &PrimitiveArray<T>, rhs: T) -> Result<BooleanArray>
where
T: NativeType,
T: NativeType + Simd8,
{
compare_op_scalar(lhs, rhs, |a, b| a == b)
compare_op_scalar(lhs, rhs, |a, b| a.eq(b))
}

/// Perform `left != right` operation on two arrays.
pub fn neq<T>(lhs: &PrimitiveArray<T>, rhs: &PrimitiveArray<T>) -> Result<BooleanArray>
where
T: NativeType,
T: NativeType + Simd8,
{
compare_op(lhs, rhs, |a, b| a != b)
compare_op(lhs, rhs, |a, b| a.neq(b))
}

/// Perform `left != right` operation on an array and a scalar value.
pub fn neq_scalar<T>(lhs: &PrimitiveArray<T>, rhs: T) -> Result<BooleanArray>
where
T: NativeType,
T: NativeType + Simd8,
{
compare_op_scalar(lhs, rhs, |a, b| a != b)
compare_op_scalar(lhs, rhs, |a, b| a.neq(b))
}

/// Perform `left < right` operation on two arrays.
pub fn lt<T>(lhs: &PrimitiveArray<T>, rhs: &PrimitiveArray<T>) -> Result<BooleanArray>
where
T: NativeType + std::cmp::PartialOrd,
T: NativeType + Simd8,
{
compare_op(lhs, rhs, |a, b| a < b)
compare_op(lhs, rhs, |a, b| a.lt(b))
}

/// Perform `left < right` operation on an array and a scalar value.
pub fn lt_scalar<T>(lhs: &PrimitiveArray<T>, rhs: T) -> Result<BooleanArray>
where
T: NativeType + std::cmp::PartialOrd,
T: NativeType + Simd8,
{
compare_op_scalar(lhs, rhs, |a, b| a < b)
compare_op_scalar(lhs, rhs, |a, b| a.lt(b))
}

/// Perform `left <= right` operation on two arrays.
pub fn lt_eq<T>(lhs: &PrimitiveArray<T>, rhs: &PrimitiveArray<T>) -> Result<BooleanArray>
where
T: NativeType + std::cmp::PartialOrd,
T: NativeType + Simd8,
{
compare_op(lhs, rhs, |a, b| a <= b)
compare_op(lhs, rhs, |a, b| a.lt_eq(b))
}

/// Perform `left <= right` operation on an array and a scalar value.
/// Null values are less than non-null values.
pub fn lt_eq_scalar<T>(lhs: &PrimitiveArray<T>, rhs: T) -> Result<BooleanArray>
where
T: NativeType + std::cmp::PartialOrd,
T: NativeType + Simd8,
{
compare_op_scalar(lhs, rhs, |a, b| a <= b)
compare_op_scalar(lhs, rhs, |a, b| a.lt_eq(b))
}

/// Perform `left > right` operation on two arrays. Non-null values are greater than null
/// values.
pub fn gt<T>(lhs: &PrimitiveArray<T>, rhs: &PrimitiveArray<T>) -> Result<BooleanArray>
where
T: NativeType + std::cmp::PartialOrd,
T: NativeType + Simd8,
{
compare_op(lhs, rhs, |a, b| a > b)
compare_op(lhs, rhs, |a, b| a.gt(b))
}

/// Perform `left > right` operation on an array and a scalar value.
/// Non-null values are greater than null values.
pub fn gt_scalar<T>(lhs: &PrimitiveArray<T>, rhs: T) -> Result<BooleanArray>
where
T: NativeType + std::cmp::PartialOrd,
T: NativeType + Simd8,
{
compare_op_scalar(lhs, rhs, |a, b| a > b)
compare_op_scalar(lhs, rhs, |a, b| a.gt(b))
}

/// Perform `left >= right` operation on two arrays. Non-null values are greater than null
/// values.
pub fn gt_eq<T>(lhs: &PrimitiveArray<T>, rhs: &PrimitiveArray<T>) -> Result<BooleanArray>
where
T: NativeType + std::cmp::PartialOrd,
T: NativeType + Simd8,
{
compare_op(lhs, rhs, |a, b| a >= b)
compare_op(lhs, rhs, |a, b| a.gt_eq(b))
}

/// Perform `left >= right` operation on an array and a scalar value.
/// Non-null values are greater than null values.
pub fn gt_eq_scalar<T>(lhs: &PrimitiveArray<T>, rhs: T) -> Result<BooleanArray>
where
T: NativeType + std::cmp::PartialOrd,
T: NativeType + Simd8,
{
compare_op_scalar(lhs, rhs, |a, b| a >= b)
compare_op_scalar(lhs, rhs, |a, b| a.gt_eq(b))
}

pub fn compare<T: NativeType + std::cmp::PartialOrd>(
pub fn compare<T: NativeType + Simd8>(
lhs: &PrimitiveArray<T>,
rhs: &PrimitiveArray<T>,
op: Operator,
Expand All @@ -239,7 +221,7 @@ pub fn compare<T: NativeType + std::cmp::PartialOrd>(
}
}

pub fn compare_scalar<T: NativeType + std::cmp::PartialOrd>(
pub fn compare_scalar<T: NativeType + Simd8>(
lhs: &PrimitiveArray<T>,
rhs: T,
op: Operator,
Expand Down
91 changes: 91 additions & 0 deletions src/compute/comparison/simd/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
use crate::types::NativeType;

/// [`NativeType`] that supports a representation of 8 lanes
pub trait Simd8: NativeType {
type Simd: Simd8Lanes<Self>;
}

pub trait Simd8Lanes<T>: Copy {
fn from_chunk(v: &[T]) -> Self;
fn from_incomplete_chunk(v: &[T], remaining: T) -> Self;
fn eq(self, other: Self) -> u8;
fn neq(self, other: Self) -> u8;
fn lt_eq(self, other: Self) -> u8;
fn lt(self, other: Self) -> u8;
fn gt(self, other: Self) -> u8;
fn gt_eq(self, other: Self) -> u8;
}

#[inline]
pub(super) fn set<T: Copy, F: Fn(T, T) -> bool>(lhs: [T; 8], rhs: [T; 8], op: F) -> u8 {
let mut byte = 0u8;
lhs.iter()
.zip(rhs.iter())
.enumerate()
.for_each(|(i, (lhs, rhs))| {
byte |= if op(*lhs, *rhs) { 1 << i } else { 0 };
});
byte
}

macro_rules! simd8_native {
($type:ty) => {
impl Simd8 for $type {
type Simd = [$type; 8];
}

impl Simd8Lanes<$type> for [$type; 8] {
#[inline]
fn from_chunk(v: &[$type]) -> Self {
v.try_into().unwrap()
}

#[inline]
fn from_incomplete_chunk(v: &[$type], remaining: $type) -> Self {
let mut a = [remaining; 8];
a.iter_mut().zip(v.iter()).for_each(|(a, b)| *a = *b);
a
}

#[inline]
fn eq(self, other: Self) -> u8 {
set(self, other, |x, y| x == y)
}

#[inline]
fn neq(self, other: Self) -> u8 {
#[allow(clippy::float_cmp)]
set(self, other, |x, y| x != y)
}

#[inline]
fn lt_eq(self, other: Self) -> u8 {
set(self, other, |x, y| x <= y)
}

#[inline]
fn lt(self, other: Self) -> u8 {
set(self, other, |x, y| x < y)
}

#[inline]
fn gt_eq(self, other: Self) -> u8 {
set(self, other, |x, y| x >= y)
}

#[inline]
fn gt(self, other: Self) -> u8 {
set(self, other, |x, y| x > y)
}
}
};
}

#[cfg(not(feature = "simd"))]
mod native;
#[cfg(not(feature = "simd"))]
pub use native::*;
#[cfg(feature = "simd")]
mod packed;
#[cfg(feature = "simd")]
pub use packed::*;
15 changes: 15 additions & 0 deletions src/compute/comparison/simd/native.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
use std::convert::TryInto;

use super::{set, Simd8, Simd8Lanes};

simd8_native!(u8);
simd8_native!(u16);
simd8_native!(u32);
simd8_native!(u64);
simd8_native!(i8);
simd8_native!(i16);
simd8_native!(i32);
simd8_native!(i128);
simd8_native!(i64);
simd8_native!(f32);
simd8_native!(f64);
Loading