Skip to content

Commit

Permalink
feat: GenAI - Allowed callable functions to return values directly in…
Browse files Browse the repository at this point in the history
… Automatic Function Calling

PiperOrigin-RevId: 640574734
  • Loading branch information
jaycee-li authored and Copybara-Service committed Jun 5, 2024
1 parent 945b9e4 commit 768af67
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 1 deletion.
47 changes: 46 additions & 1 deletion tests/unit/vertexai/test_generative_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -950,7 +950,7 @@ def test_generate_content_vertex_rag_retriever(self):
attribute="generate_content",
new=mock_generate_content,
)
def test_chat_automatic_function_calling(self):
def test_chat_automatic_function_calling_with_function_returning_dict(self):
generative_models = preview_generative_models
get_current_weather_func = generative_models.FunctionDeclaration.from_func(
get_current_weather
Expand Down Expand Up @@ -984,6 +984,51 @@ def test_chat_automatic_function_calling(self):
chat2.send_message("What is the weather like in Boston?")
assert err.match("Exceeded the maximum")

@mock.patch.object(
target=prediction_service.PredictionServiceClient,
attribute="generate_content",
new=mock_generate_content,
)
def test_chat_automatic_function_calling_with_function_returning_value(self):
# Define a new function that returns a value instead of a dict.
def get_current_weather(location: str):
"""Gets weather in the specified location.
Args:
location: The location for which to get the weather.
Returns:
The weather information as a str.
"""
if location == "Boston":
return "Super nice, but maybe a bit hot."
return "Unavailable"

generative_models = preview_generative_models
get_current_weather_func = generative_models.FunctionDeclaration.from_func(
get_current_weather
)
weather_tool = generative_models.Tool(
function_declarations=[get_current_weather_func],
)

model = generative_models.GenerativeModel(
"gemini-pro",
# Specifying the tools once to avoid specifying them in every request
tools=[weather_tool],
)
afc_responder = generative_models.AutomaticFunctionCallingResponder(
max_automatic_function_calls=5,
)
chat = model.start_chat(responder=afc_responder)

response1 = chat.send_message("What is the weather like in Boston?")
assert response1.text.startswith("The weather in Boston is")
assert "nice" in response1.text
assert len(chat.history) == 4
assert chat.history[-3].parts[0].function_call
assert chat.history[-2].parts[0].function_response


EXPECTED_SCHEMA_FOR_GET_CURRENT_WEATHER = {
"title": "get_current_weather",
Expand Down
5 changes: 5 additions & 0 deletions vertexai/generative_models/_generative_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
"""Classes for working with generative models."""
# pylint: disable=bad-continuation, line-too-long, protected-access

from collections.abc import Mapping
import copy
import io
import json
Expand Down Expand Up @@ -2422,6 +2423,10 @@ def respond_to_model_response(
# due to: AttributeError: type object 'MapComposite' has no attribute 'to_dict'
function_args = type(function_call).to_dict(function_call)["args"]
function_call_result = callable_function._function(**function_args)
if not isinstance(function_call_result, Mapping):
# If the function returns a single value, wrap it in the
# format that Part.from_function_response can accept.
function_call_result = {"result": function_call_result}
except Exception as ex:
raise RuntimeError(
f"""Error raised when calling function "{function_call.name}" as requested by the model."""
Expand Down

0 comments on commit 768af67

Please sign in to comment.