Skip to content

Commit

Permalink
Implement initial decimal support (#1073)
Browse files Browse the repository at this point in the history
* Implement initial decimal support

* Remove unneccesary import

* Remove duplicates

* Add first batch of tests

* Remove todo

* Address reviews

* Address more reviews

* Unblock creation of scalar decimals

* Add scalar filter test

* Expose precision/scale to python, use when deciding to cast decimal columns

* Fix typo in dask_sql/mappings.py

Co-authored-by: Ayush Dattagupta <ayushdg95@gmail.com>

---------

Co-authored-by: Chris Jarrett <cjarrett@ipp1-3302.aselab.nvidia.com>
Co-authored-by: Ayush Dattagupta <ayushdg95@gmail.com>
Co-authored-by: Charles Blackmon-Luca <20627856+charlesbluca@users.noreply.github.com>
  • Loading branch information
4 people committed Apr 26, 2023
1 parent 9395804 commit 218ff24
Show file tree
Hide file tree
Showing 8 changed files with 230 additions and 13 deletions.
23 changes: 23 additions & 0 deletions dask_planner/src/expression.rs
Original file line number Diff line number Diff line change
Expand Up @@ -598,6 +598,29 @@ impl PyExpr {
}))
}

/// Gets the precision/scale represented by the Expression's decimal datatype
#[pyo3(name = "getPrecisionScale")]
pub fn get_precision_scale(&self) -> PyResult<(u8, i8)> {
Ok(match &self.expr {
Expr::Cast(Cast { expr: _, data_type }) => match data_type {
DataType::Decimal128(precision, scale) | DataType::Decimal256(precision, scale) => {
(*precision, *scale)
}
_ => {
return Err(py_type_err(format!(
"Catch all triggered for Cast in get_precision_scale; {data_type:?}"
)))
}
},
_ => {
return Err(py_type_err(format!(
"Catch all triggered in get_precision_scale; {:?}",
&self.expr
)))
}
})
}

#[pyo3(name = "getFilterExpr")]
pub fn get_filter_expr(&self) -> PyResult<Option<PyExpr>> {
// TODO refactor to avoid duplication
Expand Down
44 changes: 43 additions & 1 deletion dask_planner/src/sql/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use datafusion::arrow::datatypes::{DataType, IntervalUnit, TimeUnit};
use datafusion_sql::sqlparser::{ast::DataType as SQLType, parser::Parser, tokenizer::Tokenizer};
use pyo3::{prelude::*, types::PyDict};

use crate::{dialect::DaskDialect, error::DaskPlannerError};
use crate::{dialect::DaskDialect, error::DaskPlannerError, sql::exceptions::py_type_err};

#[derive(Debug, Clone, PartialEq, Eq, Hash, PartialOrd, Ord)]
#[pyclass(name = "RexType", module = "datafusion")]
Expand Down Expand Up @@ -111,6 +111,29 @@ impl DaskTypeMap {
};
DataType::Timestamp(unit, tz)
}
SqlTypeName::DECIMAL => {
let (precision, scale) = match py_kwargs {
Some(dict) => {
let precision: u8 = match dict.get_item("precision") {
Some(e) => {
let res: PyResult<u8> = e.extract();
res.unwrap()
}
None => 38,
};
let scale: i8 = match dict.get_item("scale") {
Some(e) => {
let res: PyResult<i8> = e.extract();
res.unwrap()
}
None => 0,
};
(precision, scale)
}
None => (38, 10),
};
DataType::Decimal128(precision, scale)
}
_ => sql_type.to_arrow()?,
};

Expand Down Expand Up @@ -141,6 +164,25 @@ pub struct PyDataType {
data_type: DataType,
}

#[pymethods]
impl PyDataType {
/// Gets the precision/scale represented by the PyDataType's decimal datatype
#[pyo3(name = "getPrecisionScale")]
pub fn get_precision_scale(&self) -> PyResult<(u8, i8)> {
Ok(match &self.data_type {
DataType::Decimal128(precision, scale) | DataType::Decimal256(precision, scale) => {
(*precision, *scale)
}
_ => {
return Err(py_type_err(format!(
"Catch all triggered in get_precision_scale, {:?}",
&self.data_type
)))
}
})
}
}

impl From<PyDataType> for DataType {
fn from(data_type: PyDataType) -> DataType {
data_type.data_type
Expand Down
41 changes: 33 additions & 8 deletions dask_sql/mappings.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import logging
from decimal import Decimal
from typing import Any

import dask.array as da
Expand All @@ -9,8 +10,14 @@
from dask_planner.rust import DaskTypeMap, SqlTypeName
from dask_sql._compat import FLOAT_NAN_IMPLEMENTED

try:
import cudf
except ImportError:
cudf = None

logger = logging.getLogger(__name__)


# Default mapping between python types and SQL types
_PYTHON_TO_SQL = {
np.float64: SqlTypeName.DOUBLE,
Expand Down Expand Up @@ -51,7 +58,7 @@
_SQL_TO_PYTHON_SCALARS = {
"SqlTypeName.DOUBLE": np.float64,
"SqlTypeName.FLOAT": np.float32,
"SqlTypeName.DECIMAL": np.float32,
"SqlTypeName.DECIMAL": Decimal,
"SqlTypeName.BIGINT": np.int64,
"SqlTypeName.INTEGER": np.int32,
"SqlTypeName.SMALLINT": np.int16,
Expand All @@ -68,7 +75,8 @@
_SQL_TO_PYTHON_FRAMES = {
"SqlTypeName.DOUBLE": np.float64,
"SqlTypeName.FLOAT": np.float32,
"SqlTypeName.DECIMAL": np.float64, # We use np.float64 always, even though we might be able to use a smaller type
# a column of Decimals in pandas is `object`, but cuDF has a dedicated dtype
"SqlTypeName.DECIMAL": object if not cudf else cudf.Decimal128Dtype(38, 10),
"SqlTypeName.BIGINT": pd.Int64Dtype(),
"SqlTypeName.INTEGER": pd.Int32Dtype(),
"SqlTypeName.SMALLINT": pd.Int16Dtype(),
Expand Down Expand Up @@ -107,6 +115,13 @@ def python_to_sql_type(python_type) -> "DaskTypeMap":
tz=str(python_type.tz),
)

if is_decimal(python_type):
return DaskTypeMap(
SqlTypeName.DECIMAL,
precision=python_type.precision,
scale=python_type.scale,
)

try:
return DaskTypeMap(_PYTHON_TO_SQL[python_type])
except KeyError: # pragma: no cover
Expand Down Expand Up @@ -179,10 +194,6 @@ def sql_to_python_value(sql_type: "SqlTypeName", literal_value: Any) -> Any:
if sql_type == SqlTypeName.DATE:
return literal_value.astype("<M8[D]")
return literal_value.astype("<M8[ns]")
elif sql_type == SqlTypeName.DECIMAL:
# We use np.float64 always, even though we might
# be able to use a smaller type
python_type = np.float64
else:
try:
python_type = _SQL_TO_PYTHON_SCALARS[str(sql_type)]
Expand All @@ -203,9 +214,11 @@ def sql_to_python_value(sql_type: "SqlTypeName", literal_value: Any) -> Any:
return python_type(literal_value)


def sql_to_python_type(sql_type: "SqlTypeName") -> type:
def sql_to_python_type(sql_type: "SqlTypeName", *args) -> type:
"""Turn an SQL type into a dataframe dtype"""
try:
if str(sql_type) == "SqlTypeName.DECIMAL":
return cudf.Decimal128Dtype(*args)
return _SQL_TO_PYTHON_FRAMES[str(sql_type)]
except KeyError: # pragma: no cover
raise NotImplementedError(
Expand Down Expand Up @@ -239,15 +252,20 @@ def similar_type(lhs: type, rhs: type) -> bool:
is_sint,
is_float,
is_object,
is_string,
# is_string_dtype considers decimal columns to be string columns
lambda x: is_string(x) and not is_decimal(x),
is_dt_tz,
is_dt_ntz,
is_td_ns,
is_bool,
is_decimal,
]

for check in checks:
if check(lhs) and check(rhs):
# check that decimal columns have equal precision/scale
if check is is_decimal:
return lhs.precision == rhs.precision and lhs.scale == rhs.scale
return True

return False
Expand Down Expand Up @@ -298,3 +316,10 @@ def cast_column_to_type(col: dd.Series, expected_type: str):

logger.debug(f"Need to cast from {current_type} to {expected_type}")
return col.astype(expected_type)


def is_decimal(dtype):
"""
Check if dtype is a decimal type
"""
return "decimal" in str(dtype).lower()
8 changes: 7 additions & 1 deletion dask_sql/physical/rel/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,13 @@ def fix_dtype_to_row_type(dc: DataContainer, row_type: "RelDataType"):
}

for field_name, field_type in field_types.items():
expected_type = sql_to_python_type(field_type.getSqlType())
sql_type = field_type.getSqlType()
sql_type_args = tuple()

if str(sql_type) == "SqlTypeName.DECIMAL":
sql_type_args = field_type.getDataType().getPrecisionScale()

expected_type = sql_to_python_type(sql_type, *sql_type_args)
df_field_name = cc.get_backend_by_frontend_name(field_name)
df = cast_column_type(df, df_field_name, expected_type)

Expand Down
11 changes: 8 additions & 3 deletions dask_sql/physical/rex/core/call.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,13 +243,18 @@ def __init__(self):
super().__init__(self.cast)

def cast(self, operand, rex=None) -> SeriesOrScalar:
output_type = str(rex.getType())
sql_type = SqlTypeName.fromString(output_type.upper())
output_type = rex.getType()
sql_type = SqlTypeName.fromString(output_type)
sql_type_args = ()

# decimal datatypes require precision and scale
if output_type == "DECIMAL":
sql_type_args = rex.getPrecisionScale()

if not is_frame(operand): # pragma: no cover
return sql_to_python_value(sql_type, operand)

python_type = sql_to_python_type(sql_type)
python_type = sql_to_python_type(sql_type, *sql_type_args)

return_column = cast_column_to_type(operand, python_type)

Expand Down
48 changes: 48 additions & 0 deletions tests/integration/test_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,3 +253,51 @@ def test_filtered_csv(tmpdir, c):
expected_df = df[df["b"] < 10]

assert_eq(return_df, expected_df)


@pytest.mark.gpu
def test_filter_decimal(c):
import cudf

df = cudf.DataFrame(
{
"a": [304.5, 35.305, 9.043, 102.424, 53.34],
"b": [2.2, 82.4, 42, 76.9, 54.4],
"c": [1, 2, 2, 5, 9],
}
)
df["a"] = df["a"].astype(cudf.Decimal64Dtype(12, 3))
df["b"] = df["b"].astype(cudf.Decimal64Dtype(7, 1))
c.create_table("df", df)

result_df = c.sql(
"""
SELECT
c
FROM
df
WHERE
a < b
"""
)

expected_df = df.loc[df.a < df.b][["c"]]

assert_eq(result_df, expected_df)

result_df = c.sql(
"""
SELECT
b
FROM
df
WHERE
a < decimal '100.2'
"""
)

expected_df = cudf.DataFrame({"b": [82.4, 42, 54.4]})
expected_df["b"] = expected_df["b"].astype(cudf.Decimal64Dtype(7, 1))

assert_eq(result_df.reset_index(drop=True), expected_df)
c.drop_table("df")
58 changes: 58 additions & 0 deletions tests/integration/test_groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -596,3 +596,61 @@ def test_groupby_split_every(c, gpu):
assert_eq(split_every_4_df, expected_df, check_index=False)

c.drop_table("split_every_input")


@pytest.mark.gpu
def test_agg_decimal(c):
import cudf

df = cudf.DataFrame(
{
"a": [1.23, 12.65, 134.64, -34.3, 945.19],
"b": [1, 1, 2, 2, 3],
}
)
df["a"] = df["a"].astype(cudf.Decimal64Dtype(10, 2))

c.create_table("df", df, gpu=True)

result_df = c.sql(
"""
SELECT
SUM(a) as s,
COUNT(a) as c,
SUM(a+a) as s2
FROM
df
GROUP BY
b
"""
)

expected_df = cudf.DataFrame(
{
"s": df.groupby("b").sum()["a"],
"c": df.groupby("b").count()["a"].astype("int64"),
"s2": df.groupby("b").sum()["a"] + df.groupby("b").sum()["a"],
}
)

assert_eq(result_df, expected_df.reset_index(drop=True))

result_df = c.sql(
"""
SELECT
MIN(a) as min,
MAX(a) as max
FROM
df
"""
)

expected_df = cudf.DataFrame(
{
"min": [df.a.min()],
"max": [df.a.max()],
}
)

assert_eq(result_df, expected_df)
c.drop_table("df")
10 changes: 10 additions & 0 deletions tests/unit/test_mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import numpy as np
import pandas as pd
import pytest

from dask_planner.rust import SqlTypeName
from dask_sql.mappings import python_to_sql_type, similar_type, sql_to_python_value
Expand All @@ -16,6 +17,15 @@ def test_python_to_sql():
)


@pytest.mark.gpu
def test_python_decimal_to_sql():
import cudf

assert str(python_to_sql_type(cudf.Decimal64Dtype(12, 3))) == "DECIMAL"
assert str(python_to_sql_type(cudf.Decimal128Dtype(32, 12))) == "DECIMAL"
assert str(python_to_sql_type(cudf.Decimal32Dtype(5, -2))) == "DECIMAL"


def test_sql_to_python():
assert sql_to_python_value(SqlTypeName.VARCHAR, "test 123") == "test 123"
assert type(sql_to_python_value(SqlTypeName.BIGINT, 653)) == np.int64
Expand Down

0 comments on commit 218ff24

Please sign in to comment.