Skip to content
This repository has been archived by the owner on Feb 18, 2024. It is now read-only.

DRY of type check and len check code in compute #474

Merged
merged 1 commit into from
Sep 30, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 6 additions & 21 deletions src/compute/arithmetics/basic/add.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ use std::ops::Add;

use num_traits::{ops::overflowing::OverflowingAdd, CheckedAdd, SaturatingAdd, Zero};

use crate::compute::arithmetics::basic::check_same_type;
use crate::{
array::{Array, PrimitiveArray},
bitmap::Bitmap,
Expand All @@ -14,7 +15,7 @@ use crate::{
binary, binary_checked, binary_with_bitmap, unary, unary_checked, unary_with_bitmap,
},
},
error::{ArrowError, Result},
error::Result,
types::NativeType,
};

Expand All @@ -36,11 +37,7 @@ pub fn add<T>(lhs: &PrimitiveArray<T>, rhs: &PrimitiveArray<T>) -> Result<Primit
where
T: NativeType + Add<Output = T>,
{
if lhs.data_type() != rhs.data_type() {
return Err(ArrowError::InvalidArgumentError(
"Arrays must have the same logical type".to_string(),
));
}
check_same_type(lhs, rhs)?;

binary(lhs, rhs, lhs.data_type().clone(), |a, b| a + b)
}
Expand All @@ -63,11 +60,7 @@ pub fn checked_add<T>(lhs: &PrimitiveArray<T>, rhs: &PrimitiveArray<T>) -> Resul
where
T: NativeType + CheckedAdd<Output = T> + Zero,
{
if lhs.data_type() != rhs.data_type() {
return Err(ArrowError::InvalidArgumentError(
"Arrays must have the same logical type".to_string(),
));
}
check_same_type(lhs, rhs)?;

let op = move |a: T, b: T| a.checked_add(&b);

Expand Down Expand Up @@ -96,11 +89,7 @@ pub fn saturating_add<T>(
where
T: NativeType + SaturatingAdd<Output = T>,
{
if lhs.data_type() != rhs.data_type() {
return Err(ArrowError::InvalidArgumentError(
"Arrays must have the same logical type".to_string(),
));
}
check_same_type(lhs, rhs)?;

let op = move |a: T, b: T| a.saturating_add(&b);

Expand Down Expand Up @@ -130,11 +119,7 @@ pub fn overflowing_add<T>(
where
T: NativeType + OverflowingAdd<Output = T>,
{
if lhs.data_type() != rhs.data_type() {
return Err(ArrowError::InvalidArgumentError(
"Arrays must have the same logical type".to_string(),
));
}
check_same_type(lhs, rhs)?;

let op = move |a: T, b: T| a.overflowing_add(&b);

Expand Down
31 changes: 31 additions & 0 deletions src/compute/arithmetics/basic/common.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
use crate::array::{Array, PrimitiveArray};
use crate::error::{ArrowError, Result};
use crate::types::NativeType;

// Checking if both arrays have the same type
#[inline]
pub fn check_same_type<L: NativeType, R: NativeType>(
jorgecarleitao marked this conversation as resolved.
Show resolved Hide resolved
lhs: &PrimitiveArray<L>,
rhs: &PrimitiveArray<R>,
) -> Result<()> {
if lhs.data_type() != rhs.data_type() {
return Err(ArrowError::InvalidArgumentError(
"Arrays must have the same logical type".to_string(),
));
}
Ok(())
}

// Checking if both arrays have the same length
#[inline]
pub fn check_same_len<L: NativeType, R: NativeType>(
jorgecarleitao marked this conversation as resolved.
Show resolved Hide resolved
lhs: &PrimitiveArray<L>,
rhs: &PrimitiveArray<R>,
) -> Result<()> {
if lhs.len() != rhs.len() {
return Err(ArrowError::InvalidArgumentError(
"Arrays must have the same length".to_string(),
));
}
Ok(())
}
21 changes: 5 additions & 16 deletions src/compute/arithmetics/basic/div.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,15 @@ use std::ops::Div;

use num_traits::{CheckedDiv, NumCast, Zero};

use crate::compute::arithmetics::basic::{check_same_len, check_same_type};
use crate::datatypes::DataType;
use crate::{
array::{Array, PrimitiveArray},
compute::{
arithmetics::{ArrayCheckedDiv, ArrayDiv, NotI128},
arity::{binary, binary_checked, unary, unary_checked},
},
error::{ArrowError, Result},
error::Result,
types::NativeType,
};
use strength_reduce::{
Expand All @@ -35,20 +36,12 @@ pub fn div<T>(lhs: &PrimitiveArray<T>, rhs: &PrimitiveArray<T>) -> Result<Primit
where
T: NativeType + Div<Output = T>,
{
if lhs.data_type() != rhs.data_type() {
return Err(ArrowError::InvalidArgumentError(
"Arrays must have the same logical type".to_string(),
));
}
check_same_type(lhs, rhs)?;

if rhs.null_count() == 0 {
binary(lhs, rhs, lhs.data_type().clone(), |a, b| a / b)
} else {
if lhs.len() != rhs.len() {
return Err(ArrowError::InvalidArgumentError(
"Arrays must have the same length".to_string(),
));
}
check_same_len(lhs, rhs)?;
let values = lhs.iter().zip(rhs.iter()).map(|(l, r)| match (l, r) {
(Some(l), Some(r)) => Some(*l / *r),
_ => None,
Expand Down Expand Up @@ -77,11 +70,7 @@ pub fn checked_div<T>(lhs: &PrimitiveArray<T>, rhs: &PrimitiveArray<T>) -> Resul
where
T: NativeType + CheckedDiv<Output = T> + Zero,
{
if lhs.data_type() != rhs.data_type() {
return Err(ArrowError::InvalidArgumentError(
"Arrays must have the same logical type".to_string(),
));
}
check_same_type(lhs, rhs)?;

let op = move |a: T, b: T| a.checked_div(&b);

Expand Down
3 changes: 3 additions & 0 deletions src/compute/arithmetics/basic/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,3 +17,6 @@ mod rem;
pub use rem::*;
mod sub;
pub use sub::*;

mod common;
pub(crate) use common::*;
27 changes: 6 additions & 21 deletions src/compute/arithmetics/basic/mul.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ use std::ops::Mul;

use num_traits::{ops::overflowing::OverflowingMul, CheckedMul, SaturatingMul, Zero};

use crate::compute::arithmetics::basic::check_same_type;
use crate::{
array::{Array, PrimitiveArray},
bitmap::Bitmap,
Expand All @@ -14,7 +15,7 @@ use crate::{
binary, binary_checked, binary_with_bitmap, unary, unary_checked, unary_with_bitmap,
},
},
error::{ArrowError, Result},
error::Result,
types::NativeType,
};

Expand All @@ -36,11 +37,7 @@ pub fn mul<T>(lhs: &PrimitiveArray<T>, rhs: &PrimitiveArray<T>) -> Result<Primit
where
T: NativeType + Mul<Output = T>,
{
if lhs.data_type() != rhs.data_type() {
return Err(ArrowError::InvalidArgumentError(
"Arrays must have the same logical type".to_string(),
));
}
check_same_type(lhs, rhs)?;

binary(lhs, rhs, lhs.data_type().clone(), |a, b| a * b)
}
Expand All @@ -64,11 +61,7 @@ pub fn checked_mul<T>(lhs: &PrimitiveArray<T>, rhs: &PrimitiveArray<T>) -> Resul
where
T: NativeType + CheckedMul<Output = T> + Zero,
{
if lhs.data_type() != rhs.data_type() {
return Err(ArrowError::InvalidArgumentError(
"Arrays must have the same logical type".to_string(),
));
}
check_same_type(lhs, rhs)?;

let op = move |a: T, b: T| a.checked_mul(&b);

Expand Down Expand Up @@ -97,11 +90,7 @@ pub fn saturating_mul<T>(
where
T: NativeType + SaturatingMul<Output = T>,
{
if lhs.data_type() != rhs.data_type() {
return Err(ArrowError::InvalidArgumentError(
"Arrays must have the same logical type".to_string(),
));
}
check_same_type(lhs, rhs)?;

let op = move |a: T, b: T| a.saturating_mul(&b);

Expand Down Expand Up @@ -131,11 +120,7 @@ pub fn overflowing_mul<T>(
where
T: NativeType + OverflowingMul<Output = T>,
{
if lhs.data_type() != rhs.data_type() {
return Err(ArrowError::InvalidArgumentError(
"Arrays must have the same logical type".to_string(),
));
}
check_same_type(lhs, rhs)?;

let op = move |a: T, b: T| a.overflowing_mul(&b);

Expand Down
15 changes: 4 additions & 11 deletions src/compute/arithmetics/basic/rem.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,15 @@ use std::ops::Rem;

use num_traits::{CheckedRem, NumCast, Zero};

use crate::compute::arithmetics::basic::check_same_type;
use crate::datatypes::DataType;
use crate::{
array::{Array, PrimitiveArray},
compute::{
arithmetics::{ArrayCheckedRem, ArrayRem, NotI128},
arity::{binary, binary_checked, unary, unary_checked},
},
error::{ArrowError, Result},
error::Result,
types::NativeType,
};
use strength_reduce::{
Expand All @@ -34,11 +35,7 @@ pub fn rem<T>(lhs: &PrimitiveArray<T>, rhs: &PrimitiveArray<T>) -> Result<Primit
where
T: NativeType + Rem<Output = T>,
{
if lhs.data_type() != rhs.data_type() {
return Err(ArrowError::InvalidArgumentError(
"Arrays must have the same logical type".to_string(),
));
}
check_same_type(lhs, rhs)?;

binary(lhs, rhs, lhs.data_type().clone(), |a, b| a % b)
}
Expand All @@ -62,11 +59,7 @@ pub fn checked_rem<T>(lhs: &PrimitiveArray<T>, rhs: &PrimitiveArray<T>) -> Resul
where
T: NativeType + CheckedRem<Output = T> + Zero,
{
if lhs.data_type() != rhs.data_type() {
return Err(ArrowError::InvalidArgumentError(
"Arrays must have the same logical type".to_string(),
));
}
check_same_type(lhs, rhs)?;

let op = move |a: T, b: T| a.checked_rem(&b);

Expand Down
27 changes: 6 additions & 21 deletions src/compute/arithmetics/basic/sub.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ use std::ops::Sub;

use num_traits::{ops::overflowing::OverflowingSub, CheckedSub, SaturatingSub, Zero};

use crate::compute::arithmetics::basic::check_same_type;
use crate::{
array::{Array, PrimitiveArray},
bitmap::Bitmap,
Expand All @@ -14,7 +15,7 @@ use crate::{
binary, binary_checked, binary_with_bitmap, unary, unary_checked, unary_with_bitmap,
},
},
error::{ArrowError, Result},
error::Result,
types::NativeType,
};

Expand All @@ -36,11 +37,7 @@ pub fn sub<T>(lhs: &PrimitiveArray<T>, rhs: &PrimitiveArray<T>) -> Result<Primit
where
T: NativeType + Sub<Output = T>,
{
if lhs.data_type() != rhs.data_type() {
return Err(ArrowError::InvalidArgumentError(
"Arrays must have the same logical type".to_string(),
));
}
check_same_type(lhs, rhs)?;

binary(lhs, rhs, lhs.data_type().clone(), |a, b| a - b)
}
Expand All @@ -63,11 +60,7 @@ pub fn checked_sub<T>(lhs: &PrimitiveArray<T>, rhs: &PrimitiveArray<T>) -> Resul
where
T: NativeType + CheckedSub<Output = T> + Zero,
{
if lhs.data_type() != rhs.data_type() {
return Err(ArrowError::InvalidArgumentError(
"Arrays must have the same logical type".to_string(),
));
}
check_same_type(lhs, rhs)?;

let op = move |a: T, b: T| a.checked_sub(&b);

Expand Down Expand Up @@ -96,11 +89,7 @@ pub fn saturating_sub<T>(
where
T: NativeType + SaturatingSub<Output = T>,
{
if lhs.data_type() != rhs.data_type() {
return Err(ArrowError::InvalidArgumentError(
"Arrays must have the same logical type".to_string(),
));
}
check_same_type(lhs, rhs)?;

let op = move |a: T, b: T| a.saturating_sub(&b);

Expand Down Expand Up @@ -130,11 +119,7 @@ pub fn overflowing_sub<T>(
where
T: NativeType + OverflowingSub<Output = T>,
{
if lhs.data_type() != rhs.data_type() {
return Err(ArrowError::InvalidArgumentError(
"Arrays must have the same logical type".to_string(),
));
}
check_same_type(lhs, rhs)?;

let op = move |a: T, b: T| a.overflowing_sub(&b);

Expand Down
8 changes: 2 additions & 6 deletions src/compute/arithmetics/decimal/add.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
// under the License.

//! Defines the addition arithmetic kernels for Decimal `PrimitiveArrays`.
use crate::compute::arithmetics::basic::check_same_len;
use crate::{
array::{Array, PrimitiveArray},
buffer::Buffer,
Expand Down Expand Up @@ -253,12 +254,7 @@ pub fn adaptive_add(
lhs: &PrimitiveArray<i128>,
rhs: &PrimitiveArray<i128>,
) -> Result<PrimitiveArray<i128>> {
// Checking if both arrays have the same length
if lhs.len() != rhs.len() {
return Err(ArrowError::InvalidArgumentError(
"Arrays must have the same length".to_string(),
));
}
check_same_len(lhs, rhs)?;

if let (DataType::Decimal(lhs_p, lhs_s), DataType::Decimal(rhs_p, rhs_s)) =
(lhs.data_type(), rhs.data_type())
Expand Down
8 changes: 2 additions & 6 deletions src/compute/arithmetics/decimal/div.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
//! Defines the division arithmetic kernels for Decimal
//! `PrimitiveArrays`.

use crate::compute::arithmetics::basic::check_same_len;
use crate::{
array::{Array, PrimitiveArray},
buffer::Buffer,
Expand Down Expand Up @@ -272,12 +273,7 @@ pub fn adaptive_div(
lhs: &PrimitiveArray<i128>,
rhs: &PrimitiveArray<i128>,
) -> Result<PrimitiveArray<i128>> {
// Checking if both arrays have the same length
if lhs.len() != rhs.len() {
return Err(ArrowError::InvalidArgumentError(
"Arrays must have the same length".to_string(),
));
}
check_same_len(lhs, rhs)?;

if let (DataType::Decimal(lhs_p, lhs_s), DataType::Decimal(rhs_p, rhs_s)) =
(lhs.data_type(), rhs.data_type())
Expand Down
Loading