From 54368532a7312574916b7616ced35b51d9d8e9d1 Mon Sep 17 00:00:00 2001 From: Shenyang Cai Date: Tue, 23 Sep 2025 23:47:22 +0000 Subject: [PATCH 1/3] feat: add ai.generate_double to bigframes.bigquery package --- bigframes/bigquery/_operations/ai.py | 75 +++++++++++++++++++ .../ibis_compiler/scalar_op_registry.py | 15 +++- .../compile/sqlglot/expressions/ai_ops.py | 7 ++ bigframes/operations/__init__.py | 3 +- bigframes/operations/ai_ops.py | 22 ++++++ tests/system/small/bigquery/test_ai.py | 39 ++++++++++ .../test_ai_generate_double/out.sql | 18 +++++ .../out.sql | 18 +++++ .../sqlglot/expressions/test_ai_ops.py | 45 +++++++++++ .../sql/compilers/bigquery/__init__.py | 3 + .../ibis/expr/operations/ai_ops.py | 23 ++++++ 11 files changed, 266 insertions(+), 2 deletions(-) create mode 100644 tests/unit/core/compile/sqlglot/expressions/snapshots/test_ai_ops/test_ai_generate_double/out.sql create mode 100644 tests/unit/core/compile/sqlglot/expressions/snapshots/test_ai_ops/test_ai_generate_double_with_model_param/out.sql diff --git a/bigframes/bigquery/_operations/ai.py b/bigframes/bigquery/_operations/ai.py index f0b4f51611..0a283c9199 100644 --- a/bigframes/bigquery/_operations/ai.py +++ b/bigframes/bigquery/_operations/ai.py @@ -188,6 +188,81 @@ def generate_int( return series_list[0]._apply_nary_op(operator, series_list[1:]) +@log_adapter.method_logger(custom_base_name="bigquery_ai") +def generate_double( + prompt: PROMPT_TYPE, + *, + connection_id: str | None = None, + endpoint: str | None = None, + request_type: Literal["dedicated", "shared", "unspecified"] = "unspecified", + model_params: Mapping[Any, Any] | None = None, +) -> series.Series: + """ + Returns the AI analysis based on the prompt, which can be any combination of text and unstructured data. + + **Examples:** + + >>> import bigframes.pandas as bpd + >>> import bigframes.bigquery as bbq + >>> bpd.options.display.progress_bar = None + >>> animal = bpd.Series(["Kangaroo", "Rabbit", "Spider"]) + >>> bbq.ai.generate_int(("How many legs does a ", animal, " have?")) + 0 {'result': 2.0, 'full_response': '{"candidates... + 1 {'result': 4.0, 'full_response': '{"candidates... + 2 {'result': 8.0, 'full_response': '{"candidates... + dtype: struct>, status: string>[pyarrow] + + >>> bbq.ai.generate_int(("How many legs does a ", animal, " have?")).struct.field("result") + 0 2.0 + 1 4.0 + 2 8.0 + Name: result, dtype: Float64 + + Args: + prompt (Series | List[str|Series] | Tuple[str|Series, ...]): + A mixture of Series and string literals that specifies the prompt to send to the model. The Series can be BigFrames Series + or pandas Series. + connection_id (str, optional): + Specifies the connection to use to communicate with the model. For example, `myproject.us.myconnection`. + If not provided, the connection from the current session will be used. + endpoint (str, optional): + Specifies the Vertex AI endpoint to use for the model. For example `"gemini-2.5-flash"`. You can specify any + generally available or preview Gemini model. If you specify the model name, BigQuery ML automatically identifies and + uses the full endpoint of the model. If you don't specify an ENDPOINT value, BigQuery ML selects a recent stable + version of Gemini to use. + request_type (Literal["dedicated", "shared", "unspecified"]): + Specifies the type of inference request to send to the Gemini model. The request type determines what quota the request uses. + * "dedicated": function only uses Provisioned Throughput quota. The function returns the error Provisioned throughput is not + purchased or is not active if Provisioned Throughput quota isn't available. + * "shared": the function only uses dynamic shared quota (DSQ), even if you have purchased Provisioned Throughput quota. + * "unspecified": If you haven't purchased Provisioned Throughput quota, the function uses DSQ quota. + If you have purchased Provisioned Throughput quota, the function uses the Provisioned Throughput quota first. + If requests exceed the Provisioned Throughput quota, the overflow traffic uses DSQ quota. + model_params (Mapping[Any, Any]): + Provides additional parameters to the model. The MODEL_PARAMS value must conform to the generateContent request body format. + + Returns: + bigframes.series.Series: A new struct Series with the result data. The struct contains these fields: + * "result": an DOUBLE value containing the model's response to the prompt. The result is None if the request fails or is filtered by responsible AI. + * "full_response": a JSON value containing the response from the projects.locations.endpoints.generateContent call to the model. + The generated text is in the text element. + * "status": a STRING value that contains the API response status for the corresponding row. This value is empty if the operation was successful. + """ + + prompt_context, series_list = _separate_context_and_series(prompt) + assert len(series_list) > 0 + + operator = ai_ops.AIGenerateDouble( + prompt_context=tuple(prompt_context), + connection_id=_resolve_connection_id(series_list[0], connection_id), + endpoint=endpoint, + request_type=request_type, + model_params=json.dumps(model_params) if model_params else None, + ) + + return series_list[0]._apply_nary_op(operator, series_list[1:]) + + def _separate_context_and_series( prompt: PROMPT_TYPE, ) -> Tuple[List[str | None], List[series.Series]]: diff --git a/bigframes/core/compile/ibis_compiler/scalar_op_registry.py b/bigframes/core/compile/ibis_compiler/scalar_op_registry.py index a750a625ad..f650cfed47 100644 --- a/bigframes/core/compile/ibis_compiler/scalar_op_registry.py +++ b/bigframes/core/compile/ibis_compiler/scalar_op_registry.py @@ -1986,7 +1986,7 @@ def ai_generate_bool( @scalar_op_compiler.register_nary_op(ops.AIGenerateInt, pass_op=True) def ai_generate_int( - *values: ibis_types.Value, op: ops.AIGenerateBool + *values: ibis_types.Value, op: ops.AIGenerateInt ) -> ibis_types.StructValue: return ai_ops.AIGenerateInt( @@ -1997,6 +1997,19 @@ def ai_generate_int( op.model_params, # type: ignore ).to_expr() +@scalar_op_compiler.register_nary_op(ops.AIGenerateDouble, pass_op=True) +def ai_generate_int( + *values: ibis_types.Value, op: ops.AIGenerateDouble +) -> ibis_types.StructValue: + + return ai_ops.AIGenerateDouble( + _construct_prompt(values, op.prompt_context), # type: ignore + op.connection_id, # type: ignore + op.endpoint, # type: ignore + op.request_type.upper(), # type: ignore + op.model_params, # type: ignore + ).to_expr() + def _construct_prompt( col_refs: tuple[ibis_types.Value], prompt_context: tuple[str | None] diff --git a/bigframes/core/compile/sqlglot/expressions/ai_ops.py b/bigframes/core/compile/sqlglot/expressions/ai_ops.py index 50d56611b1..0e6d079bd7 100644 --- a/bigframes/core/compile/sqlglot/expressions/ai_ops.py +++ b/bigframes/core/compile/sqlglot/expressions/ai_ops.py @@ -40,6 +40,13 @@ def _(*exprs: TypedExpr, op: ops.AIGenerateInt) -> sge.Expression: return sge.func("AI.GENERATE_INT", *args) +@register_nary_op(ops.AIGenerateDouble, pass_op=True) +def _(*exprs: TypedExpr, op: ops.AIGenerateDouble) -> sge.Expression: + args = [_construct_prompt(exprs, op.prompt_context)] + _construct_named_args(op) + + return sge.func("AI.GENERATE_DOUBLE", *args) + + def _construct_prompt( exprs: tuple[TypedExpr, ...], prompt_context: tuple[str | None, ...] ) -> sge.Kwarg: diff --git a/bigframes/operations/__init__.py b/bigframes/operations/__init__.py index 17e1f7534f..b14d15245a 100644 --- a/bigframes/operations/__init__.py +++ b/bigframes/operations/__init__.py @@ -14,7 +14,7 @@ from __future__ import annotations -from bigframes.operations.ai_ops import AIGenerateBool, AIGenerateInt +from bigframes.operations.ai_ops import AIGenerateBool, AIGenerateDouble, AIGenerateInt from bigframes.operations.array_ops import ( ArrayIndexOp, ArrayReduceOp, @@ -413,6 +413,7 @@ "GeoStDistanceOp", # AI ops "AIGenerateBool", + "AIGenerateDouble", "AIGenerateInt", # Numpy ops mapping "NUMPY_TO_BINOP", diff --git a/bigframes/operations/ai_ops.py b/bigframes/operations/ai_ops.py index 7a8202abd2..4404558497 100644 --- a/bigframes/operations/ai_ops.py +++ b/bigframes/operations/ai_ops.py @@ -66,3 +66,25 @@ def output_type(self, *input_types: dtypes.ExpressionType) -> dtypes.ExpressionT ) ) ) + + +@dataclasses.dataclass(frozen=True) +class AIGenerateDouble(base_ops.NaryOp): + name: ClassVar[str] = "ai_generate_double" + + prompt_context: Tuple[str | None, ...] + connection_id: str + endpoint: str | None + request_type: Literal["dedicated", "shared", "unspecified"] + model_params: str | None + + def output_type(self, *input_types: dtypes.ExpressionType) -> dtypes.ExpressionType: + return pd.ArrowDtype( + pa.struct( + ( + pa.field("result", pa.float64()), + pa.field("full_response", dtypes.JSON_ARROW_TYPE), + pa.field("status", pa.string()), + ) + ) + ) diff --git a/tests/system/small/bigquery/test_ai.py b/tests/system/small/bigquery/test_ai.py index 9f6feb0bbc..7d32149726 100644 --- a/tests/system/small/bigquery/test_ai.py +++ b/tests/system/small/bigquery/test_ai.py @@ -146,5 +146,44 @@ def test_ai_generate_int_multi_model(session): ) +def test_ai_generate_double(session): + s = bpd.Series(["Cat"], session=session) + prompt = ("How many legs does a ", s, " have?") + + result = bbq.ai.generate_double(prompt, endpoint="gemini-2.5-flash") + + assert _contains_no_nulls(result) + assert result.dtype == pd.ArrowDtype( + pa.struct( + ( + pa.field("result", pa.float64()), + pa.field("full_response", dtypes.JSON_ARROW_TYPE), + pa.field("status", pa.string()), + ) + ) + ) + + +def test_ai_generate_double_multi_model(session): + df = session.from_glob_path( + "gs://bigframes-dev-testing/a_multimodel/images/*", name="image" + ) + + result = bbq.ai.generate_double( + ("How many animals are there in the picture ", df["image"]) + ) + + assert _contains_no_nulls(result) + assert result.dtype == pd.ArrowDtype( + pa.struct( + ( + pa.field("result", pa.float64()), + pa.field("full_response", dtypes.JSON_ARROW_TYPE), + pa.field("status", pa.string()), + ) + ) + ) + + def _contains_no_nulls(s: series.Series) -> bool: return len(s) == s.count() diff --git a/tests/unit/core/compile/sqlglot/expressions/snapshots/test_ai_ops/test_ai_generate_double/out.sql b/tests/unit/core/compile/sqlglot/expressions/snapshots/test_ai_ops/test_ai_generate_double/out.sql new file mode 100644 index 0000000000..0baab06c3b --- /dev/null +++ b/tests/unit/core/compile/sqlglot/expressions/snapshots/test_ai_ops/test_ai_generate_double/out.sql @@ -0,0 +1,18 @@ +WITH `bfcte_0` AS ( + SELECT + `string_col` AS `bfcol_0` + FROM `bigframes-dev`.`sqlglot_test`.`scalar_types` +), `bfcte_1` AS ( + SELECT + *, + AI.GENERATE_DOUBLE( + prompt => (`bfcol_0`, ' is the same as ', `bfcol_0`), + connection_id => 'test_connection_id', + endpoint => 'gemini-2.5-flash', + request_type => 'SHARED' + ) AS `bfcol_1` + FROM `bfcte_0` +) +SELECT + `bfcol_1` AS `result` +FROM `bfcte_1` \ No newline at end of file diff --git a/tests/unit/core/compile/sqlglot/expressions/snapshots/test_ai_ops/test_ai_generate_double_with_model_param/out.sql b/tests/unit/core/compile/sqlglot/expressions/snapshots/test_ai_ops/test_ai_generate_double_with_model_param/out.sql new file mode 100644 index 0000000000..4756cbb509 --- /dev/null +++ b/tests/unit/core/compile/sqlglot/expressions/snapshots/test_ai_ops/test_ai_generate_double_with_model_param/out.sql @@ -0,0 +1,18 @@ +WITH `bfcte_0` AS ( + SELECT + `string_col` AS `bfcol_0` + FROM `bigframes-dev`.`sqlglot_test`.`scalar_types` +), `bfcte_1` AS ( + SELECT + *, + AI.GENERATE_DOUBLE( + prompt => (`bfcol_0`, ' is the same as ', `bfcol_0`), + connection_id => 'test_connection_id', + request_type => 'SHARED', + model_params => JSON '{}' + ) AS `bfcol_1` + FROM `bfcte_0` +) +SELECT + `bfcol_1` AS `result` +FROM `bfcte_1` \ No newline at end of file diff --git a/tests/unit/core/compile/sqlglot/expressions/test_ai_ops.py b/tests/unit/core/compile/sqlglot/expressions/test_ai_ops.py index 33a257f9a9..c95889e1b2 100644 --- a/tests/unit/core/compile/sqlglot/expressions/test_ai_ops.py +++ b/tests/unit/core/compile/sqlglot/expressions/test_ai_ops.py @@ -111,3 +111,48 @@ def test_ai_generate_int_with_model_param( ) snapshot.assert_match(sql, "out.sql") + + +def test_ai_generate_double(scalar_types_df: dataframe.DataFrame, snapshot): + col_name = "string_col" + + op = ops.AIGenerateDouble( + # The prompt does not make semantic sense but we only care about syntax correctness. + prompt_context=(None, " is the same as ", None), + connection_id="test_connection_id", + endpoint="gemini-2.5-flash", + request_type="shared", + model_params=None, + ) + + sql = utils._apply_unary_ops( + scalar_types_df, [op.as_expr(col_name, col_name)], ["result"] + ) + + snapshot.assert_match(sql, "out.sql") + + +def test_ai_generate_double_with_model_param( + scalar_types_df: dataframe.DataFrame, snapshot +): + if version.Version(sqlglot.__version__) < version.Version("25.18.0"): + pytest.skip( + "Skip test because SQLGLot cannot compile model params to JSON at this version." + ) + + col_name = "string_col" + + op = ops.AIGenerateDouble( + # The prompt does not make semantic sense but we only care about syntax correctness. + prompt_context=(None, " is the same as ", None), + connection_id="test_connection_id", + endpoint=None, + request_type="shared", + model_params=json.dumps(dict()), + ) + + sql = utils._apply_unary_ops( + scalar_types_df, [op.as_expr(col_name, col_name)], ["result"] + ) + + snapshot.assert_match(sql, "out.sql") diff --git a/third_party/bigframes_vendored/ibis/backends/sql/compilers/bigquery/__init__.py b/third_party/bigframes_vendored/ibis/backends/sql/compilers/bigquery/__init__.py index ef150534ee..e4122e88da 100644 --- a/third_party/bigframes_vendored/ibis/backends/sql/compilers/bigquery/__init__.py +++ b/third_party/bigframes_vendored/ibis/backends/sql/compilers/bigquery/__init__.py @@ -1110,6 +1110,9 @@ def visit_AIGenerateBool(self, op, **kwargs): def visit_AIGenerateInt(self, op, **kwargs): return sge.func("AI.GENERATE_INT", *self._compile_ai_args(**kwargs)) + def visit_AIGenerateDouble(self, op, **kwargs): + return sge.func("AI.GENERATE_DOUBLE", *self._compile_ai_args(**kwargs)) + def _compile_ai_args(self, **kwargs): args = [] diff --git a/third_party/bigframes_vendored/ibis/expr/operations/ai_ops.py b/third_party/bigframes_vendored/ibis/expr/operations/ai_ops.py index 4b855f71c0..708a459072 100644 --- a/third_party/bigframes_vendored/ibis/expr/operations/ai_ops.py +++ b/third_party/bigframes_vendored/ibis/expr/operations/ai_ops.py @@ -49,3 +49,26 @@ def dtype(self) -> dt.Struct: return dt.Struct.from_tuples( (("result", dt.int64), ("full_resposne", dt.string), ("status", dt.string)) ) + + +@public +class AIGenerateDouble(Value): + """Generate integers based on the prompt""" + + prompt: Value + connection_id: Value[dt.String] + endpoint: Optional[Value[dt.String]] + request_type: Value[dt.String] + model_params: Optional[Value[dt.String]] + + shape = rlz.shape_like("prompt") + + @attribute + def dtype(self) -> dt.Struct: + return dt.Struct.from_tuples( + ( + ("result", dt.float64), + ("full_resposne", dt.string), + ("status", dt.string), + ) + ) From 401bea45399fd6687eb81dacf762bb89fbeb0e64 Mon Sep 17 00:00:00 2001 From: Shenyang Cai Date: Tue, 23 Sep 2025 23:54:17 +0000 Subject: [PATCH 2/3] fix lint --- bigframes/core/compile/ibis_compiler/scalar_op_registry.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/bigframes/core/compile/ibis_compiler/scalar_op_registry.py b/bigframes/core/compile/ibis_compiler/scalar_op_registry.py index f650cfed47..a4ef0b44c9 100644 --- a/bigframes/core/compile/ibis_compiler/scalar_op_registry.py +++ b/bigframes/core/compile/ibis_compiler/scalar_op_registry.py @@ -1997,8 +1997,9 @@ def ai_generate_int( op.model_params, # type: ignore ).to_expr() + @scalar_op_compiler.register_nary_op(ops.AIGenerateDouble, pass_op=True) -def ai_generate_int( +def ai_generate_double( *values: ibis_types.Value, op: ops.AIGenerateDouble ) -> ibis_types.StructValue: From 11ec67e43e794c574b9fac23257ecc4fd7cee069 Mon Sep 17 00:00:00 2001 From: Shenyang Cai Date: Wed, 24 Sep 2025 00:32:55 +0000 Subject: [PATCH 3/3] fix doctest --- bigframes/bigquery/_operations/ai.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/bigframes/bigquery/_operations/ai.py b/bigframes/bigquery/_operations/ai.py index 0a283c9199..5c7a4d682e 100644 --- a/bigframes/bigquery/_operations/ai.py +++ b/bigframes/bigquery/_operations/ai.py @@ -206,13 +206,13 @@ def generate_double( >>> import bigframes.bigquery as bbq >>> bpd.options.display.progress_bar = None >>> animal = bpd.Series(["Kangaroo", "Rabbit", "Spider"]) - >>> bbq.ai.generate_int(("How many legs does a ", animal, " have?")) + >>> bbq.ai.generate_double(("How many legs does a ", animal, " have?")) 0 {'result': 2.0, 'full_response': '{"candidates... 1 {'result': 4.0, 'full_response': '{"candidates... 2 {'result': 8.0, 'full_response': '{"candidates... dtype: struct>, status: string>[pyarrow] - >>> bbq.ai.generate_int(("How many legs does a ", animal, " have?")).struct.field("result") + >>> bbq.ai.generate_double(("How many legs does a ", animal, " have?")).struct.field("result") 0 2.0 1 4.0 2 8.0