Skip to content
This repository has been archived by the owner on Feb 18, 2024. It is now read-only.

fixed comparisson and validity kernels #1243

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 62 additions & 5 deletions src/compute/boolean.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,17 +6,24 @@ use crate::scalar::BooleanScalar;

use super::utils::combine_validities;

/// Helper function to implement binary kernels
fn binary_boolean_kernel<F>(lhs: &BooleanArray, rhs: &BooleanArray, op: F) -> BooleanArray
where
F: Fn(&Bitmap, &Bitmap) -> Bitmap,
{
fn assert_lengths(lhs: &BooleanArray, rhs: &BooleanArray) {
assert_eq!(
lhs.len(),
rhs.len(),
"lhs and rhs must have the same length"
);
}

/// Helper function to implement binary kernels
pub(crate) fn binary_boolean_kernel<F>(
lhs: &BooleanArray,
rhs: &BooleanArray,
op: F,
) -> BooleanArray
where
F: Fn(&Bitmap, &Bitmap) -> Bitmap,
{
assert_lengths(lhs, rhs);
let validity = combine_validities(lhs.validity(), rhs.validity());

let left_buffer = lhs.values();
Expand All @@ -41,6 +48,31 @@ where
/// assert_eq!(and_ab, BooleanArray::from(&[Some(false), Some(true), None]));
/// ```
pub fn and(lhs: &BooleanArray, rhs: &BooleanArray) -> BooleanArray {
if lhs.null_count() == 0 && rhs.null_count() == 0 {
let left_buffer = lhs.values();
let right_buffer = rhs.values();

match (left_buffer.unset_bits(), right_buffer.unset_bits()) {
// all values are `true` on both sides
(0, 0) => {
assert_lengths(lhs, rhs);
return lhs.clone();
}
// all values are `false` on left side
(l, _) if l == lhs.len() => {
assert_lengths(lhs, rhs);
return lhs.clone();
}
// all values are `false` on right side
(_, r) if r == rhs.len() => {
assert_lengths(lhs, rhs);
return rhs.clone();
}
// ignore the rest
_ => {}
}
}

binary_boolean_kernel(lhs, rhs, |lhs, rhs| lhs & rhs)
}

Expand All @@ -58,6 +90,31 @@ pub fn and(lhs: &BooleanArray, rhs: &BooleanArray) -> BooleanArray {
/// assert_eq!(or_ab, BooleanArray::from(vec![Some(true), Some(true), None]));
/// ```
pub fn or(lhs: &BooleanArray, rhs: &BooleanArray) -> BooleanArray {
if lhs.null_count() == 0 && rhs.null_count() == 0 {
let left_buffer = lhs.values();
let right_buffer = rhs.values();

match (left_buffer.unset_bits(), right_buffer.unset_bits()) {
// all values are `true` on left side
(0, _) => {
assert_lengths(lhs, rhs);
return lhs.clone();
}
// all values are `true` on right side
(_, 0) => {
assert_lengths(lhs, rhs);
return rhs.clone();
}
// all values on lhs and rhs are `false`
(l, r) if l == lhs.len() && r == rhs.len() => {
assert_lengths(lhs, rhs);
return rhs.clone();
}
// ignore the rest
_ => {}
}
}

binary_boolean_kernel(lhs, rhs, |lhs, rhs| lhs | rhs)
}

Expand Down
75 changes: 72 additions & 3 deletions src/compute/comparison/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ mod simd;
pub use simd::{Simd8, Simd8Lanes, Simd8PartialEq, Simd8PartialOrd};

use super::take::take_boolean;
use crate::bitmap::Bitmap;
use crate::bitmap::{binary, Bitmap};
use crate::compute;
pub(crate) use primitive::{
compare_values_op as primitive_compare_values_op,
Expand Down Expand Up @@ -521,10 +521,43 @@ fn finish_eq_validities(
&BooleanArray::new(DataType::Boolean, rhs, None),
),
(Some(lhs), Some(rhs)) => {
let lhs_validity_unset_bits = lhs.unset_bits();
let rhs_validity_unset_bits = rhs.unset_bits();

// this branch is a bit more complicated as both arrays can have masked out values
// these masked out values might differ and lead to a `eq == false` that has to
// be corrected as both should be `null == null = true`

let lhs = BooleanArray::new(DataType::Boolean, lhs, None);
let rhs = BooleanArray::new(DataType::Boolean, rhs, None);
let eq_validities = compute::comparison::boolean::eq(&lhs, &rhs);
compute::boolean::and(&output_without_validities, &eq_validities)

// validity_bits are equal AND values are equal
let equal = compute::boolean::and(&output_without_validities, &eq_validities);

match (lhs_validity_unset_bits, rhs_validity_unset_bits) {
// there is at least one side with all values valid
// so we don't have to correct.
(0, _) | (_, 0) => equal,
_ => {
// we use the binary kernel here to save allocations
// and apply `!(lhs | rhs)` in one step
let both_sides_invalid =
compute::boolean::binary_boolean_kernel(&lhs, &rhs, |lhs, rhs| {
binary(lhs, rhs, |lhs, rhs| !(lhs | rhs))
});
// this still might include incorrect masked out values
// under the validity bits, so we must correct for that

// if not all true, e.g. at least one is set.
// then we propagate that null as `true` in equality
if both_sides_invalid.values().unset_bits() != both_sides_invalid.len() {
compute::boolean::or(&equal, &both_sides_invalid)
} else {
equal
}
}
}
}
}
}
Expand All @@ -547,10 +580,46 @@ fn finish_neq_validities(
compute::boolean::or(&output_without_validities, &rhs_negated)
}
(Some(lhs), Some(rhs)) => {
let lhs_validity_unset_bits = lhs.unset_bits();
let rhs_validity_unset_bits = rhs.unset_bits();

// this branch is a bit more complicated as both arrays can have masked out values
// these masked out values might differ and lead to a `neq == true` that has to
// be corrected as both should be `null != null = false`
let lhs = BooleanArray::new(DataType::Boolean, lhs, None);
let rhs = BooleanArray::new(DataType::Boolean, rhs, None);
let neq_validities = compute::comparison::boolean::neq(&lhs, &rhs);
compute::boolean::or(&output_without_validities, &neq_validities)

// validity_bits are not equal OR values not equal
let or = compute::boolean::or(&output_without_validities, &neq_validities);

match (lhs_validity_unset_bits, rhs_validity_unset_bits) {
// there is at least one side with all values valid
// so we don't have to correct.
(0, _) | (_, 0) => or,
_ => {
// we use the binary kernel here to save allocations
// and apply `!(lhs | rhs)` in one step
let both_sides_invalid =
compute::boolean::binary_boolean_kernel(&lhs, &rhs, |lhs, rhs| {
binary(lhs, rhs, |lhs, rhs| !(lhs | rhs))
});
// this still might include incorrect masked out values
// under the validity bits, so we must correct for that

// if not all true, e.g. at least one is set.
// then we propagate that null as `false` as the nulls are equal
if both_sides_invalid.values().unset_bits() != both_sides_invalid.len() {
// we use the `binary` kernel directly to save allocations
// and apply `lhs & !rhs)` in one shot.
compute::boolean::binary_boolean_kernel(&lhs, &rhs, |lhs, rhs| {
binary(lhs, rhs, |lhs, rhs| (lhs & !rhs))
})
} else {
or
}
}
}
}
}
}
24 changes: 22 additions & 2 deletions tests/it/compute/comparison.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use arrow2::array::*;
use arrow2::bitmap::Bitmap;
use arrow2::compute::comparison::{self, boolean::*, primitive};
use arrow2::datatypes::{DataType::*, IntegerType, IntervalUnit, TimeUnit};
use arrow2::compute::comparison::{self, boolean::*, primitive, utf8};
use arrow2::datatypes::{DataType, DataType::*, IntegerType, IntervalUnit, TimeUnit};
use arrow2::scalar::new_scalar;

#[test]
Expand Down Expand Up @@ -381,3 +381,23 @@ fn primitive_gt_eq() {
BooleanArray::from([Some(true), Some(false), Some(true), None, None])
)
}

#[test]
#[cfg(any(feature = "compute_cast", feature = "compute_boolean_kleene"))]
fn utf8_and_validity() {
use arrow2::compute::cast::CastOptions;
let a1 = Utf8Array::<i32>::from([Some("0"), Some("1"), None, Some("2")]);
let a2 = Int32Array::from([Some(0), Some(1), None, Some(2)]);

// due to the cast the values underneath the validity bits differ
let a2 = arrow2::compute::cast::cast(&a2, &DataType::Utf8, CastOptions::default()).unwrap();
let a2 = a2.as_any().downcast_ref::<Utf8Array<i32>>().unwrap();

let expected = BooleanArray::from_slice([true, true, true, true]);
assert_eq!(utf8::eq_and_validity(&a1, &a1), expected);
assert_eq!(utf8::eq_and_validity(&a1, a2), expected);

let expected = BooleanArray::from_slice([false, false, false, false]);
assert_eq!(utf8::neq_and_validity(&a1, &a1), expected);
assert_eq!(utf8::neq_and_validity(&a1, a2), expected);
}