Skip to content

Commit

Permalink
feat: GenAI - Added the `GenerativeModel.start_chat(response_validati…
Browse files Browse the repository at this point in the history
…on: bool = True)` parameter

The error messages are now more informative.

The use of the `raise_on_blocked` parameter has been deprecated. Use `response_validation` instead.

PiperOrigin-RevId: 607188491
  • Loading branch information
Ark-kun authored and Copybara-Service committed Feb 15, 2024
1 parent 0c3e294 commit 94f7cd9
Show file tree
Hide file tree
Showing 4 changed files with 145 additions and 36 deletions.
40 changes: 40 additions & 0 deletions tests/unit/vertexai/test_generative_models.py
Expand Up @@ -120,6 +120,23 @@ def mock_generate_content(
model: Optional[str] = None,
contents: Optional[MutableSequence[gapic_content_types.Content]] = None,
) -> Iterable[gapic_prediction_service_types.GenerateContentResponse]:
last_message_part = request.contents[-1].parts[0]
should_fail = last_message_part.text and "Please fail" in last_message_part.text
if should_fail:
response = gapic_prediction_service_types.GenerateContentResponse(
candidates=[
gapic_content_types.Candidate(
finish_reason=gapic_content_types.Candidate.FinishReason.SAFETY,
finish_message="Failed due to: " + last_message_part.text,
safety_ratings=[
gapic_content_types.SafetyRating(rating)
for rating in _RESPONSE_SAFETY_RATINGS_STRUCT
],
),
],
)
return response

is_continued_chat = len(request.contents) > 1
has_retrieval = any(
tool.retrieval or tool.google_search_retrieval for tool in request.tools
Expand Down Expand Up @@ -281,6 +298,29 @@ def test_chat_send_message(self, generative_models: generative_models):
response2 = chat.send_message("Is sky blue on other planets?")
assert response2.text

@mock.patch.object(
target=prediction_service.PredictionServiceClient,
attribute="generate_content",
new=mock_generate_content,
)
@pytest.mark.parametrize(
"generative_models",
[generative_models, preview_generative_models],
)
def test_chat_send_message_response_validation_errors(
self, generative_models: generative_models
):
model = generative_models.GenerativeModel("gemini-pro")
chat = model.start_chat()
response1 = chat.send_message("Why is sky blue?")
assert response1.text
assert len(chat.history) == 2

with pytest.raises(generative_models.ResponseValidationError):
chat.send_message("Please fail!")
# Checking that history did not get updated
assert len(chat.history) == 2

@mock.patch.object(
target=prediction_service.PredictionServiceClient,
attribute="generate_content",
Expand Down
4 changes: 2 additions & 2 deletions vertexai/generative_models/__init__.py
Expand Up @@ -29,7 +29,7 @@
HarmBlockThreshold,
Image,
Part,
ResponseBlockedError,
ResponseValidationError,
Tool,
)

Expand All @@ -46,6 +46,6 @@
"HarmBlockThreshold",
"Image",
"Part",
"ResponseBlockedError",
"ResponseValidationError",
"Tool",
]
129 changes: 96 additions & 33 deletions vertexai/generative_models/_generative_models.py
Expand Up @@ -45,6 +45,7 @@
from vertexai.language_models import (
_language_models as tunable_models,
)
import warnings

try:
from PIL import Image as PIL_Image # pylint: disable=g-import-not-at-top
Expand Down Expand Up @@ -606,18 +607,28 @@ def start_chat(
self,
*,
history: Optional[List["Content"]] = None,
response_validation: bool = True,
) -> "ChatSession":
"""Creates a stateful chat session.
Args:
history: Previous history to initialize the chat session.
response_validation: Whether to validate responses before adding
them to chat history. By default, `send_message` will raise
error if the request or response is blocked or if the response
is incomplete due to going over the max token limit.
If set to `False`, the chat session history will always
accumulate the request and response messages even if the
reponse if blocked or incomplete. This can result in an unusable
chat session state.
Returns:
A ChatSession object.
"""
return ChatSession(
model=self,
history=history,
response_validation=response_validation,
)


Expand All @@ -628,6 +639,29 @@ def start_chat(
]


def _validate_response(
response: "GenerationResponse",
request_contents: Optional[List["Content"]] = None,
response_chunks: Optional[List["GenerationResponse"]] = None,
) -> None:
candidate = response.candidates[0]
if candidate.finish_reason not in _SUCCESSFUL_FINISH_REASONS:
message = (
"The model response did not completed successfully.\n"
f"Finish reason: {candidate.finish_reason}.\n"
f"Finish message: {candidate.finish_message}.\n"
f"Safety ratings: {candidate.safety_ratings}.\n"
"To protect the integrity of the chat session, the request and response were not added to chat history.\n"
"To skip the response validation, specify `model.start_chat(response_validation=False)`.\n"
"Note that letting blocked or otherwise incomplete responses into chat history might lead to future interactions being blocked by the service."
)
raise ResponseValidationError(
message=message,
request_contents=request_contents,
responses=response_chunks,
)


class ChatSession:
"""Chat session holds the chat history."""

Expand All @@ -639,15 +673,15 @@ def __init__(
model: _GenerativeModel,
*,
history: Optional[List["Content"]] = None,
raise_on_blocked: bool = True,
response_validation: bool = True,
):
if history:
if not all(isinstance(item, Content) for item in history):
raise ValueError("history must be a list of Content objects.")

self._model = model
self._history = history or []
self._raise_on_blocked = raise_on_blocked
self._response_validator = _validate_response if response_validation else None

@property
def history(self) -> List["Content"]:
Expand Down Expand Up @@ -784,13 +818,12 @@ def _send_message(
tools=tools,
)
# By default we're not adding incomplete interactions to history.
if self._raise_on_blocked:
if response.candidates[0].finish_reason not in _SUCCESSFUL_FINISH_REASONS:
raise ResponseBlockedError(
message="The response was blocked.",
request_contents=request_history,
responses=[response],
)
if self._response_validator is not None:
self._response_validator(
response=response,
request_contents=request_history,
response_chunks=[response],
)

# Adding the request and the first response candidate to history
response_message = response.candidates[0].content
Expand Down Expand Up @@ -841,13 +874,13 @@ async def _send_message_async(
tools=tools,
)
# By default we're not adding incomplete interactions to history.
if self._raise_on_blocked:
if response.candidates[0].finish_reason not in _SUCCESSFUL_FINISH_REASONS:
raise ResponseBlockedError(
message="The response was blocked.",
request_contents=request_history,
responses=[response],
)
if self._response_validator is not None:
self._response_validator(
response=response,
request_contents=request_history,
response_chunks=[response],
)

# Adding the request and the first response candidate to history
response_message = response.candidates[0].content
# Response role is NOT set by the model.
Expand Down Expand Up @@ -905,13 +938,12 @@ def _send_message_streaming(
else:
full_response = chunk
# By default we're not adding incomplete interactions to history.
if self._raise_on_blocked:
if chunk.candidates[0].finish_reason not in _SUCCESSFUL_FINISH_REASONS:
raise ResponseBlockedError(
message="The response was blocked.",
request_contents=request_history,
responses=chunks,
)
if self._response_validator is not None:
self._response_validator(
response=chunk,
request_contents=request_history,
response_chunks=chunks,
)
yield chunk
if not full_response:
return
Expand Down Expand Up @@ -973,16 +1005,13 @@ async def async_generator():
else:
full_response = chunk
# By default we're not adding incomplete interactions to history.
if self._raise_on_blocked:
if (
chunk.candidates[0].finish_reason
not in _SUCCESSFUL_FINISH_REASONS
):
raise ResponseBlockedError(
message="The response was blocked.",
request_contents=request_history,
responses=chunks,
)
if self._response_validator is not None:
self._response_validator(
response=chunk,
request_contents=request_history,
response_chunks=chunks,
)

yield chunk
if not full_response:
return
Expand All @@ -996,6 +1025,36 @@ async def async_generator():
return async_generator()


class _PreviewChatSession(ChatSession):
__doc__ = ChatSession.__doc__

# This class preserves backwards compatibility with the `raise_on_blocked` parameter.

def __init__(
self,
model: _GenerativeModel,
*,
history: Optional[List["Content"]] = None,
response_validation: bool = True,
# Deprecated
raise_on_blocked: Optional[bool] = None,
):
if raise_on_blocked is not None:
warnings.warn(
message="Use `response_validation` instead of `raise_on_blocked`."
)
if response_validation is not None:
raise ValueError(
"Cannot use `response_validation` when `raise_on_blocked` is set."
)
response_validation = raise_on_blocked
super().__init__(
model=model,
history=history,
response_validation=response_validation,
)


class ResponseBlockedError(Exception):
def __init__(
self,
Expand All @@ -1008,6 +1067,10 @@ def __init__(
self.responses = responses


class ResponseValidationError(ResponseBlockedError):
pass


### Structures


Expand Down
8 changes: 7 additions & 1 deletion vertexai/preview/generative_models.py
Expand Up @@ -19,10 +19,10 @@
from vertexai.generative_models._generative_models import (
grounding,
_PreviewGenerativeModel,
_PreviewChatSession,
GenerationConfig,
GenerationResponse,
Candidate,
ChatSession,
Content,
FinishReason,
FunctionDeclaration,
Expand All @@ -31,6 +31,7 @@
Image,
Part,
ResponseBlockedError,
ResponseValidationError,
Tool,
)

Expand All @@ -39,6 +40,10 @@ class GenerativeModel(_PreviewGenerativeModel):
__doc__ = _PreviewGenerativeModel.__doc__


class ChatSession(_PreviewChatSession):
__doc__ = _PreviewChatSession.__doc__


__all__ = [
"grounding",
"GenerationConfig",
Expand All @@ -54,5 +59,6 @@ class GenerativeModel(_PreviewGenerativeModel):
"Image",
"Part",
"ResponseBlockedError",
"ResponseValidationError",
"Tool",
]

0 comments on commit 94f7cd9

Please sign in to comment.