diff --git a/src/query/functions/src/aggregates/adaptors/aggregate_combinator_distinct.rs b/src/query/functions/src/aggregates/adaptors/aggregate_combinator_distinct.rs index df51adc0a6655..68b973c615013 100644 --- a/src/query/functions/src/aggregates/adaptors/aggregate_combinator_distinct.rs +++ b/src/query/functions/src/aggregates/adaptors/aggregate_combinator_distinct.rs @@ -36,27 +36,43 @@ use super::aggregate_distinct_state::AggregateDistinctState; use super::aggregate_distinct_state::AggregateDistinctStringState; use super::aggregate_distinct_state::AggregateUniqStringState; use super::aggregate_distinct_state::DistinctStateFunc; +use super::aggregate_null_result::AggregateNullResultFunction; use super::assert_variadic_arguments; use super::AggrState; use super::AggrStateLoc; use super::AggregateCountFunction; use super::AggregateFunction; +use super::AggregateFunctionCombinatorNull; use super::AggregateFunctionCreator; use super::AggregateFunctionDescription; +use super::AggregateFunctionFeatures; use super::AggregateFunctionSortDesc; use super::CombinatorDescription; use super::StateAddr; -#[derive(Clone)] pub struct AggregateDistinctCombinator { name: String, nested_name: String, arguments: Vec, + skip_null: bool, nested: Arc, _s: PhantomData, } +impl Clone for AggregateDistinctCombinator { + fn clone(&self) -> Self { + Self { + name: self.name.clone(), + nested_name: self.nested_name.clone(), + arguments: self.arguments.clone(), + skip_null: self.skip_null, + nested: self.nested.clone(), + _s: PhantomData, + } + } +} + impl AggregateDistinctCombinator where State: Send + 'static { @@ -104,12 +120,12 @@ where State: DistinctStateFunc input_rows: usize, ) -> Result<()> { let state = Self::get_state(place); - state.batch_add(columns, validity, input_rows) + state.batch_add(columns, validity, input_rows, self.skip_null) } fn accumulate_row(&self, place: AggrState, columns: ProjectedBlock, row: usize) -> Result<()> { let state = Self::get_state(place); - state.add(columns, row) + state.add(columns, row, self.skip_null) } fn serialize_type(&self) -> Vec { @@ -202,32 +218,63 @@ pub fn aggregate_combinator_distinct_desc() -> CombinatorDescription { CombinatorDescription::creator(Box::new(try_create)) } -pub fn aggregate_combinator_uniq_desc() -> AggregateFunctionDescription { - let features = super::AggregateFunctionFeatures { +pub fn aggregate_uniq_desc() -> AggregateFunctionDescription { + let features = AggregateFunctionFeatures { returns_default_when_only_null: true, ..Default::default() }; - AggregateFunctionDescription::creator_with_features(Box::new(try_create_uniq), features) + AggregateFunctionDescription::creator_with_features( + Box::new(|nested_name, params, arguments, sort_descs| { + let creator = Box::new(AggregateCountFunction::try_create) as _; + try_create(nested_name, params, arguments, sort_descs, &creator) + }), + features, + ) } -pub fn try_create_uniq( - nested_name: &str, - params: Vec, - arguments: Vec, - sort_descs: Vec, -) -> Result> { - let creator: AggregateFunctionCreator = Box::new(AggregateCountFunction::try_create); - try_create(nested_name, params, arguments, sort_descs, &creator) +pub fn aggregate_count_distinct_desc() -> AggregateFunctionDescription { + AggregateFunctionDescription::creator_with_features( + Box::new(|_, params, arguments, _| { + let count_creator = Box::new(AggregateCountFunction::try_create) as _; + match *arguments { + [DataType::Nullable(_)] => { + let new_arguments = + AggregateFunctionCombinatorNull::transform_arguments(&arguments)?; + let nested = try_create( + "count", + params.clone(), + new_arguments, + vec![], + &count_creator, + )?; + AggregateFunctionCombinatorNull::try_create(params, arguments, nested, true) + } + ref arguments + if !arguments.is_empty() && arguments.iter().all(DataType::is_null) => + { + AggregateNullResultFunction::try_create(DataType::Number( + NumberDataType::UInt64, + )) + } + _ => try_create("count", params, arguments, vec![], &count_creator), + } + }), + AggregateFunctionFeatures { + returns_default_when_only_null: true, + keep_nullable: true, + ..Default::default() + }, + ) } -pub fn try_create( +fn try_create( nested_name: &str, params: Vec, arguments: Vec, sort_descs: Vec, nested_creator: &AggregateFunctionCreator, ) -> Result> { - let name = format!("DistinctCombinator({})", nested_name); + let name = format!("DistinctCombinator({nested_name})"); assert_variadic_arguments(&name, arguments.len(), (1, 32))?; let nested_arguments = match nested_name { @@ -236,53 +283,54 @@ pub fn try_create( }; let nested = nested_creator(nested_name, params, nested_arguments, sort_descs)?; - if arguments.len() == 1 { - match &arguments[0] { - DataType::Number(ty) => with_number_mapped_type!(|NUM_TYPE| match ty { - NumberDataType::NUM_TYPE => { - return Ok(Arc::new(AggregateDistinctCombinator::< - AggregateDistinctNumberState, - > { - nested_name: nested_name.to_owned(), - arguments, - nested, - name, - _s: PhantomData, - })); - } - }), - DataType::String => { - return match nested_name { - "count" | "uniq" => Ok(Arc::new(AggregateDistinctCombinator::< - AggregateUniqStringState, - > { - name, - arguments, - nested, - nested_name: nested_name.to_owned(), - _s: PhantomData, - })), - _ => Ok(Arc::new(AggregateDistinctCombinator::< - AggregateDistinctStringState, - > { - nested_name: nested_name.to_owned(), - arguments, - nested, - name, - _s: PhantomData, - })), - }; + match *arguments { + [DataType::Number(ty)] => with_number_mapped_type!(|NUM_TYPE| match ty { + NumberDataType::NUM_TYPE => { + Ok(Arc::new(AggregateDistinctCombinator::< + AggregateDistinctNumberState, + > { + nested_name: nested_name.to_owned(), + arguments, + skip_null: false, + nested, + name, + _s: PhantomData, + })) } - _ => {} + }), + [DataType::String] if matches!(nested_name, "count" | "uniq") => { + Ok(Arc::new(AggregateDistinctCombinator::< + AggregateUniqStringState, + > { + name, + arguments, + skip_null: false, + nested, + nested_name: nested_name.to_owned(), + _s: PhantomData, + })) } + [DataType::String] => Ok(Arc::new(AggregateDistinctCombinator::< + AggregateDistinctStringState, + > { + nested_name: nested_name.to_owned(), + arguments, + skip_null: false, + nested, + name, + _s: PhantomData, + })), + _ => Ok(Arc::new(AggregateDistinctCombinator::< + AggregateDistinctState, + > { + nested_name: nested_name.to_owned(), + skip_null: nested_name == "count" + && arguments.len() > 1 + && arguments.iter().any(DataType::is_nullable_or_null), + arguments, + nested, + name, + _s: PhantomData, + })), } - Ok(Arc::new(AggregateDistinctCombinator::< - AggregateDistinctState, - > { - nested_name: nested_name.to_owned(), - arguments, - nested, - name, - _s: PhantomData, - })) } diff --git a/src/query/functions/src/aggregates/adaptors/aggregate_null_adaptor.rs b/src/query/functions/src/aggregates/adaptors/aggregate_null_adaptor.rs index 302bcd519f142..0679e13df0769 100644 --- a/src/query/functions/src/aggregates/adaptors/aggregate_null_adaptor.rs +++ b/src/query/functions/src/aggregates/adaptors/aggregate_null_adaptor.rs @@ -32,7 +32,6 @@ use super::AggrStateLoc; use super::AggrStateRegistry; use super::AggrStateType; use super::AggregateFunction; -use super::AggregateFunctionFeatures; use super::AggregateFunctionRef; use super::AggregateNullResultFunction; use super::StateAddr; @@ -57,20 +56,15 @@ impl AggregateFunctionCombinatorNull { Ok(results) } - pub fn transform_params(params: &[Scalar]) -> Result> { - Ok(params.to_owned()) - } - pub fn try_create( - _name: &str, params: Vec, arguments: Vec, nested: AggregateFunctionRef, - properties: AggregateFunctionFeatures, + returns_default_when_only_null: bool, ) -> Result { // has_null_types if arguments.iter().any(|f| f == &DataType::Null) { - if properties.returns_default_when_only_null { + if returns_default_when_only_null { return AggregateNullResultFunction::try_create(DataType::Number( NumberDataType::UInt64, )); @@ -78,7 +72,6 @@ impl AggregateFunctionCombinatorNull { return AggregateNullResultFunction::try_create(DataType::Null); } } - let params = Self::transform_params(¶ms)?; let arguments = Self::transform_arguments(&arguments)?; let size = arguments.len(); @@ -90,8 +83,7 @@ impl AggregateFunctionCombinatorNull { } let return_type = nested.return_type()?; - let result_is_null = - !properties.returns_default_when_only_null && return_type.can_inside_nullable(); + let result_is_null = !returns_default_when_only_null && return_type.can_inside_nullable(); match size { 1 => match result_is_null { diff --git a/src/query/functions/src/aggregates/adaptors/aggregate_ornull_adaptor.rs b/src/query/functions/src/aggregates/adaptors/aggregate_ornull_adaptor.rs index 7035ba52cffc9..febab59f8f708 100644 --- a/src/query/functions/src/aggregates/adaptors/aggregate_ornull_adaptor.rs +++ b/src/query/functions/src/aggregates/adaptors/aggregate_ornull_adaptor.rs @@ -30,7 +30,6 @@ use databend_common_expression::StateSerdeItem; use super::AggrState; use super::AggrStateLoc; use super::AggregateFunction; -use super::AggregateFunctionFeatures; use super::AggregateFunctionRef; use super::StateAddr; @@ -44,13 +43,10 @@ pub struct AggregateFunctionOrNullAdaptor { } impl AggregateFunctionOrNullAdaptor { - pub fn create( - nested: AggregateFunctionRef, - features: AggregateFunctionFeatures, - ) -> Result { + pub fn create(nested: AggregateFunctionRef) -> Result { // count/count distinct should not be nullable for empty set, just return zero let inner_return_type = nested.return_type()?; - if features.returns_default_when_only_null || inner_return_type == DataType::Null { + if inner_return_type == DataType::Null { return Ok(nested); } diff --git a/src/query/functions/src/aggregates/aggregate_array_agg.rs b/src/query/functions/src/aggregates/aggregate_array_agg.rs index ad37f011aea1c..c573b2ea0604f 100644 --- a/src/query/functions/src/aggregates/aggregate_array_agg.rs +++ b/src/query/functions/src/aggregates/aggregate_array_agg.rs @@ -62,6 +62,7 @@ use super::AggrState; use super::AggrStateLoc; use super::AggregateFunction; use super::AggregateFunctionDescription; +use super::AggregateFunctionFeatures; use super::AggregateFunctionSortDesc; use super::SerializeInfo; use super::StateAddr; @@ -809,5 +810,12 @@ fn try_create_aggregate_array_agg_function( } pub fn aggregate_array_agg_function_desc() -> AggregateFunctionDescription { - AggregateFunctionDescription::creator(Box::new(try_create_aggregate_array_agg_function)) + AggregateFunctionDescription::creator_with_features( + Box::new(try_create_aggregate_array_agg_function), + AggregateFunctionFeatures { + allow_sort: true, + keep_nullable: true, + ..Default::default() + }, + ) } diff --git a/src/query/functions/src/aggregates/aggregate_array_moving.rs b/src/query/functions/src/aggregates/aggregate_array_moving.rs index 2509aced68de6..5481f76aa5bb4 100644 --- a/src/query/functions/src/aggregates/aggregate_array_moving.rs +++ b/src/query/functions/src/aggregates/aggregate_array_moving.rs @@ -44,6 +44,7 @@ use super::AggrState; use super::AggrStateLoc; use super::AggregateFunction; use super::AggregateFunctionDescription; +use super::AggregateFunctionFeatures; use super::AggregateFunctionRef; use super::AggregateFunctionSortDesc; use super::SerializeInfo; @@ -678,7 +679,13 @@ pub fn try_create_aggregate_array_moving_avg_function( } pub fn aggregate_array_moving_avg_function_desc() -> AggregateFunctionDescription { - AggregateFunctionDescription::creator(Box::new(try_create_aggregate_array_moving_avg_function)) + AggregateFunctionDescription::creator_with_features( + Box::new(try_create_aggregate_array_moving_avg_function), + AggregateFunctionFeatures { + keep_nullable: true, + ..Default::default() + }, + ) } #[derive(Clone)] @@ -859,5 +866,11 @@ pub fn try_create_aggregate_array_moving_sum_function( } pub fn aggregate_array_moving_sum_function_desc() -> AggregateFunctionDescription { - AggregateFunctionDescription::creator(Box::new(try_create_aggregate_array_moving_sum_function)) + AggregateFunctionDescription::creator_with_features( + Box::new(try_create_aggregate_array_moving_sum_function), + AggregateFunctionFeatures { + keep_nullable: true, + ..Default::default() + }, + ) } diff --git a/src/query/functions/src/aggregates/aggregate_distinct_state.rs b/src/query/functions/src/aggregates/aggregate_distinct_state.rs index 3500952ad19ac..116f30cd98be6 100644 --- a/src/query/functions/src/aggregates/aggregate_distinct_state.rs +++ b/src/query/functions/src/aggregates/aggregate_distinct_state.rs @@ -12,7 +12,6 @@ // See the License for the specific language governing permissions and // limitations under the License. -use std::collections::hash_map::RandomState; use std::collections::HashSet; use std::hash::Hasher; use std::marker::Send; @@ -51,19 +50,20 @@ pub(super) trait DistinctStateFunc: Sized + Send + StateSerde + 'static { fn new() -> Self; fn is_empty(&self) -> bool; fn len(&self) -> usize; - fn add(&mut self, columns: ProjectedBlock, row: usize) -> Result<()>; + fn add(&mut self, columns: ProjectedBlock, row: usize, skip_null: bool) -> Result<()>; fn batch_add( &mut self, columns: ProjectedBlock, validity: Option<&Bitmap>, input_rows: usize, + skip_null: bool, ) -> Result<()>; fn merge(&mut self, rhs: &Self) -> Result<()>; fn build_entries(&mut self, types: &[DataType]) -> Result>; } pub struct AggregateDistinctState { - set: HashSet, RandomState>, + set: HashSet>, } impl DistinctStateFunc for AggregateDistinctState { @@ -81,16 +81,16 @@ impl DistinctStateFunc for AggregateDistinctState { self.set.len() } - fn add(&mut self, columns: ProjectedBlock, row: usize) -> Result<()> { + fn add(&mut self, columns: ProjectedBlock, row: usize, skip_null: bool) -> Result<()> { let values = columns .iter() - .map(|entry| match entry { - BlockEntry::Const(scalar, _, _) => scalar.clone(), - BlockEntry::Column(column) => { - unsafe { AnyType::index_column_unchecked(column, row) }.to_owned() - } - }) + .map(|entry| unsafe { entry.index_unchecked(row) }.to_owned()) .collect::>(); + + if skip_null && values.iter().all(Scalar::is_null) { + return Ok(()); + } + let mut buffer = Vec::with_capacity(values.len() * std::mem::size_of::()); values.serialize(&mut buffer)?; self.set.insert(buffer); @@ -102,26 +102,26 @@ impl DistinctStateFunc for AggregateDistinctState { columns: ProjectedBlock, validity: Option<&Bitmap>, input_rows: usize, + skip_null: bool, ) -> Result<()> { - for row in 0..input_rows { - if validity.map(|v| v.get_bit(row)).unwrap_or(true) { - let values = columns - .iter() - .map(|entry| match entry { - BlockEntry::Const(scalar, _, _) => scalar.clone(), - BlockEntry::Column(column) => { - unsafe { AnyType::index_column_unchecked(column, row) }.to_owned() - } - }) - .collect::>(); - - let mut buffer = Vec::with_capacity(values.len() * std::mem::size_of::()); - values.serialize(&mut buffer)?; - self.set.insert(buffer); + match validity { + Some(validity) => { + for (row, b) in (0..input_rows).zip(validity) { + if !b { + continue; + } + self.add(columns, row, skip_null)?; + } + } + None => { + for row in 0..input_rows { + self.add(columns, row, skip_null)?; + } } } Ok(()) } + fn merge(&mut self, rhs: &Self) -> Result<()> { self.set.extend(rhs.set.clone()); Ok(()) @@ -204,7 +204,7 @@ impl DistinctStateFunc for AggregateDistinctStringState { self.set.len() } - fn add(&mut self, columns: ProjectedBlock, row: usize) -> Result<()> { + fn add(&mut self, columns: ProjectedBlock, row: usize, _skip_null: bool) -> Result<()> { let view = columns[0].downcast::().unwrap(); let data = unsafe { view.index_unchecked(row) }; let _ = self.set.set_insert(data.as_bytes()); @@ -216,6 +216,7 @@ impl DistinctStateFunc for AggregateDistinctStringState { columns: ProjectedBlock, validity: Option<&Bitmap>, input_rows: usize, + _skip_null: bool, ) -> Result<()> { let view = columns[0].downcast::().unwrap(); match validity { @@ -311,7 +312,7 @@ where T: Number + HashtableKeyable self.set.len() } - fn add(&mut self, columns: ProjectedBlock, row: usize) -> Result<()> { + fn add(&mut self, columns: ProjectedBlock, row: usize, _skip_null: bool) -> Result<()> { let view = columns[0].downcast::>().unwrap(); let v = unsafe { view.index_unchecked(row) }; let _ = self.set.set_insert(v).is_ok(); @@ -323,6 +324,7 @@ where T: Number + HashtableKeyable columns: ProjectedBlock, validity: Option<&Bitmap>, input_rows: usize, + _skip_null: bool, ) -> Result<()> { let view = columns[0].downcast::>().unwrap(); match validity { @@ -421,7 +423,7 @@ impl DistinctStateFunc for AggregateUniqStringState { self.set.len() } - fn add(&mut self, columns: ProjectedBlock, row: usize) -> Result<()> { + fn add(&mut self, columns: ProjectedBlock, row: usize, _skip_null: bool) -> Result<()> { let view = columns[0].downcast::().unwrap(); let data = unsafe { view.index_unchecked(row) }.as_bytes(); let mut hasher = SipHasher24::new(); @@ -436,6 +438,7 @@ impl DistinctStateFunc for AggregateUniqStringState { columns: ProjectedBlock, validity: Option<&Bitmap>, input_rows: usize, + _skip_null: bool, ) -> Result<()> { let view = columns[0].downcast::().unwrap(); match validity { diff --git a/src/query/functions/src/aggregates/aggregate_function_factory.rs b/src/query/functions/src/aggregates/aggregate_function_factory.rs index b4f4e91791ffe..29f7bc4991353 100644 --- a/src/query/functions/src/aggregates/aggregate_function_factory.rs +++ b/src/query/functions/src/aggregates/aggregate_function_factory.rs @@ -27,18 +27,6 @@ use super::AggregateFunctionRef; use super::AggregateFunctionSortAdaptor; use super::Aggregators; -// The NULL value in the those function needs to be handled separately. -const NEED_NULL_AGGREGATE_FUNCTIONS: [&str; 8] = [ - "array_agg", - "list", - "json_agg", - "json_array_agg", - "json_object_agg", - "group_array_moving_avg", - "group_array_moving_sum", - "st_collect", -]; - const STATE_SUFFIX: &str = "_state"; pub type AggregateFunctionCreator = Box< @@ -72,7 +60,7 @@ static FACTORY: LazyLock> = LazyLock::new(|| { }); pub struct AggregateFunctionDescription { - pub(crate) aggregate_function_creator: AggregateFunctionCreator, + pub(crate) creator: AggregateFunctionCreator, pub(crate) features: AggregateFunctionFeatures, } @@ -95,6 +83,11 @@ pub struct AggregateFunctionFeatures { /// AVG(C) = SUM(C) / COUNT(C) pub(crate) is_decomposable: bool, + pub(crate) allow_sort: bool, + + // The NULL value in the those function needs to be handled separately. + pub(crate) keep_nullable: bool, + // Function Category pub category: &'static str, // Introduce the function in brief. @@ -108,7 +101,7 @@ pub struct AggregateFunctionFeatures { impl AggregateFunctionDescription { pub fn creator(creator: AggregateFunctionCreator) -> AggregateFunctionDescription { AggregateFunctionDescription { - aggregate_function_creator: creator, + creator, features: AggregateFunctionFeatures { returns_default_when_only_null: false, is_decomposable: false, @@ -121,10 +114,7 @@ impl AggregateFunctionDescription { creator: AggregateFunctionCreator, features: AggregateFunctionFeatures, ) -> AggregateFunctionDescription { - AggregateFunctionDescription { - aggregate_function_creator: creator, - features, - } + AggregateFunctionDescription { creator, features } } } @@ -214,111 +204,106 @@ impl AggregateFunctionFactory { sort_descs: Vec, or_null: bool, ) -> Result { - let name = name.as_ref(); - let mut features = AggregateFunctionFeatures::default(); + let name = name.as_ref().to_lowercase(); + + if let Some(desc) = self.case_insensitive_desc.get(&name) { + if desc.features.keep_nullable { + let agg = (desc.creator)(&name, params, arguments, sort_descs.clone())?; + return if desc.features.allow_sort { + AggregateFunctionSortAdaptor::create(agg, sort_descs) + } else { + Ok(agg) + }; + } - if NEED_NULL_AGGREGATE_FUNCTIONS.contains(&name) { - let mut agg = - self.get_impl(name, params, arguments, sort_descs.clone(), &mut features)?; - if !sort_descs.is_empty() { - agg = AggregateFunctionSortAdaptor::create(agg, sort_descs)? + if arguments.iter().all(|f| !f.is_nullable_or_null()) { + let mut agg = (desc.creator)(&name, params, arguments, sort_descs.clone())?; + if desc.features.allow_sort { + agg = AggregateFunctionSortAdaptor::create(agg, sort_descs)? + } + return if or_null && !desc.features.returns_default_when_only_null { + AggregateFunctionOrNullAdaptor::create(agg) + } else { + Ok(agg) + }; } - return Ok(agg); - } - if arguments.iter().all(|f| !f.is_nullable_or_null()) { - let mut agg = - self.get_impl(name, params, arguments, sort_descs.clone(), &mut features)?; - if !sort_descs.is_empty() { + let new_arguments = AggregateFunctionCombinatorNull::transform_arguments(&arguments)?; + + let nested = (desc.creator)(&name, params.clone(), new_arguments, sort_descs.clone())?; + let mut agg = AggregateFunctionCombinatorNull::try_create( + params, + arguments, + nested, + desc.features.returns_default_when_only_null, + )?; + if desc.features.allow_sort { agg = AggregateFunctionSortAdaptor::create(agg, sort_descs)? } - if or_null { - agg = AggregateFunctionOrNullAdaptor::create(agg, features)? - } - return Ok(agg); + return if or_null && !desc.features.returns_default_when_only_null { + AggregateFunctionOrNullAdaptor::create(agg) + } else { + Ok(agg) + }; } - let nested = if name.to_lowercase().strip_suffix(STATE_SUFFIX).is_some() { - self.get_impl( - name, + // find suffix + let Some((nested_name, suffix, nested_desc, desc)) = self + .case_insensitive_combinator_desc + .iter() + .find_map(|(suffix, desc)| { + name.strip_suffix(suffix) + .map(|nested_name| (nested_name, suffix, desc)) + }) + .and_then(|(nested_name, suffix, desc)| { + self.case_insensitive_desc + .get(nested_name) + .map(|nested_desc| (nested_name, suffix, nested_desc, desc)) + }) + else { + return Err(ErrorCode::UnknownAggregateFunction(format!( + "Unsupported AggregateFunction: {name}", + ))); + }; + + let (nested, features) = if suffix == STATE_SUFFIX { + let nested = (*desc.creator)( + nested_name, params.clone(), arguments.clone(), sort_descs.clone(), - &mut features, - )? + &nested_desc.creator, + )?; + let mut features = nested_desc.features.clone(); + features.returns_default_when_only_null = true; + (nested, features) } else { - let new_params = AggregateFunctionCombinatorNull::transform_params(¶ms)?; let new_arguments = AggregateFunctionCombinatorNull::transform_arguments(&arguments)?; - self.get_impl( - name, - new_params, + + let nested = (*desc.creator)( + nested_name, + params.clone(), new_arguments, sort_descs.clone(), - &mut features, - )? + &nested_desc.creator, + )?; + (nested, nested_desc.features.clone()) }; let mut agg = AggregateFunctionCombinatorNull::try_create( - name, params, arguments, nested, - features.clone(), + features.returns_default_when_only_null, )?; - if !sort_descs.is_empty() { + if features.allow_sort { agg = AggregateFunctionSortAdaptor::create(agg, sort_descs)? } - if or_null { - agg = AggregateFunctionOrNullAdaptor::create(agg, features)? - } - Ok(agg) - } - - fn get_impl( - &self, - name: &str, - params: Vec, - arguments: Vec, - sort_descs: Vec, - features: &mut AggregateFunctionFeatures, - ) -> Result { - let lowercase_name = name.to_lowercase(); - let aggregate_functions_map = &self.case_insensitive_desc; - if let Some(desc) = aggregate_functions_map.get(&lowercase_name) { - *features = desc.features.clone(); - return (desc.aggregate_function_creator)(name, params, arguments, sort_descs); - } - - // find suffix - for (suffix, desc) in &self.case_insensitive_combinator_desc { - if let Some(nested_name) = lowercase_name.strip_suffix(suffix) { - let aggregate_functions_map = &self.case_insensitive_desc; - - match aggregate_functions_map.get(nested_name) { - None => { - break; - } - Some(nested_desc) => { - *features = nested_desc.features.clone(); - if suffix.eq_ignore_ascii_case(STATE_SUFFIX) { - features.returns_default_when_only_null = true; - } - return (desc.creator)( - nested_name, - params, - arguments, - sort_descs, - &nested_desc.aggregate_function_creator, - ); - } - } - } + if or_null && !features.returns_default_when_only_null { + AggregateFunctionOrNullAdaptor::create(agg) + } else { + Ok(agg) } - - Err(ErrorCode::UnknownAggregateFunction(format!( - "Unsupported AggregateFunction: {}", - name - ))) } pub fn contains(&self, func_name: impl AsRef) -> bool { diff --git a/src/query/functions/src/aggregates/aggregate_json_array_agg.rs b/src/query/functions/src/aggregates/aggregate_json_array_agg.rs index 575e8cedb65b3..76bea22c0ebbc 100644 --- a/src/query/functions/src/aggregates/aggregate_json_array_agg.rs +++ b/src/query/functions/src/aggregates/aggregate_json_array_agg.rs @@ -49,6 +49,7 @@ use super::AggrState; use super::AggrStateLoc; use super::AggregateFunction; use super::AggregateFunctionDescription; +use super::AggregateFunctionFeatures; use super::AggregateFunctionSortDesc; use super::StateAddr; use super::StateSerde; @@ -349,7 +350,7 @@ where } } -pub fn try_create_aggregate_json_array_agg_function( +fn try_create_aggregate_json_array_agg_function( display_name: &str, params: Vec, argument_types: Vec, @@ -364,5 +365,11 @@ pub fn try_create_aggregate_json_array_agg_function( } pub fn aggregate_json_array_agg_function_desc() -> AggregateFunctionDescription { - AggregateFunctionDescription::creator(Box::new(try_create_aggregate_json_array_agg_function)) + AggregateFunctionDescription::creator_with_features( + Box::new(try_create_aggregate_json_array_agg_function), + AggregateFunctionFeatures { + keep_nullable: true, + ..Default::default() + }, + ) } diff --git a/src/query/functions/src/aggregates/aggregate_json_object_agg.rs b/src/query/functions/src/aggregates/aggregate_json_object_agg.rs index 6df4e177b4ed1..3081810bdf4a3 100644 --- a/src/query/functions/src/aggregates/aggregate_json_object_agg.rs +++ b/src/query/functions/src/aggregates/aggregate_json_object_agg.rs @@ -51,6 +51,7 @@ use super::AggregateFunction; use super::AggregateFunctionDescription; use super::AggregateFunctionSortDesc; use super::StateAddr; +use crate::aggregates::AggregateFunctionFeatures; pub(super) trait BinaryScalarStateFunc: BorshSerialize + BorshDeserialize + Send + 'static @@ -444,5 +445,11 @@ pub fn try_create_aggregate_json_object_agg_function( } pub fn aggregate_json_object_agg_function_desc() -> AggregateFunctionDescription { - AggregateFunctionDescription::creator(Box::new(try_create_aggregate_json_object_agg_function)) + AggregateFunctionDescription::creator_with_features( + Box::new(try_create_aggregate_json_object_agg_function), + AggregateFunctionFeatures { + keep_nullable: true, + ..Default::default() + }, + ) } diff --git a/src/query/functions/src/aggregates/aggregate_st_collect.rs b/src/query/functions/src/aggregates/aggregate_st_collect.rs index 0df5f643ca2c4..119fd34669026 100644 --- a/src/query/functions/src/aggregates/aggregate_st_collect.rs +++ b/src/query/functions/src/aggregates/aggregate_st_collect.rs @@ -56,6 +56,7 @@ use super::AggrState; use super::AggrStateLoc; use super::AggregateFunction; use super::AggregateFunctionDescription; +use super::AggregateFunctionFeatures; use super::AggregateFunctionSortDesc; use super::StateAddr; use super::StateSerde; @@ -401,7 +402,7 @@ where } } -pub fn try_create_aggregate_st_collect_function( +fn try_create_aggregate_st_collect_function( display_name: &str, params: Vec, argument_types: Vec, @@ -423,5 +424,11 @@ pub fn try_create_aggregate_st_collect_function( } pub fn aggregate_st_collect_function_desc() -> AggregateFunctionDescription { - AggregateFunctionDescription::creator(Box::new(try_create_aggregate_st_collect_function)) + AggregateFunctionDescription::creator_with_features( + Box::new(try_create_aggregate_st_collect_function), + AggregateFunctionFeatures { + keep_nullable: true, + ..Default::default() + }, + ) } diff --git a/src/query/functions/src/aggregates/aggregate_string_agg.rs b/src/query/functions/src/aggregates/aggregate_string_agg.rs index 8b35d0c9a110f..68b8c8f8aec6a 100644 --- a/src/query/functions/src/aggregates/aggregate_string_agg.rs +++ b/src/query/functions/src/aggregates/aggregate_string_agg.rs @@ -39,6 +39,7 @@ use super::assert_variadic_arguments; use super::batch_merge1; use super::batch_serialize1; use super::AggregateFunctionDescription; +use super::AggregateFunctionFeatures; use super::AggregateFunctionSortDesc; use super::AggregateUnaryFunction; use super::SerializeInfo; @@ -201,5 +202,11 @@ pub fn try_create_aggregate_string_agg_function( } pub fn aggregate_string_agg_function_desc() -> AggregateFunctionDescription { - AggregateFunctionDescription::creator(Box::new(try_create_aggregate_string_agg_function)) + AggregateFunctionDescription::creator_with_features( + Box::new(try_create_aggregate_string_agg_function), + AggregateFunctionFeatures { + allow_sort: true, + ..Default::default() + }, + ) } diff --git a/src/query/functions/src/aggregates/aggregator.rs b/src/query/functions/src/aggregates/aggregator.rs index 0d0047c61b4d7..59a64f69124fa 100644 --- a/src/query/functions/src/aggregates/aggregator.rs +++ b/src/query/functions/src/aggregates/aggregator.rs @@ -12,6 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. +use super::adaptors::*; use super::aggregate_approx_count_distinct::aggregate_approx_count_distinct_function_desc; use super::aggregate_arg_min_max::aggregate_arg_max_function_desc; use super::aggregate_arg_min_max::aggregate_arg_min_function_desc; @@ -27,8 +28,6 @@ use super::aggregate_bitmap::aggregate_bitmap_or_count_function_desc; use super::aggregate_bitmap::aggregate_bitmap_union_function_desc; use super::aggregate_bitmap::aggregate_bitmap_xor_count_function_desc; use super::aggregate_boolean::aggregate_boolean_function_desc; -use super::aggregate_combinator_distinct_desc; -use super::aggregate_combinator_uniq_desc; use super::aggregate_covariance::aggregate_covariance_population_desc; use super::aggregate_covariance::aggregate_covariance_sample_desc; use super::aggregate_histogram::aggregate_histogram_function_desc; @@ -54,13 +53,13 @@ use super::aggregate_st_collect::aggregate_st_collect_function_desc; use super::aggregate_stddev::aggregate_stddev_pop_function_desc; use super::aggregate_stddev::aggregate_stddev_samp_function_desc; use super::aggregate_string_agg::aggregate_string_agg_function_desc; -use super::aggregate_sum_function_desc; +use super::aggregate_sum::aggregate_sum_function_desc; +use super::aggregate_sum_zero::AggregateSumZeroFunction; use super::aggregate_window_funnel::aggregate_window_funnel_function_desc; use super::AggregateCountFunction; use super::AggregateFunctionFactory; use super::AggregateIfCombinator; use super::AggregateStateCombinator; -use crate::aggregates::aggregate_sum_zero::AggregateSumZeroFunction; pub struct Aggregators; @@ -72,7 +71,8 @@ impl Aggregators { factory.register("sum0", AggregateSumZeroFunction::desc()); factory.register("sum_zero", AggregateSumZeroFunction::desc()); factory.register("avg", aggregate_avg_function_desc()); - factory.register("uniq", aggregate_combinator_uniq_desc()); + factory.register("uniq", aggregate_uniq_desc()); + factory.register("count_distinct", aggregate_count_distinct_desc()); factory.register("min", aggregate_min_function_desc()); factory.register("max", aggregate_max_function_desc()); diff --git a/src/query/functions/tests/it/aggregates/agg.rs b/src/query/functions/tests/it/aggregates/agg.rs index 28ffec61a53ea..d9a90f4235911 100644 --- a/src/query/functions/tests/it/aggregates/agg.rs +++ b/src/query/functions/tests/it/aggregates/agg.rs @@ -54,6 +54,7 @@ fn test_aggr_functions() { let file = &mut mint.new_goldenfile("agg.txt").unwrap(); test_count(file, eval_aggr); + test_count_distinct(file, eval_aggr); test_sum(file, eval_aggr); test_avg(file, eval_aggr); test_uniq(file, eval_aggr); @@ -103,6 +104,7 @@ fn test_aggr_functions_group_by() { let file = &mut mint.new_goldenfile("agg_group_by.txt").unwrap(); test_count(file, simulate_two_groups_group_by); + test_count_distinct(file, simulate_two_groups_group_by); test_sum(file, simulate_two_groups_group_by); test_avg(file, simulate_two_groups_group_by); test_uniq(file, simulate_two_groups_group_by); @@ -388,6 +390,31 @@ fn test_count(file: &mut impl Write, simulator: impl AggregationSimulator) { ); } +fn test_count_distinct(file: &mut impl Write, simulator: impl AggregationSimulator) { + let columns = &get_example(); + run_agg_ast(file, "count_distinct(null)", columns, simulator, vec![]); + run_agg_ast( + file, + "count_distinct(null,null)", + columns, + simulator, + vec![], + ); + run_agg_ast(file, "count_distinct(1)", columns, simulator, vec![]); + run_agg_ast(file, "count_distinct(a)", columns, simulator, vec![]); + run_agg_ast(file, "count_distinct(x_null)", columns, simulator, vec![]); + run_agg_ast(file, "count_distinct(x_null,a)", columns, simulator, vec![]); + run_agg_ast(file, "count_distinct(all_null)", columns, simulator, vec![]); + run_agg_ast( + file, + "count_distinct(all_null,s)", + columns, + simulator, + vec![], + ); + run_agg_ast(file, "count_distinct(s_null,s)", columns, simulator, vec![]); +} + fn test_sum(file: &mut impl Write, simulator: impl AggregationSimulator) { run_agg_ast(file, "sum(1)", get_example().as_slice(), simulator, vec![]); run_agg_ast(file, "sum(a)", get_example().as_slice(), simulator, vec![]); diff --git a/src/query/functions/tests/it/aggregates/testdata/agg.txt b/src/query/functions/tests/it/aggregates/testdata/agg.txt index ea42f5475089f..145a358e553e5 100644 --- a/src/query/functions/tests/it/aggregates/testdata/agg.txt +++ b/src/query/functions/tests/it/aggregates/testdata/agg.txt @@ -58,6 +58,99 @@ evaluation (internal): +----------+-------------------------------------------------------------------------+ +ast: count_distinct(null) +evaluation (internal): ++--------+---------------------+ +| Column | Data | ++--------+---------------------+ +| a | Int64([4, 3, 2, 1]) | +| Output | UInt64([0]) | ++--------+---------------------+ + + +ast: count_distinct(null,null) +evaluation (internal): ++--------+---------------------+ +| Column | Data | ++--------+---------------------+ +| a | Int64([4, 3, 2, 1]) | +| Output | UInt64([0]) | ++--------+---------------------+ + + +ast: count_distinct(1) +evaluation (internal): ++--------+---------------------+ +| Column | Data | ++--------+---------------------+ +| a | Int64([4, 3, 2, 1]) | +| Output | UInt64([1]) | ++--------+---------------------+ + + +ast: count_distinct(a) +evaluation (internal): ++--------+---------------------+ +| Column | Data | ++--------+---------------------+ +| a | Int64([4, 3, 2, 1]) | +| Output | UInt64([4]) | ++--------+---------------------+ + + +ast: count_distinct(x_null) +evaluation (internal): ++--------+-------------------------------------------------------------------------+ +| Column | Data | ++--------+-------------------------------------------------------------------------+ +| x_null | NullableColumn { column: UInt64([1, 2, 3, 4]), validity: [0b____0011] } | +| Output | UInt64([2]) | ++--------+-------------------------------------------------------------------------+ + + +ast: count_distinct(x_null,a) +evaluation (internal): ++--------+-------------------------------------------------------------------------+ +| Column | Data | ++--------+-------------------------------------------------------------------------+ +| a | Int64([4, 3, 2, 1]) | +| x_null | NullableColumn { column: UInt64([1, 2, 3, 4]), validity: [0b____0011] } | +| Output | UInt64([4]) | ++--------+-------------------------------------------------------------------------+ + + +ast: count_distinct(all_null) +evaluation (internal): ++----------+-------------------------------------------------------------------------+ +| Column | Data | ++----------+-------------------------------------------------------------------------+ +| all_null | NullableColumn { column: UInt64([1, 2, 3, 4]), validity: [0b____0000] } | +| Output | UInt64([0]) | ++----------+-------------------------------------------------------------------------+ + + +ast: count_distinct(all_null,s) +evaluation (internal): ++----------+-------------------------------------------------------------------------+ +| Column | Data | ++----------+-------------------------------------------------------------------------+ +| all_null | NullableColumn { column: UInt64([1, 2, 3, 4]), validity: [0b____0000] } | +| s | StringColumn[abc, def, opq, xyz] | +| Output | UInt64([4]) | ++----------+-------------------------------------------------------------------------+ + + +ast: count_distinct(s_null,s) +evaluation (internal): ++--------+----------------------------------------------------------------------------+ +| Column | Data | ++--------+----------------------------------------------------------------------------+ +| s | StringColumn[abc, def, opq, xyz] | +| s_null | NullableColumn { column: StringColumn[a, , c, d], validity: [0b____1101] } | +| Output | UInt64([4]) | ++--------+----------------------------------------------------------------------------+ + + ast: sum(1) evaluation (internal): +--------+----------------------------------------------------------------+ diff --git a/src/query/functions/tests/it/aggregates/testdata/agg_group_by.txt b/src/query/functions/tests/it/aggregates/testdata/agg_group_by.txt index 61c84a00ead15..8153bf2913d70 100644 --- a/src/query/functions/tests/it/aggregates/testdata/agg_group_by.txt +++ b/src/query/functions/tests/it/aggregates/testdata/agg_group_by.txt @@ -58,6 +58,99 @@ evaluation (internal): +----------+-------------------------------------------------------------------------+ +ast: count_distinct(null) +evaluation (internal): ++--------+---------------------+ +| Column | Data | ++--------+---------------------+ +| a | Int64([4, 3, 2, 1]) | +| Output | UInt64([0, 0]) | ++--------+---------------------+ + + +ast: count_distinct(null,null) +evaluation (internal): ++--------+---------------------+ +| Column | Data | ++--------+---------------------+ +| a | Int64([4, 3, 2, 1]) | +| Output | UInt64([0, 0]) | ++--------+---------------------+ + + +ast: count_distinct(1) +evaluation (internal): ++--------+---------------------+ +| Column | Data | ++--------+---------------------+ +| a | Int64([4, 3, 2, 1]) | +| Output | UInt64([1, 1]) | ++--------+---------------------+ + + +ast: count_distinct(a) +evaluation (internal): ++--------+---------------------+ +| Column | Data | ++--------+---------------------+ +| a | Int64([4, 3, 2, 1]) | +| Output | UInt64([2, 2]) | ++--------+---------------------+ + + +ast: count_distinct(x_null) +evaluation (internal): ++--------+-------------------------------------------------------------------------+ +| Column | Data | ++--------+-------------------------------------------------------------------------+ +| x_null | NullableColumn { column: UInt64([1, 2, 3, 4]), validity: [0b____0011] } | +| Output | UInt64([1, 1]) | ++--------+-------------------------------------------------------------------------+ + + +ast: count_distinct(x_null,a) +evaluation (internal): ++--------+-------------------------------------------------------------------------+ +| Column | Data | ++--------+-------------------------------------------------------------------------+ +| a | Int64([4, 3, 2, 1]) | +| x_null | NullableColumn { column: UInt64([1, 2, 3, 4]), validity: [0b____0011] } | +| Output | UInt64([2, 2]) | ++--------+-------------------------------------------------------------------------+ + + +ast: count_distinct(all_null) +evaluation (internal): ++----------+-------------------------------------------------------------------------+ +| Column | Data | ++----------+-------------------------------------------------------------------------+ +| all_null | NullableColumn { column: UInt64([1, 2, 3, 4]), validity: [0b____0000] } | +| Output | UInt64([0, 0]) | ++----------+-------------------------------------------------------------------------+ + + +ast: count_distinct(all_null,s) +evaluation (internal): ++----------+-------------------------------------------------------------------------+ +| Column | Data | ++----------+-------------------------------------------------------------------------+ +| all_null | NullableColumn { column: UInt64([1, 2, 3, 4]), validity: [0b____0000] } | +| s | StringColumn[abc, def, opq, xyz] | +| Output | UInt64([2, 2]) | ++----------+-------------------------------------------------------------------------+ + + +ast: count_distinct(s_null,s) +evaluation (internal): ++--------+----------------------------------------------------------------------------+ +| Column | Data | ++--------+----------------------------------------------------------------------------+ +| s | StringColumn[abc, def, opq, xyz] | +| s_null | NullableColumn { column: StringColumn[a, , c, d], validity: [0b____1101] } | +| Output | UInt64([2, 2]) | ++--------+----------------------------------------------------------------------------+ + + ast: sum(1) evaluation (internal): +--------+-------------------------------------------------------------------+ diff --git a/tests/sqllogictests/src/client/http_client.rs b/tests/sqllogictests/src/client/http_client.rs index 1aeb05448cc12..f88b39c04f4d4 100644 --- a/tests/sqllogictests/src/client/http_client.rs +++ b/tests/sqllogictests/src/client/http_client.rs @@ -24,11 +24,11 @@ use reqwest::Client; use reqwest::ClientBuilder; use serde::Deserialize; use sqllogictest::DBOutput; -use sqllogictest::DefaultColumnType; use crate::client::global_cookie_store::GlobalCookieStore; use crate::error::Result; use crate::util::parser_rows; +use crate::util::ColumnType; use crate::util::HttpSessionConf; pub struct HttpClient { @@ -57,7 +57,7 @@ struct SchemaItem { } impl SchemaItem { - fn parse_type(&self) -> Result { + fn parse_type(&self) -> Result { let nullable = Regex::new(r"^Nullable\((.+)\)$").unwrap(); let value = match nullable.captures(&self.r#type) { Some(captures) => { @@ -67,13 +67,14 @@ impl SchemaItem { None => &self.r#type, }; let typ = match value { - "String" => DefaultColumnType::Text, + "Boolean" => ColumnType::Bool, + "String" => ColumnType::Text, "Int8" | "Int16" | "Int32" | "Int64" | "UInt8" | "UInt16" | "UInt32" | "UInt64" => { - DefaultColumnType::Integer + ColumnType::Integer } - "Float32" | "Float64" => DefaultColumnType::FloatingPoint, - decimal if decimal.starts_with("Decimal") => DefaultColumnType::FloatingPoint, - _ => DefaultColumnType::Any, + "Float32" | "Float64" => ColumnType::FloatingPoint, + decimal if decimal.starts_with("Decimal") => ColumnType::FloatingPoint, + _ => ColumnType::Any, }; Ok(typ) } @@ -155,7 +156,7 @@ impl HttpClient { }) } - pub async fn query(&mut self, sql: &str) -> Result> { + pub async fn query(&mut self, sql: &str) -> Result> { let start = Instant::now(); let port = self.port; let mut response = self @@ -193,7 +194,7 @@ impl HttpClient { let types = schema .iter() - .map(|item| item.parse_type().unwrap_or(DefaultColumnType::Any)) + .map(|item| item.parse_type().unwrap_or(ColumnType::Any)) .collect(); Ok(DBOutput::Rows { diff --git a/tests/sqllogictests/src/client/mod.rs b/tests/sqllogictests/src/client/mod.rs index bf089df327465..fae04579ebd9f 100644 --- a/tests/sqllogictests/src/client/mod.rs +++ b/tests/sqllogictests/src/client/mod.rs @@ -26,10 +26,10 @@ use rand::distributions::Alphanumeric; use rand::Rng; use regex::Regex; use sqllogictest::DBOutput; -use sqllogictest::DefaultColumnType; pub use ttc_client::TTCClient; use crate::error::Result; +use crate::util::ColumnType; #[derive(Debug, Clone)] pub enum ClientType { @@ -54,7 +54,7 @@ pub enum Client { } impl Client { - pub async fn query(&mut self, sql: &str) -> Result> { + pub async fn query(&mut self, sql: &str) -> Result> { let sql = replace_rand_values(sql); match self { Client::MySQL(client) => client.query(&sql).await, diff --git a/tests/sqllogictests/src/client/mysql_client.rs b/tests/sqllogictests/src/client/mysql_client.rs index b821e760f3201..3a557942d86d1 100644 --- a/tests/sqllogictests/src/client/mysql_client.rs +++ b/tests/sqllogictests/src/client/mysql_client.rs @@ -19,9 +19,9 @@ use mysql_async::Conn; use mysql_async::Pool; use mysql_async::Row; use sqllogictest::DBOutput; -use sqllogictest::DefaultColumnType; use crate::error::Result; +use crate::util::ColumnType; #[derive(Debug)] pub struct MySQLClient { @@ -46,7 +46,7 @@ impl MySQLClient { self.bench = true; } - pub async fn query(&mut self, sql: &str) -> Result> { + pub async fn query(&mut self, sql: &str) -> Result> { let start = Instant::now(); let res = self.conn.query(sql).await; @@ -75,8 +75,29 @@ impl MySQLClient { } }; + let types = rows.first().map(|row| { + row.columns() + .iter() + .map(|c| { + use mysql_async::consts::ColumnType::*; + match c.column_type() { + MYSQL_TYPE_TINY => ColumnType::Any, + MYSQL_TYPE_SHORT | MYSQL_TYPE_LONG | MYSQL_TYPE_LONGLONG + | MYSQL_TYPE_INT24 => ColumnType::Integer, + MYSQL_TYPE_FLOAT | MYSQL_TYPE_DOUBLE | MYSQL_TYPE_DECIMAL => { + ColumnType::FloatingPoint + } + MYSQL_TYPE_VAR_STRING | MYSQL_TYPE_STRING | MYSQL_TYPE_VARCHAR => { + ColumnType::Text + } + _ => ColumnType::Any, + } + }) + .collect::>() + }); + let mut parsed_rows = Vec::with_capacity(rows.len()); - for row in rows.into_iter() { + for row in rows { let mut parsed_row = Vec::new(); for i in 0..row.len() { let value: Option> = row.get(i); @@ -90,13 +111,9 @@ impl MySQLClient { } parsed_rows.push(parsed_row); } - let mut types = vec![]; - if !parsed_rows.is_empty() { - types = vec![DefaultColumnType::Any; parsed_rows[0].len()]; - } - // Todo: add types to compare + Ok(DBOutput::Rows { - types, + types: types.unwrap_or_default(), rows: parsed_rows, }) } diff --git a/tests/sqllogictests/src/client/ttc_client.rs b/tests/sqllogictests/src/client/ttc_client.rs index 62531dc5d5866..698a07988eab2 100644 --- a/tests/sqllogictests/src/client/ttc_client.rs +++ b/tests/sqllogictests/src/client/ttc_client.rs @@ -16,12 +16,12 @@ use std::time::Instant; use regex::Regex; use sqllogictest::DBOutput; -use sqllogictest::DefaultColumnType; use tokio::io::AsyncReadExt; use tokio::io::AsyncWriteExt; use tokio::net::TcpStream; use crate::error::Result; +use crate::util::ColumnType; #[derive(Debug)] pub struct TTCClient { @@ -47,7 +47,7 @@ impl TTCClient { self.bench = true; } - pub async fn query(&mut self, sql: &str) -> Result> { + pub async fn query(&mut self, sql: &str) -> Result> { let start = Instant::now(); let res = self.query_response(sql).await; @@ -106,7 +106,7 @@ impl TTCClient { } let mut types = vec![]; if !parsed_rows.is_empty() { - types = vec![DefaultColumnType::Any; parsed_rows[0].len()]; + types = vec![ColumnType::Any; parsed_rows[0].len()]; } // Todo: add types to compare Ok(DBOutput::Rows { diff --git a/tests/sqllogictests/src/main.rs b/tests/sqllogictests/src/main.rs index ebaf84cdf7e05..eebdebb2dc1b0 100644 --- a/tests/sqllogictests/src/main.rs +++ b/tests/sqllogictests/src/main.rs @@ -26,7 +26,6 @@ use sqllogictest::default_column_validator; use sqllogictest::default_validator; use sqllogictest::parse_file; use sqllogictest::DBOutput; -use sqllogictest::DefaultColumnType; use sqllogictest::Location; use sqllogictest::QueryExpect; use sqllogictest::Record; @@ -48,6 +47,7 @@ use crate::util::get_files; use crate::util::lazy_prepare_data; use crate::util::lazy_run_dictionary_containers; use crate::util::run_ttc_container; +use crate::util::ColumnType; mod arg; mod client; @@ -97,7 +97,7 @@ impl Databend { #[async_trait::async_trait] impl sqllogictest::AsyncDB for Databend { type Error = DSqlLogicTestError; - type ColumnType = DefaultColumnType; + type ColumnType = ColumnType; async fn run(&mut self, sql: &str) -> Result> { self.client.query(sql).await @@ -291,7 +291,7 @@ async fn run_suits(args: SqlLogicTestArgs, client_type: ClientType) -> Result<() continue; } } - num_of_tests += parse_file::(suit_file.as_ref().unwrap().path()) + num_of_tests += parse_file::(suit_file.as_ref().unwrap().path()) .unwrap() .len(); @@ -355,23 +355,19 @@ async fn run_suits(args: SqlLogicTestArgs, client_type: ClientType) -> Result<() Ok(()) } -fn column_validator( - loc: Location, - actual: Vec, - expected: Vec, -) { +fn column_validator(loc: Location, actual: Vec, expected: Vec) { let equals = if actual.len() != expected.len() { false } else { actual.iter().zip(expected.iter()).all(|x| { - use DefaultColumnType::*; + use ColumnType::*; matches!( x, - (Text, Text) + (Bool, Bool) + | (Text, Text) | (Integer, Integer) | (FloatingPoint, FloatingPoint) | (Any, _) - | (_, Any) ) }) }; diff --git a/tests/sqllogictests/src/util.rs b/tests/sqllogictests/src/util.rs index f65b5f7fc3642..0a7661c2f09f8 100644 --- a/tests/sqllogictests/src/util.rs +++ b/tests/sqllogictests/src/util.rs @@ -508,3 +508,34 @@ async fn stop_container(docker: &Docker, container_name: &str) { } } } + +#[derive(Debug, PartialEq, Eq, Clone, Copy)] +pub enum ColumnType { + Bool, + Text, + Integer, + FloatingPoint, + Any, +} + +impl sqllogictest::ColumnType for ColumnType { + fn from_char(value: char) -> Option { + match value { + 'B' => Some(Self::Bool), + 'T' => Some(Self::Text), + 'I' => Some(Self::Integer), + 'R' => Some(Self::FloatingPoint), + _ => Some(Self::Any), + } + } + + fn to_char(&self) -> char { + match self { + Self::Bool => 'B', + Self::Text => 'T', + Self::Integer => 'I', + Self::FloatingPoint => 'R', + Self::Any => '?', + } + } +} diff --git a/tests/sqllogictests/suites/query/functions/02_0000_function_aggregate_distinct.test b/tests/sqllogictests/suites/query/functions/02_0000_function_aggregate_distinct.test new file mode 100644 index 0000000000000..a2ae4bc89f9f4 --- /dev/null +++ b/tests/sqllogictests/suites/query/functions/02_0000_function_aggregate_distinct.test @@ -0,0 +1,33 @@ +query TI +select 'a', count(distinct a,b,c) +from (values(1,null,1),(2,null,null),(4,null,4),(null,null,null)) t(a,b,c); +---- +a 3 + +query TI +select 'a', count(distinct b) +from (values(1,null,1),(2,null,null),(4,null,4),(null,null,null)) t(a,b,c); +---- +a 0 + +query TI +select 'a', count(distinct c) +from (values(1,null,1),(2,null,null),(4,null,4),(null,null,null)) t(a,b,c); +---- +a 2 + +query TI +select 'a', count(distinct b,c) +from (values(1,null,1),(2,null,null),(4,null,4),(null,null,null)) t(a,b,c); +---- +a 2 + +query TI +select 'a',count(distinct null) from numbers(5); +---- +a 0 + +query TI +select 'a',count(distinct null,null) from numbers(5); +---- +a 0