diff --git a/arrow-pyarrow-integration-testing/tests/test_sql.py b/arrow-pyarrow-integration-testing/tests/test_sql.py index e6ec3223d0e..5260536342a 100644 --- a/arrow-pyarrow-integration-testing/tests/test_sql.py +++ b/arrow-pyarrow-integration-testing/tests/test_sql.py @@ -104,7 +104,7 @@ def test_decimal_roundtrip(self): data = [ round(decimal.Decimal(722.82), 2), round(decimal.Decimal(-934.11), 2), - None + None, ] a = pyarrow.array(data, pyarrow.decimal128(5, 2)) b = arrow_pyarrow_integration_testing.round_trip(a) @@ -179,3 +179,17 @@ def test_list_list_array(self): b.validate(full=True) assert a.to_pylist() == b.to_pylist() assert a.type == b.type + + def test_dict(self): + """ + Python -> Rust -> Python + """ + a = pyarrow.array( + ["a", "a", "b", None, "c"], + pyarrow.dictionary(pyarrow.int64(), pyarrow.utf8()), + ) + b = arrow_pyarrow_integration_testing.round_trip(a) + + b.validate(full=True) + assert a.to_pylist() == b.to_pylist() + assert a.type == b.type diff --git a/src/array/ffi.rs b/src/array/ffi.rs index 123d925d36b..e707f98d6b0 100644 --- a/src/array/ffi.rs +++ b/src/array/ffi.rs @@ -29,13 +29,28 @@ pub unsafe trait FromFfi: Sized { macro_rules! ffi_dyn { ($array:expr, $ty:ty) => {{ let array = $array.as_any().downcast_ref::<$ty>().unwrap(); - (array.buffers(), array.children()) + (array.buffers(), array.children(), None) }}; } -type BuffersChildren = (Vec>>, Vec>); +macro_rules! ffi_dict_dyn { + ($array:expr, $ty:ty) => {{ + let array = $array.as_any().downcast_ref::<$ty>().unwrap(); + ( + array.buffers(), + array.children(), + Some(array.values().clone()), + ) + }}; +} + +type BuffersChildren = ( + Vec>>, + Vec>, + Option>, +); -pub fn buffers_children(array: &dyn Array) -> BuffersChildren { +pub fn buffers_children_dictionary(array: &dyn Array) -> BuffersChildren { match array.data_type() { DataType::Null => ffi_dyn!(array, NullArray), DataType::Boolean => ffi_dyn!(array, BooleanArray), @@ -72,14 +87,14 @@ pub fn buffers_children(array: &dyn Array) -> BuffersChildren { DataType::Struct(_) => ffi_dyn!(array, StructArray), DataType::Union(_) => unimplemented!(), DataType::Dictionary(key_type, _) => match key_type.as_ref() { - DataType::Int8 => ffi_dyn!(array, DictionaryArray::), - DataType::Int16 => ffi_dyn!(array, DictionaryArray::), - DataType::Int32 => ffi_dyn!(array, DictionaryArray::), - DataType::Int64 => ffi_dyn!(array, DictionaryArray::), - DataType::UInt8 => ffi_dyn!(array, DictionaryArray::), - DataType::UInt16 => ffi_dyn!(array, DictionaryArray::), - DataType::UInt32 => ffi_dyn!(array, DictionaryArray::), - DataType::UInt64 => ffi_dyn!(array, DictionaryArray::), + DataType::Int8 => ffi_dict_dyn!(array, DictionaryArray::), + DataType::Int16 => ffi_dict_dyn!(array, DictionaryArray::), + DataType::Int32 => ffi_dict_dyn!(array, DictionaryArray::), + DataType::Int64 => ffi_dict_dyn!(array, DictionaryArray::), + DataType::UInt8 => ffi_dict_dyn!(array, DictionaryArray::), + DataType::UInt16 => ffi_dict_dyn!(array, DictionaryArray::), + DataType::UInt32 => ffi_dict_dyn!(array, DictionaryArray::), + DataType::UInt64 => ffi_dict_dyn!(array, DictionaryArray::), _ => unreachable!(), }, } diff --git a/src/array/mod.rs b/src/array/mod.rs index 772c49b46c8..3505a6204a6 100644 --- a/src/array/mod.rs +++ b/src/array/mod.rs @@ -401,7 +401,7 @@ pub use specification::Offset; pub use struct_::StructArray; pub use utf8::{MutableUtf8Array, Utf8Array, Utf8ValuesIter}; -pub(crate) use self::ffi::buffers_children; +pub(crate) use self::ffi::buffers_children_dictionary; pub use self::ffi::FromFfi; pub use self::ffi::ToFfi; diff --git a/src/ffi/array.rs b/src/ffi/array.rs index cd1e1852f7f..88575b60f50 100644 --- a/src/ffi/array.rs +++ b/src/ffi/array.rs @@ -67,13 +67,15 @@ pub fn try_from(array: A) -> Result> { DataType::LargeList(_) => Box::new(ListArray::::try_from_ffi(array)?), DataType::Struct(_) => Box::new(StructArray::try_from_ffi(array)?), DataType::Dictionary(keys, _) => match keys.as_ref() { + DataType::Int8 => Box::new(DictionaryArray::::try_from_ffi(array)?), + DataType::Int16 => Box::new(DictionaryArray::::try_from_ffi(array)?), + DataType::Int32 => Box::new(DictionaryArray::::try_from_ffi(array)?), DataType::Int64 => Box::new(DictionaryArray::::try_from_ffi(array)?), - other => { - return Err(ArrowError::NotYetImplemented(format!( - "Reading dictionary of keys \"{}\" is not yet supported.", - other - ))) - } + DataType::UInt8 => Box::new(DictionaryArray::::try_from_ffi(array)?), + DataType::UInt16 => Box::new(DictionaryArray::::try_from_ffi(array)?), + DataType::UInt32 => Box::new(DictionaryArray::::try_from_ffi(array)?), + DataType::UInt64 => Box::new(DictionaryArray::::try_from_ffi(array)?), + _ => unreachable!(), }, data_type => { return Err(ArrowError::NotYetImplemented(format!( @@ -89,7 +91,6 @@ pub fn try_from(array: A) -> Result> { #[cfg(test)] mod tests { use super::*; - use crate::array::*; use crate::datatypes::TimeUnit; use crate::{error::Result, ffi}; use std::sync::Arc; @@ -209,7 +210,6 @@ mod tests { test_round_trip(array) } - /* #[test] fn test_dict() -> Result<()> { let data = vec![Some("a"), Some("a"), None, Some("b")]; @@ -221,5 +221,4 @@ mod tests { test_round_trip(array) } - */ } diff --git a/src/ffi/ffi.rs b/src/ffi/ffi.rs index fca89e4ae69..756b0b04da8 100644 --- a/src/ffi/ffi.rs +++ b/src/ffi/ffi.rs @@ -19,7 +19,7 @@ use std::{ptr::NonNull, sync::Arc}; use super::schema::{to_field, Ffi_ArrowSchema}; use crate::{ - array::{buffers_children, Array}, + array::{buffers_children_dictionary, Array}, bitmap::{utils::bytes_for, Bitmap}, buffer::{ bytes::{Bytes, Deallocation}, @@ -75,6 +75,10 @@ unsafe extern "C" fn c_release_array(array: *mut Ffi_ArrowArray) { let _ = Box::from_raw(*child); } + if let Some(ptr) = private.dictionary_ptr { + let _ = Box::from_raw(ptr); + } + array.release = None; } @@ -83,6 +87,7 @@ struct PrivateData { array: Arc, buffers_ptr: Box<[*const std::os::raw::c_void]>, children_ptr: Box<[*mut Ffi_ArrowArray]>, + dictionary_ptr: Option<*mut Ffi_ArrowArray>, } impl Ffi_ArrowArray { @@ -91,7 +96,7 @@ impl Ffi_ArrowArray { /// This method releases `buffers`. Consumers of this struct *must* call `release` before /// releasing this struct, or contents in `buffers` leak. fn new(array: Arc) -> Self { - let (buffers, children) = buffers_children(array.as_ref()); + let (buffers, children, dictionary) = buffers_children_dictionary(array.as_ref()); let buffers_ptr = buffers .iter() @@ -109,6 +114,9 @@ impl Ffi_ArrowArray { .collect::>(); let n_children = children_ptr.len() as i64; + let dictionary_ptr = + dictionary.map(|array| Box::into_raw(Box::new(Ffi_ArrowArray::new(array)))); + let length = array.len() as i64; let null_count = array.null_count() as i64; @@ -116,6 +124,7 @@ impl Ffi_ArrowArray { array, buffers_ptr, children_ptr, + dictionary_ptr, }); Self { @@ -126,7 +135,7 @@ impl Ffi_ArrowArray { n_children, buffers: private_data.buffers_ptr.as_mut_ptr(), children: private_data.children_ptr.as_mut_ptr(), - dictionary: std::ptr::null_mut(), + dictionary: private_data.dictionary_ptr.unwrap_or(std::ptr::null_mut()), release: Some(c_release_array), private_data: Box::into_raw(private_data) as *mut ::std::os::raw::c_void, } diff --git a/src/ffi/schema.rs b/src/ffi/schema.rs index 3261432226e..b9edf4d544d 100644 --- a/src/ffi/schema.rs +++ b/src/ffi/schema.rs @@ -9,6 +9,7 @@ use crate::{ struct SchemaPrivateData { field: Field, children_ptr: Box<[*mut Ffi_ArrowSchema]>, + dictionary: Option<*mut Ffi_ArrowSchema>, } /// ABI-compatible struct for `ArrowSchema` from C Data Interface @@ -43,6 +44,10 @@ unsafe extern "C" fn c_release_schema(schema: *mut Ffi_ArrowSchema) { let _ = Box::from_raw(*child); } + if let Some(ptr) = private.dictionary { + let _ = Box::from_raw(ptr); + } + schema.release = None; } @@ -75,9 +80,18 @@ impl Ffi_ArrowSchema { let flags = field.is_nullable() as i64 * 2; + let dictionary = if let DataType::Dictionary(_, values) = field.data_type() { + // we do not store field info in the dict values, so can't recover it all :( + let field = Field::new("item", values.as_ref().clone(), true); + Some(Box::new(Ffi_ArrowSchema::try_new(field)?)) + } else { + None + }; + let mut private = Box::new(SchemaPrivateData { field, children_ptr, + dictionary: dictionary.map(Box::into_raw), }); // @@ -88,7 +102,7 @@ impl Ffi_ArrowSchema { flags, n_children, children: private.children_ptr.as_mut_ptr(), - dictionary: std::ptr::null_mut(), + dictionary: private.dictionary.unwrap_or(std::ptr::null_mut()), release: Some(c_release_schema), private_data: Box::into_raw(private) as *mut ::std::os::raw::c_void, })