From aa42ac252817a44097a4c503ac606b34ede342a0 Mon Sep 17 00:00:00 2001 From: Vasanthakumar Vijayasekaran Date: Sat, 18 Sep 2021 11:28:23 +0530 Subject: [PATCH] Add `extend`/`extend_unchecked` for `MutableUtf8Array` (#413) --- src/array/utf8/mutable.rs | 197 ++++++++++++++++++++++++++------- tests/it/array/utf8/mutable.rs | 36 ++++++ 2 files changed, 196 insertions(+), 37 deletions(-) diff --git a/src/array/utf8/mutable.rs b/src/array/utf8/mutable.rs index 80e1e292ae6..cbc3904e022 100644 --- a/src/array/utf8/mutable.rs +++ b/src/array/utf8/mutable.rs @@ -224,6 +224,75 @@ impl> FromIterator> for MutableUtf8Array { } impl MutableUtf8Array { + /// Extends the [`MutableUtf8Array`] from an iterator of values of trusted len. + /// This differs from `extended_trusted_len` which accepts iterator of optional values. + #[inline] + pub fn extend_trusted_len_values(&mut self, iterator: I) + where + P: AsRef, + I: TrustedLen, + { + unsafe { self.extend_trusted_len_values_unchecked(iterator) } + } + + /// Extends the [`MutableUtf8Array`] from an iterator of values of trusted len. + /// This differs from `extended_trusted_len_unchecked` which accepts iterator of optional + /// values. + /// # Safety + /// The iterator must be trusted len. + #[inline] + pub unsafe fn extend_trusted_len_values_unchecked(&mut self, iterator: I) + where + P: AsRef, + I: Iterator, + { + let (_, upper) = iterator.size_hint(); + let additional = upper.expect("extend_trusted_len_values requires an upper limit"); + + extend_from_trusted_len_values_iter(&mut self.offsets, &mut self.values, iterator); + + if let Some(validity) = self.validity.as_mut() { + validity.extend_constant(additional, true); + } + } + + /// Extends the [`MutableUtf8Array`] from an iterator of trusted len. + #[inline] + pub fn extend_trusted_len(&mut self, iterator: I) + where + P: AsRef, + I: TrustedLen>, + { + unsafe { self.extend_trusted_len_unchecked(iterator) } + } + + /// Extends [`MutableUtf8Array`] from an iterator of trusted len. + /// #Safety + /// The iterator must be trusted len. + #[inline] + pub unsafe fn extend_trusted_len_unchecked(&mut self, iterator: I) + where + P: AsRef, + I: Iterator>, + { + if self.validity.is_none() { + let mut validity = MutableBitmap::new(); + validity.extend_constant(self.len(), true); + self.validity = Some(validity); + } + + extend_from_trusted_len_iter( + &mut self.offsets, + &mut self.values, + &mut self.validity.as_mut().unwrap(), + iterator, + ); + + if self.validity.as_mut().unwrap().null_count() == 0 { + self.validity = None; + } + } + /// Creates a [`MutableUtf8Array`] from an iterator of trusted length. /// # Safety /// The iterator must be [`TrustedLen`](https://doc.rust-lang.org/std/iter/trait.TrustedLen.html). @@ -377,37 +446,13 @@ where P: AsRef, I: Iterator>, { - let (_, upper) = iterator.size_hint(); - let len = upper.expect("trusted_len_unzip requires an upper limit"); - - let mut validity = MutableBitmap::with_capacity(len); - let mut offsets = MutableBuffer::::with_capacity(len + 1); + let mut offsets = MutableBuffer::::with_capacity(1); let mut values = MutableBuffer::::new(); + let mut validity = MutableBitmap::new(); - let mut length = O::default(); - let mut dst = offsets.as_mut_ptr(); - std::ptr::write(dst, length); - dst = dst.add(1); - for item in iterator { - if let Some(item) = item { - validity.push(true); - let s = item.as_ref(); - length += O::from_usize(s.len()).unwrap(); - values.extend_from_slice(s.as_bytes()); - } else { - validity.push(false); - values.extend_from_slice(b""); - }; + offsets.push_unchecked(O::default()); - std::ptr::write(dst, length); - dst = dst.add(1); - } - assert_eq!( - dst.offset_from(offsets.as_ptr()) as usize, - len + 1, - "Trusted iterator length was not accurately reported" - ); - offsets.set_len(len + 1); + extend_from_trusted_len_iter(&mut offsets, &mut values, &mut validity, iterator); let validity = if validity.null_count() > 0 { Some(validity) @@ -481,33 +526,111 @@ where O: Offset, P: AsRef, I: Iterator, +{ + let mut offsets = MutableBuffer::::with_capacity(1 + iterator.size_hint().1.unwrap()); + let mut values = MutableBuffer::::new(); + + offsets.push_unchecked(O::default()); + + extend_from_trusted_len_values_iter(&mut offsets, &mut values, iterator); + + (offsets, values) +} + +/// Populates `offsets` and `values` [`Buffer`] with information +/// extracted from the incoming iterator. +/// # Safety +/// The caller must ensure that `iterator` is [`TrustedLen`] +#[inline] +unsafe fn extend_from_trusted_len_values_iter( + offsets: &mut MutableBuffer, + values: &mut MutableBuffer, + iterator: I, +) where + O: Offset, + P: AsRef, + I: Iterator, { let (_, upper) = iterator.size_hint(); - let len = upper.expect("trusted_len_unzip requires an upper limit"); + let additional = upper.expect("extend_from_trusted_len_iter_values requires an upper limit"); - let mut offsets = MutableBuffer::::with_capacity(len + 1); - let mut values = MutableBuffer::::new(); + offsets.reserve(additional); + + let mut length = *offsets.last().unwrap(); - let mut length = O::default(); let mut dst = offsets.as_mut_ptr(); - std::ptr::write(dst, length); - dst = dst.add(1); + dst = dst.add(offsets.len()); + for item in iterator { let s = item.as_ref(); + length += O::from_usize(s.len()).unwrap(); + values.extend_from_slice(s.as_bytes()); + std::ptr::write(dst, length); + + dst = dst.add(1); + } + + assert_eq!( + dst.offset_from(offsets.as_ptr()) as usize, + offsets.len() + additional, + "Trusted iterator length was not accurately reported" + ); + + offsets.set_len(offsets.len() + additional); +} + +/// Populates `offsets`, `values`, and validity [`Buffer`] with information +/// extracted from the incoming iterator. +/// # Safety +/// The caller must ensure that `iterator` is [`TrustedLen`] +#[inline] +unsafe fn extend_from_trusted_len_iter( + offsets: &mut MutableBuffer, + values: &mut MutableBuffer, + validity: &mut MutableBitmap, + iterator: I, +) where + O: Offset, + P: AsRef, + I: Iterator>, +{ + let (_, upper) = iterator.size_hint(); + let additional = upper.expect("extend_from_trusted_len_values_iter requires an upper limit"); + + offsets.reserve(additional); + validity.reserve(additional); + + let mut length = *offsets.last().unwrap(); + + let mut dst = offsets.as_mut_ptr(); + dst = dst.add(offsets.len()); + + for item in iterator { + if let Some(item) = item { + let s = item.as_ref(); + + length += O::from_usize(s.len()).unwrap(); + + values.extend_from_slice(s.as_bytes()); + validity.push_unchecked(true); + } else { + validity.push_unchecked(false); + }; std::ptr::write(dst, length); + dst = dst.add(1); } + assert_eq!( dst.offset_from(offsets.as_ptr()) as usize, - len + 1, + offsets.len() + additional, "Trusted iterator length was not accurately reported" ); - offsets.set_len(len + 1); - (offsets, values) + offsets.set_len(offsets.len() + additional); } /// Creates two [`MutableBuffer`]s from an iterator of `&str`. diff --git a/tests/it/array/utf8/mutable.rs b/tests/it/array/utf8/mutable.rs index c7ee553ee5e..6afd6c666f2 100644 --- a/tests/it/array/utf8/mutable.rs +++ b/tests/it/array/utf8/mutable.rs @@ -45,3 +45,39 @@ fn wrong_data_type() { let values = MutableBuffer::from(b"abbb"); MutableUtf8Array::::from_data(DataType::Int8, offsets, values, None); } + +#[test] +fn test_extend_trusted_len_values() { + let mut array = MutableUtf8Array::::new(); + + array.extend_trusted_len_values(["hi", "there"].iter()); + array.extend_trusted_len_values(["hello"].iter()); + array.extend_trusted_len(vec![Some("again"), None].into_iter()); + + let array: Utf8Array = array.into(); + + assert_eq!(array.values().as_slice(), b"hitherehelloagain"); + assert_eq!(array.offsets().as_slice(), &[0, 2, 7, 12, 17, 17]); + assert_eq!( + array.validity(), + &Some(Bitmap::from_u8_slice(&[0b00001111], 5)) + ); +} + +#[test] +fn test_extend_trusted_len() { + let mut array = MutableUtf8Array::::new(); + + array.extend_trusted_len(vec![Some("hi"), Some("there")].into_iter()); + array.extend_trusted_len(vec![None, Some("hello")].into_iter()); + array.extend_trusted_len_values(["again"].iter()); + + let array: Utf8Array = array.into(); + + assert_eq!(array.values().as_slice(), b"hitherehelloagain"); + assert_eq!(array.offsets().as_slice(), &[0, 2, 7, 7, 12, 17]); + assert_eq!( + array.validity(), + &Some(Bitmap::from_u8_slice(&[0b00011011], 5)) + ); +}