diff --git a/src/array/null.rs b/src/array/null.rs index 10d90d37fd0..1c3e7dce26c 100644 --- a/src/array/null.rs +++ b/src/array/null.rs @@ -22,7 +22,7 @@ impl NullArray { pub fn try_new(data_type: DataType, length: usize) -> Result { if data_type.to_physical_type() != PhysicalType::Null { return Err(ArrowError::oos( - "BooleanArray can only be initialized with a DataType whose physical type is Boolean", + "NullArray can only be initialized with a DataType whose physical type is Boolean", )); } diff --git a/src/scalar/equal.rs b/src/scalar/equal.rs index e2cf20ee4f7..3e7e2a620b9 100644 --- a/src/scalar/equal.rs +++ b/src/scalar/equal.rs @@ -53,7 +53,7 @@ fn equal(lhs: &dyn Scalar, rhs: &dyn Scalar) -> bool { Struct => dyn_eq!(StructScalar, lhs, rhs), FixedSizeBinary => dyn_eq!(FixedSizeBinaryScalar, lhs, rhs), FixedSizeList => dyn_eq!(FixedSizeListScalar, lhs, rhs), - Union => unimplemented!("{:?}", Union), + Union => dyn_eq!(UnionScalar, lhs, rhs), Map => unimplemented!("{:?}", Map), } } diff --git a/src/scalar/mod.rs b/src/scalar/mod.rs index 44a909c990f..c7394af0097 100644 --- a/src/scalar/mod.rs +++ b/src/scalar/mod.rs @@ -25,6 +25,8 @@ mod fixed_size_list; pub use fixed_size_list::*; mod fixed_size_binary; pub use fixed_size_binary::*; +mod union; +pub use union::UnionScalar; /// Trait object declaring an optional value with a [`DataType`]. /// This strait is often used in APIs that accept multiple scalar types. @@ -144,7 +146,15 @@ pub fn new_scalar(array: &dyn Array, index: usize) -> Box { }; Box::new(FixedSizeListScalar::new(array.data_type().clone(), value)) } - Union | Map => todo!(), + Union => { + let array = array.as_any().downcast_ref::().unwrap(); + Box::new(UnionScalar::new( + array.data_type().clone(), + array.types()[index], + array.value(index).into(), + )) + } + Map => todo!(), Dictionary(key_type) => match_integer_type!(key_type, |$T| { let array = array .as_any() diff --git a/src/scalar/union.rs b/src/scalar/union.rs new file mode 100644 index 00000000000..fc273457b32 --- /dev/null +++ b/src/scalar/union.rs @@ -0,0 +1,54 @@ +use std::sync::Arc; + +use crate::datatypes::DataType; + +use super::Scalar; + +/// A single entry of a [`crate::array::UnionArray`]. +#[derive(Debug, Clone, PartialEq)] +pub struct UnionScalar { + value: Arc, + type_: i8, + data_type: DataType, +} + +impl UnionScalar { + /// Returns a new [`UnionScalar`] + #[inline] + pub fn new(data_type: DataType, type_: i8, value: Arc) -> Self { + Self { + value, + type_, + data_type, + } + } + + /// Returns the inner value + #[inline] + pub fn value(&self) -> &Arc { + &self.value + } + + /// Returns the type of the union scalar + #[inline] + pub fn type_(&self) -> i8 { + self.type_ + } +} + +impl Scalar for UnionScalar { + #[inline] + fn as_any(&self) -> &dyn std::any::Any { + self + } + + #[inline] + fn is_valid(&self) -> bool { + true + } + + #[inline] + fn data_type(&self) -> &DataType { + &self.data_type + } +} diff --git a/tests/it/array/union.rs b/tests/it/array/union.rs index 9045c560681..b52b5ff4fd7 100644 --- a/tests/it/array/union.rs +++ b/tests/it/array/union.rs @@ -5,9 +5,22 @@ use arrow2::{ buffer::Buffer, datatypes::*, error::Result, - scalar::{PrimitiveScalar, Utf8Scalar}, + scalar::{new_scalar, PrimitiveScalar, Scalar, UnionScalar, Utf8Scalar}, }; +fn next_unchecked(iter: &mut I) -> T +where + I: Iterator>, + T: Clone + 'static, +{ + iter.next() + .unwrap() + .as_any() + .downcast_ref::() + .unwrap() + .clone() +} + #[test] fn sparse_debug() -> Result<()> { let fields = vec![ @@ -94,30 +107,15 @@ fn iter_sparse() -> Result<()> { let mut iter = array.iter(); assert_eq!( - iter.next() - .unwrap() - .as_any() - .downcast_ref::>() - .unwrap() - .value(), + next_unchecked::, _>(&mut iter).value(), Some(1) ); assert_eq!( - iter.next() - .unwrap() - .as_any() - .downcast_ref::>() - .unwrap() - .value(), + next_unchecked::, _>(&mut iter).value(), None ); assert_eq!( - iter.next() - .unwrap() - .as_any() - .downcast_ref::>() - .unwrap() - .value(), + next_unchecked::, _>(&mut iter).value(), Some("c") ); assert_eq!(iter.next(), None); @@ -143,30 +141,15 @@ fn iter_dense() -> Result<()> { let mut iter = array.iter(); assert_eq!( - iter.next() - .unwrap() - .as_any() - .downcast_ref::>() - .unwrap() - .value(), + next_unchecked::, _>(&mut iter).value(), Some(1) ); assert_eq!( - iter.next() - .unwrap() - .as_any() - .downcast_ref::>() - .unwrap() - .value(), + next_unchecked::, _>(&mut iter).value(), None ); assert_eq!( - iter.next() - .unwrap() - .as_any() - .downcast_ref::>() - .unwrap() - .value(), + next_unchecked::, _>(&mut iter).value(), Some("c") ); assert_eq!(iter.next(), None); @@ -192,12 +175,7 @@ fn iter_sparse_slice() -> Result<()> { let mut iter = array_slice.iter(); assert_eq!( - iter.next() - .unwrap() - .as_any() - .downcast_ref::>() - .unwrap() - .value(), + next_unchecked::, _>(&mut iter).value(), Some(3) ); assert_eq!(iter.next(), None); @@ -224,15 +202,67 @@ fn iter_dense_slice() -> Result<()> { let mut iter = array_slice.iter(); assert_eq!( - iter.next() + next_unchecked::, _>(&mut iter).value(), + Some(3) + ); + assert_eq!(iter.next(), None); + + Ok(()) +} + +#[test] +fn scalar() -> Result<()> { + let fields = vec![ + Field::new("a", DataType::Int32, true), + Field::new("b", DataType::Utf8, true), + ]; + let data_type = DataType::Union(fields, None, UnionMode::Dense); + let types = Buffer::from_slice([0, 0, 1]); + let offsets = Buffer::::from_slice([0, 1, 0]); + let fields = vec![ + Arc::new(Int32Array::from(&[Some(1), None])) as Arc, + Arc::new(Utf8Array::::from(&[Some("c")])) as Arc, + ]; + + let array = UnionArray::from_data(data_type, types, fields.clone(), Some(offsets)); + + let scalar = new_scalar(&array, 0); + let union_scalar = scalar.as_any().downcast_ref::().unwrap(); + assert_eq!( + union_scalar + .value() + .as_any() + .downcast_ref::>() .unwrap() + .value(), + Some(1) + ); + assert_eq!(union_scalar.type_(), 0); + let scalar = new_scalar(&array, 1); + let union_scalar = scalar.as_any().downcast_ref::().unwrap(); + assert_eq!( + union_scalar + .value() .as_any() .downcast_ref::>() .unwrap() .value(), - Some(3) + None ); - assert_eq!(iter.next(), None); + assert_eq!(union_scalar.type_(), 0); + + let scalar = new_scalar(&array, 2); + let union_scalar = scalar.as_any().downcast_ref::().unwrap(); + assert_eq!( + union_scalar + .value() + .as_any() + .downcast_ref::>() + .unwrap() + .value(), + Some("c") + ); + assert_eq!(union_scalar.type_(), 1); Ok(()) }