From 66e93e156c98f3bb3024124fb62b3533a02c4895 Mon Sep 17 00:00:00 2001 From: Gil Forsyth Date: Wed, 27 Apr 2022 10:26:25 -0400 Subject: [PATCH] fix: dispatch underlying op for aliases and decimal literal `ops.Alias` was added in Ibis recently and the dispatcher was defaulting to `ops.ValueOp` when it received one -- instead this added in a specific Alias dispatch rule to pull out the underlying op. Need to add proper decimal handling, but for now this at least keeps things consistent between 2.1.1 and 3.x --- ibis_substrait/compiler/translate.py | 21 +++++++++++++++++++ ibis_substrait/tests/compiler/test_literal.py | 2 +- poetry.lock | 4 ++-- pyproject.toml | 1 + 4 files changed, 25 insertions(+), 3 deletions(-) diff --git a/ibis_substrait/compiler/translate.py b/ibis_substrait/compiler/translate.py index 48c283f9..4de7f016 100644 --- a/ibis_substrait/compiler/translate.py +++ b/ibis_substrait/compiler/translate.py @@ -8,17 +8,20 @@ import collections import collections.abc import datetime +import decimal import functools import itertools import operator import uuid from typing import Any, Mapping, MutableMapping, Sequence, TypeVar +import ibis import ibis.expr.datatypes as dt import ibis.expr.operations as ops import ibis.expr.schema as sch import ibis.expr.types as ir from ibis import util +from packaging import version from ..proto.substrait import algebra_pb2 as stalg from ..proto.substrait import type_pb2 as stt @@ -264,6 +267,11 @@ def _literal_float64(_: dt.Float64, value: float) -> stalg.Expression.Literal: return stalg.Expression.Literal(fp64=value) +@translate_literal.register +def _literal_decimal(_: dt.Decimal, value: decimal.Decimal) -> stalg.Expression.Literal: + raise NotImplementedError + + @translate_literal.register def _literal_string(_: dt.String, value: str) -> stalg.Expression.Literal: return stalg.Expression.Literal(string=value) @@ -441,6 +449,19 @@ def _translate_window_bounds( return translate_preceding(*preceding), translate_following(*following) +if version.parse(ibis.__version__) >= version.parse("3.0.0"): + + @translate.register(ops.Alias) + def alias_op( + op: ops.Alias, + expr: ir.ValueExpr, + compiler: SubstraitCompiler, + **kwargs: Any, + ) -> stalg.Expression: + # For an alias, dispatch on the underlying argument + return translate(op.arg.op(), op.arg, compiler, **kwargs) + + @translate.register(ops.ValueOp) def value_op( op: ops.ValueOp, diff --git a/ibis_substrait/tests/compiler/test_literal.py b/ibis_substrait/tests/compiler/test_literal.py index 2d684788..88b94993 100644 --- a/ibis_substrait/tests/compiler/test_literal.py +++ b/ibis_substrait/tests/compiler/test_literal.py @@ -303,7 +303,7 @@ def test_literal(compiler, expr, ir): @pytest.mark.xfail( - raises=ibis.common.exceptions.IbisTypeError, + raises=(ibis.common.exceptions.IbisTypeError, NotImplementedError), reason="Ibis doesn't allow decimal values through validation", ) def test_decimal_literal(compiler): diff --git a/poetry.lock b/poetry.lock index 3df3b1c0..4d70d82b 100644 --- a/poetry.lock +++ b/poetry.lock @@ -200,9 +200,9 @@ tabulate = ">=0.8.9,<1" toolz = ">=0.11,<0.12" [package.extras] -all = ["clickhouse-cityhash (>=1.0.2,<2)", "clickhouse-driver (>=0.1,<0.3)", "dask[array,dataframe] (>=2021.10.0)", "datafusion (>=0.4,<0.6)", "duckdb (>=0.3.2,<0.4.0)", "duckdb-engine (>=0.1.8,<0.2.0)", "fsspec (>=2022.1.0)", "GeoAlchemy2 (>=0.6.3,<0.12)", "geopandas (>=0.6,<0.11)", "graphviz (>=0.16,<0.21)", "impyla[kerberos] (>=0.17,<0.19)", "lz4 (>=3.1.10,<5)", "psycopg2 (>=2.8.4,<3)", "pyarrow (>=1,<8)", "pymysql (>=1,<2)", "pyspark (>=3,<4)", "requests (>=2,<3)", "Shapely (>=1.6,<1.8.2)", "sqlalchemy (>=1.4,<2.0)"] +all = ["clickhouse-cityhash (>=1.0.2,<2)", "clickhouse-driver (>=0.1,<0.3)", "dask[dataframe,array] (>=2021.10.0)", "datafusion (>=0.4,<0.6)", "duckdb (>=0.3.2,<0.4.0)", "duckdb-engine (>=0.1.8,<0.2.0)", "fsspec (>=2022.1.0)", "GeoAlchemy2 (>=0.6.3,<0.12)", "geopandas (>=0.6,<0.11)", "graphviz (>=0.16,<0.21)", "impyla[kerberos] (>=0.17,<0.19)", "lz4 (>=3.1.10,<5)", "psycopg2 (>=2.8.4,<3)", "pyarrow (>=1,<8)", "pymysql (>=1,<2)", "pyspark (>=3,<4)", "requests (>=2,<3)", "Shapely (>=1.6,<1.8.2)", "sqlalchemy (>=1.4,<2.0)"] clickhouse = ["clickhouse-cityhash (>=1.0.2,<2)", "clickhouse-driver (>=0.1,<0.3)", "lz4 (>=3.1.10,<5)"] -dask = ["dask[array,dataframe] (>=2021.10.0)", "pyarrow (>=1,<8)"] +dask = ["dask[dataframe,array] (>=2021.10.0)", "pyarrow (>=1,<8)"] datafusion = ["datafusion (>=0.4,<0.6)"] duckdb = ["duckdb (>=0.3.2,<0.4.0)", "duckdb-engine (>=0.1.8,<0.2.0)", "sqlalchemy (>=1.4,<2.0)"] impala = ["fsspec (>=2022.1.0)", "impyla[kerberos] (>=0.17,<0.19)", "requests (>=2,<3)"] diff --git a/pyproject.toml b/pyproject.toml index 2b7a5d02..2dda83f9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -23,6 +23,7 @@ python = ">=3.8,<3.11" ibis-framework = ">=2,<4" protobuf = "^3.19.4" platformdirs = "<2.5.2" +packaging = "^21.3" [tool.poetry.dev-dependencies] black = "^21.9b0"