Skip to content

Commit

Permalink
Coprocessor: Support SUM() (tikv#4797)
Browse files Browse the repository at this point in the history
Signed-off-by: Breezewish <breezewish@pingcap.com>
  • Loading branch information
breezewish committed Jun 12, 2019
1 parent 6d6e2f8 commit 6cad183
Show file tree
Hide file tree
Showing 3 changed files with 192 additions and 0 deletions.
190 changes: 190 additions & 0 deletions src/coprocessor/dag/aggr_fn/impl_sum.rs
@@ -0,0 +1,190 @@
// Copyright 2019 TiKV Project Authors. Licensed under Apache-2.0.

use cop_codegen::AggrFunction;
use cop_datatype::EvalType;
use tipb::expression::{Expr, ExprType, FieldType};

use super::summable::Summable;
use crate::coprocessor::codec::data_type::*;
use crate::coprocessor::codec::mysql::Tz;
use crate::coprocessor::dag::expr::EvalContext;
use crate::coprocessor::dag::rpn_expr::{RpnExpression, RpnExpressionBuilder};
use crate::coprocessor::Result;

/// The parser for SUM aggregate function.
pub struct AggrFnDefinitionParserSum;

impl super::parser::AggrDefinitionParser for AggrFnDefinitionParserSum {
fn check_supported(&self, aggr_def: &Expr) -> Result<()> {
assert_eq!(aggr_def.get_tp(), ExprType::Sum);
super::util::check_aggr_exp_supported_one_child(aggr_def)
}

fn parse(
&self,
mut aggr_def: Expr,
time_zone: &Tz,
src_schema: &[FieldType],
out_schema: &mut Vec<FieldType>,
out_exp: &mut Vec<RpnExpression>,
) -> Result<Box<dyn super::AggrFunction>> {
use cop_datatype::FieldTypeAccessor;
use std::convert::TryFrom;

assert_eq!(aggr_def.get_tp(), ExprType::Sum);

// SUM outputs one column.
out_schema.push(aggr_def.take_field_type());

// Rewrite expression, inserting CAST if necessary. See `typeInfer4Sum` in TiDB.
let child = aggr_def.take_children().into_iter().next().unwrap();
let mut exp =
RpnExpressionBuilder::build_from_expr_tree(child, time_zone, src_schema.len())?;
// The rewrite should always success.
super::util::rewrite_exp_for_sum_avg(src_schema, &mut exp).unwrap();

let rewritten_eval_type = EvalType::try_from(exp.ret_field_type(src_schema).tp()).unwrap();
out_exp.push(exp);

// Choose a type-aware SUM implementation based on the eval type after rewriting exp.
Ok(match rewritten_eval_type {
EvalType::Decimal => Box::new(AggrFnSum::<Decimal>::new()),
EvalType::Real => Box::new(AggrFnSum::<Real>::new()),
// If we meet unexpected types after rewriting, it is an implementation fault.
_ => unreachable!(),
})
}
}

/// The SUM aggregate function.
///
/// Note that there are `SUM(Decimal) -> Decimal` and `SUM(Double) -> Double`.
#[derive(Debug, AggrFunction)]
#[aggr_function(state = AggrFnStateSum::<T>::new())]
pub struct AggrFnSum<T>
where
T: Summable,
VectorValue: VectorValueExt<T>,
{
_phantom: std::marker::PhantomData<T>,
}

impl<T> AggrFnSum<T>
where
T: Summable,
VectorValue: VectorValueExt<T>,
{
pub fn new() -> Self {
Self {
_phantom: std::marker::PhantomData,
}
}
}

/// The state of the SUM aggregate function.
#[derive(Debug)]
pub struct AggrFnStateSum<T>
where
T: Summable,
VectorValue: VectorValueExt<T>,
{
sum: T,
has_value: bool,
}

impl<T> AggrFnStateSum<T>
where
T: Summable,
VectorValue: VectorValueExt<T>,
{
pub fn new() -> Self {
Self {
sum: T::zero(),
has_value: false,
}
}
}

impl<T> super::ConcreteAggrFunctionState for AggrFnStateSum<T>
where
T: Summable,
VectorValue: VectorValueExt<T>,
{
type ParameterType = T;

#[inline]
fn update_concrete(&mut self, ctx: &mut EvalContext, value: &Option<T>) -> Result<()> {
match value {
None => Ok(()),
Some(value) => {
self.sum.add_assign(ctx, value)?;
self.has_value = true;
Ok(())
}
}
}

#[inline]
fn push_result(&self, _ctx: &mut EvalContext, target: &mut [VectorValue]) -> Result<()> {
if !self.has_value {
target[0].push(None);
} else {
target[0].push(Some(self.sum.clone()));
}
Ok(())
}
}

#[cfg(test)]
mod tests {
use super::*;

use cop_datatype::{FieldTypeAccessor, FieldTypeTp};
use tipb_helper::ExprDefBuilder;

use crate::coprocessor::codec::batch::{LazyBatchColumn, LazyBatchColumnVec};
use crate::coprocessor::dag::aggr_fn::parser::AggrDefinitionParser;

/// SUM(Bytes) should produce (Real).
#[test]
fn test_integration() {
let expr = ExprDefBuilder::aggr_func(ExprType::Sum, FieldTypeTp::Double)
.push_child(ExprDefBuilder::column_ref(0, FieldTypeTp::VarString))
.build();
AggrFnDefinitionParserSum.check_supported(&expr).unwrap();

let src_schema = [FieldTypeTp::VarString.into()];
let mut columns = LazyBatchColumnVec::from(vec![{
let mut col = LazyBatchColumn::decoded_with_capacity_and_tp(0, EvalType::Bytes);
col.mut_decoded().push_bytes(Some(b"12.5".to_vec()));
col.mut_decoded().push_bytes(None);
col.mut_decoded().push_bytes(Some(b"42.0".to_vec()));
col.mut_decoded().push_bytes(None);
col
}]);

let mut schema = vec![];
let mut exp = vec![];

let aggr_fn = AggrFnDefinitionParserSum
.parse(expr, &Tz::utc(), &src_schema, &mut schema, &mut exp)
.unwrap();
assert_eq!(schema.len(), 1);
assert_eq!(schema[0].tp(), FieldTypeTp::Double);

assert_eq!(exp.len(), 1);

let mut state = aggr_fn.create_state();
let mut ctx = EvalContext::default();

let exp_result = exp[0].eval(&mut ctx, 4, &src_schema, &mut columns).unwrap();
assert!(exp_result.is_vector());
let slice: &[Option<Real>] = exp_result.vector_value().unwrap().as_ref();
state.update_vector(&mut ctx, slice).unwrap();

let mut aggr_result = [VectorValue::with_capacity(0, EvalType::Real)];
state.push_result(&mut ctx, &mut aggr_result).unwrap();

assert_eq!(aggr_result[0].as_real_slice(), &[Real::new(54.5).ok()]);
}
}
1 change: 1 addition & 0 deletions src/coprocessor/dag/aggr_fn/mod.rs
Expand Up @@ -5,6 +5,7 @@
mod impl_avg;
mod impl_count;
mod impl_first;
mod impl_sum;
mod parser;
mod summable;
mod util;
Expand Down
1 change: 1 addition & 0 deletions src/coprocessor/dag/aggr_fn/parser.rs
Expand Up @@ -43,6 +43,7 @@ pub trait AggrDefinitionParser {
fn map_pb_sig_to_aggr_func_parser(value: ExprType) -> Result<Box<dyn AggrDefinitionParser>> {
match value {
ExprType::Count => Ok(Box::new(super::impl_count::AggrFnDefinitionParserCount)),
ExprType::Sum => Ok(Box::new(super::impl_sum::AggrFnDefinitionParserSum)),
ExprType::Avg => Ok(Box::new(super::impl_avg::AggrFnDefinitionParserAvg)),
ExprType::First => Ok(Box::new(super::impl_first::AggrFnDefinitionParserFirst)),
v => Err(box_err!(
Expand Down

0 comments on commit 6cad183

Please sign in to comment.