diff --git a/bigframes/bigquery/_operations/ai.py b/bigframes/bigquery/_operations/ai.py index 5c7a4d682e..3893ad12d1 100644 --- a/bigframes/bigquery/_operations/ai.py +++ b/bigframes/bigquery/_operations/ai.py @@ -35,6 +35,80 @@ ] +@log_adapter.method_logger(custom_base_name="bigquery_ai") +def generate( + prompt: PROMPT_TYPE, + *, + connection_id: str | None = None, + endpoint: str | None = None, + request_type: Literal["dedicated", "shared", "unspecified"] = "unspecified", + model_params: Mapping[Any, Any] | None = None, + # TODO(b/446974666) Add output_schema parameter +) -> series.Series: + """ + Returns the AI analysis based on the prompt, which can be any combination of text and unstructured data. + + **Examples:** + + >>> import bigframes.pandas as bpd + >>> import bigframes.bigquery as bbq + >>> bpd.options.display.progress_bar = None + >>> country = bpd.Series(["Japan", "Canada"]) + >>> bbq.ai.generate(("What's the capital city of ", country, " one word only")) + 0 {'result': 'Tokyo\\n', 'full_response': '{"cand... + 1 {'result': 'Ottawa\\n', 'full_response': '{"can... + dtype: struct>, status: string>[pyarrow] + + >>> bbq.ai.generate(("What's the capital city of ", country, " one word only")).struct.field("result") + 0 Tokyo\\n + 1 Ottawa\\n + Name: result, dtype: string + + Args: + prompt (Series | List[str|Series] | Tuple[str|Series, ...]): + A mixture of Series and string literals that specifies the prompt to send to the model. The Series can be BigFrames Series + or pandas Series. + connection_id (str, optional): + Specifies the connection to use to communicate with the model. For example, `myproject.us.myconnection`. + If not provided, the connection from the current session will be used. + endpoint (str, optional): + Specifies the Vertex AI endpoint to use for the model. For example `"gemini-2.5-flash"`. You can specify any + generally available or preview Gemini model. If you specify the model name, BigQuery ML automatically identifies and + uses the full endpoint of the model. If you don't specify an ENDPOINT value, BigQuery ML selects a recent stable + version of Gemini to use. + request_type (Literal["dedicated", "shared", "unspecified"]): + Specifies the type of inference request to send to the Gemini model. The request type determines what quota the request uses. + * "dedicated": function only uses Provisioned Throughput quota. The function returns the error Provisioned throughput is not + purchased or is not active if Provisioned Throughput quota isn't available. + * "shared": the function only uses dynamic shared quota (DSQ), even if you have purchased Provisioned Throughput quota. + * "unspecified": If you haven't purchased Provisioned Throughput quota, the function uses DSQ quota. + If you have purchased Provisioned Throughput quota, the function uses the Provisioned Throughput quota first. + If requests exceed the Provisioned Throughput quota, the overflow traffic uses DSQ quota. + model_params (Mapping[Any, Any]): + Provides additional parameters to the model. The MODEL_PARAMS value must conform to the generateContent request body format. + + Returns: + bigframes.series.Series: A new struct Series with the result data. The struct contains these fields: + * "result": a STRING value containing the model's response to the prompt. The result is None if the request fails or is filtered by responsible AI. + * "full_response": a JSON value containing the response from the projects.locations.endpoints.generateContent call to the model. + The generated text is in the text element. + * "status": a STRING value that contains the API response status for the corresponding row. This value is empty if the operation was successful. + """ + + prompt_context, series_list = _separate_context_and_series(prompt) + assert len(series_list) > 0 + + operator = ai_ops.AIGenerate( + prompt_context=tuple(prompt_context), + connection_id=_resolve_connection_id(series_list[0], connection_id), + endpoint=endpoint, + request_type=request_type, + model_params=json.dumps(model_params) if model_params else None, + ) + + return series_list[0]._apply_nary_op(operator, series_list[1:]) + + @log_adapter.method_logger(custom_base_name="bigquery_ai") def generate_bool( prompt: PROMPT_TYPE, diff --git a/bigframes/core/compile/ibis_compiler/scalar_op_registry.py b/bigframes/core/compile/ibis_compiler/scalar_op_registry.py index 635ba516e4..a0750ec73d 100644 --- a/bigframes/core/compile/ibis_compiler/scalar_op_registry.py +++ b/bigframes/core/compile/ibis_compiler/scalar_op_registry.py @@ -1974,6 +1974,20 @@ def struct_op_impl( return ibis_types.struct(data) +@scalar_op_compiler.register_nary_op(ops.AIGenerate, pass_op=True) +def ai_generate( + *values: ibis_types.Value, op: ops.AIGenerate +) -> ibis_types.StructValue: + + return ai_ops.AIGenerate( + _construct_prompt(values, op.prompt_context), # type: ignore + op.connection_id, # type: ignore + op.endpoint, # type: ignore + op.request_type.upper(), # type: ignore + op.model_params, # type: ignore + ).to_expr() + + @scalar_op_compiler.register_nary_op(ops.AIGenerateBool, pass_op=True) def ai_generate_bool( *values: ibis_types.Value, op: ops.AIGenerateBool diff --git a/bigframes/core/compile/sqlglot/expressions/ai_ops.py b/bigframes/core/compile/sqlglot/expressions/ai_ops.py index 0e6d079bd7..3f909ebc92 100644 --- a/bigframes/core/compile/sqlglot/expressions/ai_ops.py +++ b/bigframes/core/compile/sqlglot/expressions/ai_ops.py @@ -26,6 +26,13 @@ register_nary_op = scalar_compiler.scalar_op_compiler.register_nary_op +@register_nary_op(ops.AIGenerate, pass_op=True) +def _(*exprs: TypedExpr, op: ops.AIGenerate) -> sge.Expression: + args = [_construct_prompt(exprs, op.prompt_context)] + _construct_named_args(op) + + return sge.func("AI.GENERATE", *args) + + @register_nary_op(ops.AIGenerateBool, pass_op=True) def _(*exprs: TypedExpr, op: ops.AIGenerateBool) -> sge.Expression: args = [_construct_prompt(exprs, op.prompt_context)] + _construct_named_args(op) diff --git a/bigframes/operations/__init__.py b/bigframes/operations/__init__.py index b14d15245a..031e42cf03 100644 --- a/bigframes/operations/__init__.py +++ b/bigframes/operations/__init__.py @@ -14,7 +14,12 @@ from __future__ import annotations -from bigframes.operations.ai_ops import AIGenerateBool, AIGenerateDouble, AIGenerateInt +from bigframes.operations.ai_ops import ( + AIGenerate, + AIGenerateBool, + AIGenerateDouble, + AIGenerateInt, +) from bigframes.operations.array_ops import ( ArrayIndexOp, ArrayReduceOp, @@ -412,6 +417,7 @@ "geo_y_op", "GeoStDistanceOp", # AI ops + "AIGenerate", "AIGenerateBool", "AIGenerateDouble", "AIGenerateInt", diff --git a/bigframes/operations/ai_ops.py b/bigframes/operations/ai_ops.py index 4404558497..5d710bf6b5 100644 --- a/bigframes/operations/ai_ops.py +++ b/bigframes/operations/ai_ops.py @@ -24,6 +24,28 @@ from bigframes.operations import base_ops +@dataclasses.dataclass(frozen=True) +class AIGenerate(base_ops.NaryOp): + name: ClassVar[str] = "ai_generate" + + prompt_context: Tuple[str | None, ...] + connection_id: str + endpoint: str | None + request_type: Literal["dedicated", "shared", "unspecified"] + model_params: str | None + + def output_type(self, *input_types: dtypes.ExpressionType) -> dtypes.ExpressionType: + return pd.ArrowDtype( + pa.struct( + ( + pa.field("result", pa.string()), + pa.field("full_response", dtypes.JSON_ARROW_TYPE), + pa.field("status", pa.string()), + ) + ) + ) + + @dataclasses.dataclass(frozen=True) class AIGenerateBool(base_ops.NaryOp): name: ClassVar[str] = "ai_generate_bool" diff --git a/tests/system/small/bigquery/test_ai.py b/tests/system/small/bigquery/test_ai.py index 7d32149726..890cd4fb2b 100644 --- a/tests/system/small/bigquery/test_ai.py +++ b/tests/system/small/bigquery/test_ai.py @@ -69,6 +69,24 @@ def test_ai_function_compile_model_params(session): ) +def test_ai_generate(session): + country = bpd.Series(["Japan", "Canada"], session=session) + prompt = ("What's the capital city of ", country, "? one word only") + + result = bbq.ai.generate(prompt, endpoint="gemini-2.5-flash") + + assert _contains_no_nulls(result) + assert result.dtype == pd.ArrowDtype( + pa.struct( + ( + pa.field("result", pa.string()), + pa.field("full_response", dtypes.JSON_ARROW_TYPE), + pa.field("status", pa.string()), + ) + ) + ) + + def test_ai_generate_bool(session): s1 = bpd.Series(["apple", "bear"], session=session) s2 = bpd.Series(["fruit", "tree"], session=session) diff --git a/tests/unit/core/compile/sqlglot/expressions/snapshots/test_ai_ops/test_ai_generate/out.sql b/tests/unit/core/compile/sqlglot/expressions/snapshots/test_ai_ops/test_ai_generate/out.sql new file mode 100644 index 0000000000..5c4ccefd7b --- /dev/null +++ b/tests/unit/core/compile/sqlglot/expressions/snapshots/test_ai_ops/test_ai_generate/out.sql @@ -0,0 +1,18 @@ +WITH `bfcte_0` AS ( + SELECT + `string_col` AS `bfcol_0` + FROM `bigframes-dev`.`sqlglot_test`.`scalar_types` +), `bfcte_1` AS ( + SELECT + *, + AI.GENERATE( + prompt => (`bfcol_0`, ' is the same as ', `bfcol_0`), + connection_id => 'bigframes-dev.us.bigframes-default-connection', + endpoint => 'gemini-2.5-flash', + request_type => 'SHARED' + ) AS `bfcol_1` + FROM `bfcte_0` +) +SELECT + `bfcol_1` AS `result` +FROM `bfcte_1` \ No newline at end of file diff --git a/tests/unit/core/compile/sqlglot/expressions/snapshots/test_ai_ops/test_ai_generate_bool/out.sql b/tests/unit/core/compile/sqlglot/expressions/snapshots/test_ai_ops/test_ai_generate_bool/out.sql index 584ccd9ce1..28905d0349 100644 --- a/tests/unit/core/compile/sqlglot/expressions/snapshots/test_ai_ops/test_ai_generate_bool/out.sql +++ b/tests/unit/core/compile/sqlglot/expressions/snapshots/test_ai_ops/test_ai_generate_bool/out.sql @@ -7,7 +7,7 @@ WITH `bfcte_0` AS ( *, AI.GENERATE_BOOL( prompt => (`bfcol_0`, ' is the same as ', `bfcol_0`), - connection_id => 'test_connection_id', + connection_id => 'bigframes-dev.us.bigframes-default-connection', endpoint => 'gemini-2.5-flash', request_type => 'SHARED' ) AS `bfcol_1` diff --git a/tests/unit/core/compile/sqlglot/expressions/snapshots/test_ai_ops/test_ai_generate_bool_with_model_param/out.sql b/tests/unit/core/compile/sqlglot/expressions/snapshots/test_ai_ops/test_ai_generate_bool_with_model_param/out.sql index fca2b965bf..bf52361a52 100644 --- a/tests/unit/core/compile/sqlglot/expressions/snapshots/test_ai_ops/test_ai_generate_bool_with_model_param/out.sql +++ b/tests/unit/core/compile/sqlglot/expressions/snapshots/test_ai_ops/test_ai_generate_bool_with_model_param/out.sql @@ -7,7 +7,7 @@ WITH `bfcte_0` AS ( *, AI.GENERATE_BOOL( prompt => (`bfcol_0`, ' is the same as ', `bfcol_0`), - connection_id => 'test_connection_id', + connection_id => 'bigframes-dev.us.bigframes-default-connection', request_type => 'SHARED', model_params => JSON '{}' ) AS `bfcol_1` diff --git a/tests/unit/core/compile/sqlglot/expressions/snapshots/test_ai_ops/test_ai_generate_double/out.sql b/tests/unit/core/compile/sqlglot/expressions/snapshots/test_ai_ops/test_ai_generate_double/out.sql index 0baab06c3b..cbb05264e9 100644 --- a/tests/unit/core/compile/sqlglot/expressions/snapshots/test_ai_ops/test_ai_generate_double/out.sql +++ b/tests/unit/core/compile/sqlglot/expressions/snapshots/test_ai_ops/test_ai_generate_double/out.sql @@ -7,7 +7,7 @@ WITH `bfcte_0` AS ( *, AI.GENERATE_DOUBLE( prompt => (`bfcol_0`, ' is the same as ', `bfcol_0`), - connection_id => 'test_connection_id', + connection_id => 'bigframes-dev.us.bigframes-default-connection', endpoint => 'gemini-2.5-flash', request_type => 'SHARED' ) AS `bfcol_1` diff --git a/tests/unit/core/compile/sqlglot/expressions/snapshots/test_ai_ops/test_ai_generate_double_with_model_param/out.sql b/tests/unit/core/compile/sqlglot/expressions/snapshots/test_ai_ops/test_ai_generate_double_with_model_param/out.sql index 4756cbb509..a1c1a18664 100644 --- a/tests/unit/core/compile/sqlglot/expressions/snapshots/test_ai_ops/test_ai_generate_double_with_model_param/out.sql +++ b/tests/unit/core/compile/sqlglot/expressions/snapshots/test_ai_ops/test_ai_generate_double_with_model_param/out.sql @@ -7,7 +7,7 @@ WITH `bfcte_0` AS ( *, AI.GENERATE_DOUBLE( prompt => (`bfcol_0`, ' is the same as ', `bfcol_0`), - connection_id => 'test_connection_id', + connection_id => 'bigframes-dev.us.bigframes-default-connection', request_type => 'SHARED', model_params => JSON '{}' ) AS `bfcol_1` diff --git a/tests/unit/core/compile/sqlglot/expressions/snapshots/test_ai_ops/test_ai_generate_int/out.sql b/tests/unit/core/compile/sqlglot/expressions/snapshots/test_ai_ops/test_ai_generate_int/out.sql index e48b64bead..ba5febe0cd 100644 --- a/tests/unit/core/compile/sqlglot/expressions/snapshots/test_ai_ops/test_ai_generate_int/out.sql +++ b/tests/unit/core/compile/sqlglot/expressions/snapshots/test_ai_ops/test_ai_generate_int/out.sql @@ -7,7 +7,7 @@ WITH `bfcte_0` AS ( *, AI.GENERATE_INT( prompt => (`bfcol_0`, ' is the same as ', `bfcol_0`), - connection_id => 'test_connection_id', + connection_id => 'bigframes-dev.us.bigframes-default-connection', endpoint => 'gemini-2.5-flash', request_type => 'SHARED' ) AS `bfcol_1` diff --git a/tests/unit/core/compile/sqlglot/expressions/snapshots/test_ai_ops/test_ai_generate_int_with_model_param/out.sql b/tests/unit/core/compile/sqlglot/expressions/snapshots/test_ai_ops/test_ai_generate_int_with_model_param/out.sql index 6f406dea18..996906fe9c 100644 --- a/tests/unit/core/compile/sqlglot/expressions/snapshots/test_ai_ops/test_ai_generate_int_with_model_param/out.sql +++ b/tests/unit/core/compile/sqlglot/expressions/snapshots/test_ai_ops/test_ai_generate_int_with_model_param/out.sql @@ -7,7 +7,7 @@ WITH `bfcte_0` AS ( *, AI.GENERATE_INT( prompt => (`bfcol_0`, ' is the same as ', `bfcol_0`), - connection_id => 'test_connection_id', + connection_id => 'bigframes-dev.us.bigframes-default-connection', request_type => 'SHARED', model_params => JSON '{}' ) AS `bfcol_1` diff --git a/tests/unit/core/compile/sqlglot/expressions/snapshots/test_ai_ops/test_ai_generate_with_model_param/out.sql b/tests/unit/core/compile/sqlglot/expressions/snapshots/test_ai_ops/test_ai_generate_with_model_param/out.sql new file mode 100644 index 0000000000..8726910619 --- /dev/null +++ b/tests/unit/core/compile/sqlglot/expressions/snapshots/test_ai_ops/test_ai_generate_with_model_param/out.sql @@ -0,0 +1,18 @@ +WITH `bfcte_0` AS ( + SELECT + `string_col` AS `bfcol_0` + FROM `bigframes-dev`.`sqlglot_test`.`scalar_types` +), `bfcte_1` AS ( + SELECT + *, + AI.GENERATE( + prompt => (`bfcol_0`, ' is the same as ', `bfcol_0`), + connection_id => 'bigframes-dev.us.bigframes-default-connection', + request_type => 'SHARED', + model_params => JSON '{}' + ) AS `bfcol_1` + FROM `bfcte_0` +) +SELECT + `bfcol_1` AS `result` +FROM `bfcte_1` \ No newline at end of file diff --git a/tests/unit/core/compile/sqlglot/expressions/test_ai_ops.py b/tests/unit/core/compile/sqlglot/expressions/test_ai_ops.py index c95889e1b2..f20b39bc74 100644 --- a/tests/unit/core/compile/sqlglot/expressions/test_ai_ops.py +++ b/tests/unit/core/compile/sqlglot/expressions/test_ai_ops.py @@ -24,13 +24,56 @@ pytest.importorskip("pytest_snapshot") +CONNECTION_ID = "bigframes-dev.us.bigframes-default-connection" + + +def test_ai_generate(scalar_types_df: dataframe.DataFrame, snapshot): + col_name = "string_col" + + op = ops.AIGenerate( + prompt_context=(None, " is the same as ", None), + connection_id=CONNECTION_ID, + endpoint="gemini-2.5-flash", + request_type="shared", + model_params=None, + ) + + sql = utils._apply_unary_ops( + scalar_types_df, [op.as_expr(col_name, col_name)], ["result"] + ) + + snapshot.assert_match(sql, "out.sql") + + +def test_ai_generate_with_model_param(scalar_types_df: dataframe.DataFrame, snapshot): + if version.Version(sqlglot.__version__) < version.Version("25.18.0"): + pytest.skip( + "Skip test because SQLGLot cannot compile model params to JSON at this version." + ) + + col_name = "string_col" + + op = ops.AIGenerate( + prompt_context=(None, " is the same as ", None), + connection_id=CONNECTION_ID, + endpoint=None, + request_type="shared", + model_params=json.dumps(dict()), + ) + + sql = utils._apply_unary_ops( + scalar_types_df, [op.as_expr(col_name, col_name)], ["result"] + ) + + snapshot.assert_match(sql, "out.sql") + def test_ai_generate_bool(scalar_types_df: dataframe.DataFrame, snapshot): col_name = "string_col" op = ops.AIGenerateBool( prompt_context=(None, " is the same as ", None), - connection_id="test_connection_id", + connection_id=CONNECTION_ID, endpoint="gemini-2.5-flash", request_type="shared", model_params=None, @@ -55,7 +98,7 @@ def test_ai_generate_bool_with_model_param( op = ops.AIGenerateBool( prompt_context=(None, " is the same as ", None), - connection_id="test_connection_id", + connection_id=CONNECTION_ID, endpoint=None, request_type="shared", model_params=json.dumps(dict()), @@ -74,7 +117,7 @@ def test_ai_generate_int(scalar_types_df: dataframe.DataFrame, snapshot): op = ops.AIGenerateInt( # The prompt does not make semantic sense but we only care about syntax correctness. prompt_context=(None, " is the same as ", None), - connection_id="test_connection_id", + connection_id=CONNECTION_ID, endpoint="gemini-2.5-flash", request_type="shared", model_params=None, @@ -100,7 +143,7 @@ def test_ai_generate_int_with_model_param( op = ops.AIGenerateInt( # The prompt does not make semantic sense but we only care about syntax correctness. prompt_context=(None, " is the same as ", None), - connection_id="test_connection_id", + connection_id=CONNECTION_ID, endpoint=None, request_type="shared", model_params=json.dumps(dict()), @@ -119,7 +162,7 @@ def test_ai_generate_double(scalar_types_df: dataframe.DataFrame, snapshot): op = ops.AIGenerateDouble( # The prompt does not make semantic sense but we only care about syntax correctness. prompt_context=(None, " is the same as ", None), - connection_id="test_connection_id", + connection_id=CONNECTION_ID, endpoint="gemini-2.5-flash", request_type="shared", model_params=None, @@ -145,7 +188,7 @@ def test_ai_generate_double_with_model_param( op = ops.AIGenerateDouble( # The prompt does not make semantic sense but we only care about syntax correctness. prompt_context=(None, " is the same as ", None), - connection_id="test_connection_id", + connection_id=CONNECTION_ID, endpoint=None, request_type="shared", model_params=json.dumps(dict()), diff --git a/third_party/bigframes_vendored/ibis/backends/sql/compilers/bigquery/__init__.py b/third_party/bigframes_vendored/ibis/backends/sql/compilers/bigquery/__init__.py index e4122e88da..836d15118c 100644 --- a/third_party/bigframes_vendored/ibis/backends/sql/compilers/bigquery/__init__.py +++ b/third_party/bigframes_vendored/ibis/backends/sql/compilers/bigquery/__init__.py @@ -1104,6 +1104,9 @@ def visit_StringAgg(self, op, *, arg, sep, order_by, where): expr = arg return self.agg.string_agg(expr, sep, where=where) + def visit_AIGenerate(self, op, **kwargs): + return sge.func("AI.GENERATE", *self._compile_ai_args(**kwargs)) + def visit_AIGenerateBool(self, op, **kwargs): return sge.func("AI.GENERATE_BOOL", *self._compile_ai_args(**kwargs)) diff --git a/third_party/bigframes_vendored/ibis/expr/operations/ai_ops.py b/third_party/bigframes_vendored/ibis/expr/operations/ai_ops.py index 708a459072..05c5e7e0af 100644 --- a/third_party/bigframes_vendored/ibis/expr/operations/ai_ops.py +++ b/third_party/bigframes_vendored/ibis/expr/operations/ai_ops.py @@ -13,6 +13,25 @@ from public import public +@public +class AIGenerate(Value): + """Generate content based on the prompt""" + + prompt: Value + connection_id: Value[dt.String] + endpoint: Optional[Value[dt.String]] + request_type: Value[dt.String] + model_params: Optional[Value[dt.String]] + + shape = rlz.shape_like("prompt") + + @attribute + def dtype(self) -> dt.Struct: + return dt.Struct.from_tuples( + (("result", dt.string), ("full_resposne", dt.string), ("status", dt.string)) + ) + + @public class AIGenerateBool(Value): """Generate Bool based on the prompt""" @@ -53,7 +72,7 @@ def dtype(self) -> dt.Struct: @public class AIGenerateDouble(Value): - """Generate integers based on the prompt""" + """Generate doubles based on the prompt""" prompt: Value connection_id: Value[dt.String]