Skip to content

Commit

Permalink
feat: add support for ops.Extract<span>
Browse files Browse the repository at this point in the history
Covers `ExtractYear`, `ExtractMonth`, and `ExtractDay`.  Substrait also
allows `SECONDS` as the specified argument but I'm not clear on exactly
what they're looking for there.  Epoch Seconds?
  • Loading branch information
gforsyth committed Jun 28, 2022
1 parent 308d33a commit 2fe7f26
Show file tree
Hide file tree
Showing 4 changed files with 68 additions and 3 deletions.
20 changes: 18 additions & 2 deletions ibis_substrait/compiler/decompile.py
Original file line number Diff line number Diff line change
Expand Up @@ -570,7 +570,7 @@ def _decompile_expression_aggregate_function(
) -> ir.ValueExpr:
extension = decompiler.function_extensions[aggregate_function.function_reference]
function_name = extension.name
op_type = getattr(ops, SUBSTRAIT_IBIS_OP_MAPPING[function_name])
op_type = SUBSTRAIT_IBIS_OP_MAPPING[function_name]
args = [
decompile(arg, children, field_offsets, decompiler)
for arg in aggregate_function.args
Expand Down Expand Up @@ -774,6 +774,15 @@ def decompile_cast(
) -> ir.ValueExpr:
return decompile(cast, children, offsets, decompiler)

@staticmethod
def decompile_enum(
enum: stalg.Expression.Enum,
children: Sequence[ir.TableExpr],
offsets: Sequence[int],
decompiler: SubstraitDecompiler,
) -> ir.ValueExpr:
return decompile(enum)


@decompile.register
def _decompile_expression(
Expand All @@ -800,7 +809,7 @@ def _decompile_expression_scalar_function(
) -> ir.ValueExpr:
extension = decompiler.function_extensions[msg.function_reference]
function_name = extension.name
op_type = getattr(ops, SUBSTRAIT_IBIS_OP_MAPPING[function_name])
op_type = SUBSTRAIT_IBIS_OP_MAPPING[function_name]
args = [decompile(arg, children, field_offsets, decompiler) for arg in msg.args]
expr = op_type(*args).to_expr()
output_type = _decompile_type(msg.output_type)
Expand Down Expand Up @@ -923,6 +932,13 @@ def _decompile_expression_cast(
return column.cast(_decompile_type(msg.type))


@decompile.register
def _decompile_expression_enum(
msg: stalg.Expression.Enum,
) -> str:
return msg.specified


class LiteralDecompiler:
@staticmethod
def decompile_boolean(value: bool) -> tuple[bool, dt.Boolean]:
Expand Down
13 changes: 12 additions & 1 deletion ibis_substrait/compiler/mapping.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import ibis.expr.operations as ops

IBIS_SUBSTRAIT_OP_MAPPING = {
"Add": "add",
"And": "and",
Expand All @@ -7,6 +9,9 @@
"CountDistinct": "countdistinct",
"Divide": "divide",
"Equals": "equal",
"ExtractYear": "extract",
"ExtractMonth": "extract",
"ExtractDay": "extract",
"Greater": "gt",
"GreaterEqual": "gte",
"Less": "lt",
Expand All @@ -25,4 +30,10 @@
"Sum": "sum",
}

SUBSTRAIT_IBIS_OP_MAPPING = {v: k for k, v in IBIS_SUBSTRAIT_OP_MAPPING.items()}
SUBSTRAIT_IBIS_OP_MAPPING = {
v: getattr(ops, k) for k, v in IBIS_SUBSTRAIT_OP_MAPPING.items()
}
# override when reversing many-to-one mappings
SUBSTRAIT_IBIS_OP_MAPPING["extract"] = lambda table, span: getattr(
ops, f"Extract{span.capitalize()}"
)(table)
22 changes: 22 additions & 0 deletions ibis_substrait/compiler/translate.py
Original file line number Diff line number Diff line change
Expand Up @@ -959,3 +959,25 @@ def _cast(
type=translate(op.to), input=translate(op.arg, compiler, **kwargs)
)
)


@translate.register(ops.ExtractDateField)
def _extractdatefield(
op: ops.ExtractDateField,
expr: ir.TableExpr,
compiler: SubstraitCompiler,
**kwargs: Any,
) -> stalg.Expression:
scalar_func = stalg.Expression.ScalarFunction(
function_reference=compiler.function_id(expr),
output_type=translate(expr.type()),
args=[
translate(arg, compiler, **kwargs)
for arg in op.args
if isinstance(arg, ir.Expr)
],
)
# e.g. "ExtractYear" -> "YEAR"
span = type(op).__name__.lstrip("Extract").upper()
scalar_func.args.add(enum=stalg.Expression.Enum(specified=span))
return stalg.Expression(scalar_function=scalar_func)
16 changes: 16 additions & 0 deletions ibis_substrait/tests/compiler/test_decompiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,3 +265,19 @@ def test_searchedcase(compiler):
plan = compiler.compile(expr)
(result,) = decompile(plan)
assert result.equals(expr)


@pytest.mark.parametrize(
"span",
[
"year",
"month",
"day",
],
)
def test_extract_date(compiler, span):
t = ibis.table([("o_orderdate", dt.date)], name="t")
expr = t[getattr(t.o_orderdate, span)()]
plan = compiler.compile(expr)
(result,) = decompile(plan)
assert result.equals(expr)

0 comments on commit 2fe7f26

Please sign in to comment.