diff --git a/bigframes/bigquery/_operations/ai.py b/bigframes/bigquery/_operations/ai.py index a789310683..0c5eba9496 100644 --- a/bigframes/bigquery/_operations/ai.py +++ b/bigframes/bigquery/_operations/ai.py @@ -25,7 +25,7 @@ from bigframes import clients, dtypes, series, session from bigframes.core import convert, log_adapter -from bigframes.operations import ai_ops +from bigframes.operations import ai_ops, output_schemas PROMPT_TYPE = Union[ series.Series, @@ -43,7 +43,7 @@ def generate( endpoint: str | None = None, request_type: Literal["dedicated", "shared", "unspecified"] = "unspecified", model_params: Mapping[Any, Any] | None = None, - # TODO(b/446974666) Add output_schema parameter + output_schema: Mapping[str, str] | None = None, ) -> series.Series: """ Returns the AI analysis based on the prompt, which can be any combination of text and unstructured data. @@ -64,6 +64,14 @@ def generate( 1 Ottawa\\n Name: result, dtype: string + You get structured output when the `output_schema` parameter is set: + + >>> animals = bpd.Series(["Rabbit", "Spider"]) + >>> bbq.ai.generate(animals, output_schema={"number_of_legs": "INT64", "is_herbivore": "BOOL"}) + 0 {'is_herbivore': True, 'number_of_legs': 4, 'f... + 1 {'is_herbivore': False, 'number_of_legs': 8, '... + dtype: struct>, status: string>[pyarrow] + Args: prompt (Series | List[str|Series] | Tuple[str|Series, ...]): A mixture of Series and string literals that specifies the prompt to send to the model. The Series can be BigFrames Series @@ -86,10 +94,14 @@ def generate( If requests exceed the Provisioned Throughput quota, the overflow traffic uses DSQ quota. model_params (Mapping[Any, Any]): Provides additional parameters to the model. The MODEL_PARAMS value must conform to the generateContent request body format. + output_schema (Mapping[str, str]): + A mapping value that specifies the schema of the output, in the form {field_name: data_type}. Supported data types include + `STRING`, `INT64`, `FLOAT64`, `BOOL`, `ARRAY`, and `STRUCT`. Returns: bigframes.series.Series: A new struct Series with the result data. The struct contains these fields: * "result": a STRING value containing the model's response to the prompt. The result is None if the request fails or is filtered by responsible AI. + If you specify an output schema then result is replaced by your custom schema. * "full_response": a JSON value containing the response from the projects.locations.endpoints.generateContent call to the model. The generated text is in the text element. * "status": a STRING value that contains the API response status for the corresponding row. This value is empty if the operation was successful. @@ -98,12 +110,22 @@ def generate( prompt_context, series_list = _separate_context_and_series(prompt) assert len(series_list) > 0 + if output_schema is None: + output_schema_str = None + else: + output_schema_str = ", ".join( + [f"{name} {sql_type}" for name, sql_type in output_schema.items()] + ) + # Validate user input + output_schemas.parse_sql_fields(output_schema_str) + operator = ai_ops.AIGenerate( prompt_context=tuple(prompt_context), connection_id=_resolve_connection_id(series_list[0], connection_id), endpoint=endpoint, request_type=request_type, model_params=json.dumps(model_params) if model_params else None, + output_schema=output_schema_str, ) return series_list[0]._apply_nary_op(operator, series_list[1:]) diff --git a/bigframes/core/compile/ibis_compiler/scalar_op_registry.py b/bigframes/core/compile/ibis_compiler/scalar_op_registry.py index 4c02e17d6f..e983fc7e21 100644 --- a/bigframes/core/compile/ibis_compiler/scalar_op_registry.py +++ b/bigframes/core/compile/ibis_compiler/scalar_op_registry.py @@ -1985,6 +1985,7 @@ def ai_generate( op.endpoint, # type: ignore op.request_type.upper(), # type: ignore op.model_params, # type: ignore + op.output_schema, # type: ignore ).to_expr() diff --git a/bigframes/core/compile/sqlglot/expressions/ai_ops.py b/bigframes/core/compile/sqlglot/expressions/ai_ops.py index 4129c91906..e40173d2fd 100644 --- a/bigframes/core/compile/sqlglot/expressions/ai_ops.py +++ b/bigframes/core/compile/sqlglot/expressions/ai_ops.py @@ -15,7 +15,6 @@ from __future__ import annotations from dataclasses import asdict -import typing import sqlglot.expressions as sge @@ -105,16 +104,16 @@ def _construct_named_args(op: ops.NaryOp) -> list[sge.Kwarg]: op_args = asdict(op) - connection_id = typing.cast(str, op_args["connection_id"]) + connection_id = op_args["connection_id"] args.append( sge.Kwarg(this="connection_id", expression=sge.Literal.string(connection_id)) ) - endpoit = typing.cast(str, op_args.get("endpoint", None)) + endpoit = op_args.get("endpoint", None) if endpoit is not None: args.append(sge.Kwarg(this="endpoint", expression=sge.Literal.string(endpoit))) - request_type = typing.cast(str, op_args.get("request_type", None)) + request_type = op_args.get("request_type", None) if request_type is not None: args.append( sge.Kwarg( @@ -122,7 +121,7 @@ def _construct_named_args(op: ops.NaryOp) -> list[sge.Kwarg]: ) ) - model_params = typing.cast(str, op_args.get("model_params", None)) + model_params = op_args.get("model_params", None) if model_params is not None: args.append( sge.Kwarg( @@ -133,4 +132,13 @@ def _construct_named_args(op: ops.NaryOp) -> list[sge.Kwarg]: ) ) + output_schema = op_args.get("output_schema", None) + if output_schema is not None: + args.append( + sge.Kwarg( + this="output_schema", + expression=sge.Literal.string(output_schema), + ) + ) + return args diff --git a/bigframes/operations/ai_ops.py b/bigframes/operations/ai_ops.py index 7ba3737ba0..ea65b705e5 100644 --- a/bigframes/operations/ai_ops.py +++ b/bigframes/operations/ai_ops.py @@ -21,7 +21,7 @@ import pyarrow as pa from bigframes import dtypes -from bigframes.operations import base_ops +from bigframes.operations import base_ops, output_schemas @dataclasses.dataclass(frozen=True) @@ -33,12 +33,18 @@ class AIGenerate(base_ops.NaryOp): endpoint: str | None request_type: Literal["dedicated", "shared", "unspecified"] model_params: str | None + output_schema: str | None def output_type(self, *input_types: dtypes.ExpressionType) -> dtypes.ExpressionType: + if self.output_schema is None: + output_fields = (pa.field("result", pa.string()),) + else: + output_fields = output_schemas.parse_sql_fields(self.output_schema) + return pd.ArrowDtype( pa.struct( ( - pa.field("result", pa.string()), + *output_fields, pa.field("full_response", dtypes.JSON_ARROW_TYPE), pa.field("status", pa.string()), ) diff --git a/bigframes/operations/output_schemas.py b/bigframes/operations/output_schemas.py new file mode 100644 index 0000000000..ff9c9883dc --- /dev/null +++ b/bigframes/operations/output_schemas.py @@ -0,0 +1,90 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pyarrow as pa + + +def parse_sql_type(sql: str) -> pa.DataType: + """ + Parses a SQL type string to its PyArrow equivalence: + + For example: + "STRING" -> pa.string() + "ARRAY" -> pa.list_(pa.int64()) + "STRUCT, y BOOL>" -> pa.struct( + ( + pa.field("x", pa.list_(pa.float64())), + pa.field("y", pa.bool_()), + ) + ) + """ + sql = sql.strip() + + if sql.upper() == "STRING": + return pa.string() + + if sql.upper() == "INT64": + return pa.int64() + + if sql.upper() == "FLOAT64": + return pa.float64() + + if sql.upper() == "BOOL": + return pa.bool_() + + if sql.upper().startswith("ARRAY<") and sql.endswith(">"): + inner_type = sql[len("ARRAY<") : -1] + return pa.list_(parse_sql_type(inner_type)) + + if sql.upper().startswith("STRUCT<") and sql.endswith(">"): + inner_fields = parse_sql_fields(sql[len("STRUCT<") : -1]) + return pa.struct(inner_fields) + + raise ValueError(f"Unsupported SQL type: {sql}") + + +def parse_sql_fields(sql: str) -> tuple[pa.Field]: + sql = sql.strip() + + start_idx = 0 + nested_depth = 0 + fields: list[pa.field] = [] + + for end_idx in range(len(sql)): + c = sql[end_idx] + + if c == "<": + nested_depth += 1 + elif c == ">": + nested_depth -= 1 + elif c == "," and nested_depth == 0: + field = sql[start_idx:end_idx] + fields.append(parse_sql_field(field)) + start_idx = end_idx + 1 + + # Append the last field + fields.append(parse_sql_field(sql[start_idx:])) + + return tuple(sorted(fields, key=lambda f: f.name)) + + +def parse_sql_field(sql: str) -> pa.Field: + sql = sql.strip() + + space_idx = sql.find(" ") + + if space_idx == -1: + raise ValueError(f"Invalid struct field: {sql}") + + return pa.field(sql[:space_idx].strip(), parse_sql_type(sql[space_idx:])) diff --git a/tests/system/small/bigquery/test_ai.py b/tests/system/small/bigquery/test_ai.py index 7a6e5aea4f..2ccdb01944 100644 --- a/tests/system/small/bigquery/test_ai.py +++ b/tests/system/small/bigquery/test_ai.py @@ -87,6 +87,41 @@ def test_ai_generate(session): ) +def test_ai_generate_with_output_schema(session): + country = bpd.Series(["Japan", "Canada"], session=session) + prompt = ("Describe ", country) + + result = bbq.ai.generate( + prompt, + endpoint="gemini-2.5-flash", + output_schema={"population": "INT64", "is_in_north_america": "bool"}, + ) + + assert _contains_no_nulls(result) + assert result.dtype == pd.ArrowDtype( + pa.struct( + ( + pa.field("is_in_north_america", pa.bool_()), + pa.field("population", pa.int64()), + pa.field("full_response", dtypes.JSON_ARROW_TYPE), + pa.field("status", pa.string()), + ) + ) + ) + + +def test_ai_generate_with_invalid_output_schema_raise_error(session): + country = bpd.Series(["Japan", "Canada"], session=session) + prompt = ("Describe ", country) + + with pytest.raises(ValueError): + bbq.ai.generate( + prompt, + endpoint="gemini-2.5-flash", + output_schema={"population": "INT64", "is_in_north_america": "JSON"}, + ) + + def test_ai_generate_bool(session): s1 = bpd.Series(["apple", "bear"], session=session) s2 = bpd.Series(["fruit", "tree"], session=session) diff --git a/tests/unit/core/compile/sqlglot/expressions/snapshots/test_ai_ops/test_ai_generate_with_output_schema/out.sql b/tests/unit/core/compile/sqlglot/expressions/snapshots/test_ai_ops/test_ai_generate_with_output_schema/out.sql new file mode 100644 index 0000000000..62fc2f9db0 --- /dev/null +++ b/tests/unit/core/compile/sqlglot/expressions/snapshots/test_ai_ops/test_ai_generate_with_output_schema/out.sql @@ -0,0 +1,19 @@ +WITH `bfcte_0` AS ( + SELECT + `string_col` AS `bfcol_0` + FROM `bigframes-dev`.`sqlglot_test`.`scalar_types` +), `bfcte_1` AS ( + SELECT + *, + AI.GENERATE( + prompt => (`bfcol_0`, ' is the same as ', `bfcol_0`), + connection_id => 'bigframes-dev.us.bigframes-default-connection', + endpoint => 'gemini-2.5-flash', + request_type => 'SHARED', + output_schema => 'x INT64, y FLOAT64' + ) AS `bfcol_1` + FROM `bfcte_0` +) +SELECT + `bfcol_1` AS `result` +FROM `bfcte_1` \ No newline at end of file diff --git a/tests/unit/core/compile/sqlglot/expressions/test_ai_ops.py b/tests/unit/core/compile/sqlglot/expressions/test_ai_ops.py index c809e90a90..13481d88c6 100644 --- a/tests/unit/core/compile/sqlglot/expressions/test_ai_ops.py +++ b/tests/unit/core/compile/sqlglot/expressions/test_ai_ops.py @@ -36,6 +36,26 @@ def test_ai_generate(scalar_types_df: dataframe.DataFrame, snapshot): endpoint="gemini-2.5-flash", request_type="shared", model_params=None, + output_schema=None, + ) + + sql = utils._apply_unary_ops( + scalar_types_df, [op.as_expr(col_name, col_name)], ["result"] + ) + + snapshot.assert_match(sql, "out.sql") + + +def test_ai_generate_with_output_schema(scalar_types_df: dataframe.DataFrame, snapshot): + col_name = "string_col" + + op = ops.AIGenerate( + prompt_context=(None, " is the same as ", None), + connection_id=CONNECTION_ID, + endpoint="gemini-2.5-flash", + request_type="shared", + model_params=None, + output_schema="x INT64, y FLOAT64", ) sql = utils._apply_unary_ops( @@ -59,6 +79,7 @@ def test_ai_generate_with_model_param(scalar_types_df: dataframe.DataFrame, snap endpoint=None, request_type="shared", model_params=json.dumps(dict()), + output_schema=None, ) sql = utils._apply_unary_ops( diff --git a/tests/unit/operations/test_output_schemas.py b/tests/unit/operations/test_output_schemas.py new file mode 100644 index 0000000000..c609098c98 --- /dev/null +++ b/tests/unit/operations/test_output_schemas.py @@ -0,0 +1,99 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pyarrow as pa +import pytest + +from bigframes.operations import output_schemas + + +@pytest.mark.parametrize( + ("sql", "expected"), + [ + ("INT64", pa.int64()), + (" INT64 ", pa.int64()), + ("int64", pa.int64()), + ("FLOAT64", pa.float64()), + ("STRING", pa.string()), + ("BOOL", pa.bool_()), + ("ARRAY", pa.list_(pa.int64())), + ( + "STRUCT", + pa.struct((pa.field("x", pa.int64()), pa.field("y", pa.float64()))), + ), + ( + "STRUCT< x INT64, y FLOAT64>", + pa.struct((pa.field("x", pa.int64()), pa.field("y", pa.float64()))), + ), + ( + "STRUCT", + pa.struct((pa.field("x", pa.float64()), pa.field("y", pa.int64()))), + ), + ( + "ARRAY>", + pa.list_(pa.struct((pa.field("x", pa.int64()), pa.field("y", pa.int64())))), + ), + ( + "STRUCT, x ARRAY>", + pa.struct( + ( + pa.field("x", pa.list_(pa.float64())), + pa.field( + "y", + pa.struct( + (pa.field("a", pa.bool_()), pa.field("b", pa.string())) + ), + ), + ) + ), + ), + ], +) +def test_parse_sql_to_pyarrow_dtype(sql, expected): + assert output_schemas.parse_sql_type(sql) == expected + + +@pytest.mark.parametrize( + "sql", + [ + "a INT64", + "ARRAY<>", + "ARRAY" "ARRAY" "STRUCT<>", + "DATE", + "STRUCT", + "ARRAY>", + ], +) +def test_parse_sql_to_pyarrow_dtype_invalid_input_raies_error(sql): + with pytest.raises(ValueError): + output_schemas.parse_sql_type(sql) + + +@pytest.mark.parametrize( + ("sql", "expected"), + [ + ("x INT64", (pa.field("x", pa.int64()),)), + ( + "x INT64, y FLOAT64", + (pa.field("x", pa.int64()), pa.field("y", pa.float64())), + ), + ( + "y FLOAT64, x INT64", + (pa.field("x", pa.int64()), pa.field("y", pa.float64())), + ), + ], +) +def test_parse_sql_fields(sql, expected): + assert output_schemas.parse_sql_fields(sql) == expected diff --git a/third_party/bigframes_vendored/ibis/expr/operations/ai_ops.py b/third_party/bigframes_vendored/ibis/expr/operations/ai_ops.py index e9d704fa8e..da7f132de3 100644 --- a/third_party/bigframes_vendored/ibis/expr/operations/ai_ops.py +++ b/third_party/bigframes_vendored/ibis/expr/operations/ai_ops.py @@ -1,6 +1,6 @@ # Contains code from https://github.com/ibis-project/ibis/blob/9.2.0/ibis/expr/operations/maps.py -"""Operations for working with maps.""" +"""Operations for working with AI operators.""" from __future__ import annotations @@ -11,6 +11,9 @@ from bigframes_vendored.ibis.expr.operations.core import Value import bigframes_vendored.ibis.expr.rules as rlz from public import public +import pyarrow as pa + +from bigframes.operations import output_schemas @public @@ -22,15 +25,27 @@ class AIGenerate(Value): endpoint: Optional[Value[dt.String]] request_type: Value[dt.String] model_params: Optional[Value[dt.String]] + output_schema: Optional[Value[dt.String]] shape = rlz.shape_like("prompt") @attribute def dtype(self) -> dt.Struct: - return dt.Struct.from_tuples( - (("result", dt.string), ("full_resposne", dt.string), ("status", dt.string)) + if self.output_schema is None: + output_pa_fields = (pa.field("result", pa.string()),) + else: + output_pa_fields = output_schemas.parse_sql_fields(self.output_schema.value) + + pyarrow_output_type = pa.struct( + ( + *output_pa_fields, + pa.field("full_resposne", pa.string()), + pa.field("status", pa.string()), + ) ) + return dt.Struct.from_pyarrow(pyarrow_output_type) + @public class AIGenerateBool(Value):