Skip to content

Commit

Permalink
feat: LLM - Added support for async streaming
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 573094790
  • Loading branch information
Ark-kun authored and Copybara-Service committed Oct 13, 2023
1 parent 7944348 commit 760a025
Show file tree
Hide file tree
Showing 3 changed files with 317 additions and 2 deletions.
88 changes: 87 additions & 1 deletion google/cloud/aiplatform/_streaming_prediction.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
#
"""Streaming prediction functions."""

from typing import Any, Dict, Iterator, List, Optional, Sequence
from typing import Any, AsyncIterator, Dict, Iterator, List, Optional, Sequence

from google.cloud.aiplatform_v1.services import prediction_service
from google.cloud.aiplatform_v1.types import (
Expand Down Expand Up @@ -108,6 +108,34 @@ def predict_stream_of_tensor_lists_from_single_tensor_list(
yield response.outputs


async def predict_stream_of_tensor_lists_from_single_tensor_list_async(
prediction_service_async_client: prediction_service.PredictionServiceAsyncClient,
endpoint_name: str,
tensor_list: List[aiplatform_types.Tensor],
parameters_tensor: Optional[aiplatform_types.Tensor] = None,
) -> AsyncIterator[List[aiplatform_types.Tensor]]:
"""Asynchronously predicts a stream of lists of `Tensor` objects from a single list of `Tensor` objects.
Args:
tensor_list: Model input as a list of `Tensor` objects.
parameters_tensor: Optional. Prediction parameters in `Tensor` form.
prediction_service_async_client: A PredictionServiceAsyncClient object.
endpoint_name: Resource name of Endpoint or PublisherModel.
Yields:
A generator of model prediction `Tensor` lists.
"""
request = prediction_service_types.StreamingPredictRequest(
endpoint=endpoint_name,
inputs=tensor_list,
parameters=parameters_tensor,
)
async for response in prediction_service_async_client.server_streaming_predict(
request=request
):
yield response.outputs


def predict_stream_of_dict_lists_from_single_dict_list(
prediction_service_client: prediction_service.PredictionServiceClient,
endpoint_name: str,
Expand Down Expand Up @@ -136,6 +164,34 @@ def predict_stream_of_dict_lists_from_single_dict_list(
yield [tensor_to_value(tensor._pb) for tensor in tensor_list]


async def predict_stream_of_dict_lists_from_single_dict_list_async(
prediction_service_async_client: prediction_service.PredictionServiceAsyncClient,
endpoint_name: str,
dict_list: List[Dict[str, Any]],
parameters: Optional[Dict[str, Any]] = None,
) -> AsyncIterator[List[Dict[str, Any]]]:
"""Asynchronously predicts a stream of lists of dicts from a stream of lists of dicts.
Args:
dict_list: Model input as a list of `dict` objects.
parameters: Optional. Prediction parameters `dict` form.
prediction_service_async_client: A PredictionServiceAsyncClient object.
endpoint_name: Resource name of Endpoint or PublisherModel.
Yields:
A generator of model prediction dict lists.
"""
tensor_list = [value_to_tensor(d) for d in dict_list]
parameters_tensor = value_to_tensor(parameters) if parameters else None
async for tensor_list in predict_stream_of_tensor_lists_from_single_tensor_list_async(
prediction_service_async_client=prediction_service_async_client,
endpoint_name=endpoint_name,
tensor_list=tensor_list,
parameters_tensor=parameters_tensor,
):
yield [tensor_to_value(tensor._pb) for tensor in tensor_list]


def predict_stream_of_dicts_from_single_dict(
prediction_service_client: prediction_service.PredictionServiceClient,
endpoint_name: str,
Expand Down Expand Up @@ -164,3 +220,33 @@ def predict_stream_of_dicts_from_single_dict(
f"Expected to receive a single output, but got {dict_list}"
)
yield dict_list[0]


async def predict_stream_of_dicts_from_single_dict_async(
prediction_service_async_client: prediction_service.PredictionServiceAsyncClient,
endpoint_name: str,
instance: Dict[str, Any],
parameters: Optional[Dict[str, Any]] = None,
) -> AsyncIterator[Dict[str, Any]]:
"""Asynchronously predicts a stream of dicts from a single instance dict.
Args:
instance: A single input instance `dict`.
parameters: Optional. Prediction parameters `dict`.
prediction_service_async_client: A PredictionServiceAsyncClient object.
endpoint_name: Resource name of Endpoint or PublisherModel.
Yields:
A generator of model prediction dicts.
"""
async for dict_list in predict_stream_of_dict_lists_from_single_dict_list_async(
prediction_service_async_client=prediction_service_async_client,
endpoint_name=endpoint_name,
dict_list=[instance],
parameters=parameters,
):
if len(dict_list) > 1:
raise ValueError(
f"Expected to receive a single output, but got {dict_list}"
)
yield dict_list[0]
35 changes: 35 additions & 0 deletions tests/unit/aiplatform/test_language_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -1465,6 +1465,41 @@ def test_text_generation_model_predict_streaming(self):
):
assert len(response.text) > 10

@pytest.mark.asyncio
async def test_text_generation_model_predict_streaming_async(self):
"""Tests the TextGenerationModel.predict_streaming_async method."""
with mock.patch.object(
target=model_garden_service_client.ModelGardenServiceClient,
attribute="get_publisher_model",
return_value=gca_publisher_model.PublisherModel(
_TEXT_BISON_PUBLISHER_MODEL_DICT
),
):
model = language_models.TextGenerationModel.from_pretrained(
"text-bison@001"
)

async def mock_server_streaming_predict_async(*args, **kwargs):
for response_dict in _TEST_TEXT_GENERATION_PREDICTION_STREAMING:
yield gca_prediction_service.StreamingPredictResponse(
outputs=[_streaming_prediction.value_to_tensor(response_dict)]
)

with mock.patch.object(
target=prediction_service_async_client.PredictionServiceAsyncClient,
attribute="server_streaming_predict",
new=mock_server_streaming_predict_async,
):
async for response in model.predict_streaming_async(
"Count to 50",
max_output_tokens=1000,
temperature=0.0,
top_p=1.0,
top_k=5,
stop_sequences=["# %%"],
):
assert len(response.text) > 10

def test_text_generation_response_repr(self):
response = language_models.TextGenerationResponse(
text="",
Expand Down
196 changes: 195 additions & 1 deletion vertexai/language_models/_language_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
"""Classes for working with language models."""

import dataclasses
from typing import Any, Dict, Iterator, List, Optional, Sequence, Union
from typing import Any, AsyncIterator, Dict, Iterator, List, Optional, Sequence, Union
import warnings

from google.cloud import aiplatform
Expand Down Expand Up @@ -871,6 +871,54 @@ def predict_streaming(
)
yield _parse_text_generation_model_response(prediction_obj)

async def predict_streaming_async(
self,
prompt: str,
*,
max_output_tokens: int = _DEFAULT_MAX_OUTPUT_TOKENS,
temperature: Optional[float] = None,
top_k: Optional[int] = None,
top_p: Optional[float] = None,
stop_sequences: Optional[List[str]] = None,
) -> AsyncIterator[TextGenerationResponse]:
"""Asynchronously gets a streaming model response for a single prompt.
The result is a stream (generator) of partial responses.
Args:
prompt: Question to ask the model.
max_output_tokens: Max length of the output text in tokens. Range: [1, 1024].
temperature: Controls the randomness of predictions. Range: [0, 1]. Default: 0.
top_k: The number of highest probability vocabulary tokens to keep for top-k-filtering. Range: [1, 40]. Default: 40.
top_p: The cumulative probability of parameter highest probability vocabulary tokens to keep for nucleus sampling. Range: [0, 1]. Default: 0.95.
stop_sequences: Customized stop sequences to stop the decoding process.
Yields:
A stream of `TextGenerationResponse` objects that contain partial
responses produced by the model.
"""
prediction_request = _create_text_generation_prediction_request(
prompt=prompt,
max_output_tokens=max_output_tokens,
temperature=temperature,
top_k=top_k,
top_p=top_p,
stop_sequences=stop_sequences,
)

prediction_service_async_client = self._endpoint._prediction_async_client
async for prediction_dict in _streaming_prediction.predict_stream_of_dicts_from_single_dict_async(
prediction_service_async_client=prediction_service_async_client,
endpoint_name=self._endpoint_name,
instance=prediction_request.instance,
parameters=prediction_request.parameters,
):
prediction_obj = aiplatform.models.Prediction(
predictions=[prediction_dict],
deployed_model_id="",
)
yield _parse_text_generation_model_response(prediction_obj)


def _create_text_generation_prediction_request(
prompt: str,
Expand Down Expand Up @@ -1928,6 +1976,75 @@ def send_message_streaming(
ChatMessage(content=full_response_text, author=self.MODEL_AUTHOR)
)

async def send_message_streaming_async(
self,
message: str,
*,
max_output_tokens: Optional[int] = None,
temperature: Optional[float] = None,
top_k: Optional[int] = None,
top_p: Optional[float] = None,
stop_sequences: Optional[List[str]] = None,
) -> AsyncIterator[TextGenerationResponse]:
"""Asynchronously sends message to the language model and gets a streamed response.
The response is only added to the history once it's fully read.
Args:
message: Message to send to the model
max_output_tokens: Max length of the output text in tokens. Range: [1, 1024].
Uses the value specified when calling `ChatModel.start_chat` by default.
temperature: Controls the randomness of predictions. Range: [0, 1]. Default: 0.
Uses the value specified when calling `ChatModel.start_chat` by default.
top_k: The number of highest probability vocabulary tokens to keep for top-k-filtering. Range: [1, 40]. Default: 40.
Uses the value specified when calling `ChatModel.start_chat` by default.
top_p: The cumulative probability of parameter highest probability vocabulary tokens to keep for nucleus sampling. Range: [0, 1]. Default: 0.95.
Uses the value specified when calling `ChatModel.start_chat` by default.
stop_sequences: Customized stop sequences to stop the decoding process.
Uses the value specified when calling `ChatModel.start_chat` by default.
Yields:
A stream of `TextGenerationResponse` objects that contain partial
responses produced by the model.
"""
prediction_request = self._prepare_request(
message=message,
max_output_tokens=max_output_tokens,
temperature=temperature,
top_k=top_k,
top_p=top_p,
stop_sequences=stop_sequences,
)

prediction_service_async_client = self._model._endpoint._prediction_async_client

full_response_text = ""

async for prediction_dict in _streaming_prediction.predict_stream_of_dicts_from_single_dict_async(
prediction_service_async_client=prediction_service_async_client,
endpoint_name=self._model._endpoint_name,
instance=prediction_request.instance,
parameters=prediction_request.parameters,
):
prediction_response = aiplatform.models.Prediction(
predictions=[prediction_dict],
deployed_model_id="",
)
text_generation_response = self._parse_chat_prediction_response(
prediction_response=prediction_response
)
full_response_text += text_generation_response.text
yield text_generation_response

# We only add the question and answer to the history if/when the answer
# was read fully. Otherwise, the answer would have been truncated.
self._message_history.append(
ChatMessage(content=message, author=self.USER_AUTHOR)
)
self._message_history.append(
ChatMessage(content=full_response_text, author=self.MODEL_AUTHOR)
)


class ChatSession(_ChatSessionBase):
"""ChatSession represents a chat session with a language model.
Expand Down Expand Up @@ -2073,6 +2190,38 @@ def send_message_streaming(
stop_sequences=stop_sequences,
)

def send_message_streaming_async(
self,
message: str,
*,
max_output_tokens: Optional[int] = None,
temperature: Optional[float] = None,
stop_sequences: Optional[List[str]] = None,
) -> AsyncIterator[TextGenerationResponse]:
"""Asynchronously sends message to the language model and gets a streamed response.
The response is only added to the history once it's fully read.
Args:
message: Message to send to the model
max_output_tokens: Max length of the output text in tokens. Range: [1, 1024].
Uses the value specified when calling `ChatModel.start_chat` by default.
temperature: Controls the randomness of predictions. Range: [0, 1]. Default: 0.
Uses the value specified when calling `ChatModel.start_chat` by default.
stop_sequences: Customized stop sequences to stop the decoding process.
Uses the value specified when calling `ChatModel.start_chat` by default.
Returns:
A stream of `TextGenerationResponse` objects that contain partial
responses produced by the model.
"""
return super().send_message_streaming_async(
message=message,
max_output_tokens=max_output_tokens,
temperature=temperature,
stop_sequences=stop_sequences,
)


class CodeGenerationModel(_LanguageModel):
"""A language model that generates code.
Expand Down Expand Up @@ -2255,6 +2404,51 @@ def predict_streaming(
)
yield _parse_text_generation_model_response(prediction_obj)

async def predict_streaming_async(
self,
prefix: str,
suffix: Optional[str] = None,
*,
max_output_tokens: Optional[int] = None,
temperature: Optional[float] = None,
stop_sequences: Optional[List[str]] = None,
) -> AsyncIterator[TextGenerationResponse]:
"""Asynchronously predicts the code based on previous code.
The result is a stream (generator) of partial responses.
Args:
prefix: Code before the current point.
suffix: Code after the current point.
max_output_tokens: Max length of the output text in tokens. Range: [1, 1000].
temperature: Controls the randomness of predictions. Range: [0, 1].
stop_sequences: Customized stop sequences to stop the decoding process.
Yields:
A stream of `TextGenerationResponse` objects that contain partial
responses produced by the model.
"""
prediction_request = self._create_prediction_request(
prefix=prefix,
suffix=suffix,
max_output_tokens=max_output_tokens,
temperature=temperature,
stop_sequences=stop_sequences,
)

prediction_service_async_client = self._endpoint._prediction_async_client
async for prediction_dict in _streaming_prediction.predict_stream_of_dicts_from_single_dict_async(
prediction_service_async_client=prediction_service_async_client,
endpoint_name=self._endpoint_name,
instance=prediction_request.instance,
parameters=prediction_request.parameters,
):
prediction_obj = aiplatform.models.Prediction(
predictions=[prediction_dict],
deployed_model_id="",
)
yield _parse_text_generation_model_response(prediction_obj)


class _PreviewCodeGenerationModel(CodeGenerationModel, _TunableModelMixin):
__name__ = "CodeGenerationModel"
Expand Down

0 comments on commit 760a025

Please sign in to comment.