Skip to content

Commit 1f9ee37

Browse files
authored
feat: add bigquery.ml.transform function (#2394)
Thank you for opening a Pull Request! Before submitting your PR, there are a few things you can do to make sure it goes smoothly: - [ ] Make sure to open an issue as a [bug/issue](https://github.com/googleapis/python-bigquery-dataframes/issues/new/choose) before writing your code! That way we can discuss the change, evaluate designs, and agree on the general idea - [ ] Ensure the tests and linter pass - [ ] Code coverage does not decrease (if any source code was changed) - [ ] Appropriate docs were updated (if necessary) Fixes #<issue_number_goes_here> 🦕
1 parent e90e1d8 commit 1f9ee37

File tree

6 files changed

+78
-0
lines changed

6 files changed

+78
-0
lines changed

bigframes/bigquery/_operations/ml.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -393,3 +393,41 @@ def global_explain(
393393
return bpd.read_gbq_query(sql)
394394
else:
395395
return session.read_gbq_query(sql)
396+
397+
398+
@log_adapter.method_logger(custom_base_name="bigquery_ml")
399+
def transform(
400+
model: Union[bigframes.ml.base.BaseEstimator, str, pd.Series],
401+
input_: Union[pd.DataFrame, dataframe.DataFrame, str],
402+
) -> dataframe.DataFrame:
403+
"""
404+
Transforms input data using a BigQuery ML model.
405+
406+
See the `BigQuery ML TRANSFORM function syntax
407+
<https://docs.cloud.google.com/bigquery/docs/reference/standard-sql/bigqueryml-syntax-transform>`_
408+
for additional reference.
409+
410+
Args:
411+
model (bigframes.ml.base.BaseEstimator or str):
412+
The model to use for transformation.
413+
input_ (Union[bigframes.pandas.DataFrame, str]):
414+
The DataFrame or query to use for transformation.
415+
416+
Returns:
417+
bigframes.pandas.DataFrame:
418+
The transformed data.
419+
"""
420+
import bigframes.pandas as bpd
421+
422+
model_name, session = _get_model_name_and_session(model, input_)
423+
table_sql = _to_sql(input_)
424+
425+
sql = bigframes.core.sql.ml.transform(
426+
model_name=model_name,
427+
table=table_sql,
428+
)
429+
430+
if session is None:
431+
return bpd.read_gbq_query(sql)
432+
else:
433+
return session.read_gbq_query(sql)

bigframes/bigquery/ml.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
explain_predict,
2626
global_explain,
2727
predict,
28+
transform,
2829
)
2930

3031
__all__ = [
@@ -33,4 +34,5 @@
3334
"predict",
3435
"explain_predict",
3536
"global_explain",
37+
"transform",
3638
]

bigframes/core/sql/ml.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -213,3 +213,14 @@ def global_explain(
213213
sql += _build_struct_sql(struct_options)
214214
sql += ")\n"
215215
return sql
216+
217+
218+
def transform(
219+
model_name: str,
220+
table: str,
221+
) -> str:
222+
"""Encode the ML.TRANSFORM statement.
223+
See https://cloud.google.com/bigquery/docs/reference/standard-sql/bigqueryml-syntax-transform for reference.
224+
"""
225+
sql = f"SELECT * FROM ML.TRANSFORM(MODEL {googlesql.identifier(model_name)}, ({table}))\n"
226+
return sql

tests/unit/bigquery/test_ml.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -145,3 +145,21 @@ def test_global_explain_with_pandas_series_model(read_gbq_query_mock):
145145
generated_sql = read_gbq_query_mock.call_args[0][0]
146146
assert "ML.GLOBAL_EXPLAIN" in generated_sql
147147
assert f"MODEL `{MODEL_NAME}`" in generated_sql
148+
149+
150+
@mock.patch("bigframes.pandas.read_gbq_query")
151+
@mock.patch("bigframes.pandas.read_pandas")
152+
def test_transform_with_pandas_dataframe(read_pandas_mock, read_gbq_query_mock):
153+
df = pd.DataFrame({"col1": [1, 2, 3]})
154+
read_pandas_mock.return_value._to_sql_query.return_value = (
155+
"SELECT * FROM `pandas_df`",
156+
[],
157+
[],
158+
)
159+
ml_ops.transform(MODEL_SERIES, input_=df)
160+
read_pandas_mock.assert_called_once()
161+
read_gbq_query_mock.assert_called_once()
162+
generated_sql = read_gbq_query_mock.call_args[0][0]
163+
assert "ML.TRANSFORM" in generated_sql
164+
assert f"MODEL `{MODEL_NAME}`" in generated_sql
165+
assert "(SELECT * FROM `pandas_df`)" in generated_sql
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
SELECT * FROM ML.TRANSFORM(MODEL `my_project.my_dataset.my_model`, (SELECT * FROM new_data))

tests/unit/core/sql/test_ml.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -169,3 +169,11 @@ def test_global_explain_model_with_options(snapshot):
169169
class_level_explain=True,
170170
)
171171
snapshot.assert_match(sql, "global_explain_model_with_options.sql")
172+
173+
174+
def test_transform_model_basic(snapshot):
175+
sql = bigframes.core.sql.ml.transform(
176+
model_name="my_project.my_dataset.my_model",
177+
table="SELECT * FROM new_data",
178+
)
179+
snapshot.assert_match(sql, "transform_model_basic.sql")

0 commit comments

Comments
 (0)