Skip to content

Commit

Permalink
System instruction (#270)
Browse files Browse the repository at this point in the history
* fix arg name: system_instructions to system_instruction

proto def for glm.GenerateContentRequest lists system_instruction as singular

* import TypeDict from typing_extensions

* Update glm dependecy to use 0.6.1 to support files, SI, tool_config

* Handle function_calling_mode when passed as a dict with allowed_func_names

* format

* Scope TypedDict test to package directory

* De-pluralise 'instructions'

* System instructions tests and blacken

* format

---------

Co-authored-by: Mark McDonald <macd@google.com>
Co-authored-by: Elsie L <elsieling95@gmail.com>
  • Loading branch information
3 people committed Apr 3, 2024
1 parent 0778d56 commit ba6b439
Show file tree
Hide file tree
Showing 5 changed files with 36 additions and 9 deletions.
3 changes: 2 additions & 1 deletion google/generativeai/answer.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@
import dataclasses
from collections.abc import Iterable
import itertools
from typing import Any, Iterable, Union, Mapping, Optional, TypedDict
from typing import Any, Iterable, Union, Mapping, Optional
from typing_extensions import TypedDict

import google.ai.generativelanguage as glm

Expand Down
10 changes: 5 additions & 5 deletions google/generativeai/generative_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def __init__(
generation_config: generation_types.GenerationConfigType | None = None,
tools: content_types.FunctionLibraryType | None = None,
tool_config: content_types.ToolConfigType | None = None,
system_instructions: content_types.ContentType | None = None,
system_instruction: content_types.ContentType | None = None,
):
if "/" not in model_name:
model_name = "models/" + model_name
Expand All @@ -90,10 +90,10 @@ def __init__(
else:
self._tool_config = content_types.to_tool_config(tool_config)

if system_instructions is None:
self._system_instructions = None
if system_instruction is None:
self._system_instruction = None
else:
self._system_instructions = content_types.to_content(system_instructions)
self._system_instruction = content_types.to_content(system_instruction)

self._client = None
self._async_client = None
Expand Down Expand Up @@ -155,7 +155,7 @@ def _prepare_request(
safety_settings=merged_ss,
tools=tools_lib,
tool_config=tool_config,
system_instructions=self._system_instructions,
system_instruction=self._system_instruction,
)

def _get_tools_lib(
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def get_version():
release_status = "Development Status :: 5 - Production/Stable"

dependencies = [
"google-ai-generativelanguage@https://storage.googleapis.com/generativeai-downloads/preview/ai-generativelanguage-v1beta-py.tar.gz",
"google-ai-generativelanguage==0.6.1",
"google-api-core",
"google-api-python-client",
"google-auth>=2.15.0", # 2.15 adds API key auth support
Expand Down
28 changes: 27 additions & 1 deletion tests/test_generative_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,16 @@
TEST_IMAGE_DATA = TEST_IMAGE_PATH.read_bytes()


def simple_part(text: str) -> glm.Content:
return glm.Content({"parts": [{"text": text}]})


def iter_part(texts: Iterable[str]) -> glm.Content:
return glm.Content({"parts": [{"text": t} for t in texts]})


def simple_response(text: str) -> glm.GenerateContentResponse:
return glm.GenerateContentResponse({"candidates": [{"content": {"parts": [{"text": text}]}}]})
return glm.GenerateContentResponse({"candidates": [{"content": simple_part(text)}]})


class CUJTests(parameterized.TestCase):
Expand Down Expand Up @@ -605,6 +613,24 @@ def test_tools(self):
self.assertLen(obr.tools, 1)
self.assertEqual(type(obr.tools[0]).to_dict(obr.tools[0]), tools)

@parameterized.named_parameters(
["bare_str", "talk like a pirate", simple_part("talk like a pirate")],
[
"part_dict",
{"parts": [{"text": "talk like a pirate"}]},
simple_part("talk like a pirate"),
],
["part_list", ["talk like:", "a pirate"], iter_part(["talk like:", "a pirate"])],
)
def test_system_instruction(self, instruction, expected_instr):
self.responses["generate_content"] = [simple_response("echo echo")]
model = generative_models.GenerativeModel("gemini-pro", system_instruction=instruction)

_ = model.generate_content("test")

[req] = self.observed_requests
self.assertEqual(req.system_instruction, expected_instr)

@parameterized.named_parameters(
["basic", "Hello"],
["list", ["Hello"]],
Expand Down
2 changes: 1 addition & 1 deletion tests/test_typing_extensions.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ class TypingExtensionsTests(absltest.TestCase):

def test_no_typing_typed_dict(self):
root = pathlib.Path(__file__).parent.parent
for fpath in root.rglob("*.py"):
for fpath in (root / "google").rglob("*.py"):
source = fpath.read_text()
if match := TYPING_RE.search(source):
raise ValueError(
Expand Down

0 comments on commit ba6b439

Please sign in to comment.