Skip to content

Commit

Permalink
feat: GenAI - Added to_dict() methods to response and content classes
Browse files Browse the repository at this point in the history
Also fixed couple of existing methods that were broken.

PiperOrigin-RevId: 607215896
  • Loading branch information
Ark-kun authored and Copybara-Service committed Feb 15, 2024
1 parent 94f7cd9 commit a78748e
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 2 deletions.
31 changes: 31 additions & 0 deletions tests/unit/vertexai/test_generative_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -362,6 +362,37 @@ def test_chat_function_calling(self, generative_models: generative_models):
)
assert response2.text == "The weather in Boston is super nice!"

@mock.patch.object(
target=prediction_service.PredictionServiceClient,
attribute="generate_content",
new=mock_generate_content,
)
@pytest.mark.parametrize(
"generative_models",
[generative_models, preview_generative_models],
)
def test_conversion_methods(self, generative_models: generative_models):
"""Tests the .to_dict, .from_dict and __repr__ methods"""
model = generative_models.GenerativeModel("gemini-pro")
response = model.generate_content("Why is sky blue?")

response_new = generative_models.GenerationResponse.from_dict(
response.to_dict()
)
assert repr(response_new) == repr(response)

for candidate in response.candidates:
candidate_new = generative_models.Candidate.from_dict(candidate.to_dict())
assert repr(candidate_new) == repr(candidate)

content = candidate.content
content_new = generative_models.Content.from_dict(content.to_dict())
assert repr(content_new) == repr(content)

for part in content.parts:
part_new = generative_models.Part.from_dict(part.to_dict())
assert repr(part_new) == repr(part)

@mock.patch.object(
target=prediction_service.PredictionServiceClient,
attribute="generate_content",
Expand Down
19 changes: 17 additions & 2 deletions vertexai/generative_models/_generative_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -1137,8 +1137,11 @@ def from_dict(cls, generation_config_dict: Dict[str, Any]) -> "GenerationConfig"
)
return cls._from_gapic(raw_generation_config=raw_generation_config)

def to_dict(self) -> Dict[str, Any]:
return type(self._raw_generation_config).to_dict(self._raw_generation_config)

def __repr__(self):
return self._raw_tool.__repr__()
return self._raw_generation_config.__repr__()


class Tool:
Expand Down Expand Up @@ -1249,6 +1252,9 @@ def from_dict(cls, tool_dict: Dict[str, Any]) -> "Tool":
raw_tool = gapic_tool_types.Tool(tool_dict)
return cls._from_gapic(raw_tool=raw_tool)

def to_dict(self) -> Dict[str, Any]:
return type(self._raw_tool).to_dict(self._raw_tool)

def __repr__(self):
return self._raw_tool.__repr__()

Expand Down Expand Up @@ -1378,6 +1384,9 @@ def from_dict(cls, response_dict: Dict[str, Any]) -> "GenerationResponse":
)
return cls._from_gapic(raw_response=raw_response)

def to_dict(self) -> Dict[str, Any]:
return type(self._raw_response).to_dict(self._raw_response)

def __repr__(self):
return self._raw_response.__repr__()

Expand Down Expand Up @@ -1414,6 +1423,9 @@ def from_dict(cls, candidate_dict: Dict[str, Any]) -> "Candidate":
raw_candidate = gapic_content_types.Candidate(candidate_dict)
return cls._from_gapic(raw_candidate=raw_candidate)

def to_dict(self) -> Dict[str, Any]:
return type(self._raw_candidate).to_dict(self._raw_candidate)

def __repr__(self):
return self._raw_candidate.__repr__()

Expand Down Expand Up @@ -1480,6 +1492,9 @@ def from_dict(cls, content_dict: Dict[str, Any]) -> "Content":
raw_content = gapic_content_types.Content(content_dict)
return cls._from_gapic(raw_content=raw_content)

def to_dict(self) -> Dict[str, Any]:
return type(self._raw_content).to_dict(self._raw_content)

def __repr__(self):
return self._raw_content.__repr__()

Expand Down Expand Up @@ -1584,7 +1599,7 @@ def from_function_response(name: str, response: Dict[str, Any]) -> "Part":
)

def to_dict(self) -> Dict[str, Any]:
return self._raw_part.to_dict()
return type(self._raw_part).to_dict(self._raw_part)

@property
def text(self) -> str:
Expand Down

0 comments on commit a78748e

Please sign in to comment.