From 508b997ef0cefaa115b3ca31e6481b3df34fab1a Mon Sep 17 00:00:00 2001 From: Dmitry Patsura Date: Thu, 13 Nov 2025 17:00:52 +0100 Subject: [PATCH] feat: Support round() function with two parameters it's based on https://github.com/cube-js/arrow-datafusion/commit/771c20ce2f0ade29a2d334e4e8494e9fbd7a5940 --- datafusion/src/logical_plan/expr.rs | 42 +++++- datafusion/src/physical_plan/functions.rs | 9 +- .../src/physical_plan/math_expressions.rs | 126 +++++++++++++++++- datafusion/src/physical_plan/planner.rs | 5 +- 4 files changed, 175 insertions(+), 7 deletions(-) diff --git a/datafusion/src/logical_plan/expr.rs b/datafusion/src/logical_plan/expr.rs index aaf3997ad4f6..d859d4573665 100644 --- a/datafusion/src/logical_plan/expr.rs +++ b/datafusion/src/logical_plan/expr.rs @@ -1472,7 +1472,13 @@ unary_scalar_expr!(Atan, atan); unary_scalar_expr!(Floor, floor); unary_scalar_expr!(Ceil, ceil); unary_scalar_expr!(Now, now); -unary_scalar_expr!(Round, round); +/// Returns the nearest integer value to the expression. Digits defaults to 0 if not provided. +pub fn round(args: Vec) -> Expr { + Expr::ScalarFunction { + fun: functions::BuiltinScalarFunction::Round, + args, + } +} unary_scalar_expr!(Trunc, trunc); unary_scalar_expr!(Abs, abs); unary_scalar_expr!(Signum, signum); @@ -2050,6 +2056,18 @@ mod tests { }}; } + macro_rules! test_nary_scalar_expr { + ($ENUM:ident, $FUNC:ident) => {{ + if let Expr::ScalarFunction { fun, args } = $FUNC(col("tableA.a")) { + let name = functions::BuiltinScalarFunction::$ENUM; + assert_eq!(name, fun); + assert_eq!(2, args.len()); + } else { + assert!(false, "unexpected"); + } + }}; + } + #[test] fn scalar_function_definitions() { test_unary_scalar_expr!(Sqrt, sqrt); @@ -2062,7 +2080,6 @@ mod tests { test_unary_scalar_expr!(Floor, floor); test_unary_scalar_expr!(Ceil, ceil); test_unary_scalar_expr!(Now, now); - test_unary_scalar_expr!(Round, round); test_unary_scalar_expr!(Trunc, trunc); test_unary_scalar_expr!(Abs, abs); test_unary_scalar_expr!(Signum, signum); @@ -2104,4 +2121,25 @@ mod tests { test_unary_scalar_expr!(Trim, trim); test_unary_scalar_expr!(Upper, upper); } + + #[test] + fn test_round_definition() { + // test round with 1 argument + if let Expr::ScalarFunction { fun, args } = round(vec![col("tableA.a")]) { + let name = functions::BuiltinScalarFunction::Round; + assert_eq!(name, fun); + assert_eq!(1, args.len()); + } else { + assert!(false, "unexpected"); + } + + // test round with 2 arguments + if let Expr::ScalarFunction { fun, args } = round(vec![col("tableA.a"), lit(2)]) { + let name = functions::BuiltinScalarFunction::Round; + assert_eq!(name, fun); + assert_eq!(2, args.len()); + } else { + assert!(false, "unexpected"); + } + } } diff --git a/datafusion/src/physical_plan/functions.rs b/datafusion/src/physical_plan/functions.rs index 2170881af9c3..1bcc064ea047 100644 --- a/datafusion/src/physical_plan/functions.rs +++ b/datafusion/src/physical_plan/functions.rs @@ -571,7 +571,9 @@ pub fn create_physical_fun( BuiltinScalarFunction::Log10 => Arc::new(math_expressions::log10), BuiltinScalarFunction::Log2 => Arc::new(math_expressions::log2), BuiltinScalarFunction::Random => Arc::new(math_expressions::random), - BuiltinScalarFunction::Round => Arc::new(math_expressions::round), + BuiltinScalarFunction::Round => { + Arc::new(|args| make_scalar_function(math_expressions::round)(args)) + } BuiltinScalarFunction::Signum => Arc::new(math_expressions::signum), BuiltinScalarFunction::Sin => Arc::new(math_expressions::sin), BuiltinScalarFunction::Sqrt => Arc::new(math_expressions::sqrt), @@ -1279,6 +1281,11 @@ fn signature(fun: &BuiltinScalarFunction) -> Signature { Signature::Exact(vec![DataType::Utf8, DataType::Utf8, DataType::Utf8]), Signature::Exact(vec![DataType::LargeUtf8, DataType::Utf8, DataType::Utf8]), ]), + BuiltinScalarFunction::Round => Signature::OneOf(vec![ + Signature::Uniform(1, vec![DataType::Float64, DataType::Float32]), + Signature::Exact(vec![DataType::Float64, DataType::Int64]), + Signature::Exact(vec![DataType::Float32, DataType::Int64]), + ]), BuiltinScalarFunction::Random => Signature::Exact(vec![]), // math expressions expect 1 argument of type f64 or f32 // priority is given to f64 because e.g. `sqrt(1i32)` is in IR (real numbers) and thus we diff --git a/datafusion/src/physical_plan/math_expressions.rs b/datafusion/src/physical_plan/math_expressions.rs index eabacfc6eb18..9918563a6269 100644 --- a/datafusion/src/physical_plan/math_expressions.rs +++ b/datafusion/src/physical_plan/math_expressions.rs @@ -18,9 +18,10 @@ //! Math expressions use super::{ColumnarValue, ScalarValue}; use crate::error::{DataFusionError, Result}; -use arrow::array::{Float32Array, Float64Array}; +use arrow::array::{ArrayRef, Float32Array, Float64Array, Int64Array}; use arrow::datatypes::DataType; use rand::{thread_rng, Rng}; +use std::any::type_name; use std::iter; use std::sync::Arc; @@ -84,6 +85,33 @@ macro_rules! math_unary_function { }; } +macro_rules! downcast_arg { + ($ARG:expr, $NAME:expr, $ARRAY_TYPE:ident) => {{ + $ARG.as_any().downcast_ref::<$ARRAY_TYPE>().ok_or_else(|| { + DataFusionError::Internal(format!( + "could not cast {} to {}", + $NAME, + type_name::<$ARRAY_TYPE>() + )) + })? + }}; +} + +macro_rules! make_function_inputs2 { + ($ARG1: expr, $ARG2: expr, $NAME1:expr, $NAME2: expr, $ARRAY_TYPE1:ident, $ARRAY_TYPE2:ident, $FUNC: block) => {{ + let arg1 = downcast_arg!($ARG1, $NAME1, $ARRAY_TYPE1); + let arg2 = downcast_arg!($ARG2, $NAME2, $ARRAY_TYPE2); + + arg1.iter() + .zip(arg2.iter()) + .map(|(a1, a2)| match (a1, a2) { + (Some(a1), Some(a2)) => Some($FUNC(a1, a2.try_into().ok()?)), + _ => None, + }) + .collect::<$ARRAY_TYPE1>() + }}; +} + math_unary_function!("sqrt", sqrt); math_unary_function!("sin", sin); math_unary_function!("cos", cos); @@ -93,7 +121,6 @@ math_unary_function!("acos", acos); math_unary_function!("atan", atan); math_unary_function!("floor", floor); math_unary_function!("ceil", ceil); -math_unary_function!("round", round); math_unary_function!("trunc", trunc); math_unary_function!("abs", abs); math_unary_function!("signum", signum); @@ -118,11 +145,64 @@ pub fn random(args: &[ColumnarValue]) -> Result { Ok(ColumnarValue::Array(Arc::new(array))) } +/// Round SQL function +pub fn round(args: &[ArrayRef]) -> Result { + if args.len() != 1 && args.len() != 2 { + return Err(DataFusionError::Internal(format!( + "round function requires one or two arguments, got {}", + args.len() + ))); + } + + let mut decimal_places = + &(Arc::new(Int64Array::from_value(0, args[0].len())) as ArrayRef); + + if args.len() == 2 { + decimal_places = &args[1]; + } + + match args[0].data_type() { + DataType::Float64 => Ok(Arc::new(make_function_inputs2!( + &args[0], + decimal_places, + "value", + "decimal_places", + Float64Array, + Int64Array, + { + |value: f64, decimal_places: i64| { + (value * 10.0_f64.powi(decimal_places.try_into().unwrap())).round() + / 10.0_f64.powi(decimal_places.try_into().unwrap()) + } + } + )) as ArrayRef), + + DataType::Float32 => Ok(Arc::new(make_function_inputs2!( + &args[0], + decimal_places, + "value", + "decimal_places", + Float32Array, + Int64Array, + { + |value: f32, decimal_places: i64| { + (value * 10.0_f32.powi(decimal_places.try_into().unwrap())).round() + / 10.0_f32.powi(decimal_places.try_into().unwrap()) + } + } + )) as ArrayRef), + + other => Err(DataFusionError::Internal(format!( + "Unsupported data type {other:?} for function round" + ))), + } +} + #[cfg(test)] mod tests { use super::*; - use arrow::array::{Float64Array, NullArray}; + use arrow::array::{Float32Array, Float64Array, NullArray}; #[test] fn test_random_expression() { @@ -133,4 +213,44 @@ mod tests { assert_eq!(floats.len(), 1); assert!(0.0 <= floats.value(0) && floats.value(0) < 1.0); } + + #[test] + fn test_round_f32() { + let args: Vec = vec![ + Arc::new(Float32Array::from(vec![125.2345; 10])), // input + Arc::new(Int64Array::from(vec![0, 1, 2, 3, 4, 5, -1, -2, -3, -4])), // decimal_places + ]; + + let result = round(&args).expect("failed to initialize function round"); + let floats = result + .as_any() + .downcast_ref::() + .expect("failed to initialize function round"); + + let expected = Float32Array::from(vec![ + 125.0, 125.2, 125.23, 125.235, 125.2345, 125.2345, 130.0, 100.0, 0.0, 0.0, + ]); + + assert_eq!(floats, &expected); + } + + #[test] + fn test_round_f64() { + let args: Vec = vec![ + Arc::new(Float64Array::from(vec![125.2345; 10])), // input + Arc::new(Int64Array::from(vec![0, 1, 2, 3, 4, 5, -1, -2, -3, -4])), // decimal_places + ]; + + let result = round(&args).expect("failed to initialize function round"); + let floats = result + .as_any() + .downcast_ref::() + .expect("failed to initialize function round"); + + let expected = Float64Array::from(vec![ + 125.0, 125.2, 125.23, 125.235, 125.2345, 125.2345, 130.0, 100.0, 0.0, 0.0, + ]); + + assert_eq!(floats, &expected); + } } diff --git a/datafusion/src/physical_plan/planner.rs b/datafusion/src/physical_plan/planner.rs index fe163aa02643..f4f3c7e9924e 100644 --- a/datafusion/src/physical_plan/planner.rs +++ b/datafusion/src/physical_plan/planner.rs @@ -345,7 +345,10 @@ impl DefaultPhysicalPlanner { extension_planners.insert(1, Arc::new(CrossJoinPlanner {})); extension_planners.insert(2, Arc::new(CrossJoinAggPlanner {})); extension_planners.insert(3, Arc::new(crate::cube_ext::rolling::Planner {})); - Self { should_evaluate_constants: true, extension_planners } + Self { + should_evaluate_constants: true, + extension_planners, + } } /// Create a physical plan from a logical plan