Skip to content

Commit

Permalink
feat(cubesql): Support type coercion for IF function
Browse files Browse the repository at this point in the history
  • Loading branch information
ovr committed Dec 8, 2021
1 parent 434084e commit 3b3f48c
Show file tree
Hide file tree
Showing 5 changed files with 126 additions and 6 deletions.
53 changes: 53 additions & 0 deletions rust/cubesql/src/compile/engine/df/coerce.rs
@@ -0,0 +1,53 @@
use datafusion::arrow::datatypes::DataType;

pub fn is_signed_numeric(dt: &DataType) -> bool {
matches!(
dt,
DataType::Int8
| DataType::Int16
| DataType::Int32
| DataType::Int64
| DataType::Float16
| DataType::Float32
| DataType::Float64
)
}

pub fn is_numeric(dt: &DataType) -> bool {
is_signed_numeric(dt)
|| match dt {
DataType::UInt8 | DataType::UInt16 | DataType::UInt32 | DataType::UInt64 => true,
_ => false,
}
}

pub fn numerical_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option<DataType> {
// error on any non-numeric type
if !is_numeric(lhs_type) || !is_numeric(rhs_type) {
return None;
};

// same type => all good
if lhs_type == rhs_type {
return Some(lhs_type.clone());
}

match (lhs_type, rhs_type) {
(_, DataType::UInt64) => Some(DataType::UInt64),
(DataType::UInt64, _) => Some(DataType::UInt64),
//
(_, DataType::Int64) => Some(DataType::Int64),
(DataType::Int64, _) => Some(DataType::Int64),
//
_ => None,
}
}

pub fn if_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option<DataType> {
// same type => all good
if lhs_type == rhs_type {
return Some(lhs_type.clone());
}

numerical_coercion(lhs_type, rhs_type)
}
1 change: 1 addition & 0 deletions rust/cubesql/src/compile/engine/df/mod.rs
@@ -0,0 +1 @@
pub mod coerce;
1 change: 1 addition & 0 deletions rust/cubesql/src/compile/engine/mod.rs
@@ -1,4 +1,5 @@
pub mod context;
pub mod df;
pub mod information_schema;
pub mod provider;
pub mod udf;
23 changes: 17 additions & 6 deletions rust/cubesql/src/compile/engine/udf.rs
Expand Up @@ -7,6 +7,7 @@ use datafusion::{
ArrayRef, BooleanArray, BooleanBuilder, GenericStringArray, Int32Builder,
PrimitiveArray, StringBuilder, UInt32Builder,
},
compute::cast,
datatypes::{DataType, Int64Type},
},
error::DataFusionError,
Expand All @@ -17,7 +18,7 @@ use datafusion::{
},
};

use crate::compile::QueryPlannerExecutionProps;
use crate::compile::{engine::df::coerce::if_coercion, QueryPlannerExecutionProps};

pub fn create_version_udf() -> ScalarUDF {
let version = make_scalar_function(|_args: &[ArrayRef]| {
Expand Down Expand Up @@ -226,13 +227,16 @@ pub fn create_if_udf() -> ScalarUDF {
let left = &args[1];
let right = &args[2];

if left.data_type() != right.data_type() {
return Err(DataFusionError::Execution(format!(
let base_type = if_coercion(left.data_type(), right.data_type()).ok_or_else(|| {
DataFusionError::Execution(format!(
"positive and negative results must be the same type, actual: [{}, {}]",
left.data_type(),
right.data_type(),
)));
}
))
})?;

let left = cast(&left, &base_type)?;
let right = cast(&right, &base_type)?;

let is_true: bool = match condition.data_type() {
DataType::Boolean => {
Expand All @@ -254,7 +258,14 @@ pub fn create_if_udf() -> ScalarUDF {
let return_type: ReturnTypeFunction = Arc::new(move |types| {
assert!(types.len() == 3);

Ok(Arc::new(types[1].clone()))
let base_type = if_coercion(&types[1], &types[2]).ok_or_else(|| {
DataFusionError::Execution(format!(
"positive and negative results must be the same type, actual: [{}, {}]",
&types[1], &types[2],
))
})?;

Ok(Arc::new(base_type))
});

ScalarUDF::new(
Expand Down
54 changes: 54 additions & 0 deletions rust/cubesql/src/compile/mod.rs
Expand Up @@ -3169,6 +3169,31 @@ mod tests {
Ok(())
}

#[tokio::test]
async fn test_if() -> Result<(), CubeError> {
assert_eq!(
execute_df_query(
r#"select
if(null, true, false) as r1,
if(true, false, true) as r2,
if(true, 'true', 'false') as r3,
if(true, CAST(1 as int), CAST(2 as bigint)) as c1,
if(false, CAST(1 as int), CAST(2 as bigint)) as c2,
if(true, CAST(1 as bigint), CAST(2 as int)) as c3
"#
.to_string()
)
.await?,
"+-------+-------+------+----+----+----+\n\
| r1 | r2 | r3 | c1 | c2 | c3 |\n\
+-------+-------+------+----+----+----+\n\
| false | false | true | 1 | 2 | 1 |\n\
+-------+-------+------+----+----+----+"
);

Ok(())
}

#[tokio::test]
async fn test_least_tz() -> Result<(), CubeError> {
assert_eq!(
Expand Down Expand Up @@ -3199,4 +3224,33 @@ mod tests {

Ok(())
}

#[tokio::test]
async fn test_metabase() -> Result<(), CubeError> {
assert_eq!(
execute_df_query(
"SELECT \
TABLE_SCHEMA TABLE_CAT, NULL TABLE_SCHEM, TABLE_NAME, COLUMN_NAME, \
CASE data_type WHEN 'bit' THEN -7 WHEN 'tinyblob' THEN -3 WHEN 'mediumblob' THEN -4 WHEN 'longblob' THEN -4 WHEN 'blob' THEN -4 WHEN 'tinytext' THEN 12 WHEN 'mediumtext' THEN -1 WHEN 'longtext' THEN -1 WHEN 'text' THEN -1 WHEN 'date' THEN 91 WHEN 'datetime' THEN 93 WHEN 'decimal' THEN 3 WHEN 'double' THEN 8 WHEN 'enum' THEN 12 WHEN 'float' THEN 7 WHEN 'int' THEN IF( COLUMN_TYPE like '%unsigned%', 4,4) WHEN 'bigint' THEN -5 WHEN 'mediumint' THEN 4 WHEN 'null' THEN 0 WHEN 'set' THEN 12 WHEN 'smallint' THEN IF( COLUMN_TYPE like '%unsigned%', 5,5) WHEN 'varchar' THEN 12 WHEN 'varbinary' THEN -3 WHEN 'char' THEN 1 WHEN 'binary' THEN -2 WHEN 'time' THEN 92 WHEN 'timestamp' THEN 93 WHEN 'tinyint' THEN IF(COLUMN_TYPE like 'tinyint(1)%',-7,-6) WHEN 'year' THEN 91 ELSE 1111 END DATA_TYPE, IF(COLUMN_TYPE like 'tinyint(1)%', 'BIT', UCASE(IF( COLUMN_TYPE LIKE '%(%)%', CONCAT(SUBSTRING( COLUMN_TYPE,1, LOCATE('(',COLUMN_TYPE) - 1 ), SUBSTRING(COLUMN_TYPE ,1+locate(')', COLUMN_TYPE))), COLUMN_TYPE))) TYPE_NAME, CASE DATA_TYPE WHEN 'time' THEN IF(DATETIME_PRECISION = 0, 10, CAST(11 + DATETIME_PRECISION as signed integer)) WHEN 'date' THEN 10 WHEN 'datetime' THEN IF(DATETIME_PRECISION = 0, 19, CAST(20 + DATETIME_PRECISION as signed integer)) WHEN 'timestamp' THEN IF(DATETIME_PRECISION = 0, 19, CAST(20 + DATETIME_PRECISION as signed integer)) ELSE IF(NUMERIC_PRECISION IS NULL, LEAST(CHARACTER_MAXIMUM_LENGTH,2147483647), NUMERIC_PRECISION) END COLUMN_SIZE, 65535 BUFFER_LENGTH, CONVERT (CASE DATA_TYPE WHEN 'year' THEN NUMERIC_SCALE WHEN 'tinyint' THEN 0 ELSE NUMERIC_SCALE END, UNSIGNED INTEGER) DECIMAL_DIGITS, 10 NUM_PREC_RADIX, IF(IS_NULLABLE = 'yes',1,0) NULLABLE,COLUMN_COMMENT REMARKS, COLUMN_DEFAULT COLUMN_DEF, 0 SQL_DATA_TYPE, 0 SQL_DATETIME_SUB, LEAST(CHARACTER_OCTET_LENGTH,2147483647) CHAR_OCTET_LENGTH, ORDINAL_POSITION, IS_NULLABLE, NULL SCOPE_CATALOG, NULL SCOPE_SCHEMA, NULL SCOPE_TABLE, NULL SOURCE_DATA_TYPE, IF(EXTRA = 'auto_increment','YES','NO') IS_AUTOINCREMENT, IF(EXTRA in ('VIRTUAL', 'PERSISTENT', 'VIRTUAL GENERATED', 'STORED GENERATED') ,'YES','NO') IS_GENERATEDCOLUMN \
FROM INFORMATION_SCHEMA.COLUMNS WHERE (ISNULL(database()) OR (TABLE_SCHEMA = database())) AND TABLE_NAME = 'KibanaSampleDataEcommerce' \
ORDER BY TABLE_CAT, TABLE_SCHEM, TABLE_NAME, ORDINAL_POSITION;".to_string()
)
.await?,
"+-----------+-------------+---------------------------+--------------------+-----------+--------------+-------------+---------------+----------------+----------------+----------+---------+------------+---------------+------------------+-------------------+------------------+-------------+---------------+--------------+-------------+------------------+------------------+--------------------+\n\
| TABLE_CAT | TABLE_SCHEM | TABLE_NAME | COLUMN_NAME | DATA_TYPE | TYPE_NAME | COLUMN_SIZE | BUFFER_LENGTH | DECIMAL_DIGITS | NUM_PREC_RADIX | NULLABLE | REMARKS | COLUMN_DEF | SQL_DATA_TYPE | SQL_DATETIME_SUB | CHAR_OCTET_LENGTH | ORDINAL_POSITION | IS_NULLABLE | SCOPE_CATALOG | SCOPE_SCHEMA | SCOPE_TABLE | SOURCE_DATA_TYPE | IS_AUTOINCREMENT | IS_GENERATEDCOLUMN |\n\
+-----------+-------------+---------------------------+--------------------+-----------+--------------+-------------+---------------+----------------+----------------+----------+---------+------------+---------------+------------------+-------------------+------------------+-------------+---------------+--------------+-------------+------------------+------------------+--------------------+\n\
| db | NULL | KibanaSampleDataEcommerce | count | 4 | int | 2147483647 | 65535 | 0 | 10 | 0 | | | 0 | 0 | 2147483647 | 0 | NO | NULL | NULL | NULL | NULL | NO | NO |\n\
| db | NULL | KibanaSampleDataEcommerce | maxPrice | 4 | int | 2147483647 | 65535 | 0 | 10 | 0 | | | 0 | 0 | 2147483647 | 0 | NO | NULL | NULL | NULL | NULL | NO | NO |\n\
| db | NULL | KibanaSampleDataEcommerce | minPrice | 4 | int | 2147483647 | 65535 | 0 | 10 | 0 | | | 0 | 0 | 2147483647 | 0 | NO | NULL | NULL | NULL | NULL | NO | NO |\n\
| db | NULL | KibanaSampleDataEcommerce | avgPrice | 4 | int | 2147483647 | 65535 | 0 | 10 | 0 | | | 0 | 0 | 2147483647 | 0 | NO | NULL | NULL | NULL | NULL | NO | NO |\n\
| db | NULL | KibanaSampleDataEcommerce | order_date | 93 | datetime | NULL | 65535 | 0 | 10 | 0 | | | 0 | 0 | 2147483647 | 0 | YES | NULL | NULL | NULL | NULL | NO | NO |\n\
| db | NULL | KibanaSampleDataEcommerce | customer_gender | 12 | varchar(255) | 2147483647 | 65535 | 0 | 10 | 0 | | | 0 | 0 | 2147483647 | 0 | YES | NULL | NULL | NULL | NULL | NO | NO |\n\
| db | NULL | KibanaSampleDataEcommerce | taxful_total_price | 12 | varchar(255) | 2147483647 | 65535 | 0 | 10 | 0 | | | 0 | 0 | 2147483647 | 0 | YES | NULL | NULL | NULL | NULL | NO | NO |\n\
| db | NULL | KibanaSampleDataEcommerce | is_male | 1111 | boolean | 2147483647 | 65535 | 0 | 10 | 0 | | | 0 | 0 | 2147483647 | 0 | NO | NULL | NULL | NULL | NULL | NO | NO |\n\
| db | NULL | KibanaSampleDataEcommerce | is_female | 1111 | boolean | 2147483647 | 65535 | 0 | 10 | 0 | | | 0 | 0 | 2147483647 | 0 | NO | NULL | NULL | NULL | NULL | NO | NO |\n\
+-----------+-------------+---------------------------+--------------------+-----------+--------------+-------------+---------------+----------------+----------------+----------+---------+------------+---------------+------------------+-------------------+------------------+-------------+---------------+--------------+-------------+------------------+------------------+--------------------+"
);

Ok(())
}
}

0 comments on commit 3b3f48c

Please sign in to comment.