Skip to content

Commit ad3b8fa

Browse files
feat(bigframes): Defer unnamed @udf deployment until needed (#17217)
1 parent 4d20bab commit ad3b8fa

28 files changed

Lines changed: 824 additions & 976 deletions

File tree

packages/bigframes/bigframes/core/blocks.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1091,14 +1091,14 @@ def multi_apply_window_op(
10911091

10921092
def multi_apply_unary_op(
10931093
self,
1094-
op: Union[ops.UnaryOp, ex.Expression],
1094+
op: Union[ops.UnaryOp, ops.NaryOp, ex.Expression],
10951095
) -> Block:
1096-
if isinstance(op, ops.UnaryOp):
1096+
if isinstance(op, (ops.UnaryOp, ops.NaryOp)):
10971097
input_varname = guid.generate_guid()
10981098
expr = op.as_expr(ex.free_var(input_varname))
10991099
else:
11001100
input_varnames = op.free_variables
1101-
assert len(input_varnames) == 1
1101+
assert len(set(input_varnames)) == 1
11021102
expr = op
11031103
input_varname = input_varnames[0]
11041104

packages/bigframes/bigframes/core/compile/ibis_compiler/scalar_op_registry.py

Lines changed: 3 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -1034,44 +1034,8 @@ def timedelta_floor_op_impl(x: ibis_types.NumericValue):
10341034
return ibis_api.case().when(x > ibis.literal(0), x.floor()).else_(x.ceil()).end()
10351035

10361036

1037-
@scalar_op_compiler.register_unary_op(ops.RemoteFunctionOp, pass_op=True)
1038-
def remote_function_op_impl(x: ibis_types.Value, op: ops.RemoteFunctionOp):
1039-
udf_sig = op.function_def.signature
1040-
assert not udf_sig.is_virtual # should have been devirtualized in lowering pass
1041-
ibis_py_sig = (tuple(arg.py_type for arg in udf_sig.inputs), udf_sig.output.py_type)
1042-
1043-
@ibis_udf.scalar.builtin(
1044-
name=str(op.function_def.routine_ref), signature=ibis_py_sig
1045-
)
1046-
def udf(input): ...
1047-
1048-
x_transformed = udf(x)
1049-
if not op.apply_on_null:
1050-
return ibis_api.case().when(x.isnull(), x).else_(x_transformed).end()
1051-
return x_transformed
1052-
1053-
1054-
@scalar_op_compiler.register_binary_op(ops.BinaryRemoteFunctionOp, pass_op=True)
1055-
def binary_remote_function_op_impl(
1056-
x: ibis_types.Value, y: ibis_types.Value, op: ops.BinaryRemoteFunctionOp
1057-
):
1058-
udf_sig = op.function_def.signature
1059-
assert not udf_sig.is_virtual # should have been devirtualized in lowering pass
1060-
ibis_py_sig = (tuple(arg.py_type for arg in udf_sig.inputs), udf_sig.output.py_type)
1061-
1062-
@ibis_udf.scalar.builtin(
1063-
name=str(op.function_def.routine_ref), signature=ibis_py_sig
1064-
)
1065-
def udf(input1, input2): ...
1066-
1067-
x_transformed = udf(x, y)
1068-
return x_transformed
1069-
1070-
1071-
@scalar_op_compiler.register_nary_op(ops.NaryRemoteFunctionOp, pass_op=True)
1072-
def nary_remote_function_op_impl(
1073-
*operands: ibis_types.Value, op: ops.NaryRemoteFunctionOp
1074-
):
1037+
@scalar_op_compiler.register_nary_op(ops.RemoteFunctionOp, pass_op=True)
1038+
def remote_function_op_impl(*values: ibis_types.Value, op: ops.RemoteFunctionOp):
10751039
udf_sig = op.function_def.signature
10761040
assert not udf_sig.is_virtual # should have been devirtualized in lowering pass
10771041
ibis_py_sig = (tuple(arg.py_type for arg in udf_sig.inputs), udf_sig.output.py_type)
@@ -1084,8 +1048,7 @@ def nary_remote_function_op_impl(
10841048
)
10851049
def udf(*inputs): ...
10861050

1087-
result = udf(*operands)
1088-
return result
1051+
return udf(*values)
10891052

10901053

10911054
@scalar_op_compiler.register_unary_op(ops.MapOp, pass_op=True)

packages/bigframes/bigframes/core/compile/sqlglot/expressions/generic_ops.py

Lines changed: 3 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -186,33 +186,9 @@ def _get_remote_function_name(op):
186186
)
187187

188188

189-
@register_unary_op(ops.RemoteFunctionOp, pass_op=True)
190-
def _(expr: TypedExpr, op: ops.RemoteFunctionOp) -> sge.Expression:
191-
func_name = _get_remote_function_name(op)
192-
func = sge.func(func_name, expr.expr)
193-
194-
if not op.apply_on_null:
195-
return sge.If(
196-
this=sge.Is(this=expr.expr, expression=sge.Null()),
197-
true=expr.expr,
198-
false=func,
199-
)
200-
201-
return func
202-
203-
204-
@register_binary_op(ops.BinaryRemoteFunctionOp, pass_op=True)
205-
def _(
206-
left: TypedExpr, right: TypedExpr, op: ops.BinaryRemoteFunctionOp
207-
) -> sge.Expression:
208-
func_name = _get_remote_function_name(op)
209-
return sge.func(func_name, left.expr, right.expr)
210-
211-
212-
@register_nary_op(ops.NaryRemoteFunctionOp, pass_op=True)
213-
def _(*operands: TypedExpr, op: ops.NaryRemoteFunctionOp) -> sge.Expression:
214-
func_name = _get_remote_function_name(op)
215-
return sge.func(func_name, *(operand.expr for operand in operands))
189+
@register_nary_op(ops.RemoteFunctionOp, pass_op=True)
190+
def _(*values: TypedExpr, op: ops.RemoteFunctionOp) -> sge.Expression:
191+
return sge.func(_get_remote_function_name(op), *(value.expr for value in values))
216192

217193

218194
@register_nary_op(ops.case_when_op)

packages/bigframes/bigframes/core/rewrite/udfs.py

Lines changed: 1 addition & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -32,55 +32,14 @@ def lower(self, expr: expression.OpExpression) -> expression.Expression:
3232
func_def = expr.op.function_def
3333
devirtualized_expr = ops.RemoteFunctionOp(
3434
func_def.with_devirtualize(),
35-
apply_on_null=expr.op.apply_on_null,
3635
).as_expr(*expr.children)
3736
if isinstance(func_def.signature.output, udf_def.VirtualListTypeV1):
3837
return func_def.signature.output.out_expr(devirtualized_expr)
3938
else:
4039
return devirtualized_expr
4140

4241

43-
@dataclasses.dataclass
44-
class LowerBinaryRemoteFunctionRule(op_lowering.OpLoweringRule):
45-
@property
46-
def op(self) -> type[ops.ScalarOp]:
47-
return ops.BinaryRemoteFunctionOp
48-
49-
def lower(self, expr: expression.OpExpression) -> expression.Expression:
50-
assert isinstance(expr.op, ops.BinaryRemoteFunctionOp)
51-
func_def = expr.op.function_def
52-
devirtualized_expr = ops.BinaryRemoteFunctionOp(
53-
func_def.with_devirtualize(),
54-
).as_expr(*expr.children)
55-
if isinstance(func_def.signature.output, udf_def.VirtualListTypeV1):
56-
return func_def.signature.output.out_expr(devirtualized_expr)
57-
else:
58-
return devirtualized_expr
59-
60-
61-
@dataclasses.dataclass
62-
class LowerNaryRemoteFunctionRule(op_lowering.OpLoweringRule):
63-
@property
64-
def op(self) -> type[ops.ScalarOp]:
65-
return ops.NaryRemoteFunctionOp
66-
67-
def lower(self, expr: expression.OpExpression) -> expression.Expression:
68-
assert isinstance(expr.op, ops.NaryRemoteFunctionOp)
69-
func_def = expr.op.function_def
70-
devirtualized_expr = ops.NaryRemoteFunctionOp(
71-
func_def.with_devirtualize(),
72-
).as_expr(*expr.children)
73-
if isinstance(func_def.signature.output, udf_def.VirtualListTypeV1):
74-
return func_def.signature.output.out_expr(devirtualized_expr)
75-
else:
76-
return devirtualized_expr
77-
78-
79-
UDF_LOWERING_RULES = (
80-
LowerRemoteFunctionRule(),
81-
LowerBinaryRemoteFunctionRule(),
82-
LowerNaryRemoteFunctionRule(),
83-
)
42+
UDF_LOWERING_RULES = (LowerRemoteFunctionRule(),)
8443

8544

8645
def lower_udfs(root: bigframe_node.BigFrameNode) -> bigframe_node.BigFrameNode:

packages/bigframes/bigframes/dataframe.py

Lines changed: 13 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -4694,12 +4694,14 @@ def map(self, func, na_action: Optional[str] = None) -> DataFrame:
46944694
if na_action not in {None, "ignore"}:
46954695
raise ValueError(f"na_action={na_action} not supported")
46964696

4697-
# TODO(shobs): Support **kwargs
4698-
return self._apply_unary_op(
4699-
ops.RemoteFunctionOp(
4700-
function_def=func.udf_def, apply_on_null=(na_action is None)
4697+
expr = ops.func_to_op(func).as_expr(ex.free_var("input"))
4698+
if na_action == "ignore":
4699+
# True case, predicate, False case
4700+
expr = ops.where_op.as_expr(
4701+
expr, ops.notnull_op.as_expr(ex.free_var("input")), ex.const(None)
47014702
)
4702-
)
4703+
4704+
return DataFrame(self._block.multi_apply_unary_op(expr))
47034705

47044706
def apply(self, func, *, axis=0, args: typing.Tuple = (), **kwargs):
47054707
# In Bigframes BigQuery function, DataFrame '.apply' method is specifically
@@ -4770,17 +4772,11 @@ def apply(self, func, *, axis=0, args: typing.Tuple = (), **kwargs):
47704772
)
47714773

47724774
# Apply the function
4773-
if args:
4774-
result_series = rows_as_json_series._apply_nary_op(
4775-
ops.NaryRemoteFunctionOp(function_def=func.udf_def),
4776-
list(args),
4777-
)
4778-
else:
4779-
result_series = rows_as_json_series._apply_unary_op(
4780-
ops.RemoteFunctionOp(
4781-
function_def=func.udf_def, apply_on_null=True
4782-
)
4783-
)
4775+
result_series = rows_as_json_series._apply_nary_op(
4776+
ops.func_to_op(func),
4777+
list(args),
4778+
)
4779+
47844780
else:
47854781
# This is a special case where we are providing not-pandas-like
47864782
# extension. If the bigquery function can take one or more
@@ -4838,7 +4834,7 @@ def apply(self, func, *, axis=0, args: typing.Tuple = (), **kwargs):
48384834
series_list = [self[col] for col in self.columns]
48394835
op_list = series_list[1:] + list(args)
48404836
result_series = series_list[0]._apply_nary_op(
4841-
ops.NaryRemoteFunctionOp(function_def=func.udf_def), op_list
4837+
ops.func_to_op(func), op_list
48424838
)
48434839
result_series.name = None
48444840

0 commit comments

Comments
 (0)