diff --git a/src/cohere/embed_jobs/raw_client.py b/src/cohere/embed_jobs/raw_client.py index 11c9106e4..8f7dfa067 100644 --- a/src/cohere/embed_jobs/raw_client.py +++ b/src/cohere/embed_jobs/raw_client.py @@ -28,6 +28,21 @@ from ..types.list_embed_job_response import ListEmbedJobResponse from .types.create_embed_job_request_truncate import CreateEmbedJobRequestTruncate +_ERROR_MAP = { + 400: BadRequestError, + 401: UnauthorizedError, + 403: ForbiddenError, + 404: NotFoundError, + 422: UnprocessableEntityError, + 429: TooManyRequestsError, + 498: InvalidTokenError, + 499: ClientClosedRequestError, + 500: InternalServerError, + 501: NotImplementedError, + 503: ServiceUnavailableError, + 504: GatewayTimeoutError, +} + # this is used as the default value for optional parameters OMIT = typing.cast(typing.Any, ...) @@ -612,145 +627,18 @@ def cancel(self, id: str, *, request_options: typing.Optional[RequestOptions] = method="POST", request_options=request_options, ) + status_code = _response.status_code try: - if 200 <= _response.status_code < 300: + if 200 <= status_code < 300: return HttpResponse(response=_response, data=None) - if _response.status_code == 400: - raise BadRequestError( - headers=dict(_response.headers), - body=typing.cast( - typing.Optional[typing.Any], - construct_type( - type_=typing.Optional[typing.Any], # type: ignore - object_=_response.json(), - ), - ), - ) - if _response.status_code == 401: - raise UnauthorizedError( - headers=dict(_response.headers), - body=typing.cast( - typing.Optional[typing.Any], - construct_type( - type_=typing.Optional[typing.Any], # type: ignore - object_=_response.json(), - ), - ), - ) - if _response.status_code == 403: - raise ForbiddenError( - headers=dict(_response.headers), - body=typing.cast( - typing.Optional[typing.Any], - construct_type( - type_=typing.Optional[typing.Any], # type: ignore - object_=_response.json(), - ), - ), - ) - if _response.status_code == 404: - raise NotFoundError( - headers=dict(_response.headers), - body=typing.cast( - typing.Optional[typing.Any], - construct_type( - type_=typing.Optional[typing.Any], # type: ignore - object_=_response.json(), - ), - ), - ) - if _response.status_code == 422: - raise UnprocessableEntityError( - headers=dict(_response.headers), - body=typing.cast( - typing.Optional[typing.Any], - construct_type( - type_=typing.Optional[typing.Any], # type: ignore - object_=_response.json(), - ), - ), - ) - if _response.status_code == 429: - raise TooManyRequestsError( - headers=dict(_response.headers), - body=typing.cast( - typing.Optional[typing.Any], - construct_type( - type_=typing.Optional[typing.Any], # type: ignore - object_=_response.json(), - ), - ), - ) - if _response.status_code == 498: - raise InvalidTokenError( - headers=dict(_response.headers), - body=typing.cast( - typing.Optional[typing.Any], - construct_type( - type_=typing.Optional[typing.Any], # type: ignore - object_=_response.json(), - ), - ), - ) - if _response.status_code == 499: - raise ClientClosedRequestError( - headers=dict(_response.headers), - body=typing.cast( - typing.Optional[typing.Any], - construct_type( - type_=typing.Optional[typing.Any], # type: ignore - object_=_response.json(), - ), - ), - ) - if _response.status_code == 500: - raise InternalServerError( - headers=dict(_response.headers), - body=typing.cast( - typing.Optional[typing.Any], - construct_type( - type_=typing.Optional[typing.Any], # type: ignore - object_=_response.json(), - ), - ), - ) - if _response.status_code == 501: - raise NotImplementedError( - headers=dict(_response.headers), - body=typing.cast( - typing.Optional[typing.Any], - construct_type( - type_=typing.Optional[typing.Any], # type: ignore - object_=_response.json(), - ), - ), - ) - if _response.status_code == 503: - raise ServiceUnavailableError( - headers=dict(_response.headers), - body=typing.cast( - typing.Optional[typing.Any], - construct_type( - type_=typing.Optional[typing.Any], # type: ignore - object_=_response.json(), - ), - ), - ) - if _response.status_code == 504: - raise GatewayTimeoutError( - headers=dict(_response.headers), - body=typing.cast( - typing.Optional[typing.Any], - construct_type( - type_=typing.Optional[typing.Any], # type: ignore - object_=_response.json(), - ), - ), - ) + error_cls = _ERROR_MAP.get(status_code) + if error_cls is not None: + response_json = _response.json() + raise _build_error(error_cls, response_json, _response.headers) _response_json = _response.json() except JSONDecodeError: - raise ApiError(status_code=_response.status_code, headers=dict(_response.headers), body=_response.text) - raise ApiError(status_code=_response.status_code, headers=dict(_response.headers), body=_response_json) + raise ApiError(status_code=status_code, headers=dict(_response.headers), body=_response.text) + raise ApiError(status_code=status_code, headers=dict(_response.headers), body=_response_json) class AsyncRawEmbedJobsClient: @@ -1478,3 +1366,16 @@ async def cancel( except JSONDecodeError: raise ApiError(status_code=_response.status_code, headers=dict(_response.headers), body=_response.text) raise ApiError(status_code=_response.status_code, headers=dict(_response.headers), body=_response_json) + + +def _build_error(error_cls, response_json, response_headers): + return error_cls( + headers=dict(response_headers), + body=typing.cast( + typing.Optional[typing.Any], + construct_type( + type_=typing.Optional[typing.Any], # type: ignore + object_=response_json, + ), + ), + )