From 77d920c828d3331ae1463640a76f7396e827441a Mon Sep 17 00:00:00 2001 From: Shenyang Cai Date: Sun, 5 Oct 2025 01:48:12 +0000 Subject: [PATCH 01/12] feat: add output_schema to ai.generate() --- bigframes/bigquery/_operations/ai.py | 27 +++++- .../ibis_compiler/scalar_op_registry.py | 1 + .../compile/sqlglot/expressions/ai_ops.py | 17 +++- bigframes/operations/ai_ops.py | 10 ++- bigframes/operations/output_schemas.py | 90 +++++++++++++++++++ tests/system/small/bigquery/test_ai.py | 35 ++++++++ .../out.sql | 19 ++++ .../sqlglot/expressions/test_ai_ops.py | 21 +++++ tests/unit/operations/test_output_schemas.py | 90 +++++++++++++++++++ .../ibis/expr/operations/ai_ops.py | 21 ++++- 10 files changed, 320 insertions(+), 11 deletions(-) create mode 100644 bigframes/operations/output_schemas.py create mode 100644 tests/unit/core/compile/sqlglot/expressions/snapshots/test_ai_ops/test_ai_generate_with_output_schema/out.sql create mode 100644 tests/unit/operations/test_output_schemas.py diff --git a/bigframes/bigquery/_operations/ai.py b/bigframes/bigquery/_operations/ai.py index 4759c99016..820c4e8a31 100644 --- a/bigframes/bigquery/_operations/ai.py +++ b/bigframes/bigquery/_operations/ai.py @@ -25,7 +25,7 @@ from bigframes import clients, dtypes, series, session from bigframes.core import convert, log_adapter -from bigframes.operations import ai_ops +from bigframes.operations import ai_ops, output_schemas PROMPT_TYPE = Union[ series.Series, @@ -43,7 +43,7 @@ def generate( endpoint: str | None = None, request_type: Literal["dedicated", "shared", "unspecified"] = "unspecified", model_params: Mapping[Any, Any] | None = None, - # TODO(b/446974666) Add output_schema parameter + output_schema: Mapping[str, str] | None = None, ) -> series.Series: """ Returns the AI analysis based on the prompt, which can be any combination of text and unstructured data. @@ -64,6 +64,14 @@ def generate( 1 Ottawa\\n Name: result, dtype: string + You get structured output when the `output_schema` parameter is set: + + >>> animals = bpd.Series(["Rabbit", "Spider"]) + >>> bbq.ai.generate(animals, output_schema={"number_of_legs": "INT64", "is_herbivore": "BOOL"}) + 0 {'is_herbivore': True, 'number_of_legs': 4, 'f... + 1 {'is_herbivore': False, 'number_of_legs': 8, '... + dtype: struct>, status: string>[pyarrow] + Args: prompt (Series | List[str|Series] | Tuple[str|Series, ...]): A mixture of Series and string literals that specifies the prompt to send to the model. The Series can be BigFrames Series @@ -86,10 +94,14 @@ def generate( If requests exceed the Provisioned Throughput quota, the overflow traffic uses DSQ quota. model_params (Mapping[Any, Any]): Provides additional parameters to the model. The MODEL_PARAMS value must conform to the generateContent request body format. + output_schema (Mapping[str, str]): + A mapping value that specifies the schema of the output, in the form {field_name: data_type}. Supported data types include + `STRING`, `INT64`, `FLOAT64`, `BOOL`, `ARRAY`, and `STRUCT`. Returns: bigframes.series.Series: A new struct Series with the result data. The struct contains these fields: * "result": a STRING value containing the model's response to the prompt. The result is None if the request fails or is filtered by responsible AI. + If you specify an output schema then result is replaced by your custom schema. * "full_response": a JSON value containing the response from the projects.locations.endpoints.generateContent call to the model. The generated text is in the text element. * "status": a STRING value that contains the API response status for the corresponding row. This value is empty if the operation was successful. @@ -98,12 +110,23 @@ def generate( prompt_context, series_list = _separate_context_and_series(prompt) assert len(series_list) > 0 + if output_schema is None: + output_schema_str = None + else: + sorted_fields = sorted(tuple(output_schema.items()), key=lambda t: t[0]) + output_schema_str = ", ".join( + [f"{name} {sql_type}" for name, sql_type in sorted_fields] + ) + # Validate user input + output_schemas.parse_sql_fields(output_schema_str) + operator = ai_ops.AIGenerate( prompt_context=tuple(prompt_context), connection_id=_resolve_connection_id(series_list[0], connection_id), endpoint=endpoint, request_type=request_type, model_params=json.dumps(model_params) if model_params else None, + output_schema=output_schema_str, ) return series_list[0]._apply_nary_op(operator, series_list[1:]) diff --git a/bigframes/core/compile/ibis_compiler/scalar_op_registry.py b/bigframes/core/compile/ibis_compiler/scalar_op_registry.py index 7280e9a40a..a7e30f438a 100644 --- a/bigframes/core/compile/ibis_compiler/scalar_op_registry.py +++ b/bigframes/core/compile/ibis_compiler/scalar_op_registry.py @@ -1985,6 +1985,7 @@ def ai_generate( op.endpoint, # type: ignore op.request_type.upper(), # type: ignore op.model_params, # type: ignore + op.output_schema, # type: ignore ).to_expr() diff --git a/bigframes/core/compile/sqlglot/expressions/ai_ops.py b/bigframes/core/compile/sqlglot/expressions/ai_ops.py index 46a79d1440..2fe15965bc 100644 --- a/bigframes/core/compile/sqlglot/expressions/ai_ops.py +++ b/bigframes/core/compile/sqlglot/expressions/ai_ops.py @@ -88,16 +88,16 @@ def _construct_named_args(op: ops.NaryOp) -> list[sge.Kwarg]: op_args = asdict(op) - connection_id = typing.cast(str, op_args["connection_id"]) + connection_id = op_args["connection_id"] args.append( sge.Kwarg(this="connection_id", expression=sge.Literal.string(connection_id)) ) - endpoit = typing.cast(str, op_args.get("endpoint", None)) + endpoit = op_args.get("endpoint", None) if endpoit is not None: args.append(sge.Kwarg(this="endpoint", expression=sge.Literal.string(endpoit))) - request_type = typing.cast(str, op_args.get("request_type", None)) + request_type = op_args.get("request_type", None) if request_type is not None: args.append( sge.Kwarg( @@ -105,7 +105,7 @@ def _construct_named_args(op: ops.NaryOp) -> list[sge.Kwarg]: ) ) - model_params = typing.cast(str, op_args.get("model_params", None)) + model_params = op_args.get("model_params", None) if model_params is not None: args.append( sge.Kwarg( @@ -116,4 +116,13 @@ def _construct_named_args(op: ops.NaryOp) -> list[sge.Kwarg]: ) ) + output_schema = op_args.get("output_schema", None) + if output_schema is not None: + args.append( + sge.Kwarg( + this="output_schema", + expression=sge.Literal.string(output_schema), + ) + ) + return args diff --git a/bigframes/operations/ai_ops.py b/bigframes/operations/ai_ops.py index 05d37d2a90..f4e9ed7c94 100644 --- a/bigframes/operations/ai_ops.py +++ b/bigframes/operations/ai_ops.py @@ -21,7 +21,7 @@ import pyarrow as pa from bigframes import dtypes -from bigframes.operations import base_ops +from bigframes.operations import base_ops, output_schemas @dataclasses.dataclass(frozen=True) @@ -33,12 +33,18 @@ class AIGenerate(base_ops.NaryOp): endpoint: str | None request_type: Literal["dedicated", "shared", "unspecified"] model_params: str | None + output_schema: str | None def output_type(self, *input_types: dtypes.ExpressionType) -> dtypes.ExpressionType: + if self.output_schema is None: + output_fields = [pa.field("result", pa.string())] + else: + output_fields = output_schemas.parse_sql_fields(self.output_schema) + return pd.ArrowDtype( pa.struct( ( - pa.field("result", pa.string()), + *output_fields, pa.field("full_response", dtypes.JSON_ARROW_TYPE), pa.field("status", pa.string()), ) diff --git a/bigframes/operations/output_schemas.py b/bigframes/operations/output_schemas.py new file mode 100644 index 0000000000..c1ab318bf4 --- /dev/null +++ b/bigframes/operations/output_schemas.py @@ -0,0 +1,90 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pyarrow as pa + + +def parse_sql_type(sql: str) -> pa.DataType: + """ + Parses a SQL type string to its PyArrow equivalence: + + For example: + "STRING" -> pa.string() + "ARRAY" -> pa.list_(pa.int64()) + "STRUCT, y BOOL>" -> pa.struct( + ( + pa.field("x", pa.list_(pa.float64())), + pa.field("y", pa.bool_()), + ) + ) + """ + sql = sql.strip() + + if sql == "STRING": + return pa.string() + + if sql == "INT64": + return pa.int64() + + if sql == "FLOAT64": + return pa.float64() + + if sql == "BOOL": + return pa.bool_() + + if sql.startswith("ARRAY<") and sql.endswith(">"): + inner_type = sql[len("ARRAY<") : -1] + return pa.list_(parse_sql_type(inner_type)) + + if sql.startswith("STRUCT<") and sql.endswith(">"): + inner_fields = parse_sql_fields(sql[len("STRUCT<") : -1]) + return pa.struct(parse_sql_fields(inner_fields)) + + raise ValueError(f"Unsupported SQL type: {sql}") + + +def parse_sql_fields(sql: str) -> tuple[pa.Field]: + sql = sql.strip() + + start_idx = 0 + nested_depth = 0 + fields: list[pa.field] = [] + + for end_idx in range(len(sql)): + c = sql[end_idx] + + if c == "<": + nested_depth += 1 + elif c == ">": + nested_depth -= 1 + elif c == "," and nested_depth == 0: + field = sql[start_idx:end_idx] + fields.append(parse_sql_field(field)) + start_idx = end_idx + 1 + + # Append the last field + fields.append(parse_sql_field(sql[start_idx:])) + + return tuple(fields) + + +def parse_sql_field(sql: str) -> pa.Field: + sql = sql.strip() + + space_idx = sql.find(" ") + + if space_idx == -1: + raise ValueError(f"Invalid struct field: {sql}") + + return pa.field(sql[:space_idx].strip(), parse_sql_type(sql[space_idx:])) diff --git a/tests/system/small/bigquery/test_ai.py b/tests/system/small/bigquery/test_ai.py index 91499d0efe..8e9e037311 100644 --- a/tests/system/small/bigquery/test_ai.py +++ b/tests/system/small/bigquery/test_ai.py @@ -87,6 +87,41 @@ def test_ai_generate(session): ) +def test_ai_generate_with_output_schema(session): + country = bpd.Series(["Japan", "Canada"], session=session) + prompt = ("Describe ", country) + + result = bbq.ai.generate( + prompt, + endpoint="gemini-2.5-flash", + output_schema={"population": "INT64", "is_in_north_america": "BOOL"}, + ) + + assert _contains_no_nulls(result) + assert result.dtype == pd.ArrowDtype( + pa.struct( + ( + pa.field("is_in_north_america", pa.bool_()), + pa.field("population", pa.int64()), + pa.field("full_response", dtypes.JSON_ARROW_TYPE), + pa.field("status", pa.string()), + ) + ) + ) + + +def test_ai_generate_with_invalid_output_schema_raise_error(session): + country = bpd.Series(["Japan", "Canada"], session=session) + prompt = ("Describe ", country) + + with pytest.raises(ValueError): + result = bbq.ai.generate( + prompt, + endpoint="gemini-2.5-flash", + output_schema={"population": "INT64", "is_in_north_america": "bool"}, + ) + + def test_ai_generate_bool(session): s1 = bpd.Series(["apple", "bear"], session=session) s2 = bpd.Series(["fruit", "tree"], session=session) diff --git a/tests/unit/core/compile/sqlglot/expressions/snapshots/test_ai_ops/test_ai_generate_with_output_schema/out.sql b/tests/unit/core/compile/sqlglot/expressions/snapshots/test_ai_ops/test_ai_generate_with_output_schema/out.sql new file mode 100644 index 0000000000..62fc2f9db0 --- /dev/null +++ b/tests/unit/core/compile/sqlglot/expressions/snapshots/test_ai_ops/test_ai_generate_with_output_schema/out.sql @@ -0,0 +1,19 @@ +WITH `bfcte_0` AS ( + SELECT + `string_col` AS `bfcol_0` + FROM `bigframes-dev`.`sqlglot_test`.`scalar_types` +), `bfcte_1` AS ( + SELECT + *, + AI.GENERATE( + prompt => (`bfcol_0`, ' is the same as ', `bfcol_0`), + connection_id => 'bigframes-dev.us.bigframes-default-connection', + endpoint => 'gemini-2.5-flash', + request_type => 'SHARED', + output_schema => 'x INT64, y FLOAT64' + ) AS `bfcol_1` + FROM `bfcte_0` +) +SELECT + `bfcol_1` AS `result` +FROM `bfcte_1` \ No newline at end of file diff --git a/tests/unit/core/compile/sqlglot/expressions/test_ai_ops.py b/tests/unit/core/compile/sqlglot/expressions/test_ai_ops.py index 8f048a5bbf..f9bc5f983a 100644 --- a/tests/unit/core/compile/sqlglot/expressions/test_ai_ops.py +++ b/tests/unit/core/compile/sqlglot/expressions/test_ai_ops.py @@ -36,6 +36,26 @@ def test_ai_generate(scalar_types_df: dataframe.DataFrame, snapshot): endpoint="gemini-2.5-flash", request_type="shared", model_params=None, + output_schema=None, + ) + + sql = utils._apply_unary_ops( + scalar_types_df, [op.as_expr(col_name, col_name)], ["result"] + ) + + snapshot.assert_match(sql, "out.sql") + + +def test_ai_generate_with_output_schema(scalar_types_df: dataframe.DataFrame, snapshot): + col_name = "string_col" + + op = ops.AIGenerate( + prompt_context=(None, " is the same as ", None), + connection_id=CONNECTION_ID, + endpoint="gemini-2.5-flash", + request_type="shared", + model_params=None, + output_schema='x INT64, y FLOAT64', ) sql = utils._apply_unary_ops( @@ -59,6 +79,7 @@ def test_ai_generate_with_model_param(scalar_types_df: dataframe.DataFrame, snap endpoint=None, request_type="shared", model_params=json.dumps(dict()), + output_schema=None, ) sql = utils._apply_unary_ops( diff --git a/tests/unit/operations/test_output_schemas.py b/tests/unit/operations/test_output_schemas.py new file mode 100644 index 0000000000..a0183b8f48 --- /dev/null +++ b/tests/unit/operations/test_output_schemas.py @@ -0,0 +1,90 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pyarrow as pa +import pytest + +from bigframes.operations import output_schemas + + +@pytest.mark.parametrize( + ("sql", "expected"), + [ + ("INT64", pa.int64()), + (" INT64 ", pa.int64()), + ("FLOAT64", pa.float64()), + ("STRING", pa.string()), + ("BOOL", pa.bool_()), + ("ARRAY", pa.list_(pa.int64())), + ( + "STRUCT", + pa.struct((pa.field("x", pa.int64()), pa.field("y", pa.float64()))), + ), + ( + "STRUCT< x INT64, y FLOAT64>", + pa.struct((pa.field("x", pa.int64()), pa.field("y", pa.float64()))), + ), + ( + "ARRAY>", + pa.list_(pa.struct((pa.field("x", pa.int64()), pa.field("y", pa.int64())))), + ), + ( + "STRUCT, y STRUCT>", + pa.struct( + ( + pa.field("x", pa.list_(pa.float64())), + pa.field( + "y", + pa.struct( + (pa.field("a", pa.bool_()), pa.field("b", pa.string())) + ), + ), + ) + ), + ), + ], +) +def test_parse_sql_to_pyarrow_dtype(sql, expected): + assert output_schemas.parse_sql_type(sql) == expected + + +@pytest.mark.parametrize( + "sql", + [ + "int64", + "a INT64", + "ARRAY<>", + "ARRAY" "ARRAY" "STRUCT<>", + "DATE", + "STRUCT", + "STRUCT", + ], +) +def test_parse_sql_to_pyarrow_dtype_invalid_input_raies_error(sql): + with pytest.raises(ValueError): + output_schemas.parse_sql_type(sql) + + +@pytest.mark.parametrize( + ("sql", "expected"), + [ + ("x INT64", (pa.field("x", pa.int64()),))( + "x INT64, y FLOAT64", + (pa.field("x", pa.int64()), pa.field("y", pa.float64())), + ) + ], +) +def test_parse_sql_fields(sql, expected): + assert output_schemas.parse_sql_fields(sql) == expected diff --git a/third_party/bigframes_vendored/ibis/expr/operations/ai_ops.py b/third_party/bigframes_vendored/ibis/expr/operations/ai_ops.py index 5289ee7e60..f89377fa18 100644 --- a/third_party/bigframes_vendored/ibis/expr/operations/ai_ops.py +++ b/third_party/bigframes_vendored/ibis/expr/operations/ai_ops.py @@ -1,6 +1,6 @@ # Contains code from https://github.com/ibis-project/ibis/blob/9.2.0/ibis/expr/operations/maps.py -"""Operations for working with maps.""" +"""Operations for working with AI operators.""" from __future__ import annotations @@ -11,6 +11,9 @@ from bigframes_vendored.ibis.expr.operations.core import Value import bigframes_vendored.ibis.expr.rules as rlz from public import public +import pyarrow as pa + +from bigframes.operations import output_schemas @public @@ -22,15 +25,27 @@ class AIGenerate(Value): endpoint: Optional[Value[dt.String]] request_type: Value[dt.String] model_params: Optional[Value[dt.String]] + output_schema: Optional[Value[dt.String]] shape = rlz.shape_like("prompt") @attribute def dtype(self) -> dt.Struct: - return dt.Struct.from_tuples( - (("result", dt.string), ("full_resposne", dt.string), ("status", dt.string)) + if self.output_schema is None: + output_pa_fields = pa.field("result", pa.string()) + else: + output_pa_fields = output_schemas.parse_sql_fields(self.output_schema.value) + + pyarrow_output_type = pa.struct( + ( + *output_pa_fields, + pa.field("full_resposne", pa.string()), + pa.field("status", pa.string()), + ) ) + return dt.Struct.from_pyarrow(pyarrow_output_type) + @public class AIGenerateBool(Value): From d6225af11168b60d61155b16221a2acdccde1ccf Mon Sep 17 00:00:00 2001 From: Shenyang Cai Date: Sun, 5 Oct 2025 01:51:00 +0000 Subject: [PATCH 02/12] fix lint --- tests/unit/core/compile/sqlglot/expressions/test_ai_ops.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/unit/core/compile/sqlglot/expressions/test_ai_ops.py b/tests/unit/core/compile/sqlglot/expressions/test_ai_ops.py index f9bc5f983a..2fb5a1f46a 100644 --- a/tests/unit/core/compile/sqlglot/expressions/test_ai_ops.py +++ b/tests/unit/core/compile/sqlglot/expressions/test_ai_ops.py @@ -55,7 +55,7 @@ def test_ai_generate_with_output_schema(scalar_types_df: dataframe.DataFrame, sn endpoint="gemini-2.5-flash", request_type="shared", model_params=None, - output_schema='x INT64, y FLOAT64', + output_schema="x INT64, y FLOAT64", ) sql = utils._apply_unary_ops( From ca083a0d0fa64c64e525e20a73428fb12b41f85c Mon Sep 17 00:00:00 2001 From: Shenyang Cai Date: Sun, 5 Oct 2025 01:53:47 +0000 Subject: [PATCH 03/12] fix lint --- bigframes/core/compile/sqlglot/expressions/ai_ops.py | 1 - tests/system/small/bigquery/test_ai.py | 2 +- 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/bigframes/core/compile/sqlglot/expressions/ai_ops.py b/bigframes/core/compile/sqlglot/expressions/ai_ops.py index 2fe15965bc..dfcee72a76 100644 --- a/bigframes/core/compile/sqlglot/expressions/ai_ops.py +++ b/bigframes/core/compile/sqlglot/expressions/ai_ops.py @@ -15,7 +15,6 @@ from __future__ import annotations from dataclasses import asdict -import typing import sqlglot.expressions as sge diff --git a/tests/system/small/bigquery/test_ai.py b/tests/system/small/bigquery/test_ai.py index 8e9e037311..fe0d679306 100644 --- a/tests/system/small/bigquery/test_ai.py +++ b/tests/system/small/bigquery/test_ai.py @@ -115,7 +115,7 @@ def test_ai_generate_with_invalid_output_schema_raise_error(session): prompt = ("Describe ", country) with pytest.raises(ValueError): - result = bbq.ai.generate( + bbq.ai.generate( prompt, endpoint="gemini-2.5-flash", output_schema={"population": "INT64", "is_in_north_america": "bool"}, From 565a9636a04e3bdb06d0045573d3540bf0ac672e Mon Sep 17 00:00:00 2001 From: Shenyang Cai Date: Sun, 5 Oct 2025 01:57:37 +0000 Subject: [PATCH 04/12] fix test --- tests/unit/operations/test_output_schemas.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/tests/unit/operations/test_output_schemas.py b/tests/unit/operations/test_output_schemas.py index a0183b8f48..bb8baa789a 100644 --- a/tests/unit/operations/test_output_schemas.py +++ b/tests/unit/operations/test_output_schemas.py @@ -80,10 +80,11 @@ def test_parse_sql_to_pyarrow_dtype_invalid_input_raies_error(sql): @pytest.mark.parametrize( ("sql", "expected"), [ - ("x INT64", (pa.field("x", pa.int64()),))( + ("x INT64", (pa.field("x", pa.int64()),)), + ( "x INT64, y FLOAT64", (pa.field("x", pa.int64()), pa.field("y", pa.float64())), - ) + ), ], ) def test_parse_sql_fields(sql, expected): From 8f841e5e3387e5579e3de078e59135c7257b0fdd Mon Sep 17 00:00:00 2001 From: Shenyang Cai Date: Sun, 5 Oct 2025 02:12:18 +0000 Subject: [PATCH 05/12] fix mypy --- bigframes/operations/ai_ops.py | 2 +- bigframes/operations/output_schemas.py | 3 ++- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/bigframes/operations/ai_ops.py b/bigframes/operations/ai_ops.py index f4e9ed7c94..5642a19bc2 100644 --- a/bigframes/operations/ai_ops.py +++ b/bigframes/operations/ai_ops.py @@ -37,7 +37,7 @@ class AIGenerate(base_ops.NaryOp): def output_type(self, *input_types: dtypes.ExpressionType) -> dtypes.ExpressionType: if self.output_schema is None: - output_fields = [pa.field("result", pa.string())] + output_fields = (pa.field("result", pa.string()),) else: output_fields = output_schemas.parse_sql_fields(self.output_schema) diff --git a/bigframes/operations/output_schemas.py b/bigframes/operations/output_schemas.py index c1ab318bf4..e4b78c68d5 100644 --- a/bigframes/operations/output_schemas.py +++ b/bigframes/operations/output_schemas.py @@ -49,7 +49,8 @@ def parse_sql_type(sql: str) -> pa.DataType: if sql.startswith("STRUCT<") and sql.endswith(">"): inner_fields = parse_sql_fields(sql[len("STRUCT<") : -1]) - return pa.struct(parse_sql_fields(inner_fields)) + + return pa.struct(sorted(inner_fields, key=lambda f:f.name)) raise ValueError(f"Unsupported SQL type: {sql}") From 4074ca4e786251a682df072ad104702bd942209a Mon Sep 17 00:00:00 2001 From: Shenyang Cai Date: Sun, 5 Oct 2025 02:15:38 +0000 Subject: [PATCH 06/12] fix lint --- bigframes/operations/output_schemas.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/bigframes/operations/output_schemas.py b/bigframes/operations/output_schemas.py index e4b78c68d5..3a0225a4c9 100644 --- a/bigframes/operations/output_schemas.py +++ b/bigframes/operations/output_schemas.py @@ -49,8 +49,7 @@ def parse_sql_type(sql: str) -> pa.DataType: if sql.startswith("STRUCT<") and sql.endswith(">"): inner_fields = parse_sql_fields(sql[len("STRUCT<") : -1]) - - return pa.struct(sorted(inner_fields, key=lambda f:f.name)) + return pa.struct(sorted(inner_fields, key=lambda f: f.name)) raise ValueError(f"Unsupported SQL type: {sql}") From 7b3d36f864100a6868fa93bdf94094433b52b986 Mon Sep 17 00:00:00 2001 From: Shenyang Cai Date: Sun, 5 Oct 2025 02:35:44 +0000 Subject: [PATCH 07/12] code optimization --- bigframes/bigquery/_operations/ai.py | 3 +-- bigframes/operations/output_schemas.py | 4 ++-- tests/unit/operations/test_output_schemas.py | 6 +++++- 3 files changed, 8 insertions(+), 5 deletions(-) diff --git a/bigframes/bigquery/_operations/ai.py b/bigframes/bigquery/_operations/ai.py index 820c4e8a31..99ddbb22c6 100644 --- a/bigframes/bigquery/_operations/ai.py +++ b/bigframes/bigquery/_operations/ai.py @@ -113,9 +113,8 @@ def generate( if output_schema is None: output_schema_str = None else: - sorted_fields = sorted(tuple(output_schema.items()), key=lambda t: t[0]) output_schema_str = ", ".join( - [f"{name} {sql_type}" for name, sql_type in sorted_fields] + [f"{name} {sql_type}" for name, sql_type in output_schema.items()] ) # Validate user input output_schemas.parse_sql_fields(output_schema_str) diff --git a/bigframes/operations/output_schemas.py b/bigframes/operations/output_schemas.py index 3a0225a4c9..fb1bb25932 100644 --- a/bigframes/operations/output_schemas.py +++ b/bigframes/operations/output_schemas.py @@ -49,7 +49,7 @@ def parse_sql_type(sql: str) -> pa.DataType: if sql.startswith("STRUCT<") and sql.endswith(">"): inner_fields = parse_sql_fields(sql[len("STRUCT<") : -1]) - return pa.struct(sorted(inner_fields, key=lambda f: f.name)) + return pa.struct(inner_fields) raise ValueError(f"Unsupported SQL type: {sql}") @@ -76,7 +76,7 @@ def parse_sql_fields(sql: str) -> tuple[pa.Field]: # Append the last field fields.append(parse_sql_field(sql[start_idx:])) - return tuple(fields) + return tuple(sorted(fields, key=lambda f: f.name)) def parse_sql_field(sql: str) -> pa.Field: diff --git a/tests/unit/operations/test_output_schemas.py b/tests/unit/operations/test_output_schemas.py index bb8baa789a..551f61659d 100644 --- a/tests/unit/operations/test_output_schemas.py +++ b/tests/unit/operations/test_output_schemas.py @@ -36,7 +36,11 @@ pa.struct((pa.field("x", pa.int64()), pa.field("y", pa.float64()))), ), ( - "ARRAY>", + "STRUCT", + pa.struct((pa.field("x", pa.float64()), pa.field("y", pa.int64()))), + ), + ( + "ARRAY>", pa.list_(pa.struct((pa.field("x", pa.int64()), pa.field("y", pa.int64())))), ), ( From 06a32b8b238c115db4ee326e3ce6482ca151646e Mon Sep 17 00:00:00 2001 From: Shenyang Cai Date: Sun, 5 Oct 2025 03:20:06 +0000 Subject: [PATCH 08/12] fix tests --- third_party/bigframes_vendored/ibis/expr/operations/ai_ops.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/third_party/bigframes_vendored/ibis/expr/operations/ai_ops.py b/third_party/bigframes_vendored/ibis/expr/operations/ai_ops.py index f89377fa18..3e899dce93 100644 --- a/third_party/bigframes_vendored/ibis/expr/operations/ai_ops.py +++ b/third_party/bigframes_vendored/ibis/expr/operations/ai_ops.py @@ -32,7 +32,7 @@ class AIGenerate(Value): @attribute def dtype(self) -> dt.Struct: if self.output_schema is None: - output_pa_fields = pa.field("result", pa.string()) + output_pa_fields = (pa.field("result", pa.string()),) else: output_pa_fields = output_schemas.parse_sql_fields(self.output_schema.value) From 36b5dfba10905817dbe36114ae84b53262325c50 Mon Sep 17 00:00:00 2001 From: Shenyang Cai Date: Mon, 6 Oct 2025 17:40:56 +0000 Subject: [PATCH 09/12] support case-insensitive type parsing --- bigframes/operations/output_schemas.py | 12 ++++++------ tests/system/small/bigquery/test_ai.py | 2 +- tests/unit/operations/test_output_schemas.py | 10 +++++++--- 3 files changed, 14 insertions(+), 10 deletions(-) diff --git a/bigframes/operations/output_schemas.py b/bigframes/operations/output_schemas.py index fb1bb25932..ff9c9883dc 100644 --- a/bigframes/operations/output_schemas.py +++ b/bigframes/operations/output_schemas.py @@ -31,23 +31,23 @@ def parse_sql_type(sql: str) -> pa.DataType: """ sql = sql.strip() - if sql == "STRING": + if sql.upper() == "STRING": return pa.string() - if sql == "INT64": + if sql.upper() == "INT64": return pa.int64() - if sql == "FLOAT64": + if sql.upper() == "FLOAT64": return pa.float64() - if sql == "BOOL": + if sql.upper() == "BOOL": return pa.bool_() - if sql.startswith("ARRAY<") and sql.endswith(">"): + if sql.upper().startswith("ARRAY<") and sql.endswith(">"): inner_type = sql[len("ARRAY<") : -1] return pa.list_(parse_sql_type(inner_type)) - if sql.startswith("STRUCT<") and sql.endswith(">"): + if sql.upper().startswith("STRUCT<") and sql.endswith(">"): inner_fields = parse_sql_fields(sql[len("STRUCT<") : -1]) return pa.struct(inner_fields) diff --git a/tests/system/small/bigquery/test_ai.py b/tests/system/small/bigquery/test_ai.py index fe0d679306..f2b7aa7037 100644 --- a/tests/system/small/bigquery/test_ai.py +++ b/tests/system/small/bigquery/test_ai.py @@ -94,7 +94,7 @@ def test_ai_generate_with_output_schema(session): result = bbq.ai.generate( prompt, endpoint="gemini-2.5-flash", - output_schema={"population": "INT64", "is_in_north_america": "BOOL"}, + output_schema={"population": "INT64", "is_in_north_america": "bool"}, ) assert _contains_no_nulls(result) diff --git a/tests/unit/operations/test_output_schemas.py b/tests/unit/operations/test_output_schemas.py index 551f61659d..c609098c98 100644 --- a/tests/unit/operations/test_output_schemas.py +++ b/tests/unit/operations/test_output_schemas.py @@ -23,6 +23,7 @@ [ ("INT64", pa.int64()), (" INT64 ", pa.int64()), + ("int64", pa.int64()), ("FLOAT64", pa.float64()), ("STRING", pa.string()), ("BOOL", pa.bool_()), @@ -44,7 +45,7 @@ pa.list_(pa.struct((pa.field("x", pa.int64()), pa.field("y", pa.int64())))), ), ( - "STRUCT, y STRUCT>", + "STRUCT, x ARRAY>", pa.struct( ( pa.field("x", pa.list_(pa.float64())), @@ -66,14 +67,13 @@ def test_parse_sql_to_pyarrow_dtype(sql, expected): @pytest.mark.parametrize( "sql", [ - "int64", "a INT64", "ARRAY<>", "ARRAY" "ARRAY" "STRUCT<>", "DATE", "STRUCT", - "STRUCT", + "ARRAY>", ], ) def test_parse_sql_to_pyarrow_dtype_invalid_input_raies_error(sql): @@ -89,6 +89,10 @@ def test_parse_sql_to_pyarrow_dtype_invalid_input_raies_error(sql): "x INT64, y FLOAT64", (pa.field("x", pa.int64()), pa.field("y", pa.float64())), ), + ( + "y FLOAT64, x INT64", + (pa.field("x", pa.int64()), pa.field("y", pa.float64())), + ), ], ) def test_parse_sql_fields(sql, expected): From 99c7bfc7d70eb565bc99a6562686d8dfbc2c4ca9 Mon Sep 17 00:00:00 2001 From: Shenyang Cai Date: Mon, 6 Oct 2025 17:57:49 +0000 Subject: [PATCH 10/12] fix test --- tests/system/small/bigquery/test_ai.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/system/small/bigquery/test_ai.py b/tests/system/small/bigquery/test_ai.py index f2b7aa7037..944a84c06b 100644 --- a/tests/system/small/bigquery/test_ai.py +++ b/tests/system/small/bigquery/test_ai.py @@ -118,7 +118,7 @@ def test_ai_generate_with_invalid_output_schema_raise_error(session): bbq.ai.generate( prompt, endpoint="gemini-2.5-flash", - output_schema={"population": "INT64", "is_in_north_america": "bool"}, + output_schema={"population": "INT64", "is_in_north_america": "JSON"}, ) From 93706b6042b9368bf3c3e0c445cfdb6962f0e2bb Mon Sep 17 00:00:00 2001 From: TrevorBergeron Date: Mon, 6 Oct 2025 14:21:15 -0700 Subject: [PATCH 11/12] fix: Fix row count local execution bug (#2133) --- bigframes/core/rewrite/pruning.py | 19 ++++++++---------- .../system/small/engines/test_aggregation.py | 19 ++++++++++++++++++ .../test_row_number/out.sql | 20 ++++++++++++++++--- .../test_nullary_compiler/test_size/out.sql | 20 ++++++++++++++++--- 4 files changed, 61 insertions(+), 17 deletions(-) diff --git a/bigframes/core/rewrite/pruning.py b/bigframes/core/rewrite/pruning.py index 8a07f0b87e..41664e1c47 100644 --- a/bigframes/core/rewrite/pruning.py +++ b/bigframes/core/rewrite/pruning.py @@ -13,6 +13,7 @@ # limitations under the License. import dataclasses import functools +import itertools import typing from bigframes.core import identifiers, nodes @@ -51,17 +52,9 @@ def prune_columns(node: nodes.BigFrameNode): if isinstance(node, nodes.SelectionNode): result = prune_selection_child(node) elif isinstance(node, nodes.ResultNode): - result = node.replace_child( - prune_node( - node.child, node.consumed_ids or frozenset(list(node.child.ids)[0:1]) - ) - ) + result = node.replace_child(prune_node(node.child, node.consumed_ids)) elif isinstance(node, nodes.AggregateNode): - result = node.replace_child( - prune_node( - node.child, node.consumed_ids or frozenset(list(node.child.ids)[0:1]) - ) - ) + result = node.replace_child(prune_node(node.child, node.consumed_ids)) elif isinstance(node, nodes.InNode): result = dataclasses.replace( node, @@ -149,9 +142,13 @@ def prune_node( if not (set(node.ids) - ids): return node else: + # If no child ids are needed, probably a size op or numbering op above, keep a single column always + ids_to_keep = tuple(id for id in node.ids if id in ids) or tuple( + itertools.islice(node.ids, 0, 1) + ) return nodes.SelectionNode( node, - tuple(nodes.AliasedRef.identity(id) for id in node.ids if id in ids), + tuple(nodes.AliasedRef.identity(id) for id in ids_to_keep), ) diff --git a/tests/system/small/engines/test_aggregation.py b/tests/system/small/engines/test_aggregation.py index 9b4efe8cbe..a25c167f71 100644 --- a/tests/system/small/engines/test_aggregation.py +++ b/tests/system/small/engines/test_aggregation.py @@ -48,6 +48,25 @@ def apply_agg_to_all_valid( return new_arr +@pytest.mark.parametrize("engine", ["polars", "bq", "bq-sqlglot"], indirect=True) +def test_engines_aggregate_post_filter_size( + scalars_array_value: array_value.ArrayValue, + engine, +): + w_offsets, offsets_id = ( + scalars_array_value.select_columns(("bool_col", "string_col")) + .filter(expression.deref("bool_col")) + .promote_offsets() + ) + plan = ( + w_offsets.select_columns((offsets_id, "bool_col", "string_col")) + .row_count() + .node + ) + + assert_equivalence_execution(plan, REFERENCE_ENGINE, engine) + + @pytest.mark.parametrize("engine", ["polars", "bq", "bq-sqlglot"], indirect=True) def test_engines_aggregate_size( scalars_array_value: array_value.ArrayValue, diff --git a/tests/unit/core/compile/sqlglot/aggregations/snapshots/test_nullary_compiler/test_row_number/out.sql b/tests/unit/core/compile/sqlglot/aggregations/snapshots/test_nullary_compiler/test_row_number/out.sql index d20a635e3d..b48dcfa01b 100644 --- a/tests/unit/core/compile/sqlglot/aggregations/snapshots/test_nullary_compiler/test_row_number/out.sql +++ b/tests/unit/core/compile/sqlglot/aggregations/snapshots/test_nullary_compiler/test_row_number/out.sql @@ -1,13 +1,27 @@ WITH `bfcte_0` AS ( SELECT - `bool_col` AS `bfcol_0` + `bool_col` AS `bfcol_0`, + `bytes_col` AS `bfcol_1`, + `date_col` AS `bfcol_2`, + `datetime_col` AS `bfcol_3`, + `geography_col` AS `bfcol_4`, + `int64_col` AS `bfcol_5`, + `int64_too` AS `bfcol_6`, + `numeric_col` AS `bfcol_7`, + `float64_col` AS `bfcol_8`, + `rowindex` AS `bfcol_9`, + `rowindex_2` AS `bfcol_10`, + `string_col` AS `bfcol_11`, + `time_col` AS `bfcol_12`, + `timestamp_col` AS `bfcol_13`, + `duration_col` AS `bfcol_14` FROM `bigframes-dev`.`sqlglot_test`.`scalar_types` ), `bfcte_1` AS ( SELECT *, - ROW_NUMBER() OVER () AS `bfcol_1` + ROW_NUMBER() OVER () AS `bfcol_32` FROM `bfcte_0` ) SELECT - `bfcol_1` AS `row_number` + `bfcol_32` AS `row_number` FROM `bfcte_1` \ No newline at end of file diff --git a/tests/unit/core/compile/sqlglot/aggregations/snapshots/test_nullary_compiler/test_size/out.sql b/tests/unit/core/compile/sqlglot/aggregations/snapshots/test_nullary_compiler/test_size/out.sql index 19ae8aa3fd..8cda9a3d80 100644 --- a/tests/unit/core/compile/sqlglot/aggregations/snapshots/test_nullary_compiler/test_size/out.sql +++ b/tests/unit/core/compile/sqlglot/aggregations/snapshots/test_nullary_compiler/test_size/out.sql @@ -1,12 +1,26 @@ WITH `bfcte_0` AS ( SELECT - `rowindex` AS `bfcol_0` + `bool_col` AS `bfcol_0`, + `bytes_col` AS `bfcol_1`, + `date_col` AS `bfcol_2`, + `datetime_col` AS `bfcol_3`, + `geography_col` AS `bfcol_4`, + `int64_col` AS `bfcol_5`, + `int64_too` AS `bfcol_6`, + `numeric_col` AS `bfcol_7`, + `float64_col` AS `bfcol_8`, + `rowindex` AS `bfcol_9`, + `rowindex_2` AS `bfcol_10`, + `string_col` AS `bfcol_11`, + `time_col` AS `bfcol_12`, + `timestamp_col` AS `bfcol_13`, + `duration_col` AS `bfcol_14` FROM `bigframes-dev`.`sqlglot_test`.`scalar_types` ), `bfcte_1` AS ( SELECT - COUNT(1) AS `bfcol_2` + COUNT(1) AS `bfcol_32` FROM `bfcte_0` ) SELECT - `bfcol_2` AS `size` + `bfcol_32` AS `size` FROM `bfcte_1` \ No newline at end of file From 780ff9f7116407b30033f2d7b68d0b0a9c718db0 Mon Sep 17 00:00:00 2001 From: TrevorBergeron Date: Tue, 7 Oct 2025 09:41:47 -0700 Subject: [PATCH 12/12] fix: join on, how args are now positional (#2140) --- bigframes/dataframe.py | 1 - third_party/bigframes_vendored/pandas/core/frame.py | 3 +-- 2 files changed, 1 insertion(+), 3 deletions(-) diff --git a/bigframes/dataframe.py b/bigframes/dataframe.py index de153fca48..1bde29506d 100644 --- a/bigframes/dataframe.py +++ b/bigframes/dataframe.py @@ -3721,7 +3721,6 @@ def _validate_left_right_on( def join( self, other: Union[DataFrame, bigframes.series.Series], - *, on: Optional[str] = None, how: str = "left", lsuffix: str = "", diff --git a/third_party/bigframes_vendored/pandas/core/frame.py b/third_party/bigframes_vendored/pandas/core/frame.py index 1d8f5cbace..557c332797 100644 --- a/third_party/bigframes_vendored/pandas/core/frame.py +++ b/third_party/bigframes_vendored/pandas/core/frame.py @@ -4601,9 +4601,8 @@ def map(self, func, na_action: Optional[str] = None) -> DataFrame: def join( self, other, - *, on: Optional[str] = None, - how: str, + how: str = "left", lsuffix: str = "", rsuffix: str = "", ) -> DataFrame: