diff --git a/src/array_decoder/decimal.rs b/src/array_decoder/decimal.rs index 9ecf274..24634cf 100644 --- a/src/array_decoder/decimal.rs +++ b/src/array_decoder/decimal.rs @@ -101,6 +101,10 @@ impl ArrayBatchDecoder for DecimalArrayDecoder { let array = Arc::new(array) as ArrayRef; Ok(array) } + + fn skip_values(&mut self, n: usize, parent_present: Option<&NullBuffer>) -> Result<()> { + self.inner.skip_values(n, parent_present) + } } /// This iter fixes the scales of the varints decoded as scale is specified on a per @@ -112,6 +116,12 @@ struct DecimalScaleRepairDecoder { } impl PrimitiveValueDecoder for DecimalScaleRepairDecoder { + fn skip(&mut self, n: usize) -> Result<()> { + self.varint_iter.skip(n)?; + self.scale_iter.skip(n)?; + Ok(()) + } + fn decode(&mut self, out: &mut [i128]) -> Result<()> { // TODO: can probably optimize, reuse buffers? let mut varint = vec![0; out.len()]; diff --git a/src/array_decoder/list.rs b/src/array_decoder/list.rs index 34d40d0..6e13147 100644 --- a/src/array_decoder/list.rs +++ b/src/array_decoder/list.rs @@ -85,4 +85,21 @@ impl ArrayBatchDecoder for ListArrayDecoder { let array = Arc::new(array); Ok(array) } + + fn skip_values(&mut self, n: usize, parent_present: Option<&NullBuffer>) -> Result<()> { + use super::skip_present_and_get_non_null_count; + + let non_null_count = + skip_present_and_get_non_null_count(&mut self.present, parent_present, n)?; + + // Decode lengths to determine how many child values to skip + let mut lengths = vec![0; non_null_count]; + self.lengths.decode(&mut lengths)?; + let total_length: i64 = lengths.iter().sum(); + + // Skip the child values (children don't have parent_present from list) + self.inner.skip_values(total_length as usize, None)?; + + Ok(()) + } } diff --git a/src/array_decoder/map.rs b/src/array_decoder/map.rs index 175b09a..8ccd810 100644 --- a/src/array_decoder/map.rs +++ b/src/array_decoder/map.rs @@ -102,4 +102,22 @@ impl ArrayBatchDecoder for MapArrayDecoder { let array = Arc::new(array); Ok(array) } + + fn skip_values(&mut self, n: usize, parent_present: Option<&NullBuffer>) -> Result<()> { + use super::skip_present_and_get_non_null_count; + + let non_null_count = + skip_present_and_get_non_null_count(&mut self.present, parent_present, n)?; + + // Decode lengths to determine how many entries to skip + let mut lengths = vec![0; non_null_count]; + self.lengths.decode(&mut lengths)?; + let total_length: i64 = lengths.iter().sum(); + + // Skip both keys and values (they don't have parent_present from map) + self.keys.skip_values(total_length as usize, None)?; + self.values.skip_values(total_length as usize, None)?; + + Ok(()) + } } diff --git a/src/array_decoder/mod.rs b/src/array_decoder/mod.rs index afbdaa0..6074b53 100644 --- a/src/array_decoder/mod.rs +++ b/src/array_decoder/mod.rs @@ -75,6 +75,13 @@ pub trait ArrayBatchDecoder: Send { batch_size: usize, parent_present: Option<&NullBuffer>, ) -> Result; + + /// Skip the next `n` values without decoding them, failing if it cannot skip the enough values. + /// If parent nested type (e.g. Struct) indicates a null in it's PRESENT stream, + /// then the child doesn't have a value (similar to other nullability). So we need + /// to take care to insert these null values as Arrow requires the child to hold + /// data in the null slot of the child. + fn skip_values(&mut self, n: usize, parent_present: Option<&NullBuffer>) -> Result<()>; } struct PrimitiveArrayDecoder { @@ -123,6 +130,12 @@ impl ArrayBatchDecoder for PrimitiveArrayDecoder { let array = Arc::new(array) as ArrayRef; Ok(array) } + + fn skip_values(&mut self, n: usize, parent_present: Option<&NullBuffer>) -> Result<()> { + let non_null_count = + skip_present_and_get_non_null_count(&mut self.present, parent_present, n)?; + self.iter.skip(non_null_count) + } } type Int64ArrayDecoder = PrimitiveArrayDecoder; @@ -168,6 +181,12 @@ impl ArrayBatchDecoder for BooleanArrayDecoder { }; Ok(Arc::new(array)) } + + fn skip_values(&mut self, n: usize, parent_present: Option<&NullBuffer>) -> Result<()> { + let non_null_count = + skip_present_and_get_non_null_count(&mut self.present, parent_present, n)?; + self.iter.skip(non_null_count) + } } struct PresentDecoder { @@ -232,6 +251,42 @@ fn derive_present_vec( } } +/// Skip n values and return the non-null count for the data stream +fn skip_present_and_get_non_null_count( + present: &mut Option, + parent_present: Option<&NullBuffer>, + n: usize, +) -> Result { + match (present, parent_present) { + (Some(present), Some(parent_present)) => { + // Parent has nulls, so we need to decode parent present to know how many + // of our present values to skip + let non_null_in_parent = parent_present.len() - parent_present.null_count(); + + // Skip our present values for non-null parents and count non-nulls + let mut our_present = vec![false; non_null_in_parent]; + present.inner.decode(&mut our_present)?; + let our_non_null_count = our_present.iter().filter(|&&v| v).count(); + + Ok(our_non_null_count) + } + (Some(present), None) => { + // No parent present, skip n values and count non-nulls + let mut present_values = vec![false; n]; + present.inner.decode(&mut present_values)?; + Ok(present_values.iter().filter(|&&v| v).count()) + } + (None, Some(parent_present)) => { + // No our present stream, all non-null parents have data + Ok(parent_present.len() - parent_present.null_count()) + } + (None, None) => { + // No nulls at all, all n values have data + Ok(n) + } + } +} + pub struct NaiveStripeDecoder { stripe: Stripe, schema_ref: SchemaRef, @@ -243,56 +298,81 @@ pub struct NaiveStripeDecoder { selection_index: usize, } +impl NaiveStripeDecoder { + /// Advance according to the configured row selection and return the next batch, if any. + /// + /// Behavior: + /// - Iterates `RowSelection` segments (skip/select) starting at `selection_index`. + /// - For skip segments: clamp to remaining rows in this stripe, advance decoders via + /// `skip_rows(actual_skip)`, and advance `index`. If the segment is fully consumed, + /// increment `selection_index`. + /// - For select segments: decode up to `min(row_count, batch_size, remaining_in_stripe)`, + /// advance `index`, update `selection_index` if fully consumed, and return the batch. + /// - If a segment requests rows beyond the end of the stripe, it is skipped (advancing + /// `selection_index`) without touching decoders. + fn next_with_row_selection(&mut self) -> Option> { + // Process selectors until we produce a batch or exhaust selection + loop { + let (is_skip, row_count) = { + let selectors = self.row_selection.as_ref().unwrap().selectors(); + if self.selection_index >= selectors.len() { + return None; + } + let selector = selectors[self.selection_index]; + (selector.skip, selector.row_count) + }; + + if is_skip { + let remaining = self.number_of_rows - self.index; + let actual_skip = row_count.min(remaining); + + if actual_skip == 0 { + // Nothing to skip in this stripe; try next selector + self.selection_index += 1; + continue; + } + + // Keep decoders in sync by skipping values per column + if let Err(e) = self.skip_rows(actual_skip) { + return Some(Err(e)); + } + self.index += actual_skip; + + if actual_skip >= row_count { + self.selection_index += 1; + } + } else { + let rows_to_read = row_count.min(self.batch_size); + let remaining = self.number_of_rows - self.index; + let actual_rows = rows_to_read.min(remaining); + + if actual_rows == 0 { + // Nothing to read from this selector in this stripe; advance selector + self.selection_index += 1; + continue; + } + + let record = self.decode_next_batch(actual_rows).transpose()?; + self.index += actual_rows; + + if actual_rows >= row_count { + self.selection_index += 1; + } + return Some(record); + } + } + } +} + impl Iterator for NaiveStripeDecoder { type Item = Result; + // TODO: check if we can make this more efficient fn next(&mut self) -> Option { if self.index < self.number_of_rows { // Handle row selection if present if self.row_selection.is_some() { - // Process selectors until we find rows to select or exhaust the selection - loop { - let (is_skip, row_count) = { - // Safety: this has been checked above - let selectors = self.row_selection.as_ref().unwrap().selectors(); - if self.selection_index >= selectors.len() { - return None; - } - let selector = selectors[self.selection_index]; - (selector.skip, selector.row_count) - }; - - if is_skip { - // Skip these rows by advancing the index - self.index += row_count; - self.selection_index += 1; - - // Decode and discard the skipped rows to advance the internal decoders - if let Err(e) = self.skip_rows(row_count) { - return Some(Err(e)); - } - } else { - // Select these rows - let rows_to_read = row_count.min(self.batch_size); - let remaining = self.number_of_rows - self.index; - let actual_rows = rows_to_read.min(remaining); - - if actual_rows == 0 { - self.selection_index += 1; - continue; - } - - let record = self.decode_next_batch(actual_rows).transpose()?; - self.index += actual_rows; - - // Update selector to track progress - if actual_rows >= row_count { - self.selection_index += 1; - } - - return Some(record); - } - } + self.next_with_row_selection() } else { // No row selection - decode normally let record = self @@ -513,14 +593,12 @@ impl NaiveStripeDecoder { }) } - /// Skip the specified number of rows by decoding and discarding them + /// Skip the specified number of rows by calling skip_values on each decoder fn skip_rows(&mut self, count: usize) -> Result<()> { - // Decode in batches to avoid large memory allocations - let mut remaining = count; - while remaining > 0 { - let chunk = self.batch_size.min(remaining); - let _ = self.inner_decode_next_batch(chunk)?; - remaining -= chunk; + // Call skip_values on each decoder to efficiently skip rows + // Top-level decoders don't have parent_present + for decoder in &mut self.decoders { + decoder.skip_values(count, None)?; } Ok(()) } diff --git a/src/array_decoder/string.rs b/src/array_decoder/string.rs index 0518990..7da4fdb 100644 --- a/src/array_decoder/string.rs +++ b/src/array_decoder/string.rs @@ -155,6 +155,28 @@ impl ArrayBatchDecoder for GenericByteArrayDecoder { let array = Arc::new(array) as ArrayRef; Ok(array) } + + fn skip_values(&mut self, n: usize, parent_present: Option<&NullBuffer>) -> Result<()> { + use crate::array_decoder::skip_present_and_get_non_null_count; + + let non_null_count = + skip_present_and_get_non_null_count(&mut self.present, parent_present, n)?; + + // Decode lengths to determine how many bytes to skip + let mut lengths = vec![0; non_null_count]; + self.lengths.decode(&mut lengths)?; + let total_bytes: i64 = lengths.iter().sum(); + + // Skip the data bytes + // TODO: can we use the decompressor to skip the bytes? + std::io::copy( + &mut self.bytes.by_ref().take(total_bytes as u64), + &mut std::io::sink(), + ) + .context(IoSnafu)?; + + Ok(()) + } } pub struct DictionaryStringArrayDecoder { @@ -192,4 +214,8 @@ impl ArrayBatchDecoder for DictionaryStringArrayDecoder { let array = Arc::new(array); Ok(array) } + + fn skip_values(&mut self, n: usize, parent_present: Option<&NullBuffer>) -> Result<()> { + self.indexes.skip_values(n, parent_present) + } } diff --git a/src/array_decoder/struct_decoder.rs b/src/array_decoder/struct_decoder.rs index 576e45f..3eac359 100644 --- a/src/array_decoder/struct_decoder.rs +++ b/src/array_decoder/struct_decoder.rs @@ -76,4 +76,19 @@ impl ArrayBatchDecoder for StructArrayDecoder { let array = Arc::new(array); Ok(array) } + + fn skip_values(&mut self, n: usize, parent_present: Option<&NullBuffer>) -> Result<()> { + use super::derive_present_vec; + + // Derive the combined present buffer like in next_batch + let present = derive_present_vec(&mut self.present, parent_present, n).transpose()?; + + // Skip values in all child decoders + // Pass the present buffer to children so they know which values to skip + for decoder in &mut self.decoders { + decoder.skip_values(n, present.as_ref())?; + } + + Ok(()) + } } diff --git a/src/array_decoder/timestamp.rs b/src/array_decoder/timestamp.rs index ee3bd8d..c2d7bdd 100644 --- a/src/array_decoder/timestamp.rs +++ b/src/array_decoder/timestamp.rs @@ -267,6 +267,10 @@ impl ArrayBatchDecoder for TimestampOffsetArrayDecoder let array = Arc::new(array) as ArrayRef; Ok(array) } + + fn skip_values(&mut self, n: usize, parent_present: Option<&NullBuffer>) -> Result<()> { + self.inner.skip_values(n, parent_present) + } } /// Wrapper around PrimitiveArrayDecoder to allow specifying the timezone of the output @@ -286,6 +290,10 @@ impl ArrayBatchDecoder for TimestampInstantArrayDecoder) -> Result<()> { + self.0.skip_values(n, parent_present) + } } struct TimestampNanosecondAsDecimalWithTzDecoder(TimestampNanosecondAsDecimalDecoder, Tz); @@ -308,6 +316,11 @@ impl TimestampNanosecondAsDecimalWithTzDecoder { } impl PrimitiveValueDecoder for TimestampNanosecondAsDecimalWithTzDecoder { + fn skip(&mut self, n: usize) -> Result<()> { + self.0.skip(n)?; + Ok(()) + } + fn decode(&mut self, out: &mut [i128]) -> Result<()> { self.0.decode(out)?; for x in out.iter_mut() { diff --git a/src/array_decoder/union.rs b/src/array_decoder/union.rs index a674832..47a45e3 100644 --- a/src/array_decoder/union.rs +++ b/src/array_decoder/union.rs @@ -134,4 +134,30 @@ impl ArrayBatchDecoder for UnionArrayDecoder { let array = Arc::new(array); Ok(array) } + + fn skip_values(&mut self, n: usize, parent_present: Option<&NullBuffer>) -> Result<()> { + use super::derive_present_vec; + + // Derive the combined present buffer like in next_batch + let present = derive_present_vec(&mut self.present, parent_present, n).transpose()?; + + // Determine how many non-null values we need to skip from tags stream + let non_null_count = if let Some(present) = &present { + present.len() - present.null_count() + } else { + n + }; + + // Skip tags (only non-null values have tags) + self.tags.skip(non_null_count)?; + + // Skip values in all variant decoders + // For sparse union, each variant stores n values regardless of which variant is active + // Pass the present buffer to children + for decoder in &mut self.variants { + decoder.skip_values(n, present.as_ref())?; + } + + Ok(()) + } } diff --git a/src/async_arrow_reader.rs b/src/async_arrow_reader.rs index 69c7a83..d1c5aae 100644 --- a/src/async_arrow_reader.rs +++ b/src/async_arrow_reader.rs @@ -32,6 +32,7 @@ use crate::arrow_reader::Cursor; use crate::error::Result; use crate::reader::metadata::read_metadata_async; use crate::reader::AsyncChunkReader; +use crate::row_selection::RowSelection; use crate::stripe::{Stripe, StripeMetadata}; use crate::ArrowReaderBuilder; @@ -77,6 +78,7 @@ pub struct ArrowStreamReader { factory: Option>>, batch_size: usize, schema_ref: SchemaRef, + row_selection: Option, state: StreamState, } @@ -124,11 +126,17 @@ impl StripeFactory { } impl ArrowStreamReader { - pub(crate) fn new(cursor: Cursor, batch_size: usize, schema_ref: SchemaRef) -> Self { + pub(crate) fn new( + cursor: Cursor, + batch_size: usize, + schema_ref: SchemaRef, + row_selection: Option, + ) -> Self { Self { factory: Some(Box::new(cursor.into())), batch_size, schema_ref, + row_selection, state: StreamState::Init, } } @@ -171,10 +179,22 @@ impl ArrowStreamReader { StreamState::Reading(f) => match ready!(f.poll_unpin(cx)) { Ok((factory, Some(stripe))) => { self.factory = Some(Box::new(factory)); - match NaiveStripeDecoder::new( + + // Split off the row selection for this stripe + let stripe_rows = stripe.number_of_rows(); + let selection = self.row_selection.as_mut().and_then(|s| { + if s.row_count() > 0 { + Some(s.split_off(stripe_rows)) + } else { + None + } + }); + + match NaiveStripeDecoder::new_with_selection( stripe, self.schema_ref.clone(), self.batch_size, + selection, ) { Ok(decoder) => { self.state = StreamState::Decoding(Box::new(decoder)); @@ -229,6 +249,6 @@ impl ArrowReaderBuilder { stripe_index: 0, file_byte_range: self.file_byte_range, }; - ArrowStreamReader::new(cursor, self.batch_size, schema_ref) + ArrowStreamReader::new(cursor, self.batch_size, schema_ref, self.row_selection) } } diff --git a/src/bin/orc-stats.rs b/src/bin/orc-stats.rs index 1113a01..369117f 100644 --- a/src/bin/orc-stats.rs +++ b/src/bin/orc-stats.rs @@ -35,48 +35,48 @@ fn print_column_stats(col_stats: &ColumnStatistics) { match tstats { orc_rust::statistics::TypeStatistics::Integer { min, max, sum } => { println!("* Data type Integer"); - println!("* Minimum: {}", min); - println!("* Maximum: {}", max); + println!("* Minimum: {min}"); + println!("* Maximum: {max}"); if let Some(sum) = sum { - println!("* Sum: {}", sum); + println!("* Sum: {sum}"); } } orc_rust::statistics::TypeStatistics::Double { min, max, sum } => { println!("* Data type Double"); - println!("* Minimum: {}", min); - println!("* Maximum: {}", max); + println!("* Minimum: {min}"); + println!("* Maximum: {max}"); if let Some(sum) = sum { - println!("* Sum: {}", sum); + println!("* Sum: {sum}"); } } orc_rust::statistics::TypeStatistics::String { min, max, sum } => { println!("* Data type String"); - println!("* Minimum: {}", min); - println!("* Maximum: {}", max); - println!("* Sum: {}", sum); + println!("* Minimum: {min}"); + println!("* Maximum: {max}"); + println!("* Sum: {sum}"); } orc_rust::statistics::TypeStatistics::Bucket { true_count } => { println!("* Data type Bucket"); - println!("* True count: {}", true_count); + println!("* True count: {true_count}"); } orc_rust::statistics::TypeStatistics::Decimal { min, max, sum } => { println!("* Data type Decimal"); - println!("* Minimum: {}", min); - println!("* Maximum: {}", max); - println!("* Sum: {}", sum); + println!("* Minimum: {min}"); + println!("* Maximum: {max}"); + println!("* Sum: {sum}"); } orc_rust::statistics::TypeStatistics::Date { min, max } => { println!("* Data type Date"); if let Some(dt) = date32_to_datetime(*min) { - println!("* Minimum: {}", dt); + println!("* Minimum: {dt}"); } if let Some(dt) = date32_to_datetime(*max) { - println!("* Maximum: {}", dt); + println!("* Maximum: {dt}"); } } orc_rust::statistics::TypeStatistics::Binary { sum } => { println!("* Data type Binary"); - println!("* Sum: {}", sum); + println!("* Sum: {sum}"); } orc_rust::statistics::TypeStatistics::Timestamp { min, @@ -85,13 +85,13 @@ fn print_column_stats(col_stats: &ColumnStatistics) { max_utc, } => { println!("* Data type Timestamp"); - println!("* Minimum: {}", min); - println!("* Maximum: {}", max); + println!("* Minimum: {min}"); + println!("* Maximum: {max}"); if let Some(ts) = timestamp_ms_to_datetime(*min_utc) { - println!("* Minimum UTC: {}", ts); + println!("* Minimum UTC: {ts}"); } if let Some(ts) = timestamp_ms_to_datetime(*max_utc) { - println!("* Maximum UTC: {}", ts); + println!("* Maximum UTC: {ts}"); } } orc_rust::statistics::TypeStatistics::Collection { @@ -100,9 +100,9 @@ fn print_column_stats(col_stats: &ColumnStatistics) { total_children, } => { println!("* Data type Collection"); - println!("* Minimum children: {}", min_children); - println!("* Maximum children: {}", max_children); - println!("* Total children: {}", total_children); + println!("* Minimum children: {min_children}"); + println!("* Maximum children: {max_children}"); + println!("* Total children: {total_children}"); } } } diff --git a/src/encoding/boolean.rs b/src/encoding/boolean.rs index f86e585..4b10dd5 100644 --- a/src/encoding/boolean.rs +++ b/src/encoding/boolean.rs @@ -55,6 +55,48 @@ impl BooleanDecoder { } impl PrimitiveValueDecoder for BooleanDecoder { + fn skip(&mut self, n: usize) -> Result<()> { + let mut remaining_bits = n; + + // First consume from any buffered bits in `data` + if self.bits_in_data > 0 { + let take = remaining_bits.min(self.bits_in_data); + // Advance by shifting left (MSB-first) + self.data <<= take; + self.bits_in_data -= take; + remaining_bits -= take; + } + + if remaining_bits == 0 { + return Ok(()); + } + + // Skip whole bytes directly from byte RLE + let whole_bytes = remaining_bits / 8; + if whole_bytes > 0 { + self.decoder.skip(whole_bytes)?; + remaining_bits -= whole_bytes * 8; + } + + // Skip remaining bits by decoding one more byte and positioning inside it + if remaining_bits > 0 { + let mut byte = [0i8; 1]; + match self.decoder.decode(&mut byte) { + Ok(_) => { + self.data = (byte[0] as u8) << remaining_bits; + self.bits_in_data = 8 - remaining_bits; + } + Err(e) => { + // If we can't read more data, we're at the end of the stream + // This means we tried to skip more than available + return Err(e); + } + } + } + + Ok(()) + } + // TODO: can probably implement this better fn decode(&mut self, out: &mut [bool]) -> Result<()> { for x in out.iter_mut() { @@ -167,4 +209,111 @@ mod tests { decoder.decode(&mut actual).unwrap(); assert_eq!(actual, expected) } + + #[test] + fn test_skip_run() { + // Run: 100 false values (0x61, 0x00) + let data = [0x61u8, 0x00]; + let mut decoder = BooleanDecoder::new(data.as_ref()); + + // Decode first 10 values + let mut batch = vec![true; 10]; + decoder.decode(&mut batch).unwrap(); + assert_eq!(batch, vec![false; 10]); + + // Skip next 80 values + decoder.skip(80).unwrap(); + + // Decode last 10 values + let mut batch = vec![true; 10]; + decoder.decode(&mut batch).unwrap(); + assert_eq!(batch, vec![false; 10]); + } + + #[test] + fn test_skip_all() { + // Literal list of exactly 1 byte -> 8 bits + let data = [0xffu8, 0x00u8]; + let mut decoder = BooleanDecoder::new(data.as_ref()); + + // Skip all 8 bits + decoder.skip(8).unwrap(); + + // Next decode must error (EOF) + let mut batch = vec![true; 1]; + let result = decoder.decode(&mut batch); + assert!(result.is_err()); + } + + #[test] + fn test_skip_partial_bits() { + // Test skipping partial bits within a byte + let data = [0xfeu8, 0b01000100, 0b01000101]; // 16 bits of data + let mut decoder = BooleanDecoder::new(data.as_ref()); + + // Skip first 3 bits (should leave 5 bits in the first byte) + decoder.skip(3).unwrap(); + + // Decode next 5 bits should work + let mut batch = vec![true; 5]; + decoder.decode(&mut batch).unwrap(); + // Expected: After skipping 3 bits from 0b01000100, we get 0b000100 + // Which is [false, false, true, false, false] + assert_eq!(batch, vec![false, false, true, false, false]); + } + + #[test] + fn test_skip_cross_byte_boundary() { + // Test skipping across byte boundaries + let data = [0xfeu8, 0b01000100, 0b01000101]; // 16 bits of data + let mut decoder = BooleanDecoder::new(data.as_ref()); + + // Skip 6 bits (should consume first byte and 2 bits of second byte) + decoder.skip(6).unwrap(); + + // Decode remaining bits should work + let mut batch = vec![true; 4]; + decoder.decode(&mut batch).unwrap(); + // Expected: 0b0001 -> [false, false, false, true] + assert_eq!(batch, vec![false, false, false, true]); + } + + #[test] + fn test_skip_zero() { + let data = [0x61u8, 0x00]; // 100 false values + let mut decoder = BooleanDecoder::new(data.as_ref()); + + // Skip 0 values should be a no-op + decoder.skip(0).unwrap(); + + // Decode should still work normally + let mut batch = vec![true; 10]; + decoder.decode(&mut batch).unwrap(); + assert_eq!(batch, vec![false; 10]); + } + + #[test] + fn test_skip_exact_byte() { + let data = [0x61u8, 0x00]; // 100 false values + let mut decoder = BooleanDecoder::new(data.as_ref()); + + // Skip exactly 8 bits (1 byte) + decoder.skip(8).unwrap(); + + // Should be able to continue decoding + let mut batch = vec![true; 10]; + decoder.decode(&mut batch).unwrap(); + assert_eq!(batch, vec![false; 10]); + } + + #[test] + fn test_skip_more_than_available() { + // Literal list of exactly 1 byte -> 8 bits + let data = [0xffu8, 0x00u8]; + let mut decoder = BooleanDecoder::new(data.as_ref()); + + // Try to skip more than available should fail + let result = decoder.skip(9); + assert!(result.is_err()); + } } diff --git a/src/encoding/byte.rs b/src/encoding/byte.rs index 4cb8c40..ad992ca 100644 --- a/src/encoding/byte.rs +++ b/src/encoding/byte.rs @@ -20,12 +20,16 @@ use bytes::{BufMut, BytesMut}; use snafu::ResultExt; use crate::{ - error::{IoSnafu, Result}, + error::{IoSnafu, OutOfSpecSnafu, Result}, memory::EstimateMemory, }; use std::io::Read; -use super::{rle::GenericRle, util::read_u8, PrimitiveValueEncoder}; +use super::{ + rle::GenericRle, + util::{read_u8, try_read_u8}, + PrimitiveValueEncoder, +}; const MAX_LITERAL_LENGTH: usize = 128; const MIN_REPEAT_LENGTH: usize = 3; @@ -241,6 +245,79 @@ impl GenericRle for ByteRleDecoder { } Ok(()) } + + fn skip_values(&mut self, n: usize) -> Result<()> { + let mut remaining = n; + + // Try to skip from the internal buffer first + let available_count = self.available().len(); + if available_count >= remaining { + self.advance(remaining); + return Ok(()); + } + + // Buffer insufficient, consume what's available + self.advance(available_count); + remaining -= available_count; + + // Skip by reading headers and efficiently skipping blocks + while remaining > 0 { + // Read header to determine the next batch size + let header = match try_read_u8(&mut self.reader)? { + Some(byte) => byte, + None => { + // Stream ended but still have remaining values to skip + return OutOfSpecSnafu { + msg: "not enough values to skip in Byte RLE", + } + .fail(); + } + }; + + if header < 0x80 { + // Run of repeated value + let length = header as usize + MIN_REPEAT_LENGTH; + + if length <= remaining { + // Skip entire run, only read value byte but don't store + read_u8(&mut self.reader)?; + remaining -= length; + } else { + // Run exceeds remaining count, decode to buffer then skip from buffer + let value = read_u8(&mut self.reader)?; + self.leftovers.clear(); + self.index = 0; + self.leftovers.extend(std::iter::repeat(value).take(length)); + self.advance(remaining); + remaining = 0; + } + } else { + // List of values + let length = 0x100 - header as usize; + + if length <= remaining { + // Skip entire list, read but don't store + let mut discard_buffer = vec![0u8; length]; + self.reader + .read_exact(&mut discard_buffer) + .context(IoSnafu)?; + remaining -= length; + } else { + // List exceeds remaining count, decode to buffer then skip from buffer + self.leftovers.clear(); + self.index = 0; + self.leftovers.resize(length, 0); + self.reader + .read_exact(&mut self.leftovers) + .context(IoSnafu)?; + self.advance(remaining); + remaining = 0; + } + } + } + + Ok(()) + } } #[cfg(test)] @@ -278,6 +355,101 @@ mod tests { test_helper(&data, &expected); } + #[test] + fn test_skip_values() -> Result<()> { + // Test 1: Skip from buffer (buffer is sufficient) + let data = [0x61u8, 0x07]; // Run: 100 7s (header=0x61=97, length=97+3=100, value=0x07) + let mut decoder = ByteRleDecoder::new(Cursor::new(&data)); + + // Decode some to buffer + let mut batch = vec![0; 10]; + decoder.decode(&mut batch)?; + assert_eq!(batch, vec![7; 10]); + + // Skip 5 from buffer (buffer still has 90) + decoder.skip(5)?; + + // Continue decoding to verify position is correct + let mut batch = vec![0; 5]; + decoder.decode(&mut batch)?; + assert_eq!(batch, vec![7; 5]); + + // Test 2: Skip entire Run (length <= remaining) + let data = [0x61u8, 0x07]; // Run: 100 7s + let mut decoder = ByteRleDecoder::new(Cursor::new(&data)); + + // Skip entire run + decoder.skip(100)?; + + // Should reach stream end + let mut batch = vec![0; 1]; + let result = decoder.decode(&mut batch); + assert!(result.is_err()); // Expect error, because there is no more data + + // Test 3: Skip partial Run (length > remaining) + let data = [0x61u8, 0x07]; // Run: 100 7s + let mut decoder = ByteRleDecoder::new(Cursor::new(&data)); + + // Skip 50 + decoder.skip(50)?; + + // Decode next 10 + let mut batch = vec![0; 10]; + decoder.decode(&mut batch)?; + assert_eq!(batch, vec![7; 10]); + + // Test 4: Skip entire Literals (length <= remaining) + let data = [0xfeu8, 0x44, 0x45]; // Literals: [0x44, 0x45] + let mut decoder = ByteRleDecoder::new(Cursor::new(&data)); + + // Skip all 2 + decoder.skip(2)?; + + // Should reach stream end + let mut batch = vec![0; 1]; + let result = decoder.decode(&mut batch); + assert!(result.is_err()); + + // Test 5: Skip partial Literals (length > remaining) + // 0xfb means length = 256 - 251 = 5 + let data = [0xfbu8, 0x01, 0x02, 0x03, 0x04, 0x05]; // Literals: [1,2,3,4,5] + let mut decoder = ByteRleDecoder::new(Cursor::new(&data)); + + // Skip first 2 + decoder.skip(2)?; + + // Decode remaining 3 + let mut batch = vec![0; 3]; + decoder.decode(&mut batch)?; + assert_eq!(batch, vec![3, 4, 5]); + + // Test 6: Skip across multiple blocks + // Run: 10 zeros (header=0x07, length=7+3=10, value=0x00) + // Literals: [11, 12, 13] (header=0xfd, length=256-253=3) + // Run: 20 fives (header=0x11, length=17+3=20, value=0x05) + let data = [ + 0x07, 0x00, // Run: 10 zeros + 0xfdu8, 0x0b, 0x0c, 0x0d, // Literals: [11, 12, 13] + 0x11, 0x05, // Run: 20 fives + ]; + let mut decoder = ByteRleDecoder::new(Cursor::new(&data)); + + // Skip first 12 values (all 10 from run + 2 from literals) + decoder.skip(12)?; + + // Next value should be 13 (last literal) + let mut batch = vec![0; 1]; + decoder.decode(&mut batch)?; + assert_eq!(batch, vec![13]); + + // Next values should be 5s from the run + let mut batch = vec![0; 5]; + decoder.decode(&mut batch)?; + assert_eq!(batch, vec![5; 5]); + + Ok(()) + } + fn roundtrip_byte_rle_helper(values: &[i8]) -> Result> { let mut writer = ByteRleEncoder::new(); writer.write_slice(values); diff --git a/src/encoding/decimal.rs b/src/encoding/decimal.rs index f722cf0..0695432 100644 --- a/src/encoding/decimal.rs +++ b/src/encoding/decimal.rs @@ -36,6 +36,13 @@ impl UnboundedVarintStreamDecoder { } impl PrimitiveValueDecoder for UnboundedVarintStreamDecoder { + fn skip(&mut self, n: usize) -> Result<()> { + for _ in 0..n { + read_varint_zigzagged::(&mut self.reader)?; + } + Ok(()) + } + fn decode(&mut self, out: &mut [i128]) -> Result<()> { for x in out.iter_mut() { *x = read_varint_zigzagged::(&mut self.reader)?; @@ -43,3 +50,90 @@ impl PrimitiveValueDecoder for UnboundedVarintStreamDecoder { Ok(()) } } + +#[cfg(test)] +mod tests { + use super::*; + use std::io::Cursor; + + // Manually encode a few simple i128 values as zigzag varint for testing + // Format: zigzag encode, then varint encode + // 0 -> zigzag: 0 -> varint: [0x00] + // 1 -> zigzag: 2 -> varint: [0x02] + // -1 -> zigzag: 1 -> varint: [0x01] + // 100 -> zigzag: 200 -> varint: [0xc8, 0x01] + + #[test] + fn test_unbounded_varint_decoder_skip() -> Result<()> { + // Test data: [0, 1, -1, 100, 200] + // 0: 0x00 + // 1: 0x02 + // -1: 0x01 + // 100: 0xc8, 0x01 (zigzag: 200) + // 200: 0x90, 0x03 (zigzag: 400) + let encoded = vec![0x00, 0x02, 0x01, 0xc8, 0x01, 0x90, 0x03]; + let mut decoder = UnboundedVarintStreamDecoder::new(Cursor::new(&encoded)); + + // Decode first 2 values + let mut batch = vec![0i128; 2]; + decoder.decode(&mut batch)?; + assert_eq!(batch, vec![0, 1]); + + // Skip next 2 values (-1, 100) + decoder.skip(2)?; + + // Decode remaining value (200) + let mut batch = vec![0i128; 1]; + decoder.decode(&mut batch)?; + assert_eq!(batch, vec![200]); + + Ok(()) + } + + #[test] + fn test_unbounded_varint_skip_all() -> Result<()> { + // Test data: [0, 1, -1] + let encoded = vec![0x00, 0x02, 0x01]; + let mut decoder = UnboundedVarintStreamDecoder::new(Cursor::new(&encoded)); + + // Skip all 3 values + decoder.skip(3)?; + + // Try to decode should fail (EOF) + let mut batch = vec![0i128; 1]; + let result = decoder.decode(&mut batch); + assert!(result.is_err()); + + Ok(()) + } + + #[test] + fn test_unbounded_varint_skip_then_decode() -> Result<()> { + // Test data: [10, 20, 30, 40, 50] + // 10: zigzag 20 = 0x14 + // 20: zigzag 40 = 0x28 + // 30: zigzag 60 = 0x3c + // 40: zigzag 80 = 0x50 + // 50: zigzag 100 = 0x64 + let encoded = vec![0x14, 0x28, 0x3c, 0x50, 0x64]; + let mut decoder = UnboundedVarintStreamDecoder::new(Cursor::new(&encoded)); + + // Skip first 2 + decoder.skip(2)?; + + // Decode next 2 + let mut batch = vec![0i128; 2]; + decoder.decode(&mut batch)?; + assert_eq!(batch, vec![30, 40]); + + // Skip last 1 + decoder.skip(1)?; + + // Try to decode should fail (EOF) + let mut batch = vec![0i128; 1]; + let result = decoder.decode(&mut batch); + assert!(result.is_err()); + + Ok(()) + } +} diff --git a/src/encoding/float.rs b/src/encoding/float.rs index 5b9fa7e..3183b18 100644 --- a/src/encoding/float.rs +++ b/src/encoding/float.rs @@ -51,6 +51,22 @@ impl FloatDecoder { } impl PrimitiveValueDecoder for FloatDecoder { + fn skip(&mut self, n: usize) -> Result<()> { + let bytes_to_skip = n * std::mem::size_of::(); + let mut remaining = bytes_to_skip; + // TODO: use seek instead of read to avoid copying data + let mut buf = [0u8; 8192]; + + while remaining > 0 { + let to_read = remaining.min(buf.len()); + self.reader + .read_exact(&mut buf[..to_read]) + .context(IoSnafu)?; + remaining -= to_read; + } + Ok(()) + } + fn decode(&mut self, out: &mut [F]) -> Result<()> { let bytes = must_cast_slice_mut::(out); self.reader.read_exact(bytes).context(IoSnafu)?; @@ -176,4 +192,117 @@ mod tests { f64::INFINITY, ]); } + + #[test] + fn test_skip_f32() -> Result<()> { + // Encode 10 f32 values: [0.0, 1.5, 3.0, 4.5, 6.0, 7.5, 9.0, 10.5, 12.0, 13.5] + let values: Vec = (0..10).map(|i| i as f32 * 1.5).collect(); + let mut encoder = FloatEncoder::::new(); + encoder.write_slice(&values); + let bytes = encoder.take_inner(); + + let mut decoder = FloatDecoder::::new(Cursor::new(bytes)); + + // Decode first 3 values + let mut batch = vec![0.0f32; 3]; + decoder.decode(&mut batch)?; + assert_eq!(batch, vec![0.0, 1.5, 3.0]); + + // Skip next 4 values (4.5, 6.0, 7.5, 9.0) + decoder.skip(4)?; + + // Decode remaining 3 values (10.5, 12.0, 13.5) + let mut batch = vec![0.0f32; 3]; + decoder.decode(&mut batch)?; + assert_eq!(batch, vec![10.5, 12.0, 13.5]); + + Ok(()) + } + + #[test] + fn test_skip_f64() -> Result<()> { + // Encode 10 f64 values + let values: Vec = (0..10).map(|i| i as f64 * 2.5).collect(); + let mut encoder = FloatEncoder::::new(); + encoder.write_slice(&values); + let bytes = encoder.take_inner(); + + let mut decoder = FloatDecoder::::new(Cursor::new(bytes)); + + // Skip first 5 values + decoder.skip(5)?; + + // Decode next 3 values + let mut batch = vec![0.0f64; 3]; + decoder.decode(&mut batch)?; + assert_eq!(batch, vec![12.5, 15.0, 17.5]); + + // Skip 1 value + decoder.skip(1)?; + + // Decode last value + let mut batch = vec![0.0f64; 1]; + decoder.decode(&mut batch)?; + assert_eq!(batch, vec![22.5]); + + Ok(()) + } + + #[test] + fn test_skip_all_values() -> Result<()> { + // Test skipping all values + let values: Vec = vec![1.0, 2.0, 3.0, 4.0, 5.0]; + let mut encoder = FloatEncoder::::new(); + encoder.write_slice(&values); + let bytes = encoder.take_inner(); + + let mut decoder = FloatDecoder::::new(Cursor::new(bytes)); + + // Skip all 5 values + decoder.skip(5)?; + + // Try to decode should fail (EOF) + let mut batch = vec![0.0f32; 1]; + let result = decoder.decode(&mut batch); + assert!(result.is_err()); + + Ok(()) + } + + #[test] + fn test_skip_edge_cases() -> Result<()> { + // Test with special float values + let values = vec![ + f64::NAN, + f64::INFINITY, + f64::NEG_INFINITY, + 0.0, + -0.0, + f64::MIN, + f64::MAX, + ]; + let mut encoder = FloatEncoder::::new(); + encoder.write_slice(&values); + let bytes = encoder.take_inner(); + + let mut decoder = FloatDecoder::::new(Cursor::new(bytes)); + + // Skip first 3 (NAN, INF, NEG_INF) + decoder.skip(3)?; + + // Decode next 2 + let mut batch = vec![0.0f64; 2]; + decoder.decode(&mut batch)?; + assert_eq!(batch, vec![0.0, -0.0]); + + // Skip 1 (MIN) + decoder.skip(1)?; + + // Decode last (MAX) + let mut batch = vec![0.0f64; 1]; + decoder.decode(&mut batch)?; + assert_eq!(batch, vec![f64::MAX]); + + Ok(()) + } } diff --git a/src/encoding/integer/rle_v1.rs b/src/encoding/integer/rle_v1.rs index 1249504..134e47a 100644 --- a/src/encoding/integer/rle_v1.rs +++ b/src/encoding/integer/rle_v1.rs @@ -151,8 +151,79 @@ impl GenericRle for RleV1Decoder Some(EncodingType::Run { length, delta }) => { read_run::<_, _, S>(&mut self.reader, &mut self.decoded_ints, length, delta) } - None => Ok(()), + None => OutOfSpecSnafu { + msg: "not enough values to decode", + } + .fail(), + } + } + + fn skip_values(&mut self, n: usize) -> Result<()> { + let mut remaining = n; + + // Try to skip from the internal buffer first + let available_count = self.available().len(); + if available_count >= remaining { + self.advance(remaining); + return Ok(()); + } + + // Buffer insufficient, consume what's available + self.advance(available_count); + remaining -= available_count; + + // Skip by reading headers and efficiently skipping blocks + while remaining > 0 { + // Read header to determine the next batch type and size + match EncodingType::from_header(&mut self.reader)? { + Some(EncodingType::Literals { length }) => { + // Check if within skip range + if length <= remaining { + // Skip entire literal sequence, only read and discard varints + for _ in 0..length { + read_varint_zigzagged::(&mut self.reader)?; + } + remaining -= length; + } else { + // Literals exceed remaining count, decode to buffer then skip from buffer + self.decoded_ints.clear(); + self.current_head = 0; + read_literals::<_, _, S>(&mut self.reader, &mut self.decoded_ints, length)?; + self.advance(remaining); + remaining = 0; + } + } + Some(EncodingType::Run { length, delta }) => { + // Check if within skip range + if length <= remaining { + // Skip entire run, only read base value without computing sequence + read_varint_zigzagged::(&mut self.reader)?; + remaining -= length; + } else { + // Run exceeds remaining count, decode to buffer then skip from buffer + self.decoded_ints.clear(); + self.current_head = 0; + read_run::<_, _, S>( + &mut self.reader, + &mut self.decoded_ints, + length, + delta, + )?; + self.advance(remaining); + remaining = 0; + } + } + None => { + // Stream ended but still have remaining values to skip + return OutOfSpecSnafu { + msg: "not enough values to skip in RLE v1", + } + .fail(); + } + } } + + Ok(()) } } @@ -394,4 +465,87 @@ mod tests { test_helper(&original, &encoded); Ok(()) } + + #[test] + fn test_skip_values() -> Result<()> { + // Test 1: Skip from buffer (buffer is sufficient) + let encoded = [0x61, 0x00, 0x07]; // Run: 100 7s + let mut decoder = RleV1Decoder::::new(Cursor::new(&encoded)); + + // Decode some to buffer + let mut batch = vec![0; 10]; + decoder.decode(&mut batch)?; + assert_eq!(batch, vec![7; 10]); + + // Skip 5 from buffer (buffer still has 90) + decoder.skip(5)?; + + // Continue decoding to verify position is correct + let mut batch = vec![0; 5]; + decoder.decode(&mut batch)?; + assert_eq!(batch, vec![7; 5]); + + // Test 2: Skip entire Run (length <= remaining) + let encoded = [0x61, 0x00, 0x07]; // Run: 100 7s + let mut decoder = RleV1Decoder::::new(Cursor::new(&encoded)); + + // Skip entire run + decoder.skip(100)?; + + // Should reach stream end + let mut batch = vec![0; 1]; + let result = decoder.decode(&mut batch); + assert!(result.is_err()); // Expect error, because there is no more data + + // Test 3: Skip partial Run (length > remaining) + let encoded = [0x61, 0x00, 0x07]; // Run: 100 7s + let mut decoder = RleV1Decoder::::new(Cursor::new(&encoded)); + + // Skip 50 + decoder.skip(50)?; + + // Decode next 10 + let mut batch = vec![0; 10]; + decoder.decode(&mut batch)?; + assert_eq!(batch, vec![7; 10]); + + // Test 4: Skip entire Literals (length <= remaining) + let encoded = [0xfb, 0x02, 0x03, 0x06, 0x07, 0xb]; // Literals: [2,3,6,7,11] + let mut decoder = RleV1Decoder::::new(Cursor::new(&encoded)); + + // Skip all 5 + decoder.skip(5)?; + + // Should reach stream end + let mut batch = vec![0; 1]; + let result = decoder.decode(&mut batch); + assert!(result.is_err()); + + // Test 5: Skip partial Literals (length > remaining) + let encoded = [0xfb, 0x02, 0x03, 0x06, 0x07, 0xb]; // Literals: [2,3,6,7,11] + let mut decoder = RleV1Decoder::::new(Cursor::new(&encoded)); + + // Skip first 2 + decoder.skip(2)?; + + // Decode remaining 3 + let mut batch = vec![0; 3]; + decoder.decode(&mut batch)?; + assert_eq!(batch, vec![6, 7, 11]); + + // Test 6: Skip across multiple headers + // Encoded: 150 decreasing numbers (150, 149, ..., 1) + let encoded = [0x7f, 0xff, 0x96, 0x01, 0x11, 0xff, 0x14]; + let mut decoder = RleV1Decoder::::new(Cursor::new(&encoded)); + + // Skip first 100 + decoder.skip(100)?; + + // Decode next 10 (should be 50, 49, ..., 41) + let mut batch = vec![0; 10]; + decoder.decode(&mut batch)?; + assert_eq!(batch, vec![50, 49, 48, 47, 46, 45, 44, 43, 42, 41]); + + Ok(()) + } } diff --git a/src/encoding/integer/rle_v2/mod.rs b/src/encoding/integer/rle_v2/mod.rs index ed871cf..f916c26 100644 --- a/src/encoding/integer/rle_v2/mod.rs +++ b/src/encoding/integer/rle_v2/mod.rs @@ -21,7 +21,7 @@ use bytes::BytesMut; use crate::{ encoding::{rle::GenericRle, util::try_read_u8, PrimitiveValueEncoder}, - error::Result, + error::{OutOfSpecSnafu, Result}, memory::EstimateMemory, }; @@ -114,7 +114,12 @@ impl GenericRle for RleV2Decoder self.decoded_ints.clear(); let header = match try_read_u8(&mut self.reader)? { Some(byte) => byte, - None => return Ok(()), + None => { + return OutOfSpecSnafu { + msg: "not enough values to decode in RLE v2", + } + .fail(); + } }; match EncodingType::from_header(header) { @@ -139,6 +144,35 @@ impl GenericRle for RleV2Decoder Ok(()) } + + fn skip_values(&mut self, n: usize) -> Result<()> { + let mut remaining = n; + + // Try to skip from the internal buffer first + let available = self.decoded_ints.len() - self.current_head; + if available >= remaining { + self.advance(remaining); + return Ok(()); + } + + // Buffer insufficient, consume what's available + self.advance(available); + remaining -= available; + + while remaining > 0 { + // Decode the next block into buffer + // TODO(optimization): avoid decode + self.decode_batch()?; + + // Skip from the newly decoded buffer + let decoded_count = self.decoded_ints.len(); + let to_skip = decoded_count.min(remaining); + self.advance(to_skip); + remaining -= to_skip; + } + + Ok(()) + } } struct DeltaEncodingCheckResult { @@ -649,6 +683,208 @@ mod tests { Ok(actual) } + #[test] + fn test_skip_values_short_repeat() -> Result<()> { + // Use the existing test data: ShortRepeat encoding + // 0x0a = 00_001_010 (width=1 byte, count=2+3=5 values) + // Followed by the value in 2 bytes (little-endian): 0x2710 = 10000 + let data = [0x0a, 0x27, 0x10]; + let mut decoder = RleV2Decoder::::new(Cursor::new(&data)); + + // Decode first 2 values + let mut batch = vec![0; 2]; + decoder.decode(&mut batch)?; + assert_eq!(batch, vec![10000, 10000]); + + // Skip next 2 values from buffer + decoder.skip(2)?; + + // Decode remaining 1 value + let mut batch = vec![0; 1]; + decoder.decode(&mut batch)?; + assert_eq!(batch, vec![10000]); + + Ok(()) + } + + #[test] + fn test_skip_values_entire_block() -> Result<()> { + // Generate test data using encoder + let mut encoder1 = RleV2Encoder::<_, UnsignedEncoding>::new(); + for _ in 0..5 { + encoder1.write_one(100); + } + encoder1.flush(); + let data1 = encoder1.take_inner(); + + let mut encoder2 = RleV2Encoder::<_, UnsignedEncoding>::new(); + for _ in 0..5 { + encoder2.write_one(200); + } + encoder2.flush(); + let data2 = encoder2.take_inner(); + + // Combine two blocks + let mut combined = Vec::new(); + combined.extend_from_slice(&data1); + combined.extend_from_slice(&data2); + + let mut decoder = RleV2Decoder::::new(Cursor::new(&combined)); + + // Skip entire first block + decoder.skip(5)?; + + // Decode from second block + let mut batch = vec![0; 3]; + decoder.decode(&mut batch)?; + assert_eq!(batch, vec![200, 200, 200]); + + Ok(()) + } + + #[test] + fn test_skip_values_across_blocks() -> Result<()> { + // Generate test data using encoder + let mut encoder = RleV2Encoder::<_, SignedEncoding>::new(); + + // Block 1: 5 values of 100 + for _ in 0..5 { + encoder.write_one(100); + } + encoder.flush(); + + // Block 2: 5 values of 200 + for _ in 0..5 { + encoder.write_one(200); + } + encoder.flush(); + + // Block 3: 5 values of 300 + for _ in 0..5 { + encoder.write_one(300); + } + encoder.flush(); + + let data = encoder.take_inner(); + let mut decoder = RleV2Decoder::::new(Cursor::new(&data)); + + // Skip 7 values (entire first block + 2 from second) + decoder.skip(7)?; + + // Next value should be from second block + let mut batch = vec![0; 1]; + decoder.decode(&mut batch)?; + assert_eq!(batch, vec![200]); + + // Skip 2 more (rest of second block) + decoder.skip(2)?; + + // Decode from third block + let mut batch = vec![0; 3]; + decoder.decode(&mut batch)?; + assert_eq!(batch, vec![300, 300, 300]); + + Ok(()) + } + + #[test] + fn test_skip_values_direct_encoding() -> Result<()> { + // Direct encoding: 4 values with specific bit width + let data = [0x5e, 0x03, 0x5c, 0xa1, 0xab, 0x1e, 0xde, 0xad, 0xbe, 0xef]; + + let mut decoder = RleV2Decoder::::new(Cursor::new(&data)); + + // Decode first 2 + let mut batch = vec![0; 2]; + decoder.decode(&mut batch)?; + assert_eq!(batch, vec![23713, 43806]); + + // Skip 1 + decoder.skip(1)?; + + // Decode last one + let mut batch = vec![0; 1]; + decoder.decode(&mut batch)?; + assert_eq!(batch, vec![48879]); + + Ok(()) + } + + #[test] + fn test_skip_values_delta_encoding() -> Result<()> { + // Delta encoding: sequence with fixed delta + let data = [0xc6, 0x09, 0x02, 0x02, 0x22, 0x42, 0x42, 0x46]; + + let mut decoder = RleV2Decoder::::new(Cursor::new(&data)); + + // Skip first 5 values + decoder.skip(5)?; + + // Decode next 3 + let mut batch = vec![0; 3]; + decoder.decode(&mut batch)?; + assert_eq!(batch, vec![13, 17, 19]); + + Ok(()) + } + + #[test] + fn test_skip_values_patched_base() -> Result<()> { + // PatchedBase encoding: with patches + let data = [ + 0x8e, 0x09, 0x2b, 0x21, 0x07, 0xd0, 0x1e, 0x00, 0x14, 0x70, 0x28, 0x32, 0x3c, 0x46, + 0x50, 0x5a, 0xfc, 0xe8, + ]; + + let mut decoder = RleV2Decoder::::new(Cursor::new(&data)); + + // Skip first 3 values + decoder.skip(3)?; + + // Decode next value (should be the patched one) + let mut batch = vec![0; 1]; + decoder.decode(&mut batch)?; + assert_eq!(batch, vec![1000000]); + + // Skip 2 more + decoder.skip(2)?; + + // Decode next 3 + let mut batch = vec![0; 3]; + decoder.decode(&mut batch)?; + assert_eq!(batch, vec![2060, 2070, 2080]); + + Ok(()) + } + + #[test] + fn test_skip_all_values() -> Result<()> { + // Test skipping all values in the stream + let mut encoder = RleV2Encoder::<_, SignedEncoding>::new(); + for _ in 0..5 { + encoder.write_one(10); + } + encoder.flush(); + for _ in 0..5 { + encoder.write_one(20); + } + encoder.flush(); + let data = encoder.take_inner(); + + let mut decoder = RleV2Decoder::::new(Cursor::new(&data)); + + // Skip all 10 values + decoder.skip(10)?; + + // Try to decode should result in empty read (no more data) + let mut batch = vec![0; 1]; + let result = decoder.decode(&mut batch); + // EOF is acceptable when stream is exhausted + assert!(result.is_err() || batch[0] == 0); + + Ok(()) + } + proptest! { #[test] fn roundtrip_i16(values in prop::collection::vec(any::(), 1..1_000)) { diff --git a/src/encoding/mod.rs b/src/encoding/mod.rs index 7a5dd08..57633c9 100644 --- a/src/encoding/mod.rs +++ b/src/encoding/mod.rs @@ -51,6 +51,9 @@ pub trait PrimitiveValueEncoder: EstimateMemory { } pub trait PrimitiveValueDecoder { + /// Skip the next `n` values without decoding them, failing if it cannot skip the enough values. + fn skip(&mut self, n: usize) -> Result<()>; + /// Decode out.len() values into out at a time, failing if it cannot fill /// the buffer. fn decode(&mut self, out: &mut [V]) -> Result<()>; @@ -95,12 +98,27 @@ mod tests { use super::*; /// Emits numbers increasing from 0. - struct DummyDecoder; + struct DummyDecoder { + value: i32, + } + + impl DummyDecoder { + fn new() -> Self { + Self { value: 0 } + } + } impl PrimitiveValueDecoder for DummyDecoder { + fn skip(&mut self, n: usize) -> Result<()> { + self.value += n as i32; + Ok(()) + } fn decode(&mut self, out: &mut [i32]) -> Result<()> { - let values = (0..out.len()).map(|x| x as i32).collect::>(); + let values = (0..out.len()) + .map(|x| self.value + x as i32) + .collect::>(); out.copy_from_slice(&values); + self.value += out.len() as i32; Ok(()) } } @@ -122,7 +140,7 @@ mod tests { proptest! { #[test] fn decode_spaced_proptest(present: Vec) { - let mut decoder = DummyDecoder; + let mut decoder = DummyDecoder::new(); let mut out = vec![-1; present.len()]; decoder.decode_spaced(&mut out, &NullBuffer::from(present.clone())).unwrap(); let expected = gen_spaced_dummy_decoder_expected(&present); @@ -132,7 +150,7 @@ mod tests { #[test] fn decode_spaced_edge_cases() { - let mut decoder = DummyDecoder; + let mut decoder = DummyDecoder::new(); let len = 10; // all present @@ -151,4 +169,17 @@ mod tests { let expected = vec![-1; len]; assert_eq!(out, expected); } + + #[test] + fn test_skip() { + let mut decoder = DummyDecoder::new(); + decoder.skip(10).unwrap(); + let mut out = vec![-1; 1]; + decoder.decode(&mut out).unwrap(); + assert_eq!(out, vec![10]); + decoder.skip(10).unwrap(); + let mut out2 = vec![-1; 5]; + decoder.decode(&mut out2).unwrap(); + assert_eq!(out2, vec![21, 22, 23, 24, 25]); + } } diff --git a/src/encoding/rle.rs b/src/encoding/rle.rs index a330efa..c36f4ef 100644 --- a/src/encoding/rle.rs +++ b/src/encoding/rle.rs @@ -53,9 +53,18 @@ pub trait GenericRle { // directly to the output and skip the middle man. Ideally the internal buffer // should only be used for leftovers between calls to PrimitiveValueDecoder::decode. fn decode_batch(&mut self) -> Result<()>; + + /// Skip n values without decoding them, failing if it cannot skip enough values. + /// This should first consume the left values in the internal buffer, then skip the remaining values from the reader. + fn skip_values(&mut self, n: usize) -> Result<()>; } impl + sealed::Rle> PrimitiveValueDecoder for G { + fn skip(&mut self, n: usize) -> Result<()> { + // Delegate to the GenericRle implementation + self.skip_values(n) + } + fn decode(&mut self, out: &mut [V]) -> Result<()> { let available = self.available(); // If we have enough leftover to copy, can skip decoding more. diff --git a/src/encoding/timestamp.rs b/src/encoding/timestamp.rs index 5f7fd5c..be01bd6 100644 --- a/src/encoding/timestamp.rs +++ b/src/encoding/timestamp.rs @@ -50,6 +50,12 @@ impl TimestampDecoder { } impl PrimitiveValueDecoder for TimestampDecoder { + fn skip(&mut self, n: usize) -> Result<()> { + self.data.skip(n)?; + self.secondary.skip(n)?; + Ok(()) + } + fn decode(&mut self, out: &mut [T::Native]) -> Result<()> { // TODO: can probably optimize, reuse buffers? let mut data = vec![0; out.len()]; @@ -90,6 +96,12 @@ impl TimestampNanosecondAsDecimalDecoder { } impl PrimitiveValueDecoder for TimestampNanosecondAsDecimalDecoder { + fn skip(&mut self, n: usize) -> Result<()> { + self.data.skip(n)?; + self.secondary.skip(n)?; + Ok(()) + } + fn decode(&mut self, out: &mut [i128]) -> Result<()> { // TODO: can probably optimize, reuse buffers? let mut data = vec![0; out.len()]; diff --git a/tests/row_selection/main.rs b/tests/row_selection/main.rs index 7a580f9..29356b3 100644 --- a/tests/row_selection/main.rs +++ b/tests/row_selection/main.rs @@ -25,6 +25,9 @@ use orc_rust::arrow_reader::ArrowReaderBuilder; use orc_rust::projection::ProjectionMask; use orc_rust::row_selection::{RowSelection, RowSelector}; +#[cfg(feature = "async")] +use futures_util::stream::TryStreamExt; + fn basic_path(path: &str) -> String { let dir = env!("CARGO_MANIFEST_DIR"); format!("{dir}/tests/basic/data/{path}") @@ -368,29 +371,296 @@ fn test_row_selection_with_compression() { assert_eq!(total_rows, 20); } -// TODO: Async version doesn't support row_selection yet -// Need to update async_arrow_reader.rs to pass row_selection to NaiveStripeDecoder -// #[cfg(feature = "async")] -// #[tokio::test] -// async fn test_row_selection_async() { -// let path = basic_path("test.orc"); -// let f = tokio::fs::File::open(path).await.unwrap(); -// -// let selection = vec![ -// RowSelector::skip(1), -// RowSelector::select(3), -// RowSelector::skip(1), -// ] -// .into(); -// -// let reader = ArrowReaderBuilder::try_new_async(f) -// .await -// .unwrap() -// .with_row_selection(selection) -// .build_async(); -// -// let batches = reader.try_collect::>().await.unwrap(); -// let total_rows: usize = batches.iter().map(|b| b.num_rows()).sum(); -// -// assert_eq!(total_rows, 3); -// } +#[cfg(feature = "async")] +#[tokio::test] +async fn test_row_selection_async() { + let path = basic_path("test.orc"); + let f = tokio::fs::File::open(path).await.unwrap(); + + let selection = vec![ + RowSelector::skip(1), + RowSelector::select(3), + RowSelector::skip(1), + ] + .into(); + + let reader = ArrowReaderBuilder::try_new_async(f) + .await + .unwrap() + .with_row_selection(selection) + .build_async(); + + let batches = reader.try_collect::>().await.unwrap(); + let total_rows: usize = batches.iter().map(|b| b.num_rows()).sum(); + + assert_eq!(total_rows, 3); +} + +#[cfg(feature = "async")] +#[tokio::test] +async fn test_row_selection_async_select_all() { + let path = basic_path("test.orc"); + let f = tokio::fs::File::open(path).await.unwrap(); + + let selection = RowSelection::select_all(5); + + let reader = ArrowReaderBuilder::try_new_async(f) + .await + .unwrap() + .with_row_selection(selection) + .build_async(); + + let batches = reader.try_collect::>().await.unwrap(); + let total_rows: usize = batches.iter().map(|b| b.num_rows()).sum(); + + assert_eq!(total_rows, 5); +} + +#[cfg(feature = "async")] +#[tokio::test] +async fn test_row_selection_async_skip_all() { + let path = basic_path("test.orc"); + let f = tokio::fs::File::open(path).await.unwrap(); + + let selection = RowSelection::skip_all(5); + + let reader = ArrowReaderBuilder::try_new_async(f) + .await + .unwrap() + .with_row_selection(selection) + .build_async(); + + let batches = reader.try_collect::>().await.unwrap(); + let total_rows: usize = batches.iter().map(|b| b.num_rows()).sum(); + + assert_eq!(total_rows, 0); +} + +#[cfg(feature = "async")] +#[tokio::test] +async fn test_row_selection_async_with_consecutive_ranges() { + let path = basic_path("test.orc"); + let f = tokio::fs::File::open(path).await.unwrap(); + + // Select rows at indices 0-1 and 3-4 (skip row 2) + let selection = RowSelection::from_consecutive_ranges(vec![0..2, 3..5].into_iter(), 5); + + let reader = ArrowReaderBuilder::try_new_async(f) + .await + .unwrap() + .with_row_selection(selection) + .build_async(); + + let batches = reader.try_collect::>().await.unwrap(); + let total_rows: usize = batches.iter().map(|b| b.num_rows()).sum(); + + assert_eq!(total_rows, 4); +} + +#[cfg(feature = "async")] +#[tokio::test] +async fn test_row_selection_async_select_first_only() { + let path = basic_path("test.orc"); + let f = tokio::fs::File::open(path).await.unwrap(); + + // Select only first row + let selection = vec![RowSelector::select(1), RowSelector::skip(4)].into(); + + let reader = ArrowReaderBuilder::try_new_async(f) + .await + .unwrap() + .with_row_selection(selection) + .build_async(); + + let batches = reader.try_collect::>().await.unwrap(); + let total_rows: usize = batches.iter().map(|b| b.num_rows()).sum(); + + assert_eq!(total_rows, 1); + + let expected = [ + "+-----+------+------------+---+-----+-------+--------------------+------------------------+-----------+---------------+------------+----------------+---------------+-------------------+--------------+---------------+---------------+-------------------------+-------------+----------------+", + "| a | b | str_direct | d | e | f | int_short_repeated | int_neg_short_repeated | int_delta | int_neg_delta | int_direct | int_neg_direct | bigint_direct | bigint_neg_direct | bigint_other | utf8_increase | utf8_decrease | timestamp_simple | date_simple | tinyint_simple |", + "+-----+------+------------+---+-----+-------+--------------------+------------------------+-----------+---------------+------------+----------------+---------------+-------------------+--------------+---------------+---------------+-------------------------+-------------+----------------+", + "| 1.0 | true | a | a | ddd | aaaaa | 5 | -5 | 1 | 5 | 1 | -1 | 1 | -1 | 5 | a | eeeee | 2023-04-01T20:15:30.002 | 2023-04-01 | -1 |", + "+-----+------+------------+---+-----+-------+--------------------+------------------------+-----------+---------------+------------+----------------+---------------+-------------------+--------------+---------------+---------------+-------------------------+-------------+----------------+", + ]; + assert_batches_eq(&batches, &expected); +} + +#[cfg(feature = "async")] +#[tokio::test] +async fn test_row_selection_async_select_last_only() { + let path = basic_path("test.orc"); + let f = tokio::fs::File::open(path).await.unwrap(); + + // Skip first 4 rows, select last row + let selection = vec![RowSelector::skip(4), RowSelector::select(1)].into(); + + let reader = ArrowReaderBuilder::try_new_async(f) + .await + .unwrap() + .with_row_selection(selection) + .build_async(); + + let batches = reader.try_collect::>().await.unwrap(); + let total_rows: usize = batches.iter().map(|b| b.num_rows()).sum(); + + assert_eq!(total_rows, 1); + + let expected = [ + "+-----+-------+------------+-----+---+-------+--------------------+------------------------+-----------+---------------+------------+----------------+---------------+-------------------+--------------+---------------+---------------+---------------------+-------------+----------------+", + "| a | b | str_direct | d | e | f | int_short_repeated | int_neg_short_repeated | int_delta | int_neg_delta | int_direct | int_neg_direct | bigint_direct | bigint_neg_direct | bigint_other | utf8_increase | utf8_decrease | timestamp_simple | date_simple | tinyint_simple |", + "+-----+-------+------------+-----+---+-------+--------------------+------------------------+-----------+---------------+------------+----------------+---------------+-------------------+--------------+---------------+---------------+---------------------+-------------+----------------+", + "| 5.0 | false | ee | ddd | a | ddddd | 5 | -5 | 5 | 1 | 2 | -2 | 2 | -2 | 5 | eeeee | a | 2023-03-01T00:00:00 | 2023-03-01 | -127 |", + "+-----+-------+------------+-----+---+-------+--------------------+------------------------+-----------+---------------+------------+----------------+---------------+-------------------+--------------+---------------+---------------+---------------------+-------------+----------------+", + ]; + assert_batches_eq(&batches, &expected); +} + +#[cfg(feature = "async")] +#[tokio::test] +async fn test_row_selection_async_with_nested_struct() { + let path = basic_path("nested_struct.orc"); + let f = tokio::fs::File::open(path).await.unwrap(); + + // Select first 2 rows and last row + let selection = vec![ + RowSelector::select(2), + RowSelector::skip(2), + RowSelector::select(1), + ] + .into(); + + let reader = ArrowReaderBuilder::try_new_async(f) + .await + .unwrap() + .with_row_selection(selection) + .build_async(); + + let batches = reader.try_collect::>().await.unwrap(); + let total_rows: usize = batches.iter().map(|b| b.num_rows()).sum(); + + assert_eq!(total_rows, 3); + + let expected = [ + "+-------------------+", + "| nest |", + "+-------------------+", + "| {a: 1.0, b: true} |", + "| {a: 3.0, b: } |", + "| {a: -3.0, b: } |", + "+-------------------+", + ]; + assert_batches_eq(&batches, &expected); +} + +#[cfg(feature = "async")] +#[tokio::test] +async fn test_row_selection_async_with_nested_array() { + let path = basic_path("nested_array.orc"); + let f = tokio::fs::File::open(path).await.unwrap(); + + // Select middle rows (index 1-2) + let selection = vec![ + RowSelector::skip(1), + RowSelector::select(2), + RowSelector::skip(2), + ] + .into(); + + let reader = ArrowReaderBuilder::try_new_async(f) + .await + .unwrap() + .with_row_selection(selection) + .build_async(); + + let batches = reader.try_collect::>().await.unwrap(); + let total_rows: usize = batches.iter().map(|b| b.num_rows()).sum(); + + assert_eq!(total_rows, 2); + + let expected = [ + "+--------------------+", + "| value |", + "+--------------------+", + "| [5, , 32, 4, 15] |", + "| [16, , 3, 4, 5, 6] |", + "+--------------------+", + ]; + assert_batches_eq(&batches, &expected); +} + +#[cfg(feature = "async")] +#[tokio::test] +async fn test_row_selection_async_with_large_file() { + // Test with a larger file that spans multiple stripes + let path = basic_path("string_long_long.orc"); + let f = tokio::fs::File::open(path).await.unwrap(); + + // Skip first 1000 rows, select next 500, skip rest + let selection = vec![ + RowSelector::skip(1000), + RowSelector::select(500), + RowSelector::skip(8500), + ] + .into(); + + let reader = ArrowReaderBuilder::try_new_async(f) + .await + .unwrap() + .with_row_selection(selection) + .build_async(); + + let batches = reader.try_collect::>().await.unwrap(); + let total_rows: usize = batches.iter().map(|b| b.num_rows()).sum(); + + assert_eq!(total_rows, 500); +} + +#[cfg(feature = "async")] +#[tokio::test] +async fn test_row_selection_async_empty_selection() { + let path = basic_path("test.orc"); + let f = tokio::fs::File::open(path).await.unwrap(); + + // Empty selection - skip all rows + let selection = RowSelection::skip_all(5); + + let reader = ArrowReaderBuilder::try_new_async(f) + .await + .unwrap() + .with_row_selection(selection) + .build_async(); + + let batches = reader.try_collect::>().await.unwrap(); + let total_rows: usize = batches.iter().map(|b| b.num_rows()).sum(); + + // Empty selection should read 0 rows + assert_eq!(total_rows, 0); +} + +#[cfg(feature = "async")] +#[tokio::test] +async fn test_row_selection_async_with_compression() { + // Test that row selection works with compressed files + let path = basic_path("string_dict_gzip.orc"); + let f = tokio::fs::File::open(path).await.unwrap(); + + let selection = vec![ + RowSelector::skip(10), + RowSelector::select(20), + RowSelector::skip(34), + ] + .into(); + + let reader = ArrowReaderBuilder::try_new_async(f) + .await + .unwrap() + .with_row_selection(selection) + .build_async(); + + let batches = reader.try_collect::>().await.unwrap(); + let total_rows: usize = batches.iter().map(|b| b.num_rows()).sum(); + + assert_eq!(total_rows, 20); +}