Skip to content

Commit

Permalink
fix: GenAI - Improved from_dict methods for content types (`Generat…
Browse files Browse the repository at this point in the history
…ionResponse`, `Candidate`, `Content`, `Part`)

Workaround for issue in the proto-plus library: googleapis/proto-plus-python#424

Fixes #3194

PiperOrigin-RevId: 615334182
  • Loading branch information
Ark-kun authored and Copybara-Service committed Mar 13, 2024
1 parent b30f5a6 commit 613ce69
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 9 deletions.
32 changes: 29 additions & 3 deletions tests/unit/vertexai/test_generative_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -379,9 +379,29 @@ def test_chat_function_calling(self, generative_models: generative_models):
[generative_models, preview_generative_models],
)
def test_conversion_methods(self, generative_models: generative_models):
"""Tests the .to_dict, .from_dict and __repr__ methods"""
model = generative_models.GenerativeModel("gemini-pro")
response = model.generate_content("Why is sky blue?")
"""Tests the .to_dict, .from_dict and __repr__ methods."""
# Testing on a full chat conversation which includes function calling
get_current_weather_func = generative_models.FunctionDeclaration(
name="get_current_weather",
description="Get the current weather in a given location",
parameters=_REQUEST_FUNCTION_PARAMETER_SCHEMA_STRUCT,
)
weather_tool = generative_models.Tool(
function_declarations=[get_current_weather_func],
)

model = generative_models.GenerativeModel("gemini-pro", tools=[weather_tool])
chat = model.start_chat()
response = chat.send_message("What is the weather like in Boston?")
chat.send_message(
generative_models.Part.from_function_response(
name="get_current_weather",
response={
"location": "Boston",
"weather": "super nice",
},
),
)

response_new = generative_models.GenerationResponse.from_dict(
response.to_dict()
Expand All @@ -400,6 +420,12 @@ def test_conversion_methods(self, generative_models: generative_models):
part_new = generative_models.Part.from_dict(part.to_dict())
assert repr(part_new) == repr(part)

# Checking the history which contains different Part types
for content in chat.history:
for part in content.parts:
part_new = generative_models.Part.from_dict(part.to_dict())
assert repr(part_new) == repr(part)

@mock.patch.object(
target=prediction_service.PredictionServiceClient,
attribute="generate_content",
Expand Down
15 changes: 9 additions & 6 deletions vertexai/generative_models/_generative_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
from vertexai.language_models import (
_language_models as tunable_models,
)
from google.protobuf import json_format
import warnings

try:
Expand Down Expand Up @@ -1377,9 +1378,8 @@ def _from_gapic(

@classmethod
def from_dict(cls, response_dict: Dict[str, Any]) -> "GenerationResponse":
raw_response = gapic_prediction_service_types.GenerateContentResponse(
response_dict
)
raw_response = gapic_prediction_service_types.GenerateContentResponse()
json_format.ParseDict(response_dict, raw_response._pb)
return cls._from_gapic(raw_response=raw_response)

def to_dict(self) -> Dict[str, Any]:
Expand Down Expand Up @@ -1418,7 +1418,8 @@ def _from_gapic(cls, raw_candidate: gapic_content_types.Candidate) -> "Candidate

@classmethod
def from_dict(cls, candidate_dict: Dict[str, Any]) -> "Candidate":
raw_candidate = gapic_content_types.Candidate(candidate_dict)
raw_candidate = gapic_content_types.Candidate()
json_format.ParseDict(candidate_dict, raw_candidate._pb)
return cls._from_gapic(raw_candidate=raw_candidate)

def to_dict(self) -> Dict[str, Any]:
Expand Down Expand Up @@ -1497,7 +1498,8 @@ def _from_gapic(cls, raw_content: gapic_content_types.Content) -> "Content":

@classmethod
def from_dict(cls, content_dict: Dict[str, Any]) -> "Content":
raw_content = gapic_content_types.Content(content_dict)
raw_content = gapic_content_types.Content()
json_format.ParseDict(content_dict, raw_content._pb)
return cls._from_gapic(raw_content=raw_content)

def to_dict(self) -> Dict[str, Any]:
Expand Down Expand Up @@ -1563,7 +1565,8 @@ def _from_gapic(cls, raw_part: gapic_content_types.Part) -> "Part":

@classmethod
def from_dict(cls, part_dict: Dict[str, Any]) -> "Part":
raw_part = gapic_content_types.Part(part_dict)
raw_part = gapic_content_types.Part()
json_format.ParseDict(part_dict, raw_part._pb)
return cls._from_gapic(raw_part=raw_part)

def __repr__(self):
Expand Down

0 comments on commit 613ce69

Please sign in to comment.