diff --git a/src/query/expression/src/type_check.rs b/src/query/expression/src/type_check.rs index fbe7ca0fa4391..2164c977276d6 100755 --- a/src/query/expression/src/type_check.rs +++ b/src/query/expression/src/type_check.rs @@ -112,7 +112,7 @@ pub fn check( if dest_ty.is_integer() && src_ty.is_integer() { if let Ok(casted_scalar) = - cast_scalar(*span, scalar.clone(), dest_ty, fn_registry) + cast_scalar(*span, scalar.clone(), &dest_ty, fn_registry) { *scalar = casted_scalar; *data_type = scalar.as_ref().infer_data_type(); diff --git a/src/query/expression/src/utils/mod.rs b/src/query/expression/src/utils/mod.rs index 0113f0ede43da..c7f338b172e75 100644 --- a/src/query/expression/src/utils/mod.rs +++ b/src/query/expression/src/utils/mod.rs @@ -27,6 +27,7 @@ pub mod visitor; use databend_common_ast::Span; use databend_common_column::bitmap::Bitmap; +use databend_common_exception::ErrorCode; use databend_common_exception::Result; pub use self::column_from::*; @@ -36,8 +37,12 @@ use crate::types::AnyType; use crate::types::DataType; use crate::types::Decimal; use crate::types::DecimalDataKind; +use crate::types::DecimalDataType; use crate::types::DecimalSize; +use crate::types::NumberDataType; use crate::types::NumberScalar; +use crate::types::F32; +use crate::types::F64; use crate::BlockEntry; use crate::Column; use crate::DataBlock; @@ -87,9 +92,13 @@ pub fn eval_function( pub fn cast_scalar( span: Span, scalar: Scalar, - dest_type: DataType, + dest_type: &DataType, fn_registry: &FunctionRegistry, ) -> Result { + if let Some(result) = try_fast_cast_scalar(&scalar, dest_type) { + return result; + } + let raw_expr = RawExpr::Cast { span, is_try: false, @@ -98,7 +107,7 @@ pub fn cast_scalar( scalar, data_type: None, }), - dest_type, + dest_type: dest_type.clone(), }; let expr = crate::type_check::check(&raw_expr, fn_registry)?; let block = DataBlock::empty(); @@ -107,6 +116,89 @@ pub fn cast_scalar( Ok(evaluator.run(&expr)?.into_scalar().unwrap()) } +fn try_fast_cast_scalar(scalar: &Scalar, dest_type: &DataType) -> Option> { + match dest_type { + DataType::Null => Some(Ok(Scalar::Null)), + DataType::Nullable(inner) => { + if matches!(scalar, Scalar::Null) { + Some(Ok(Scalar::Null)) + } else { + try_fast_cast_scalar(scalar, inner) + } + } + DataType::Number(NumberDataType::Float32) => match scalar { + Scalar::Null => Some(Ok(Scalar::Null)), + Scalar::Number(num) => Some(Ok(Scalar::Number(NumberScalar::Float32(num.to_f32())))), + Scalar::Decimal(dec) => Some(Ok(Scalar::Number(NumberScalar::Float32(F32::from( + dec.to_float32(), + ))))), + _ => None, + }, + DataType::Number(NumberDataType::Float64) => match scalar { + Scalar::Null => Some(Ok(Scalar::Null)), + Scalar::Number(num) => Some(Ok(Scalar::Number(NumberScalar::Float64(num.to_f64())))), + Scalar::Decimal(dec) => Some(Ok(Scalar::Number(NumberScalar::Float64(F64::from( + dec.to_float64(), + ))))), + _ => None, + }, + DataType::Decimal(size) => match scalar { + Scalar::Null => Some(Ok(Scalar::Null)), + Scalar::Decimal(dec) => Some(rescale_decimal_scalar(*dec, *size)), + _ => None, + }, + _ => None, + } +} + +fn rescale_decimal_scalar(decimal: DecimalScalar, target_size: DecimalSize) -> Result { + let from_size = decimal.size(); + if from_size == target_size { + return Ok(Scalar::Decimal(decimal)); + } + + let source_scale = from_size.scale(); + let target_scale = target_size.scale(); + let data_type: DecimalDataType = target_size.into(); + + let scaled = match data_type { + DecimalDataType::Decimal64(_) => { + let value = decimal.as_decimal::(); + let adjusted = rescale_decimal_value(value, source_scale, target_scale)?; + Scalar::Decimal(DecimalScalar::Decimal64(adjusted, target_size)) + } + DecimalDataType::Decimal128(_) => { + let value = decimal.as_decimal::(); + let adjusted = rescale_decimal_value(value, source_scale, target_scale)?; + Scalar::Decimal(DecimalScalar::Decimal128(adjusted, target_size)) + } + DecimalDataType::Decimal256(_) => { + let value = decimal.as_decimal::(); + let adjusted = rescale_decimal_value(value, source_scale, target_scale)?; + Scalar::Decimal(DecimalScalar::Decimal256(adjusted, target_size)) + } + }; + + Ok(scaled) +} + +fn rescale_decimal_value(value: T, source_scale: u8, target_scale: u8) -> Result { + if source_scale == target_scale { + return Ok(value); + } + + let diff = target_scale.abs_diff(source_scale); + if target_scale > source_scale { + value.checked_mul(T::e(diff)).ok_or_else(|| { + ErrorCode::Overflow("Decimal literal overflow after scale expansion".to_string()) + }) + } else { + value.checked_div(T::e(diff)).ok_or_else(|| { + ErrorCode::Overflow("Decimal literal overflow after scale reduction".to_string()) + }) + } +} + pub fn column_merge_validity(entry: &BlockEntry, bitmap: Option) -> Option { match entry { BlockEntry::Const(scalar, data_type, n) => { diff --git a/src/query/functions/src/scalars/decimal/src/comparison.rs b/src/query/functions/src/scalars/decimal/src/comparison.rs index 8aeb84379ebae..c8e61c6014b15 100644 --- a/src/query/functions/src/scalars/decimal/src/comparison.rs +++ b/src/query/functions/src/scalars/decimal/src/comparison.rs @@ -145,7 +145,12 @@ fn op_decimal( T::e(size_calc.scale() - a_type.scale()), T::e(size_calc.scale() - b_type.scale()), ); - compare_decimal(a, b, |a, b, _| Op::compare(a, b, f_a, f_b), ctx) + + if (f_a == f_b) { + compare_decimal(a, b, |a, b, _| Op::is(a.cmp(&b)), ctx) + } else { + compare_decimal(a, b, |a, b, _| Op::compare(a, b, f_a, f_b), ctx) + } } }) } diff --git a/src/query/service/src/interpreters/interpreter_set.rs b/src/query/service/src/interpreters/interpreter_set.rs index 9000ba8921375..acc570318d095 100644 --- a/src/query/service/src/interpreters/interpreter_set.rs +++ b/src/query/service/src/interpreters/interpreter_set.rs @@ -62,7 +62,7 @@ impl SetInterpreter { async fn execute_settings(&self, scalars: Vec, is_global: bool) -> Result<()> { let scalars: Vec = scalars .into_iter() - .map(|scalar| cast_scalar(None, scalar.clone(), DataType::String, &BUILTIN_FUNCTIONS)) + .map(|scalar| cast_scalar(None, scalar.clone(), &DataType::String, &BUILTIN_FUNCTIONS)) .collect::>>()?; let mut keys: Vec = vec![]; diff --git a/src/query/sql/src/planner/binder/statement_settings.rs b/src/query/sql/src/planner/binder/statement_settings.rs index 7dfdfb1e1efa6..1becd783a062c 100644 --- a/src/query/sql/src/planner/binder/statement_settings.rs +++ b/src/query/sql/src/planner/binder/statement_settings.rs @@ -83,7 +83,7 @@ impl Binder { let scalar = cast_scalar( None, scalar.clone(), - DataType::String, + &DataType::String, &BUILTIN_FUNCTIONS, )?; results.push(scalar); diff --git a/src/query/sql/src/planner/semantic/type_check.rs b/src/query/sql/src/planner/semantic/type_check.rs index f21343f59f4df..e14fdd492df22 100644 --- a/src/query/sql/src/planner/semantic/type_check.rs +++ b/src/query/sql/src/planner/semantic/type_check.rs @@ -62,12 +62,14 @@ use databend_common_compress::CompressAlgorithm; use databend_common_compress::DecompressDecoder; use databend_common_exception::ErrorCode; use databend_common_exception::Result; +use databend_common_expression::cast_scalar; use databend_common_expression::display::display_tuple_field_name; use databend_common_expression::expr; use databend_common_expression::infer_schema_type; use databend_common_expression::shrink_scalar; use databend_common_expression::type_check; use databend_common_expression::type_check::check_number; +use databend_common_expression::type_check::common_super_type; use databend_common_expression::type_check::convert_escape_pattern; use databend_common_expression::types::decimal::DecimalScalar; use databend_common_expression::types::decimal::DecimalSize; @@ -81,6 +83,7 @@ use databend_common_expression::types::F32; use databend_common_expression::udf_client::UDFFlightClient; use databend_common_expression::BlockEntry; use databend_common_expression::Column; +use databend_common_expression::ColumnBuilder; use databend_common_expression::ColumnIndex; use databend_common_expression::Constant; use databend_common_expression::ConstantFolder; @@ -3409,6 +3412,15 @@ impl<'a> TypeChecker<'a> { // Omit unary + operator self.resolve(child) } + UnaryOperator::Minus => { + if let Expr::Literal { value, .. } = child { + let box (value, data_type) = self.resolve_minus_literal_scalar(span, value)?; + let scalar_expr = ScalarExpr::ConstantExpr(ConstantExpr { span, value }); + return Ok(Box::new((scalar_expr, data_type))); + } + let name = op.to_func_name(); + self.resolve_function(span, name.as_str(), vec![], &[child]) + } other => { let name = other.to_func_name(); self.resolve_function(span, name.as_str(), vec![], &[child]) @@ -4821,18 +4833,121 @@ impl<'a> TypeChecker<'a> { Ok(Box::new((value, data_type))) } - // TODO(leiysky): use an array builder function instead, since we should allow declaring - // an array with variable as element. + pub fn resolve_minus_literal_scalar( + &self, + span: Span, + literal: &databend_common_ast::ast::Literal, + ) -> Result> { + let value = match literal { + Literal::UInt64(v) => { + if *v <= i64::MAX as u64 { + Scalar::Number(NumberScalar::Int64(-(*v as i64))) + } else { + Scalar::Decimal(DecimalScalar::Decimal128( + -(*v as i128), + DecimalSize::new_unchecked(i128::MAX_PRECISION, 0), + )) + } + } + Literal::Decimal256 { + value, + precision, + scale, + } => Scalar::Decimal(DecimalScalar::Decimal256( + i256(*value).checked_mul(i256::minus_one()).unwrap(), + DecimalSize::new_unchecked(*precision, *scale), + )), + Literal::Float64(v) => Scalar::Number(NumberScalar::Float64((-*v).into())), + Literal::Null => Scalar::Null, + Literal::String(_) | Literal::Boolean(_) => { + return Err(ErrorCode::InvalidArgument(format!( + "Invalid minus operator for {}", + literal + )) + .set_span(span)); + } + }; + let value = shrink_scalar(value); + let data_type = value.as_ref().infer_data_type(); + Ok(Box::new((value, data_type))) + } + + // Fast path for constant arrays so we don't need to go through the scalar `array()` function + // (which performs full type-checking and constant-folding). Non-constant elements still use + // the generic resolver to preserve the previous behaviour. fn resolve_array(&mut self, span: Span, exprs: &[Expr]) -> Result> { let mut elems = Vec::with_capacity(exprs.len()); + let mut constant_values: Option> = + Some(Vec::with_capacity(exprs.len())); + let mut element_type: Option = None; + for expr in exprs { - let box (arg, _data_type) = self.resolve(expr)?; + let box (arg, data_type) = self.resolve(expr)?; + if let Some(values) = constant_values.as_mut() { + let maybe_constant = match &arg { + ScalarExpr::ConstantExpr(constant) => Some(constant.value.clone()), + ScalarExpr::TypedConstantExpr(constant, _) => Some(constant.value.clone()), + _ => None, + }; + if let Some(value) = maybe_constant { + element_type = if let Some(current_ty) = element_type.clone() { + common_super_type( + current_ty.clone(), + data_type.clone(), + &BUILTIN_FUNCTIONS.default_cast_rules, + ) + } else { + Some(data_type.clone()) + }; + + if element_type.is_some() { + values.push((value, data_type)); + } else { + constant_values = None; + element_type = None; + } + } else { + constant_values = None; + element_type = None; + } + } elems.push(arg); } + if let (Some(values), Some(element_ty)) = (constant_values, element_type) { + let mut casted = Vec::with_capacity(values.len()); + for (value, ty) in values { + if ty == element_ty { + casted.push(value); + } else { + casted.push(cast_scalar(span, value, &element_ty, &BUILTIN_FUNCTIONS)?); + } + } + return Ok(Self::build_constant_array(span, element_ty, casted)); + } + self.resolve_scalar_function_call(span, "array", vec![], elems) } + fn build_constant_array( + span: Span, + element_ty: DataType, + values: Vec, + ) -> Box<(ScalarExpr, DataType)> { + let mut builder = ColumnBuilder::with_capacity(&element_ty, values.len()); + for value in &values { + builder.push(value.as_ref()); + } + let scalar = Scalar::Array(builder.build()); + Box::new(( + ScalarExpr::ConstantExpr(ConstantExpr { + span, + value: scalar, + }), + DataType::Array(Box::new(element_ty)), + )) + } + fn resolve_map( &mut self, span: Span,