From 1192927556ff69a50b377217cc6c63b54389f73e Mon Sep 17 00:00:00 2001 From: "Jorge C. Leitao" Date: Tue, 14 Dec 2021 20:03:16 +0000 Subject: [PATCH] Added dynamic version of neg. --- src/compute/arithmetics/mod.rs | 68 ++++++++++++++++++++++++++++- src/types/mod.rs | 20 ++++++++- tests/it/compute/arithmetics/mod.rs | 24 +++++++++- 3 files changed, 109 insertions(+), 3 deletions(-) diff --git a/src/compute/arithmetics/mod.rs b/src/compute/arithmetics/mod.rs index 2dd1ae5a484..b8e61c9fc27 100644 --- a/src/compute/arithmetics/mod.rs +++ b/src/compute/arithmetics/mod.rs @@ -17,7 +17,7 @@ pub mod decimal; pub mod time; use crate::{ - array::{Array, PrimitiveArray}, + array::{Array, DictionaryArray, PrimitiveArray}, bitmap::Bitmap, datatypes::{DataType, IntervalUnit, TimeUnit}, scalar::{PrimitiveScalar, Scalar}, @@ -400,6 +400,72 @@ pub fn can_rem(lhs: &DataType, rhs: &DataType) -> bool { ) } +macro_rules! with_match_negatable {( + $key_type:expr, | $_:tt $T:ident | $($body:tt)* +) => ({ + macro_rules! __with_ty__ {( $_ $T:ident ) => ( $($body)* )} + use crate::datatypes::PrimitiveType::*; + use crate::types::{days_ms, months_days_ns}; + match $key_type { + Int8 => __with_ty__! { i8 }, + Int16 => __with_ty__! { i16 }, + Int32 => __with_ty__! { i32 }, + Int64 => __with_ty__! { i64 }, + Int128 => __with_ty__! { i128 }, + DaysMs => __with_ty__! { days_ms }, + MonthDayNano => __with_ty__! { months_days_ns }, + UInt8 | UInt16 | UInt32 | UInt64=> todo!(), + Float32 => __with_ty__! { f32 }, + Float64 => __with_ty__! { f64 }, + } +})} + +/// Negates an [`Array`]. +/// # Panic +/// This function panics iff either +/// * the opertion is not supported for the logical type (use [`can_neg`] to check) +/// * the operation overflows +pub fn neg(array: &dyn Array) -> Box { + use crate::datatypes::PhysicalType::*; + match array.data_type().to_physical_type() { + Primitive(primitive) => with_match_negatable!(primitive, |$T| { + let array = array.as_any().downcast_ref().unwrap(); + + let result = basic::negate::<$T>(array); + Box::new(result) as Box + }), + Dictionary(key) => match_integer_type!(key, |$T| { + let array = array.as_any().downcast_ref::>().unwrap(); + + let values = neg(array.values().as_ref()).into(); + + Box::new(DictionaryArray::<$T>::from_data(array.keys().clone(), values)) as Box + }), + _ => todo!(), + } +} + +/// Whether [`neg`] is supported for a given [`DataType`] +pub fn can_neg(data_type: &DataType) -> bool { + if let DataType::Dictionary(_, values) = data_type.to_logical_type() { + return can_neg(values.as_ref()); + } + + use crate::datatypes::PhysicalType::*; + use crate::datatypes::PrimitiveType::*; + matches!( + data_type.to_physical_type(), + Primitive(Int8) + | Primitive(Int16) + | Primitive(Int32) + | Primitive(Int64) + | Primitive(Float64) + | Primitive(Float32) + | Primitive(DaysMs) + | Primitive(MonthDayNano) + ) +} + /// Defines basic addition operation for primitive arrays pub trait ArrayAdd: Sized { /// Adds itself to `rhs` diff --git a/src/types/mod.rs b/src/types/mod.rs index 3ea5afd6e97..c13c11249e9 100644 --- a/src/types/mod.rs +++ b/src/types/mod.rs @@ -7,7 +7,7 @@ //! represent chunks of bits (e.g. `u8`, `u16`), and [`BitChunkIter`], that can be used to //! iterate over bitmaps in [`BitChunk`]s. //! Finally, this module also contains traits used to compile code optimized for SIMD instructions at [`mod@simd`]. -use std::convert::TryFrom; +use std::{convert::TryFrom, ops::Neg}; mod bit_chunk; pub use bit_chunk::{BitChunk, BitChunkIter}; @@ -399,3 +399,21 @@ impl months_days_ns { self.2 } } + +impl Neg for days_ms { + type Output = Self; + + #[inline(always)] + fn neg(self) -> Self::Output { + Self([-self.0[0], -self.0[0]]) + } +} + +impl Neg for months_days_ns { + type Output = Self; + + #[inline(always)] + fn neg(self) -> Self::Output { + Self(-self.0, -self.1, -self.2) + } +} diff --git a/tests/it/compute/arithmetics/mod.rs b/tests/it/compute/arithmetics/mod.rs index 222f1c3d84c..ed84daa3998 100644 --- a/tests/it/compute/arithmetics/mod.rs +++ b/tests/it/compute/arithmetics/mod.rs @@ -2,7 +2,7 @@ mod basic; mod decimal; mod time; -use arrow2::array::{new_empty_array, Int32Array}; +use arrow2::array::*; use arrow2::compute::arithmetics::*; use arrow2::datatypes::DataType::*; use arrow2::datatypes::{IntervalUnit, TimeUnit}; @@ -84,3 +84,25 @@ fn consistency() { } }); } + +#[test] +fn test_neg() { + let a = Int32Array::from(&[None, Some(6), None, Some(6)]); + let result = neg(&a); + let expected = Int32Array::from(&[None, Some(-6), None, Some(-6)]); + assert_eq!(expected, result.as_ref()); +} + +#[test] +fn test_neg_dict() { + let a = DictionaryArray::::from_data( + UInt8Array::from_slice(&[0, 0, 1]), + std::sync::Arc::new(Int8Array::from_slice(&[1, 2])), + ); + let result = neg(&a); + let expected = DictionaryArray::::from_data( + UInt8Array::from_slice(&[0, 0, 1]), + std::sync::Arc::new(Int8Array::from_slice(&[-1, -2])), + ); + assert_eq!(expected, result.as_ref()); +}