From 22ba2adbeefb6f62bbc9d4d74733ecf3367caf22 Mon Sep 17 00:00:00 2001 From: Jordan Wu <101218661+jordan-definitive@users.noreply.github.com> Date: Wed, 28 Aug 2024 15:31:40 -0700 Subject: [PATCH] feat(internal): handle streaming error --- src/groq/_streaming.py | 79 ++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 76 insertions(+), 3 deletions(-) diff --git a/src/groq/_streaming.py b/src/groq/_streaming.py index 73cf6841..01defb2a 100755 --- a/src/groq/_streaming.py +++ b/src/groq/_streaming.py @@ -9,7 +9,8 @@ import httpx -from ._utils import extract_type_var_from_base +from ._utils import is_mapping, extract_type_var_from_base +from ._exceptions import APIError if TYPE_CHECKING: from ._client import Groq, AsyncGroq @@ -57,7 +58,43 @@ def __stream__(self) -> Iterator[_T]: for sse in iterator: if sse.data.startswith("[DONE]"): break - yield process_data(data=sse.json(), cast_to=cast_to, response=response) + + if sse.event is None: + data = sse.json() + if is_mapping(data) and data.get("error"): + message = None + error = data.get("error") + if is_mapping(error): + message = error.get("message") + if not message or not isinstance(message, str): + message = "An error occurred during streaming" + + raise APIError( + message=message, + request=self.response.request, + body=data["error"], + ) + + yield process_data(data=data, cast_to=cast_to, response=response) + + else: + data = sse.json() + + if sse.event == "error" and is_mapping(data) and data.get("error"): + message = None + error = data.get("error") + if is_mapping(error): + message = error.get("message") + if not message or not isinstance(message, str): + message = "An error occurred during streaming" + + raise APIError( + message=message, + request=self.response.request, + body=data["error"], + ) + + yield process_data(data={"data": data, "event": sse.event}, cast_to=cast_to, response=response) # Ensure the entire stream is consumed for _sse in iterator: @@ -123,7 +160,43 @@ async def __stream__(self) -> AsyncIterator[_T]: async for sse in iterator: if sse.data.startswith("[DONE]"): break - yield process_data(data=sse.json(), cast_to=cast_to, response=response) + + if sse.event is None: + data = sse.json() + if is_mapping(data) and data.get("error"): + message = None + error = data.get("error") + if is_mapping(error): + message = error.get("message") + if not message or not isinstance(message, str): + message = "An error occurred during streaming" + + raise APIError( + message=message, + request=self.response.request, + body=data["error"], + ) + + yield process_data(data=data, cast_to=cast_to, response=response) + + else: + data = sse.json() + + if sse.event == "error" and is_mapping(data) and data.get("error"): + message = None + error = data.get("error") + if is_mapping(error): + message = error.get("message") + if not message or not isinstance(message, str): + message = "An error occurred during streaming" + + raise APIError( + message=message, + request=self.response.request, + body=data["error"], + ) + + yield process_data(data={"data": data, "event": sse.event}, cast_to=cast_to, response=response) # Ensure the entire stream is consumed async for _sse in iterator: