Skip to content

Commit

Permalink
refactor(backends): adjust backends to work with new array representa…
Browse files Browse the repository at this point in the history
…tion
  • Loading branch information
cpcloud authored and kszucs committed Sep 26, 2023
1 parent b91ecf0 commit 90befb2
Show file tree
Hide file tree
Showing 5 changed files with 35 additions and 26 deletions.
9 changes: 5 additions & 4 deletions ibis/backends/bigquery/registry.py
Expand Up @@ -214,14 +214,15 @@ def _array_zip(translator, op):

def _array_map(translator, op):
arg = translator.translate(op.arg)
result = translator.translate(op.result)
return f"ARRAY(SELECT {result} FROM UNNEST({arg}) {op.parameter})"
result = translator.translate(op.body)
param = op.param
return f"ARRAY(SELECT {result} FROM UNNEST({arg}) {param})"


def _array_filter(translator, op):
arg = translator.translate(op.arg)
result = translator.translate(op.result)
param = op.parameter
result = translator.translate(op.body)
param = op.param
return f"ARRAY(SELECT {param} FROM UNNEST({arg}) {param} WHERE {result})"


Expand Down
29 changes: 20 additions & 9 deletions ibis/backends/duckdb/registry.py
Expand Up @@ -8,7 +8,6 @@
import sqlalchemy as sa
from sqlalchemy.ext.compiler import compiles
from sqlalchemy.sql.functions import GenericFunction
from toolz.curried import flip

import ibis.expr.operations as ops
from ibis.backends.base.sql import alchemy
Expand Down Expand Up @@ -225,9 +224,7 @@ def compiles_list_apply(element, compiler, **kw):

def _array_map(t, op):
return array_map(
t.translate(op.arg),
sa.literal_column(f"({op.parameter})"),
t.translate(op.result),
t.translate(op.arg), sa.literal_column(f"({op.param})"), t.translate(op.body)
)


Expand All @@ -239,15 +236,19 @@ def compiles_list_filter(element, compiler, **kw):

def _array_filter(t, op):
return array_filter(
t.translate(op.arg),
sa.literal_column(f"({op.parameter})"),
t.translate(op.result),
t.translate(op.arg), sa.literal_column(f"({op.param})"), t.translate(op.body)
)


def _array_intersect(t, op):
name = "x"
parameter = ops.Argument(
name=name, shape=op.left.shape, dtype=op.left.dtype.value_type
)
return t.translate(
ops.ArrayFilter(op.left, func=lambda x: ops.ArrayContains(op.right, x))
ops.ArrayFilter(
op.left, param=name, body=ops.ArrayContains(op.right, parameter)
)
)


Expand Down Expand Up @@ -372,7 +373,17 @@ def _try_cast(t, op):
),
ops.ArraySort: fixed_arity(sa.func.list_sort, 1),
ops.ArrayRemove: lambda t, op: _array_filter(
t, ops.ArrayFilter(op.arg, flip(ops.NotEquals, op.other))
t,
ops.ArrayFilter(
op.arg,
param="x",
body=ops.NotEquals(
ops.Argument(
name="x", shape=op.arg.shape, dtype=op.arg.dtype.value_type
),
op.other,
),
),
),
ops.ArrayUnion: lambda t, op: t.translate(
ops.ArrayDistinct(ops.ArrayConcat((op.left, op.right)))
Expand Down
8 changes: 4 additions & 4 deletions ibis/backends/postgres/registry.py
Expand Up @@ -574,26 +574,26 @@ def _array_map(t, op):
return sa.func.array(
# this translates to the function call, with column names the same as
# the parameter names in the lambda
sa.select(t.translate(op.result))
sa.select(t.translate(op.body))
.select_from(
# unnest the input array
sa.func.unnest(t.translate(op.arg))
# name the columns of the result the same as the lambda parameter
# so that we can reference them as such in the outer query
.table_valued(op.parameter).render_derived()
.table_valued(op.param).render_derived()
)
.scalar_subquery()
)


def _array_filter(t, op):
param = op.parameter
param = op.param
return sa.func.array(
sa.select(sa.column(param, type_=t.get_sqla_type(op.arg.dtype.value_type)))
.select_from(
sa.func.unnest(t.translate(op.arg)).table_valued(param).render_derived()
)
.where(t.translate(op.result))
.where(t.translate(op.body))
.scalar_subquery()
)

Expand Down
4 changes: 2 additions & 2 deletions ibis/backends/pyspark/compiler.py
Expand Up @@ -1702,7 +1702,7 @@ def compile_array_filter(t, op, **kwargs):
src_column = t.translate(op.arg, **kwargs)
return F.filter(
src_column,
lambda x: t.translate(op.result, arg_columns={op.parameter: x}, **kwargs),
lambda x: t.translate(op.body, arg_columns={op.param: x}, **kwargs),
)


Expand All @@ -1711,7 +1711,7 @@ def compile_array_map(t, op, **kwargs):
src_column = t.translate(op.arg, **kwargs)
return F.transform(
src_column,
lambda x: t.translate(op.result, arg_columns={op.parameter: x}, **kwargs),
lambda x: t.translate(op.body, arg_columns={op.param: x}, **kwargs),
)


Expand Down
11 changes: 4 additions & 7 deletions ibis/backends/trino/registry.py
Expand Up @@ -236,9 +236,7 @@ def compiles_list_apply(element, compiler, **kw):

def _array_map(t, op):
return array_map(
t.translate(op.arg),
sa.literal_column(f"({op.parameter})"),
t.translate(op.result),
t.translate(op.arg), sa.literal_column(f"({op.param})"), t.translate(op.body)
)


Expand All @@ -250,9 +248,7 @@ def compiles_list_filter(element, compiler, **kw):

def _array_filter(t, op):
return array_filter(
t.translate(op.arg),
sa.literal_column(f"({op.parameter})"),
t.translate(op.result),
t.translate(op.arg), sa.literal_column(f"({op.param})"), t.translate(op.body)
)


Expand Down Expand Up @@ -313,8 +309,9 @@ def _try_cast(t, op):


def _array_intersect(t, op):
x = ops.Argument(name="x", shape=op.left.shape, dtype=op.left.dtype.value_type)
return t.translate(
ops.ArrayFilter(op.left, func=lambda x: ops.ArrayContains(op.right, x))
ops.ArrayFilter(op.left, param="x", body=ops.ArrayContains(op.right, x))
)


Expand Down

0 comments on commit 90befb2

Please sign in to comment.