Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
97 changes: 74 additions & 23 deletions bigframes/bigquery/_operations/ai.py
Original file line number Diff line number Diff line change
Expand Up @@ -348,20 +348,20 @@ def if_(
provides optimization such that not all rows are evaluated with the LLM.

**Examples:**
>>> import bigframes.pandas as bpd
>>> import bigframes.bigquery as bbq
>>> bpd.options.display.progress_bar = None
>>> us_state = bpd.Series(["Massachusetts", "Illinois", "Hawaii"])
>>> bbq.ai.if_((us_state, " has a city called Springfield"))
0 True
1 True
2 False
dtype: boolean

>>> us_state[bbq.ai.if_((us_state, " has a city called Springfield"))]
0 Massachusetts
1 Illinois
dtype: string
>>> import bigframes.pandas as bpd
>>> import bigframes.bigquery as bbq
>>> bpd.options.display.progress_bar = None
>>> us_state = bpd.Series(["Massachusetts", "Illinois", "Hawaii"])
>>> bbq.ai.if_((us_state, " has a city called Springfield"))
0 True
1 True
2 False
dtype: boolean

>>> us_state[bbq.ai.if_((us_state, " has a city called Springfield"))]
0 Massachusetts
1 Illinois
dtype: string

Args:
prompt (Series | List[str|Series] | Tuple[str|Series, ...]):
Expand All @@ -386,6 +386,56 @@ def if_(
return series_list[0]._apply_nary_op(operator, series_list[1:])


@log_adapter.method_logger(custom_base_name="bigquery_ai")
def classify(
input: PROMPT_TYPE,
categories: tuple[str, ...] | list[str],
*,
connection_id: str | None = None,
) -> series.Series:
"""
Classifies a given input into one of the specified categories. It will always return one of the provided categories best fit the prompt input.

**Examples:**

>>> import bigframes.pandas as bpd
>>> import bigframes.bigquery as bbq
>>> bpd.options.display.progress_bar = None
>>> df = bpd.DataFrame({'creature': ['Cat', 'Salmon']})
>>> df['type'] = bbq.ai.classify(df['creature'], ['Mammal', 'Fish'])
>>> df
creature type
0 Cat Mammal
1 Salmon Fish
<BLANKLINE>
[2 rows x 2 columns]

Args:
input (Series | List[str|Series] | Tuple[str|Series, ...]):
A mixture of Series and string literals that specifies the input to send to the model. The Series can be BigFrames Series
or pandas Series.
categories (tuple[str, ...] | list[str]):
Categories to classify the input into.
connection_id (str, optional):
Specifies the connection to use to communicate with the model. For example, `myproject.us.myconnection`.
If not provided, the connection from the current session will be used.

Returns:
bigframes.series.Series: A new series of strings.
"""

prompt_context, series_list = _separate_context_and_series(input)
assert len(series_list) > 0

operator = ai_ops.AIClassify(
prompt_context=tuple(prompt_context),
categories=tuple(categories),
connection_id=_resolve_connection_id(series_list[0], connection_id),
)

return series_list[0]._apply_nary_op(operator, series_list[1:])


@log_adapter.method_logger(custom_base_name="bigquery_ai")
def score(
prompt: PROMPT_TYPE,
Expand All @@ -398,15 +448,16 @@ def score(
rubric with examples in the prompt.

**Examples:**
>>> import bigframes.pandas as bpd
>>> import bigframes.bigquery as bbq
>>> bpd.options.display.progress_bar = None
>>> animal = bpd.Series(["Tiger", "Rabbit", "Blue Whale"])
>>> bbq.ai.score(("Rank the relative weights of ", animal, " on the scale from 1 to 3")) # doctest: +SKIP
0 2.0
1 1.0
2 3.0
dtype: Float64

>>> import bigframes.pandas as bpd
>>> import bigframes.bigquery as bbq
>>> bpd.options.display.progress_bar = None
>>> animal = bpd.Series(["Tiger", "Rabbit", "Blue Whale"])
>>> bbq.ai.score(("Rank the relative weights of ", animal, " on the scale from 1 to 3")) # doctest: +SKIP
0 2.0
1 1.0
2 3.0
dtype: Float64

Args:
prompt (Series | List[str|Series] | Tuple[str|Series, ...]):
Expand Down
12 changes: 12 additions & 0 deletions bigframes/core/compile/ibis_compiler/scalar_op_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -2039,6 +2039,18 @@ def ai_if(*values: ibis_types.Value, op: ops.AIIf) -> ibis_types.StructValue:
).to_expr()


@scalar_op_compiler.register_nary_op(ops.AIClassify, pass_op=True)
def ai_classify(
*values: ibis_types.Value, op: ops.AIClassify
) -> ibis_types.StructValue:

return ai_ops.AIClassify(
_construct_prompt(values, op.prompt_context), # type: ignore
op.categories, # type: ignore
op.connection_id, # type: ignore
).to_expr()


@scalar_op_compiler.register_nary_op(ops.AIScore, pass_op=True)
def ai_score(*values: ibis_types.Value, op: ops.AIScore) -> ibis_types.StructValue:

Expand Down
21 changes: 19 additions & 2 deletions bigframes/core/compile/sqlglot/expressions/ai_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,21 @@ def _(*exprs: TypedExpr, op: ops.AIIf) -> sge.Expression:
return sge.func("AI.IF", *args)


@register_nary_op(ops.AIClassify, pass_op=True)
def _(*exprs: TypedExpr, op: ops.AIClassify) -> sge.Expression:
category_literals = [sge.Literal.string(cat) for cat in op.categories]
categories_arg = sge.Kwarg(
this="categories", expression=sge.array(*category_literals)
)

args = [
_construct_prompt(exprs, op.prompt_context, param_name="input"),
categories_arg,
] + _construct_named_args(op)

return sge.func("AI.CLASSIFY", *args)


@register_nary_op(ops.AIScore, pass_op=True)
def _(*exprs: TypedExpr, op: ops.AIScore) -> sge.Expression:
args = [_construct_prompt(exprs, op.prompt_context)] + _construct_named_args(op)
Expand All @@ -69,7 +84,9 @@ def _(*exprs: TypedExpr, op: ops.AIScore) -> sge.Expression:


def _construct_prompt(
exprs: tuple[TypedExpr, ...], prompt_context: tuple[str | None, ...]
exprs: tuple[TypedExpr, ...],
prompt_context: tuple[str | None, ...],
param_name: str = "prompt",
) -> sge.Kwarg:
prompt: list[str | sge.Expression] = []
column_ref_idx = 0
Expand All @@ -80,7 +97,7 @@ def _construct_prompt(
else:
prompt.append(sge.Literal.string(elem))

return sge.Kwarg(this="prompt", expression=sge.Tuple(expressions=prompt))
return sge.Kwarg(this=param_name, expression=sge.Tuple(expressions=prompt))


def _construct_named_args(op: ops.NaryOp) -> list[sge.Kwarg]:
Expand Down
2 changes: 2 additions & 0 deletions bigframes/operations/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from __future__ import annotations

from bigframes.operations.ai_ops import (
AIClassify,
AIGenerate,
AIGenerateBool,
AIGenerateDouble,
Expand Down Expand Up @@ -419,6 +420,7 @@
"geo_y_op",
"GeoStDistanceOp",
# AI ops
"AIClassify",
"AIGenerate",
"AIGenerateBool",
"AIGenerateDouble",
Expand Down
12 changes: 12 additions & 0 deletions bigframes/operations/ai_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,18 @@ def output_type(self, *input_types: dtypes.ExpressionType) -> dtypes.ExpressionT
return dtypes.BOOL_DTYPE


@dataclasses.dataclass(frozen=True)
class AIClassify(base_ops.NaryOp):
name: ClassVar[str] = "ai_classify"

prompt_context: Tuple[str | None, ...]
categories: tuple[str, ...]
connection_id: str

def output_type(self, *input_types: dtypes.ExpressionType) -> dtypes.ExpressionType:
return dtypes.STRING_DTYPE


@dataclasses.dataclass(frozen=True)
class AIScore(base_ops.NaryOp):
name: ClassVar[str] = "ai_score"
Expand Down
21 changes: 21 additions & 0 deletions tests/system/small/bigquery/test_ai.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,6 +225,27 @@ def test_ai_if_multi_model(session):
assert result.dtype == dtypes.BOOL_DTYPE


def test_ai_classify(session):
s = bpd.Series(["cat", "orchid"], session=session)
bpd.options.display.repr_mode = "deferred"

result = bbq.ai.classify(s, ["animal", "plant"])

assert _contains_no_nulls(result)
assert result.dtype == dtypes.STRING_DTYPE


def test_ai_classify_multi_model(session):
df = session.from_glob_path(
"gs://bigframes-dev-testing/a_multimodel/images/*", name="image"
)

result = bbq.ai.classify(df["image"], ["photo", "cartoon"])

assert _contains_no_nulls(result)
assert result.dtype == dtypes.STRING_DTYPE


def test_ai_score(session):
s = bpd.Series(["Tiger", "Rabbit"], session=session)
prompt = ("Rank the relative weights of ", s, " on the scale from 1 to 3")
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
WITH `bfcte_0` AS (
SELECT
`string_col` AS `bfcol_0`
FROM `bigframes-dev`.`sqlglot_test`.`scalar_types`
), `bfcte_1` AS (
SELECT
*,
AI.CLASSIFY(
input => (`bfcol_0`),
categories => ['greeting', 'rejection'],
connection_id => 'bigframes-dev.us.bigframes-default-connection'
) AS `bfcol_1`
FROM `bfcte_0`
)
SELECT
`bfcol_1` AS `result`
FROM `bfcte_1`
14 changes: 14 additions & 0 deletions tests/unit/core/compile/sqlglot/expressions/test_ai_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,20 @@ def test_ai_if(scalar_types_df: dataframe.DataFrame, snapshot):
snapshot.assert_match(sql, "out.sql")


def test_ai_classify(scalar_types_df: dataframe.DataFrame, snapshot):
col_name = "string_col"

op = ops.AIClassify(
prompt_context=(None,),
categories=("greeting", "rejection"),
connection_id=CONNECTION_ID,
)

sql = utils._apply_unary_ops(scalar_types_df, [op.as_expr(col_name)], ["result"])

snapshot.assert_match(sql, "out.sql")


def test_ai_score(scalar_types_df: dataframe.DataFrame, snapshot):
col_name = "string_col"

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1119,6 +1119,9 @@ def visit_AIGenerateDouble(self, op, **kwargs):
def visit_AIIf(self, op, **kwargs):
return sge.func("AI.IF", *self._compile_ai_args(**kwargs))

def visit_AIClassify(self, op, **kwargs):
return sge.func("AI.CLASSIFY", *self._compile_ai_args(**kwargs))

def visit_AIScore(self, op, **kwargs):
return sge.func("AI.SCORE", *self._compile_ai_args(**kwargs))

Expand Down
15 changes: 15 additions & 0 deletions third_party/bigframes_vendored/ibis/expr/operations/ai_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,21 @@ def dtype(self) -> dt.Struct:
return dt.bool


@public
class AIClassify(Value):
"""Generate True/False based on the prompt"""

input: Value
categories: Value[dt.Array[dt.String]]
connection_id: Value[dt.String]

shape = rlz.shape_like("input")

@attribute
def dtype(self) -> dt.Struct:
return dt.string


@public
class AIScore(Value):
"""Generate doubles based on the prompt"""
Expand Down