Skip to content
This repository has been archived by the owner on Feb 18, 2024. It is now read-only.

Commit

Permalink
Improve arity_assign: ~2x improvement if we share data. (#1076)
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 committed Jun 16, 2022
1 parent cd20ddf commit d4873fc
Show file tree
Hide file tree
Showing 8 changed files with 155 additions and 26 deletions.
4 changes: 2 additions & 2 deletions benches/assign_ops.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ fn add_benchmark(c: &mut Criterion) {
c.bench_function(&format!("apply_mul 2^{}", log2_size), |b| {
b.iter(|| {
criterion::black_box(&mut arr_a)
.apply_values(|x| x.iter_mut().for_each(|x| *x *= 1.01));
.apply_values_mut(|x| x.iter_mut().for_each(|x| *x *= 1.01));
assert!(!arr_a.value(10).is_nan());
})
});
Expand All @@ -30,7 +30,7 @@ fn add_benchmark(c: &mut Criterion) {
let mut arr_a = create_primitive_array::<f32>(size, 0.2);
let mut arr_b = create_primitive_array_with_seed::<f32>(size, 0.2, 10);
// convert to be close to 1.01
arr_b.apply_values(|x| x.iter_mut().for_each(|x| *x = 1.01 + *x / 20.0));
arr_b.apply_values_mut(|x| x.iter_mut().for_each(|x| *x = 1.01 + *x / 20.0));

c.bench_function(&format!("apply_mul null 2^{}", log2_size), |b| {
b.iter(|| {
Expand Down
2 changes: 1 addition & 1 deletion examples/cow.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ fn main() {
.unwrap();

// 2. call `apply_values` with the function to apply over the values
array.apply_values(|x| x.iter_mut().for_each(|x| *x *= 10));
array.apply_values_mut(|x| x.iter_mut().for_each(|x| *x *= 10));

// confirm that it gives the right result :)
assert_eq!(array, &PrimitiveArray::from_vec(vec![10i32, 20]));
Expand Down
4 changes: 2 additions & 2 deletions src/array/boolean/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ impl BooleanArray {
/// if it is being shared (since it results in a `O(N)` memcopy).
/// # Panics
/// This function panics if the function modifies the length of the [`MutableBitmap`].
pub fn apply_values<F: Fn(&mut MutableBitmap)>(&mut self, f: F) {
pub fn apply_values_mut<F: Fn(&mut MutableBitmap)>(&mut self, f: F) {
let values = std::mem::take(&mut self.values);
let mut values = values.make_mut();
f(&mut values);
Expand All @@ -121,7 +121,7 @@ impl BooleanArray {
/// if it is being shared (since it results in a `O(N)` memcopy).
/// # Panics
/// This function panics if the function modifies the length of the [`MutableBitmap`].
pub fn apply_validity<F: Fn(&mut MutableBitmap)>(&mut self, f: F) {
pub fn apply_validity_mut<F: Fn(&mut MutableBitmap)>(&mut self, f: F) {
if let Some(validity) = self.validity.as_mut() {
let values = std::mem::take(validity);
let mut bitmap = values.make_mut();
Expand Down
69 changes: 59 additions & 10 deletions src/array/primitive/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -245,12 +245,41 @@ impl<T: NativeType> PrimitiveArray<T> {
/// This function panics iff `validity.len() != self.len()`.
#[must_use]
pub fn with_validity(&self, validity: Option<Bitmap>) -> Self {
let mut out = self.clone();
out.set_validity(validity);
out
}

/// Update the validity buffer of this [`PrimitiveArray`].
/// # Panics
/// This function panics iff `values.len() != self.len()`.
pub fn set_validity(&mut self, validity: Option<Bitmap>) {
if matches!(&validity, Some(bitmap) if bitmap.len() != self.len()) {
panic!("validity should be as least as large as the array")
}
let mut arr = self.clone();
arr.validity = validity;
arr
self.validity = validity;
}

/// Returns a clone of this [`PrimitiveArray`] with a new values.
/// # Panics
/// This function panics iff `values.len() != self.len()`.
#[must_use]
pub fn with_values(&self, values: Buffer<T>) -> Self {
let mut out = self.clone();
out.set_values(values);
out
}

/// Update the values buffer of this [`PrimitiveArray`].
/// # Panics
/// This function panics iff `values.len() != self.len()`.
pub fn set_values(&mut self, values: Buffer<T>) {
assert_eq!(
values.len(),
self.len(),
"values length should be equal to this arrays length"
);
self.values = values;
}

/// Applies a function `f` to the values of this array, cloning the values
Expand All @@ -260,10 +289,14 @@ impl<T: NativeType> PrimitiveArray<T> {
/// # Implementation
/// This function is `O(f)` if the data is not being shared, and `O(N) + O(f)`
/// if it is being shared (since it results in a `O(N)` memcopy).
pub fn apply_values<F: Fn(&mut [T])>(&mut self, f: F) {
/// # Panics
/// This function panics, if `f` modifies the length of `&mut [T]`
pub fn apply_values_mut<F: Fn(&mut [T])>(&mut self, f: F) {
let values = std::mem::take(&mut self.values);
let mut values = values.make_mut();
let len = values.len();
f(&mut values);
assert_eq!(values.len(), len, "values length must remain the same");
self.values = values.into();
}

Expand All @@ -276,13 +309,29 @@ impl<T: NativeType> PrimitiveArray<T> {
/// if it is being shared (since it results in a `O(N)` memcopy).
/// # Panics
/// This function panics if the function modifies the length of the [`MutableBitmap`].
pub fn apply_validity<F: Fn(&mut MutableBitmap)>(&mut self, f: F) {
pub fn apply_validity_mut<F: Fn(&mut MutableBitmap)>(&mut self, f: F) {
if let Some(validity) = self.validity.as_mut() {
let owned_validity = std::mem::take(validity);
let mut mut_bitmap = owned_validity.make_mut();
f(&mut mut_bitmap);
assert_eq!(mut_bitmap.len(), self.values.len());
*validity = mut_bitmap.into();
}
}

/// Applies a function `f` to the validity of this array, the caller can decide to make
/// it mutable or not.
///
/// This is an API to leverage clone-on-write
/// # Implementation
/// This function is `O(f)` if the data is not being shared, and `O(N) + O(f)`
/// if it is being shared (since it results in a `O(N)` memcopy).
/// # Panics
/// This function panics if the function modifies the length of the [`MutableBitmap`].
pub fn apply_validity<F: Fn(&mut Bitmap)>(&mut self, f: F) {
if let Some(validity) = self.validity.as_mut() {
let values = std::mem::take(validity);
let mut bitmap = values.make_mut();
f(&mut bitmap);
assert_eq!(bitmap.len(), self.values.len());
*validity = bitmap.into();
f(validity);
assert_eq!(validity.len(), self.values.len());
}
}

Expand Down
31 changes: 31 additions & 0 deletions src/array/primitive/mutable.rs
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,37 @@ impl<T: NativeType> MutablePrimitiveArray<T> {
pub fn into_data(self) -> (DataType, Vec<T>, Option<MutableBitmap>) {
(self.data_type, self.values, self.validity)
}

/// Applies a function `f` to the values of this array, cloning the values
/// iff they are being shared with others
///
/// This is an API to use clone-on-write
/// # Implementation
/// This function is `O(f)` if the data is not being shared, and `O(N) + O(f)`
/// if it is being shared (since it results in a `O(N)` memcopy).
/// # Panics
/// This function panics, if `f` modifies the length of `&mut [T]`
pub fn apply_values<F: Fn(&mut [T])>(&mut self, f: F) {
let len = self.values.len();
f(&mut self.values);
assert_eq!(len, self.values.len(), "values length must remain the same")
}

/// Applies a function `f` to the validity of this array, cloning it
/// iff it is being shared.
///
/// This is an API to leverage clone-on-write
/// # Implementation
/// This function is `O(f)` if the data is not being shared, and `O(N) + O(f)`
/// if it is being shared (since it results in a `O(N)` memcopy).
/// # Panics
/// This function panics if the function modifies the length of the [`MutableBitmap`].
pub fn apply_validity<F: Fn(&mut MutableBitmap)>(&mut self, f: F) {
if let Some(validity) = &mut self.validity {
f(validity);
assert_eq!(validity.len(), self.values.len());
}
}
}

impl<T: NativeType> Default for MutablePrimitiveArray<T> {
Expand Down
63 changes: 56 additions & 7 deletions src/compute/arity_assign.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

use super::utils::check_same_len;
use crate::{array::PrimitiveArray, types::NativeType};
use either::Either;

/// Applies an unary function to a [`PrimitiveArray`] in-place via cow semantics.
///
Expand All @@ -17,7 +18,7 @@ where
I: NativeType,
F: Fn(I) -> I,
{
array.apply_values(|values| values.iter_mut().for_each(|v| *v = op(*v)));
array.apply_values_mut(|values| values.iter_mut().for_each(|v| *v = op(*v)));
}

/// Applies a binary operations to two [`PrimitiveArray`], applying the operation
Expand All @@ -38,20 +39,68 @@ where
{
check_same_len(lhs, rhs).unwrap();

// both for the validity and for the values
// we branch to check if we can mutate in place
// if we can, great that is fastest.
// if we cannot, we allocate a new buffer and assign values to that
// new buffer, that is benchmarked to be ~2x faster than first memcpy and assign in place
// for the validity bits it can be much faster as we might need to iterate all bits if the
// bitmap has an offset.
match rhs.validity() {
None => {}
Some(rhs) => {
if lhs.validity().is_none() {
*lhs = lhs.with_validity(Some(rhs.clone()))
} else {
lhs.apply_validity(|mut lhs| lhs &= rhs)
lhs.apply_validity(|bitmap| {
// we need to take ownership for the `into_mut` call, but leave the `&mut` lhs intact
// so that we can later assign the result to out `&mut bitmap`
let owned_lhs = std::mem::take(bitmap);

match owned_lhs.into_mut() {
// we take alloc and write to new buffer
Either::Left(immutable) => {
// we allocate a new bitmap because that is a lot faster
// than doing the memcpy or the potential iteration of bits if
// we are dealing with an offset
let new = &immutable & rhs;
*bitmap = new;
}
// we can mutate in place, happy days.
Either::Right(mut mutable) => {
let mut mutable_ref = &mut mutable;
mutable_ref &= rhs;
*bitmap = mutable.into()
}
}
});
}
}
}
// we need to take ownership for the `into_mut` call, but leave the `&mut` lhs intact
// so that we can later assign the result to out `&mut lhs`
let owned_lhs = std::mem::take(lhs);

lhs.apply_values(|x| {
x.iter_mut()
.zip(rhs.values().iter())
.for_each(|(l, r)| *l = op(*l, *r))
});
match owned_lhs.into_mut() {
// we take alloc and write to new buffer
Either::Left(mut immutable) => {
let values = immutable
.values()
.iter()
.zip(rhs.values().iter())
.map(|(l, r)| op(*l, *r))
.collect::<Vec<_>>();
immutable.set_values(values.into());
*lhs = immutable;
}
// we can mutate in place
Either::Right(mut mutable) => {
mutable.apply_values(|x| {
x.iter_mut()
.zip(rhs.values().iter())
.for_each(|(l, r)| *l = op(*l, *r))
});
*lhs = mutable.into()
}
}
}
4 changes: 2 additions & 2 deletions tests/it/array/boolean/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ fn from_iter() {
#[test]
fn apply_values() {
let mut a = BooleanArray::from([Some(true), Some(false), None]);
a.apply_values(|x| {
a.apply_values_mut(|x| {
let mut a = std::mem::take(x);
a = !a;
*x = a;
Expand All @@ -147,7 +147,7 @@ fn apply_values() {
#[test]
fn apply_validity() {
let mut a = BooleanArray::from([Some(true), Some(false), None]);
a.apply_validity(|x| {
a.apply_validity_mut(|x| {
let mut a = std::mem::take(x);
a = !a;
*x = a;
Expand Down
4 changes: 2 additions & 2 deletions tests/it/array/primitive/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ fn into_mut_3() {
#[test]
fn apply_values() {
let mut a = PrimitiveArray::from([Some(1), Some(2), None]);
a.apply_values(|x| {
a.apply_values_mut(|x| {
x[0] = 10;
});
let expected = PrimitiveArray::from([Some(10), Some(2), None]);
Expand All @@ -138,7 +138,7 @@ fn apply_values() {
#[test]
fn apply_validity() {
let mut a = PrimitiveArray::from([Some(1), Some(2), None]);
a.apply_validity(|x| {
a.apply_validity_mut(|x| {
let mut a = std::mem::take(x);
a = !a;
*x = a;
Expand Down

0 comments on commit d4873fc

Please sign in to comment.