Skip to content

Commit 45935b7

Browse files
committed
feat(datafusion): add support for scalar pyarrow UDFs
1 parent 3283333 commit 45935b7

File tree

2 files changed

+24
-4
lines changed

2 files changed

+24
-4
lines changed

ibis/backends/datafusion/compiler.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
import ibis.common.exceptions as com
1313
import ibis.expr.datatypes as dt
1414
import ibis.expr.operations as ops
15+
from ibis.expr.operations.udf import InputType
1516
from ibis.formats.pyarrow import PyArrowType
1617

1718

@@ -467,6 +468,23 @@ def elementwise_udf(op):
467468
return udf(*args)
468469

469470

471+
@translate.register(ops.ScalarUDF)
472+
def scalar_udf(op):
473+
if (input_type := op.__input_type__) != InputType.PYARROW:
474+
raise NotImplementedError(
475+
f"DataFusion only supports pyarrow UDFs: got a {input_type.name.lower()} UDF"
476+
)
477+
udf = df.udf(
478+
op.__func__,
479+
input_types=[PyArrowType.from_ibis(arg.output_dtype) for arg in op.args],
480+
return_type=PyArrowType.from_ibis(op.output_dtype),
481+
volatility="volatile",
482+
)
483+
args = map(translate, op.args)
484+
485+
return udf(*args)
486+
487+
470488
@translate.register(ops.StringConcat)
471489
def string_concat(op):
472490
return df.functions.concat(*map(translate, op.arg))

ibis/backends/tests/test_udf.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
"bigquery",
1414
"clickhouse",
1515
"dask",
16-
"datafusion",
1716
"druid",
1817
"impala",
1918
"mssql",
@@ -29,6 +28,7 @@
2928

3029

3130
@no_python_udfs
31+
@mark.notyet(["datafusion"], raises=NotImplementedError)
3232
def test_udf(batting):
3333
@udf.scalar.python
3434
def num_vowels(s: str, include_y: bool = False) -> int:
@@ -49,6 +49,7 @@ def num_vowels(s: str, include_y: bool = False) -> int:
4949
@mark.notyet(
5050
["postgres"], raises=TypeError, reason="postgres only supports map<string, string>"
5151
)
52+
@mark.notyet(["datafusion"], raises=NotImplementedError)
5253
@mark.xfail(
5354
sys.version_info[:2] < (3, 9), reason="annotations not supported with Python 3.8"
5455
)
@@ -73,6 +74,7 @@ def num_vowels_map(s: str, include_y: bool = False) -> dict[str, int]:
7374
@mark.notyet(
7475
["postgres"], raises=TypeError, reason="postgres only supports map<string, string>"
7576
)
77+
@mark.notyet(["datafusion"], raises=NotImplementedError)
7678
@mark.xfail(
7779
sys.version_info[:2] < (3, 9), reason="annotations not supported with Python 3.8"
7880
)
@@ -141,9 +143,9 @@ def add_one_pyarrow(s: int) -> int: # s is series, int is the element type
141143
add_one_pandas,
142144
marks=[
143145
mark.notyet(
144-
["duckdb"],
146+
["duckdb", "datafusion"],
145147
raises=NotImplementedError,
146-
reason="duckdb doesn't support pandas UDFs",
148+
reason="backend doesn't support pandas UDFs",
147149
),
148150
],
149151
),
@@ -153,7 +155,7 @@ def add_one_pyarrow(s: int) -> int: # s is series, int is the element type
153155
mark.notyet(
154156
["snowflake"],
155157
raises=NotImplementedError,
156-
reason="snowflake doesn't support pyarrow UDFs",
158+
reason="backend doesn't support pyarrow UDFs",
157159
)
158160
],
159161
),

0 commit comments

Comments
 (0)