diff --git a/bigframes/bigquery/_operations/ai.py b/bigframes/bigquery/_operations/ai.py index 0c5eba9496..5c001d4caf 100644 --- a/bigframes/bigquery/_operations/ai.py +++ b/bigframes/bigquery/_operations/ai.py @@ -28,6 +28,7 @@ from bigframes.operations import ai_ops, output_schemas PROMPT_TYPE = Union[ + str, series.Series, pd.Series, List[Union[str, series.Series, pd.Series]], @@ -73,7 +74,7 @@ def generate( dtype: struct>, status: string>[pyarrow] Args: - prompt (Series | List[str|Series] | Tuple[str|Series, ...]): + prompt (str | Series | List[str|Series] | Tuple[str|Series, ...]): A mixture of Series and string literals that specifies the prompt to send to the model. The Series can be BigFrames Series or pandas Series. connection_id (str, optional): @@ -165,7 +166,7 @@ def generate_bool( Name: result, dtype: boolean Args: - prompt (Series | List[str|Series] | Tuple[str|Series, ...]): + prompt (str | Series | List[str|Series] | Tuple[str|Series, ...]): A mixture of Series and string literals that specifies the prompt to send to the model. The Series can be BigFrames Series or pandas Series. connection_id (str, optional): @@ -240,7 +241,7 @@ def generate_int( Name: result, dtype: Int64 Args: - prompt (Series | List[str|Series] | Tuple[str|Series, ...]): + prompt (str | Series | List[str|Series] | Tuple[str|Series, ...]): A mixture of Series and string literals that specifies the prompt to send to the model. The Series can be BigFrames Series or pandas Series. connection_id (str, optional): @@ -315,7 +316,7 @@ def generate_double( Name: result, dtype: Float64 Args: - prompt (Series | List[str|Series] | Tuple[str|Series, ...]): + prompt (str | Series | List[str|Series] | Tuple[str|Series, ...]): A mixture of Series and string literals that specifies the prompt to send to the model. The Series can be BigFrames Series or pandas Series. connection_id (str, optional): @@ -386,7 +387,7 @@ def if_( dtype: string Args: - prompt (Series | List[str|Series] | Tuple[str|Series, ...]): + prompt (str | Series | List[str|Series] | Tuple[str|Series, ...]): A mixture of Series and string literals that specifies the prompt to send to the model. The Series can be BigFrames Series or pandas Series. connection_id (str, optional): @@ -433,7 +434,7 @@ def classify( [2 rows x 2 columns] Args: - input (Series | List[str|Series] | Tuple[str|Series, ...]): + input (str | Series | List[str|Series] | Tuple[str|Series, ...]): A mixture of Series and string literals that specifies the input to send to the model. The Series can be BigFrames Series or pandas Series. categories (tuple[str, ...] | list[str]): @@ -482,7 +483,7 @@ def score( dtype: Float64 Args: - prompt (Series | List[str|Series] | Tuple[str|Series, ...]): + prompt (str | Series | List[str|Series] | Tuple[str|Series, ...]): A mixture of Series and string literals that specifies the prompt to send to the model. The Series can be BigFrames Series or pandas Series. connection_id (str, optional): @@ -514,9 +515,12 @@ def _separate_context_and_series( Input: ("str1", series1, "str2", "str3", series2) Output: ["str1", None, "str2", "str3", None], [series1, series2] """ - if not isinstance(prompt, (list, tuple, series.Series)): + if not isinstance(prompt, (str, list, tuple, series.Series)): raise ValueError(f"Unsupported prompt type: {type(prompt)}") + if isinstance(prompt, str): + return [None], [series.Series([prompt])] + if isinstance(prompt, series.Series): if prompt.dtype == dtypes.OBJ_REF_DTYPE: # Multi-model support diff --git a/tests/system/small/bigquery/test_ai.py b/tests/system/small/bigquery/test_ai.py index 2ccdb01944..203de616ee 100644 --- a/tests/system/small/bigquery/test_ai.py +++ b/tests/system/small/bigquery/test_ai.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from unittest import mock + from packaging import version import pandas as pd import pyarrow as pa @@ -42,6 +44,27 @@ def test_ai_function_pandas_input(session): ) +def test_ai_function_string_input(session): + with mock.patch( + "bigframes.core.global_session.get_global_session" + ) as mock_get_session: + mock_get_session.return_value = session + prompt = "Is apple a fruit?" + + result = bbq.ai.generate_bool(prompt, endpoint="gemini-2.5-flash") + + assert _contains_no_nulls(result) + assert result.dtype == pd.ArrowDtype( + pa.struct( + ( + pa.field("result", pa.bool_()), + pa.field("full_response", dtypes.JSON_ARROW_TYPE), + pa.field("status", pa.string()), + ) + ) + ) + + def test_ai_function_compile_model_params(session): if version.Version(sqlglot.__version__) < version.Version("25.18.0"): pytest.skip(