Skip to content

Commit

Permalink
feat: enable grounding to ChatModel send_message and send_message_asy…
Browse files Browse the repository at this point in the history
…nc methods

PiperOrigin-RevId: 579999652
  • Loading branch information
vertex-sdk-bot authored and Copybara-Service committed Nov 7, 2023
1 parent eaf4420 commit d4667f2
Show file tree
Hide file tree
Showing 3 changed files with 355 additions and 19 deletions.
21 changes: 16 additions & 5 deletions tests/system/aiplatform/test_language_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,8 +124,8 @@ def test_preview_text_embedding_top_level_from_pretrained(self):

def test_chat_on_chat_model(self):
aiplatform.init(project=e2e_base._PROJECT, location=e2e_base._LOCATION)

chat_model = ChatModel.from_pretrained("google/chat-bison@001")
grounding_source = language_models.GroundingSource.WebSearch()
chat = chat_model.start_chat(
context="My name is Ned. You are my personal assistant. My favorite movies are Lord of the Rings and Hobbit.",
examples=[
Expand All @@ -143,19 +143,23 @@ def test_chat_on_chat_model(self):
)

message1 = "Are my favorite movies based on a book series?"
response1 = chat.send_message(message1)
response1 = chat.send_message(
message1,
grounding_source=grounding_source,
)
assert response1.text
assert response1.grounding_metadata
assert len(chat.message_history) == 2
assert chat.message_history[0].author == chat.USER_AUTHOR
assert chat.message_history[0].content == message1
assert chat.message_history[1].author == chat.MODEL_AUTHOR

message2 = "When were these books published?"
response2 = chat.send_message(
message2,
temperature=0.1,
message2, temperature=0.1, grounding_source=grounding_source
)
assert response2.text
assert response2.grounding_metadata
assert len(chat.message_history) == 4
assert chat.message_history[2].author == chat.USER_AUTHOR
assert chat.message_history[2].content == message2
Expand Down Expand Up @@ -189,6 +193,7 @@ async def test_chat_model_async(self):
aiplatform.init(project=e2e_base._PROJECT, location=e2e_base._LOCATION)

chat_model = ChatModel.from_pretrained("google/chat-bison@001")
grounding_source = language_models.GroundingSource.WebSearch()
chat = chat_model.start_chat(
context="My name is Ned. You are my personal assistant. My favorite movies are Lord of the Rings and Hobbit.",
examples=[
Expand All @@ -206,8 +211,12 @@ async def test_chat_model_async(self):
)

message1 = "Are my favorite movies based on a book series?"
response1 = await chat.send_message_async(message1)
response1 = await chat.send_message_async(
message1,
grounding_source=grounding_source,
)
assert response1.text
assert response1.grounding_metadata
assert len(chat.message_history) == 2
assert chat.message_history[0].author == chat.USER_AUTHOR
assert chat.message_history[0].content == message1
Expand All @@ -217,8 +226,10 @@ async def test_chat_model_async(self):
response2 = await chat.send_message_async(
message2,
temperature=0.1,
grounding_source=grounding_source,
)
assert response2.text
assert response2.grounding_metadata
assert len(chat.message_history) == 4
assert chat.message_history[2].author == chat.USER_AUTHOR
assert chat.message_history[2].content == message2
Expand Down
306 changes: 306 additions & 0 deletions tests/unit/aiplatform/test_language_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -311,6 +311,97 @@
],
}

_TEST_CHAT_GENERATION_MULTI_CANDIDATE_PREDICTION_GROUNDING = {
"safetyAttributes": [
{
"scores": [],
"categories": [],
"blocked": False,
},
{
"scores": [0.1],
"categories": ["Finance"],
"blocked": True,
},
],
"groundingMetadata": [
{
"citations": [
{
"startIndex": 1,
"endIndex": 2,
"url": "url1",
}
]
},
{
"citations": [
{
"startIndex": 3,
"endIndex": 4,
"url": "url2",
}
]
},
],
"candidates": [
{
"author": "1",
"content": "Chat response 2",
},
{
"author": "1",
"content": "",
},
],
}

_TEST_CHAT_GENERATION_MULTI_CANDIDATE_PREDICTION_GROUNDING_NONE = {
"safetyAttributes": [
{
"scores": [],
"categories": [],
"blocked": False,
},
{
"scores": [0.1],
"categories": ["Finance"],
"blocked": True,
},
],
"groundingMetadata": [
None,
None,
],
"candidates": [
{
"author": "1",
"content": "Chat response 2",
},
{
"author": "1",
"content": "",
},
],
}

_EXPECTED_PARSED_GROUNDING_METADATA_CHAT = {
"citations": [
{
"url": "url1",
"start_index": 1,
"end_index": 2,
"title": None,
"license": None,
"publication_date": None,
},
],
}

_EXPECTED_PARSED_GROUNDING_METADATA_CHAT_NONE = {
"citations": [],
}

_TEST_CHAT_PREDICTION_STREAMING = [
{
"candidates": [
Expand Down Expand Up @@ -2312,6 +2403,221 @@ def test_chat(self):
assert prediction_parameters["topK"] == message_top_k
assert prediction_parameters["topP"] == message_top_p

gca_predict_response4 = gca_prediction_service.PredictResponse()
gca_predict_response4.predictions.append(
_TEST_CHAT_GENERATION_MULTI_CANDIDATE_PREDICTION_GROUNDING
)
test_grounding_sources = [
_TEST_GROUNDING_WEB_SEARCH,
_TEST_GROUNDING_VERTEX_AI_SEARCH_DATASTORE,
]
datastore_path = (
"projects/test-project/locations/global/"
"collections/default_collection/dataStores/test_datastore"
)
expected_grounding_sources = [
{"sources": [{"type": "WEB"}]},
{
"sources": [
{
"type": "ENTERPRISE",
"enterpriseDatastore": datastore_path,
}
]
},
]
for test_grounding_source, expected_grounding_source in zip(
test_grounding_sources, expected_grounding_sources
):
with mock.patch.object(
target=prediction_service_client.PredictionServiceClient,
attribute="predict",
return_value=gca_predict_response4,
) as mock_predict4:
response = chat2.send_message(
"Are my favorite movies based on a book series?",
grounding_source=test_grounding_source,
)
prediction_parameters = mock_predict4.call_args[1]["parameters"]
assert (
prediction_parameters["groundingConfig"]
== expected_grounding_source
)
assert (
dataclasses.asdict(response.grounding_metadata)
== _EXPECTED_PARSED_GROUNDING_METADATA_CHAT
)

gca_predict_response5 = gca_prediction_service.PredictResponse()
gca_predict_response5.predictions.append(
_TEST_CHAT_GENERATION_MULTI_CANDIDATE_PREDICTION_GROUNDING_NONE
)
test_grounding_sources = [
_TEST_GROUNDING_WEB_SEARCH,
_TEST_GROUNDING_VERTEX_AI_SEARCH_DATASTORE,
]
datastore_path = (
"projects/test-project/locations/global/"
"collections/default_collection/dataStores/test_datastore"
)
expected_grounding_sources = [
{"sources": [{"type": "WEB"}]},
{
"sources": [
{
"type": "ENTERPRISE",
"enterpriseDatastore": datastore_path,
}
]
},
]
for test_grounding_source, expected_grounding_source in zip(
test_grounding_sources, expected_grounding_sources
):
with mock.patch.object(
target=prediction_service_client.PredictionServiceClient,
attribute="predict",
return_value=gca_predict_response5,
) as mock_predict5:
response = chat2.send_message(
"Are my favorite movies based on a book series?",
grounding_source=test_grounding_source,
)
prediction_parameters = mock_predict5.call_args[1]["parameters"]
assert (
prediction_parameters["groundingConfig"]
== expected_grounding_source
)
assert (
dataclasses.asdict(response.grounding_metadata)
== _EXPECTED_PARSED_GROUNDING_METADATA_CHAT_NONE
)

@pytest.mark.asyncio
async def test_chat_async(self):
"""Test the chat generation model async api."""
aiplatform.init(
project=_TEST_PROJECT,
location=_TEST_LOCATION,
)
with mock.patch.object(
target=model_garden_service_client.ModelGardenServiceClient,
attribute="get_publisher_model",
return_value=gca_publisher_model.PublisherModel(
_CHAT_BISON_PUBLISHER_MODEL_DICT
),
) as mock_get_publisher_model:
model = preview_language_models.ChatModel.from_pretrained("chat-bison@001")

mock_get_publisher_model.assert_called_once_with(
name="publishers/google/models/chat-bison@001", retry=base._DEFAULT_RETRY
)
chat_temperature = 0.1
chat_max_output_tokens = 100
chat_top_k = 1
chat_top_p = 0.1

chat = model.start_chat(
temperature=chat_temperature,
max_output_tokens=chat_max_output_tokens,
top_k=chat_top_k,
top_p=chat_top_p,
)

gca_predict_response6 = gca_prediction_service.PredictResponse()
gca_predict_response6.predictions.append(
_TEST_CHAT_GENERATION_MULTI_CANDIDATE_PREDICTION_GROUNDING
)
test_grounding_sources = [
_TEST_GROUNDING_WEB_SEARCH,
_TEST_GROUNDING_VERTEX_AI_SEARCH_DATASTORE,
]
datastore_path = (
"projects/test-project/locations/global/"
"collections/default_collection/dataStores/test_datastore"
)
expected_grounding_sources = [
{"sources": [{"type": "WEB"}]},
{
"sources": [
{
"type": "ENTERPRISE",
"enterpriseDatastore": datastore_path,
}
]
},
]
for test_grounding_source, expected_grounding_source in zip(
test_grounding_sources, expected_grounding_sources
):
with mock.patch.object(
target=prediction_service_async_client.PredictionServiceAsyncClient,
attribute="predict",
return_value=gca_predict_response6,
) as mock_predict6:
response = await chat.send_message_async(
"Are my favorite movies based on a book series?",
grounding_source=test_grounding_source,
)
prediction_parameters = mock_predict6.call_args[1]["parameters"]
assert prediction_parameters["temperature"] == chat_temperature
assert prediction_parameters["maxDecodeSteps"] == chat_max_output_tokens
assert prediction_parameters["topK"] == chat_top_k
assert prediction_parameters["topP"] == chat_top_p
assert (
prediction_parameters["groundingConfig"]
== expected_grounding_source
)
assert (
dataclasses.asdict(response.grounding_metadata)
== _EXPECTED_PARSED_GROUNDING_METADATA_CHAT
)

gca_predict_response7 = gca_prediction_service.PredictResponse()
gca_predict_response7.predictions.append(
_TEST_CHAT_GENERATION_MULTI_CANDIDATE_PREDICTION_GROUNDING_NONE
)
test_grounding_sources = [
_TEST_GROUNDING_WEB_SEARCH,
_TEST_GROUNDING_VERTEX_AI_SEARCH_DATASTORE,
]
datastore_path = (
"projects/test-project/locations/global/"
"collections/default_collection/dataStores/test_datastore"
)
expected_grounding_sources = [
{"sources": [{"type": "WEB"}]},
{
"sources": [
{
"type": "ENTERPRISE",
"enterpriseDatastore": datastore_path,
}
]
},
]
for test_grounding_source, expected_grounding_source in zip(
test_grounding_sources, expected_grounding_sources
):
with mock.patch.object(
target=prediction_service_async_client.PredictionServiceAsyncClient,
attribute="predict",
return_value=gca_predict_response7,
) as mock_predict7:
response = await chat.send_message_async(
"Are my favorite movies based on a book series?",
grounding_source=test_grounding_source,
)
prediction_parameters = mock_predict7.call_args[1]["parameters"]
assert (
prediction_parameters["groundingConfig"]
== expected_grounding_source
)
assert (
dataclasses.asdict(response.grounding_metadata)
== _EXPECTED_PARSED_GROUNDING_METADATA_CHAT_NONE
)

def test_chat_ga(self):
"""Tests the chat generation model."""
aiplatform.init(
Expand Down
Loading

0 comments on commit d4667f2

Please sign in to comment.