Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 74 additions & 0 deletions bigframes/bigquery/_operations/ai.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,80 @@
]


@log_adapter.method_logger(custom_base_name="bigquery_ai")
def generate(
prompt: PROMPT_TYPE,
*,
connection_id: str | None = None,
endpoint: str | None = None,
request_type: Literal["dedicated", "shared", "unspecified"] = "unspecified",
model_params: Mapping[Any, Any] | None = None,
# TODO(b/446974666) Add output_schema parameter
) -> series.Series:
"""
Returns the AI analysis based on the prompt, which can be any combination of text and unstructured data.

**Examples:**

>>> import bigframes.pandas as bpd
>>> import bigframes.bigquery as bbq
>>> bpd.options.display.progress_bar = None
>>> country = bpd.Series(["Japan", "Canada"])
>>> bbq.ai.generate(("What's the capital city of ", country, " one word only"))
0 {'result': 'Tokyo\\n', 'full_response': '{"cand...
1 {'result': 'Ottawa\\n', 'full_response': '{"can...
dtype: struct<result: string, full_response: extension<dbjson<JSONArrowType>>, status: string>[pyarrow]

>>> bbq.ai.generate(("What's the capital city of ", country, " one word only")).struct.field("result")
0 Tokyo\\n
1 Ottawa\\n
Name: result, dtype: string

Args:
prompt (Series | List[str|Series] | Tuple[str|Series, ...]):
A mixture of Series and string literals that specifies the prompt to send to the model. The Series can be BigFrames Series
or pandas Series.
connection_id (str, optional):
Specifies the connection to use to communicate with the model. For example, `myproject.us.myconnection`.
If not provided, the connection from the current session will be used.
endpoint (str, optional):
Specifies the Vertex AI endpoint to use for the model. For example `"gemini-2.5-flash"`. You can specify any
generally available or preview Gemini model. If you specify the model name, BigQuery ML automatically identifies and
uses the full endpoint of the model. If you don't specify an ENDPOINT value, BigQuery ML selects a recent stable
version of Gemini to use.
request_type (Literal["dedicated", "shared", "unspecified"]):
Specifies the type of inference request to send to the Gemini model. The request type determines what quota the request uses.
* "dedicated": function only uses Provisioned Throughput quota. The function returns the error Provisioned throughput is not
purchased or is not active if Provisioned Throughput quota isn't available.
* "shared": the function only uses dynamic shared quota (DSQ), even if you have purchased Provisioned Throughput quota.
* "unspecified": If you haven't purchased Provisioned Throughput quota, the function uses DSQ quota.
If you have purchased Provisioned Throughput quota, the function uses the Provisioned Throughput quota first.
If requests exceed the Provisioned Throughput quota, the overflow traffic uses DSQ quota.
model_params (Mapping[Any, Any]):
Provides additional parameters to the model. The MODEL_PARAMS value must conform to the generateContent request body format.

Returns:
bigframes.series.Series: A new struct Series with the result data. The struct contains these fields:
* "result": a STRING value containing the model's response to the prompt. The result is None if the request fails or is filtered by responsible AI.
* "full_response": a JSON value containing the response from the projects.locations.endpoints.generateContent call to the model.
The generated text is in the text element.
* "status": a STRING value that contains the API response status for the corresponding row. This value is empty if the operation was successful.
"""

prompt_context, series_list = _separate_context_and_series(prompt)
assert len(series_list) > 0

operator = ai_ops.AIGenerate(
prompt_context=tuple(prompt_context),
connection_id=_resolve_connection_id(series_list[0], connection_id),
endpoint=endpoint,
request_type=request_type,
model_params=json.dumps(model_params) if model_params else None,
)

return series_list[0]._apply_nary_op(operator, series_list[1:])


@log_adapter.method_logger(custom_base_name="bigquery_ai")
def generate_bool(
prompt: PROMPT_TYPE,
Expand Down
14 changes: 14 additions & 0 deletions bigframes/core/compile/ibis_compiler/scalar_op_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -1974,6 +1974,20 @@ def struct_op_impl(
return ibis_types.struct(data)


@scalar_op_compiler.register_nary_op(ops.AIGenerate, pass_op=True)
def ai_generate(
*values: ibis_types.Value, op: ops.AIGenerate
) -> ibis_types.StructValue:

return ai_ops.AIGenerate(
_construct_prompt(values, op.prompt_context), # type: ignore
op.connection_id, # type: ignore
op.endpoint, # type: ignore
op.request_type.upper(), # type: ignore
op.model_params, # type: ignore
).to_expr()


@scalar_op_compiler.register_nary_op(ops.AIGenerateBool, pass_op=True)
def ai_generate_bool(
*values: ibis_types.Value, op: ops.AIGenerateBool
Expand Down
7 changes: 7 additions & 0 deletions bigframes/core/compile/sqlglot/expressions/ai_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,13 @@
register_nary_op = scalar_compiler.scalar_op_compiler.register_nary_op


@register_nary_op(ops.AIGenerate, pass_op=True)
def _(*exprs: TypedExpr, op: ops.AIGenerate) -> sge.Expression:
args = [_construct_prompt(exprs, op.prompt_context)] + _construct_named_args(op)

return sge.func("AI.GENERATE", *args)


@register_nary_op(ops.AIGenerateBool, pass_op=True)
def _(*exprs: TypedExpr, op: ops.AIGenerateBool) -> sge.Expression:
args = [_construct_prompt(exprs, op.prompt_context)] + _construct_named_args(op)
Expand Down
8 changes: 7 additions & 1 deletion bigframes/operations/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,12 @@

from __future__ import annotations

from bigframes.operations.ai_ops import AIGenerateBool, AIGenerateDouble, AIGenerateInt
from bigframes.operations.ai_ops import (
AIGenerate,
AIGenerateBool,
AIGenerateDouble,
AIGenerateInt,
)
from bigframes.operations.array_ops import (
ArrayIndexOp,
ArrayReduceOp,
Expand Down Expand Up @@ -412,6 +417,7 @@
"geo_y_op",
"GeoStDistanceOp",
# AI ops
"AIGenerate",
"AIGenerateBool",
"AIGenerateDouble",
"AIGenerateInt",
Expand Down
22 changes: 22 additions & 0 deletions bigframes/operations/ai_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,28 @@
from bigframes.operations import base_ops


@dataclasses.dataclass(frozen=True)
class AIGenerate(base_ops.NaryOp):
name: ClassVar[str] = "ai_generate"

prompt_context: Tuple[str | None, ...]
connection_id: str
endpoint: str | None
request_type: Literal["dedicated", "shared", "unspecified"]
model_params: str | None

def output_type(self, *input_types: dtypes.ExpressionType) -> dtypes.ExpressionType:
return pd.ArrowDtype(
pa.struct(
(
pa.field("result", pa.string()),
pa.field("full_response", dtypes.JSON_ARROW_TYPE),
pa.field("status", pa.string()),
)
)
)


@dataclasses.dataclass(frozen=True)
class AIGenerateBool(base_ops.NaryOp):
name: ClassVar[str] = "ai_generate_bool"
Expand Down
18 changes: 18 additions & 0 deletions tests/system/small/bigquery/test_ai.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,24 @@ def test_ai_function_compile_model_params(session):
)


def test_ai_generate(session):
country = bpd.Series(["Japan", "Canada"], session=session)
prompt = ("What's the capital city of ", country, "? one word only")

result = bbq.ai.generate(prompt, endpoint="gemini-2.5-flash")

assert _contains_no_nulls(result)
assert result.dtype == pd.ArrowDtype(
pa.struct(
(
pa.field("result", pa.string()),
pa.field("full_response", dtypes.JSON_ARROW_TYPE),
pa.field("status", pa.string()),
)
)
)


def test_ai_generate_bool(session):
s1 = bpd.Series(["apple", "bear"], session=session)
s2 = bpd.Series(["fruit", "tree"], session=session)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
WITH `bfcte_0` AS (
SELECT
`string_col` AS `bfcol_0`
FROM `bigframes-dev`.`sqlglot_test`.`scalar_types`
), `bfcte_1` AS (
SELECT
*,
AI.GENERATE(
prompt => (`bfcol_0`, ' is the same as ', `bfcol_0`),
connection_id => 'bigframes-dev.us.bigframes-default-connection',
endpoint => 'gemini-2.5-flash',
request_type => 'SHARED'
) AS `bfcol_1`
FROM `bfcte_0`
)
SELECT
`bfcol_1` AS `result`
FROM `bfcte_1`
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ WITH `bfcte_0` AS (
*,
AI.GENERATE_BOOL(
prompt => (`bfcol_0`, ' is the same as ', `bfcol_0`),
connection_id => 'test_connection_id',
connection_id => 'bigframes-dev.us.bigframes-default-connection',
endpoint => 'gemini-2.5-flash',
request_type => 'SHARED'
) AS `bfcol_1`
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ WITH `bfcte_0` AS (
*,
AI.GENERATE_BOOL(
prompt => (`bfcol_0`, ' is the same as ', `bfcol_0`),
connection_id => 'test_connection_id',
connection_id => 'bigframes-dev.us.bigframes-default-connection',
request_type => 'SHARED',
model_params => JSON '{}'
) AS `bfcol_1`
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ WITH `bfcte_0` AS (
*,
AI.GENERATE_DOUBLE(
prompt => (`bfcol_0`, ' is the same as ', `bfcol_0`),
connection_id => 'test_connection_id',
connection_id => 'bigframes-dev.us.bigframes-default-connection',
endpoint => 'gemini-2.5-flash',
request_type => 'SHARED'
) AS `bfcol_1`
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ WITH `bfcte_0` AS (
*,
AI.GENERATE_DOUBLE(
prompt => (`bfcol_0`, ' is the same as ', `bfcol_0`),
connection_id => 'test_connection_id',
connection_id => 'bigframes-dev.us.bigframes-default-connection',
request_type => 'SHARED',
model_params => JSON '{}'
) AS `bfcol_1`
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ WITH `bfcte_0` AS (
*,
AI.GENERATE_INT(
prompt => (`bfcol_0`, ' is the same as ', `bfcol_0`),
connection_id => 'test_connection_id',
connection_id => 'bigframes-dev.us.bigframes-default-connection',
endpoint => 'gemini-2.5-flash',
request_type => 'SHARED'
) AS `bfcol_1`
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ WITH `bfcte_0` AS (
*,
AI.GENERATE_INT(
prompt => (`bfcol_0`, ' is the same as ', `bfcol_0`),
connection_id => 'test_connection_id',
connection_id => 'bigframes-dev.us.bigframes-default-connection',
request_type => 'SHARED',
model_params => JSON '{}'
) AS `bfcol_1`
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
WITH `bfcte_0` AS (
SELECT
`string_col` AS `bfcol_0`
FROM `bigframes-dev`.`sqlglot_test`.`scalar_types`
), `bfcte_1` AS (
SELECT
*,
AI.GENERATE(
prompt => (`bfcol_0`, ' is the same as ', `bfcol_0`),
connection_id => 'bigframes-dev.us.bigframes-default-connection',
request_type => 'SHARED',
model_params => JSON '{}'
) AS `bfcol_1`
FROM `bfcte_0`
)
SELECT
`bfcol_1` AS `result`
FROM `bfcte_1`
55 changes: 49 additions & 6 deletions tests/unit/core/compile/sqlglot/expressions/test_ai_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,56 @@

pytest.importorskip("pytest_snapshot")

CONNECTION_ID = "bigframes-dev.us.bigframes-default-connection"


def test_ai_generate(scalar_types_df: dataframe.DataFrame, snapshot):
col_name = "string_col"

op = ops.AIGenerate(
prompt_context=(None, " is the same as ", None),
connection_id=CONNECTION_ID,
endpoint="gemini-2.5-flash",
request_type="shared",
model_params=None,
)

sql = utils._apply_unary_ops(
scalar_types_df, [op.as_expr(col_name, col_name)], ["result"]
)

snapshot.assert_match(sql, "out.sql")


def test_ai_generate_with_model_param(scalar_types_df: dataframe.DataFrame, snapshot):
if version.Version(sqlglot.__version__) < version.Version("25.18.0"):
pytest.skip(
"Skip test because SQLGLot cannot compile model params to JSON at this version."
)

col_name = "string_col"

op = ops.AIGenerate(
prompt_context=(None, " is the same as ", None),
connection_id=CONNECTION_ID,
endpoint=None,
request_type="shared",
model_params=json.dumps(dict()),
)

sql = utils._apply_unary_ops(
scalar_types_df, [op.as_expr(col_name, col_name)], ["result"]
)

snapshot.assert_match(sql, "out.sql")


def test_ai_generate_bool(scalar_types_df: dataframe.DataFrame, snapshot):
col_name = "string_col"

op = ops.AIGenerateBool(
prompt_context=(None, " is the same as ", None),
connection_id="test_connection_id",
connection_id=CONNECTION_ID,
endpoint="gemini-2.5-flash",
request_type="shared",
model_params=None,
Expand All @@ -55,7 +98,7 @@ def test_ai_generate_bool_with_model_param(

op = ops.AIGenerateBool(
prompt_context=(None, " is the same as ", None),
connection_id="test_connection_id",
connection_id=CONNECTION_ID,
endpoint=None,
request_type="shared",
model_params=json.dumps(dict()),
Expand All @@ -74,7 +117,7 @@ def test_ai_generate_int(scalar_types_df: dataframe.DataFrame, snapshot):
op = ops.AIGenerateInt(
# The prompt does not make semantic sense but we only care about syntax correctness.
prompt_context=(None, " is the same as ", None),
connection_id="test_connection_id",
connection_id=CONNECTION_ID,
endpoint="gemini-2.5-flash",
request_type="shared",
model_params=None,
Expand All @@ -100,7 +143,7 @@ def test_ai_generate_int_with_model_param(
op = ops.AIGenerateInt(
# The prompt does not make semantic sense but we only care about syntax correctness.
prompt_context=(None, " is the same as ", None),
connection_id="test_connection_id",
connection_id=CONNECTION_ID,
endpoint=None,
request_type="shared",
model_params=json.dumps(dict()),
Expand All @@ -119,7 +162,7 @@ def test_ai_generate_double(scalar_types_df: dataframe.DataFrame, snapshot):
op = ops.AIGenerateDouble(
# The prompt does not make semantic sense but we only care about syntax correctness.
prompt_context=(None, " is the same as ", None),
connection_id="test_connection_id",
connection_id=CONNECTION_ID,
endpoint="gemini-2.5-flash",
request_type="shared",
model_params=None,
Expand All @@ -145,7 +188,7 @@ def test_ai_generate_double_with_model_param(
op = ops.AIGenerateDouble(
# The prompt does not make semantic sense but we only care about syntax correctness.
prompt_context=(None, " is the same as ", None),
connection_id="test_connection_id",
connection_id=CONNECTION_ID,
endpoint=None,
request_type="shared",
model_params=json.dumps(dict()),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1104,6 +1104,9 @@ def visit_StringAgg(self, op, *, arg, sep, order_by, where):
expr = arg
return self.agg.string_agg(expr, sep, where=where)

def visit_AIGenerate(self, op, **kwargs):
return sge.func("AI.GENERATE", *self._compile_ai_args(**kwargs))

def visit_AIGenerateBool(self, op, **kwargs):
return sge.func("AI.GENERATE_BOOL", *self._compile_ai_args(**kwargs))

Expand Down
Loading