Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
257 changes: 191 additions & 66 deletions google/genai/_api_client.py

Large diffs are not rendered by default.

5 changes: 3 additions & 2 deletions google/genai/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,9 @@ def _nextgen_client(self) -> AsyncGeminiNextGenAPIClient:
stacklevel=5,
)

http_client: httpx.AsyncClient = self._api_client._async_httpx_client
http_client: Optional[httpx.AsyncClient] = (
self._api_client._async_httpx_client
)

async_client_args = self._api_client._http_options.async_client_args or {}
has_custom_transport = 'transport' in async_client_args
Expand Down Expand Up @@ -308,7 +310,6 @@ class DebugConfig(pydantic.BaseModel):
)



class Client:
"""Client for making synchronous requests.

Expand Down
59 changes: 45 additions & 14 deletions google/genai/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,19 +18,25 @@
from typing import Any, Callable, Optional, TYPE_CHECKING, Union
import httpx
import json
import websockets
import requests
from . import _common


if TYPE_CHECKING:
from .replay_api_client import ReplayResponse
import aiohttp
from google.auth.aio.transport.aiohttp import Response as AsyncAuthorizedSessionResponse


class APIError(Exception):
"""General errors raised by the GenAI API."""
code: int
response: Union['ReplayResponse', httpx.Response, 'aiohttp.ClientResponse']
response: Union[
requests.Response,
'ReplayResponse',
httpx.Response,
'AsyncAuthorizedSessionResponse',
]

status: Optional[str] = None
message: Optional[str] = None
Expand All @@ -40,7 +46,12 @@ def __init__(
code: int,
response_json: Any,
response: Optional[
Union['ReplayResponse', httpx.Response, 'aiohttp.ClientResponse']
Union[
requests.Response,
'ReplayResponse',
httpx.Response,
'AsyncAuthorizedSessionResponse',
]
] = None,
):
if isinstance(response_json, list) and len(response_json) == 1:
Expand Down Expand Up @@ -112,7 +123,7 @@ def _to_replay_record(self) -> _common.StringDict:

@classmethod
def raise_for_response(
cls, response: Union['ReplayResponse', httpx.Response]
cls, response: Union['ReplayResponse', httpx.Response, requests.Response]
) -> None:
"""Raises an error with detailed error message if the response has an error status."""
if response.status_code == 200:
Expand All @@ -128,6 +139,16 @@ def raise_for_response(
'message': message,
'status': response.reason_phrase,
}
elif isinstance(response, requests.Response):
try:
# do not do any extra muanipulation on the response.
# return the raw response json as is.
response_json = response.json()
except requests.exceptions.JSONDecodeError:
response_json = {
'message': response.text,
'status': response.reason,
}
else:
response_json = response.body_segments[0].get('error', {})

Expand All @@ -139,7 +160,11 @@ def raise_error(
status_code: int,
response_json: Any,
response: Optional[
Union['ReplayResponse', httpx.Response, 'aiohttp.ClientResponse']
Union[
'ReplayResponse',
httpx.Response,
requests.Response,
]
],
) -> None:
"""Raises an appropriate APIError subclass based on the status code.
Expand All @@ -166,12 +191,13 @@ def raise_error(
async def raise_for_async_response(
cls,
response: Union[
'ReplayResponse', httpx.Response, 'aiohttp.ClientResponse'
'ReplayResponse',
httpx.Response,
'aiohttp.ClientResponse',
'AsyncAuthorizedSessionResponse',
],
) -> None:
"""Raises an error with detailed error message if the response has an error status."""
status_code = 0
response_json = None
if isinstance(response, httpx.Response):
if response.status_code == 200:
return
Expand All @@ -196,18 +222,23 @@ async def raise_for_async_response(
try:
import aiohttp # pylint: disable=g-import-not-at-top

if isinstance(response, aiohttp.ClientResponse):
if response.status == 200:
# Use a local variable to help Mypy handle the unwrapped response
unwrapped_response: Any = response
if hasattr(unwrapped_response, '_response'):
unwrapped_response = unwrapped_response._response

if isinstance(unwrapped_response, aiohttp.ClientResponse):
if unwrapped_response.status == 200:
return
try:
response_json = await response.json()
response_json = await unwrapped_response.json()
except aiohttp.client_exceptions.ContentTypeError:
message = await response.text()
message = await unwrapped_response.text()
response_json = {
'message': message,
'status': response.reason,
'status': unwrapped_response.reason,
}
status_code = response.status
status_code = unwrapped_response.status
else:
raise ValueError(f'Unsupported response type: {type(response)}')
except ImportError:
Expand Down
8 changes: 8 additions & 0 deletions google/genai/tests/client/test_client_close.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ def test_close_httpx_client():
vertexai=True,
project='test_project',
location='global',
http_options=api_client.HttpOptions(client_args={'max_redirects': 10}),
)
client.close()
assert client._api_client._httpx_client.is_closed
Expand All @@ -55,6 +56,7 @@ def test_httpx_client_context_manager():
vertexai=True,
project='test_project',
location='global',
http_options=api_client.HttpOptions(client_args={'max_redirects': 10}),
) as client:
pass
assert not client._api_client._httpx_client.is_closed
Expand Down Expand Up @@ -135,6 +137,9 @@ async def run():
vertexai=True,
project='test_project',
location='global',
http_options=api_client.HttpOptions(
async_client_args={'trust_env': False}
),
).aio
# aiohttp session is created in the first request instead of client
# initialization.
Expand Down Expand Up @@ -176,6 +181,9 @@ async def run():
vertexai=True,
project='test_project',
location='global',
http_options=api_client.HttpOptions(
async_client_args={'trust_env': False}
),
).aio as async_client:
# aiohttp session is created in the first request instead of client
# initialization.
Expand Down
33 changes: 24 additions & 9 deletions google/genai/tests/client/test_client_initialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import concurrent.futures
import logging
import os
import requests
import ssl
from unittest import mock

Expand Down Expand Up @@ -1331,18 +1332,32 @@ def refresh_side_effect(request):
mock_refresh = mock.Mock(side_effect=refresh_side_effect)
mock_creds.refresh = mock_refresh

# Mock the actual request to avoid network calls
mock_httpx_response = httpx.Response(
status_code=200,
headers={},
text='{"candidates": [{"content": {"parts": [{"text": "response"}]}}]}',
)
mock_request = mock.Mock(return_value=mock_httpx_response)
monkeypatch.setattr(api_client.SyncHttpxClient, "request", mock_request)

client = Client(
vertexai=True, project="fake_project_id", location="fake-location"
)
# Mock the actual request to avoid network calls
if client._api_client._use_google_auth_sync():
# Cloud environment enables mTLS and uses requests.Response
mock_http_response = requests.Response()
mock_http_response.status_code = 200
mock_http_response.headers = {}
mock_http_response._content = (
b'{"candidates": [{"content": {"parts": [{"text": "response"}]}}]}'
)
mock_request = mock.Mock(return_value=mock_http_response)
monkeypatch.setattr(
google.auth.transport.requests.AuthorizedSession, "request", mock_request
)
else:
# Non-cloud environment w/o certificates uses httpx.Response
mock_httpx_response = httpx.Response(
status_code=200,
headers={},
text='{"candidates": [{"content": {"parts": [{"text": "response"}]}}]}',
)
mock_request = mock.Mock(return_value=mock_httpx_response)
monkeypatch.setattr(api_client.SyncHttpxClient, "send", mock_request)

# Reset credentials to test initialization to ensure the sync lock is tested.
client._api_client._credentials = None

Expand Down
Loading
Loading