Skip to content

Commit

Permalink
feat: support aggregate function first
Browse files Browse the repository at this point in the history
  • Loading branch information
yukkit committed Aug 9, 2023
1 parent 1e49459 commit 1b7b06a
Show file tree
Hide file tree
Showing 9 changed files with 223 additions and 15 deletions.
18 changes: 9 additions & 9 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

10 changes: 10 additions & 0 deletions query_server/query/benches/aggregate_function.rs
Expand Up @@ -42,6 +42,16 @@ fn criterion_benchmark(c: &mut Criterion) {
b.iter(|| data_utils::query(ctx.clone(), "select gauge_agg(ts, f64) FROM t"))
});
group.finish();

c.bench_function("aggregate_query_no_group_by_first", |b| {
b.iter(|| {
data_utils::query(
ctx.clone(),
"SELECT first(ts, f64) \
FROM t",
)
})
});
}

criterion_group!(benches, criterion_benchmark);
Expand Down
170 changes: 170 additions & 0 deletions query_server/query/src/extension/expr/aggregate_function/first.rs
@@ -0,0 +1,170 @@
use std::cmp::Ordering;
use std::sync::Arc;

use datafusion::arrow::array::ArrayRef;
use datafusion::arrow::compute::{sort_to_indices, SortOptions};
use datafusion::arrow::datatypes::DataType;
use datafusion::common::Result as DFResult;
use datafusion::error::DataFusionError;
use datafusion::logical_expr::type_coercion::aggregates::{
DATES, NUMERICS, STRINGS, TIMES, TIMESTAMPS,
};
use datafusion::logical_expr::{
AccumulatorFactoryFunction, AggregateUDF, ReturnTypeFunction, Signature, StateTypeFunction,
TypeSignature, Volatility,
};
use datafusion::physical_plan::Accumulator;
use datafusion::scalar::ScalarValue;
use spi::query::function::FunctionMetadataManager;
use spi::{QueryError, Result};

use super::TSPoint;
use crate::extension::expr::aggregate_function::FIRST_UDAF_NAME;
use crate::extension::expr::BINARYS;

pub fn register_udaf(func_manager: &mut dyn FunctionMetadataManager) -> Result<AggregateUDF> {
let udf = new();
func_manager.register_udaf(udf.clone())?;
Ok(udf)
}

fn new() -> AggregateUDF {
let return_type_func: ReturnTypeFunction =
Arc::new(move |input| Ok(Arc::new(input[1].clone())));

let state_type_func: StateTypeFunction = Arc::new(move |input, _| Ok(Arc::new(input.to_vec())));

let accumulator: AccumulatorFactoryFunction = Arc::new(|input, _| {
let time_data_type = input[0].clone();
let value_data_type = input[1].clone();

Ok(Box::new(FirstAccumulator::try_new(
time_data_type,
value_data_type,
)?))
});

// first(
// time TIMESTAMP,
// value ANY
// )
let type_signatures = STRINGS
.iter()
.chain(NUMERICS.iter())
.chain(TIMESTAMPS.iter())
.chain(DATES.iter())
.chain(BINARYS.iter())
.chain(TIMES.iter())
.flat_map(|t| {
TIMESTAMPS
.iter()
.map(|s_t| TypeSignature::Exact(vec![s_t.clone(), t.clone()]))
})
.collect();

AggregateUDF::new(
FIRST_UDAF_NAME,
&Signature::one_of(type_signatures, Volatility::Immutable),
&return_type_func,
&accumulator,
&state_type_func,
)
}

#[derive(Debug)]
struct FirstAccumulator {
first: TSPoint,

sort_opts: SortOptions,
}

impl FirstAccumulator {
fn try_new(time_data_type: DataType, value_data_type: DataType) -> DFResult<Self> {
let null = TSPoint::try_new_null(time_data_type, value_data_type)?;
Ok(Self {
first: null,
sort_opts: SortOptions {
descending: false,
nulls_first: false,
},
})
}

fn update_inner(&mut self, point: TSPoint) -> DFResult<()> {
if point.ts().is_null() || point.val().is_null() {
return Ok(());
}

if self.first.ts().is_null() {
self.first = point;
return Ok(());
}

match point.ts().partial_cmp(self.first.ts()) {
Some(ordering) => {
if ordering == Ordering::Less {
self.first = point;
}
}
None => {
return Err(DataFusionError::External(Box::new(QueryError::Internal {
reason: format!("cannot compare {:?} with {:?}", point.ts(), self.first.ts()),
})))
}
}

Ok(())
}
}

impl Accumulator for FirstAccumulator {
fn update_batch(&mut self, values: &[ArrayRef]) -> DFResult<()> {
trace::trace!("update_batch: {:?}", values);

if values.is_empty() {
return Ok(());
}

debug_assert!(
values.len() == 2,
"gauge_agg can only take 2 param, but found {}",
values.len()
);

let times_records = values[0].as_ref();
let value_records = values[1].as_ref();

let indices = sort_to_indices(times_records, Some(self.sort_opts), Some(1))?;

if !indices.is_empty() {
let idx = indices.value(0) as usize;
let ts = ScalarValue::try_from_array(times_records, idx)?;
let val = ScalarValue::try_from_array(value_records, idx)?;
let point = TSPoint { ts, val };
self.update_inner(point)?;
}

Ok(())
}

fn evaluate(&self) -> DFResult<ScalarValue> {
Ok(self.first.val().clone())
}

fn size(&self) -> usize {
std::mem::size_of_val(self) - std::mem::size_of_val(self.first.ts())
+ self.first.ts().size()
- std::mem::size_of_val(self.first.ts())
+ self.first.ts().size()
}

fn state(&self) -> DFResult<Vec<ScalarValue>> {
Ok(vec![self.first.ts().clone(), self.first.val().clone()])
}

fn merge_batch(&mut self, states: &[ArrayRef]) -> DFResult<()> {
trace::trace!("merge_batch: {:?}", states);

self.update_batch(states)
}
}
@@ -1,5 +1,6 @@
#[cfg(test)]
mod example;
mod first;
mod gauge;
mod sample;
mod state_agg;
Expand All @@ -17,6 +18,7 @@ use spi::{QueryError, Result};
pub const SAMPLE_UDAF_NAME: &str = "sample";
pub const COMPACT_STATE_AGG_UDAF_NAME: &str = "compact_state_agg";
pub const GAUGE_AGG_UDAF_NAME: &str = "gauge_agg";
pub const FIRST_UDAF_NAME: &str = "first";
pub use gauge::GaugeData;

pub fn register_udafs(func_manager: &mut dyn FunctionMetadataManager) -> Result<()> {
Expand All @@ -26,6 +28,7 @@ pub fn register_udafs(func_manager: &mut dyn FunctionMetadataManager) -> Result<
sample::register_udaf(func_manager)?;
state_agg::register_udafs(func_manager)?;
gauge::register_udafs(func_manager)?;
first::register_udaf(func_manager)?;
Ok(())
}

Expand Down
4 changes: 2 additions & 2 deletions query_server/sqllogicaltests/cases/ddl/user.slt
Expand Up @@ -38,5 +38,5 @@ select * from cluster_schema.users where user_name = 'test_alter_options_u';
test_alter_options_u false {"password":"*****","must_change_password":true,"comment":"ooo ooo","granted_admin":false}

# table not found
statement error .*Table not found: \\"a_non_existent_table\\".*
drop table a_non_existent_table;
statement error .*Table not found: \\"a_non_existent_table\\".*
drop table a_non_existent_table;
Expand Up @@ -8,8 +8,8 @@ statement ok
with tmp as (select compact_state_agg(time, f1) as state from func_tbl)
select state.state_duration, state.state_periods from tmp;

statement error Arrow error: Io error: Status \{ code: Internal, message: "Build logical plan: Failed to do analyze. err: The function \\"compact_state_agg\\" expects 2 arguments, but 3 were provided",.*
statement error Arrow error: Io error: Status \{ code: Internal, message: "Build logical plan: Datafusion: Error during planning: No function matches the given name and argument types 'compact_state_agg\(Timestamp\(Nanosecond, None\), Timestamp\(Nanosecond, None\), Timestamp\(Nanosecond, None\)\)'\. You might need to add explicit type casts\.\\n\\tCandidate functions:\\n\\tcompact_state_agg\(Timestamp\(Second, None\), Utf8\)\\n\\tcompact_state_agg\(Timestamp\(Millisecond, None\), Utf8\)\\n\\tcompact_state_agg\(Timestamp\(Microsecond, None\), Utf8\)\\n\\tcompact_state_agg\(Timestamp\(Nanosecond, None\), Utf8\)\\n\\tcompact_state_agg\(Timestamp\(Second, None\), LargeUtf8\)\\n\\tcompact_state_agg\(Timestamp\(Millisecond, None\), LargeUtf8\)\\n\\tcompact_state_agg\(Timestamp\(Microsecond, None\), LargeUtf8\)\\n\\tcompact_state_agg\(Timestamp\(Nanosecond, None\), LargeUtf8\)\\n\\tcompact_state_agg\(Timestamp\(Second, None\), Int8\)\\n\\tcompact_state_agg\(Timestamp\(Millisecond, None\), Int8\)\\n\\tcompact_state_agg\(Timestamp\(Microsecond, None\), Int8\)\\n\\tcompact_state_agg\(Timestamp\(Nanosecond, None\), Int8\)\\n\\tcompact_state_agg\(Timestamp\(Second, None\), Int16\)\\n\\tcompact_state_agg\(Timestamp\(Millisecond, None\), Int16\)\\n\\tcompact_state_agg\(Timestamp\(Microsecond, None\), Int16\)\\n\\tcompact_state_agg\(Timestamp\(Nanosecond, None\), Int16\)\\n\\tcompact_state_agg\(Timestamp\(Second, None\), Int32\)\\n\\tcompact_state_agg\(Timestamp\(Millisecond, None\), Int32\)\\n\\tcompact_state_agg\(Timestamp\(Microsecond, None\), Int32\)\\n\\tcompact_state_agg\(Timestamp\(Nanosecond, None\), Int32\)\\n\\tcompact_state_agg\(Timestamp\(Second, None\), Int64\)\\n\\tcompact_state_agg\(Timestamp\(Millisecond, None\), Int64\)\\n\\tcompact_state_agg\(Timestamp\(Microsecond, None\), Int64\)\\n\\tcompact_state_agg\(Timestamp\(Nanosecond, None\), Int64\)\\n\\tcompact_state_agg\(Timestamp\(Second, None\), UInt8\)\\n\\tcompact_state_agg\(Timestamp\(Millisecond, None\), UInt8\)\\n\\tcompact_state_agg\(Timestamp\(Microsecond, None\), UInt8\)\\n\\tcompact_state_agg\(Timestamp\(Nanosecond, None\), UInt8\)\\n\\tcompact_state_agg\(Timestamp\(Second, None\), UInt16\)\\n\\tcompact_state_agg\(Timestamp\(Millisecond, None\), UInt16\)\\n\\tcompact_state_agg\(Timestamp\(Microsecond, None\), UInt16\)\\n\\tcompact_state_agg\(Timestamp\(Nanosecond, None\), UInt16\)\\n\\tcompact_state_agg\(Timestamp\(Second, None\), UInt32\)\\n\\tcompact_state_agg\(Timestamp\(Millisecond, None\), UInt32\)\\n\\tcompact_state_agg\(Timestamp\(Microsecond, None\), UInt32\)\\n\\tcompact_state_agg\(Timestamp\(Nanosecond, None\), UInt32\)\\n\\tcompact_state_agg\(Timestamp\(Second, None\), UInt64\)\\n\\tcompact_state_agg\(Timestamp\(Millisecond, None\), UInt64\)\\n\\tcompact_state_agg\(Timestamp\(Microsecond, None\), UInt64\)\\n\\tcompact_state_agg\(Timestamp\(Nanosecond, None\), UInt64\)", .*
select compact_state_agg(time, time, time) as state from func_tbl;

statement error Arrow error: Io error: Status \{ code: Internal, message: "Build logical plan: Failed to do analyze. err: The function \\"compact_state_agg\\" expects 2 arguments, but 1 were provided",.*
statement error Arrow error: Io error: Status \{ code: Internal, message: "Build logical plan: Datafusion: Error during planning: No function matches the given name and argument types 'compact_state_agg\(Timestamp\(Nanosecond, None\)\)'\. You might need to add explicit type casts\.\\n\\tCandidate functions:\\n\\tcompact_state_agg\(Timestamp\(Second, None\), Utf8\)\\n\\tcompact_state_agg\(Timestamp\(Millisecond, None\), Utf8\)\\n\\tcompact_state_agg\(Timestamp\(Microsecond, None\), Utf8\)\\n\\tcompact_state_agg\(Timestamp\(Nanosecond, None\), Utf8\)\\n\\tcompact_state_agg\(Timestamp\(Second, None\), LargeUtf8\)\\n\\tcompact_state_agg\(Timestamp\(Millisecond, None\), LargeUtf8\)\\n\\tcompact_state_agg\(Timestamp\(Microsecond, None\), LargeUtf8\)\\n\\tcompact_state_agg\(Timestamp\(Nanosecond, None\), LargeUtf8\)\\n\\tcompact_state_agg\(Timestamp\(Second, None\), Int8\)\\n\\tcompact_state_agg\(Timestamp\(Millisecond, None\), Int8\)\\n\\tcompact_state_agg\(Timestamp\(Microsecond, None\), Int8\)\\n\\tcompact_state_agg\(Timestamp\(Nanosecond, None\), Int8\)\\n\\tcompact_state_agg\(Timestamp\(Second, None\), Int16\)\\n\\tcompact_state_agg\(Timestamp\(Millisecond, None\), Int16\)\\n\\tcompact_state_agg\(Timestamp\(Microsecond, None\), Int16\)\\n\\tcompact_state_agg\(Timestamp\(Nanosecond, None\), Int16\)\\n\\tcompact_state_agg\(Timestamp\(Second, None\), Int32\)\\n\\tcompact_state_agg\(Timestamp\(Millisecond, None\), Int32\)\\n\\tcompact_state_agg\(Timestamp\(Microsecond, None\), Int32\)\\n\\tcompact_state_agg\(Timestamp\(Nanosecond, None\), Int32\)\\n\\tcompact_state_agg\(Timestamp\(Second, None\), Int64\)\\n\\tcompact_state_agg\(Timestamp\(Millisecond, None\), Int64\)\\n\\tcompact_state_agg\(Timestamp\(Microsecond, None\), Int64\)\\n\\tcompact_state_agg\(Timestamp\(Nanosecond, None\), Int64\)\\n\\tcompact_state_agg\(Timestamp\(Second, None\), UInt8\)\\n\\tcompact_state_agg\(Timestamp\(Millisecond, None\), UInt8\)\\n\\tcompact_state_agg\(Timestamp\(Microsecond, None\), UInt8\)\\n\\tcompact_state_agg\(Timestamp\(Nanosecond, None\), UInt8\)\\n\\tcompact_state_agg\(Timestamp\(Second, None\), UInt16\)\\n\\tcompact_state_agg\(Timestamp\(Millisecond, None\), UInt16\)\\n\\tcompact_state_agg\(Timestamp\(Microsecond, None\), UInt16\)\\n\\tcompact_state_agg\(Timestamp\(Nanosecond, None\), UInt16\)\\n\\tcompact_state_agg\(Timestamp\(Second, None\), UInt32\)\\n\\tcompact_state_agg\(Timestamp\(Millisecond, None\), UInt32\)\\n\\tcompact_state_agg\(Timestamp\(Microsecond, None\), UInt32\)\\n\\tcompact_state_agg\(Timestamp\(Nanosecond, None\), UInt32\)\\n\\tcompact_state_agg\(Timestamp\(Second, None\), UInt64\)\\n\\tcompact_state_agg\(Timestamp\(Millisecond, None\), UInt64\)\\n\\tcompact_state_agg\(Timestamp\(Microsecond, None\), UInt64\)\\n\\tcompact_state_agg\(Timestamp\(Nanosecond, None\), UInt64\)", .*
select compact_state_agg(time) as state from func_tbl;

0 comments on commit 1b7b06a

Please sign in to comment.