Skip to content

Commit b925aa2

Browse files
authored
feat: add bigquery.ai.generate_table function (#2453)
Thank you for opening a Pull Request! Before submitting your PR, there are a few things you can do to make sure it goes smoothly: - [ ] Make sure to open an issue as a [bug/issue](https://github.com/googleapis/python-bigquery-dataframes/issues/new/choose) before writing your code! That way we can discuss the change, evaluate designs, and agree on the general idea - [ ] Ensure the tests and linter pass - [ ] Code coverage does not decrease (if any source code was changed) - [ ] Appropriate docs were updated (if necessary) Fixes #<issue_number_goes_here> 🦕
1 parent a6aafaa commit b925aa2

File tree

4 files changed

+163
-0
lines changed

4 files changed

+163
-0
lines changed

bigframes/bigquery/_operations/ai.py

Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -601,6 +601,101 @@ def generate_text(
601601
return session.read_gbq_query(query)
602602

603603

604+
@log_adapter.method_logger(custom_base_name="bigquery_ai")
605+
def generate_table(
606+
model: Union[bigframes.ml.base.BaseEstimator, str, pd.Series],
607+
data: Union[dataframe.DataFrame, series.Series, pd.DataFrame, pd.Series],
608+
*,
609+
output_schema: str,
610+
temperature: Optional[float] = None,
611+
top_p: Optional[float] = None,
612+
max_output_tokens: Optional[int] = None,
613+
stop_sequences: Optional[List[str]] = None,
614+
request_type: Optional[str] = None,
615+
) -> dataframe.DataFrame:
616+
"""
617+
Generates a table using a BigQuery ML model.
618+
619+
See the `AI.GENERATE_TABLE function syntax
620+
<https://docs.cloud.google.com/bigquery/docs/reference/standard-sql/bigqueryml-syntax-generate-table>`_
621+
for additional reference.
622+
623+
**Examples:**
624+
625+
>>> import bigframes.pandas as bpd
626+
>>> import bigframes.bigquery as bbq
627+
>>> # The user is responsible for constructing a DataFrame that contains
628+
>>> # the necessary columns for the model's prompt. For example, a
629+
>>> # DataFrame with a 'prompt' column for text classification.
630+
>>> df = bpd.DataFrame({'prompt': ["some text to classify"]})
631+
>>> result = bbq.ai.generate_table(
632+
... "project.dataset.model_name",
633+
... data=df,
634+
... output_schema="category STRING"
635+
... ) # doctest: +SKIP
636+
637+
Args:
638+
model (bigframes.ml.base.BaseEstimator or str):
639+
The model to use for table generation.
640+
data (bigframes.pandas.DataFrame or bigframes.pandas.Series):
641+
The data to generate embeddings for. If a Series is provided, it is
642+
treated as the 'content' column. If a DataFrame is provided, it
643+
must contain a 'content' column, or you must rename the column you
644+
wish to embed to 'content'.
645+
output_schema (str):
646+
A string defining the output schema (e.g., "col1 STRING, col2 INT64").
647+
temperature (float, optional):
648+
A FLOAT64 value that is used for sampling promiscuity. The value
649+
must be in the range ``[0.0, 1.0]``.
650+
top_p (float, optional):
651+
A FLOAT64 value that changes how the model selects tokens for
652+
output.
653+
max_output_tokens (int, optional):
654+
An INT64 value that sets the maximum number of tokens in the
655+
generated table.
656+
stop_sequences (List[str], optional):
657+
An ARRAY<STRING> value that contains the stop sequences for the model.
658+
request_type (str, optional):
659+
A STRING value that contains the request type for the model.
660+
661+
Returns:
662+
bigframes.pandas.DataFrame:
663+
The generated table.
664+
"""
665+
data = _to_dataframe(data, series_rename="prompt")
666+
model_name, session = bq_utils.get_model_name_and_session(model, data)
667+
table_sql = bq_utils.to_sql(data)
668+
669+
struct_fields_bq: Dict[str, bigframes.core.sql.literals.STRUCT_VALUES] = {
670+
"output_schema": output_schema
671+
}
672+
if temperature is not None:
673+
struct_fields_bq["temperature"] = temperature
674+
if top_p is not None:
675+
struct_fields_bq["top_p"] = top_p
676+
if max_output_tokens is not None:
677+
struct_fields_bq["max_output_tokens"] = max_output_tokens
678+
if stop_sequences is not None:
679+
struct_fields_bq["stop_sequences"] = stop_sequences
680+
if request_type is not None:
681+
struct_fields_bq["request_type"] = request_type
682+
683+
struct_sql = bigframes.core.sql.literals.struct_literal(struct_fields_bq)
684+
query = f"""
685+
SELECT *
686+
FROM AI.GENERATE_TABLE(
687+
MODEL `{model_name}`,
688+
({table_sql}),
689+
{struct_sql}
690+
)
691+
"""
692+
693+
if session is None:
694+
return bpd.read_gbq_query(query)
695+
else:
696+
return session.read_gbq_query(query)
697+
698+
604699
@log_adapter.method_logger(custom_base_name="bigquery_ai")
605700
def if_(
606701
prompt: PROMPT_TYPE,

bigframes/bigquery/ai.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
generate_double,
2525
generate_embedding,
2626
generate_int,
27+
generate_table,
2728
generate_text,
2829
if_,
2930
score,
@@ -37,6 +38,7 @@
3738
"generate_double",
3839
"generate_embedding",
3940
"generate_int",
41+
"generate_table",
4042
"generate_text",
4143
"if_",
4244
"score",

tests/system/large/bigquery/test_ai.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,3 +94,20 @@ def test_generate_text_with_options(text_model):
9494

9595
# It basically asserts that the results are still returned.
9696
assert len(result) == 2
97+
98+
99+
def test_generate_table(text_model):
100+
df = bpd.DataFrame(
101+
{"prompt": ["Generate a table of 2 programming languages and their creators."]}
102+
)
103+
104+
result = ai.generate_table(
105+
text_model,
106+
df,
107+
output_schema="language STRING, creator STRING",
108+
)
109+
110+
assert "language" in result.columns
111+
assert "creator" in result.columns
112+
# The model may not always return the exact number of rows requested.
113+
assert len(result) > 0

tests/unit/bigquery/test_ai.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -220,6 +220,55 @@ def test_generate_text_defaults(mock_dataframe, mock_session):
220220
assert "STRUCT()" in query
221221

222222

223+
def test_generate_table_with_dataframe(mock_dataframe, mock_session):
224+
model_name = "project.dataset.model"
225+
226+
bbq.ai.generate_table(
227+
model_name,
228+
mock_dataframe,
229+
output_schema="col1 STRING, col2 INT64",
230+
)
231+
232+
mock_session.read_gbq_query.assert_called_once()
233+
query = mock_session.read_gbq_query.call_args[0][0]
234+
235+
# Normalize whitespace for comparison
236+
query = " ".join(query.split())
237+
238+
expected_part_1 = "SELECT * FROM AI.GENERATE_TABLE("
239+
expected_part_2 = f"MODEL `{model_name}`,"
240+
expected_part_3 = "(SELECT * FROM my_table),"
241+
expected_part_4 = "STRUCT('col1 STRING, col2 INT64' AS output_schema)"
242+
243+
assert expected_part_1 in query
244+
assert expected_part_2 in query
245+
assert expected_part_3 in query
246+
assert expected_part_4 in query
247+
248+
249+
def test_generate_table_with_options(mock_dataframe, mock_session):
250+
model_name = "project.dataset.model"
251+
252+
bbq.ai.generate_table(
253+
model_name,
254+
mock_dataframe,
255+
output_schema="col1 STRING",
256+
temperature=0.5,
257+
max_output_tokens=100,
258+
)
259+
260+
mock_session.read_gbq_query.assert_called_once()
261+
query = mock_session.read_gbq_query.call_args[0][0]
262+
query = " ".join(query.split())
263+
264+
assert f"MODEL `{model_name}`" in query
265+
assert "(SELECT * FROM my_table)" in query
266+
assert (
267+
"STRUCT('col1 STRING' AS output_schema, 0.5 AS temperature, 100 AS max_output_tokens)"
268+
in query
269+
)
270+
271+
223272
@mock.patch("bigframes.pandas.read_pandas")
224273
def test_generate_text_with_pandas_dataframe(
225274
read_pandas_mock, mock_dataframe, mock_session

0 commit comments

Comments
 (0)