From d850da6364b98c4e01120725e1e609ad8f6c1263 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tim=20Swe=C3=B1a=20=28Swast=29?= Date: Sun, 26 May 2024 15:10:36 -0500 Subject: [PATCH] feat: allow functions decorated with `@bpd.remote_function` to execute locally (#704) * feat: allow functions decorated with `@bpd.remote_function` to execute locally * fix read_gbq_function * fix for rare case where re-deploy exact same function object --- bigframes/core/compile/scalar_op_compiler.py | 10 +-- bigframes/functions/remote_function.py | 52 ++++++++++----- tests/system/large/test_remote_function.py | 5 +- tests/system/small/test_remote_function.py | 70 ++++++++++++++------ 4 files changed, 94 insertions(+), 43 deletions(-) diff --git a/bigframes/core/compile/scalar_op_compiler.py b/bigframes/core/compile/scalar_op_compiler.py index e8e5a1f3a..a79a4ecea 100644 --- a/bigframes/core/compile/scalar_op_compiler.py +++ b/bigframes/core/compile/scalar_op_compiler.py @@ -856,11 +856,12 @@ def to_timestamp_op_impl(x: ibis_types.Value, op: ops.ToTimestampOp): @scalar_op_compiler.register_unary_op(ops.RemoteFunctionOp, pass_op=True) def remote_function_op_impl(x: ibis_types.Value, op: ops.RemoteFunctionOp): - if not hasattr(op.func, "bigframes_remote_function"): + ibis_node = getattr(op.func, "ibis_node", None) + if ibis_node is None: raise TypeError( f"only a bigframes remote function is supported as a callable. {constants.FEEDBACK_LINK}" ) - x_transformed = op.func(x) + x_transformed = ibis_node(x) if not op.apply_on_null: x_transformed = ibis.case().when(x.isnull(), x).else_(x_transformed).end() return x_transformed @@ -1342,11 +1343,12 @@ def minimum_impl( def binary_remote_function_op_impl( x: ibis_types.Value, y: ibis_types.Value, op: ops.BinaryRemoteFunctionOp ): - if not hasattr(op.func, "bigframes_remote_function"): + ibis_node = getattr(op.func, "ibis_node", None) + if ibis_node is None: raise TypeError( f"only a bigframes remote function is supported as a callable. {constants.FEEDBACK_LINK}" ) - x_transformed = op.func(x, y) + x_transformed = ibis_node(x, y) return x_transformed diff --git a/bigframes/functions/remote_function.py b/bigframes/functions/remote_function.py index af4dd5982..fb4e3f2f3 100644 --- a/bigframes/functions/remote_function.py +++ b/bigframes/functions/remote_function.py @@ -1013,11 +1013,11 @@ def remote_function( bq_connection_manager = None if session is None else session.bqconnectionmanager - def wrapper(f): + def wrapper(func): nonlocal input_types, output_type - if not callable(f): - raise TypeError("f must be callable, got {}".format(f)) + if not callable(func): + raise TypeError("f must be callable, got {}".format(func)) if sys.version_info >= (3, 10): # Add `eval_str = True` so that deferred annotations are turned into their @@ -1028,7 +1028,7 @@ def wrapper(f): signature_kwargs = {} signature = inspect.signature( - f, + func, **signature_kwargs, ) @@ -1089,8 +1089,23 @@ def wrapper(f): session=session, # type: ignore ) + # In the unlikely case where the user is trying to re-deploy the same + # function, cleanup the attributes we add below, first. This prevents + # the pickle from having dependencies that might not otherwise be + # present such as ibis or pandas. + def try_delattr(attr): + try: + delattr(func, attr) + except AttributeError: + pass + + try_delattr("bigframes_cloud_function") + try_delattr("bigframes_remote_function") + try_delattr("output_dtype") + try_delattr("ibis_node") + rf_name, cf_name = remote_function_client.provision_bq_remote_function( - f, + func, ibis_signature.input_types, ibis_signature.output_type, reuse, @@ -1105,19 +1120,20 @@ def wrapper(f): # TODO: Move ibis logic to compiler step node = ibis.udf.scalar.builtin( - f, + func, name=rf_name, schema=f"{dataset_ref.project}.{dataset_ref.dataset_id}", signature=(ibis_signature.input_types, ibis_signature.output_type), ) - node.bigframes_cloud_function = ( + func.bigframes_cloud_function = ( remote_function_client.get_cloud_function_fully_qualified_name(cf_name) ) - node.bigframes_remote_function = str(dataset_ref.routine(rf_name)) # type: ignore - node.output_dtype = bigframes.dtypes.ibis_dtype_to_bigframes_dtype( + func.bigframes_remote_function = str(dataset_ref.routine(rf_name)) # type: ignore + func.output_dtype = bigframes.dtypes.ibis_dtype_to_bigframes_dtype( ibis_signature.output_type ) - return node + func.ibis_node = node + return func return wrapper @@ -1168,19 +1184,23 @@ def read_gbq_function( # The name "args" conflicts with the Ibis operator, so we use # non-standard names for the arguments here. - def node(*ignored_args, **ignored_kwargs): + def func(*ignored_args, **ignored_kwargs): f"""Remote function {str(routine_ref)}.""" + # TODO(swast): Construct an ibis client from bigquery_client and + # execute node via a query. # TODO: Move ibis logic to compiler step - node.__name__ = routine_ref.routine_id + func.__name__ = routine_ref.routine_id + node = ibis.udf.scalar.builtin( - node, + func, name=routine_ref.routine_id, schema=f"{routine_ref.project}.{routine_ref.dataset_id}", signature=(ibis_signature.input_types, ibis_signature.output_type), ) - node.bigframes_remote_function = str(routine_ref) # type: ignore - node.output_dtype = bigframes.dtypes.ibis_dtype_to_bigframes_dtype( # type: ignore + func.bigframes_remote_function = str(routine_ref) # type: ignore + func.output_dtype = bigframes.dtypes.ibis_dtype_to_bigframes_dtype( # type: ignore ibis_signature.output_type ) - return node + func.ibis_node = node # type: ignore + return func diff --git a/tests/system/large/test_remote_function.py b/tests/system/large/test_remote_function.py index cac8483b5..4114eaae0 100644 --- a/tests/system/large/test_remote_function.py +++ b/tests/system/large/test_remote_function.py @@ -194,6 +194,9 @@ def test_remote_function_stringify_with_ibis( def stringify(x): return f"I got {x}" + # Function should work locally. + assert stringify(42) == "I got 42" + _, dataset_name, table_name = scalars_table_id.split(".") if not ibis_client.dataset: ibis_client.dataset = dataset_name @@ -205,7 +208,7 @@ def stringify(x): pandas_df_orig = bigquery_client.query(sql).to_dataframe() col = table[col_name] - col_2x = stringify(col).name("int64_str_col") + col_2x = stringify.ibis_node(col).name("int64_str_col") table = table.mutate([col_2x]) sql = table.compile() pandas_df_new = bigquery_client.query(sql).to_dataframe() diff --git a/tests/system/small/test_remote_function.py b/tests/system/small/test_remote_function.py index 4a39e75ff..096a26844 100644 --- a/tests/system/small/test_remote_function.py +++ b/tests/system/small/test_remote_function.py @@ -67,7 +67,7 @@ def bq_cf_connection_location_project(bigquery_client) -> str: @pytest.fixture(scope="module") def bq_cf_connection_location_project_mismatched() -> str: - """Pre-created BQ connection in the migframes-metrics project in US location, + """Pre-created BQ connection in the bigframes-metrics project in US location, in format PROJECT_ID.LOCATION.CONNECTION_NAME, used to invoke cloud function. $ bq show --connection --location=us --project_id=PROJECT_ID bigframes-rf-conn @@ -108,11 +108,15 @@ def test_remote_function_direct_no_session_param( reuse=True, ) def square(x): - # This executes on a remote function, where coverage isn't tracked. - return x * x # pragma: NO COVER + return x * x - assert square.bigframes_remote_function - assert square.bigframes_cloud_function + # Function should still work normally. + assert square(2) == 4 + + # Function should have extra metadata attached for remote execution. + assert hasattr(square, "bigframes_remote_function") + assert hasattr(square, "bigframes_cloud_function") + assert hasattr(square, "ibis_node") scalars_df, scalars_pandas_df = scalars_dfs @@ -161,8 +165,10 @@ def test_remote_function_direct_no_session_param_location_specified( reuse=True, ) def square(x): - # This executes on a remote function, where coverage isn't tracked. - return x * x # pragma: NO COVER + return x * x + + # Function should still work normally. + assert square(2) == 4 scalars_df, scalars_pandas_df = scalars_dfs @@ -197,7 +203,10 @@ def test_remote_function_direct_no_session_param_location_mismatched( dataset_id_permanent, bq_cf_connection_location_mismatched, ): - with pytest.raises(ValueError): + with pytest.raises( + ValueError, + match=re.escape("The location does not match BigQuery connection location:"), + ): @rf.remote_function( [int], @@ -212,7 +221,8 @@ def test_remote_function_direct_no_session_param_location_mismatched( reuse=True, ) def square(x): - # This executes on a remote function, where coverage isn't tracked. + # Not expected to reach this code, as the location of the + # connection doesn't match the location of the dataset. return x * x # pragma: NO COVER @@ -239,8 +249,10 @@ def test_remote_function_direct_no_session_param_location_project_specified( reuse=True, ) def square(x): - # This executes on a remote function, where coverage isn't tracked. - return x * x # pragma: NO COVER + return x * x + + # Function should still work normally. + assert square(2) == 4 scalars_df, scalars_pandas_df = scalars_dfs @@ -275,7 +287,12 @@ def test_remote_function_direct_no_session_param_project_mismatched( dataset_id_permanent, bq_cf_connection_location_project_mismatched, ): - with pytest.raises(ValueError): + with pytest.raises( + ValueError, + match=re.escape( + "The project_id does not match BigQuery connection gcp_project_id:" + ), + ): @rf.remote_function( [int], @@ -290,7 +307,8 @@ def test_remote_function_direct_no_session_param_project_mismatched( reuse=True, ) def square(x): - # This executes on a remote function, where coverage isn't tracked. + # Not expected to reach this code, as the project of the + # connection doesn't match the project of the dataset. return x * x # pragma: NO COVER @@ -302,8 +320,10 @@ def test_remote_function_direct_session_param(session_with_bq_connection, scalar session=session_with_bq_connection, ) def square(x): - # This executes on a remote function, where coverage isn't tracked. - return x * x # pragma: NO COVER + return x * x + + # Function should still work normally. + assert square(2) == 4 scalars_df, scalars_pandas_df = scalars_dfs @@ -340,8 +360,10 @@ def test_remote_function_via_session_default(session_with_bq_connection, scalars # cloud function would be common and quickly reused. @session_with_bq_connection.remote_function([int], int) def square(x): - # This executes on a remote function, where coverage isn't tracked. - return x * x # pragma: NO COVER + return x * x + + # Function should still work normally. + assert square(2) == 4 scalars_df, scalars_pandas_df = scalars_dfs @@ -380,8 +402,10 @@ def test_remote_function_via_session_with_overrides( reuse=True, ) def square(x): - # This executes on a remote function, where coverage isn't tracked. - return x * x # pragma: NO COVER + return x * x + + # Function should still work normally. + assert square(2) == 4 scalars_df, scalars_pandas_df = scalars_dfs @@ -508,7 +532,7 @@ def test_skip_bq_connection_check(dataset_id_permanent): @session.remote_function([int], int, dataset=dataset_id_permanent) def add_one(x): - # This executes on a remote function, where coverage isn't tracked. + # Not expected to reach this code, as the connection doesn't exist. return x + 1 # pragma: NO COVER @@ -546,8 +570,10 @@ def test_read_gbq_function_like_original( reuse=True, ) def square1(x): - # This executes on a remote function, where coverage isn't tracked. - return x * x # pragma: NO COVER + return x * x + + # Function should still work normally. + assert square1(2) == 4 square2 = rf.read_gbq_function( function_name=square1.bigframes_remote_function,