Skip to content

Commit

Permalink
feat: allow functions decorated with @bpd.remote_function to execut…
Browse files Browse the repository at this point in the history
…e locally (#704)

* feat: allow functions decorated with `@bpd.remote_function` to execute locally

* fix read_gbq_function

* fix for rare case where re-deploy exact same function object
  • Loading branch information
tswast committed May 26, 2024
1 parent 4a12e3c commit d850da6
Show file tree
Hide file tree
Showing 4 changed files with 94 additions and 43 deletions.
10 changes: 6 additions & 4 deletions bigframes/core/compile/scalar_op_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -856,11 +856,12 @@ def to_timestamp_op_impl(x: ibis_types.Value, op: ops.ToTimestampOp):

@scalar_op_compiler.register_unary_op(ops.RemoteFunctionOp, pass_op=True)
def remote_function_op_impl(x: ibis_types.Value, op: ops.RemoteFunctionOp):
if not hasattr(op.func, "bigframes_remote_function"):
ibis_node = getattr(op.func, "ibis_node", None)
if ibis_node is None:
raise TypeError(
f"only a bigframes remote function is supported as a callable. {constants.FEEDBACK_LINK}"
)
x_transformed = op.func(x)
x_transformed = ibis_node(x)
if not op.apply_on_null:
x_transformed = ibis.case().when(x.isnull(), x).else_(x_transformed).end()
return x_transformed
Expand Down Expand Up @@ -1342,11 +1343,12 @@ def minimum_impl(
def binary_remote_function_op_impl(
x: ibis_types.Value, y: ibis_types.Value, op: ops.BinaryRemoteFunctionOp
):
if not hasattr(op.func, "bigframes_remote_function"):
ibis_node = getattr(op.func, "ibis_node", None)
if ibis_node is None:
raise TypeError(
f"only a bigframes remote function is supported as a callable. {constants.FEEDBACK_LINK}"
)
x_transformed = op.func(x, y)
x_transformed = ibis_node(x, y)
return x_transformed


Expand Down
52 changes: 36 additions & 16 deletions bigframes/functions/remote_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -1013,11 +1013,11 @@ def remote_function(

bq_connection_manager = None if session is None else session.bqconnectionmanager

def wrapper(f):
def wrapper(func):
nonlocal input_types, output_type

if not callable(f):
raise TypeError("f must be callable, got {}".format(f))
if not callable(func):
raise TypeError("f must be callable, got {}".format(func))

if sys.version_info >= (3, 10):
# Add `eval_str = True` so that deferred annotations are turned into their
Expand All @@ -1028,7 +1028,7 @@ def wrapper(f):
signature_kwargs = {}

signature = inspect.signature(
f,
func,
**signature_kwargs,
)

Expand Down Expand Up @@ -1089,8 +1089,23 @@ def wrapper(f):
session=session, # type: ignore
)

# In the unlikely case where the user is trying to re-deploy the same
# function, cleanup the attributes we add below, first. This prevents
# the pickle from having dependencies that might not otherwise be
# present such as ibis or pandas.
def try_delattr(attr):
try:
delattr(func, attr)
except AttributeError:
pass

try_delattr("bigframes_cloud_function")
try_delattr("bigframes_remote_function")
try_delattr("output_dtype")
try_delattr("ibis_node")

rf_name, cf_name = remote_function_client.provision_bq_remote_function(
f,
func,
ibis_signature.input_types,
ibis_signature.output_type,
reuse,
Expand All @@ -1105,19 +1120,20 @@ def wrapper(f):

# TODO: Move ibis logic to compiler step
node = ibis.udf.scalar.builtin(
f,
func,
name=rf_name,
schema=f"{dataset_ref.project}.{dataset_ref.dataset_id}",
signature=(ibis_signature.input_types, ibis_signature.output_type),
)
node.bigframes_cloud_function = (
func.bigframes_cloud_function = (
remote_function_client.get_cloud_function_fully_qualified_name(cf_name)
)
node.bigframes_remote_function = str(dataset_ref.routine(rf_name)) # type: ignore
node.output_dtype = bigframes.dtypes.ibis_dtype_to_bigframes_dtype(
func.bigframes_remote_function = str(dataset_ref.routine(rf_name)) # type: ignore
func.output_dtype = bigframes.dtypes.ibis_dtype_to_bigframes_dtype(
ibis_signature.output_type
)
return node
func.ibis_node = node
return func

return wrapper

Expand Down Expand Up @@ -1168,19 +1184,23 @@ def read_gbq_function(

# The name "args" conflicts with the Ibis operator, so we use
# non-standard names for the arguments here.
def node(*ignored_args, **ignored_kwargs):
def func(*ignored_args, **ignored_kwargs):
f"""Remote function {str(routine_ref)}."""
# TODO(swast): Construct an ibis client from bigquery_client and
# execute node via a query.

# TODO: Move ibis logic to compiler step
node.__name__ = routine_ref.routine_id
func.__name__ = routine_ref.routine_id

node = ibis.udf.scalar.builtin(
node,
func,
name=routine_ref.routine_id,
schema=f"{routine_ref.project}.{routine_ref.dataset_id}",
signature=(ibis_signature.input_types, ibis_signature.output_type),
)
node.bigframes_remote_function = str(routine_ref) # type: ignore
node.output_dtype = bigframes.dtypes.ibis_dtype_to_bigframes_dtype( # type: ignore
func.bigframes_remote_function = str(routine_ref) # type: ignore
func.output_dtype = bigframes.dtypes.ibis_dtype_to_bigframes_dtype( # type: ignore
ibis_signature.output_type
)
return node
func.ibis_node = node # type: ignore
return func
5 changes: 4 additions & 1 deletion tests/system/large/test_remote_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,9 @@ def test_remote_function_stringify_with_ibis(
def stringify(x):
return f"I got {x}"

# Function should work locally.
assert stringify(42) == "I got 42"

_, dataset_name, table_name = scalars_table_id.split(".")
if not ibis_client.dataset:
ibis_client.dataset = dataset_name
Expand All @@ -205,7 +208,7 @@ def stringify(x):
pandas_df_orig = bigquery_client.query(sql).to_dataframe()

col = table[col_name]
col_2x = stringify(col).name("int64_str_col")
col_2x = stringify.ibis_node(col).name("int64_str_col")
table = table.mutate([col_2x])
sql = table.compile()
pandas_df_new = bigquery_client.query(sql).to_dataframe()
Expand Down
70 changes: 48 additions & 22 deletions tests/system/small/test_remote_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def bq_cf_connection_location_project(bigquery_client) -> str:

@pytest.fixture(scope="module")
def bq_cf_connection_location_project_mismatched() -> str:
"""Pre-created BQ connection in the migframes-metrics project in US location,
"""Pre-created BQ connection in the bigframes-metrics project in US location,
in format PROJECT_ID.LOCATION.CONNECTION_NAME, used to invoke cloud function.
$ bq show --connection --location=us --project_id=PROJECT_ID bigframes-rf-conn
Expand Down Expand Up @@ -108,11 +108,15 @@ def test_remote_function_direct_no_session_param(
reuse=True,
)
def square(x):
# This executes on a remote function, where coverage isn't tracked.
return x * x # pragma: NO COVER
return x * x

assert square.bigframes_remote_function
assert square.bigframes_cloud_function
# Function should still work normally.
assert square(2) == 4

# Function should have extra metadata attached for remote execution.
assert hasattr(square, "bigframes_remote_function")
assert hasattr(square, "bigframes_cloud_function")
assert hasattr(square, "ibis_node")

scalars_df, scalars_pandas_df = scalars_dfs

Expand Down Expand Up @@ -161,8 +165,10 @@ def test_remote_function_direct_no_session_param_location_specified(
reuse=True,
)
def square(x):
# This executes on a remote function, where coverage isn't tracked.
return x * x # pragma: NO COVER
return x * x

# Function should still work normally.
assert square(2) == 4

scalars_df, scalars_pandas_df = scalars_dfs

Expand Down Expand Up @@ -197,7 +203,10 @@ def test_remote_function_direct_no_session_param_location_mismatched(
dataset_id_permanent,
bq_cf_connection_location_mismatched,
):
with pytest.raises(ValueError):
with pytest.raises(
ValueError,
match=re.escape("The location does not match BigQuery connection location:"),
):

@rf.remote_function(
[int],
Expand All @@ -212,7 +221,8 @@ def test_remote_function_direct_no_session_param_location_mismatched(
reuse=True,
)
def square(x):
# This executes on a remote function, where coverage isn't tracked.
# Not expected to reach this code, as the location of the
# connection doesn't match the location of the dataset.
return x * x # pragma: NO COVER


Expand All @@ -239,8 +249,10 @@ def test_remote_function_direct_no_session_param_location_project_specified(
reuse=True,
)
def square(x):
# This executes on a remote function, where coverage isn't tracked.
return x * x # pragma: NO COVER
return x * x

# Function should still work normally.
assert square(2) == 4

scalars_df, scalars_pandas_df = scalars_dfs

Expand Down Expand Up @@ -275,7 +287,12 @@ def test_remote_function_direct_no_session_param_project_mismatched(
dataset_id_permanent,
bq_cf_connection_location_project_mismatched,
):
with pytest.raises(ValueError):
with pytest.raises(
ValueError,
match=re.escape(
"The project_id does not match BigQuery connection gcp_project_id:"
),
):

@rf.remote_function(
[int],
Expand All @@ -290,7 +307,8 @@ def test_remote_function_direct_no_session_param_project_mismatched(
reuse=True,
)
def square(x):
# This executes on a remote function, where coverage isn't tracked.
# Not expected to reach this code, as the project of the
# connection doesn't match the project of the dataset.
return x * x # pragma: NO COVER


Expand All @@ -302,8 +320,10 @@ def test_remote_function_direct_session_param(session_with_bq_connection, scalar
session=session_with_bq_connection,
)
def square(x):
# This executes on a remote function, where coverage isn't tracked.
return x * x # pragma: NO COVER
return x * x

# Function should still work normally.
assert square(2) == 4

scalars_df, scalars_pandas_df = scalars_dfs

Expand Down Expand Up @@ -340,8 +360,10 @@ def test_remote_function_via_session_default(session_with_bq_connection, scalars
# cloud function would be common and quickly reused.
@session_with_bq_connection.remote_function([int], int)
def square(x):
# This executes on a remote function, where coverage isn't tracked.
return x * x # pragma: NO COVER
return x * x

# Function should still work normally.
assert square(2) == 4

scalars_df, scalars_pandas_df = scalars_dfs

Expand Down Expand Up @@ -380,8 +402,10 @@ def test_remote_function_via_session_with_overrides(
reuse=True,
)
def square(x):
# This executes on a remote function, where coverage isn't tracked.
return x * x # pragma: NO COVER
return x * x

# Function should still work normally.
assert square(2) == 4

scalars_df, scalars_pandas_df = scalars_dfs

Expand Down Expand Up @@ -508,7 +532,7 @@ def test_skip_bq_connection_check(dataset_id_permanent):

@session.remote_function([int], int, dataset=dataset_id_permanent)
def add_one(x):
# This executes on a remote function, where coverage isn't tracked.
# Not expected to reach this code, as the connection doesn't exist.
return x + 1 # pragma: NO COVER


Expand Down Expand Up @@ -546,8 +570,10 @@ def test_read_gbq_function_like_original(
reuse=True,
)
def square1(x):
# This executes on a remote function, where coverage isn't tracked.
return x * x # pragma: NO COVER
return x * x

# Function should still work normally.
assert square1(2) == 4

square2 = rf.read_gbq_function(
function_name=square1.bigframes_remote_function,
Expand Down

0 comments on commit d850da6

Please sign in to comment.