diff --git a/src/openai/_base_client.py b/src/openai/_base_client.py index 17863bc067..4d8d94f661 100644 --- a/src/openai/_base_client.py +++ b/src/openai/_base_client.py @@ -809,11 +809,61 @@ def _idempotency_key(self) -> str: return f"stainless-python-retry-{uuid.uuid4()}" +def _sanitize_no_proxy(value: str) -> str: + if "\n" not in value and "\r" not in value: + return value + + return ",".join(part.strip() for part in value.replace("\r", ",").replace("\n", ",").split(",") if part.strip()) + + +def _get_sanitized_environment_proxies() -> dict[str, str | None]: + from urllib.request import getproxies + + from httpx._utils import is_ipv4_hostname, is_ipv6_hostname # pyright: ignore[reportPrivateImportUsage] + + proxy_info = getproxies() + mounts: dict[str, str | None] = {} + + for scheme in ("http", "https", "all"): + if proxy_info.get(scheme): + hostname = proxy_info[scheme] + mounts[f"{scheme}://"] = hostname if "://" in hostname else f"http://{hostname}" + + no_proxy_hosts = [host.strip() for host in _sanitize_no_proxy(proxy_info.get("no", "")).split(",")] + for hostname in no_proxy_hosts: + if hostname == "*": + return {} + elif hostname: + if "://" in hostname: + mounts[hostname] = None + elif is_ipv4_hostname(hostname): + mounts[f"all://{hostname}"] = None + elif is_ipv6_hostname(hostname): + mounts[f"all://[{hostname}]"] = None + elif hostname.lower() == "localhost": + mounts[f"all://{hostname}"] = None + else: + mounts[f"all://*{hostname}"] = None + + return mounts + + class _DefaultHttpxClient(httpx.Client): + @override + def _get_proxy_map(self, proxy: Any | None, allow_env_proxies: bool) -> dict[str, httpx.Proxy | None]: + if proxy is None and allow_env_proxies: + return { + key: None if url is None else httpx.Proxy(url=url) + for key, url in _get_sanitized_environment_proxies().items() + } + + return super()._get_proxy_map(proxy, allow_env_proxies) + def __init__(self, **kwargs: Any) -> None: kwargs.setdefault("timeout", DEFAULT_TIMEOUT) kwargs.setdefault("limits", DEFAULT_CONNECTION_LIMITS) kwargs.setdefault("follow_redirects", True) + super().__init__(**kwargs) @@ -1384,10 +1434,21 @@ def get_api_list( class _DefaultAsyncHttpxClient(httpx.AsyncClient): + @override + def _get_proxy_map(self, proxy: Any | None, allow_env_proxies: bool) -> dict[str, httpx.Proxy | None]: + if proxy is None and allow_env_proxies: + return { + key: None if url is None else httpx.Proxy(url=url) + for key, url in _get_sanitized_environment_proxies().items() + } + + return super()._get_proxy_map(proxy, allow_env_proxies) + def __init__(self, **kwargs: Any) -> None: kwargs.setdefault("timeout", DEFAULT_TIMEOUT) kwargs.setdefault("limits", DEFAULT_CONNECTION_LIMITS) kwargs.setdefault("follow_redirects", True) + super().__init__(**kwargs) diff --git a/tests/test_client.py b/tests/test_client.py index 396f6dea99..df4a3a1e05 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -1043,6 +1043,15 @@ def test_proxy_environment_variables(self, monkeypatch: pytest.MonkeyPatch) -> N assert len(mounts) == 1 assert mounts[0][0].pattern == "https://" + def test_no_proxy_environment_variable_with_newlines(self, monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setenv("NO_PROXY", "localhost\n192.168.1.1") + + client = DefaultHttpxClient() + + patterns = {mount.pattern for mount in client._mounts} + assert patterns == {"all://localhost", "all://192.168.1.1"} + assert os.environ["NO_PROXY"] == "localhost\n192.168.1.1" + @pytest.mark.filterwarnings("ignore:.*deprecated.*:DeprecationWarning") def test_default_client_creation(self) -> None: # Ensure that the client can be initialized without any exceptions @@ -2086,6 +2095,15 @@ async def test_proxy_environment_variables(self, monkeypatch: pytest.MonkeyPatch assert len(mounts) == 1 assert mounts[0][0].pattern == "https://" + async def test_no_proxy_environment_variable_with_newlines(self, monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setenv("NO_PROXY", "localhost\n192.168.1.1") + + client = DefaultAsyncHttpxClient() + + patterns = {mount.pattern for mount in client._mounts} + assert patterns == {"all://localhost", "all://192.168.1.1"} + assert os.environ["NO_PROXY"] == "localhost\n192.168.1.1" + @pytest.mark.filterwarnings("ignore:.*deprecated.*:DeprecationWarning") async def test_default_client_creation(self) -> None: # Ensure that the client can be initialized without any exceptions