From b8f18f5c6b121d09864cabcb039c018a78f39b08 Mon Sep 17 00:00:00 2001 From: Shenyang Cai Date: Fri, 3 Oct 2025 22:14:24 +0000 Subject: [PATCH] feat: add ai.classify() to bigframes.bigquery package --- bigframes/bigquery/_operations/ai.py | 97 ++++++++++++++----- .../ibis_compiler/scalar_op_registry.py | 12 +++ .../compile/sqlglot/expressions/ai_ops.py | 21 +++- bigframes/operations/__init__.py | 2 + bigframes/operations/ai_ops.py | 12 +++ tests/system/small/bigquery/test_ai.py | 21 ++++ .../test_ai_ops/test_ai_classify/out.sql | 17 ++++ .../sqlglot/expressions/test_ai_ops.py | 14 +++ .../sql/compilers/bigquery/__init__.py | 3 + .../ibis/expr/operations/ai_ops.py | 15 +++ 10 files changed, 189 insertions(+), 25 deletions(-) create mode 100644 tests/unit/core/compile/sqlglot/expressions/snapshots/test_ai_ops/test_ai_classify/out.sql diff --git a/bigframes/bigquery/_operations/ai.py b/bigframes/bigquery/_operations/ai.py index 4759c99016..a789310683 100644 --- a/bigframes/bigquery/_operations/ai.py +++ b/bigframes/bigquery/_operations/ai.py @@ -348,20 +348,20 @@ def if_( provides optimization such that not all rows are evaluated with the LLM. **Examples:** - >>> import bigframes.pandas as bpd - >>> import bigframes.bigquery as bbq - >>> bpd.options.display.progress_bar = None - >>> us_state = bpd.Series(["Massachusetts", "Illinois", "Hawaii"]) - >>> bbq.ai.if_((us_state, " has a city called Springfield")) - 0 True - 1 True - 2 False - dtype: boolean - - >>> us_state[bbq.ai.if_((us_state, " has a city called Springfield"))] - 0 Massachusetts - 1 Illinois - dtype: string + >>> import bigframes.pandas as bpd + >>> import bigframes.bigquery as bbq + >>> bpd.options.display.progress_bar = None + >>> us_state = bpd.Series(["Massachusetts", "Illinois", "Hawaii"]) + >>> bbq.ai.if_((us_state, " has a city called Springfield")) + 0 True + 1 True + 2 False + dtype: boolean + + >>> us_state[bbq.ai.if_((us_state, " has a city called Springfield"))] + 0 Massachusetts + 1 Illinois + dtype: string Args: prompt (Series | List[str|Series] | Tuple[str|Series, ...]): @@ -386,6 +386,56 @@ def if_( return series_list[0]._apply_nary_op(operator, series_list[1:]) +@log_adapter.method_logger(custom_base_name="bigquery_ai") +def classify( + input: PROMPT_TYPE, + categories: tuple[str, ...] | list[str], + *, + connection_id: str | None = None, +) -> series.Series: + """ + Classifies a given input into one of the specified categories. It will always return one of the provided categories best fit the prompt input. + + **Examples:** + + >>> import bigframes.pandas as bpd + >>> import bigframes.bigquery as bbq + >>> bpd.options.display.progress_bar = None + >>> df = bpd.DataFrame({'creature': ['Cat', 'Salmon']}) + >>> df['type'] = bbq.ai.classify(df['creature'], ['Mammal', 'Fish']) + >>> df + creature type + 0 Cat Mammal + 1 Salmon Fish + + [2 rows x 2 columns] + + Args: + input (Series | List[str|Series] | Tuple[str|Series, ...]): + A mixture of Series and string literals that specifies the input to send to the model. The Series can be BigFrames Series + or pandas Series. + categories (tuple[str, ...] | list[str]): + Categories to classify the input into. + connection_id (str, optional): + Specifies the connection to use to communicate with the model. For example, `myproject.us.myconnection`. + If not provided, the connection from the current session will be used. + + Returns: + bigframes.series.Series: A new series of strings. + """ + + prompt_context, series_list = _separate_context_and_series(input) + assert len(series_list) > 0 + + operator = ai_ops.AIClassify( + prompt_context=tuple(prompt_context), + categories=tuple(categories), + connection_id=_resolve_connection_id(series_list[0], connection_id), + ) + + return series_list[0]._apply_nary_op(operator, series_list[1:]) + + @log_adapter.method_logger(custom_base_name="bigquery_ai") def score( prompt: PROMPT_TYPE, @@ -398,15 +448,16 @@ def score( rubric with examples in the prompt. **Examples:** - >>> import bigframes.pandas as bpd - >>> import bigframes.bigquery as bbq - >>> bpd.options.display.progress_bar = None - >>> animal = bpd.Series(["Tiger", "Rabbit", "Blue Whale"]) - >>> bbq.ai.score(("Rank the relative weights of ", animal, " on the scale from 1 to 3")) # doctest: +SKIP - 0 2.0 - 1 1.0 - 2 3.0 - dtype: Float64 + + >>> import bigframes.pandas as bpd + >>> import bigframes.bigquery as bbq + >>> bpd.options.display.progress_bar = None + >>> animal = bpd.Series(["Tiger", "Rabbit", "Blue Whale"]) + >>> bbq.ai.score(("Rank the relative weights of ", animal, " on the scale from 1 to 3")) # doctest: +SKIP + 0 2.0 + 1 1.0 + 2 3.0 + dtype: Float64 Args: prompt (Series | List[str|Series] | Tuple[str|Series, ...]): diff --git a/bigframes/core/compile/ibis_compiler/scalar_op_registry.py b/bigframes/core/compile/ibis_compiler/scalar_op_registry.py index 7280e9a40a..4c02e17d6f 100644 --- a/bigframes/core/compile/ibis_compiler/scalar_op_registry.py +++ b/bigframes/core/compile/ibis_compiler/scalar_op_registry.py @@ -2039,6 +2039,18 @@ def ai_if(*values: ibis_types.Value, op: ops.AIIf) -> ibis_types.StructValue: ).to_expr() +@scalar_op_compiler.register_nary_op(ops.AIClassify, pass_op=True) +def ai_classify( + *values: ibis_types.Value, op: ops.AIClassify +) -> ibis_types.StructValue: + + return ai_ops.AIClassify( + _construct_prompt(values, op.prompt_context), # type: ignore + op.categories, # type: ignore + op.connection_id, # type: ignore + ).to_expr() + + @scalar_op_compiler.register_nary_op(ops.AIScore, pass_op=True) def ai_score(*values: ibis_types.Value, op: ops.AIScore) -> ibis_types.StructValue: diff --git a/bigframes/core/compile/sqlglot/expressions/ai_ops.py b/bigframes/core/compile/sqlglot/expressions/ai_ops.py index 46a79d1440..4129c91906 100644 --- a/bigframes/core/compile/sqlglot/expressions/ai_ops.py +++ b/bigframes/core/compile/sqlglot/expressions/ai_ops.py @@ -61,6 +61,21 @@ def _(*exprs: TypedExpr, op: ops.AIIf) -> sge.Expression: return sge.func("AI.IF", *args) +@register_nary_op(ops.AIClassify, pass_op=True) +def _(*exprs: TypedExpr, op: ops.AIClassify) -> sge.Expression: + category_literals = [sge.Literal.string(cat) for cat in op.categories] + categories_arg = sge.Kwarg( + this="categories", expression=sge.array(*category_literals) + ) + + args = [ + _construct_prompt(exprs, op.prompt_context, param_name="input"), + categories_arg, + ] + _construct_named_args(op) + + return sge.func("AI.CLASSIFY", *args) + + @register_nary_op(ops.AIScore, pass_op=True) def _(*exprs: TypedExpr, op: ops.AIScore) -> sge.Expression: args = [_construct_prompt(exprs, op.prompt_context)] + _construct_named_args(op) @@ -69,7 +84,9 @@ def _(*exprs: TypedExpr, op: ops.AIScore) -> sge.Expression: def _construct_prompt( - exprs: tuple[TypedExpr, ...], prompt_context: tuple[str | None, ...] + exprs: tuple[TypedExpr, ...], + prompt_context: tuple[str | None, ...], + param_name: str = "prompt", ) -> sge.Kwarg: prompt: list[str | sge.Expression] = [] column_ref_idx = 0 @@ -80,7 +97,7 @@ def _construct_prompt( else: prompt.append(sge.Literal.string(elem)) - return sge.Kwarg(this="prompt", expression=sge.Tuple(expressions=prompt)) + return sge.Kwarg(this=param_name, expression=sge.Tuple(expressions=prompt)) def _construct_named_args(op: ops.NaryOp) -> list[sge.Kwarg]: diff --git a/bigframes/operations/__init__.py b/bigframes/operations/__init__.py index e7d0751fc9..24a7d6542f 100644 --- a/bigframes/operations/__init__.py +++ b/bigframes/operations/__init__.py @@ -15,6 +15,7 @@ from __future__ import annotations from bigframes.operations.ai_ops import ( + AIClassify, AIGenerate, AIGenerateBool, AIGenerateDouble, @@ -419,6 +420,7 @@ "geo_y_op", "GeoStDistanceOp", # AI ops + "AIClassify", "AIGenerate", "AIGenerateBool", "AIGenerateDouble", diff --git a/bigframes/operations/ai_ops.py b/bigframes/operations/ai_ops.py index 05d37d2a90..7ba3737ba0 100644 --- a/bigframes/operations/ai_ops.py +++ b/bigframes/operations/ai_ops.py @@ -123,6 +123,18 @@ def output_type(self, *input_types: dtypes.ExpressionType) -> dtypes.ExpressionT return dtypes.BOOL_DTYPE +@dataclasses.dataclass(frozen=True) +class AIClassify(base_ops.NaryOp): + name: ClassVar[str] = "ai_classify" + + prompt_context: Tuple[str | None, ...] + categories: tuple[str, ...] + connection_id: str + + def output_type(self, *input_types: dtypes.ExpressionType) -> dtypes.ExpressionType: + return dtypes.STRING_DTYPE + + @dataclasses.dataclass(frozen=True) class AIScore(base_ops.NaryOp): name: ClassVar[str] = "ai_score" diff --git a/tests/system/small/bigquery/test_ai.py b/tests/system/small/bigquery/test_ai.py index 91499d0efe..7a6e5aea4f 100644 --- a/tests/system/small/bigquery/test_ai.py +++ b/tests/system/small/bigquery/test_ai.py @@ -225,6 +225,27 @@ def test_ai_if_multi_model(session): assert result.dtype == dtypes.BOOL_DTYPE +def test_ai_classify(session): + s = bpd.Series(["cat", "orchid"], session=session) + bpd.options.display.repr_mode = "deferred" + + result = bbq.ai.classify(s, ["animal", "plant"]) + + assert _contains_no_nulls(result) + assert result.dtype == dtypes.STRING_DTYPE + + +def test_ai_classify_multi_model(session): + df = session.from_glob_path( + "gs://bigframes-dev-testing/a_multimodel/images/*", name="image" + ) + + result = bbq.ai.classify(df["image"], ["photo", "cartoon"]) + + assert _contains_no_nulls(result) + assert result.dtype == dtypes.STRING_DTYPE + + def test_ai_score(session): s = bpd.Series(["Tiger", "Rabbit"], session=session) prompt = ("Rank the relative weights of ", s, " on the scale from 1 to 3") diff --git a/tests/unit/core/compile/sqlglot/expressions/snapshots/test_ai_ops/test_ai_classify/out.sql b/tests/unit/core/compile/sqlglot/expressions/snapshots/test_ai_ops/test_ai_classify/out.sql new file mode 100644 index 0000000000..bb06760e4d --- /dev/null +++ b/tests/unit/core/compile/sqlglot/expressions/snapshots/test_ai_ops/test_ai_classify/out.sql @@ -0,0 +1,17 @@ +WITH `bfcte_0` AS ( + SELECT + `string_col` AS `bfcol_0` + FROM `bigframes-dev`.`sqlglot_test`.`scalar_types` +), `bfcte_1` AS ( + SELECT + *, + AI.CLASSIFY( + input => (`bfcol_0`), + categories => ['greeting', 'rejection'], + connection_id => 'bigframes-dev.us.bigframes-default-connection' + ) AS `bfcol_1` + FROM `bfcte_0` +) +SELECT + `bfcol_1` AS `result` +FROM `bfcte_1` \ No newline at end of file diff --git a/tests/unit/core/compile/sqlglot/expressions/test_ai_ops.py b/tests/unit/core/compile/sqlglot/expressions/test_ai_ops.py index 8f048a5bbf..c809e90a90 100644 --- a/tests/unit/core/compile/sqlglot/expressions/test_ai_ops.py +++ b/tests/unit/core/compile/sqlglot/expressions/test_ai_ops.py @@ -216,6 +216,20 @@ def test_ai_if(scalar_types_df: dataframe.DataFrame, snapshot): snapshot.assert_match(sql, "out.sql") +def test_ai_classify(scalar_types_df: dataframe.DataFrame, snapshot): + col_name = "string_col" + + op = ops.AIClassify( + prompt_context=(None,), + categories=("greeting", "rejection"), + connection_id=CONNECTION_ID, + ) + + sql = utils._apply_unary_ops(scalar_types_df, [op.as_expr(col_name)], ["result"]) + + snapshot.assert_match(sql, "out.sql") + + def test_ai_score(scalar_types_df: dataframe.DataFrame, snapshot): col_name = "string_col" diff --git a/third_party/bigframes_vendored/ibis/backends/sql/compilers/bigquery/__init__.py b/third_party/bigframes_vendored/ibis/backends/sql/compilers/bigquery/__init__.py index 8603c89cc8..cf205b69d6 100644 --- a/third_party/bigframes_vendored/ibis/backends/sql/compilers/bigquery/__init__.py +++ b/third_party/bigframes_vendored/ibis/backends/sql/compilers/bigquery/__init__.py @@ -1119,6 +1119,9 @@ def visit_AIGenerateDouble(self, op, **kwargs): def visit_AIIf(self, op, **kwargs): return sge.func("AI.IF", *self._compile_ai_args(**kwargs)) + def visit_AIClassify(self, op, **kwargs): + return sge.func("AI.CLASSIFY", *self._compile_ai_args(**kwargs)) + def visit_AIScore(self, op, **kwargs): return sge.func("AI.SCORE", *self._compile_ai_args(**kwargs)) diff --git a/third_party/bigframes_vendored/ibis/expr/operations/ai_ops.py b/third_party/bigframes_vendored/ibis/expr/operations/ai_ops.py index 5289ee7e60..e9d704fa8e 100644 --- a/third_party/bigframes_vendored/ibis/expr/operations/ai_ops.py +++ b/third_party/bigframes_vendored/ibis/expr/operations/ai_ops.py @@ -107,6 +107,21 @@ def dtype(self) -> dt.Struct: return dt.bool +@public +class AIClassify(Value): + """Generate True/False based on the prompt""" + + input: Value + categories: Value[dt.Array[dt.String]] + connection_id: Value[dt.String] + + shape = rlz.shape_like("input") + + @attribute + def dtype(self) -> dt.Struct: + return dt.string + + @public class AIScore(Value): """Generate doubles based on the prompt"""