@@ -842,9 +842,12 @@ def classify(
842842 input : PROMPT_TYPE ,
843843 categories : tuple [str , ...] | list [str ],
844844 * ,
845- examples : list [tuple [str , str ]] | None = None ,
845+ examples : list [tuple [str , str ]]
846+ | list [tuple [str , list [str ] | tuple [str , ...]]]
847+ | None = None ,
846848 connection_id : str | None = None ,
847849 endpoint : str | None = None ,
850+ output_mode : Literal ["single" , "multi" ] | None = None ,
848851 optimization_mode : Literal ["minimize_cost" , "maximize_quality" ] | None = None ,
849852 max_error_ratio : float | None = None ,
850853) -> series .Series :
@@ -870,17 +873,21 @@ def classify(
870873 or pandas Series.
871874 categories (tuple[str, ...] | list[str]):
872875 Categories to classify the input into.
873- examples (list[tuple[str, str]], optional):
876+ examples (list[tuple[str, str]] | list[tuple[str, list[str] | tuple[str, ...]]] , optional):
874877 An array that contains representative examples of input strings and the output category
875- that you expect. You can provide examples to help the model understand your
876- intended threshold for a condition with nuanced or subjective logic. We recommend providing at most 5 examples.
878+ that you expect. If ``output_mode`` is ``multi``, each example output must be a list or tuple of strings.
879+ You can provide examples to help the model understand your intended threshold for a condition with nuanced
880+ or subjective logic. We recommend providing at most 5 examples.
877881 connection_id (str, optional):
878882 Specifies the connection to use to communicate with the model. For example, ``myproject.us.myconnection``.
879883 If not provided, the query uses your end-user credential.
880884 endpoint (str, optional):
881885 A STRING value that specifies the Vertex AI endpoint to use for the model. You can specify any
882886 generally available or preview Gemini model. If you specify the model name, BigQuery ML automatically
883887 identifies and uses the full endpoint of the model.
888+ output_mode (Literal["single", "multi"], optional):
889+ A STRING value that indicates whether a single input can be classified into multiple categories.
890+ Supported values are ``single`` and ``multi``.
884891 optimization_mode (Literal["minimize_cost", "maximize_quality"], optional):
885892 A STRING value that specifies the optimization strategy to use. Supported values are ``minimize_cost``
886893 and ``maximize_quality``.
@@ -890,20 +897,27 @@ def classify(
890897 This argument isn't supported when ``optimization_mode`` is set to ``minimize_cost``.
891898
892899 Returns:
893- bigframes.series.Series: A new series of strings.
900+ bigframes.series.Series: A new series of strings (or a series of arrays of strings if ``output_mode`` is specified) .
894901 """
895902
896903 prompt_context , series_list = _separate_context_and_series (input )
897904 assert len (series_list ) > 0
898905
899- example_tuples = tuple (examples ) if examples is not None else None
906+ if examples is not None :
907+ example_tuples : Any = tuple (
908+ (ex [0 ], tuple (ex [1 ]) if isinstance (ex [1 ], (list , tuple )) else ex [1 ])
909+ for ex in examples
910+ )
911+ else :
912+ example_tuples = None
900913
901914 operator = ai_ops .AIClassify (
902915 prompt_context = tuple (prompt_context ),
903916 categories = tuple (categories ),
904917 examples = example_tuples ,
905918 connection_id = connection_id ,
906919 endpoint = endpoint ,
920+ output_mode = output_mode ,
907921 optimization_mode = _upper_optional (optimization_mode ),
908922 max_error_ratio = max_error_ratio ,
909923 )
0 commit comments