diff --git a/ibis_substrait/compiler/decompile.py b/ibis_substrait/compiler/decompile.py index 07147059..9f529b9f 100644 --- a/ibis_substrait/compiler/decompile.py +++ b/ibis_substrait/compiler/decompile.py @@ -570,7 +570,7 @@ def _decompile_expression_aggregate_function( ) -> ir.ValueExpr: extension = decompiler.function_extensions[aggregate_function.function_reference] function_name = extension.name - op_type = getattr(ops, SUBSTRAIT_IBIS_OP_MAPPING[function_name]) + op_type = SUBSTRAIT_IBIS_OP_MAPPING[function_name] args = [ decompile(arg, children, field_offsets, decompiler) for arg in aggregate_function.args @@ -774,6 +774,15 @@ def decompile_cast( ) -> ir.ValueExpr: return decompile(cast, children, offsets, decompiler) + @staticmethod + def decompile_enum( + enum: stalg.Expression.Enum, + children: Sequence[ir.TableExpr], + offsets: Sequence[int], + decompiler: SubstraitDecompiler, + ) -> ir.ValueExpr: + return decompile(enum) + @decompile.register def _decompile_expression( @@ -800,7 +809,7 @@ def _decompile_expression_scalar_function( ) -> ir.ValueExpr: extension = decompiler.function_extensions[msg.function_reference] function_name = extension.name - op_type = getattr(ops, SUBSTRAIT_IBIS_OP_MAPPING[function_name]) + op_type = SUBSTRAIT_IBIS_OP_MAPPING[function_name] args = [decompile(arg, children, field_offsets, decompiler) for arg in msg.args] expr = op_type(*args).to_expr() output_type = _decompile_type(msg.output_type) @@ -923,6 +932,13 @@ def _decompile_expression_cast( return column.cast(_decompile_type(msg.type)) +@decompile.register +def _decompile_expression_enum( + msg: stalg.Expression.Enum, +) -> str: + return msg.specified + + class LiteralDecompiler: @staticmethod def decompile_boolean(value: bool) -> tuple[bool, dt.Boolean]: diff --git a/ibis_substrait/compiler/mapping.py b/ibis_substrait/compiler/mapping.py index 7cb739e0..32cfce6a 100644 --- a/ibis_substrait/compiler/mapping.py +++ b/ibis_substrait/compiler/mapping.py @@ -1,3 +1,5 @@ +import ibis.expr.operations as ops + IBIS_SUBSTRAIT_OP_MAPPING = { "Add": "add", "And": "and", @@ -7,6 +9,9 @@ "CountDistinct": "countdistinct", "Divide": "divide", "Equals": "equal", + "ExtractYear": "extract", + "ExtractMonth": "extract", + "ExtractDay": "extract", "Greater": "gt", "GreaterEqual": "gte", "Less": "lt", @@ -25,4 +30,10 @@ "Sum": "sum", } -SUBSTRAIT_IBIS_OP_MAPPING = {v: k for k, v in IBIS_SUBSTRAIT_OP_MAPPING.items()} +SUBSTRAIT_IBIS_OP_MAPPING = { + v: getattr(ops, k) for k, v in IBIS_SUBSTRAIT_OP_MAPPING.items() +} +# override when reversing many-to-one mappings +SUBSTRAIT_IBIS_OP_MAPPING["extract"] = lambda table, span: getattr( + ops, f"Extract{span.capitalize()}" +)(table) diff --git a/ibis_substrait/compiler/translate.py b/ibis_substrait/compiler/translate.py index 289a0649..2ceb5c8f 100644 --- a/ibis_substrait/compiler/translate.py +++ b/ibis_substrait/compiler/translate.py @@ -959,3 +959,25 @@ def _cast( type=translate(op.to), input=translate(op.arg, compiler, **kwargs) ) ) + + +@translate.register(ops.ExtractDateField) +def _extractdatefield( + op: ops.ExtractDateField, + expr: ir.TableExpr, + compiler: SubstraitCompiler, + **kwargs: Any, +) -> stalg.Expression: + scalar_func = stalg.Expression.ScalarFunction( + function_reference=compiler.function_id(expr), + output_type=translate(expr.type()), + args=[ + translate(arg, compiler, **kwargs) + for arg in op.args + if isinstance(arg, ir.Expr) + ], + ) + # e.g. "ExtractYear" -> "YEAR" + span = type(op).__name__.lstrip("Extract").upper() + scalar_func.args.add(enum=stalg.Expression.Enum(specified=span)) + return stalg.Expression(scalar_function=scalar_func) diff --git a/ibis_substrait/tests/compiler/test_decompiler.py b/ibis_substrait/tests/compiler/test_decompiler.py index da638054..7805d2b4 100644 --- a/ibis_substrait/tests/compiler/test_decompiler.py +++ b/ibis_substrait/tests/compiler/test_decompiler.py @@ -265,3 +265,19 @@ def test_searchedcase(compiler): plan = compiler.compile(expr) (result,) = decompile(plan) assert result.equals(expr) + + +@pytest.mark.parametrize( + "span", + [ + "year", + "month", + "day", + ], +) +def test_extract_date(compiler, span): + t = ibis.table([("o_orderdate", dt.date)], name="t") + expr = t[getattr(t.o_orderdate, span)()] + plan = compiler.compile(expr) + (result,) = decompile(plan) + assert result.equals(expr)