Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 61 additions & 0 deletions src/openai/_base_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -809,11 +809,61 @@ def _idempotency_key(self) -> str:
return f"stainless-python-retry-{uuid.uuid4()}"


def _sanitize_no_proxy(value: str) -> str:
if "\n" not in value and "\r" not in value:
return value

return ",".join(part.strip() for part in value.replace("\r", ",").replace("\n", ",").split(",") if part.strip())


def _get_sanitized_environment_proxies() -> dict[str, str | None]:
from urllib.request import getproxies

from httpx._utils import is_ipv4_hostname, is_ipv6_hostname # pyright: ignore[reportPrivateImportUsage]

proxy_info = getproxies()
mounts: dict[str, str | None] = {}

for scheme in ("http", "https", "all"):
if proxy_info.get(scheme):
hostname = proxy_info[scheme]
mounts[f"{scheme}://"] = hostname if "://" in hostname else f"http://{hostname}"

no_proxy_hosts = [host.strip() for host in _sanitize_no_proxy(proxy_info.get("no", "")).split(",")]
for hostname in no_proxy_hosts:
if hostname == "*":
return {}
elif hostname:
if "://" in hostname:
mounts[hostname] = None
elif is_ipv4_hostname(hostname):
mounts[f"all://{hostname}"] = None
elif is_ipv6_hostname(hostname):
mounts[f"all://[{hostname}]"] = None
elif hostname.lower() == "localhost":
mounts[f"all://{hostname}"] = None
else:
mounts[f"all://*{hostname}"] = None

return mounts


class _DefaultHttpxClient(httpx.Client):
@override
def _get_proxy_map(self, proxy: Any | None, allow_env_proxies: bool) -> dict[str, httpx.Proxy | None]:
if proxy is None and allow_env_proxies:
return {
key: None if url is None else httpx.Proxy(url=url)
for key, url in _get_sanitized_environment_proxies().items()
}

return super()._get_proxy_map(proxy, allow_env_proxies)

def __init__(self, **kwargs: Any) -> None:
kwargs.setdefault("timeout", DEFAULT_TIMEOUT)
kwargs.setdefault("limits", DEFAULT_CONNECTION_LIMITS)
kwargs.setdefault("follow_redirects", True)

super().__init__(**kwargs)


Expand Down Expand Up @@ -1384,10 +1434,21 @@ def get_api_list(


class _DefaultAsyncHttpxClient(httpx.AsyncClient):
@override
def _get_proxy_map(self, proxy: Any | None, allow_env_proxies: bool) -> dict[str, httpx.Proxy | None]:
if proxy is None and allow_env_proxies:
return {
key: None if url is None else httpx.Proxy(url=url)
for key, url in _get_sanitized_environment_proxies().items()
}

return super()._get_proxy_map(proxy, allow_env_proxies)

def __init__(self, **kwargs: Any) -> None:
kwargs.setdefault("timeout", DEFAULT_TIMEOUT)
kwargs.setdefault("limits", DEFAULT_CONNECTION_LIMITS)
kwargs.setdefault("follow_redirects", True)

super().__init__(**kwargs)


Expand Down
18 changes: 18 additions & 0 deletions tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -1043,6 +1043,15 @@ def test_proxy_environment_variables(self, monkeypatch: pytest.MonkeyPatch) -> N
assert len(mounts) == 1
assert mounts[0][0].pattern == "https://"

def test_no_proxy_environment_variable_with_newlines(self, monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setenv("NO_PROXY", "localhost\n192.168.1.1")

client = DefaultHttpxClient()

patterns = {mount.pattern for mount in client._mounts}
assert patterns == {"all://localhost", "all://192.168.1.1"}
assert os.environ["NO_PROXY"] == "localhost\n192.168.1.1"

@pytest.mark.filterwarnings("ignore:.*deprecated.*:DeprecationWarning")
def test_default_client_creation(self) -> None:
# Ensure that the client can be initialized without any exceptions
Expand Down Expand Up @@ -2086,6 +2095,15 @@ async def test_proxy_environment_variables(self, monkeypatch: pytest.MonkeyPatch
assert len(mounts) == 1
assert mounts[0][0].pattern == "https://"

async def test_no_proxy_environment_variable_with_newlines(self, monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setenv("NO_PROXY", "localhost\n192.168.1.1")

client = DefaultAsyncHttpxClient()

patterns = {mount.pattern for mount in client._mounts}
assert patterns == {"all://localhost", "all://192.168.1.1"}
assert os.environ["NO_PROXY"] == "localhost\n192.168.1.1"

@pytest.mark.filterwarnings("ignore:.*deprecated.*:DeprecationWarning")
async def test_default_client_creation(self) -> None:
# Ensure that the client can be initialized without any exceptions
Expand Down