Skip to content
This repository has been archived by the owner on Feb 18, 2024. It is now read-only.

Commit

Permalink
fixed comparisson and validity kernels (#1243)
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 committed Sep 5, 2022
1 parent a7428e0 commit 6c102a0
Show file tree
Hide file tree
Showing 3 changed files with 156 additions and 10 deletions.
67 changes: 62 additions & 5 deletions src/compute/boolean.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,17 +6,24 @@ use crate::scalar::BooleanScalar;

use super::utils::combine_validities;

/// Helper function to implement binary kernels
fn binary_boolean_kernel<F>(lhs: &BooleanArray, rhs: &BooleanArray, op: F) -> BooleanArray
where
F: Fn(&Bitmap, &Bitmap) -> Bitmap,
{
fn assert_lengths(lhs: &BooleanArray, rhs: &BooleanArray) {
assert_eq!(
lhs.len(),
rhs.len(),
"lhs and rhs must have the same length"
);
}

/// Helper function to implement binary kernels
pub(crate) fn binary_boolean_kernel<F>(
lhs: &BooleanArray,
rhs: &BooleanArray,
op: F,
) -> BooleanArray
where
F: Fn(&Bitmap, &Bitmap) -> Bitmap,
{
assert_lengths(lhs, rhs);
let validity = combine_validities(lhs.validity(), rhs.validity());

let left_buffer = lhs.values();
Expand All @@ -41,6 +48,31 @@ where
/// assert_eq!(and_ab, BooleanArray::from(&[Some(false), Some(true), None]));
/// ```
pub fn and(lhs: &BooleanArray, rhs: &BooleanArray) -> BooleanArray {
if lhs.null_count() == 0 && rhs.null_count() == 0 {
let left_buffer = lhs.values();
let right_buffer = rhs.values();

match (left_buffer.unset_bits(), right_buffer.unset_bits()) {
// all values are `true` on both sides
(0, 0) => {
assert_lengths(lhs, rhs);
return lhs.clone();
}
// all values are `false` on left side
(l, _) if l == lhs.len() => {
assert_lengths(lhs, rhs);
return lhs.clone();
}
// all values are `false` on right side
(_, r) if r == rhs.len() => {
assert_lengths(lhs, rhs);
return rhs.clone();
}
// ignore the rest
_ => {}
}
}

binary_boolean_kernel(lhs, rhs, |lhs, rhs| lhs & rhs)
}

Expand All @@ -58,6 +90,31 @@ pub fn and(lhs: &BooleanArray, rhs: &BooleanArray) -> BooleanArray {
/// assert_eq!(or_ab, BooleanArray::from(vec![Some(true), Some(true), None]));
/// ```
pub fn or(lhs: &BooleanArray, rhs: &BooleanArray) -> BooleanArray {
if lhs.null_count() == 0 && rhs.null_count() == 0 {
let left_buffer = lhs.values();
let right_buffer = rhs.values();

match (left_buffer.unset_bits(), right_buffer.unset_bits()) {
// all values are `true` on left side
(0, _) => {
assert_lengths(lhs, rhs);
return lhs.clone();
}
// all values are `true` on right side
(_, 0) => {
assert_lengths(lhs, rhs);
return rhs.clone();
}
// all values on lhs and rhs are `false`
(l, r) if l == lhs.len() && r == rhs.len() => {
assert_lengths(lhs, rhs);
return rhs.clone();
}
// ignore the rest
_ => {}
}
}

binary_boolean_kernel(lhs, rhs, |lhs, rhs| lhs | rhs)
}

Expand Down
75 changes: 72 additions & 3 deletions src/compute/comparison/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ mod simd;
pub use simd::{Simd8, Simd8Lanes, Simd8PartialEq, Simd8PartialOrd};

use super::take::take_boolean;
use crate::bitmap::Bitmap;
use crate::bitmap::{binary, Bitmap};
use crate::compute;
pub(crate) use primitive::{
compare_values_op as primitive_compare_values_op,
Expand Down Expand Up @@ -521,10 +521,43 @@ fn finish_eq_validities(
&BooleanArray::new(DataType::Boolean, rhs, None),
),
(Some(lhs), Some(rhs)) => {
let lhs_validity_unset_bits = lhs.unset_bits();
let rhs_validity_unset_bits = rhs.unset_bits();

// this branch is a bit more complicated as both arrays can have masked out values
// these masked out values might differ and lead to a `eq == false` that has to
// be corrected as both should be `null == null = true`

let lhs = BooleanArray::new(DataType::Boolean, lhs, None);
let rhs = BooleanArray::new(DataType::Boolean, rhs, None);
let eq_validities = compute::comparison::boolean::eq(&lhs, &rhs);
compute::boolean::and(&output_without_validities, &eq_validities)

// validity_bits are equal AND values are equal
let equal = compute::boolean::and(&output_without_validities, &eq_validities);

match (lhs_validity_unset_bits, rhs_validity_unset_bits) {
// there is at least one side with all values valid
// so we don't have to correct.
(0, _) | (_, 0) => equal,
_ => {
// we use the binary kernel here to save allocations
// and apply `!(lhs | rhs)` in one step
let both_sides_invalid =
compute::boolean::binary_boolean_kernel(&lhs, &rhs, |lhs, rhs| {
binary(lhs, rhs, |lhs, rhs| !(lhs | rhs))
});
// this still might include incorrect masked out values
// under the validity bits, so we must correct for that

// if not all true, e.g. at least one is set.
// then we propagate that null as `true` in equality
if both_sides_invalid.values().unset_bits() != both_sides_invalid.len() {
compute::boolean::or(&equal, &both_sides_invalid)
} else {
equal
}
}
}
}
}
}
Expand All @@ -547,10 +580,46 @@ fn finish_neq_validities(
compute::boolean::or(&output_without_validities, &rhs_negated)
}
(Some(lhs), Some(rhs)) => {
let lhs_validity_unset_bits = lhs.unset_bits();
let rhs_validity_unset_bits = rhs.unset_bits();

// this branch is a bit more complicated as both arrays can have masked out values
// these masked out values might differ and lead to a `neq == true` that has to
// be corrected as both should be `null != null = false`
let lhs = BooleanArray::new(DataType::Boolean, lhs, None);
let rhs = BooleanArray::new(DataType::Boolean, rhs, None);
let neq_validities = compute::comparison::boolean::neq(&lhs, &rhs);
compute::boolean::or(&output_without_validities, &neq_validities)

// validity_bits are not equal OR values not equal
let or = compute::boolean::or(&output_without_validities, &neq_validities);

match (lhs_validity_unset_bits, rhs_validity_unset_bits) {
// there is at least one side with all values valid
// so we don't have to correct.
(0, _) | (_, 0) => or,
_ => {
// we use the binary kernel here to save allocations
// and apply `!(lhs | rhs)` in one step
let both_sides_invalid =
compute::boolean::binary_boolean_kernel(&lhs, &rhs, |lhs, rhs| {
binary(lhs, rhs, |lhs, rhs| !(lhs | rhs))
});
// this still might include incorrect masked out values
// under the validity bits, so we must correct for that

// if not all true, e.g. at least one is set.
// then we propagate that null as `false` as the nulls are equal
if both_sides_invalid.values().unset_bits() != both_sides_invalid.len() {
// we use the `binary` kernel directly to save allocations
// and apply `lhs & !rhs)` in one shot.
compute::boolean::binary_boolean_kernel(&lhs, &rhs, |lhs, rhs| {
binary(lhs, rhs, |lhs, rhs| (lhs & !rhs))
})
} else {
or
}
}
}
}
}
}
24 changes: 22 additions & 2 deletions tests/it/compute/comparison.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use arrow2::array::*;
use arrow2::bitmap::Bitmap;
use arrow2::compute::comparison::{self, boolean::*, primitive};
use arrow2::datatypes::{DataType::*, IntegerType, IntervalUnit, TimeUnit};
use arrow2::compute::comparison::{self, boolean::*, primitive, utf8};
use arrow2::datatypes::{DataType, DataType::*, IntegerType, IntervalUnit, TimeUnit};
use arrow2::scalar::new_scalar;

#[test]
Expand Down Expand Up @@ -381,3 +381,23 @@ fn primitive_gt_eq() {
BooleanArray::from([Some(true), Some(false), Some(true), None, None])
)
}

#[test]
#[cfg(any(feature = "compute_cast", feature = "compute_boolean_kleene"))]
fn utf8_and_validity() {
use arrow2::compute::cast::CastOptions;
let a1 = Utf8Array::<i32>::from([Some("0"), Some("1"), None, Some("2")]);
let a2 = Int32Array::from([Some(0), Some(1), None, Some(2)]);

// due to the cast the values underneath the validity bits differ
let a2 = arrow2::compute::cast::cast(&a2, &DataType::Utf8, CastOptions::default()).unwrap();
let a2 = a2.as_any().downcast_ref::<Utf8Array<i32>>().unwrap();

let expected = BooleanArray::from_slice([true, true, true, true]);
assert_eq!(utf8::eq_and_validity(&a1, &a1), expected);
assert_eq!(utf8::eq_and_validity(&a1, a2), expected);

let expected = BooleanArray::from_slice([false, false, false, false]);
assert_eq!(utf8::neq_and_validity(&a1, &a1), expected);
assert_eq!(utf8::neq_and_validity(&a1, a2), expected);
}

0 comments on commit 6c102a0

Please sign in to comment.