diff --git a/src/coprocessor/dag/aggr_fn/impl_sum.rs b/src/coprocessor/dag/aggr_fn/impl_sum.rs new file mode 100644 index 000000000000..6c22de319a4b --- /dev/null +++ b/src/coprocessor/dag/aggr_fn/impl_sum.rs @@ -0,0 +1,190 @@ +// Copyright 2019 TiKV Project Authors. Licensed under Apache-2.0. + +use cop_codegen::AggrFunction; +use cop_datatype::EvalType; +use tipb::expression::{Expr, ExprType, FieldType}; + +use super::summable::Summable; +use crate::coprocessor::codec::data_type::*; +use crate::coprocessor::codec::mysql::Tz; +use crate::coprocessor::dag::expr::EvalContext; +use crate::coprocessor::dag::rpn_expr::{RpnExpression, RpnExpressionBuilder}; +use crate::coprocessor::Result; + +/// The parser for SUM aggregate function. +pub struct AggrFnDefinitionParserSum; + +impl super::parser::AggrDefinitionParser for AggrFnDefinitionParserSum { + fn check_supported(&self, aggr_def: &Expr) -> Result<()> { + assert_eq!(aggr_def.get_tp(), ExprType::Sum); + super::util::check_aggr_exp_supported_one_child(aggr_def) + } + + fn parse( + &self, + mut aggr_def: Expr, + time_zone: &Tz, + src_schema: &[FieldType], + out_schema: &mut Vec, + out_exp: &mut Vec, + ) -> Result> { + use cop_datatype::FieldTypeAccessor; + use std::convert::TryFrom; + + assert_eq!(aggr_def.get_tp(), ExprType::Sum); + + // SUM outputs one column. + out_schema.push(aggr_def.take_field_type()); + + // Rewrite expression, inserting CAST if necessary. See `typeInfer4Sum` in TiDB. + let child = aggr_def.take_children().into_iter().next().unwrap(); + let mut exp = + RpnExpressionBuilder::build_from_expr_tree(child, time_zone, src_schema.len())?; + // The rewrite should always success. + super::util::rewrite_exp_for_sum_avg(src_schema, &mut exp).unwrap(); + + let rewritten_eval_type = EvalType::try_from(exp.ret_field_type(src_schema).tp()).unwrap(); + out_exp.push(exp); + + // Choose a type-aware SUM implementation based on the eval type after rewriting exp. + Ok(match rewritten_eval_type { + EvalType::Decimal => Box::new(AggrFnSum::::new()), + EvalType::Real => Box::new(AggrFnSum::::new()), + // If we meet unexpected types after rewriting, it is an implementation fault. + _ => unreachable!(), + }) + } +} + +/// The SUM aggregate function. +/// +/// Note that there are `SUM(Decimal) -> Decimal` and `SUM(Double) -> Double`. +#[derive(Debug, AggrFunction)] +#[aggr_function(state = AggrFnStateSum::::new())] +pub struct AggrFnSum +where + T: Summable, + VectorValue: VectorValueExt, +{ + _phantom: std::marker::PhantomData, +} + +impl AggrFnSum +where + T: Summable, + VectorValue: VectorValueExt, +{ + pub fn new() -> Self { + Self { + _phantom: std::marker::PhantomData, + } + } +} + +/// The state of the SUM aggregate function. +#[derive(Debug)] +pub struct AggrFnStateSum +where + T: Summable, + VectorValue: VectorValueExt, +{ + sum: T, + has_value: bool, +} + +impl AggrFnStateSum +where + T: Summable, + VectorValue: VectorValueExt, +{ + pub fn new() -> Self { + Self { + sum: T::zero(), + has_value: false, + } + } +} + +impl super::ConcreteAggrFunctionState for AggrFnStateSum +where + T: Summable, + VectorValue: VectorValueExt, +{ + type ParameterType = T; + + #[inline] + fn update_concrete(&mut self, ctx: &mut EvalContext, value: &Option) -> Result<()> { + match value { + None => Ok(()), + Some(value) => { + self.sum.add_assign(ctx, value)?; + self.has_value = true; + Ok(()) + } + } + } + + #[inline] + fn push_result(&self, _ctx: &mut EvalContext, target: &mut [VectorValue]) -> Result<()> { + if !self.has_value { + target[0].push(None); + } else { + target[0].push(Some(self.sum.clone())); + } + Ok(()) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + use cop_datatype::{FieldTypeAccessor, FieldTypeTp}; + use tipb_helper::ExprDefBuilder; + + use crate::coprocessor::codec::batch::{LazyBatchColumn, LazyBatchColumnVec}; + use crate::coprocessor::dag::aggr_fn::parser::AggrDefinitionParser; + + /// SUM(Bytes) should produce (Real). + #[test] + fn test_integration() { + let expr = ExprDefBuilder::aggr_func(ExprType::Sum, FieldTypeTp::Double) + .push_child(ExprDefBuilder::column_ref(0, FieldTypeTp::VarString)) + .build(); + AggrFnDefinitionParserSum.check_supported(&expr).unwrap(); + + let src_schema = [FieldTypeTp::VarString.into()]; + let mut columns = LazyBatchColumnVec::from(vec![{ + let mut col = LazyBatchColumn::decoded_with_capacity_and_tp(0, EvalType::Bytes); + col.mut_decoded().push_bytes(Some(b"12.5".to_vec())); + col.mut_decoded().push_bytes(None); + col.mut_decoded().push_bytes(Some(b"42.0".to_vec())); + col.mut_decoded().push_bytes(None); + col + }]); + + let mut schema = vec![]; + let mut exp = vec![]; + + let aggr_fn = AggrFnDefinitionParserSum + .parse(expr, &Tz::utc(), &src_schema, &mut schema, &mut exp) + .unwrap(); + assert_eq!(schema.len(), 1); + assert_eq!(schema[0].tp(), FieldTypeTp::Double); + + assert_eq!(exp.len(), 1); + + let mut state = aggr_fn.create_state(); + let mut ctx = EvalContext::default(); + + let exp_result = exp[0].eval(&mut ctx, 4, &src_schema, &mut columns).unwrap(); + assert!(exp_result.is_vector()); + let slice: &[Option] = exp_result.vector_value().unwrap().as_ref(); + state.update_vector(&mut ctx, slice).unwrap(); + + let mut aggr_result = [VectorValue::with_capacity(0, EvalType::Real)]; + state.push_result(&mut ctx, &mut aggr_result).unwrap(); + + assert_eq!(aggr_result[0].as_real_slice(), &[Real::new(54.5).ok()]); + } +} diff --git a/src/coprocessor/dag/aggr_fn/mod.rs b/src/coprocessor/dag/aggr_fn/mod.rs index cf5326a9513b..afe1b73db5ed 100644 --- a/src/coprocessor/dag/aggr_fn/mod.rs +++ b/src/coprocessor/dag/aggr_fn/mod.rs @@ -5,6 +5,7 @@ mod impl_avg; mod impl_count; mod impl_first; +mod impl_sum; mod parser; mod summable; mod util; diff --git a/src/coprocessor/dag/aggr_fn/parser.rs b/src/coprocessor/dag/aggr_fn/parser.rs index ec6f25583b53..0a7fcd1dcada 100644 --- a/src/coprocessor/dag/aggr_fn/parser.rs +++ b/src/coprocessor/dag/aggr_fn/parser.rs @@ -43,6 +43,7 @@ pub trait AggrDefinitionParser { fn map_pb_sig_to_aggr_func_parser(value: ExprType) -> Result> { match value { ExprType::Count => Ok(Box::new(super::impl_count::AggrFnDefinitionParserCount)), + ExprType::Sum => Ok(Box::new(super::impl_sum::AggrFnDefinitionParserSum)), ExprType::Avg => Ok(Box::new(super::impl_avg::AggrFnDefinitionParserAvg)), ExprType::First => Ok(Box::new(super::impl_first::AggrFnDefinitionParserFirst)), v => Err(box_err!(