Skip to content
This repository has been archived by the owner on Feb 18, 2024. It is now read-only.

Commit

Permalink
Fixed error in dispatching scalar arithmetics. (#682)
Browse files Browse the repository at this point in the history
  • Loading branch information
jorgecarleitao committed Dec 13, 2021
1 parent bfa2d43 commit 280ed1d
Show file tree
Hide file tree
Showing 2 changed files with 119 additions and 7 deletions.
105 changes: 99 additions & 6 deletions src/compute/arithmetics/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,10 @@ pub mod decimal;
pub mod time;

use crate::{
array::Array,
array::{Array, PrimitiveArray},
bitmap::Bitmap,
datatypes::{DataType, IntervalUnit, TimeUnit},
scalar::Scalar,
scalar::{PrimitiveScalar, Scalar},
};

// Macro to evaluate match branch in arithmetic function.
Expand Down Expand Up @@ -101,6 +101,99 @@ macro_rules! arith {
}};
}

// Macro to evaluate match branch in arithmetic function.
macro_rules! primitive_scalar {
($lhs:expr, $rhs:expr, $op:tt, $type:ty) => {{
let lhs = $lhs
.as_any()
.downcast_ref::<PrimitiveArray<$type>>()
.unwrap();
let rhs = $rhs
.as_any()
.downcast_ref::<PrimitiveScalar<$type>>()
.unwrap();

let rhs = if let Some(rhs) = rhs.value() {
rhs
} else {
return Box::new(PrimitiveArray::<$type>::new_null(
lhs.data_type().clone(),
lhs.len(),
)) as Box<dyn Array>;
};

let result = basic::$op::<$type>(lhs, &rhs);
Box::new(result) as Box<dyn Array>
}};
}

// Macro to create a `match` statement with dynamic dispatch to functions based on
// the array's logical types
macro_rules! arith_scalar {
($lhs:expr, $rhs:expr, $op:tt $(, decimal = $op_decimal:tt )? $(, duration = $op_duration:tt )? $(, interval = $op_interval:tt )? $(, timestamp = $op_timestamp:tt )?) => {{
let lhs = $lhs;
let rhs = $rhs;
use DataType::*;
match (lhs.data_type(), rhs.data_type()) {
(Int8, Int8) => primitive_scalar!(lhs, rhs, $op, i8),
(Int16, Int16) => primitive_scalar!(lhs, rhs, $op, i16),
(Int32, Int32) => primitive_scalar!(lhs, rhs, $op, i32),
(Int64, Int64) | (Duration(_), Duration(_)) => {
primitive_scalar!(lhs, rhs, $op, i64)
}
(UInt8, UInt8) => primitive_scalar!(lhs, rhs, $op, u8),
(UInt16, UInt16) => primitive_scalar!(lhs, rhs, $op, u16),
(UInt32, UInt32) => primitive_scalar!(lhs, rhs, $op, u32),
(UInt64, UInt64) => primitive_scalar!(lhs, rhs, $op, u64),
(Float32, Float32) => primitive_scalar!(lhs, rhs, $op, f32),
(Float64, Float64) => primitive_scalar!(lhs, rhs, $op, f64),
$ (
(Decimal(_, _), Decimal(_, _)) => {
let lhs = lhs.as_any().downcast_ref().unwrap();
let rhs = rhs.as_any().downcast_ref().unwrap();
Box::new(decimal::$op_decimal(lhs, rhs)) as Box<dyn Array>
}
)?
$ (
(Time32(TimeUnit::Second), Duration(_))
| (Time32(TimeUnit::Millisecond), Duration(_))
| (Date32, Duration(_)) => {
let lhs = lhs.as_any().downcast_ref().unwrap();
let rhs = rhs.as_any().downcast_ref().unwrap();
Box::new(time::$op_duration::<i32>(lhs, rhs)) as Box<dyn Array>
}
(Time64(TimeUnit::Microsecond), Duration(_))
| (Time64(TimeUnit::Nanosecond), Duration(_))
| (Date64, Duration(_))
| (Timestamp(_, _), Duration(_)) => {
let lhs = lhs.as_any().downcast_ref().unwrap();
let rhs = rhs.as_any().downcast_ref().unwrap();
Box::new(time::$op_duration::<i64>(lhs, rhs)) as Box<dyn Array>
}
)?
$ (
(Timestamp(_, _), Interval(IntervalUnit::MonthDayNano)) => {
let lhs = lhs.as_any().downcast_ref().unwrap();
let rhs = rhs.as_any().downcast_ref().unwrap();
time::$op_interval(lhs, rhs).map(|x| Box::new(x) as Box<dyn Array>).unwrap()
}
)?
$ (
(Timestamp(_, None), Timestamp(_, None)) => {
let lhs = lhs.as_any().downcast_ref().unwrap();
let rhs = rhs.as_any().downcast_ref().unwrap();
time::$op_timestamp(lhs, rhs).map(|x| Box::new(x) as Box<dyn Array>).unwrap()
}
)?
_ => todo!(
"Addition of {:?} with {:?} is not supported",
lhs.data_type(),
rhs.data_type()
),
}
}};
}

/// Adds two [`Array`]s.
/// # Panic
/// This function panics iff
Expand All @@ -124,7 +217,7 @@ pub fn add(lhs: &dyn Array, rhs: &dyn Array) -> Box<dyn Array> {
/// * the arrays have a different length
/// * one of the arrays is a timestamp with timezone and the timezone is not valid.
pub fn add_scalar(lhs: &dyn Array, rhs: &dyn Scalar) -> Box<dyn Array> {
arith!(
arith_scalar!(
lhs,
rhs,
add_scalar,
Expand Down Expand Up @@ -185,7 +278,7 @@ pub fn sub(lhs: &dyn Array, rhs: &dyn Array) -> Box<dyn Array> {
/// * the arrays have a different length
/// * one of the arrays is a timestamp with timezone and the timezone is not valid.
pub fn sub_scalar(lhs: &dyn Array, rhs: &dyn Scalar) -> Box<dyn Array> {
arith!(
arith_scalar!(
lhs,
rhs,
sub_scalar,
Expand Down Expand Up @@ -236,7 +329,7 @@ pub fn mul(lhs: &dyn Array, rhs: &dyn Array) -> Box<dyn Array> {
/// This function panics iff
/// * the opertion is not supported for the logical types (use [`can_mul`] to check)
pub fn mul_scalar(lhs: &dyn Array, rhs: &dyn Scalar) -> Box<dyn Array> {
arith!(lhs, rhs, mul_scalar, decimal = mul_scalar)
arith_scalar!(lhs, rhs, mul_scalar, decimal = mul_scalar)
}

/// Returns whether two [`DataType`]s can be multiplied by [`mul`].
Expand Down Expand Up @@ -272,7 +365,7 @@ pub fn div(lhs: &dyn Array, rhs: &dyn Array) -> Box<dyn Array> {
/// This function panics iff
/// * the opertion is not supported for the logical types (use [`can_div`] to check)
pub fn div_scalar(lhs: &dyn Array, rhs: &dyn Scalar) -> Box<dyn Array> {
arith!(lhs, rhs, div_scalar, decimal = div_scalar)
arith_scalar!(lhs, rhs, div_scalar, decimal = div_scalar)
}

/// Returns whether two [`DataType`]s can be divided by [`div`].
Expand Down
21 changes: 20 additions & 1 deletion tests/it/compute/arithmetics/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,29 @@ mod basic;
mod decimal;
mod time;

use arrow2::array::new_empty_array;
use arrow2::array::{new_empty_array, Int32Array};
use arrow2::compute::arithmetics::*;
use arrow2::datatypes::DataType::*;
use arrow2::datatypes::{IntervalUnit, TimeUnit};
use arrow2::scalar::PrimitiveScalar;

#[test]
fn test_add() {
let a = Int32Array::from(&[None, Some(6), None, Some(6)]);
let b = Int32Array::from(&[Some(5), None, None, Some(6)]);
let result = add(&a, &b);
let expected = Int32Array::from(&[None, None, None, Some(12)]);
assert_eq!(expected, result.as_ref());
}

#[test]
fn test_add_scalar() {
let a = Int32Array::from(&[None, Some(6), None, Some(6)]);
let b: PrimitiveScalar<i32> = Some(1i32).into();
let result = add_scalar(&a, &b);
let expected = Int32Array::from(&[None, Some(7), None, Some(7)]);
assert_eq!(expected, result.as_ref());
}

#[test]
fn consistency() {
Expand Down

0 comments on commit 280ed1d

Please sign in to comment.