Skip to content

Commit

Permalink
Rebasing
Browse files Browse the repository at this point in the history
  • Loading branch information
rok committed Dec 20, 2023
1 parent a3363da commit 1f90a94
Showing 1 changed file with 102 additions and 2 deletions.
104 changes: 102 additions & 2 deletions rust/lance-core/src/io/writer/statistics.rs
Original file line number Diff line number Diff line change
Expand Up @@ -461,6 +461,38 @@ fn get_fixed_size_binary_statistics(arrays: &[&ArrayRef]) -> StatisticsRow {
}
}

fn get_list_statistics(arrays: &[&ArrayRef]) -> StatisticsRow {
let mut null_count: i64 = 0;
let mut stats: StatisticsRow;
match arrays[0].data_type() {
DataType::List(_) => {
let arrays = arrays
.iter()
.map(|x| {
null_count += x.null_count() as i64;
x.as_list::<i32>().values()
})
.collect::<Vec<_>>();
stats = collect_statistics(&arrays);
}
DataType::LargeList(_) => {
let arrays = arrays
.iter()
.map(|x| {
null_count += x.null_count() as i64;
x.as_list::<i64>().values()
})
.collect::<Vec<_>>();
stats = collect_statistics(&arrays);
}
_ => {
unreachable!()
}
}
stats.null_count = null_count;
stats
}

fn get_boolean_statistics(arrays: &[&ArrayRef]) -> StatisticsRow {
let mut true_present = false;
let mut false_present = false;
Expand Down Expand Up @@ -591,8 +623,8 @@ pub fn collect_statistics(arrays: &[&ArrayRef]) -> StatisticsRow {
DataType::Utf8 => get_string_statistics::<i32>(arrays),
DataType::LargeUtf8 => get_string_statistics::<i64>(arrays),
DataType::Dictionary(_, _) => get_dictionary_statistics(arrays),
// DataType::List(_) => get_list_statistics(arrays),
// DataType::LargeList(_) => get_list_statistics(arrays),
DataType::List(_) => get_list_statistics(arrays),
DataType::LargeList(_) => get_list_statistics(arrays),
_ => unreachable!(),
}
}
Expand Down Expand Up @@ -954,6 +986,7 @@ mod tests {
Time64MicrosecondArray, Time64NanosecondArray, TimestampMicrosecondArray,
TimestampMillisecondArray, TimestampNanosecondArray, TimestampSecondArray, UInt16Array,
UInt32Array, UInt64Array, UInt8Array,
ListArray, LargeListArray,
};
use arrow_select::interleave::interleave;
use num_traits::One;
Expand Down Expand Up @@ -2227,4 +2260,71 @@ mod tests {
}
}
}

#[test]
fn test_collect_list_stats() {
let data1 = vec![
Some(vec![Some(0)]),
Some(vec![Some(9)]),
Some(vec![Some(9), Some(2), Some(2)]),
];
let data2 = vec![
Some(vec![Some(0), Some(1), Some(2)]),
None,
Some(vec![Some(3), None]),
Some(vec![Some(6), Some(7), Some(8)]),
];

let expected_stats = StatisticsRow {
null_count: 1,
min_value: ScalarValue::from(0_i16),
max_value: ScalarValue::from(9_i16),
};
let arrays = vec![
Arc::new(ListArray::from_iter_primitive::<Int16Type, _, _>(
data1.clone(),
)) as ArrayRef,
Arc::new(ListArray::from_iter_primitive::<Int16Type, _, _>(
data2.clone(),
)) as ArrayRef,
];

let binding = arrays.iter().collect::<Vec<_>>();
let array_refs = binding.as_slice();
let stats = collect_statistics(array_refs);
assert_eq!(stats, expected_stats);
}

#[test]
fn test_collect_large_list_stats() {
let data1 = vec![
Some(vec![Some(0)]),
Some(vec![Some(9)]),
Some(vec![Some(9), Some(2), Some(2)]),
];
let data2 = vec![
Some(vec![Some(0), Some(1), Some(2)]),
None,
Some(vec![Some(3), None]),
Some(vec![Some(6), Some(7), Some(8)]),
];
let expected_stats = StatisticsRow {
null_count: 1,
min_value: ScalarValue::from(0_i64),
max_value: ScalarValue::from(9_i64),
};
let arrays = vec![
Arc::new(LargeListArray::from_iter_primitive::<Int64Type, _, _>(
data1.clone(),
)) as ArrayRef,
Arc::new(LargeListArray::from_iter_primitive::<Int64Type, _, _>(
data2.clone(),
)) as ArrayRef,
];

let binding = arrays.iter().collect::<Vec<_>>();
let array_refs = binding.as_slice();
let stats = collect_statistics(array_refs);
assert_eq!(stats, expected_stats);
}
}

0 comments on commit 1f90a94

Please sign in to comment.