Skip to content

Commit

Permalink
Handle TGI error when streaming tokens (#1711)
Browse files Browse the repository at this point in the history
* Handle TGI error when streaming tokens

* make quality
  • Loading branch information
Wauplin committed Oct 5, 2023
1 parent cbcd8b2 commit 5d2d297
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 14 deletions.
14 changes: 9 additions & 5 deletions src/huggingface_hub/inference/_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,9 +48,7 @@
is_numpy_available,
is_pillow_available,
)
from ._text_generation import (
TextGenerationStreamResponse,
)
from ._text_generation import TextGenerationStreamResponse, _parse_text_generation_error


if TYPE_CHECKING:
Expand Down Expand Up @@ -275,7 +273,10 @@ def _stream_text_generation_response(
if payload.startswith("data:"):
# Decode payload
json_payload = json.loads(payload.lstrip("data:").rstrip("/n"))
# Parse payload
# Either an error as being returned
if json_payload.get("error") is not None:
raise _parse_text_generation_error(json_payload["error"], json_payload.get("error_type"))
# Or parse token payload
output = TextGenerationStreamResponse(**json_payload)
yield output.token.text if not details else output

Expand All @@ -295,7 +296,10 @@ async def _async_stream_text_generation_response(
if payload.startswith("data:"):
# Decode payload
json_payload = json.loads(payload.lstrip("data:").rstrip("/n"))
# Parse payload
# Either an error as being returned
if json_payload.get("error") is not None:
raise _parse_text_generation_error(json_payload["error"], json_payload.get("error_type"))
# Or parse token payload
output = TextGenerationStreamResponse(**json_payload)
yield output.token.text if not details else output

Expand Down
28 changes: 19 additions & 9 deletions src/huggingface_hub/inference/_text_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -447,6 +447,10 @@ class IncompleteGenerationError(TextGenerationError):
pass


class UnknownError(TextGenerationError):
pass


def raise_text_generation_error(http_error: HTTPError) -> NoReturn:
"""
Try to parse text-generation-inference error message and raise HTTPError in any case.
Expand All @@ -460,21 +464,27 @@ def raise_text_generation_error(http_error: HTTPError) -> NoReturn:
try:
# Hacky way to retrieve payload in case of aiohttp error
payload = getattr(http_error, "response_error_payload", None) or http_error.response.json()
message = payload.get("error")
error = payload.get("error")
error_type = payload.get("error_type")
except Exception: # no payload
raise http_error

# If error_type => more information than `hf_raise_for_status`
if error_type is not None:
if error_type == "generation":
raise GenerationError(message) from http_error # type: ignore
if error_type == "incomplete_generation":
raise IncompleteGenerationError(message) from http_error # type: ignore
if error_type == "overloaded":
raise OverloadedError(message) from http_error # type: ignore
if error_type == "validation":
raise ValidationError(message) from http_error # type: ignore
exception = _parse_text_generation_error(error, error_type)
raise exception from http_error

# Otherwise, fallback to default error
raise http_error


def _parse_text_generation_error(error: Optional[str], error_type: Optional[str]) -> TextGenerationError:
if error_type == "generation":
return GenerationError(error) # type: ignore
if error_type == "incomplete_generation":
return IncompleteGenerationError(error) # type: ignore
if error_type == "overloaded":
return OverloadedError(error) # type: ignore
if error_type == "validation":
return ValidationError(error) # type: ignore
return UnknownError(error) # type: ignore

0 comments on commit 5d2d297

Please sign in to comment.