Skip to content
This repository has been archived by the owner on Feb 18, 2024. It is now read-only.

Commit

Permalink
Simplified arithmetics compute (#607)
Browse files Browse the repository at this point in the history
  • Loading branch information
jorgecarleitao committed Nov 24, 2021
1 parent c979dbf commit a8af882
Show file tree
Hide file tree
Showing 32 changed files with 689 additions and 811 deletions.
2 changes: 1 addition & 1 deletion benches/arithmetic_kernels.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ fn bench_add<T>(lhs: &PrimitiveArray<T>, rhs: &PrimitiveArray<T>)
where
T: NativeType + Add<Output = T> + NumCast,
{
criterion::black_box(add(lhs, rhs)).unwrap();
criterion::black_box(add(lhs, rhs));
}

fn add_benchmark(c: &mut Criterion) {
Expand Down
2 changes: 1 addition & 1 deletion benches/comparison_kernels.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use criterion::{criterion_group, criterion_main, Criterion};

use arrow2::scalar::*;
use arrow2::util::bench_util::*;
use arrow2::{compute::comparison::*, datatypes::DataType};
use arrow2::{compute::comparison::eq, datatypes::DataType};

fn add_benchmark(c: &mut Criterion) {
(10..=20).step_by(2).for_each(|log2_size| {
Expand Down
2 changes: 1 addition & 1 deletion benches/write_ipc.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ fn write(array: &dyn Array) -> Result<()> {
let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![clone(array).into()])?;

let writer = Cursor::new(vec![]);
let mut writer = FileWriter::try_new(writer, &schema)?;
let mut writer = FileWriter::try_new(writer, &schema, Default::default())?;

writer.write(&batch)
}
Expand Down
28 changes: 11 additions & 17 deletions examples/arithmetics.rs
Original file line number Diff line number Diff line change
@@ -1,48 +1,44 @@
use arrow2::array::{Array, PrimitiveArray};
use arrow2::compute::arithmetics::*;
use arrow2::compute::arithmetics::basic::*;
use arrow2::compute::arithmetics::{add as dyn_add, can_add};
use arrow2::compute::arity::{binary, unary};
use arrow2::datatypes::DataType;
use arrow2::error::Result;

fn main() -> Result<()> {
fn main() {
// say we have two arrays
let array0 = PrimitiveArray::<i64>::from(&[Some(1), Some(2), Some(3)]);
let array1 = PrimitiveArray::<i64>::from(&[Some(4), None, Some(6)]);

// we can add them as follows:
let added = arithmetic_primitive(&array0, Operator::Add, &array1)?;
let added = add(&array0, &array1);
assert_eq!(
added,
PrimitiveArray::<i64>::from(&[Some(5), None, Some(9)])
);

// subtract:
let subtracted = arithmetic_primitive(&array0, Operator::Subtract, &array1)?;
let subtracted = sub(&array0, &array1);
assert_eq!(
subtracted,
PrimitiveArray::<i64>::from(&[Some(-3), None, Some(-3)])
);

// add a scalar:
let plus10 = arithmetic_primitive_scalar(&array0, Operator::Add, &10)?;
let plus10 = add_scalar(&array0, &10);
assert_eq!(
plus10,
PrimitiveArray::<i64>::from(&[Some(11), Some(12), Some(13)])
);

// when the array is a trait object, there is a similar API
// a similar API for trait objects:
let array0 = &array0 as &dyn Array;
let array1 = &array1 as &dyn Array;

// check whether the logical types support addition (they could be any `Array`).
assert!(can_arithmetic(
array0.data_type(),
Operator::Add,
array1.data_type()
));
// check whether the logical types support addition.
assert!(can_add(array0.data_type(), array1.data_type()));

// add them
let added = arithmetic(array0, Operator::Add, array1).unwrap();
let added = dyn_add(array0, array1);
assert_eq!(
PrimitiveArray::<i64>::from(&[Some(5), None, Some(9)]),
added.as_ref(),
Expand All @@ -54,7 +50,7 @@ fn main() -> Result<()> {
let array1 = PrimitiveArray::<i64>::from(&[Some(4), None, Some(6)]);

let op = |x: i64, y: i64| x.pow(2) + y.pow(2);
let r = binary(&array0, &array1, DataType::Int64, op)?;
let r = binary(&array0, &array1, DataType::Int64, op);
assert_eq!(
r,
PrimitiveArray::<i64>::from(&[Some(1 + 16), None, Some(9 + 36)])
Expand All @@ -79,6 +75,4 @@ fn main() -> Result<()> {
rounded,
PrimitiveArray::<i64>::from(&[Some(4), None, Some(5)])
);

Ok(())
}
68 changes: 25 additions & 43 deletions src/compute/arithmetics/basic/add.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,20 +3,18 @@ use std::ops::Add;

use num_traits::{ops::overflowing::OverflowingAdd, CheckedAdd, SaturatingAdd, WrappingAdd, Zero};

use crate::compute::arithmetics::basic::check_same_type;
use crate::compute::arithmetics::ArrayWrappingAdd;
use crate::{
array::{Array, PrimitiveArray},
bitmap::Bitmap,
compute::{
arithmetics::{
ArrayAdd, ArrayCheckedAdd, ArrayOverflowingAdd, ArraySaturatingAdd, NativeArithmetics,
ArrayAdd, ArrayCheckedAdd, ArrayOverflowingAdd, ArraySaturatingAdd, ArrayWrappingAdd,
NativeArithmetics,
},
arity::{
binary, binary_checked, binary_with_bitmap, unary, unary_checked, unary_with_bitmap,
},
},
error::Result,
types::NativeType,
};

Expand All @@ -30,16 +28,14 @@ use crate::{
///
/// let a = PrimitiveArray::from([None, Some(6), None, Some(6)]);
/// let b = PrimitiveArray::from([Some(5), None, None, Some(6)]);
/// let result = add(&a, &b).unwrap();
/// let result = add(&a, &b);
/// let expected = PrimitiveArray::from([None, None, None, Some(12)]);
/// assert_eq!(result, expected)
/// ```
pub fn add<T>(lhs: &PrimitiveArray<T>, rhs: &PrimitiveArray<T>) -> Result<PrimitiveArray<T>>
pub fn add<T>(lhs: &PrimitiveArray<T>, rhs: &PrimitiveArray<T>) -> PrimitiveArray<T>
where
T: NativeType + Add<Output = T>,
{
check_same_type(lhs, rhs)?;

binary(lhs, rhs, lhs.data_type().clone(), |a, b| a + b)
}

Expand All @@ -53,19 +49,14 @@ where
///
/// let a = PrimitiveArray::from([Some(-100i8), Some(100i8), Some(100i8)]);
/// let b = PrimitiveArray::from([Some(0i8), Some(100i8), Some(0i8)]);
/// let result = wrapping_add(&a, &b).unwrap();
/// let result = wrapping_add(&a, &b);
/// let expected = PrimitiveArray::from([Some(-100i8), Some(-56i8), Some(100i8)]);
/// assert_eq!(result, expected);
/// ```
pub fn wrapping_add<T>(
lhs: &PrimitiveArray<T>,
rhs: &PrimitiveArray<T>,
) -> Result<PrimitiveArray<T>>
pub fn wrapping_add<T>(lhs: &PrimitiveArray<T>, rhs: &PrimitiveArray<T>) -> PrimitiveArray<T>
where
T: NativeType + WrappingAdd<Output = T>,
{
check_same_type(lhs, rhs)?;

let op = move |a: T, b: T| a.wrapping_add(&b);

binary(lhs, rhs, lhs.data_type().clone(), op)
Expand All @@ -81,16 +72,14 @@ where
///
/// let a = PrimitiveArray::from([Some(100i8), Some(100i8), Some(100i8)]);
/// let b = PrimitiveArray::from([Some(0i8), Some(100i8), Some(0i8)]);
/// let result = checked_add(&a, &b).unwrap();
/// let result = checked_add(&a, &b);
/// let expected = PrimitiveArray::from([Some(100i8), None, Some(100i8)]);
/// assert_eq!(result, expected);
/// ```
pub fn checked_add<T>(lhs: &PrimitiveArray<T>, rhs: &PrimitiveArray<T>) -> Result<PrimitiveArray<T>>
pub fn checked_add<T>(lhs: &PrimitiveArray<T>, rhs: &PrimitiveArray<T>) -> PrimitiveArray<T>
where
T: NativeType + CheckedAdd<Output = T>,
{
check_same_type(lhs, rhs)?;

let op = move |a: T, b: T| a.checked_add(&b);

binary_checked(lhs, rhs, lhs.data_type().clone(), op)
Expand All @@ -107,19 +96,14 @@ where
///
/// let a = PrimitiveArray::from([Some(100i8)]);
/// let b = PrimitiveArray::from([Some(100i8)]);
/// let result = saturating_add(&a, &b).unwrap();
/// let result = saturating_add(&a, &b);
/// let expected = PrimitiveArray::from([Some(127)]);
/// assert_eq!(result, expected);
/// ```
pub fn saturating_add<T>(
lhs: &PrimitiveArray<T>,
rhs: &PrimitiveArray<T>,
) -> Result<PrimitiveArray<T>>
pub fn saturating_add<T>(lhs: &PrimitiveArray<T>, rhs: &PrimitiveArray<T>) -> PrimitiveArray<T>
where
T: NativeType + SaturatingAdd<Output = T>,
{
check_same_type(lhs, rhs)?;

let op = move |a: T, b: T| a.saturating_add(&b);

binary(lhs, rhs, lhs.data_type().clone(), op)
Expand All @@ -137,19 +121,17 @@ where
///
/// let a = PrimitiveArray::from([Some(1i8), Some(100i8)]);
/// let b = PrimitiveArray::from([Some(1i8), Some(100i8)]);
/// let (result, overflow) = overflowing_add(&a, &b).unwrap();
/// let (result, overflow) = overflowing_add(&a, &b);
/// let expected = PrimitiveArray::from([Some(2i8), Some(-56i8)]);
/// assert_eq!(result, expected);
/// ```
pub fn overflowing_add<T>(
lhs: &PrimitiveArray<T>,
rhs: &PrimitiveArray<T>,
) -> Result<(PrimitiveArray<T>, Bitmap)>
) -> (PrimitiveArray<T>, Bitmap)
where
T: NativeType + OverflowingAdd<Output = T>,
{
check_same_type(lhs, rhs)?;

let op = move |a: T, b: T| a.overflowing_add(&b);

binary_with_bitmap(lhs, rhs, lhs.data_type().clone(), op)
Expand All @@ -160,7 +142,7 @@ impl<T> ArrayAdd<PrimitiveArray<T>> for PrimitiveArray<T>
where
T: NativeArithmetics + Add<Output = T>,
{
fn add(&self, rhs: &PrimitiveArray<T>) -> Result<Self> {
fn add(&self, rhs: &PrimitiveArray<T>) -> Self {
add(self, rhs)
}
}
Expand All @@ -169,7 +151,7 @@ impl<T> ArrayWrappingAdd<PrimitiveArray<T>> for PrimitiveArray<T>
where
T: NativeArithmetics + WrappingAdd<Output = T>,
{
fn wrapping_add(&self, rhs: &PrimitiveArray<T>) -> Result<Self> {
fn wrapping_add(&self, rhs: &PrimitiveArray<T>) -> Self {
wrapping_add(self, rhs)
}
}
Expand All @@ -179,7 +161,7 @@ impl<T> ArrayCheckedAdd<PrimitiveArray<T>> for PrimitiveArray<T>
where
T: NativeArithmetics + CheckedAdd<Output = T>,
{
fn checked_add(&self, rhs: &PrimitiveArray<T>) -> Result<Self> {
fn checked_add(&self, rhs: &PrimitiveArray<T>) -> Self {
checked_add(self, rhs)
}
}
Expand All @@ -189,7 +171,7 @@ impl<T> ArraySaturatingAdd<PrimitiveArray<T>> for PrimitiveArray<T>
where
T: NativeArithmetics + SaturatingAdd<Output = T>,
{
fn saturating_add(&self, rhs: &PrimitiveArray<T>) -> Result<Self> {
fn saturating_add(&self, rhs: &PrimitiveArray<T>) -> Self {
saturating_add(self, rhs)
}
}
Expand All @@ -199,7 +181,7 @@ impl<T> ArrayOverflowingAdd<PrimitiveArray<T>> for PrimitiveArray<T>
where
T: NativeArithmetics + OverflowingAdd<Output = T>,
{
fn overflowing_add(&self, rhs: &PrimitiveArray<T>) -> Result<(Self, Bitmap)> {
fn overflowing_add(&self, rhs: &PrimitiveArray<T>) -> (Self, Bitmap) {
overflowing_add(self, rhs)
}
}
Expand Down Expand Up @@ -323,8 +305,8 @@ impl<T> ArrayAdd<T> for PrimitiveArray<T>
where
T: NativeArithmetics + Add<Output = T>,
{
fn add(&self, rhs: &T) -> Result<Self> {
Ok(add_scalar(self, rhs))
fn add(&self, rhs: &T) -> Self {
add_scalar(self, rhs)
}
}

Expand All @@ -333,8 +315,8 @@ impl<T> ArrayCheckedAdd<T> for PrimitiveArray<T>
where
T: NativeArithmetics + CheckedAdd<Output = T> + Zero,
{
fn checked_add(&self, rhs: &T) -> Result<Self> {
Ok(checked_add_scalar(self, rhs))
fn checked_add(&self, rhs: &T) -> Self {
checked_add_scalar(self, rhs)
}
}

Expand All @@ -343,8 +325,8 @@ impl<T> ArraySaturatingAdd<T> for PrimitiveArray<T>
where
T: NativeArithmetics + SaturatingAdd<Output = T>,
{
fn saturating_add(&self, rhs: &T) -> Result<Self> {
Ok(saturating_add_scalar(self, rhs))
fn saturating_add(&self, rhs: &T) -> Self {
saturating_add_scalar(self, rhs)
}
}

Expand All @@ -353,7 +335,7 @@ impl<T> ArrayOverflowingAdd<T> for PrimitiveArray<T>
where
T: NativeArithmetics + OverflowingAdd<Output = T>,
{
fn overflowing_add(&self, rhs: &T) -> Result<(Self, Bitmap)> {
Ok(overflowing_add_scalar(self, rhs))
fn overflowing_add(&self, rhs: &T) -> (Self, Bitmap) {
overflowing_add_scalar(self, rhs)
}
}
Loading

0 comments on commit a8af882

Please sign in to comment.