Skip to content

Commit

Permalink
feat: implement sample function
Browse files Browse the repository at this point in the history
  • Loading branch information
yukkit authored and roseboy-liu committed Jun 1, 2023
1 parent 27b11a6 commit 1553d7e
Show file tree
Hide file tree
Showing 7 changed files with 394 additions and 12 deletions.
16 changes: 15 additions & 1 deletion query_server/query/benches/aggregate_function.rs
Expand Up @@ -7,7 +7,7 @@ use crate::criterion::Criterion;

fn criterion_benchmark(c: &mut Criterion) {
let partitions_len = 8;
let array_len = 32768 * 2; // 2^16
let array_len = 32768000 * 2;
let batch_size = 2048; // 2^11
let ctx = data_utils::create_context(partitions_len, array_len, batch_size).unwrap();

Expand All @@ -20,6 +20,20 @@ fn criterion_benchmark(c: &mut Criterion) {
)
})
});

for i in [10, 100, 1000] {
c.bench_function(&format!("aggregate_query_no_group_by_sample_{i}"), |b| {
b.iter(|| {
data_utils::query(
ctx.clone(),
&format!(
"SELECT sample(f64, {i}) \
FROM t"
),
)
})
});
}
}

criterion_group!(benches, criterion_benchmark);
Expand Down
5 changes: 1 addition & 4 deletions query_server/query/benches/data_utils/mod.rs
Expand Up @@ -15,7 +15,6 @@ use rand::rngs::StdRng;
use rand::seq::SliceRandom;
use rand::{Rng, SeedableRng};
use tokio::runtime::Runtime;
use trace::warn;

pub fn query(ctx: Arc<Mutex<SessionContext>>, sql: &str) {
let rt = Runtime::new().unwrap();
Expand All @@ -32,9 +31,7 @@ pub fn create_context(
// temporary(database level): wrap SessionContext into function meta manager
let mut func_manager = DFSessionContextFuncAdapter::new(&mut ctx);
// temporary(database level): register function to function meta manager
if let Err(e) = load_all_functions(&mut func_manager) {
warn!("Failed to load consdb's built-in function. err: {}", e);
};
load_all_functions(&mut func_manager).expect("load_all_functions");
let provider = create_table_provider(partitions_len, array_len, batch_size)?;
ctx.register_table("t", provider)?;
Ok(Arc::new(Mutex::new(ctx)))
Expand Down
@@ -1,13 +1,17 @@
#[cfg(test)]
mod example;
mod sample;

use spi::query::function::FunctionMetadataManager;
use spi::Result;

pub fn register_udafs(_func_manager: &mut dyn FunctionMetadataManager) -> Result<()> {
pub const SAMPLE_UDAF_NAME: &str = "sample";

pub fn register_udafs(func_manager: &mut dyn FunctionMetadataManager) -> Result<()> {
// extend function...
// eg.
// example::register_udaf(func_manager)?;
sample::register_udaf(func_manager)?;
Ok(())
}

Expand Down
295 changes: 295 additions & 0 deletions query_server/query/src/extension/expr/aggregate_function/sample.rs
@@ -0,0 +1,295 @@
use std::cmp;
use std::collections::HashSet;
use std::sync::Arc;

use datafusion::arrow::array::{Array, ArrayRef, UInt32Array};
use datafusion::arrow::compute::{self, take};
use datafusion::arrow::datatypes::{DataType, Field, UInt32Type};
use datafusion::common::cast::{as_list_array, as_primitive_array};
use datafusion::common::{downcast_value, DataFusionError, Result as DFResult};
use datafusion::logical_expr::type_coercion::aggregates::{
DATES, NUMERICS, STRINGS, TIMES, TIMESTAMPS,
};
use datafusion::logical_expr::{
AccumulatorFunctionImplementation, AggregateUDF, ReturnTypeFunction, Signature,
StateTypeFunction, TypeSignature, Volatility,
};
use datafusion::physical_plan::Accumulator;
use datafusion::scalar::ScalarValue;
use rand::Rng;
use spi::query::function::FunctionMetadataManager;
use spi::Result;

use super::SAMPLE_UDAF_NAME;
use crate::extension::expr::{BINARYS, INTEGERS};

pub fn register_udaf(func_manager: &mut dyn FunctionMetadataManager) -> Result<AggregateUDF> {
let udf = new();
func_manager.register_udaf(udf.clone())?;
Ok(udf)
}

fn new() -> AggregateUDF {
let return_type: ReturnTypeFunction = Arc::new(move |input| {
let date_type = DataType::List(Box::new(Field::new("item", input[0].clone(), true)));
Ok(Arc::new(date_type))
});

let state_type: StateTypeFunction =
Arc::new(move |output| Ok(Arc::new(vec![output.clone(), DataType::UInt32])));

let accumulator: AccumulatorFunctionImplementation = Arc::new(|output| match output {
DataType::List(f) => Ok(Box::new(SampleAccumulator::try_new(
output.clone(),
f.data_type().clone(),
)?)),
_ => {
panic!()
}
});

let type_signatures = STRINGS
.iter()
.chain(NUMERICS.iter())
.chain(TIMESTAMPS.iter())
.chain(DATES.iter())
.chain(BINARYS.iter())
.chain(TIMES.iter())
.flat_map(|t| {
INTEGERS
.iter()
.map(|s_t| TypeSignature::Exact(vec![t.clone(), s_t.clone()]))
})
.collect();

AggregateUDF::new(
SAMPLE_UDAF_NAME,
&Signature::one_of(type_signatures, Volatility::Volatile),
&return_type,
&accumulator,
&state_type,
)
}

/// Intermediate state data + number of samples
type IntermediateSampleState = (Vec<ScalarValue>, usize);

/// An accumulator to compute the average
#[derive(Debug)]
struct SampleAccumulator {
states: Vec<IntermediateSampleState>,

list_type: DataType,
child_type: DataType,
}

impl Accumulator for SampleAccumulator {
fn state(&self) -> DFResult<Vec<ScalarValue>> {
if self.states.is_empty() {
return empty_intermediate_sample_state(&self.list_type);
}

let (scalars, sample_n) = self.sample_state()?;

let state = ScalarValue::new_list(Some(scalars), self.child_type.clone());
let sample_n = ScalarValue::UInt32(Some(sample_n as u32));

trace::trace!("SampleAccumulator state: {:?}", state);

Ok(vec![state, sample_n])
}

fn update_batch(&mut self, values: &[ArrayRef]) -> DFResult<()> {
trace::trace!("update_batch: {:?}", values);

// Get the number of samples
if let Some(sample_n) = extract_sample_n(values[1].as_ref())? {
return self.update_batch_inner(values[0].clone(), sample_n);
}

Ok(())
}

fn merge_batch(&mut self, states: &[ArrayRef]) -> DFResult<()> {
trace::trace!("merge_batch: {:?}", states);

let state_col = as_list_array(states[0].as_ref())?;
let sample_n_col = downcast_value!(states[1], UInt32Array);

if state_col.is_empty() {
return Ok(());
}

state_col.iter().zip(sample_n_col).try_for_each(|e| {
match e {
(Some(state), Some(sample_n)) if !state.is_empty() => {
self.update_batch_inner(state, sample_n as usize)?;
}
_ => {
trace::info!("merge_batch, skip empty state: {:?}", e)
}
}

Ok(())
})
}

fn evaluate(&self) -> DFResult<ScalarValue> {
let result = self.state()?;

trace::trace!("SampleAccumulator evaluate result: {:?}", result);

Ok(result[0].clone())
}

fn size(&self) -> usize {
std::mem::size_of_val(self)
}
}

impl SampleAccumulator {
/// Creates a new `SampleAccumulator`
pub fn try_new(list_type: DataType, child_type: DataType) -> DFResult<Self> {
// let state = ScalarValue::new_list(None, field.data_type().clone());
Ok(Self {
states: vec![],
list_type,
child_type,
})
}

/// Sample the input data
fn sample_data(&self, arr: ArrayRef, sample_n: usize) -> DFResult<ArrayRef> {
if arr.len() <= sample_n {
trace::trace!("The size of data {} is less than the number of samples {}, use the original data directly", arr.len(), sample_n);
// use arr directly
Ok(arr)
} else {
trace::trace!("Take {} samples", sample_n);
// random sampling
let indices = UInt32Array::from(generate_unique_random_numbers(
sample_n as u32,
0,
arr.len() as u32,
));
Ok(take(arr.as_ref(), &indices, None)?)
}
}

/// Sample the state
fn sample_state(&self) -> DFResult<IntermediateSampleState> {
let states = &self.states;

let total_num = states.iter().map(|(e, _)| e.len()).sum::<usize>();

let mut sample_n = 0;
let mut result = vec![];

for (s, r) in states {
sample_n = *r;
let num = s.len();
let select_num = (num * sample_n + sample_n - 1) / total_num;
let indices = generate_unique_random_numbers(select_num as u32, 0, num as u32);
for i in indices {
result.push(s[i as usize].clone());
}
}

Ok((result, sample_n))
}

/// Try to merge the state
/// If the amount of state data is greater than 10 times that of remain, merge it
fn try_compact_state(&self, sample_n: usize) -> DFResult<Option<IntermediateSampleState>> {
let num_rows = self.states.iter().map(|(e, _)| e.len()).sum::<usize>();
if num_rows > sample_n * 10 {
trace::trace!("Merge existing data: {}", num_rows);
// compact
Ok(Some(self.sample_state()?))
} else {
Ok(None)
}
}

fn save_state(&mut self, state: IntermediateSampleState) {
self.states.push(state)
}

fn set_state(&mut self, state: IntermediateSampleState) {
self.states = vec![state];
}

fn update_batch_inner(&mut self, arr: ArrayRef, sample_n: usize) -> DFResult<()> {
trace::trace!("update_batch_inner: {:?}, sample_n: {}", arr, sample_n);
// sample
let sampled_arr = self.sample_data(arr, sample_n)?;

let df_values = arrow_array_to_df_values(sampled_arr.as_ref())?;

// save the sampling result
self.save_state((df_values, sample_n));
// try to merge saved sample results
// If merged every time, it will cause all sampling results to be traversed every time, which is inefficient
// If not merged, it will lead to excessive memory usage
if let Some(state) = self.try_compact_state(sample_n)? {
self.set_state(state);
}

Ok(())
}
}

fn generate_unique_random_numbers(count: u32, min: u32, max: u32) -> Vec<u32> {
let count = cmp::min(count, max - min);

let mut rng = rand::thread_rng();
let mut unique_numbers = HashSet::with_capacity(count as usize);
let mut result = Vec::new();

while (unique_numbers.len() as u32) < count {
let random_number = rng.gen_range(min..max);
unique_numbers.insert(random_number);
}

for number in unique_numbers {
result.push(number);
}

result
}

fn arrow_array_to_df_values(arr: &dyn Array) -> DFResult<Vec<ScalarValue>> {
let size = arr.len();

let mut result = Vec::with_capacity(size);
for i in 0..size {
result.push(ScalarValue::try_from_array(arr, i)?);
}

Ok(result)
}

/// Get the number of samples
fn extract_sample_n(arr: &dyn Array) -> DFResult<Option<usize>> {
let sample_n = unsafe {
if arr.is_empty() {
return Ok(None);
}
let remain_arr = compute::cast(arr, &DataType::UInt32)?;
as_primitive_array::<UInt32Type>(remain_arr.as_ref())?.value_unchecked(0)
} as usize;

// (0, 429496729]
if sample_n > 0 && sample_n <= 2000 {
return Ok(Some(sample_n));
}

Err(DataFusionError::Plan(format!(
"The number of sample points for function '{SAMPLE_UDAF_NAME}' must be (0, 2000]"
)))
}

fn empty_intermediate_sample_state(output_type: &DataType) -> DFResult<Vec<ScalarValue>> {
let empty_value = ScalarValue::try_from(output_type)?;
Ok(vec![empty_value, ScalarValue::UInt32(None)])
}
13 changes: 7 additions & 6 deletions query_server/query/src/extension/expr/func_manager.rs
Expand Up @@ -20,20 +20,21 @@ impl<'a> DFSessionContextFuncAdapter<'a> {
impl<'a> FunctionMetadataManager for DFSessionContextFuncAdapter<'a> {
fn register_udf(&mut self, udf: ScalarUDF) -> Result<()> {
if self.ctx.udf(udf.name.as_str()).is_err() {
return Err(QueryError::FunctionExists { name: udf.name });
self.ctx.register_udf(udf);

return Ok(());
}

self.ctx.register_udf(udf);
Ok(())
Err(QueryError::FunctionExists { name: udf.name })
}

fn register_udaf(&mut self, udaf: AggregateUDF) -> Result<()> {
if self.ctx.udaf(udaf.name.as_str()).is_err() {
return Err(QueryError::FunctionExists { name: udaf.name });
self.ctx.register_udaf(udaf);
return Ok(());
}

self.ctx.register_udaf(udaf);
Ok(())
Err(QueryError::FunctionExists { name: udaf.name })
}

fn udf(&self, name: &str) -> Result<Arc<ScalarUDF>> {
Expand Down

0 comments on commit 1553d7e

Please sign in to comment.