diff --git a/tests/system/vertexai/test_generative_models.py b/tests/system/vertexai/test_generative_models.py index 52bf202eb2..0dfc49eb93 100644 --- a/tests/system/vertexai/test_generative_models.py +++ b/tests/system/vertexai/test_generative_models.py @@ -24,6 +24,7 @@ from google.cloud import aiplatform from tests.system.aiplatform import e2e_base from vertexai import generative_models +from vertexai.preview import generative_models as preview_generative_models class TestGenerativeModels(e2e_base.TestEndToEnd): @@ -134,6 +135,20 @@ def test_generate_content_from_text_and_remote_video(self): assert response.text assert "Zootopia" in response.text + def test_grounding_google_search_retriever(self): + model = preview_generative_models.GenerativeModel("gemini-pro") + google_search_retriever_tool = ( + preview_generative_models.Tool.from_google_search_retrieval( + preview_generative_models.grounding.GoogleSearchRetrieval( + disable_attribution=False + ) + ) + ) + response = model.generate_content( + "Why is sky blue?", tools=[google_search_retriever_tool] + ) + assert response.text + # Chat def test_send_message_from_text(self): diff --git a/tests/unit/vertexai/test_generative_models.py b/tests/unit/vertexai/test_generative_models.py index 37051b5aea..cf6f82d2c5 100644 --- a/tests/unit/vertexai/test_generative_models.py +++ b/tests/unit/vertexai/test_generative_models.py @@ -121,12 +121,27 @@ def mock_generate_content( contents: Optional[MutableSequence[gapic_content_types.Content]] = None, ) -> Iterable[gapic_prediction_service_types.GenerateContentResponse]: is_continued_chat = len(request.contents) > 1 - has_tools = bool(request.tools) + has_retrieval = any( + tool.retrieval or tool.google_search_retrieval for tool in request.tools + ) + has_function_declarations = any( + tool.function_declarations for tool in request.tools + ) + has_function_request = any( + content.parts[0].function_call for content in request.contents + ) + has_function_response = any( + content.parts[0].function_response for content in request.contents + ) - if has_tools: - has_function_response = any( - "function_response" in content.parts[0] for content in request.contents - ) + if has_function_request: + assert has_function_response + + if has_function_response: + assert has_function_request + assert has_function_declarations + + if has_function_declarations: needs_function_call = not has_function_response if needs_function_call: response_part_struct = _RESPONSE_FUNCTION_CALL_PART_STRUCT @@ -158,6 +173,24 @@ def mock_generate_content( gapic_content_types.Citation(_RESPONSE_CITATION_STRUCT), ] ), + grounding_metadata=gapic_content_types.GroundingMetadata( + web_search_queries=[request.contents[0].parts[0].text], + grounding_attributions=[ + gapic_content_types.GroundingAttribution( + segment=gapic_content_types.Segment( + start_index=0, + end_index=67, + ), + confidence_score=0.69857746, + web=gapic_content_types.GroundingAttribution.Web( + uri="https://math.ucr.edu/home/baez/physics/General/BlueSky/blue_sky.html", + title="Why is the sky blue? - UCR Math", + ), + ), + ], + ) + if has_retrieval and request.contents[0].parts[0].text + else None, ), ], ) @@ -288,3 +321,41 @@ def test_chat_function_calling(self, generative_models: generative_models): ), ) assert response2.text == "The weather in Boston is super nice!" + + @mock.patch.object( + target=prediction_service.PredictionServiceClient, + attribute="generate_content", + new=mock_generate_content, + ) + def test_generate_content_grounding_google_search_retriever(self): + model = preview_generative_models.GenerativeModel("gemini-pro") + google_search_retriever_tool = ( + preview_generative_models.Tool.from_google_search_retrieval( + preview_generative_models.grounding.GoogleSearchRetrieval( + disable_attribution=False + ) + ) + ) + response = model.generate_content( + "Why is sky blue?", tools=[google_search_retriever_tool] + ) + assert response.text + + @mock.patch.object( + target=prediction_service.PredictionServiceClient, + attribute="generate_content", + new=mock_generate_content, + ) + def test_generate_content_grounding_vertex_ai_search_retriever(self): + model = preview_generative_models.GenerativeModel("gemini-pro") + google_search_retriever_tool = preview_generative_models.Tool.from_retrieval( + retrieval=preview_generative_models.grounding.Retrieval( + source=preview_generative_models.grounding.VertexAISearch( + datastore=f"projects/{_TEST_PROJECT}/locations/global/collections/default_collection/dataStores/test-datastore", + ) + ) + ) + response = model.generate_content( + "Why is sky blue?", tools=[google_search_retriever_tool] + ) + assert response.text diff --git a/vertexai/generative_models/_generative_models.py b/vertexai/generative_models/_generative_models.py index 779ef355fb..a7643e3cbe 100644 --- a/vertexai/generative_models/_generative_models.py +++ b/vertexai/generative_models/_generative_models.py @@ -1132,6 +1132,40 @@ def __init__( function_declarations=gapic_function_declarations ) + @classmethod + def from_function_declarations( + cls, + function_declarations: List["FunctionDeclaration"], + ): + gapic_function_declarations = [ + function_declaration._raw_function_declaration + for function_declaration in function_declarations + ] + raw_tool = gapic_tool_types.Tool( + function_declarations=gapic_function_declarations + ) + return cls._from_gapic(raw_tool=raw_tool) + + @classmethod + def from_retrieval( + cls, + retrieval: "Retrieval", + ): + raw_tool = gapic_tool_types.Tool( + retrieval=retrieval._raw_retrieval + ) + return cls._from_gapic(raw_tool=raw_tool) + + @classmethod + def from_google_search_retrieval( + cls, + google_search_retrieval: "GoogleSearchRetrieval", + ): + raw_tool = gapic_tool_types.Tool( + google_search_retrieval=google_search_retrieval._raw_google_search_retrieval + ) + return cls._from_gapic(raw_tool=raw_tool) + @classmethod def _from_gapic( cls, @@ -1520,6 +1554,87 @@ def _image(self) -> "Image": return Image.from_bytes(data=self._raw_part.inline_data.data) +class grounding: # pylint: disable=invalid-name + """Grounding namespace.""" + + def __init__(self): + raise RuntimeError("This class must not be instantiated.") + + class Retrieval: + """Defines a retrieval tool that model can call to access external knowledge.""" + + def __init__( + self, + source: Union["grounding.VertexAISearch"], + disable_attribution: Optional[bool] = None, + ): + """Initializes a Retrieval tool. + + Args: + source (VertexAISearch): + Set to use data source powered by Vertex AI Search. + disable_attribution (bool): + Optional. Disable using the result from this + tool in detecting grounding attribution. This + does not affect how the result is given to the + model for generation. + """ + self._raw_retrieval = gapic_tool_types.Retrieval( + vertex_ai_search=source._raw_vertex_ai_search, + disable_attribution=disable_attribution, + ) + + class VertexAISearch: + r"""Retrieve from Vertex AI Search datastore for grounding. + See https://cloud.google.com/vertex-ai-search-and-conversation + """ + + def __init__( + self, + datastore: str, + ): + """Initializes a Vertex AI Search tool. + + Args: + datastore (str): + Required. Fully-qualified Vertex AI Search's + datastore resource ID. + projects/<>/locations/<>/collections/<>/dataStores/<> + """ + self._raw_vertex_ai_search = gapic_tool_types.VertexAISearch( + datastore=datastore, + ) + + class GoogleSearchRetrieval: + r"""Tool to retrieve public web data for grounding, powered by + Google Search. + + Attributes: + disable_attribution (bool): + Optional. Disable using the result from this + tool in detecting grounding attribution. This + does not affect how the result is given to the + model for generation. + """ + + def __init__( + self, + disable_attribution: Optional[bool] = None, + ): + """Initializes a Google Search Retrieval tool. + + Args: + disable_attribution (bool): + Optional. Disable using the result from this + tool in detecting grounding attribution. This + does not affect how the result is given to the + model for generation. + """ + self._raw_google_search_retrieval = gapic_tool_types.GoogleSearchRetrieval( + disable_attribution=disable_attribution, + ) + + def _to_content( value: Union[ gapic_content_types.Content, diff --git a/vertexai/preview/generative_models.py b/vertexai/preview/generative_models.py index 2d08bbf7b3..db6701b93c 100644 --- a/vertexai/preview/generative_models.py +++ b/vertexai/preview/generative_models.py @@ -17,6 +17,7 @@ # We just want to re-export certain classes # pylint: disable=g-multiple-import,g-importing-member from vertexai.generative_models._generative_models import ( + grounding, _PreviewGenerativeModel, GenerationConfig, GenerationResponse, @@ -39,6 +40,7 @@ class GenerativeModel(_PreviewGenerativeModel): __all__ = [ + "grounding", "GenerationConfig", "GenerativeModel", "GenerationResponse",