diff --git a/src/compute/boolean.rs b/src/compute/boolean.rs index 46d72669541..8a17eeaba36 100644 --- a/src/compute/boolean.rs +++ b/src/compute/boolean.rs @@ -6,17 +6,24 @@ use crate::scalar::BooleanScalar; use super::utils::combine_validities; -/// Helper function to implement binary kernels -fn binary_boolean_kernel(lhs: &BooleanArray, rhs: &BooleanArray, op: F) -> BooleanArray -where - F: Fn(&Bitmap, &Bitmap) -> Bitmap, -{ +fn assert_lengths(lhs: &BooleanArray, rhs: &BooleanArray) { assert_eq!( lhs.len(), rhs.len(), "lhs and rhs must have the same length" ); +} +/// Helper function to implement binary kernels +pub(crate) fn binary_boolean_kernel( + lhs: &BooleanArray, + rhs: &BooleanArray, + op: F, +) -> BooleanArray +where + F: Fn(&Bitmap, &Bitmap) -> Bitmap, +{ + assert_lengths(lhs, rhs); let validity = combine_validities(lhs.validity(), rhs.validity()); let left_buffer = lhs.values(); @@ -41,6 +48,31 @@ where /// assert_eq!(and_ab, BooleanArray::from(&[Some(false), Some(true), None])); /// ``` pub fn and(lhs: &BooleanArray, rhs: &BooleanArray) -> BooleanArray { + if lhs.null_count() == 0 && rhs.null_count() == 0 { + let left_buffer = lhs.values(); + let right_buffer = rhs.values(); + + match (left_buffer.unset_bits(), right_buffer.unset_bits()) { + // all values are `true` on both sides + (0, 0) => { + assert_lengths(lhs, rhs); + return lhs.clone(); + } + // all values are `false` on left side + (l, _) if l == lhs.len() => { + assert_lengths(lhs, rhs); + return lhs.clone(); + } + // all values are `false` on right side + (_, r) if r == rhs.len() => { + assert_lengths(lhs, rhs); + return rhs.clone(); + } + // ignore the rest + _ => {} + } + } + binary_boolean_kernel(lhs, rhs, |lhs, rhs| lhs & rhs) } @@ -58,6 +90,31 @@ pub fn and(lhs: &BooleanArray, rhs: &BooleanArray) -> BooleanArray { /// assert_eq!(or_ab, BooleanArray::from(vec![Some(true), Some(true), None])); /// ``` pub fn or(lhs: &BooleanArray, rhs: &BooleanArray) -> BooleanArray { + if lhs.null_count() == 0 && rhs.null_count() == 0 { + let left_buffer = lhs.values(); + let right_buffer = rhs.values(); + + match (left_buffer.unset_bits(), right_buffer.unset_bits()) { + // all values are `true` on left side + (0, _) => { + assert_lengths(lhs, rhs); + return lhs.clone(); + } + // all values are `true` on right side + (_, 0) => { + assert_lengths(lhs, rhs); + return rhs.clone(); + } + // all values on lhs and rhs are `false` + (l, r) if l == lhs.len() && r == rhs.len() => { + assert_lengths(lhs, rhs); + return rhs.clone(); + } + // ignore the rest + _ => {} + } + } + binary_boolean_kernel(lhs, rhs, |lhs, rhs| lhs | rhs) } diff --git a/src/compute/comparison/mod.rs b/src/compute/comparison/mod.rs index 4ad7920705e..ba6a84e49b8 100644 --- a/src/compute/comparison/mod.rs +++ b/src/compute/comparison/mod.rs @@ -57,7 +57,7 @@ mod simd; pub use simd::{Simd8, Simd8Lanes, Simd8PartialEq, Simd8PartialOrd}; use super::take::take_boolean; -use crate::bitmap::Bitmap; +use crate::bitmap::{binary, Bitmap}; use crate::compute; pub(crate) use primitive::{ compare_values_op as primitive_compare_values_op, @@ -521,10 +521,43 @@ fn finish_eq_validities( &BooleanArray::new(DataType::Boolean, rhs, None), ), (Some(lhs), Some(rhs)) => { + let lhs_validity_unset_bits = lhs.unset_bits(); + let rhs_validity_unset_bits = rhs.unset_bits(); + + // this branch is a bit more complicated as both arrays can have masked out values + // these masked out values might differ and lead to a `eq == false` that has to + // be corrected as both should be `null == null = true` + let lhs = BooleanArray::new(DataType::Boolean, lhs, None); let rhs = BooleanArray::new(DataType::Boolean, rhs, None); let eq_validities = compute::comparison::boolean::eq(&lhs, &rhs); - compute::boolean::and(&output_without_validities, &eq_validities) + + // validity_bits are equal AND values are equal + let equal = compute::boolean::and(&output_without_validities, &eq_validities); + + match (lhs_validity_unset_bits, rhs_validity_unset_bits) { + // there is at least one side with all values valid + // so we don't have to correct. + (0, _) | (_, 0) => equal, + _ => { + // we use the binary kernel here to save allocations + // and apply `!(lhs | rhs)` in one step + let both_sides_invalid = + compute::boolean::binary_boolean_kernel(&lhs, &rhs, |lhs, rhs| { + binary(lhs, rhs, |lhs, rhs| !(lhs | rhs)) + }); + // this still might include incorrect masked out values + // under the validity bits, so we must correct for that + + // if not all true, e.g. at least one is set. + // then we propagate that null as `true` in equality + if both_sides_invalid.values().unset_bits() != both_sides_invalid.len() { + compute::boolean::or(&equal, &both_sides_invalid) + } else { + equal + } + } + } } } } @@ -547,10 +580,46 @@ fn finish_neq_validities( compute::boolean::or(&output_without_validities, &rhs_negated) } (Some(lhs), Some(rhs)) => { + let lhs_validity_unset_bits = lhs.unset_bits(); + let rhs_validity_unset_bits = rhs.unset_bits(); + + // this branch is a bit more complicated as both arrays can have masked out values + // these masked out values might differ and lead to a `neq == true` that has to + // be corrected as both should be `null != null = false` let lhs = BooleanArray::new(DataType::Boolean, lhs, None); let rhs = BooleanArray::new(DataType::Boolean, rhs, None); let neq_validities = compute::comparison::boolean::neq(&lhs, &rhs); - compute::boolean::or(&output_without_validities, &neq_validities) + + // validity_bits are not equal OR values not equal + let or = compute::boolean::or(&output_without_validities, &neq_validities); + + match (lhs_validity_unset_bits, rhs_validity_unset_bits) { + // there is at least one side with all values valid + // so we don't have to correct. + (0, _) | (_, 0) => or, + _ => { + // we use the binary kernel here to save allocations + // and apply `!(lhs | rhs)` in one step + let both_sides_invalid = + compute::boolean::binary_boolean_kernel(&lhs, &rhs, |lhs, rhs| { + binary(lhs, rhs, |lhs, rhs| !(lhs | rhs)) + }); + // this still might include incorrect masked out values + // under the validity bits, so we must correct for that + + // if not all true, e.g. at least one is set. + // then we propagate that null as `false` as the nulls are equal + if both_sides_invalid.values().unset_bits() != both_sides_invalid.len() { + // we use the `binary` kernel directly to save allocations + // and apply `lhs & !rhs)` in one shot. + compute::boolean::binary_boolean_kernel(&lhs, &rhs, |lhs, rhs| { + binary(lhs, rhs, |lhs, rhs| (lhs & !rhs)) + }) + } else { + or + } + } + } } } } diff --git a/tests/it/compute/comparison.rs b/tests/it/compute/comparison.rs index 09f3fd61458..b077e996100 100644 --- a/tests/it/compute/comparison.rs +++ b/tests/it/compute/comparison.rs @@ -1,7 +1,7 @@ use arrow2::array::*; use arrow2::bitmap::Bitmap; -use arrow2::compute::comparison::{self, boolean::*, primitive}; -use arrow2::datatypes::{DataType::*, IntegerType, IntervalUnit, TimeUnit}; +use arrow2::compute::comparison::{self, boolean::*, primitive, utf8}; +use arrow2::datatypes::{DataType, DataType::*, IntegerType, IntervalUnit, TimeUnit}; use arrow2::scalar::new_scalar; #[test] @@ -381,3 +381,23 @@ fn primitive_gt_eq() { BooleanArray::from([Some(true), Some(false), Some(true), None, None]) ) } + +#[test] +#[cfg(any(feature = "compute_cast", feature = "compute_boolean_kleene"))] +fn utf8_and_validity() { + use arrow2::compute::cast::CastOptions; + let a1 = Utf8Array::::from([Some("0"), Some("1"), None, Some("2")]); + let a2 = Int32Array::from([Some(0), Some(1), None, Some(2)]); + + // due to the cast the values underneath the validity bits differ + let a2 = arrow2::compute::cast::cast(&a2, &DataType::Utf8, CastOptions::default()).unwrap(); + let a2 = a2.as_any().downcast_ref::>().unwrap(); + + let expected = BooleanArray::from_slice([true, true, true, true]); + assert_eq!(utf8::eq_and_validity(&a1, &a1), expected); + assert_eq!(utf8::eq_and_validity(&a1, a2), expected); + + let expected = BooleanArray::from_slice([false, false, false, false]); + assert_eq!(utf8::neq_and_validity(&a1, &a1), expected); + assert_eq!(utf8::neq_and_validity(&a1, a2), expected); +}