Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add request options to chat. #341

Merged
merged 7 commits into from
May 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 32 additions & 2 deletions google/generativeai/generative_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -441,6 +441,7 @@ def send_message(
stream: bool = False,
tools: content_types.FunctionLibraryType | None = None,
tool_config: content_types.ToolConfigType | None = None,
request_options: helper_types.RequestOptionsType | None = None,
) -> generation_types.GenerateContentResponse:
"""Sends the conversation history with the added message and returns the model's response.

Expand Down Expand Up @@ -473,6 +474,9 @@ def send_message(
safety_settings: Overrides for the model's safety settings.
stream: If True, yield response chunks as they are generated.
"""
if request_options is None:
request_options = {}

if self.enable_automatic_function_calling and stream:
raise NotImplementedError(
"The `google.generativeai` SDK does not yet support `stream=True` with "
Expand Down Expand Up @@ -500,6 +504,7 @@ def send_message(
stream=stream,
tools=tools_lib,
tool_config=tool_config,
request_options=request_options,
)

self._check_response(response=response, stream=stream)
Expand All @@ -512,6 +517,7 @@ def send_message(
safety_settings=safety_settings,
stream=stream,
tools_lib=tools_lib,
request_options=request_options,
)

self._last_sent = content
Expand Down Expand Up @@ -542,7 +548,15 @@ def _get_function_calls(self, response) -> list[glm.FunctionCall]:
return function_calls

def _handle_afc(
self, *, response, history, generation_config, safety_settings, stream, tools_lib
self,
*,
response,
history,
generation_config,
safety_settings,
stream,
tools_lib,
request_options,
) -> tuple[list[glm.Content], glm.Content, generation_types.BaseGenerateContentResponse]:

while function_calls := self._get_function_calls(response):
Expand All @@ -568,6 +582,7 @@ def _handle_afc(
safety_settings=safety_settings,
stream=stream,
tools=tools_lib,
request_options=request_options,
)

self._check_response(response=response, stream=stream)
Expand All @@ -584,8 +599,12 @@ async def send_message_async(
stream: bool = False,
tools: content_types.FunctionLibraryType | None = None,
tool_config: content_types.ToolConfigType | None = None,
request_options: helper_types.RequestOptionsType | None = None,
) -> generation_types.AsyncGenerateContentResponse:
"""The async version of `ChatSession.send_message`."""
if request_options is None:
request_options = {}

if self.enable_automatic_function_calling and stream:
raise NotImplementedError(
"The `google.generativeai` SDK does not yet support `stream=True` with "
Expand Down Expand Up @@ -613,6 +632,7 @@ async def send_message_async(
stream=stream,
tools=tools_lib,
tool_config=tool_config,
request_options=request_options,
)

self._check_response(response=response, stream=stream)
Expand All @@ -625,6 +645,7 @@ async def send_message_async(
safety_settings=safety_settings,
stream=stream,
tools_lib=tools_lib,
request_options=request_options,
)

self._last_sent = content
Expand All @@ -633,7 +654,15 @@ async def send_message_async(
return response

async def _handle_afc_async(
self, *, response, history, generation_config, safety_settings, stream, tools_lib
self,
*,
response,
history,
generation_config,
safety_settings,
stream,
tools_lib,
request_options,
) -> tuple[list[glm.Content], glm.Content, generation_types.BaseGenerateContentResponse]:

while function_calls := self._get_function_calls(response):
Expand All @@ -659,6 +688,7 @@ async def _handle_afc_async(
safety_settings=safety_settings,
stream=stream,
tools=tools_lib,
request_options=request_options,
)

self._check_response(response=response, stream=stream)
Expand Down
117 changes: 74 additions & 43 deletions tests/test_generative_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
from google.generativeai import generative_models
from google.generativeai.types import content_types
from google.generativeai.types import generation_types
from google.generativeai.types import helper_types


import PIL.Image

Expand All @@ -33,49 +35,63 @@ def simple_response(text: str) -> glm.GenerateContentResponse:
return glm.GenerateContentResponse({"candidates": [{"content": simple_part(text)}]})


class MockGenerativeServiceClient:
def __init__(self, test):
self.test = test
self.observed_requests = []
self.observed_kwargs = []
self.responses = collections.defaultdict(list)

def generate_content(
self,
request: glm.GenerateContentRequest,
**kwargs,
) -> glm.GenerateContentResponse:
self.test.assertIsInstance(request, glm.GenerateContentRequest)
self.observed_requests.append(request)
self.observed_kwargs.append(kwargs)
response = self.responses["generate_content"].pop(0)
return response

def stream_generate_content(
self,
request: glm.GetModelRequest,
**kwargs,
) -> Iterable[glm.GenerateContentResponse]:
self.observed_requests.append(request)
self.observed_kwargs.append(kwargs)
response = self.responses["stream_generate_content"].pop(0)
return response

def count_tokens(
self,
request: glm.CountTokensRequest,
**kwargs,
) -> Iterable[glm.GenerateContentResponse]:
self.observed_requests.append(request)
self.observed_kwargs.append(kwargs)
response = self.responses["count_tokens"].pop(0)
return response


class CUJTests(parameterized.TestCase):
"""Tests are in order with the design doc."""

def setUp(self):
self.client = unittest.mock.MagicMock()
@property
def observed_requests(self):
return self.client.observed_requests

client_lib._client_manager.clients["generative"] = self.client

def add_client_method(f):
name = f.__name__
setattr(self.client, name, f)
return f
@property
def observed_kwargs(self):
return self.client.observed_kwargs

self.observed_requests = []
self.responses = collections.defaultdict(list)
@property
def responses(self):
return self.client.responses

@add_client_method
def generate_content(
request: glm.GenerateContentRequest,
**kwargs,
) -> glm.GenerateContentResponse:
self.assertIsInstance(request, glm.GenerateContentRequest)
self.observed_requests.append(request)
response = self.responses["generate_content"].pop(0)
return response

@add_client_method
def stream_generate_content(
request: glm.GetModelRequest,
**kwargs,
) -> Iterable[glm.GenerateContentResponse]:
self.observed_requests.append(request)
response = self.responses["stream_generate_content"].pop(0)
return response

@add_client_method
def count_tokens(
request: glm.CountTokensRequest,
**kwargs,
) -> Iterable[glm.GenerateContentResponse]:
self.observed_requests.append(request)
response = self.responses["count_tokens"].pop(0)
return response
def setUp(self):
self.client = MockGenerativeServiceClient(self)
client_lib._client_manager.clients["generative"] = self.client

def test_hello(self):
# Generate text from text prompt
Expand Down Expand Up @@ -443,7 +459,7 @@ def test_copy_history(self):
chat1 = model.start_chat()
chat1.send_message("hello1")

chat2 = copy.deepcopy(chat1)
chat2 = copy.copy(chat1)
chat2.send_message("hello2")

chat1.send_message("hello3")
Expand Down Expand Up @@ -787,7 +803,7 @@ def test_async_code_match(self, obj, aobj):
)

asource = re.sub(" *?# type: ignore", "", asource)
self.assertEqual(source, asource)
self.assertEqual(source, asource, f"error in {obj=}")

def test_repr_for_unary_non_streamed_response(self):
model = generative_models.GenerativeModel(model_name="gemini-pro")
Expand Down Expand Up @@ -1252,15 +1268,30 @@ def test_repr_for_system_instruction(self):
self.assertIn("system_instruction='Be excellent.'", result)

def test_count_tokens_called_with_request_options(self):
self.client.count_tokens = unittest.mock.MagicMock()
request = unittest.mock.ANY
self.responses["count_tokens"].append(glm.CountTokensResponse())
request_options = {"timeout": 120}

self.responses["count_tokens"] = [glm.CountTokensResponse(total_tokens=7)]
model = generative_models.GenerativeModel("gemini-pro-vision")
model.count_tokens([{"role": "user", "parts": ["hello"]}], request_options=request_options)

self.client.count_tokens.assert_called_once_with(request, **request_options)
self.assertEqual(request_options, self.observed_kwargs[0])

def test_chat_with_request_options(self):
self.responses["generate_content"].append(
glm.GenerateContentResponse(
{
"candidates": [{"finish_reason": "STOP"}],
}
)
)
request_options = {"timeout": 120}

model = generative_models.GenerativeModel("gemini-pro")
chat = model.start_chat()
chat.send_message("hello", request_options=helper_types.RequestOptions(**request_options))

request_options["retry"] = None
self.assertEqual(request_options, self.observed_kwargs[0])


if __name__ == "__main__":
Expand Down
Loading