Skip to content

Commit

Permalink
fix: LLM - Fixed the async streaming
Browse files Browse the repository at this point in the history
Fixes #2853

PiperOrigin-RevId: 577345792
  • Loading branch information
Ark-kun authored and Copybara-Service committed Oct 28, 2023
1 parent 087f3c8 commit 41bfcb6
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 2 deletions.
2 changes: 1 addition & 1 deletion google/cloud/aiplatform/_streaming_prediction.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ async def predict_stream_of_tensor_lists_from_single_tensor_list_async(
inputs=tensor_list,
parameters=parameters_tensor,
)
async for response in prediction_service_async_client.server_streaming_predict(
async for response in await prediction_service_async_client.server_streaming_predict(
request=request
):
yield response.outputs
Expand Down
5 changes: 4 additions & 1 deletion tests/unit/aiplatform/test_language_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -1484,12 +1484,15 @@ async def test_text_generation_model_predict_streaming_async(self):
"text-bison@001"
)

async def mock_server_streaming_predict_async(*args, **kwargs):
async def mock_server_streaming_predict_async_iter(*args, **kwargs):
for response_dict in _TEST_TEXT_GENERATION_PREDICTION_STREAMING:
yield gca_prediction_service.StreamingPredictResponse(
outputs=[_streaming_prediction.value_to_tensor(response_dict)]
)

async def mock_server_streaming_predict_async(*args, **kwargs):
return mock_server_streaming_predict_async_iter(*args, **kwargs)

with mock.patch.object(
target=prediction_service_async_client.PredictionServiceAsyncClient,
attribute="server_streaming_predict",
Expand Down

0 comments on commit 41bfcb6

Please sign in to comment.