Skip to content
This repository has been archived by the owner on Feb 18, 2024. It is now read-only.

Improved UnionArray #1331

Merged
merged 3 commits into from Dec 18, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/array/growable/union.rs
Expand Up @@ -89,11 +89,11 @@ impl<'a> Growable<'a> for GrowableUnion<'a> {
fn extend_validity(&mut self, _additional: usize) {}

fn as_arc(&mut self) -> Arc<dyn Array> {
Arc::new(self.to())
self.to().arced()
}

fn as_box(&mut self) -> Box<dyn Array> {
Box::new(self.to())
self.to().boxed()
}
}

Expand Down
5 changes: 4 additions & 1 deletion src/array/union/iterator.rs
Expand Up @@ -8,6 +8,7 @@ pub struct UnionIter<'a> {
}

impl<'a> UnionIter<'a> {
#[inline]
pub fn new(array: &'a UnionArray) -> Self {
Self { array, current: 0 }
}
Expand All @@ -16,16 +17,18 @@ impl<'a> UnionIter<'a> {
impl<'a> Iterator for UnionIter<'a> {
type Item = Box<dyn Scalar>;

#[inline]
fn next(&mut self) -> Option<Self::Item> {
if self.current == self.array.len() {
None
} else {
let old = self.current;
self.current += 1;
Some(self.array.value(old))
Some(unsafe { self.array.value_unchecked(old) })
}
}

#[inline]
fn size_hint(&self) -> (usize, Option<usize>) {
let len = self.array.len() - self.current;
(len, Some(len))
Expand Down
175 changes: 121 additions & 54 deletions src/array/union/mod.rs
@@ -1,5 +1,3 @@
use ahash::AHashMap;

use crate::{
bitmap::Bitmap,
buffer::Buffer,
Expand All @@ -14,7 +12,6 @@ mod ffi;
pub(super) mod fmt;
mod iterator;

type FieldEntry = (usize, Box<dyn Array>);
type UnionComponents<'a> = (&'a [Field], Option<&'a [i32]>, UnionMode);

/// [`UnionArray`] represents an array whose each slot can contain different values.
Expand All @@ -29,10 +26,13 @@ type UnionComponents<'a> = (&'a [Field], Option<&'a [i32]>, UnionMode);
// ```
#[derive(Clone)]
pub struct UnionArray {
// Invariant: every item in `types` is `> 0 && < fields.len()`
types: Buffer<i8>,
// None represents when there is no typeid
fields_hash: Option<AHashMap<i8, FieldEntry>>,
// Invariant: `map.len() == fields.len()`
// Invariant: every item in `map` is `> 0 && < fields.len()`
map: Option<[usize; 127]>,
fields: Vec<Box<dyn Array>>,
// Invariant: when set, `offsets.len() == types.len()`
offsets: Option<Buffer<i32>>,
data_type: DataType,
offset: usize,
Expand All @@ -44,6 +44,7 @@ impl UnionArray {
/// This function errors iff:
/// * `data_type`'s physical type is not [`crate::datatypes::PhysicalType::Union`].
/// * the fields's len is different from the `data_type`'s children's length
/// * The number of `fields` is larger than `i8::MAX`
/// * any of the values's data type is different from its corresponding children' data type
pub fn try_new(
data_type: DataType,
Expand All @@ -58,6 +59,10 @@ impl UnionArray {
"The number of `fields` must equal the number of children fields in DataType::Union",
));
};
let number_of_fields: i8 = fields
.len()
.try_into()
.map_err(|_| Error::oos("The number of `fields` cannot be larger than i8::MAX"))?;

f
.iter().map(|a| a.data_type())
Expand All @@ -74,27 +79,75 @@ impl UnionArray {
}
})?;

if let Some(offsets) = &offsets {
if offsets.len() != types.len() {
return Err(Error::oos(
"In a UnionArray, the offsets' length must be equal to the number of types",
));
}
}
if offsets.is_none() != mode.is_sparse() {
return Err(Error::oos(
"The offsets must be set when the Union is dense and vice-versa",
"In a sparse UnionArray, the offsets must be set (and vice-versa)",
));
}

let fields_hash = ids.as_ref().map(|ids| {
ids.iter()
.map(|x| *x as i8)
.enumerate()
.zip(fields.iter().cloned())
.map(|((i, type_), field)| (type_, (i, field)))
.collect()
});

// not validated:
// * `offsets` is valid
// * max id < fields.len()
// build hash
let map = if let Some(&ids) = ids.as_ref() {
if ids.len() != fields.len() {
return Err(Error::oos(
"In a union, when the ids are set, their length must be equal to the number of fields",
));
}

// example:
// * types = [5, 7, 5, 7, 7, 7, 5, 7, 7, 5, 5]
// * ids = [5, 7]
// => hash = [0, 0, 0, 0, 0, 0, 1, 0, ...]
let mut hash = [0; 127];

for (pos, &id) in ids.iter().enumerate() {
if !(0..=127).contains(&id) {
return Err(Error::oos(
"In a union, when the ids are set, every id must belong to [0, 128[",
));
}
hash[id as usize] = pos;
}

types.iter().try_for_each(|&type_| {
if type_ < 0 {
return Err(Error::oos("In a union, when the ids are set, every type must be >= 0"));
}
let id = hash[type_ as usize];
if id >= fields.len() {
Err(Error::oos("In a union, when the ids are set, each id must be smaller than the number of fields."))
} else {
Ok(())
}
})?;

Some(hash)
} else {
// Safety: every type in types is smaller than number of fields
let mut is_valid = true;
for &type_ in types.iter() {
if type_ < 0 || type_ >= number_of_fields {
is_valid = false
}
}
if !is_valid {
return Err(Error::oos(
"Every type in `types` must be larger than 0 and smaller than the number of fields.",
));
}

None
};

Ok(Self {
data_type,
fields_hash,
map,
fields,
offsets,
types,
Expand Down Expand Up @@ -128,7 +181,7 @@ impl UnionArray {
let offsets = if mode.is_sparse() {
None
} else {
Some((0..length as i32).collect::<Buffer<i32>>())
Some((0..length as i32).collect::<Vec<_>>().into())
};

// all from the same field
Expand All @@ -151,12 +204,12 @@ impl UnionArray {
let offsets = if mode.is_sparse() {
None
} else {
Some(Buffer::new())
Some(Buffer::default())
};

Self {
data_type,
fields_hash: None,
map: None,
fields,
offsets,
types: Buffer::new(),
Expand Down Expand Up @@ -186,17 +239,11 @@ impl UnionArray {
/// This function panics iff `offset + length >= self.len()`.
#[inline]
pub fn slice(&self, offset: usize, length: usize) -> Self {
Self {
data_type: self.data_type.clone(),
fields: self.fields.clone(),
fields_hash: self.fields_hash.clone(),
types: self.types.clone().slice(offset, length),
offsets: self
.offsets
.clone()
.map(|offsets| offsets.slice(offset, length)),
offset: self.offset + offset,
}
assert!(
offset + length <= self.len(),
"the offset of the new array cannot exceed the existing length"
);
unsafe { self.slice_unchecked(offset, length) }
}

/// Returns a slice of this [`UnionArray`].
Expand All @@ -206,10 +253,11 @@ impl UnionArray {
/// The caller must ensure that `offset + length <= self.len()`.
#[inline]
pub unsafe fn slice_unchecked(&self, offset: usize, length: usize) -> Self {
debug_assert!(offset + length <= self.len());
Self {
data_type: self.data_type.clone(),
fields: self.fields.clone(),
fields_hash: self.fields_hash.clone(),
map: self.map,
types: self.types.clone().slice_unchecked(offset, length),
offsets: self
.offsets
Expand Down Expand Up @@ -243,38 +291,57 @@ impl UnionArray {
}

#[inline]
fn field(&self, type_: i8) -> &dyn Array {
self.fields_hash
.as_ref()
.map(|x| x[&type_].1.as_ref())
.unwrap_or_else(|| self.fields[type_ as usize].as_ref())
}

#[inline]
fn field_slot(&self, index: usize) -> usize {
unsafe fn field_slot_unchecked(&self, index: usize) -> usize {
self.offsets()
.as_ref()
.map(|x| x[index] as usize)
.map(|x| *x.get_unchecked(index) as usize)
.unwrap_or(index + self.offset)
}

/// Returns the index and slot of the field to select from `self.fields`.
#[inline]
pub fn index(&self, index: usize) -> (usize, usize) {
let type_ = self.types()[index];
let field_index = self
.fields_hash
assert!(index < self.len());
unsafe { self.index_unchecked(index) }
}

/// Returns the index and slot of the field to select from `self.fields`.
/// The first value is guaranteed to be `< self.fields().len()`
/// # Safety
/// This function is safe iff `index < self.len`.
#[inline]
pub unsafe fn index_unchecked(&self, index: usize) -> (usize, usize) {
debug_assert!(index < self.len());
// Safety: assumption of the function
let type_ = unsafe { *self.types.get_unchecked(index) };
// Safety: assumption of the struct
let type_ = self
.map
.as_ref()
.map(|x| x[&type_].0)
.unwrap_or_else(|| type_ as usize);
let index = self.field_slot(index);
(field_index, index)
.map(|map| unsafe { *map.get_unchecked(type_ as usize) })
.unwrap_or(type_ as usize);
// Safety: assumption of the function
let index = self.field_slot_unchecked(index);
(type_, index)
}

/// Returns the slot `index` as a [`Scalar`].
/// # Panics
/// iff `index >= self.len()`
pub fn value(&self, index: usize) -> Box<dyn Scalar> {
let type_ = self.types()[index];
let field = self.field(type_);
let index = self.field_slot(index);
assert!(index < self.len());
unsafe { self.value_unchecked(index) }
}

/// Returns the slot `index` as a [`Scalar`].
/// # Safety
/// This function is safe iff `i < self.len`.
pub unsafe fn value_unchecked(&self, index: usize) -> Box<dyn Scalar> {
debug_assert!(index < self.len());
let (type_, index) = self.index_unchecked(index);
// Safety: assumption of the struct
debug_assert!(type_ < self.fields.len());
let field = self.fields.get_unchecked(type_).as_ref();
new_scalar(field, index)
}
}
Expand Down
4 changes: 2 additions & 2 deletions src/compute/sort/row/mod.rs
Expand Up @@ -647,9 +647,9 @@ mod tests {
#[test]
fn test_fixed_width() {
let cols = [
Int16Array::from_iter([Some(1), Some(2), None, Some(-5), Some(2), Some(2), Some(0)])
Int16Array::from([Some(1), Some(2), None, Some(-5), Some(2), Some(2), Some(0)])
.to_boxed(),
Float32Array::from_iter([
Float32Array::from([
Some(1.3),
Some(2.5),
None,
Expand Down
6 changes: 3 additions & 3 deletions src/compute/sort/row/variable.rs
Expand Up @@ -76,10 +76,9 @@ pub fn encode<'a, I: Iterator<Item = Option<&'a [u8]>>>(out: &mut Rows, i: I, op
// Write `2_u8` to demarcate as non-empty, non-null string
to_write[0] = NON_EMPTY_SENTINEL;

let chunks = val.chunks_exact(BLOCK_SIZE);
let remainder = chunks.remainder();
let mut chunks = val.chunks_exact(BLOCK_SIZE);
for (input, output) in chunks
.clone()
.by_ref()
.zip(to_write[1..].chunks_exact_mut(BLOCK_SIZE + 1))
{
let input: &[u8; BLOCK_SIZE] = input.try_into().unwrap();
Expand All @@ -92,6 +91,7 @@ pub fn encode<'a, I: Iterator<Item = Option<&'a [u8]>>>(out: &mut Rows, i: I, op
output[BLOCK_SIZE] = BLOCK_CONTINUATION;
}

let remainder = chunks.remainder();
if !remainder.is_empty() {
let start_offset = 1 + (block_count - 1) * (BLOCK_SIZE + 1);
to_write[start_offset..start_offset + remainder.len()]
Expand Down
3 changes: 2 additions & 1 deletion src/io/json_integration/read/array.rs
Expand Up @@ -414,7 +414,8 @@ pub fn to_array(
}
_ => panic!(),
})
.collect(),
.collect::<Vec<_>>()
.into(),
)
})
.unwrap_or_default();
Expand Down