diff --git a/common/aggregate_functions/src/aggregate_arg_max.rs b/common/aggregate_functions/src/aggregate_arg_max.rs new file mode 100644 index 0000000000000..db89b10c01162 --- /dev/null +++ b/common/aggregate_functions/src/aggregate_arg_max.rs @@ -0,0 +1,122 @@ +// Copyright 2020-2021 The Datafuse Authors. +// +// SPDX-License-Identifier: Apache-2.0. + +use std::convert::TryInto; +use std::fmt; + +use common_datavalues::DataArrayAggregate; +use common_datavalues::DataColumnarValue; +use common_datavalues::DataSchema; +use common_datavalues::DataType; +use common_datavalues::DataValue; +use common_datavalues::DataValueAggregate; +use common_datavalues::DataValueAggregateOperator; +use common_exception::Result; + +use crate::IAggregateFunction; + +#[derive(Clone)] +pub struct AggregateArgMaxFunction { + display_name: String, + depth: usize, + state: DataValue +} + +impl AggregateArgMaxFunction { + pub fn try_create(display_name: &str) -> Result> { + Ok(Box::new(AggregateArgMaxFunction { + display_name: display_name.to_string(), + depth: 0, + state: DataValue::Struct(vec![DataValue::Null, DataValue::Null]) + })) + } +} + +impl IAggregateFunction for AggregateArgMaxFunction { + fn name(&self) -> &str { + "AggregateArgMaxFunction" + } + + fn return_type(&self, args: &[DataType]) -> Result { + Ok(args[0].clone()) + } + + fn nullable(&self, _input_schema: &DataSchema) -> Result { + Ok(false) + } + + fn set_depth(&mut self, depth: usize) { + self.depth = depth; + } + + fn accumulate(&mut self, columns: &[DataColumnarValue], _input_rows: usize) -> Result<()> { + if let DataValue::Struct(max_arg_val) = DataArrayAggregate::data_array_aggregate_op( + DataValueAggregateOperator::ArgMax, + columns[1].to_array()? + )? { + let index: u64 = max_arg_val[0].clone().try_into()?; + let max_arg = DataValue::try_from_array(&columns[0].to_array()?, index as usize)?; + let max_val = max_arg_val[1].clone(); + + if let DataValue::Struct(old_max_arg_val) = self.state.clone() { + let old_max_arg = old_max_arg_val[0].clone(); + let old_max_val = old_max_arg_val[1].clone(); + let new_max_val = DataValueAggregate::data_value_aggregate_op( + DataValueAggregateOperator::Max, + old_max_val.clone(), + max_val + )?; + self.state = DataValue::Struct(vec![ + if new_max_val == old_max_val { + old_max_arg + } else { + max_arg + }, + new_max_val, + ]); + } + } + Ok(()) + } + + fn accumulate_result(&self) -> Result> { + Ok(vec![self.state.clone()]) + } + + fn merge(&mut self, states: &[DataValue]) -> Result<()> { + let arg_val = states[self.depth].clone(); + if let (DataValue::Struct(new_states), DataValue::Struct(old_states)) = + (arg_val, self.state.clone()) + { + let new_max_val = DataValueAggregate::data_value_aggregate_op( + DataValueAggregateOperator::Max, + new_states[1].clone(), + old_states[1].clone() + )?; + self.state = DataValue::Struct(vec![ + if new_max_val == old_states[1] { + old_states[0].clone() + } else { + new_states[0].clone() + }, + new_max_val, + ]); + } + Ok(()) + } + + fn merge_result(&self) -> Result { + Ok(if let DataValue::Struct(state) = self.state.clone() { + state[0].clone() + } else { + self.state.clone() + }) + } +} + +impl fmt::Display for AggregateArgMaxFunction { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "{}", self.display_name) + } +} diff --git a/common/aggregate_functions/src/aggregate_arg_min.rs b/common/aggregate_functions/src/aggregate_arg_min.rs new file mode 100644 index 0000000000000..bc21d1bcfdd10 --- /dev/null +++ b/common/aggregate_functions/src/aggregate_arg_min.rs @@ -0,0 +1,122 @@ +// Copyright 2020-2021 The Datafuse Authors. +// +// SPDX-License-Identifier: Apache-2.0. + +use std::convert::TryInto; +use std::fmt; + +use common_datavalues::DataArrayAggregate; +use common_datavalues::DataColumnarValue; +use common_datavalues::DataSchema; +use common_datavalues::DataType; +use common_datavalues::DataValue; +use common_datavalues::DataValueAggregate; +use common_datavalues::DataValueAggregateOperator; +use common_exception::Result; + +use crate::IAggregateFunction; + +#[derive(Clone)] +pub struct AggregateArgMinFunction { + display_name: String, + depth: usize, + state: DataValue +} + +impl AggregateArgMinFunction { + pub fn try_create(display_name: &str) -> Result> { + Ok(Box::new(AggregateArgMinFunction { + display_name: display_name.to_string(), + depth: 0, + state: DataValue::Struct(vec![DataValue::Null, DataValue::Null]) + })) + } +} + +impl IAggregateFunction for AggregateArgMinFunction { + fn name(&self) -> &str { + "AggregateArgMinFunction" + } + + fn return_type(&self, args: &[DataType]) -> Result { + Ok(args[0].clone()) + } + + fn nullable(&self, _input_schema: &DataSchema) -> Result { + Ok(false) + } + + fn set_depth(&mut self, depth: usize) { + self.depth = depth; + } + + fn accumulate(&mut self, columns: &[DataColumnarValue], _input_rows: usize) -> Result<()> { + if let DataValue::Struct(min_arg_val) = DataArrayAggregate::data_array_aggregate_op( + DataValueAggregateOperator::ArgMin, + columns[1].to_array()? + )? { + let index: u64 = min_arg_val[0].clone().try_into()?; + let min_arg = DataValue::try_from_array(&columns[0].to_array()?, index as usize)?; + let min_val = min_arg_val[1].clone(); + + if let DataValue::Struct(old_min_arg_val) = self.state.clone() { + let old_min_arg = old_min_arg_val[0].clone(); + let old_min_val = old_min_arg_val[1].clone(); + let new_min_val = DataValueAggregate::data_value_aggregate_op( + DataValueAggregateOperator::Min, + old_min_val.clone(), + min_val + )?; + self.state = DataValue::Struct(vec![ + if new_min_val == old_min_val { + old_min_arg + } else { + min_arg + }, + new_min_val, + ]); + } + } + Ok(()) + } + + fn accumulate_result(&self) -> Result> { + Ok(vec![self.state.clone()]) + } + + fn merge(&mut self, states: &[DataValue]) -> Result<()> { + let arg_val = states[self.depth].clone(); + if let (DataValue::Struct(new_states), DataValue::Struct(old_states)) = + (arg_val, self.state.clone()) + { + let new_min_val = DataValueAggregate::data_value_aggregate_op( + DataValueAggregateOperator::Min, + new_states[1].clone(), + old_states[1].clone() + )?; + self.state = DataValue::Struct(vec![ + if new_min_val == old_states[1] { + old_states[0].clone() + } else { + new_states[0].clone() + }, + new_min_val, + ]); + } + Ok(()) + } + + fn merge_result(&self) -> Result { + Ok(if let DataValue::Struct(state) = self.state.clone() { + state[0].clone() + } else { + self.state.clone() + }) + } +} + +impl fmt::Display for AggregateArgMinFunction { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "{}", self.display_name) + } +} diff --git a/common/aggregate_functions/src/aggregator.rs b/common/aggregate_functions/src/aggregator.rs index 369572e602ed7..5842d0ba04029 100644 --- a/common/aggregate_functions/src/aggregator.rs +++ b/common/aggregate_functions/src/aggregator.rs @@ -5,6 +5,8 @@ use common_exception::Result; use crate::aggregate_function_factory::FactoryFuncRef; +use crate::AggregateArgMaxFunction; +use crate::AggregateArgMinFunction; use crate::AggregateAvgFunction; use crate::AggregateCountFunction; use crate::AggregateMaxFunction; @@ -16,11 +18,14 @@ pub struct AggregatorFunction; impl AggregatorFunction { pub fn register(map: FactoryFuncRef) -> Result<()> { let mut map = map.write(); + // FuseQuery always uses lowercase function names to get functions. map.insert("count", AggregateCountFunction::try_create); map.insert("min", AggregateMinFunction::try_create); map.insert("max", AggregateMaxFunction::try_create); map.insert("sum", AggregateSumFunction::try_create); map.insert("avg", AggregateAvgFunction::try_create); + map.insert("argmin", AggregateArgMinFunction::try_create); + map.insert("argmax", AggregateArgMaxFunction::try_create); Ok(()) } } diff --git a/common/aggregate_functions/src/aggregator_test.rs b/common/aggregate_functions/src/aggregator_test.rs index a8d7a7b1f9aaa..a92f846a952eb 100644 --- a/common/aggregate_functions/src/aggregator_test.rs +++ b/common/aggregate_functions/src/aggregator_test.rs @@ -86,6 +86,28 @@ fn test_aggregate_function() -> Result<()> { expect: DataValue::Int64(Some(10)), error: "" }, + Test { + name: "argMax-passed", + eval_nums: 1, + types: vec![DataType::Int64, DataType::Int64], + display: "argmax", + nullable: false, + func: AggregateArgMaxFunction::try_create("argmax")?, + columns: columns.clone(), + expect: DataValue::Int64(Some(1)), + error: "" + }, + Test { + name: "argMin-passed", + eval_nums: 1, + types: vec![DataType::Int64, DataType::Int64], + display: "argmin", + nullable: false, + func: AggregateArgMinFunction::try_create("argmin")?, + columns: columns.clone(), + expect: DataValue::Int64(Some(4)), + error: "" + }, ]; for t in tests { diff --git a/common/aggregate_functions/src/lib.rs b/common/aggregate_functions/src/lib.rs index db166f19bf8c5..a131b417da257 100644 --- a/common/aggregate_functions/src/lib.rs +++ b/common/aggregate_functions/src/lib.rs @@ -5,6 +5,8 @@ #[cfg(test)] mod aggregator_test; +mod aggregate_arg_max; +mod aggregate_arg_min; mod aggregate_avg; mod aggregate_count; mod aggregate_function; @@ -14,6 +16,8 @@ mod aggregate_min; mod aggregate_sum; mod aggregator; +pub use aggregate_arg_max::AggregateArgMaxFunction; +pub use aggregate_arg_min::AggregateArgMinFunction; pub use aggregate_avg::AggregateAvgFunction; pub use aggregate_count::AggregateCountFunction; pub use aggregate_function::IAggregateFunction; diff --git a/common/datavalues/src/data_array_aggregate.rs b/common/datavalues/src/data_array_aggregate.rs index 9b1b9372a73e7..0cb708dbc52c1 100644 --- a/common/datavalues/src/data_array_aggregate.rs +++ b/common/datavalues/src/data_array_aggregate.rs @@ -50,6 +50,24 @@ impl DataArrayAggregate { value.data_type() ))) } + DataValueAggregateOperator::ArgMax => { + typed_array_values_min_max_to_data_value!( + value, + Int8Array, + Int8, + i8, + DataValueAggregateOperator::ArgMax + ) + } + DataValueAggregateOperator::ArgMin => { + typed_array_values_min_max_to_data_value!( + value, + Int8Array, + Int8, + i8, + DataValueAggregateOperator::ArgMin + ) + } }, DataType::Int16 => match op { DataValueAggregateOperator::Min => { @@ -71,6 +89,24 @@ impl DataArrayAggregate { value.data_type() ))) } + DataValueAggregateOperator::ArgMax => { + typed_array_values_min_max_to_data_value!( + value, + Int16Array, + Int16, + i16, + DataValueAggregateOperator::ArgMax + ) + } + DataValueAggregateOperator::ArgMin => { + typed_array_values_min_max_to_data_value!( + value, + Int16Array, + Int16, + i16, + DataValueAggregateOperator::ArgMin + ) + } }, DataType::Int32 => match op { DataValueAggregateOperator::Min => { @@ -93,6 +129,24 @@ impl DataArrayAggregate { value.data_type() ))) } + DataValueAggregateOperator::ArgMax => { + typed_array_values_min_max_to_data_value!( + value, + Int32Array, + Int32, + i32, + DataValueAggregateOperator::ArgMax + ) + } + DataValueAggregateOperator::ArgMin => { + typed_array_values_min_max_to_data_value!( + value, + Int32Array, + Int32, + i32, + DataValueAggregateOperator::ArgMin + ) + } }, DataType::Int64 => match op { DataValueAggregateOperator::Min => { @@ -115,6 +169,24 @@ impl DataArrayAggregate { value.data_type() ))) } + DataValueAggregateOperator::ArgMax => { + typed_array_values_min_max_to_data_value!( + value, + Int64Array, + Int64, + i64, + DataValueAggregateOperator::ArgMax + ) + } + DataValueAggregateOperator::ArgMin => { + typed_array_values_min_max_to_data_value!( + value, + Int64Array, + Int64, + i64, + DataValueAggregateOperator::ArgMin + ) + } }, DataType::UInt8 => match op { DataValueAggregateOperator::Min => { @@ -137,6 +209,24 @@ impl DataArrayAggregate { value.data_type() ))) } + DataValueAggregateOperator::ArgMax => { + typed_array_values_min_max_to_data_value!( + value, + UInt8Array, + UInt8, + u8, + DataValueAggregateOperator::ArgMax + ) + } + DataValueAggregateOperator::ArgMin => { + typed_array_values_min_max_to_data_value!( + value, + UInt8Array, + UInt8, + u8, + DataValueAggregateOperator::ArgMin + ) + } }, DataType::UInt16 => match op { DataValueAggregateOperator::Min => { @@ -159,6 +249,24 @@ impl DataArrayAggregate { value.data_type() ))) } + DataValueAggregateOperator::ArgMax => { + typed_array_values_min_max_to_data_value!( + value, + UInt16Array, + UInt16, + u16, + DataValueAggregateOperator::ArgMax + ) + } + DataValueAggregateOperator::ArgMin => { + typed_array_values_min_max_to_data_value!( + value, + UInt16Array, + UInt16, + u16, + DataValueAggregateOperator::ArgMin + ) + } }, DataType::UInt32 => match op { DataValueAggregateOperator::Min => { @@ -181,6 +289,24 @@ impl DataArrayAggregate { value.data_type() ))) } + DataValueAggregateOperator::ArgMax => { + typed_array_values_min_max_to_data_value!( + value, + UInt32Array, + UInt32, + u32, + DataValueAggregateOperator::ArgMax + ) + } + DataValueAggregateOperator::ArgMin => { + typed_array_values_min_max_to_data_value!( + value, + UInt32Array, + UInt32, + u32, + DataValueAggregateOperator::ArgMin + ) + } }, DataType::UInt64 => match op { DataValueAggregateOperator::Min => { @@ -202,6 +328,24 @@ impl DataArrayAggregate { value.data_type() ))) } + DataValueAggregateOperator::ArgMax => { + typed_array_values_min_max_to_data_value!( + value, + UInt64Array, + UInt64, + u64, + DataValueAggregateOperator::ArgMax + ) + } + DataValueAggregateOperator::ArgMin => { + typed_array_values_min_max_to_data_value!( + value, + UInt64Array, + UInt64, + u64, + DataValueAggregateOperator::ArgMin + ) + } }, DataType::Float32 => match op { DataValueAggregateOperator::Min => { @@ -223,6 +367,24 @@ impl DataArrayAggregate { value.data_type() ))) } + DataValueAggregateOperator::ArgMax => { + typed_array_values_min_max_to_data_value!( + value, + Float32Array, + Float32, + f32, + DataValueAggregateOperator::ArgMax + ) + } + DataValueAggregateOperator::ArgMin => { + typed_array_values_min_max_to_data_value!( + value, + Float32Array, + Float32, + f32, + DataValueAggregateOperator::ArgMin + ) + } }, DataType::Float64 => match op { DataValueAggregateOperator::Min => { @@ -244,6 +406,24 @@ impl DataArrayAggregate { value.data_type() ))) } + DataValueAggregateOperator::ArgMax => { + typed_array_values_min_max_to_data_value!( + value, + Float64Array, + Float64, + f64, + DataValueAggregateOperator::ArgMax + ) + } + DataValueAggregateOperator::ArgMin => { + typed_array_values_min_max_to_data_value!( + value, + Float64Array, + Float64, + f64, + DataValueAggregateOperator::ArgMin + ) + } }, DataType::Utf8 => match op { DataValueAggregateOperator::Min => { @@ -255,6 +435,22 @@ impl DataArrayAggregate { DataValueAggregateOperator::Count => { Ok(DataValue::UInt64(Some(value.len() as u64))) } + DataValueAggregateOperator::ArgMax => { + typed_array_values_min_max_string_to_data_value!( + value, + StringArray, + Utf8, + DataValueAggregateOperator::ArgMax + ) + } + DataValueAggregateOperator::ArgMin => { + typed_array_values_min_max_string_to_data_value!( + value, + StringArray, + Utf8, + DataValueAggregateOperator::ArgMin + ) + } _ => Result::Err(ErrorCodes::BadDataValueType(format!( "DataValue Error: Unsupported data_array_{} for data type: {:?}", op, diff --git a/common/datavalues/src/data_array_aggregate_test.rs b/common/datavalues/src/data_array_aggregate_test.rs index 952f94a108cc5..59b053727c677 100644 --- a/common/datavalues/src/data_array_aggregate_test.rs +++ b/common/datavalues/src/data_array_aggregate_test.rs @@ -144,6 +144,68 @@ fn test_array_aggregate() { "Code: 10, displayText = DataValue Error: Unsupported data_array_avg for data type: Float64.", ] }, + ArrayTest { + name: "argMax-passed", + args: vec![ + Arc::new(StringArray::from(vec!["x1", "x2"])), + Arc::new(Int8Array::from(vec![1, 2, 3, 4])), + Arc::new(Int16Array::from(vec![4, 3, 2, 1])), + Arc::new(Int32Array::from(vec![4, 3, 2, 1])), + Arc::new(Int64Array::from(vec![4, 3, 2, 1])), + Arc::new(UInt8Array::from(vec![4, 3, 2, 1])), + Arc::new(UInt16Array::from(vec![4, 3, 2, 1])), + Arc::new(UInt32Array::from(vec![4, 3, 2, 1])), + Arc::new(UInt64Array::from(vec![4, 3, 2, 1])), + Arc::new(Float32Array::from(vec![4.0, 3.0, 2.0, 1.0])), + Arc::new(Float64Array::from(vec![4.0, 3.0, 2.0, 1.0])), + ], + op: DataValueAggregateOperator::ArgMax, + expect: vec![ + DataValue::Struct(vec![DataValue::UInt64(Some(1)), DataValue::Utf8(Some("x2".to_string()))]), + DataValue::Struct(vec![DataValue::UInt64(Some(3)), DataValue::Int8(Some(4))]), + DataValue::Struct(vec![DataValue::UInt64(Some(0)), DataValue::Int16(Some(4))]), + DataValue::Struct(vec![DataValue::UInt64(Some(0)), DataValue::Int32(Some(4))]), + DataValue::Struct(vec![DataValue::UInt64(Some(0)), DataValue::Int64(Some(4))]), + DataValue::Struct(vec![DataValue::UInt64(Some(0)), DataValue::UInt8(Some(4))]), + DataValue::Struct(vec![DataValue::UInt64(Some(0)), DataValue::UInt16(Some(4))]), + DataValue::Struct(vec![DataValue::UInt64(Some(0)), DataValue::UInt32(Some(4))]), + DataValue::Struct(vec![DataValue::UInt64(Some(0)), DataValue::UInt64(Some(4))]), + DataValue::Struct(vec![DataValue::UInt64(Some(0)), DataValue::Float32(Some(4.0))]), + DataValue::Struct(vec![DataValue::UInt64(Some(0)), DataValue::Float64(Some(4.0))]), + ], + error: vec![""] + }, + ArrayTest { + name: "argMin-passed", + args: vec![ + Arc::new(StringArray::from(vec!["x1", "x2"])), + Arc::new(Int8Array::from(vec![1, 2, 3, 4])), + Arc::new(Int16Array::from(vec![4, 3, 2, 1])), + Arc::new(Int32Array::from(vec![4, 3, 2, 1])), + Arc::new(Int64Array::from(vec![4, 3, 2, 1])), + Arc::new(UInt8Array::from(vec![4, 3, 2, 1])), + Arc::new(UInt16Array::from(vec![4, 3, 2, 1])), + Arc::new(UInt32Array::from(vec![4, 3, 2, 1])), + Arc::new(UInt64Array::from(vec![4, 3, 2, 1])), + Arc::new(Float32Array::from(vec![4.0, 3.0, 2.0, 1.0])), + Arc::new(Float64Array::from(vec![4.0, 3.0, 2.0, 1.0])), + ], + op: DataValueAggregateOperator::ArgMin, + expect: vec![ + DataValue::Struct(vec![DataValue::UInt64(Some(0)), DataValue::Utf8(Some("x1".to_string()))]), + DataValue::Struct(vec![DataValue::UInt64(Some(0)), DataValue::Int8(Some(1))]), + DataValue::Struct(vec![DataValue::UInt64(Some(3)), DataValue::Int16(Some(1))]), + DataValue::Struct(vec![DataValue::UInt64(Some(3)), DataValue::Int32(Some(1))]), + DataValue::Struct(vec![DataValue::UInt64(Some(3)), DataValue::Int64(Some(1))]), + DataValue::Struct(vec![DataValue::UInt64(Some(3)), DataValue::UInt8(Some(1))]), + DataValue::Struct(vec![DataValue::UInt64(Some(3)), DataValue::UInt16(Some(1))]), + DataValue::Struct(vec![DataValue::UInt64(Some(3)), DataValue::UInt32(Some(1))]), + DataValue::Struct(vec![DataValue::UInt64(Some(3)), DataValue::UInt64(Some(1))]), + DataValue::Struct(vec![DataValue::UInt64(Some(3)), DataValue::Float32(Some(1.0))]), + DataValue::Struct(vec![DataValue::UInt64(Some(3)), DataValue::Float64(Some(1.0))]), + ], + error: vec![""] + }, ]; for t in tests { diff --git a/common/datavalues/src/data_value_operator.rs b/common/datavalues/src/data_value_operator.rs index 95e0986ba7fcd..c68dba2532eec 100644 --- a/common/datavalues/src/data_value_operator.rs +++ b/common/datavalues/src/data_value_operator.rs @@ -8,7 +8,9 @@ pub enum DataValueAggregateOperator { Max, Sum, Avg, - Count + Count, + ArgMin, + ArgMax } impl std::fmt::Display for DataValueAggregateOperator { @@ -18,7 +20,9 @@ impl std::fmt::Display for DataValueAggregateOperator { DataValueAggregateOperator::Max => "max", DataValueAggregateOperator::Sum => "sum", DataValueAggregateOperator::Avg => "avg", - DataValueAggregateOperator::Count => "count" + DataValueAggregateOperator::Count => "count", + DataValueAggregateOperator::ArgMin => "argMin", + DataValueAggregateOperator::ArgMax => "ArgMax" }; write!(f, "{}", display) } diff --git a/common/datavalues/src/macros.rs b/common/datavalues/src/macros.rs index d413e9a6bad74..515f507a72d2b 100644 --- a/common/datavalues/src/macros.rs +++ b/common/datavalues/src/macros.rs @@ -241,6 +241,71 @@ macro_rules! typed_array_min_max_string_to_data_value { Result::Ok(DataValue::$SCALAR(value)) }}; } + +macro_rules! typed_array_values_min_max_to_data_value { + ($VALUES:expr, $ARRAYTYPE:ident, $SCALAR:ident, $TYPE:ident, $OP:expr) => {{ + let array = downcast_array!($VALUES, $ARRAYTYPE)?; + let vals_std: &[$TYPE] = array.values(); + let mut min_max_row_val: (u64, $TYPE) = (0, vals_std[0]); + for (row, val) in vals_std.iter().enumerate() { + match $OP { + DataValueAggregateOperator::ArgMin => { + if *val < min_max_row_val.1 { + min_max_row_val = (row as u64, *val); + } + } + DataValueAggregateOperator::ArgMax => { + if *val > min_max_row_val.1 { + min_max_row_val = (row as u64, *val); + } + } + _ => { + panic!( + "Unexpected {} for macro typed_array_values_min_max_to_data_value", + stringify!($OP), + ) + } + } + } + Result::Ok(DataValue::Struct(vec![ + DataValue::UInt64(Some(min_max_row_val.0)), + DataValue::$SCALAR(Some(min_max_row_val.1)), + ])) + }}; +} + +macro_rules! typed_array_values_min_max_string_to_data_value { + ($VALUES:expr, $ARRAYTYPE:ident, $SCALAR:ident, $OP:expr) => {{ + let array = downcast_array!($VALUES, $ARRAYTYPE)?; + let mut min_max_row_val: (u64, &str) = (0, array.value(0)); + for (row, val) in array.iter().enumerate() { + let str_val = val.unwrap(); + match $OP { + DataValueAggregateOperator::ArgMin => { + if str_val < min_max_row_val.1 { + min_max_row_val = (row as u64, str_val); + } + } + DataValueAggregateOperator::ArgMax => { + if str_val > min_max_row_val.1 { + min_max_row_val = (row as u64, str_val); + } + } + _ => { + panic!( + "Unexpected {} for macro typed_array_values_min_max_to_data_value", + stringify!($OP), + ) + } + } + } + Result::Ok(DataValue::Struct(vec![ + DataValue::UInt64(Some(min_max_row_val.0)), + DataValue::$SCALAR(Some(min_max_row_val.1.to_string())), + ])) + }}; +} + // returns the sum of two data values, including coercion into $TYPE. macro_rules! typed_data_value_add { ($OLD_VALUE:expr, $DELTA:expr, $SCALAR:ident, $TYPE:ident) => {{ diff --git a/tests/suites/0_stateless/01_0000_system_numbers.result b/tests/suites/0_stateless/01_0000_system_numbers.result index 4296803401030..943bebc4897ee 100644 --- a/tests/suites/0_stateless/01_0000_system_numbers.result +++ b/tests/suites/0_stateless/01_0000_system_numbers.result @@ -52,3 +52,57 @@ SELECT sum(number)/count(number) from numbers_mt(10000) +-------------------------------+ | 4999.5 | +-------------------------------+ +-------------- +SELECT argMin(number, number) from numbers_mt(10000) +-------------- + ++------------------------+ +| argMin(number, number) | ++------------------------+ +| 0 | ++------------------------+ +-------------- +SELECT argMin(a, b) from (select number + 5 as a, number - 5 as b from numbers_mt(10000) JOIN numbers_mt(10000)) +-------------- + ++--------------+ +| argMin(a, b) | ++--------------+ +| 5 | ++--------------+ +-------------- +SELECT argMin(b, a) from (select number + 5 as a, number - 5 as b from numbers_mt(10000) JOIN numbers_mt(10000)) +-------------- + ++--------------+ +| argMin(b, a) | ++--------------+ +| -5 | ++--------------+ +-------------- +SELECT argMax(number, number) from numbers_mt(10000) +-------------- + ++------------------------+ +| argMax(number, number) | ++------------------------+ +| 9999 | ++------------------------+ +-------------- +SELECT argMax(a, b) from (select number + 5 as a, number - 5 as b from numbers_mt(10000) JOIN numbers_mt(10000)) +-------------- + ++--------------+ +| argMax(a, b) | ++--------------+ +| 10004 | ++--------------+ +-------------- +SELECT argMax(b, a) from (select number + 5 as a, number - 5 as b from numbers_mt(10000) JOIN numbers_mt(10000)) +-------------- + ++--------------+ +| argMax(b, a) | ++--------------+ +| 9994 | ++--------------+ diff --git a/tests/suites/0_stateless/01_0000_system_numbers.sql b/tests/suites/0_stateless/01_0000_system_numbers.sql index f8afe3d6ef45e..d132e338b8f9b 100644 --- a/tests/suites/0_stateless/01_0000_system_numbers.sql +++ b/tests/suites/0_stateless/01_0000_system_numbers.sql @@ -4,3 +4,9 @@ SELECT max(number) from numbers_mt(10000); SELECT avg(number) from numbers_mt(10000); SELECT count(number) from numbers_mt(10000); SELECT sum(number)/count(number) from numbers_mt(10000); +SELECT argMin(number, number) from numbers_mt(10000); +SELECT argMin(a, b) from (select number + 5 as a, number - 5 as b from numbers_mt(10000) JOIN numbers_mt(10000)); +SELECT argMin(b, a) from (select number + 5 as a, number - 5 as b from numbers_mt(10000) JOIN numbers_mt(10000)); +SELECT argMax(number, number) from numbers_mt(10000); +SELECT argMax(a, b) from (select number + 5 as a, number - 5 as b from numbers_mt(10000) JOIN numbers_mt(10000)); +SELECT argMax(b, a) from (select number + 5 as a, number - 5 as b from numbers_mt(10000) JOIN numbers_mt(10000)); \ No newline at end of file