Skip to content
This repository has been archived by the owner on Feb 18, 2024. It is now read-only.

Commit

Permalink
Added compare_scalar.
Browse files Browse the repository at this point in the history
  • Loading branch information
jorgecarleitao committed Aug 22, 2021
1 parent 852a1a3 commit b996f27
Show file tree
Hide file tree
Showing 4 changed files with 180 additions and 37 deletions.
28 changes: 18 additions & 10 deletions src/compute/comparison/boolean.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

use crate::array::*;
use crate::bitmap::Bitmap;
use crate::scalar::{BooleanScalar, Scalar};
use crate::{
bitmap::MutableBitmap,
error::{ArrowError, Result},
Expand Down Expand Up @@ -56,14 +57,14 @@ where

/// Evaluate `op(left, right)` for [`BooleanArray`] and scalar using
/// a specified comparison function.
pub fn compare_op_scalar<F>(lhs: &BooleanArray, rhs: bool, op: F) -> Result<BooleanArray>
pub fn compare_op_scalar<F>(lhs: &BooleanArray, rhs: bool, op: F) -> BooleanArray
where
F: Fn(bool, bool) -> bool,
{
let lhs_iter = lhs.values().iter();

let values = Bitmap::from_trusted_len_iter(lhs_iter.map(|x| op(x, rhs)));
Ok(BooleanArray::from_data(values, lhs.validity().clone()))
BooleanArray::from_data(values, lhs.validity().clone())
}

/// Perform `lhs == rhs` operation on two arrays.
Expand All @@ -72,7 +73,7 @@ pub fn eq(lhs: &BooleanArray, rhs: &BooleanArray) -> Result<BooleanArray> {
}

/// Perform `left == right` operation on an array and a scalar value.
pub fn eq_scalar(lhs: &BooleanArray, rhs: bool) -> Result<BooleanArray> {
pub fn eq_scalar(lhs: &BooleanArray, rhs: bool) -> BooleanArray {
compare_op_scalar(lhs, rhs, |a, b| a == b)
}

Expand All @@ -82,7 +83,7 @@ pub fn neq(lhs: &BooleanArray, rhs: &BooleanArray) -> Result<BooleanArray> {
}

/// Perform `left != right` operation on an array and a scalar value.
pub fn neq_scalar(lhs: &BooleanArray, rhs: bool) -> Result<BooleanArray> {
pub fn neq_scalar(lhs: &BooleanArray, rhs: bool) -> BooleanArray {
compare_op_scalar(lhs, rhs, |a, b| a != b)
}

Expand All @@ -92,7 +93,7 @@ pub fn lt(lhs: &BooleanArray, rhs: &BooleanArray) -> Result<BooleanArray> {
}

/// Perform `left < right` operation on an array and a scalar value.
pub fn lt_scalar(lhs: &BooleanArray, rhs: bool) -> Result<BooleanArray> {
pub fn lt_scalar(lhs: &BooleanArray, rhs: bool) -> BooleanArray {
compare_op_scalar(lhs, rhs, |a, b| !a & b)
}

Expand All @@ -103,7 +104,7 @@ pub fn lt_eq(lhs: &BooleanArray, rhs: &BooleanArray) -> Result<BooleanArray> {

/// Perform `left <= right` operation on an array and a scalar value.
/// Null values are less than non-null values.
pub fn lt_eq_scalar(lhs: &BooleanArray, rhs: bool) -> Result<BooleanArray> {
pub fn lt_eq_scalar(lhs: &BooleanArray, rhs: bool) -> BooleanArray {
compare_op_scalar(lhs, rhs, |a, b| a <= b)
}

Expand All @@ -115,7 +116,7 @@ pub fn gt(lhs: &BooleanArray, rhs: &BooleanArray) -> Result<BooleanArray> {

/// Perform `left > right` operation on an array and a scalar value.
/// Non-null values are greater than null values.
pub fn gt_scalar(lhs: &BooleanArray, rhs: bool) -> Result<BooleanArray> {
pub fn gt_scalar(lhs: &BooleanArray, rhs: bool) -> BooleanArray {
compare_op_scalar(lhs, rhs, |a, b| a & !b)
}

Expand All @@ -127,7 +128,7 @@ pub fn gt_eq(lhs: &BooleanArray, rhs: &BooleanArray) -> Result<BooleanArray> {

/// Perform `left >= right` operation on an array and a scalar value.
/// Non-null values are greater than null values.
pub fn gt_eq_scalar(lhs: &BooleanArray, rhs: bool) -> Result<BooleanArray> {
pub fn gt_eq_scalar(lhs: &BooleanArray, rhs: bool) -> BooleanArray {
compare_op_scalar(lhs, rhs, |a, b| a >= b)
}

Expand All @@ -142,7 +143,14 @@ pub fn compare(lhs: &BooleanArray, rhs: &BooleanArray, op: Operator) -> Result<B
}
}

pub fn compare_scalar(lhs: &BooleanArray, rhs: bool, op: Operator) -> Result<BooleanArray> {
pub fn compare_scalar(lhs: &BooleanArray, rhs: &BooleanScalar, op: Operator) -> BooleanArray {
if !rhs.is_valid() {
return BooleanArray::new_null(lhs.len());
}
compare_scalar_non_null(lhs, rhs.value(), op)
}

pub fn compare_scalar_non_null(lhs: &BooleanArray, rhs: bool, op: Operator) -> BooleanArray {
match op {
Operator::Eq => eq_scalar(lhs, rhs),
Operator::Neq => neq_scalar(lhs, rhs),
Expand Down Expand Up @@ -180,7 +188,7 @@ mod tests {
macro_rules! cmp_bool_scalar {
($KERNEL:ident, $A_VEC:expr, $B:literal, $EXPECTED:expr) => {
let a = BooleanArray::from_slice($A_VEC);
let c = $KERNEL(&a, $B).unwrap();
let c = $KERNEL(&a, $B);
assert_eq!(BooleanArray::from_slice($EXPECTED), c);
};
}
Expand Down
132 changes: 121 additions & 11 deletions src/compute/comparison/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,16 +15,12 @@
// specific language governing permissions and limitations
// under the License.

//! Defines basic comparison kernels for [`PrimitiveArray`]s.
//!
//! These kernels can leverage SIMD if available on your system. Currently no runtime
//! detection is provided, you should enable the specific SIMD intrinsics using
//! `RUSTFLAGS="-C target-feature=+avx2"` for example. See the documentation
//! [here](https://doc.rust-lang.org/stable/core/arch/) for more information.
//! Defines basic comparison kernels.

use crate::array::*;
use crate::datatypes::{DataType, IntervalUnit};
use crate::error::{ArrowError, Result};
use crate::scalar::Scalar;

mod boolean;
mod primitive;
Expand All @@ -46,7 +42,7 @@ pub enum Operator {
pub fn compare(lhs: &dyn Array, rhs: &dyn Array, operator: Operator) -> Result<BooleanArray> {
let data_type = lhs.data_type();
if data_type != rhs.data_type() {
return Err(ArrowError::NotYetImplemented(
return Err(ArrowError::InvalidArgumentError(
"Comparison is only supported for arrays of the same logical type".to_string(),
));
}
Expand Down Expand Up @@ -136,10 +132,109 @@ pub fn compare(lhs: &dyn Array, rhs: &dyn Array, operator: Operator) -> Result<B
}
}

pub use boolean::compare_scalar as boolean_compare_scalar;
pub use primitive::compare_scalar as primitive_compare_scalar;
pub fn compare_scalar(
lhs: &dyn Array,
rhs: &dyn Scalar,
operator: Operator,
) -> Result<BooleanArray> {
let data_type = lhs.data_type();
if data_type != rhs.data_type() {
return Err(ArrowError::InvalidArgumentError(
"Comparison is only supported for the same logical type".to_string(),
));
}
Ok(match data_type {
DataType::Boolean => {
let lhs = lhs.as_any().downcast_ref().unwrap();
let rhs = rhs.as_any().downcast_ref().unwrap();
boolean::compare_scalar(lhs, rhs, operator)
}
DataType::Int8 => {
let lhs = lhs.as_any().downcast_ref().unwrap();
let rhs = rhs.as_any().downcast_ref().unwrap();
primitive::compare_scalar::<i8>(lhs, rhs, operator)
}
DataType::Int16 => {
let lhs = lhs.as_any().downcast_ref().unwrap();
let rhs = rhs.as_any().downcast_ref().unwrap();
primitive::compare_scalar::<i16>(lhs, rhs, operator)
}
DataType::Int32
| DataType::Date32
| DataType::Time32(_)
| DataType::Interval(IntervalUnit::YearMonth) => {
let lhs = lhs.as_any().downcast_ref().unwrap();
let rhs = rhs.as_any().downcast_ref().unwrap();
primitive::compare_scalar::<i32>(lhs, rhs, operator)
}
DataType::Int64
| DataType::Timestamp(_, None)
| DataType::Date64
| DataType::Time64(_)
| DataType::Duration(_) => {
let lhs = lhs.as_any().downcast_ref().unwrap();
let rhs = rhs.as_any().downcast_ref().unwrap();
primitive::compare_scalar::<i64>(lhs, rhs, operator)
}
DataType::UInt8 => {
let lhs = lhs.as_any().downcast_ref().unwrap();
let rhs = rhs.as_any().downcast_ref().unwrap();
primitive::compare_scalar::<u8>(lhs, rhs, operator)
}
DataType::UInt16 => {
let lhs = lhs.as_any().downcast_ref().unwrap();
let rhs = rhs.as_any().downcast_ref().unwrap();
primitive::compare_scalar::<u16>(lhs, rhs, operator)
}
DataType::UInt32 => {
let lhs = lhs.as_any().downcast_ref().unwrap();
let rhs = rhs.as_any().downcast_ref().unwrap();
primitive::compare_scalar::<u32>(lhs, rhs, operator)
}
DataType::UInt64 => {
let lhs = lhs.as_any().downcast_ref().unwrap();
let rhs = rhs.as_any().downcast_ref().unwrap();
primitive::compare_scalar::<u64>(lhs, rhs, operator)
}
DataType::Float16 => unreachable!(),
DataType::Float32 => {
let lhs = lhs.as_any().downcast_ref().unwrap();
let rhs = rhs.as_any().downcast_ref().unwrap();
primitive::compare_scalar::<f32>(lhs, rhs, operator)
}
DataType::Float64 => {
let lhs = lhs.as_any().downcast_ref().unwrap();
let rhs = rhs.as_any().downcast_ref().unwrap();
primitive::compare_scalar::<f64>(lhs, rhs, operator)
}
DataType::Decimal(_, _) => {
let lhs = lhs.as_any().downcast_ref().unwrap();
let rhs = rhs.as_any().downcast_ref().unwrap();
primitive::compare_scalar::<i128>(lhs, rhs, operator)
}
DataType::Utf8 => {
let lhs = lhs.as_any().downcast_ref().unwrap();
let rhs = rhs.as_any().downcast_ref().unwrap();
utf8::compare_scalar::<i32>(lhs, rhs, operator)
}
DataType::LargeUtf8 => {
let lhs = lhs.as_any().downcast_ref().unwrap();
let rhs = rhs.as_any().downcast_ref().unwrap();
utf8::compare_scalar::<i64>(lhs, rhs, operator)
}
_ => {
return Err(ArrowError::NotYetImplemented(format!(
"Comparison between {:?} is not supported",
data_type
)))
}
})
}

pub use boolean::compare_scalar_non_null as boolean_compare_scalar;
pub use primitive::compare_scalar_non_null as primitive_compare_scalar;
pub(crate) use primitive::compare_values_op as primitive_compare_values_op;
pub use utf8::compare_scalar as utf8_compare_scalar;
pub use utf8::compare_scalar_non_null as utf8_compare_scalar;

/// Checks if an array of type `datatype` can be compared with another array of
/// the same type.
Expand Down Expand Up @@ -184,6 +279,8 @@ pub fn can_compare(data_type: &DataType) -> bool {

#[cfg(test)]
mod tests {
use crate::scalar::new_scalar;

use super::*;

#[test]
Expand Down Expand Up @@ -225,7 +322,8 @@ mod tests {
Duration(TimeUnit::Nanosecond),
];

datatypes.into_iter().for_each(|d1| {
// array <> array
datatypes.clone().into_iter().for_each(|d1| {
let array = new_null_array(d1.clone(), 10);
let op = Operator::Eq;
if can_compare(&d1) {
Expand All @@ -234,5 +332,17 @@ mod tests {
assert!(compare(array.as_ref(), array.as_ref(), op).is_err());
}
});

// array <> scalar
datatypes.into_iter().for_each(|d1| {
let array = new_null_array(d1.clone(), 10);
let scalar = new_scalar(array.as_ref(), 0);
let op = Operator::Eq;
if can_compare(&d1) {
assert!(compare_scalar(array.as_ref(), scalar.as_ref(), op).is_ok());
} else {
assert!(compare_scalar(array.as_ref(), scalar.as_ref(), op).is_err());
}
});
}
}
Loading

0 comments on commit b996f27

Please sign in to comment.