Skip to content
Merged
26 changes: 24 additions & 2 deletions bigframes/bigquery/_operations/ai.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@

from bigframes import clients, dtypes, series, session
from bigframes.core import convert, log_adapter
from bigframes.operations import ai_ops
from bigframes.operations import ai_ops, output_schemas

PROMPT_TYPE = Union[
series.Series,
Expand All @@ -43,7 +43,7 @@ def generate(
endpoint: str | None = None,
request_type: Literal["dedicated", "shared", "unspecified"] = "unspecified",
model_params: Mapping[Any, Any] | None = None,
# TODO(b/446974666) Add output_schema parameter
output_schema: Mapping[str, str] | None = None,
) -> series.Series:
"""
Returns the AI analysis based on the prompt, which can be any combination of text and unstructured data.
Expand All @@ -64,6 +64,14 @@ def generate(
1 Ottawa\\n
Name: result, dtype: string

You get structured output when the `output_schema` parameter is set:

>>> animals = bpd.Series(["Rabbit", "Spider"])
>>> bbq.ai.generate(animals, output_schema={"number_of_legs": "INT64", "is_herbivore": "BOOL"})
0 {'is_herbivore': True, 'number_of_legs': 4, 'f...
1 {'is_herbivore': False, 'number_of_legs': 8, '...
dtype: struct<is_herbivore: bool, number_of_legs: int64, full_response: extension<dbjson<JSONArrowType>>, status: string>[pyarrow]

Args:
prompt (Series | List[str|Series] | Tuple[str|Series, ...]):
A mixture of Series and string literals that specifies the prompt to send to the model. The Series can be BigFrames Series
Expand All @@ -86,10 +94,14 @@ def generate(
If requests exceed the Provisioned Throughput quota, the overflow traffic uses DSQ quota.
model_params (Mapping[Any, Any]):
Provides additional parameters to the model. The MODEL_PARAMS value must conform to the generateContent request body format.
output_schema (Mapping[str, str]):
A mapping value that specifies the schema of the output, in the form {field_name: data_type}. Supported data types include
`STRING`, `INT64`, `FLOAT64`, `BOOL`, `ARRAY`, and `STRUCT`.

Returns:
bigframes.series.Series: A new struct Series with the result data. The struct contains these fields:
* "result": a STRING value containing the model's response to the prompt. The result is None if the request fails or is filtered by responsible AI.
If you specify an output schema then result is replaced by your custom schema.
* "full_response": a JSON value containing the response from the projects.locations.endpoints.generateContent call to the model.
The generated text is in the text element.
* "status": a STRING value that contains the API response status for the corresponding row. This value is empty if the operation was successful.
Expand All @@ -98,12 +110,22 @@ def generate(
prompt_context, series_list = _separate_context_and_series(prompt)
assert len(series_list) > 0

if output_schema is None:
output_schema_str = None
else:
output_schema_str = ", ".join(
[f"{name} {sql_type}" for name, sql_type in output_schema.items()]
)
# Validate user input
output_schemas.parse_sql_fields(output_schema_str)

operator = ai_ops.AIGenerate(
prompt_context=tuple(prompt_context),
connection_id=_resolve_connection_id(series_list[0], connection_id),
endpoint=endpoint,
request_type=request_type,
model_params=json.dumps(model_params) if model_params else None,
output_schema=output_schema_str,
)

return series_list[0]._apply_nary_op(operator, series_list[1:])
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1985,6 +1985,7 @@ def ai_generate(
op.endpoint, # type: ignore
op.request_type.upper(), # type: ignore
op.model_params, # type: ignore
op.output_schema, # type: ignore
).to_expr()


Expand Down
18 changes: 13 additions & 5 deletions bigframes/core/compile/sqlglot/expressions/ai_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
from __future__ import annotations

from dataclasses import asdict
import typing

import sqlglot.expressions as sge

Expand Down Expand Up @@ -105,24 +104,24 @@ def _construct_named_args(op: ops.NaryOp) -> list[sge.Kwarg]:

op_args = asdict(op)

connection_id = typing.cast(str, op_args["connection_id"])
connection_id = op_args["connection_id"]
args.append(
sge.Kwarg(this="connection_id", expression=sge.Literal.string(connection_id))
)

endpoit = typing.cast(str, op_args.get("endpoint", None))
endpoit = op_args.get("endpoint", None)
if endpoit is not None:
args.append(sge.Kwarg(this="endpoint", expression=sge.Literal.string(endpoit)))

request_type = typing.cast(str, op_args.get("request_type", None))
request_type = op_args.get("request_type", None)
if request_type is not None:
args.append(
sge.Kwarg(
this="request_type", expression=sge.Literal.string(request_type.upper())
)
)

model_params = typing.cast(str, op_args.get("model_params", None))
model_params = op_args.get("model_params", None)
if model_params is not None:
args.append(
sge.Kwarg(
Expand All @@ -133,4 +132,13 @@ def _construct_named_args(op: ops.NaryOp) -> list[sge.Kwarg]:
)
)

output_schema = op_args.get("output_schema", None)
if output_schema is not None:
args.append(
sge.Kwarg(
this="output_schema",
expression=sge.Literal.string(output_schema),
)
)

return args
10 changes: 8 additions & 2 deletions bigframes/operations/ai_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
import pyarrow as pa

from bigframes import dtypes
from bigframes.operations import base_ops
from bigframes.operations import base_ops, output_schemas


@dataclasses.dataclass(frozen=True)
Expand All @@ -33,12 +33,18 @@ class AIGenerate(base_ops.NaryOp):
endpoint: str | None
request_type: Literal["dedicated", "shared", "unspecified"]
model_params: str | None
output_schema: str | None

def output_type(self, *input_types: dtypes.ExpressionType) -> dtypes.ExpressionType:
if self.output_schema is None:
output_fields = (pa.field("result", pa.string()),)
else:
output_fields = output_schemas.parse_sql_fields(self.output_schema)

return pd.ArrowDtype(
pa.struct(
(
pa.field("result", pa.string()),
*output_fields,
pa.field("full_response", dtypes.JSON_ARROW_TYPE),
pa.field("status", pa.string()),
)
Expand Down
90 changes: 90 additions & 0 deletions bigframes/operations/output_schemas.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import pyarrow as pa


def parse_sql_type(sql: str) -> pa.DataType:
"""
Parses a SQL type string to its PyArrow equivalence:

For example:
"STRING" -> pa.string()
"ARRAY<INT64>" -> pa.list_(pa.int64())
"STRUCT<x ARRAY<FLOAT64>, y BOOL>" -> pa.struct(
(
pa.field("x", pa.list_(pa.float64())),
pa.field("y", pa.bool_()),
)
)
"""
sql = sql.strip()

if sql.upper() == "STRING":
return pa.string()

if sql.upper() == "INT64":
return pa.int64()

if sql.upper() == "FLOAT64":
return pa.float64()

if sql.upper() == "BOOL":
return pa.bool_()

if sql.upper().startswith("ARRAY<") and sql.endswith(">"):
inner_type = sql[len("ARRAY<") : -1]
return pa.list_(parse_sql_type(inner_type))

if sql.upper().startswith("STRUCT<") and sql.endswith(">"):
inner_fields = parse_sql_fields(sql[len("STRUCT<") : -1])
return pa.struct(inner_fields)

raise ValueError(f"Unsupported SQL type: {sql}")


def parse_sql_fields(sql: str) -> tuple[pa.Field]:
sql = sql.strip()

start_idx = 0
nested_depth = 0
fields: list[pa.field] = []

for end_idx in range(len(sql)):
c = sql[end_idx]

if c == "<":
nested_depth += 1
elif c == ">":
nested_depth -= 1
elif c == "," and nested_depth == 0:
field = sql[start_idx:end_idx]
fields.append(parse_sql_field(field))
start_idx = end_idx + 1

# Append the last field
fields.append(parse_sql_field(sql[start_idx:]))

return tuple(sorted(fields, key=lambda f: f.name))


def parse_sql_field(sql: str) -> pa.Field:
sql = sql.strip()

space_idx = sql.find(" ")

if space_idx == -1:
raise ValueError(f"Invalid struct field: {sql}")

return pa.field(sql[:space_idx].strip(), parse_sql_type(sql[space_idx:]))
35 changes: 35 additions & 0 deletions tests/system/small/bigquery/test_ai.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,41 @@ def test_ai_generate(session):
)


def test_ai_generate_with_output_schema(session):
country = bpd.Series(["Japan", "Canada"], session=session)
prompt = ("Describe ", country)

result = bbq.ai.generate(
prompt,
endpoint="gemini-2.5-flash",
output_schema={"population": "INT64", "is_in_north_america": "bool"},
)

assert _contains_no_nulls(result)
assert result.dtype == pd.ArrowDtype(
pa.struct(
(
pa.field("is_in_north_america", pa.bool_()),
pa.field("population", pa.int64()),
pa.field("full_response", dtypes.JSON_ARROW_TYPE),
pa.field("status", pa.string()),
)
)
)


def test_ai_generate_with_invalid_output_schema_raise_error(session):
country = bpd.Series(["Japan", "Canada"], session=session)
prompt = ("Describe ", country)

with pytest.raises(ValueError):
bbq.ai.generate(
prompt,
endpoint="gemini-2.5-flash",
output_schema={"population": "INT64", "is_in_north_america": "JSON"},
)


def test_ai_generate_bool(session):
s1 = bpd.Series(["apple", "bear"], session=session)
s2 = bpd.Series(["fruit", "tree"], session=session)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
WITH `bfcte_0` AS (
SELECT
`string_col` AS `bfcol_0`
FROM `bigframes-dev`.`sqlglot_test`.`scalar_types`
), `bfcte_1` AS (
SELECT
*,
AI.GENERATE(
prompt => (`bfcol_0`, ' is the same as ', `bfcol_0`),
connection_id => 'bigframes-dev.us.bigframes-default-connection',
endpoint => 'gemini-2.5-flash',
request_type => 'SHARED',
output_schema => 'x INT64, y FLOAT64'
) AS `bfcol_1`
FROM `bfcte_0`
)
SELECT
`bfcol_1` AS `result`
FROM `bfcte_1`
21 changes: 21 additions & 0 deletions tests/unit/core/compile/sqlglot/expressions/test_ai_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,26 @@ def test_ai_generate(scalar_types_df: dataframe.DataFrame, snapshot):
endpoint="gemini-2.5-flash",
request_type="shared",
model_params=None,
output_schema=None,
)

sql = utils._apply_unary_ops(
scalar_types_df, [op.as_expr(col_name, col_name)], ["result"]
)

snapshot.assert_match(sql, "out.sql")


def test_ai_generate_with_output_schema(scalar_types_df: dataframe.DataFrame, snapshot):
col_name = "string_col"

op = ops.AIGenerate(
prompt_context=(None, " is the same as ", None),
connection_id=CONNECTION_ID,
endpoint="gemini-2.5-flash",
request_type="shared",
model_params=None,
output_schema="x INT64, y FLOAT64",
)

sql = utils._apply_unary_ops(
Expand All @@ -59,6 +79,7 @@ def test_ai_generate_with_model_param(scalar_types_df: dataframe.DataFrame, snap
endpoint=None,
request_type="shared",
model_params=json.dumps(dict()),
output_schema=None,
)

sql = utils._apply_unary_ops(
Expand Down
Loading