Skip to content

Commit

Permalink
Add request options to chat. (#341)
Browse files Browse the repository at this point in the history
* Add request options to chat

Change-Id: I6f7e4c980fd7e2a14fec4c3e2d837ad745c69c9a

* fix async

Change-Id: Ia224e9e8327443a9920ce5d9a877ebb8c272e583

* fix

Change-Id: I7eed70131346c7d7ffe435c8f6909f7eb3f7e9f7

* merge from main

Change-Id: I4b92a5bc25aa7bf11bfaf31aa6c029096f3e68bc

* add tests

Change-Id: I368315f220413ba9508012721e64093372555590

* format

Change-Id: I26c7fa1f040e7d1ea16068034d78fb9f6cc13db0
  • Loading branch information
MarkDaoust committed May 22, 2024
1 parent 386994a commit 0dca4ce
Show file tree
Hide file tree
Showing 2 changed files with 106 additions and 45 deletions.
34 changes: 32 additions & 2 deletions google/generativeai/generative_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -444,6 +444,7 @@ def send_message(
stream: bool = False,
tools: content_types.FunctionLibraryType | None = None,
tool_config: content_types.ToolConfigType | None = None,
request_options: helper_types.RequestOptionsType | None = None,
) -> generation_types.GenerateContentResponse:
"""Sends the conversation history with the added message and returns the model's response.
Expand Down Expand Up @@ -476,6 +477,9 @@ def send_message(
safety_settings: Overrides for the model's safety settings.
stream: If True, yield response chunks as they are generated.
"""
if request_options is None:
request_options = {}

if self.enable_automatic_function_calling and stream:
raise NotImplementedError(
"Unsupported configuration: The `google.generativeai` SDK currently does not support the combination of `stream=True` and `enable_automatic_function_calling=True`."
Expand Down Expand Up @@ -504,6 +508,7 @@ def send_message(
stream=stream,
tools=tools_lib,
tool_config=tool_config,
request_options=request_options,
)

self._check_response(response=response, stream=stream)
Expand All @@ -516,6 +521,7 @@ def send_message(
safety_settings=safety_settings,
stream=stream,
tools_lib=tools_lib,
request_options=request_options,
)

self._last_sent = content
Expand Down Expand Up @@ -546,7 +552,15 @@ def _get_function_calls(self, response) -> list[glm.FunctionCall]:
return function_calls

def _handle_afc(
self, *, response, history, generation_config, safety_settings, stream, tools_lib
self,
*,
response,
history,
generation_config,
safety_settings,
stream,
tools_lib,
request_options,
) -> tuple[list[glm.Content], glm.Content, generation_types.BaseGenerateContentResponse]:

while function_calls := self._get_function_calls(response):
Expand All @@ -572,6 +586,7 @@ def _handle_afc(
safety_settings=safety_settings,
stream=stream,
tools=tools_lib,
request_options=request_options,
)

self._check_response(response=response, stream=stream)
Expand All @@ -588,8 +603,12 @@ async def send_message_async(
stream: bool = False,
tools: content_types.FunctionLibraryType | None = None,
tool_config: content_types.ToolConfigType | None = None,
request_options: helper_types.RequestOptionsType | None = None,
) -> generation_types.AsyncGenerateContentResponse:
"""The async version of `ChatSession.send_message`."""
if request_options is None:
request_options = {}

if self.enable_automatic_function_calling and stream:
raise NotImplementedError(
"Unsupported configuration: The `google.generativeai` SDK currently does not support the combination of `stream=True` and `enable_automatic_function_calling=True`."
Expand Down Expand Up @@ -618,6 +637,7 @@ async def send_message_async(
stream=stream,
tools=tools_lib,
tool_config=tool_config,
request_options=request_options,
)

self._check_response(response=response, stream=stream)
Expand All @@ -630,6 +650,7 @@ async def send_message_async(
safety_settings=safety_settings,
stream=stream,
tools_lib=tools_lib,
request_options=request_options,
)

self._last_sent = content
Expand All @@ -638,7 +659,15 @@ async def send_message_async(
return response

async def _handle_afc_async(
self, *, response, history, generation_config, safety_settings, stream, tools_lib
self,
*,
response,
history,
generation_config,
safety_settings,
stream,
tools_lib,
request_options,
) -> tuple[list[glm.Content], glm.Content, generation_types.BaseGenerateContentResponse]:

while function_calls := self._get_function_calls(response):
Expand All @@ -664,6 +693,7 @@ async def _handle_afc_async(
safety_settings=safety_settings,
stream=stream,
tools=tools_lib,
request_options=request_options,
)

self._check_response(response=response, stream=stream)
Expand Down
117 changes: 74 additions & 43 deletions tests/test_generative_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
from google.generativeai import generative_models
from google.generativeai.types import content_types
from google.generativeai.types import generation_types
from google.generativeai.types import helper_types


import PIL.Image

Expand All @@ -37,49 +39,63 @@ def simple_response(text: str) -> glm.GenerateContentResponse:
return glm.GenerateContentResponse({"candidates": [{"content": simple_part(text)}]})


class MockGenerativeServiceClient:
def __init__(self, test):
self.test = test
self.observed_requests = []
self.observed_kwargs = []
self.responses = collections.defaultdict(list)

def generate_content(
self,
request: glm.GenerateContentRequest,
**kwargs,
) -> glm.GenerateContentResponse:
self.test.assertIsInstance(request, glm.GenerateContentRequest)
self.observed_requests.append(request)
self.observed_kwargs.append(kwargs)
response = self.responses["generate_content"].pop(0)
return response

def stream_generate_content(
self,
request: glm.GetModelRequest,
**kwargs,
) -> Iterable[glm.GenerateContentResponse]:
self.observed_requests.append(request)
self.observed_kwargs.append(kwargs)
response = self.responses["stream_generate_content"].pop(0)
return response

def count_tokens(
self,
request: glm.CountTokensRequest,
**kwargs,
) -> Iterable[glm.GenerateContentResponse]:
self.observed_requests.append(request)
self.observed_kwargs.append(kwargs)
response = self.responses["count_tokens"].pop(0)
return response


class CUJTests(parameterized.TestCase):
"""Tests are in order with the design doc."""

def setUp(self):
self.client = unittest.mock.MagicMock()
@property
def observed_requests(self):
return self.client.observed_requests

client_lib._client_manager.clients["generative"] = self.client

def add_client_method(f):
name = f.__name__
setattr(self.client, name, f)
return f
@property
def observed_kwargs(self):
return self.client.observed_kwargs

self.observed_requests = []
self.responses = collections.defaultdict(list)
@property
def responses(self):
return self.client.responses

@add_client_method
def generate_content(
request: glm.GenerateContentRequest,
**kwargs,
) -> glm.GenerateContentResponse:
self.assertIsInstance(request, glm.GenerateContentRequest)
self.observed_requests.append(request)
response = self.responses["generate_content"].pop(0)
return response

@add_client_method
def stream_generate_content(
request: glm.GetModelRequest,
**kwargs,
) -> Iterable[glm.GenerateContentResponse]:
self.observed_requests.append(request)
response = self.responses["stream_generate_content"].pop(0)
return response

@add_client_method
def count_tokens(
request: glm.CountTokensRequest,
**kwargs,
) -> Iterable[glm.GenerateContentResponse]:
self.observed_requests.append(request)
response = self.responses["count_tokens"].pop(0)
return response
def setUp(self):
self.client = MockGenerativeServiceClient(self)
client_lib._client_manager.clients["generative"] = self.client

def test_hello(self):
# Generate text from text prompt
Expand Down Expand Up @@ -451,7 +467,7 @@ def test_copy_history(self):
chat1 = model.start_chat()
chat1.send_message("hello1")

chat2 = copy.deepcopy(chat1)
chat2 = copy.copy(chat1)
chat2.send_message("hello2")

chat1.send_message("hello3")
Expand Down Expand Up @@ -810,7 +826,7 @@ def test_async_code_match(self, obj, aobj):
)

asource = re.sub(" *?# type: ignore", "", asource)
self.assertEqual(source, asource)
self.assertEqual(source, asource, f"error in {obj=}")

def test_repr_for_unary_non_streamed_response(self):
model = generative_models.GenerativeModel(model_name="gemini-pro")
Expand Down Expand Up @@ -1208,15 +1224,30 @@ def test_repr_for_system_instruction(self):
self.assertIn("system_instruction='Be excellent.'", result)

def test_count_tokens_called_with_request_options(self):
self.client.count_tokens = unittest.mock.MagicMock()
request = unittest.mock.ANY
self.responses["count_tokens"].append(glm.CountTokensResponse())
request_options = {"timeout": 120}

self.responses["count_tokens"] = [glm.CountTokensResponse(total_tokens=7)]
model = generative_models.GenerativeModel("gemini-pro-vision")
model.count_tokens([{"role": "user", "parts": ["hello"]}], request_options=request_options)

self.client.count_tokens.assert_called_once_with(request, **request_options)
self.assertEqual(request_options, self.observed_kwargs[0])

def test_chat_with_request_options(self):
self.responses["generate_content"].append(
glm.GenerateContentResponse(
{
"candidates": [{"finish_reason": "STOP"}],
}
)
)
request_options = {"timeout": 120}

model = generative_models.GenerativeModel("gemini-pro")
chat = model.start_chat()
chat.send_message("hello", request_options=helper_types.RequestOptions(**request_options))

request_options["retry"] = None
self.assertEqual(request_options, self.observed_kwargs[0])


if __name__ == "__main__":
Expand Down

0 comments on commit 0dca4ce

Please sign in to comment.