Skip to content

Commit 098c35c

Browse files
authored
feat(bigframes): support output_mode for ai.classify (#17097)
Fixes b/491582856 🦕
1 parent daaed67 commit 098c35c

9 files changed

Lines changed: 108 additions & 20 deletions

File tree

packages/bigframes/bigframes/bigquery/_operations/ai.py

Lines changed: 20 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -842,9 +842,12 @@ def classify(
842842
input: PROMPT_TYPE,
843843
categories: tuple[str, ...] | list[str],
844844
*,
845-
examples: list[tuple[str, str]] | None = None,
845+
examples: list[tuple[str, str]]
846+
| list[tuple[str, list[str] | tuple[str, ...]]]
847+
| None = None,
846848
connection_id: str | None = None,
847849
endpoint: str | None = None,
850+
output_mode: Literal["single", "multi"] | None = None,
848851
optimization_mode: Literal["minimize_cost", "maximize_quality"] | None = None,
849852
max_error_ratio: float | None = None,
850853
) -> series.Series:
@@ -870,17 +873,21 @@ def classify(
870873
or pandas Series.
871874
categories (tuple[str, ...] | list[str]):
872875
Categories to classify the input into.
873-
examples (list[tuple[str, str]], optional):
876+
examples (list[tuple[str, str]] | list[tuple[str, list[str] | tuple[str, ...]]], optional):
874877
An array that contains representative examples of input strings and the output category
875-
that you expect. You can provide examples to help the model understand your
876-
intended threshold for a condition with nuanced or subjective logic. We recommend providing at most 5 examples.
878+
that you expect. If ``output_mode`` is ``multi``, each example output must be a list or tuple of strings.
879+
You can provide examples to help the model understand your intended threshold for a condition with nuanced
880+
or subjective logic. We recommend providing at most 5 examples.
877881
connection_id (str, optional):
878882
Specifies the connection to use to communicate with the model. For example, ``myproject.us.myconnection``.
879883
If not provided, the query uses your end-user credential.
880884
endpoint (str, optional):
881885
A STRING value that specifies the Vertex AI endpoint to use for the model. You can specify any
882886
generally available or preview Gemini model. If you specify the model name, BigQuery ML automatically
883887
identifies and uses the full endpoint of the model.
888+
output_mode (Literal["single", "multi"], optional):
889+
A STRING value that indicates whether a single input can be classified into multiple categories.
890+
Supported values are ``single`` and ``multi``.
884891
optimization_mode (Literal["minimize_cost", "maximize_quality"], optional):
885892
A STRING value that specifies the optimization strategy to use. Supported values are ``minimize_cost``
886893
and ``maximize_quality``.
@@ -890,20 +897,27 @@ def classify(
890897
This argument isn't supported when ``optimization_mode`` is set to ``minimize_cost``.
891898
892899
Returns:
893-
bigframes.series.Series: A new series of strings.
900+
bigframes.series.Series: A new series of strings (or a series of arrays of strings if ``output_mode`` is specified).
894901
"""
895902

896903
prompt_context, series_list = _separate_context_and_series(input)
897904
assert len(series_list) > 0
898905

899-
example_tuples = tuple(examples) if examples is not None else None
906+
if examples is not None:
907+
example_tuples: Any = tuple(
908+
(ex[0], tuple(ex[1]) if isinstance(ex[1], (list, tuple)) else ex[1])
909+
for ex in examples
910+
)
911+
else:
912+
example_tuples = None
900913

901914
operator = ai_ops.AIClassify(
902915
prompt_context=tuple(prompt_context),
903916
categories=tuple(categories),
904917
examples=example_tuples,
905918
connection_id=connection_id,
906919
endpoint=endpoint,
920+
output_mode=output_mode,
907921
optimization_mode=_upper_optional(optimization_mode),
908922
max_error_ratio=max_error_ratio,
909923
)

packages/bigframes/bigframes/core/compile/ibis_compiler/scalar_op_registry.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616

1717
import functools
1818
import typing
19-
from typing import cast
19+
from typing import cast, Any
2020

2121
import bigframes_vendored.ibis.expr.api as ibis_api
2222
import bigframes_vendored.ibis.expr.datatypes as ibis_dtypes
@@ -1999,6 +1999,7 @@ def ai_classify(
19991999
_construct_examples(op.examples), # type: ignore
20002000
op.connection_id, # type: ignore
20012001
op.endpoint, # type: ignore
2002+
op.output_mode, # type: ignore
20022003
op.optimization_mode, # type: ignore
20032004
op.max_error_ratio, # type: ignore
20042005
).to_expr()
@@ -2045,20 +2046,19 @@ def _construct_prompt(
20452046

20462047

20472048
def _construct_examples(
2048-
examples: tuple[tuple[str, str]] | None,
2049+
examples: tuple[tuple[str, str | tuple[str, ...]], ...] | None,
20492050
) -> ibis_types.ArrayValue | None:
20502051
if examples is None:
20512052
return None
20522053

20532054
results: list[ibis_types.StructValue] = []
20542055

20552056
for example in examples:
2056-
ibis_example = ibis.struct(
2057-
{
2058-
"_field_1": example[0],
2059-
"_field_2": example[1],
2060-
}
2061-
)
2057+
value: Any = example[1]
2058+
if isinstance(example[1], (list, tuple)):
2059+
value = list(example[1])
2060+
2061+
ibis_example = ibis.struct({"_field_1": example[0], "_field_2": value})
20622062
results.append(ibis_example)
20632063

20642064
return ibis.array(results)

packages/bigframes/bigframes/core/compile/sqlglot/expressions/ai_ops.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -140,12 +140,17 @@ def _construct_named_args(op: ops.ScalarOp) -> list[sge.Kwarg]:
140140
)
141141
)
142142
elif field == "examples":
143-
example_expressions = [
144-
sge.Tuple(
145-
expressions=[sge.Literal.string(key), sge.Literal.string(val)]
143+
example_expressions = []
144+
for key, val in value:
145+
if isinstance(val, (list, tuple)):
146+
val_expr: sge.Array | sge.Literal = sge.array(
147+
*[sge.Literal.string(v) for v in val]
148+
)
149+
else:
150+
val_expr = sge.Literal.string(val)
151+
example_expressions.append(
152+
sge.Tuple(expressions=[sge.Literal.string(key), val_expr])
146153
)
147-
for key, val in value
148-
]
149154
args.append(
150155
sge.Kwarg(this=field, expression=sge.array(*example_expressions))
151156
)

packages/bigframes/bigframes/operations/ai_ops.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -160,13 +160,18 @@ class AIClassify(base_ops.NaryOp):
160160

161161
prompt_context: Tuple[str | None, ...]
162162
categories: tuple[str, ...]
163-
examples: tuple[tuple[str, str], ...] | None = None
163+
examples: (
164+
tuple[tuple[str, str], ...] | tuple[tuple[str, tuple[str, ...]], ...] | None
165+
) = None
164166
connection_id: str | None = None
165167
endpoint: str | None = None
168+
output_mode: str | None = None
166169
optimization_mode: str | None = None
167170
max_error_ratio: float | None = None
168171

169172
def output_type(self, *input_types: dtypes.ExpressionType) -> dtypes.ExpressionType:
173+
if self.output_mode is not None:
174+
return dtypes.list_type(dtypes.STRING_DTYPE)
170175
return dtypes.STRING_DTYPE
171176

172177

packages/bigframes/tests/system/small/bigquery/test_ai.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -336,6 +336,17 @@ def test_ai_classify_with_examples(session):
336336
assert result.dtype == dtypes.STRING_DTYPE
337337

338338

339+
def test_ai_classify_output_mode(session, bq_connection):
340+
s = bpd.Series(["cat", "orchid"], session=session)
341+
342+
result = bbq.ai.classify(
343+
s, ["animal", "plant"], output_mode="multi", examples=[("dog", ["animal"])]
344+
)
345+
346+
assert len(result) == len(s)
347+
assert result.dtype == dtypes.list_type(dtypes.STRING_DTYPE)
348+
349+
339350
def test_ai_classify_multi_model(session, bq_connection):
340351
df = session.from_glob_path(
341352
"gs://bigframes-dev-testing/a_multimodal/images/*",
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
SELECT
2+
AI.CLASSIFY(
3+
input => (`string_col`),
4+
categories => ['greeting', 'rejection'],
5+
examples => [('hi', ['greeting', 'positive']), ('bye', ['rejection', 'negative'])],
6+
output_mode => 'multi'
7+
) AS `result`
8+
FROM `bigframes-dev`.`sqlglot_test`.`scalar_types` AS `bft_0`
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
SELECT
2+
AI.CLASSIFY(
3+
input => (`string_col`),
4+
categories => ['greeting', 'rejection'],
5+
output_mode => 'multi'
6+
) AS `result`
7+
FROM `bigframes-dev`.`sqlglot_test`.`scalar_types` AS `bft_0`

packages/bigframes/tests/unit/core/compile/sqlglot/expressions/test_ai_ops.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -363,6 +363,41 @@ def test_ai_classify_with_params(scalar_types_df: dataframe.DataFrame, snapshot)
363363
snapshot.assert_match(sql, "out.sql")
364364

365365

366+
def test_ai_classify_with_output_mode(scalar_types_df: dataframe.DataFrame, snapshot):
367+
col_name = "string_col"
368+
369+
op = ops.AIClassify(
370+
prompt_context=(None,),
371+
categories=("greeting", "rejection"),
372+
output_mode="multi",
373+
)
374+
375+
sql = utils._apply_ops_to_sql(scalar_types_df, [op.as_expr(col_name)], ["result"])
376+
377+
snapshot.assert_match(sql, "out.sql")
378+
379+
380+
def test_ai_classify_multi_with_list_examples(
381+
scalar_types_df: dataframe.DataFrame, snapshot
382+
):
383+
col_name = "string_col"
384+
385+
examples = (
386+
("hi", ("greeting", "positive")),
387+
("bye", ("rejection", "negative")),
388+
)
389+
op = ops.AIClassify(
390+
prompt_context=(None,),
391+
categories=("greeting", "rejection"),
392+
examples=examples,
393+
output_mode="multi",
394+
)
395+
396+
sql = utils._apply_ops_to_sql(scalar_types_df, [op.as_expr(col_name)], ["result"])
397+
398+
snapshot.assert_match(sql, "out.sql")
399+
400+
366401
@pytest.mark.parametrize("connection_id", [None, CONNECTION_ID])
367402
def test_ai_score(scalar_types_df: dataframe.DataFrame, snapshot, connection_id):
368403
col_name = "string_col"

packages/bigframes/third_party/bigframes_vendored/ibis/expr/operations/ai_ops.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -158,13 +158,16 @@ class AIClassify(Value):
158158
examples: Optional[Value]
159159
connection_id: Optional[Value[dt.String]]
160160
endpoint: Optional[Value[dt.String]]
161+
output_mode: Optional[Value[dt.String]]
161162
optimization_mode: Optional[Value[dt.String]]
162163
max_error_ratio: Optional[Value[dt.Float64]]
163164

164165
shape = rlz.shape_like("input")
165166

166167
@attribute
167168
def dtype(self) -> dt.DataType:
169+
if self.output_mode is not None:
170+
return dt.Array(dt.string)
168171
return dt.string
169172

170173

0 commit comments

Comments
 (0)