Skip to content
This repository has been archived by the owner on Feb 18, 2024. It is now read-only.

Commit

Permalink
Add support for union scalars (#930)
Browse files Browse the repository at this point in the history
  • Loading branch information
ncpenke committed May 1, 2022
1 parent 3c64e7a commit 9a38663
Show file tree
Hide file tree
Showing 5 changed files with 143 additions and 49 deletions.
2 changes: 1 addition & 1 deletion src/array/null.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ impl NullArray {
pub fn try_new(data_type: DataType, length: usize) -> Result<Self, ArrowError> {
if data_type.to_physical_type() != PhysicalType::Null {
return Err(ArrowError::oos(
"BooleanArray can only be initialized with a DataType whose physical type is Boolean",
"NullArray can only be initialized with a DataType whose physical type is Boolean",
));
}

Expand Down
2 changes: 1 addition & 1 deletion src/scalar/equal.rs
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ fn equal(lhs: &dyn Scalar, rhs: &dyn Scalar) -> bool {
Struct => dyn_eq!(StructScalar, lhs, rhs),
FixedSizeBinary => dyn_eq!(FixedSizeBinaryScalar, lhs, rhs),
FixedSizeList => dyn_eq!(FixedSizeListScalar, lhs, rhs),
Union => unimplemented!("{:?}", Union),
Union => dyn_eq!(UnionScalar, lhs, rhs),
Map => unimplemented!("{:?}", Map),
}
}
12 changes: 11 additions & 1 deletion src/scalar/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ mod fixed_size_list;
pub use fixed_size_list::*;
mod fixed_size_binary;
pub use fixed_size_binary::*;
mod union;
pub use union::UnionScalar;

/// Trait object declaring an optional value with a [`DataType`].
/// This strait is often used in APIs that accept multiple scalar types.
Expand Down Expand Up @@ -144,7 +146,15 @@ pub fn new_scalar(array: &dyn Array, index: usize) -> Box<dyn Scalar> {
};
Box::new(FixedSizeListScalar::new(array.data_type().clone(), value))
}
Union | Map => todo!(),
Union => {
let array = array.as_any().downcast_ref::<UnionArray>().unwrap();
Box::new(UnionScalar::new(
array.data_type().clone(),
array.types()[index],
array.value(index).into(),
))
}
Map => todo!(),
Dictionary(key_type) => match_integer_type!(key_type, |$T| {
let array = array
.as_any()
Expand Down
54 changes: 54 additions & 0 deletions src/scalar/union.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
use std::sync::Arc;

use crate::datatypes::DataType;

use super::Scalar;

/// A single entry of a [`crate::array::UnionArray`].
#[derive(Debug, Clone, PartialEq)]
pub struct UnionScalar {
value: Arc<dyn Scalar>,
type_: i8,
data_type: DataType,
}

impl UnionScalar {
/// Returns a new [`UnionScalar`]
#[inline]
pub fn new(data_type: DataType, type_: i8, value: Arc<dyn Scalar>) -> Self {
Self {
value,
type_,
data_type,
}
}

/// Returns the inner value
#[inline]
pub fn value(&self) -> &Arc<dyn Scalar> {
&self.value
}

/// Returns the type of the union scalar
#[inline]
pub fn type_(&self) -> i8 {
self.type_
}
}

impl Scalar for UnionScalar {
#[inline]
fn as_any(&self) -> &dyn std::any::Any {
self
}

#[inline]
fn is_valid(&self) -> bool {
true
}

#[inline]
fn data_type(&self) -> &DataType {
&self.data_type
}
}
122 changes: 76 additions & 46 deletions tests/it/array/union.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,22 @@ use arrow2::{
buffer::Buffer,
datatypes::*,
error::Result,
scalar::{PrimitiveScalar, Utf8Scalar},
scalar::{new_scalar, PrimitiveScalar, Scalar, UnionScalar, Utf8Scalar},
};

fn next_unchecked<T, I>(iter: &mut I) -> T
where
I: Iterator<Item = Box<dyn Scalar>>,
T: Clone + 'static,
{
iter.next()
.unwrap()
.as_any()
.downcast_ref::<T>()
.unwrap()
.clone()
}

#[test]
fn sparse_debug() -> Result<()> {
let fields = vec![
Expand Down Expand Up @@ -94,30 +107,15 @@ fn iter_sparse() -> Result<()> {
let mut iter = array.iter();

assert_eq!(
iter.next()
.unwrap()
.as_any()
.downcast_ref::<PrimitiveScalar<i32>>()
.unwrap()
.value(),
next_unchecked::<PrimitiveScalar<i32>, _>(&mut iter).value(),
Some(1)
);
assert_eq!(
iter.next()
.unwrap()
.as_any()
.downcast_ref::<PrimitiveScalar<i32>>()
.unwrap()
.value(),
next_unchecked::<PrimitiveScalar<i32>, _>(&mut iter).value(),
None
);
assert_eq!(
iter.next()
.unwrap()
.as_any()
.downcast_ref::<Utf8Scalar<i32>>()
.unwrap()
.value(),
next_unchecked::<Utf8Scalar<i32>, _>(&mut iter).value(),
Some("c")
);
assert_eq!(iter.next(), None);
Expand All @@ -143,30 +141,15 @@ fn iter_dense() -> Result<()> {
let mut iter = array.iter();

assert_eq!(
iter.next()
.unwrap()
.as_any()
.downcast_ref::<PrimitiveScalar<i32>>()
.unwrap()
.value(),
next_unchecked::<PrimitiveScalar<i32>, _>(&mut iter).value(),
Some(1)
);
assert_eq!(
iter.next()
.unwrap()
.as_any()
.downcast_ref::<PrimitiveScalar<i32>>()
.unwrap()
.value(),
next_unchecked::<PrimitiveScalar<i32>, _>(&mut iter).value(),
None
);
assert_eq!(
iter.next()
.unwrap()
.as_any()
.downcast_ref::<Utf8Scalar<i32>>()
.unwrap()
.value(),
next_unchecked::<Utf8Scalar<i32>, _>(&mut iter).value(),
Some("c")
);
assert_eq!(iter.next(), None);
Expand All @@ -192,12 +175,7 @@ fn iter_sparse_slice() -> Result<()> {
let mut iter = array_slice.iter();

assert_eq!(
iter.next()
.unwrap()
.as_any()
.downcast_ref::<PrimitiveScalar<i32>>()
.unwrap()
.value(),
next_unchecked::<PrimitiveScalar<i32>, _>(&mut iter).value(),
Some(3)
);
assert_eq!(iter.next(), None);
Expand All @@ -224,15 +202,67 @@ fn iter_dense_slice() -> Result<()> {
let mut iter = array_slice.iter();

assert_eq!(
iter.next()
next_unchecked::<PrimitiveScalar<i32>, _>(&mut iter).value(),
Some(3)
);
assert_eq!(iter.next(), None);

Ok(())
}

#[test]
fn scalar() -> Result<()> {
let fields = vec![
Field::new("a", DataType::Int32, true),
Field::new("b", DataType::Utf8, true),
];
let data_type = DataType::Union(fields, None, UnionMode::Dense);
let types = Buffer::from_slice([0, 0, 1]);
let offsets = Buffer::<i32>::from_slice([0, 1, 0]);
let fields = vec![
Arc::new(Int32Array::from(&[Some(1), None])) as Arc<dyn Array>,
Arc::new(Utf8Array::<i32>::from(&[Some("c")])) as Arc<dyn Array>,
];

let array = UnionArray::from_data(data_type, types, fields.clone(), Some(offsets));

let scalar = new_scalar(&array, 0);
let union_scalar = scalar.as_any().downcast_ref::<UnionScalar>().unwrap();
assert_eq!(
union_scalar
.value()
.as_any()
.downcast_ref::<PrimitiveScalar<i32>>()
.unwrap()
.value(),
Some(1)
);
assert_eq!(union_scalar.type_(), 0);
let scalar = new_scalar(&array, 1);
let union_scalar = scalar.as_any().downcast_ref::<UnionScalar>().unwrap();
assert_eq!(
union_scalar
.value()
.as_any()
.downcast_ref::<PrimitiveScalar<i32>>()
.unwrap()
.value(),
Some(3)
None
);
assert_eq!(iter.next(), None);
assert_eq!(union_scalar.type_(), 0);

let scalar = new_scalar(&array, 2);
let union_scalar = scalar.as_any().downcast_ref::<UnionScalar>().unwrap();
assert_eq!(
union_scalar
.value()
.as_any()
.downcast_ref::<Utf8Scalar<i32>>()
.unwrap()
.value(),
Some("c")
);
assert_eq!(union_scalar.type_(), 1);

Ok(())
}

0 comments on commit 9a38663

Please sign in to comment.