Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 40 additions & 2 deletions datafusion/src/logical_plan/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1472,7 +1472,13 @@ unary_scalar_expr!(Atan, atan);
unary_scalar_expr!(Floor, floor);
unary_scalar_expr!(Ceil, ceil);
unary_scalar_expr!(Now, now);
unary_scalar_expr!(Round, round);
/// Returns the nearest integer value to the expression. Digits defaults to 0 if not provided.
pub fn round(args: Vec<Expr>) -> Expr {
Expr::ScalarFunction {
fun: functions::BuiltinScalarFunction::Round,
args,
}
}
unary_scalar_expr!(Trunc, trunc);
unary_scalar_expr!(Abs, abs);
unary_scalar_expr!(Signum, signum);
Expand Down Expand Up @@ -2050,6 +2056,18 @@ mod tests {
}};
}

macro_rules! test_nary_scalar_expr {
($ENUM:ident, $FUNC:ident) => {{
if let Expr::ScalarFunction { fun, args } = $FUNC(col("tableA.a")) {
let name = functions::BuiltinScalarFunction::$ENUM;
assert_eq!(name, fun);
assert_eq!(2, args.len());
} else {
assert!(false, "unexpected");
}
}};
}

#[test]
fn scalar_function_definitions() {
test_unary_scalar_expr!(Sqrt, sqrt);
Expand All @@ -2062,7 +2080,6 @@ mod tests {
test_unary_scalar_expr!(Floor, floor);
test_unary_scalar_expr!(Ceil, ceil);
test_unary_scalar_expr!(Now, now);
test_unary_scalar_expr!(Round, round);
test_unary_scalar_expr!(Trunc, trunc);
test_unary_scalar_expr!(Abs, abs);
test_unary_scalar_expr!(Signum, signum);
Expand Down Expand Up @@ -2104,4 +2121,25 @@ mod tests {
test_unary_scalar_expr!(Trim, trim);
test_unary_scalar_expr!(Upper, upper);
}

#[test]
fn test_round_definition() {
// test round with 1 argument
if let Expr::ScalarFunction { fun, args } = round(vec![col("tableA.a")]) {
let name = functions::BuiltinScalarFunction::Round;
assert_eq!(name, fun);
assert_eq!(1, args.len());
} else {
assert!(false, "unexpected");
}

// test round with 2 arguments
if let Expr::ScalarFunction { fun, args } = round(vec![col("tableA.a"), lit(2)]) {
let name = functions::BuiltinScalarFunction::Round;
assert_eq!(name, fun);
assert_eq!(2, args.len());
} else {
assert!(false, "unexpected");
}
}
}
9 changes: 8 additions & 1 deletion datafusion/src/physical_plan/functions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -571,7 +571,9 @@ pub fn create_physical_fun(
BuiltinScalarFunction::Log10 => Arc::new(math_expressions::log10),
BuiltinScalarFunction::Log2 => Arc::new(math_expressions::log2),
BuiltinScalarFunction::Random => Arc::new(math_expressions::random),
BuiltinScalarFunction::Round => Arc::new(math_expressions::round),
BuiltinScalarFunction::Round => {
Arc::new(|args| make_scalar_function(math_expressions::round)(args))
}
BuiltinScalarFunction::Signum => Arc::new(math_expressions::signum),
BuiltinScalarFunction::Sin => Arc::new(math_expressions::sin),
BuiltinScalarFunction::Sqrt => Arc::new(math_expressions::sqrt),
Expand Down Expand Up @@ -1279,6 +1281,11 @@ fn signature(fun: &BuiltinScalarFunction) -> Signature {
Signature::Exact(vec![DataType::Utf8, DataType::Utf8, DataType::Utf8]),
Signature::Exact(vec![DataType::LargeUtf8, DataType::Utf8, DataType::Utf8]),
]),
BuiltinScalarFunction::Round => Signature::OneOf(vec![
Signature::Uniform(1, vec![DataType::Float64, DataType::Float32]),
Signature::Exact(vec![DataType::Float64, DataType::Int64]),
Signature::Exact(vec![DataType::Float32, DataType::Int64]),
]),
BuiltinScalarFunction::Random => Signature::Exact(vec![]),
// math expressions expect 1 argument of type f64 or f32
// priority is given to f64 because e.g. `sqrt(1i32)` is in IR (real numbers) and thus we
Expand Down
126 changes: 123 additions & 3 deletions datafusion/src/physical_plan/math_expressions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,10 @@
//! Math expressions
use super::{ColumnarValue, ScalarValue};
use crate::error::{DataFusionError, Result};
use arrow::array::{Float32Array, Float64Array};
use arrow::array::{ArrayRef, Float32Array, Float64Array, Int64Array};
use arrow::datatypes::DataType;
use rand::{thread_rng, Rng};
use std::any::type_name;
use std::iter;
use std::sync::Arc;

Expand Down Expand Up @@ -84,6 +85,33 @@ macro_rules! math_unary_function {
};
}

macro_rules! downcast_arg {
($ARG:expr, $NAME:expr, $ARRAY_TYPE:ident) => {{
$ARG.as_any().downcast_ref::<$ARRAY_TYPE>().ok_or_else(|| {
DataFusionError::Internal(format!(
"could not cast {} to {}",
$NAME,
type_name::<$ARRAY_TYPE>()
))
})?
}};
}

macro_rules! make_function_inputs2 {
($ARG1: expr, $ARG2: expr, $NAME1:expr, $NAME2: expr, $ARRAY_TYPE1:ident, $ARRAY_TYPE2:ident, $FUNC: block) => {{
let arg1 = downcast_arg!($ARG1, $NAME1, $ARRAY_TYPE1);
let arg2 = downcast_arg!($ARG2, $NAME2, $ARRAY_TYPE2);

arg1.iter()
.zip(arg2.iter())
.map(|(a1, a2)| match (a1, a2) {
(Some(a1), Some(a2)) => Some($FUNC(a1, a2.try_into().ok()?)),
_ => None,
})
.collect::<$ARRAY_TYPE1>()
}};
}

math_unary_function!("sqrt", sqrt);
math_unary_function!("sin", sin);
math_unary_function!("cos", cos);
Expand All @@ -93,7 +121,6 @@ math_unary_function!("acos", acos);
math_unary_function!("atan", atan);
math_unary_function!("floor", floor);
math_unary_function!("ceil", ceil);
math_unary_function!("round", round);
math_unary_function!("trunc", trunc);
math_unary_function!("abs", abs);
math_unary_function!("signum", signum);
Expand All @@ -118,11 +145,64 @@ pub fn random(args: &[ColumnarValue]) -> Result<ColumnarValue> {
Ok(ColumnarValue::Array(Arc::new(array)))
}

/// Round SQL function
pub fn round(args: &[ArrayRef]) -> Result<ArrayRef> {
if args.len() != 1 && args.len() != 2 {
return Err(DataFusionError::Internal(format!(
"round function requires one or two arguments, got {}",
args.len()
)));
}

let mut decimal_places =
&(Arc::new(Int64Array::from_value(0, args[0].len())) as ArrayRef);

if args.len() == 2 {
decimal_places = &args[1];
}

match args[0].data_type() {
DataType::Float64 => Ok(Arc::new(make_function_inputs2!(
&args[0],
decimal_places,
"value",
"decimal_places",
Float64Array,
Int64Array,
{
|value: f64, decimal_places: i64| {
(value * 10.0_f64.powi(decimal_places.try_into().unwrap())).round()
/ 10.0_f64.powi(decimal_places.try_into().unwrap())
}
}
)) as ArrayRef),

DataType::Float32 => Ok(Arc::new(make_function_inputs2!(
&args[0],
decimal_places,
"value",
"decimal_places",
Float32Array,
Int64Array,
{
|value: f32, decimal_places: i64| {
(value * 10.0_f32.powi(decimal_places.try_into().unwrap())).round()
/ 10.0_f32.powi(decimal_places.try_into().unwrap())
}
}
)) as ArrayRef),

other => Err(DataFusionError::Internal(format!(
"Unsupported data type {other:?} for function round"
))),
}
}

#[cfg(test)]
mod tests {

use super::*;
use arrow::array::{Float64Array, NullArray};
use arrow::array::{Float32Array, Float64Array, NullArray};

#[test]
fn test_random_expression() {
Expand All @@ -133,4 +213,44 @@ mod tests {
assert_eq!(floats.len(), 1);
assert!(0.0 <= floats.value(0) && floats.value(0) < 1.0);
}

#[test]
fn test_round_f32() {
let args: Vec<ArrayRef> = vec![
Arc::new(Float32Array::from(vec![125.2345; 10])), // input
Arc::new(Int64Array::from(vec![0, 1, 2, 3, 4, 5, -1, -2, -3, -4])), // decimal_places
];

let result = round(&args).expect("failed to initialize function round");
let floats = result
.as_any()
.downcast_ref::<Float32Array>()
.expect("failed to initialize function round");

let expected = Float32Array::from(vec![
125.0, 125.2, 125.23, 125.235, 125.2345, 125.2345, 130.0, 100.0, 0.0, 0.0,
]);

assert_eq!(floats, &expected);
}

#[test]
fn test_round_f64() {
let args: Vec<ArrayRef> = vec![
Arc::new(Float64Array::from(vec![125.2345; 10])), // input
Arc::new(Int64Array::from(vec![0, 1, 2, 3, 4, 5, -1, -2, -3, -4])), // decimal_places
];

let result = round(&args).expect("failed to initialize function round");
let floats = result
.as_any()
.downcast_ref::<Float64Array>()
.expect("failed to initialize function round");

let expected = Float64Array::from(vec![
125.0, 125.2, 125.23, 125.235, 125.2345, 125.2345, 130.0, 100.0, 0.0, 0.0,
]);

assert_eq!(floats, &expected);
}
}
5 changes: 4 additions & 1 deletion datafusion/src/physical_plan/planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -345,7 +345,10 @@ impl DefaultPhysicalPlanner {
extension_planners.insert(1, Arc::new(CrossJoinPlanner {}));
extension_planners.insert(2, Arc::new(CrossJoinAggPlanner {}));
extension_planners.insert(3, Arc::new(crate::cube_ext::rolling::Planner {}));
Self { should_evaluate_constants: true, extension_planners }
Self {
should_evaluate_constants: true,
extension_planners,
}
}

/// Create a physical plan from a logical plan
Expand Down
Loading