Skip to content

Commit

Permalink
feat: refactor CHECK_ARGS_FUNC
Browse files Browse the repository at this point in the history
  • Loading branch information
yukkit committed Aug 3, 2023
1 parent b30aa49 commit 43c755b
Show file tree
Hide file tree
Showing 4 changed files with 11 additions and 13 deletions.
Expand Up @@ -18,7 +18,7 @@ use super::{GaugeData, TSPoint};
use crate::extension::expr::aggregate_function::{
scalar_to_points, AggResult, AggState, GAUGE_AGG_UDAF_NAME,
};
use crate::extension::expr::expr_utils::CHECK_ARGS_FUNC;
use crate::extension::expr::expr_utils::check_args;

pub fn register_udaf(func_manager: &mut dyn FunctionMetadataManager) -> Result<(), QueryError> {
func_manager.register_udaf(new())?;
Expand All @@ -27,7 +27,7 @@ pub fn register_udaf(func_manager: &mut dyn FunctionMetadataManager) -> Result<(

fn new() -> AggregateUDF {
let return_type_func: ReturnTypeFunction = Arc::new(move |input| {
CHECK_ARGS_FUNC(GAUGE_AGG_UDAF_NAME, 2, input)?;
check_args(GAUGE_AGG_UDAF_NAME, 2, input)?;

let result = GaugeData::try_new_null(input[0].clone(), input[1].clone())?;
let date_type = result.to_scalar()?.get_datatype();
Expand All @@ -45,7 +45,7 @@ fn new() -> AggregateUDF {
});

let accumulator: AccumulatorFactoryFunction = Arc::new(|input, output| {
CHECK_ARGS_FUNC(GAUGE_AGG_UDAF_NAME, 2, input)?;
check_args(GAUGE_AGG_UDAF_NAME, 2, input)?;

let ts_data_type = input[0].clone();
let value_data_type = input[1].clone();
Expand Down
Expand Up @@ -17,7 +17,7 @@ use crate::extension::expr::aggregate_function::state_agg::{
AggResult, StateAggData, LIST_ELEMENT_NAME,
};
use crate::extension::expr::aggregate_function::COMPACT_STATE_AGG_UDAF_NAME;
use crate::extension::expr::expr_utils::CHECK_ARGS_FUNC;
use crate::extension::expr::expr_utils::check_args;
use crate::extension::expr::INTEGERS;

pub fn register_udaf(func_manager: &mut dyn FunctionMetadataManager) -> Result<(), QueryError> {
Expand All @@ -27,7 +27,7 @@ pub fn register_udaf(func_manager: &mut dyn FunctionMetadataManager) -> Result<(

fn new() -> AggregateUDF {
let return_type_func: ReturnTypeFunction = Arc::new(move |input| {
CHECK_ARGS_FUNC(COMPACT_STATE_AGG_UDAF_NAME, 2, input)?;
check_args(COMPACT_STATE_AGG_UDAF_NAME, 2, input)?;

let result = StateAggData::new(input[0].clone(), input[1].clone(), false);
let date_type = result.to_scalar()?.get_datatype();
Expand All @@ -38,7 +38,7 @@ fn new() -> AggregateUDF {
});

let state_type_func: StateTypeFunction = Arc::new(move |input, _| {
CHECK_ARGS_FUNC(COMPACT_STATE_AGG_UDAF_NAME, 2, input)?;
check_args(COMPACT_STATE_AGG_UDAF_NAME, 2, input)?;

let types = input
.iter()
Expand All @@ -51,7 +51,7 @@ fn new() -> AggregateUDF {
});

let accumulator: AccumulatorFactoryFunction = Arc::new(|input, _| {
CHECK_ARGS_FUNC(COMPACT_STATE_AGG_UDAF_NAME, 2, input)?;
check_args(COMPACT_STATE_AGG_UDAF_NAME, 2, input)?;

Ok(Box::new(CompactStateAggAccumulator::try_new(
input.to_vec(),
Expand Down
6 changes: 2 additions & 4 deletions query_server/query/src/extension/expr/expr_utils.rs
Expand Up @@ -9,9 +9,7 @@ use spi::QueryError;

use super::selector_function::{BOTTOM, TOPK};

type CheckArgsFuncSignature<'a> = &'a dyn Fn(&str, usize, &[DataType]) -> DFResult<()>;

pub const CHECK_ARGS_FUNC: CheckArgsFuncSignature = &|func_name, expects, input| {
pub fn check_args(func_name: &str, expects: usize, input: &[DataType]) -> DFResult<()> {
if input.len() != expects {
return Err(DataFusionError::External(Box::new(QueryError::Analyzer {
err: format!(
Expand All @@ -24,7 +22,7 @@ pub const CHECK_ARGS_FUNC: CheckArgsFuncSignature = &|func_name, expects, input|
}

Ok(())
};
}

pub fn is_time_filter(expr: &Expr) -> bool {
match expr {
Expand Down
Expand Up @@ -9,10 +9,10 @@ macro_rules! object_accessor {
use datafusion::physical_plan::ColumnarValue;
use datafusion::scalar::ScalarValue;
use $crate::extension::expr::aggregate_function::$OBJECT;
use $crate::extension::expr::expr_utils::CHECK_ARGS_FUNC;
use $crate::extension::expr::expr_utils::check_args;

let return_type_fn: ReturnTypeFunction = Arc::new(|args| {
CHECK_ARGS_FUNC(stringify!($FUNC), 1, args)?;
check_args(stringify!($FUNC), 1, args)?;

let null_data = $OBJECT::try_from_scalar(ScalarValue::try_from(&args[0])?)?;
let output = null_data.$FUNC()?.get_datatype();
Expand Down

0 comments on commit 43c755b

Please sign in to comment.