Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
171 changes: 36 additions & 135 deletions src/cohere/embed_jobs/raw_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,21 @@
from ..types.list_embed_job_response import ListEmbedJobResponse
from .types.create_embed_job_request_truncate import CreateEmbedJobRequestTruncate

_ERROR_MAP = {
400: BadRequestError,
401: UnauthorizedError,
403: ForbiddenError,
404: NotFoundError,
422: UnprocessableEntityError,
429: TooManyRequestsError,
498: InvalidTokenError,
499: ClientClosedRequestError,
500: InternalServerError,
501: NotImplementedError,
503: ServiceUnavailableError,
504: GatewayTimeoutError,
}

# this is used as the default value for optional parameters
OMIT = typing.cast(typing.Any, ...)

Expand Down Expand Up @@ -612,145 +627,18 @@ def cancel(self, id: str, *, request_options: typing.Optional[RequestOptions] =
method="POST",
request_options=request_options,
)
status_code = _response.status_code
try:
if 200 <= _response.status_code < 300:
if 200 <= status_code < 300:
return HttpResponse(response=_response, data=None)
if _response.status_code == 400:
raise BadRequestError(
headers=dict(_response.headers),
body=typing.cast(
typing.Optional[typing.Any],
construct_type(
type_=typing.Optional[typing.Any], # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 401:
raise UnauthorizedError(
headers=dict(_response.headers),
body=typing.cast(
typing.Optional[typing.Any],
construct_type(
type_=typing.Optional[typing.Any], # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 403:
raise ForbiddenError(
headers=dict(_response.headers),
body=typing.cast(
typing.Optional[typing.Any],
construct_type(
type_=typing.Optional[typing.Any], # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 404:
raise NotFoundError(
headers=dict(_response.headers),
body=typing.cast(
typing.Optional[typing.Any],
construct_type(
type_=typing.Optional[typing.Any], # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 422:
raise UnprocessableEntityError(
headers=dict(_response.headers),
body=typing.cast(
typing.Optional[typing.Any],
construct_type(
type_=typing.Optional[typing.Any], # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 429:
raise TooManyRequestsError(
headers=dict(_response.headers),
body=typing.cast(
typing.Optional[typing.Any],
construct_type(
type_=typing.Optional[typing.Any], # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 498:
raise InvalidTokenError(
headers=dict(_response.headers),
body=typing.cast(
typing.Optional[typing.Any],
construct_type(
type_=typing.Optional[typing.Any], # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 499:
raise ClientClosedRequestError(
headers=dict(_response.headers),
body=typing.cast(
typing.Optional[typing.Any],
construct_type(
type_=typing.Optional[typing.Any], # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 500:
raise InternalServerError(
headers=dict(_response.headers),
body=typing.cast(
typing.Optional[typing.Any],
construct_type(
type_=typing.Optional[typing.Any], # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 501:
raise NotImplementedError(
headers=dict(_response.headers),
body=typing.cast(
typing.Optional[typing.Any],
construct_type(
type_=typing.Optional[typing.Any], # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 503:
raise ServiceUnavailableError(
headers=dict(_response.headers),
body=typing.cast(
typing.Optional[typing.Any],
construct_type(
type_=typing.Optional[typing.Any], # type: ignore
object_=_response.json(),
),
),
)
if _response.status_code == 504:
raise GatewayTimeoutError(
headers=dict(_response.headers),
body=typing.cast(
typing.Optional[typing.Any],
construct_type(
type_=typing.Optional[typing.Any], # type: ignore
object_=_response.json(),
),
),
)
error_cls = _ERROR_MAP.get(status_code)
if error_cls is not None:
response_json = _response.json()
raise _build_error(error_cls, response_json, _response.headers)
_response_json = _response.json()
except JSONDecodeError:
raise ApiError(status_code=_response.status_code, headers=dict(_response.headers), body=_response.text)
raise ApiError(status_code=_response.status_code, headers=dict(_response.headers), body=_response_json)
raise ApiError(status_code=status_code, headers=dict(_response.headers), body=_response.text)
raise ApiError(status_code=status_code, headers=dict(_response.headers), body=_response_json)


class AsyncRawEmbedJobsClient:
Expand Down Expand Up @@ -1478,3 +1366,16 @@ async def cancel(
except JSONDecodeError:
raise ApiError(status_code=_response.status_code, headers=dict(_response.headers), body=_response.text)
raise ApiError(status_code=_response.status_code, headers=dict(_response.headers), body=_response_json)


def _build_error(error_cls, response_json, response_headers):
return error_cls(
headers=dict(response_headers),
body=typing.cast(
typing.Optional[typing.Any],
construct_type(
type_=typing.Optional[typing.Any], # type: ignore
object_=response_json,
),
),
)