Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
122 changes: 122 additions & 0 deletions common/aggregate_functions/src/aggregate_arg_max.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
// Copyright 2020-2021 The Datafuse Authors.
//
// SPDX-License-Identifier: Apache-2.0.

use std::convert::TryInto;
use std::fmt;

use common_datavalues::DataArrayAggregate;
use common_datavalues::DataColumnarValue;
use common_datavalues::DataSchema;
use common_datavalues::DataType;
use common_datavalues::DataValue;
use common_datavalues::DataValueAggregate;
use common_datavalues::DataValueAggregateOperator;
use common_exception::Result;

use crate::IAggregateFunction;

#[derive(Clone)]
pub struct AggregateArgMaxFunction {
display_name: String,
depth: usize,
state: DataValue
}

impl AggregateArgMaxFunction {
pub fn try_create(display_name: &str) -> Result<Box<dyn IAggregateFunction>> {
Ok(Box::new(AggregateArgMaxFunction {
display_name: display_name.to_string(),
depth: 0,
state: DataValue::Struct(vec![DataValue::Null, DataValue::Null])
}))
}
}

impl IAggregateFunction for AggregateArgMaxFunction {
fn name(&self) -> &str {
"AggregateArgMaxFunction"
}

fn return_type(&self, args: &[DataType]) -> Result<DataType> {
Ok(args[0].clone())
}

fn nullable(&self, _input_schema: &DataSchema) -> Result<bool> {
Ok(false)
}

fn set_depth(&mut self, depth: usize) {
self.depth = depth;
}

fn accumulate(&mut self, columns: &[DataColumnarValue], _input_rows: usize) -> Result<()> {
if let DataValue::Struct(max_arg_val) = DataArrayAggregate::data_array_aggregate_op(
DataValueAggregateOperator::ArgMax,
columns[1].to_array()?
)? {
let index: u64 = max_arg_val[0].clone().try_into()?;
let max_arg = DataValue::try_from_array(&columns[0].to_array()?, index as usize)?;
let max_val = max_arg_val[1].clone();

if let DataValue::Struct(old_max_arg_val) = self.state.clone() {
let old_max_arg = old_max_arg_val[0].clone();
let old_max_val = old_max_arg_val[1].clone();
let new_max_val = DataValueAggregate::data_value_aggregate_op(
DataValueAggregateOperator::Max,
old_max_val.clone(),
max_val
)?;
self.state = DataValue::Struct(vec![
if new_max_val == old_max_val {
old_max_arg
} else {
max_arg
},
new_max_val,
]);
}
}
Ok(())
}

fn accumulate_result(&self) -> Result<Vec<DataValue>> {
Ok(vec![self.state.clone()])
}

fn merge(&mut self, states: &[DataValue]) -> Result<()> {
let arg_val = states[self.depth].clone();
if let (DataValue::Struct(new_states), DataValue::Struct(old_states)) =
(arg_val, self.state.clone())
{
let new_max_val = DataValueAggregate::data_value_aggregate_op(
DataValueAggregateOperator::Max,
new_states[1].clone(),
old_states[1].clone()
)?;
self.state = DataValue::Struct(vec![
if new_max_val == old_states[1] {
old_states[0].clone()
} else {
new_states[0].clone()
},
new_max_val,
]);
}
Ok(())
}

fn merge_result(&self) -> Result<DataValue> {
Ok(if let DataValue::Struct(state) = self.state.clone() {
state[0].clone()
} else {
self.state.clone()
})
}
}

impl fmt::Display for AggregateArgMaxFunction {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "{}", self.display_name)
}
}
122 changes: 122 additions & 0 deletions common/aggregate_functions/src/aggregate_arg_min.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
// Copyright 2020-2021 The Datafuse Authors.
//
// SPDX-License-Identifier: Apache-2.0.

use std::convert::TryInto;
use std::fmt;

use common_datavalues::DataArrayAggregate;
use common_datavalues::DataColumnarValue;
use common_datavalues::DataSchema;
use common_datavalues::DataType;
use common_datavalues::DataValue;
use common_datavalues::DataValueAggregate;
use common_datavalues::DataValueAggregateOperator;
use common_exception::Result;

use crate::IAggregateFunction;

#[derive(Clone)]
pub struct AggregateArgMinFunction {
display_name: String,
depth: usize,
state: DataValue
}

impl AggregateArgMinFunction {
pub fn try_create(display_name: &str) -> Result<Box<dyn IAggregateFunction>> {
Ok(Box::new(AggregateArgMinFunction {
display_name: display_name.to_string(),
depth: 0,
state: DataValue::Struct(vec![DataValue::Null, DataValue::Null])
}))
}
}

impl IAggregateFunction for AggregateArgMinFunction {
fn name(&self) -> &str {
"AggregateArgMinFunction"
}

fn return_type(&self, args: &[DataType]) -> Result<DataType> {
Ok(args[0].clone())
}

fn nullable(&self, _input_schema: &DataSchema) -> Result<bool> {
Ok(false)
}

fn set_depth(&mut self, depth: usize) {
self.depth = depth;
}

fn accumulate(&mut self, columns: &[DataColumnarValue], _input_rows: usize) -> Result<()> {
if let DataValue::Struct(min_arg_val) = DataArrayAggregate::data_array_aggregate_op(
DataValueAggregateOperator::ArgMin,
columns[1].to_array()?
)? {
let index: u64 = min_arg_val[0].clone().try_into()?;
let min_arg = DataValue::try_from_array(&columns[0].to_array()?, index as usize)?;
let min_val = min_arg_val[1].clone();

if let DataValue::Struct(old_min_arg_val) = self.state.clone() {
let old_min_arg = old_min_arg_val[0].clone();
let old_min_val = old_min_arg_val[1].clone();
let new_min_val = DataValueAggregate::data_value_aggregate_op(
DataValueAggregateOperator::Min,
old_min_val.clone(),
min_val
)?;
self.state = DataValue::Struct(vec![
if new_min_val == old_min_val {
old_min_arg
} else {
min_arg
},
new_min_val,
]);
}
}
Ok(())
}

fn accumulate_result(&self) -> Result<Vec<DataValue>> {
Ok(vec![self.state.clone()])
}

fn merge(&mut self, states: &[DataValue]) -> Result<()> {
let arg_val = states[self.depth].clone();
if let (DataValue::Struct(new_states), DataValue::Struct(old_states)) =
(arg_val, self.state.clone())
{
let new_min_val = DataValueAggregate::data_value_aggregate_op(
DataValueAggregateOperator::Min,
new_states[1].clone(),
old_states[1].clone()
)?;
self.state = DataValue::Struct(vec![
if new_min_val == old_states[1] {
old_states[0].clone()
} else {
new_states[0].clone()
},
new_min_val,
]);
}
Ok(())
}

fn merge_result(&self) -> Result<DataValue> {
Ok(if let DataValue::Struct(state) = self.state.clone() {
state[0].clone()
} else {
self.state.clone()
})
}
}

impl fmt::Display for AggregateArgMinFunction {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "{}", self.display_name)
}
}
5 changes: 5 additions & 0 deletions common/aggregate_functions/src/aggregator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
use common_exception::Result;

use crate::aggregate_function_factory::FactoryFuncRef;
use crate::AggregateArgMaxFunction;
use crate::AggregateArgMinFunction;
use crate::AggregateAvgFunction;
use crate::AggregateCountFunction;
use crate::AggregateMaxFunction;
Expand All @@ -16,11 +18,14 @@ pub struct AggregatorFunction;
impl AggregatorFunction {
pub fn register(map: FactoryFuncRef) -> Result<()> {
let mut map = map.write();
// FuseQuery always uses lowercase function names to get functions.
map.insert("count", AggregateCountFunction::try_create);
map.insert("min", AggregateMinFunction::try_create);
map.insert("max", AggregateMaxFunction::try_create);
map.insert("sum", AggregateSumFunction::try_create);
map.insert("avg", AggregateAvgFunction::try_create);
map.insert("argmin", AggregateArgMinFunction::try_create);
map.insert("argmax", AggregateArgMaxFunction::try_create);
Ok(())
}
}
22 changes: 22 additions & 0 deletions common/aggregate_functions/src/aggregator_test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,28 @@ fn test_aggregate_function() -> Result<()> {
expect: DataValue::Int64(Some(10)),
error: ""
},
Test {
name: "argMax-passed",
eval_nums: 1,
types: vec![DataType::Int64, DataType::Int64],
display: "argmax",
nullable: false,
func: AggregateArgMaxFunction::try_create("argmax")?,
columns: columns.clone(),
expect: DataValue::Int64(Some(1)),
error: ""
},
Test {
name: "argMin-passed",
eval_nums: 1,
types: vec![DataType::Int64, DataType::Int64],
display: "argmin",
nullable: false,
func: AggregateArgMinFunction::try_create("argmin")?,
columns: columns.clone(),
expect: DataValue::Int64(Some(4)),
error: ""
},
];

for t in tests {
Expand Down
4 changes: 4 additions & 0 deletions common/aggregate_functions/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
#[cfg(test)]
mod aggregator_test;

mod aggregate_arg_max;
mod aggregate_arg_min;
mod aggregate_avg;
mod aggregate_count;
mod aggregate_function;
Expand All @@ -14,6 +16,8 @@ mod aggregate_min;
mod aggregate_sum;
mod aggregator;

pub use aggregate_arg_max::AggregateArgMaxFunction;
pub use aggregate_arg_min::AggregateArgMinFunction;
pub use aggregate_avg::AggregateAvgFunction;
pub use aggregate_count::AggregateCountFunction;
pub use aggregate_function::IAggregateFunction;
Expand Down
Loading