Skip to content

Commit

Permalink
feat(struct): add struct field access
Browse files Browse the repository at this point in the history
  • Loading branch information
cpcloud committed May 12, 2022
1 parent f87121c commit 26c329a
Show file tree
Hide file tree
Showing 3 changed files with 147 additions and 16 deletions.
69 changes: 56 additions & 13 deletions ibis_substrait/compiler/decompile.py
Original file line number Diff line number Diff line change
Expand Up @@ -489,9 +489,8 @@ def _decompile_with_name(
decompiler: SubstraitDecompiler,
names: Deque[str],
) -> ir.ValueExpr:
ibis_expr = decompile(expr, children, field_offsets, decompiler).name(
names.popleft()
)
expr_name = names.popleft()
ibis_expr = decompile(expr, children, field_offsets, decompiler).name(expr_name)

# remove child names since those are encoded in the data type
_remove_names_below(names, ibis_expr.type())
Expand Down Expand Up @@ -652,6 +651,25 @@ def _decompile_set_op_rel(
)


@functools.singledispatch
def _get_field(
child: ir.TableExpr | ir.StructValue,
relative_offset: int,
) -> ir.ValueExpr:
raise NotImplementedError(f"accessing field of type {type(child)} is not supported")


@_get_field.register(ir.TableExpr)
def _get_field_table_expr(child: ir.TableExpr, relative_offset: int) -> ir.ValueExpr:
return child[relative_offset]


@_get_field.register(ir.StructValue)
def _get_field_struct(child: ir.TableExpr, relative_offset: int) -> ir.ValueExpr:
field_type = child.type()
return child[field_type.names[relative_offset]]


class ExpressionDecompiler:
@staticmethod
def decompile_literal(
Expand All @@ -663,12 +681,29 @@ def decompile_literal(
value, dtype = decompile(literal)
return ibis.literal(value, type=dtype)

@staticmethod
def _decompile_struct_field(
struct_field: stalg.Expression.ReferenceSegment.StructField,
field_offsets: Sequence[int],
children: Sequence[ir.TableExpr | ir.StructValue],
) -> ir.ValueExpr:
absolute_offset = struct_field.field

# get the index of the child relation from a sequence of field_offsets
child_index = bisect.bisect_right(field_offsets, absolute_offset) - 1
child = children[child_index]

# return the field index of the child
relative_offset = absolute_offset - field_offsets[child_index]

return _get_field(child, relative_offset)

@staticmethod
def decompile_selection(
ref: stalg.Expression.FieldReference,
children: Sequence[ir.TableExpr],
field_offsets: Sequence[int],
_: SubstraitDecompiler,
decompiler: SubstraitDecompiler,
) -> ir.ValueExpr:
ref_type, ref_variant = which_one_of(ref, "reference_type")
if ref_type != "direct_reference":
Expand All @@ -678,22 +713,30 @@ def decompile_selection(

assert isinstance(ref_variant, stalg.Expression.ReferenceSegment)

direct_ref_type, direct_ref_variant = which_one_of(
_, struct_field = which_one_of(
ref_variant,
"reference_type",
)
assert isinstance(
direct_ref_variant,
struct_field,
stalg.Expression.ReferenceSegment.StructField,
)
absolute_offset = direct_ref_variant.field

# get the index of the child relation from a sequence of field_offsets
child_relation_index = bisect.bisect_right(field_offsets, absolute_offset) - 1
child = children[child_relation_index]
# return the field of the child
relative_offset = absolute_offset - field_offsets[child_relation_index]
return child[relative_offset]
result = ExpressionDecompiler._decompile_struct_field(
struct_field,
field_offsets,
children,
)

while struct_field.HasField("child"):
struct_field = struct_field.child.struct_field
result = ExpressionDecompiler._decompile_struct_field(
struct_field,
field_offsets,
# the new child is always the previous result
children=[result],
)
return result

@staticmethod
def decompile_scalar_function(
Expand Down
31 changes: 31 additions & 0 deletions ibis_substrait/compiler/translate.py
Original file line number Diff line number Diff line change
Expand Up @@ -597,6 +597,37 @@ def table_column(
)


@translate.register(ops.StructField)
def struct_field(
op: ops.StructField,
_: ir.ColumnExpr,
compiler: SubstraitCompiler,
**kwargs: Any,
) -> stalg.Expression:
child = translate(op.arg, compiler=compiler, **kwargs)
field_name = op.field
# TODO: store names/types inverse mapping on datatypes.Struct for O(1)
# access to field index
field_index = op.arg.type().names.index(field_name)

struct_field = child.selection.direct_reference.struct_field

# keep digging until we bottom out on the deepest child reference which
# becomes the parent for the returned reference
while struct_field.HasField("child"):
struct_field = struct_field.child.struct_field

struct_field.child.MergeFrom(
stalg.Expression.ReferenceSegment(
struct_field=stalg.Expression.ReferenceSegment.StructField(
field=field_index
)
)
)

return child


@translate.register(ops.UnboundTable)
def unbound_table(
op: ops.UnboundTable, expr: ir.TableExpr, _: SubstraitCompiler, **kwargs: Any
Expand Down
63 changes: 60 additions & 3 deletions ibis_substrait/tests/compiler/test_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,17 +218,74 @@ def test_ibis_schema_to_substrait_schema():
assert translate(input) == expected


@pytest.mark.parametrize(("name", "expected_offset"), [("a", 0), ("b", 1), ("c", 2)])
def test_simple_field_access(compiler, name, expected_offset):
t = ibis.table(
dict(
a="string",
b="struct<a: string, b: int64, c: float64>",
c="array<array<float64>>",
)
)
expr = t[name]
expected = json_format.ParseDict(
{
"selection": {
"direct_reference": {"struct_field": {"field": expected_offset}},
"root_reference": {},
}
},
stalg.Expression(),
)
result = translate(expr, compiler)
assert result == expected


@pytest.mark.parametrize(("name", "expected_offset"), [("a", 0), ("b", 1), ("c", 2)])
def test_struct_field_access(compiler, name, expected_offset):
t = ibis.table([("f", "struct<a: string, b: int64, c: float64>")])
t = ibis.table(dict(f="struct<a: string, b: int64, c: float64>"))
expr = t.f[name]
expected = json_format.ParseDict(
{
"selection": {
"direct_reference": {
"struct_field": {
"field": expected_offset,
"child": {"struct_field": {}},
"field": 0,
"child": {"struct_field": {"field": expected_offset}},
}
},
"root_reference": {},
}
},
stalg.Expression(),
)
result = translate(expr, compiler)
assert result == expected


def test_nested_struct_field_access(compiler):
t = ibis.table(
dict(f="struct<a: struct<a: int64, b: struct<a: int64, b: int64, c: int64>>>")
)
# 0 0 1 2
expr = t.f["a"]["b"]["c"]
expected = json_format.ParseDict(
{
"selection": {
"direct_reference": {
"struct_field": {
"field": 0,
"child": {
"struct_field": {
"field": 0,
"child": {
"struct_field": {
"field": 1,
"child": {"struct_field": {"field": 2}},
}
},
}
},
}
},
"root_reference": {},
Expand Down

0 comments on commit 26c329a

Please sign in to comment.