diff --git a/examples/parquet_write.rs b/examples/parquet_write.rs index ad5487ebfa1..ecbbe486f24 100644 --- a/examples/parquet_write.rs +++ b/examples/parquet_write.rs @@ -7,7 +7,8 @@ use arrow2::{ datatypes::{Field, Schema}, error::Result, io::parquet::write::{ - CompressionOptions, Encoding, FileWriter, RowGroupIterator, Version, WriteOptions, + transverse, CompressionOptions, Encoding, FileWriter, RowGroupIterator, Version, + WriteOptions, }, }; @@ -20,12 +21,13 @@ fn write_batch(path: &str, schema: Schema, columns: Chunk>) -> Re let iter = vec![Ok(columns)]; - let row_groups = RowGroupIterator::try_new( - iter.into_iter(), - &schema, - options, - vec![vec![Encoding::Plain]], - )?; + let encodings = schema + .fields + .iter() + .map(|f| transverse(&f.data_type, |_| Encoding::Plain)) + .collect(); + + let row_groups = RowGroupIterator::try_new(iter.into_iter(), &schema, options, encodings)?; // Create a new empty file let file = File::create(path)?; diff --git a/src/io/parquet/read/deserialize/mod.rs b/src/io/parquet/read/deserialize/mod.rs index d252d6ac781..cfc5f2034df 100644 --- a/src/io/parquet/read/deserialize/mod.rs +++ b/src/io/parquet/read/deserialize/mod.rs @@ -291,7 +291,7 @@ where .map(|f| { let mut init = init.clone(); init.push(InitNested::Struct(field.is_nullable)); - let n = n_columns(f); + let n = n_columns(&f.data_type); let columns = columns.drain(columns.len() - n..).collect(); let types = types.drain(types.len() - n..).collect(); columns_to_iter_recursive(columns, types, f.clone(), init, chunk_size) @@ -304,26 +304,27 @@ where }) } -fn n_columns(field: &Field) -> usize { +/// Returns the number of (parquet) columns that a [`DataType`] contains. +fn n_columns(data_type: &DataType) -> usize { use crate::datatypes::PhysicalType::*; - match field.data_type.to_physical_type() { + match data_type.to_physical_type() { Null | Boolean | Primitive(_) | Binary | FixedSizeBinary | LargeBinary | Utf8 | Dictionary(_) | LargeUtf8 => 1, List | FixedSizeList | LargeList => { - let a = field.data_type().to_logical_type(); + let a = data_type.to_logical_type(); if let DataType::List(inner) = a { - n_columns(inner) + n_columns(&inner.data_type) } else if let DataType::LargeList(inner) = a { - n_columns(inner) + n_columns(&inner.data_type) } else if let DataType::FixedSizeList(inner, _) = a { - n_columns(inner) + n_columns(&inner.data_type) } else { unreachable!() } } Struct => { - if let DataType::Struct(fields) = field.data_type.to_logical_type() { - fields.iter().map(n_columns).sum() + if let DataType::Struct(fields) = data_type.to_logical_type() { + fields.iter().map(|inner| n_columns(&inner.data_type)).sum() } else { unreachable!() } diff --git a/src/io/parquet/write/mod.rs b/src/io/parquet/write/mod.rs index 7e2068cd06e..e29c3022fd1 100644 --- a/src/io/parquet/write/mod.rs +++ b/src/io/parquet/write/mod.rs @@ -4,7 +4,6 @@ mod boolean; mod dictionary; mod file; mod fixed_len_bytes; -//mod levels; mod nested; mod pages; mod primitive; @@ -448,3 +447,60 @@ fn array_to_page_nested( } .map(EncodedPage::Data) } + +fn transverse_recursive T + Clone>( + data_type: &DataType, + map: F, + encodings: &mut Vec, +) { + use crate::datatypes::PhysicalType::*; + match data_type.to_physical_type() { + Null | Boolean | Primitive(_) | Binary | FixedSizeBinary | LargeBinary | Utf8 + | Dictionary(_) | LargeUtf8 => encodings.push(map(data_type)), + List | FixedSizeList | LargeList => { + let a = data_type.to_logical_type(); + if let DataType::List(inner) = a { + transverse_recursive(&inner.data_type, map, encodings) + } else if let DataType::LargeList(inner) = a { + transverse_recursive(&inner.data_type, map, encodings) + } else if let DataType::FixedSizeList(inner, _) = a { + transverse_recursive(&inner.data_type, map, encodings) + } else { + unreachable!() + } + } + Struct => { + if let DataType::Struct(fields) = data_type.to_logical_type() { + for field in fields { + transverse_recursive(&field.data_type, map.clone(), encodings) + } + } else { + unreachable!() + } + } + Union => todo!(), + Map => todo!(), + } +} + +/// Transverses the `data_type` up to its (parquet) columns and returns a vector of +/// items based on `map`. +/// This is used to assign an [`Encoding`] to every parquet column based on the columns' type (see example) +/// # Example +/// ``` +/// use arrow2::io::parquet::write::{transverse, Encoding}; +/// use arrow2::datatypes::{DataType, Field}; +/// +/// let dt = DataType::Struct(vec![ +/// Field::new("a", DataType::Int64, true), +/// Field::new("b", DataType::List(Box::new(Field::new("item", DataType::Int32, true))), true), +/// ]); +/// +/// let encodings = transverse(&dt, |dt| Encoding::Plain); +/// assert_eq!(encodings, vec![Encoding::Plain, Encoding::Plain]); +/// ``` +pub fn transverse T + Clone>(data_type: &DataType, map: F) -> Vec { + let mut encodings = vec![]; + transverse_recursive(data_type, map, &mut encodings); + encodings +}