Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/query/expression/src/type_check.rs
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ pub fn check<Index: ColumnIndex>(

if dest_ty.is_integer() && src_ty.is_integer() {
if let Ok(casted_scalar) =
cast_scalar(*span, scalar.clone(), dest_ty, fn_registry)
cast_scalar(*span, scalar.clone(), &dest_ty, fn_registry)
{
*scalar = casted_scalar;
*data_type = scalar.as_ref().infer_data_type();
Expand Down
96 changes: 94 additions & 2 deletions src/query/expression/src/utils/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ pub mod visitor;

use databend_common_ast::Span;
use databend_common_column::bitmap::Bitmap;
use databend_common_exception::ErrorCode;
use databend_common_exception::Result;

pub use self::column_from::*;
Expand All @@ -36,8 +37,12 @@ use crate::types::AnyType;
use crate::types::DataType;
use crate::types::Decimal;
use crate::types::DecimalDataKind;
use crate::types::DecimalDataType;
use crate::types::DecimalSize;
use crate::types::NumberDataType;
use crate::types::NumberScalar;
use crate::types::F32;
use crate::types::F64;
use crate::BlockEntry;
use crate::Column;
use crate::DataBlock;
Expand Down Expand Up @@ -87,9 +92,13 @@ pub fn eval_function(
pub fn cast_scalar(
span: Span,
scalar: Scalar,
dest_type: DataType,
dest_type: &DataType,
fn_registry: &FunctionRegistry,
) -> Result<Scalar> {
if let Some(result) = try_fast_cast_scalar(&scalar, dest_type) {
return result;
}

let raw_expr = RawExpr::Cast {
span,
is_try: false,
Expand All @@ -98,7 +107,7 @@ pub fn cast_scalar(
scalar,
data_type: None,
}),
dest_type,
dest_type: dest_type.clone(),
};
let expr = crate::type_check::check(&raw_expr, fn_registry)?;
let block = DataBlock::empty();
Expand All @@ -107,6 +116,89 @@ pub fn cast_scalar(
Ok(evaluator.run(&expr)?.into_scalar().unwrap())
}

fn try_fast_cast_scalar(scalar: &Scalar, dest_type: &DataType) -> Option<Result<Scalar>> {
match dest_type {
DataType::Null => Some(Ok(Scalar::Null)),
DataType::Nullable(inner) => {
if matches!(scalar, Scalar::Null) {
Some(Ok(Scalar::Null))
} else {
try_fast_cast_scalar(scalar, inner)
}
}
DataType::Number(NumberDataType::Float32) => match scalar {
Scalar::Null => Some(Ok(Scalar::Null)),
Scalar::Number(num) => Some(Ok(Scalar::Number(NumberScalar::Float32(num.to_f32())))),
Scalar::Decimal(dec) => Some(Ok(Scalar::Number(NumberScalar::Float32(F32::from(
dec.to_float32(),
))))),
_ => None,
},
DataType::Number(NumberDataType::Float64) => match scalar {
Scalar::Null => Some(Ok(Scalar::Null)),
Scalar::Number(num) => Some(Ok(Scalar::Number(NumberScalar::Float64(num.to_f64())))),
Scalar::Decimal(dec) => Some(Ok(Scalar::Number(NumberScalar::Float64(F64::from(
dec.to_float64(),
))))),
_ => None,
},
DataType::Decimal(size) => match scalar {
Scalar::Null => Some(Ok(Scalar::Null)),
Scalar::Decimal(dec) => Some(rescale_decimal_scalar(*dec, *size)),
_ => None,
},
_ => None,
}
}

fn rescale_decimal_scalar(decimal: DecimalScalar, target_size: DecimalSize) -> Result<Scalar> {
let from_size = decimal.size();
if from_size == target_size {
return Ok(Scalar::Decimal(decimal));
}

let source_scale = from_size.scale();
let target_scale = target_size.scale();
let data_type: DecimalDataType = target_size.into();

let scaled = match data_type {
DecimalDataType::Decimal64(_) => {
let value = decimal.as_decimal::<i64>();
let adjusted = rescale_decimal_value(value, source_scale, target_scale)?;
Scalar::Decimal(DecimalScalar::Decimal64(adjusted, target_size))
}
DecimalDataType::Decimal128(_) => {
let value = decimal.as_decimal::<i128>();
let adjusted = rescale_decimal_value(value, source_scale, target_scale)?;
Scalar::Decimal(DecimalScalar::Decimal128(adjusted, target_size))
}
DecimalDataType::Decimal256(_) => {
let value = decimal.as_decimal::<i256>();
let adjusted = rescale_decimal_value(value, source_scale, target_scale)?;
Scalar::Decimal(DecimalScalar::Decimal256(adjusted, target_size))
}
};

Ok(scaled)
}

fn rescale_decimal_value<T: Decimal>(value: T, source_scale: u8, target_scale: u8) -> Result<T> {
if source_scale == target_scale {
return Ok(value);
}

let diff = target_scale.abs_diff(source_scale);
if target_scale > source_scale {
value.checked_mul(T::e(diff)).ok_or_else(|| {
ErrorCode::Overflow("Decimal literal overflow after scale expansion".to_string())
})
} else {
value.checked_div(T::e(diff)).ok_or_else(|| {
ErrorCode::Overflow("Decimal literal overflow after scale reduction".to_string())
})
}
}

pub fn column_merge_validity(entry: &BlockEntry, bitmap: Option<Bitmap>) -> Option<Bitmap> {
match entry {
BlockEntry::Const(scalar, data_type, n) => {
Expand Down
7 changes: 6 additions & 1 deletion src/query/functions/src/scalars/decimal/src/comparison.rs
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,12 @@ fn op_decimal<Op: CmpOp>(
T::e(size_calc.scale() - a_type.scale()),
T::e(size_calc.scale() - b_type.scale()),
);
compare_decimal(a, b, |a, b, _| Op::compare(a, b, f_a, f_b), ctx)

if (f_a == f_b) {
compare_decimal(a, b, |a, b, _| Op::is(a.cmp(&b)), ctx)
} else {
compare_decimal(a, b, |a, b, _| Op::compare(a, b, f_a, f_b), ctx)
}
}
})
}
Expand Down
2 changes: 1 addition & 1 deletion src/query/service/src/interpreters/interpreter_set.rs
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ impl SetInterpreter {
async fn execute_settings(&self, scalars: Vec<Scalar>, is_global: bool) -> Result<()> {
let scalars: Vec<Scalar> = scalars
.into_iter()
.map(|scalar| cast_scalar(None, scalar.clone(), DataType::String, &BUILTIN_FUNCTIONS))
.map(|scalar| cast_scalar(None, scalar.clone(), &DataType::String, &BUILTIN_FUNCTIONS))
.collect::<Result<Vec<_>>>()?;

let mut keys: Vec<String> = vec![];
Expand Down
2 changes: 1 addition & 1 deletion src/query/sql/src/planner/binder/statement_settings.rs
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ impl Binder {
let scalar = cast_scalar(
None,
scalar.clone(),
DataType::String,
&DataType::String,
&BUILTIN_FUNCTIONS,
)?;
results.push(scalar);
Expand Down
121 changes: 118 additions & 3 deletions src/query/sql/src/planner/semantic/type_check.rs
Original file line number Diff line number Diff line change
Expand Up @@ -62,12 +62,14 @@ use databend_common_compress::CompressAlgorithm;
use databend_common_compress::DecompressDecoder;
use databend_common_exception::ErrorCode;
use databend_common_exception::Result;
use databend_common_expression::cast_scalar;
use databend_common_expression::display::display_tuple_field_name;
use databend_common_expression::expr;
use databend_common_expression::infer_schema_type;
use databend_common_expression::shrink_scalar;
use databend_common_expression::type_check;
use databend_common_expression::type_check::check_number;
use databend_common_expression::type_check::common_super_type;
use databend_common_expression::type_check::convert_escape_pattern;
use databend_common_expression::types::decimal::DecimalScalar;
use databend_common_expression::types::decimal::DecimalSize;
Expand All @@ -81,6 +83,7 @@ use databend_common_expression::types::F32;
use databend_common_expression::udf_client::UDFFlightClient;
use databend_common_expression::BlockEntry;
use databend_common_expression::Column;
use databend_common_expression::ColumnBuilder;
use databend_common_expression::ColumnIndex;
use databend_common_expression::Constant;
use databend_common_expression::ConstantFolder;
Expand Down Expand Up @@ -3409,6 +3412,15 @@ impl<'a> TypeChecker<'a> {
// Omit unary + operator
self.resolve(child)
}
UnaryOperator::Minus => {
if let Expr::Literal { value, .. } = child {
let box (value, data_type) = self.resolve_minus_literal_scalar(span, value)?;
let scalar_expr = ScalarExpr::ConstantExpr(ConstantExpr { span, value });
return Ok(Box::new((scalar_expr, data_type)));
}
let name = op.to_func_name();
self.resolve_function(span, name.as_str(), vec![], &[child])
}
other => {
let name = other.to_func_name();
self.resolve_function(span, name.as_str(), vec![], &[child])
Expand Down Expand Up @@ -4821,18 +4833,121 @@ impl<'a> TypeChecker<'a> {
Ok(Box::new((value, data_type)))
}

// TODO(leiysky): use an array builder function instead, since we should allow declaring
// an array with variable as element.
pub fn resolve_minus_literal_scalar(
&self,
span: Span,
literal: &databend_common_ast::ast::Literal,
) -> Result<Box<(Scalar, DataType)>> {
let value = match literal {
Literal::UInt64(v) => {
if *v <= i64::MAX as u64 {
Scalar::Number(NumberScalar::Int64(-(*v as i64)))
} else {
Scalar::Decimal(DecimalScalar::Decimal128(
-(*v as i128),
DecimalSize::new_unchecked(i128::MAX_PRECISION, 0),
))
}
}
Literal::Decimal256 {
value,
precision,
scale,
} => Scalar::Decimal(DecimalScalar::Decimal256(
i256(*value).checked_mul(i256::minus_one()).unwrap(),
DecimalSize::new_unchecked(*precision, *scale),
)),
Literal::Float64(v) => Scalar::Number(NumberScalar::Float64((-*v).into())),
Literal::Null => Scalar::Null,
Literal::String(_) | Literal::Boolean(_) => {
return Err(ErrorCode::InvalidArgument(format!(
"Invalid minus operator for {}",
literal
))
.set_span(span));
}
};
let value = shrink_scalar(value);
let data_type = value.as_ref().infer_data_type();
Ok(Box::new((value, data_type)))
}

// Fast path for constant arrays so we don't need to go through the scalar `array()` function
// (which performs full type-checking and constant-folding). Non-constant elements still use
// the generic resolver to preserve the previous behaviour.
fn resolve_array(&mut self, span: Span, exprs: &[Expr]) -> Result<Box<(ScalarExpr, DataType)>> {
let mut elems = Vec::with_capacity(exprs.len());
let mut constant_values: Option<Vec<(Scalar, DataType)>> =
Some(Vec::with_capacity(exprs.len()));
let mut element_type: Option<DataType> = None;

for expr in exprs {
let box (arg, _data_type) = self.resolve(expr)?;
let box (arg, data_type) = self.resolve(expr)?;
if let Some(values) = constant_values.as_mut() {
let maybe_constant = match &arg {
ScalarExpr::ConstantExpr(constant) => Some(constant.value.clone()),
ScalarExpr::TypedConstantExpr(constant, _) => Some(constant.value.clone()),
_ => None,
};
if let Some(value) = maybe_constant {
element_type = if let Some(current_ty) = element_type.clone() {
common_super_type(
current_ty.clone(),
data_type.clone(),
&BUILTIN_FUNCTIONS.default_cast_rules,
)
} else {
Some(data_type.clone())
};

if element_type.is_some() {
values.push((value, data_type));
} else {
constant_values = None;
element_type = None;
}
} else {
constant_values = None;
element_type = None;
}
}
elems.push(arg);
}

if let (Some(values), Some(element_ty)) = (constant_values, element_type) {
let mut casted = Vec::with_capacity(values.len());
for (value, ty) in values {
if ty == element_ty {
casted.push(value);
} else {
casted.push(cast_scalar(span, value, &element_ty, &BUILTIN_FUNCTIONS)?);
}
}
return Ok(Self::build_constant_array(span, element_ty, casted));
}

self.resolve_scalar_function_call(span, "array", vec![], elems)
}

fn build_constant_array(
span: Span,
element_ty: DataType,
values: Vec<Scalar>,
) -> Box<(ScalarExpr, DataType)> {
let mut builder = ColumnBuilder::with_capacity(&element_ty, values.len());
for value in &values {
builder.push(value.as_ref());
}
let scalar = Scalar::Array(builder.build());
Box::new((
ScalarExpr::ConstantExpr(ConstantExpr {
span,
value: scalar,
}),
DataType::Array(Box::new(element_ty)),
))
}

fn resolve_map(
&mut self,
span: Span,
Expand Down
Loading