From 43c9e98e7e311111e8ef57426c3c28254dd94d15 Mon Sep 17 00:00:00 2001 From: "Jorge C. Leitao" Date: Mon, 13 Dec 2021 20:23:19 +0000 Subject: [PATCH] Fixed error in dispatching scalar arithmetics. --- src/compute/arithmetics/mod.rs | 105 ++++++++++++++++++++++++++-- tests/it/compute/arithmetics/mod.rs | 21 +++++- 2 files changed, 119 insertions(+), 7 deletions(-) diff --git a/src/compute/arithmetics/mod.rs b/src/compute/arithmetics/mod.rs index c7b5f7af39f..2dd1ae5a484 100644 --- a/src/compute/arithmetics/mod.rs +++ b/src/compute/arithmetics/mod.rs @@ -17,10 +17,10 @@ pub mod decimal; pub mod time; use crate::{ - array::Array, + array::{Array, PrimitiveArray}, bitmap::Bitmap, datatypes::{DataType, IntervalUnit, TimeUnit}, - scalar::Scalar, + scalar::{PrimitiveScalar, Scalar}, }; // Macro to evaluate match branch in arithmetic function. @@ -101,6 +101,99 @@ macro_rules! arith { }}; } +// Macro to evaluate match branch in arithmetic function. +macro_rules! primitive_scalar { + ($lhs:expr, $rhs:expr, $op:tt, $type:ty) => {{ + let lhs = $lhs + .as_any() + .downcast_ref::>() + .unwrap(); + let rhs = $rhs + .as_any() + .downcast_ref::>() + .unwrap(); + + let rhs = if let Some(rhs) = rhs.value() { + rhs + } else { + return Box::new(PrimitiveArray::<$type>::new_null( + lhs.data_type().clone(), + lhs.len(), + )) as Box; + }; + + let result = basic::$op::<$type>(lhs, &rhs); + Box::new(result) as Box + }}; +} + +// Macro to create a `match` statement with dynamic dispatch to functions based on +// the array's logical types +macro_rules! arith_scalar { + ($lhs:expr, $rhs:expr, $op:tt $(, decimal = $op_decimal:tt )? $(, duration = $op_duration:tt )? $(, interval = $op_interval:tt )? $(, timestamp = $op_timestamp:tt )?) => {{ + let lhs = $lhs; + let rhs = $rhs; + use DataType::*; + match (lhs.data_type(), rhs.data_type()) { + (Int8, Int8) => primitive_scalar!(lhs, rhs, $op, i8), + (Int16, Int16) => primitive_scalar!(lhs, rhs, $op, i16), + (Int32, Int32) => primitive_scalar!(lhs, rhs, $op, i32), + (Int64, Int64) | (Duration(_), Duration(_)) => { + primitive_scalar!(lhs, rhs, $op, i64) + } + (UInt8, UInt8) => primitive_scalar!(lhs, rhs, $op, u8), + (UInt16, UInt16) => primitive_scalar!(lhs, rhs, $op, u16), + (UInt32, UInt32) => primitive_scalar!(lhs, rhs, $op, u32), + (UInt64, UInt64) => primitive_scalar!(lhs, rhs, $op, u64), + (Float32, Float32) => primitive_scalar!(lhs, rhs, $op, f32), + (Float64, Float64) => primitive_scalar!(lhs, rhs, $op, f64), + $ ( + (Decimal(_, _), Decimal(_, _)) => { + let lhs = lhs.as_any().downcast_ref().unwrap(); + let rhs = rhs.as_any().downcast_ref().unwrap(); + Box::new(decimal::$op_decimal(lhs, rhs)) as Box + } + )? + $ ( + (Time32(TimeUnit::Second), Duration(_)) + | (Time32(TimeUnit::Millisecond), Duration(_)) + | (Date32, Duration(_)) => { + let lhs = lhs.as_any().downcast_ref().unwrap(); + let rhs = rhs.as_any().downcast_ref().unwrap(); + Box::new(time::$op_duration::(lhs, rhs)) as Box + } + (Time64(TimeUnit::Microsecond), Duration(_)) + | (Time64(TimeUnit::Nanosecond), Duration(_)) + | (Date64, Duration(_)) + | (Timestamp(_, _), Duration(_)) => { + let lhs = lhs.as_any().downcast_ref().unwrap(); + let rhs = rhs.as_any().downcast_ref().unwrap(); + Box::new(time::$op_duration::(lhs, rhs)) as Box + } + )? + $ ( + (Timestamp(_, _), Interval(IntervalUnit::MonthDayNano)) => { + let lhs = lhs.as_any().downcast_ref().unwrap(); + let rhs = rhs.as_any().downcast_ref().unwrap(); + time::$op_interval(lhs, rhs).map(|x| Box::new(x) as Box).unwrap() + } + )? + $ ( + (Timestamp(_, None), Timestamp(_, None)) => { + let lhs = lhs.as_any().downcast_ref().unwrap(); + let rhs = rhs.as_any().downcast_ref().unwrap(); + time::$op_timestamp(lhs, rhs).map(|x| Box::new(x) as Box).unwrap() + } + )? + _ => todo!( + "Addition of {:?} with {:?} is not supported", + lhs.data_type(), + rhs.data_type() + ), + } + }}; +} + /// Adds two [`Array`]s. /// # Panic /// This function panics iff @@ -124,7 +217,7 @@ pub fn add(lhs: &dyn Array, rhs: &dyn Array) -> Box { /// * the arrays have a different length /// * one of the arrays is a timestamp with timezone and the timezone is not valid. pub fn add_scalar(lhs: &dyn Array, rhs: &dyn Scalar) -> Box { - arith!( + arith_scalar!( lhs, rhs, add_scalar, @@ -185,7 +278,7 @@ pub fn sub(lhs: &dyn Array, rhs: &dyn Array) -> Box { /// * the arrays have a different length /// * one of the arrays is a timestamp with timezone and the timezone is not valid. pub fn sub_scalar(lhs: &dyn Array, rhs: &dyn Scalar) -> Box { - arith!( + arith_scalar!( lhs, rhs, sub_scalar, @@ -236,7 +329,7 @@ pub fn mul(lhs: &dyn Array, rhs: &dyn Array) -> Box { /// This function panics iff /// * the opertion is not supported for the logical types (use [`can_mul`] to check) pub fn mul_scalar(lhs: &dyn Array, rhs: &dyn Scalar) -> Box { - arith!(lhs, rhs, mul_scalar, decimal = mul_scalar) + arith_scalar!(lhs, rhs, mul_scalar, decimal = mul_scalar) } /// Returns whether two [`DataType`]s can be multiplied by [`mul`]. @@ -272,7 +365,7 @@ pub fn div(lhs: &dyn Array, rhs: &dyn Array) -> Box { /// This function panics iff /// * the opertion is not supported for the logical types (use [`can_div`] to check) pub fn div_scalar(lhs: &dyn Array, rhs: &dyn Scalar) -> Box { - arith!(lhs, rhs, div_scalar, decimal = div_scalar) + arith_scalar!(lhs, rhs, div_scalar, decimal = div_scalar) } /// Returns whether two [`DataType`]s can be divided by [`div`]. diff --git a/tests/it/compute/arithmetics/mod.rs b/tests/it/compute/arithmetics/mod.rs index 4b5e6a3e241..222f1c3d84c 100644 --- a/tests/it/compute/arithmetics/mod.rs +++ b/tests/it/compute/arithmetics/mod.rs @@ -2,10 +2,29 @@ mod basic; mod decimal; mod time; -use arrow2::array::new_empty_array; +use arrow2::array::{new_empty_array, Int32Array}; use arrow2::compute::arithmetics::*; use arrow2::datatypes::DataType::*; use arrow2::datatypes::{IntervalUnit, TimeUnit}; +use arrow2::scalar::PrimitiveScalar; + +#[test] +fn test_add() { + let a = Int32Array::from(&[None, Some(6), None, Some(6)]); + let b = Int32Array::from(&[Some(5), None, None, Some(6)]); + let result = add(&a, &b); + let expected = Int32Array::from(&[None, None, None, Some(12)]); + assert_eq!(expected, result.as_ref()); +} + +#[test] +fn test_add_scalar() { + let a = Int32Array::from(&[None, Some(6), None, Some(6)]); + let b: PrimitiveScalar = Some(1i32).into(); + let result = add_scalar(&a, &b); + let expected = Int32Array::from(&[None, Some(7), None, Some(7)]); + assert_eq!(expected, result.as_ref()); +} #[test] fn consistency() {