Skip to content

Commit

Permalink
switch imports
Browse files Browse the repository at this point in the history
Change-Id: I2853a88d7acc51a78174c97e30bde8eb24e1d457
  • Loading branch information
MarkDaoust committed May 9, 2024
1 parent e21724e commit 58dee62
Show file tree
Hide file tree
Showing 4 changed files with 16 additions and 16 deletions.
4 changes: 2 additions & 2 deletions google/generativeai/discuss.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from google.generativeai import string_utils
from google.generativeai.types import discuss_types
from google.generativeai.types import model_types
from google.generativeai.types import safety_types
from google.generativeai.types import palm_safety_types


def _make_message(content: discuss_types.MessageOptions) -> glm.Message:
Expand Down Expand Up @@ -521,7 +521,7 @@ def _build_chat_response(
response = type(response).to_dict(response)
response.pop("messages")

response["filters"] = safety_types.convert_filters_to_enums(response["filters"])
response["filters"] = palm_safety_types.convert_filters_to_enums(response["filters"])

if response["candidates"]:
last = response["candidates"][0]
Expand Down
4 changes: 2 additions & 2 deletions google/generativeai/types/discuss_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
import google.ai.generativelanguage as glm
from google.generativeai import string_utils

from google.generativeai.types import safety_types
from google.generativeai.types import palm_safety_types
from google.generativeai.types import citation_types


Expand Down Expand Up @@ -169,7 +169,7 @@ class ChatResponse(abc.ABC):
temperature: Optional[float]
candidate_count: Optional[int]
candidates: List[MessageDict]
filters: List[safety_types.ContentFilterDict]
filters: List[palm_safety_types.ContentFilterDict]
top_p: Optional[float] = None
top_k: Optional[float] = None

Expand Down
8 changes: 4 additions & 4 deletions google/generativeai/types/text_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from typing_extensions import TypedDict

from google.generativeai import string_utils
from google.generativeai.types import safety_types
from google.generativeai.types import palm_safety_types
from google.generativeai.types import citation_types


Expand All @@ -42,7 +42,7 @@ class BatchEmbeddingDict(TypedDict):

class TextCompletion(TypedDict, total=False):
output: str
safety_ratings: List[safety_types.SafetyRatingDict | None]
safety_ratings: List[palm_safety_types.SafetyRatingDict | None]
citation_metadata: citation_types.CitationMetadataDict | None


Expand All @@ -63,8 +63,8 @@ class Completion(abc.ABC):

candidates: List[TextCompletion]
result: str | None
filters: List[safety_types.ContentFilterDict | None]
safety_feedback: List[safety_types.SafetyFeedbackDict | None]
filters: List[palm_safety_types.ContentFilterDict | None]
safety_feedback: List[palm_safety_types.SafetyFeedbackDict | None]

def to_dict(self) -> Dict[str, Any]:
result = {
Expand Down
16 changes: 8 additions & 8 deletions tests/test_discuss.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from google.generativeai import discuss
from google.generativeai import client
import google.generativeai as genai
from google.generativeai.types import safety_types
from google.generativeai.types import palm_safety_types

from absl.testing import absltest
from absl.testing import parameterized
Expand Down Expand Up @@ -289,32 +289,32 @@ def test_receive_and_reply_with_filters(self):
self.mock_response = mock_response = glm.GenerateMessageResponse(
candidates=[glm.Message(content="a", author="1")],
filters=[
glm.ContentFilter(reason=safety_types.BlockedReason.SAFETY, message="unsafe"),
glm.ContentFilter(reason=safety_types.BlockedReason.OTHER),
glm.ContentFilter(reason=palm_safety_types.BlockedReason.SAFETY, message="unsafe"),
glm.ContentFilter(reason=palm_safety_types.BlockedReason.OTHER),
],
)
response = discuss.chat(messages="do filters work?")

filters = response.filters
self.assertLen(filters, 2)
self.assertIsInstance(filters[0]["reason"], safety_types.BlockedReason)
self.assertEqual(filters[0]["reason"], safety_types.BlockedReason.SAFETY)
self.assertIsInstance(filters[0]["reason"], palm_safety_types.BlockedReason)
self.assertEqual(filters[0]["reason"], palm_safety_types.BlockedReason.SAFETY)
self.assertEqual(filters[0]["message"], "unsafe")

self.mock_response = glm.GenerateMessageResponse(
candidates=[glm.Message(content="a", author="1")],
filters=[
glm.ContentFilter(reason=safety_types.BlockedReason.BLOCKED_REASON_UNSPECIFIED)
glm.ContentFilter(reason=palm_safety_types.BlockedReason.BLOCKED_REASON_UNSPECIFIED)
],
)

response = response.reply("Does reply work?")
filters = response.filters
self.assertLen(filters, 1)
self.assertIsInstance(filters[0]["reason"], safety_types.BlockedReason)
self.assertIsInstance(filters[0]["reason"], palm_safety_types.BlockedReason)
self.assertEqual(
filters[0]["reason"],
safety_types.BlockedReason.BLOCKED_REASON_UNSPECIFIED,
palm_safety_types.BlockedReason.BLOCKED_REASON_UNSPECIFIED,
)

def test_chat_citations(self):
Expand Down

0 comments on commit 58dee62

Please sign in to comment.