diff --git a/google/generativeai/answer.py b/google/generativeai/answer.py index 2dd491751..bbadc76c7 100644 --- a/google/generativeai/answer.py +++ b/google/generativeai/answer.py @@ -17,7 +17,8 @@ import dataclasses from collections.abc import Iterable import itertools -from typing import Any, Iterable, Union, Mapping, Optional, TypedDict +from typing import Any, Iterable, Union, Mapping, Optional +from typing_extensions import TypedDict import google.ai.generativelanguage as glm diff --git a/google/generativeai/generative_models.py b/google/generativeai/generative_models.py index 6717f71a5..da981c860 100644 --- a/google/generativeai/generative_models.py +++ b/google/generativeai/generative_models.py @@ -74,7 +74,7 @@ def __init__( generation_config: generation_types.GenerationConfigType | None = None, tools: content_types.FunctionLibraryType | None = None, tool_config: content_types.ToolConfigType | None = None, - system_instructions: content_types.ContentType | None = None, + system_instruction: content_types.ContentType | None = None, ): if "/" not in model_name: model_name = "models/" + model_name @@ -90,10 +90,10 @@ def __init__( else: self._tool_config = content_types.to_tool_config(tool_config) - if system_instructions is None: - self._system_instructions = None + if system_instruction is None: + self._system_instruction = None else: - self._system_instructions = content_types.to_content(system_instructions) + self._system_instruction = content_types.to_content(system_instruction) self._client = None self._async_client = None @@ -155,7 +155,7 @@ def _prepare_request( safety_settings=merged_ss, tools=tools_lib, tool_config=tool_config, - system_instructions=self._system_instructions, + system_instruction=self._system_instruction, ) def _get_tools_lib( diff --git a/setup.py b/setup.py index 0575dcd28..46a8b1c71 100644 --- a/setup.py +++ b/setup.py @@ -42,7 +42,7 @@ def get_version(): release_status = "Development Status :: 5 - Production/Stable" dependencies = [ - "google-ai-generativelanguage@https://storage.googleapis.com/generativeai-downloads/preview/ai-generativelanguage-v1beta-py.tar.gz", + "google-ai-generativelanguage==0.6.1", "google-api-core", "google-api-python-client", "google-auth>=2.15.0", # 2.15 adds API key auth support diff --git a/tests/test_generative_models.py b/tests/test_generative_models.py index 4a63a8767..4d5488421 100644 --- a/tests/test_generative_models.py +++ b/tests/test_generative_models.py @@ -21,8 +21,16 @@ TEST_IMAGE_DATA = TEST_IMAGE_PATH.read_bytes() +def simple_part(text: str) -> glm.Content: + return glm.Content({"parts": [{"text": text}]}) + + +def iter_part(texts: Iterable[str]) -> glm.Content: + return glm.Content({"parts": [{"text": t} for t in texts]}) + + def simple_response(text: str) -> glm.GenerateContentResponse: - return glm.GenerateContentResponse({"candidates": [{"content": {"parts": [{"text": text}]}}]}) + return glm.GenerateContentResponse({"candidates": [{"content": simple_part(text)}]}) class CUJTests(parameterized.TestCase): @@ -605,6 +613,24 @@ def test_tools(self): self.assertLen(obr.tools, 1) self.assertEqual(type(obr.tools[0]).to_dict(obr.tools[0]), tools) + @parameterized.named_parameters( + ["bare_str", "talk like a pirate", simple_part("talk like a pirate")], + [ + "part_dict", + {"parts": [{"text": "talk like a pirate"}]}, + simple_part("talk like a pirate"), + ], + ["part_list", ["talk like:", "a pirate"], iter_part(["talk like:", "a pirate"])], + ) + def test_system_instruction(self, instruction, expected_instr): + self.responses["generate_content"] = [simple_response("echo echo")] + model = generative_models.GenerativeModel("gemini-pro", system_instruction=instruction) + + _ = model.generate_content("test") + + [req] = self.observed_requests + self.assertEqual(req.system_instruction, expected_instr) + @parameterized.named_parameters( ["basic", "Hello"], ["list", ["Hello"]], diff --git a/tests/test_typing_extensions.py b/tests/test_typing_extensions.py index 5187d6fa4..38b189c3c 100644 --- a/tests/test_typing_extensions.py +++ b/tests/test_typing_extensions.py @@ -33,7 +33,7 @@ class TypingExtensionsTests(absltest.TestCase): def test_no_typing_typed_dict(self): root = pathlib.Path(__file__).parent.parent - for fpath in root.rglob("*.py"): + for fpath in (root / "google").rglob("*.py"): source = fpath.read_text() if match := TYPING_RE.search(source): raise ValueError(