diff --git a/src/openai/_base_client.py b/src/openai/_base_client.py index 7a8595c17..502ed7c7a 100644 --- a/src/openai/_base_client.py +++ b/src/openai/_base_client.py @@ -361,6 +361,11 @@ def __init__( self._strict_response_validation = _strict_response_validation self._idempotency_header = None + if max_retries is None: # pyright: ignore[reportUnnecessaryComparison] + raise TypeError( + "max_retries cannot be None. If you want to disable retries, pass `0`; if you want unlimited retries, pass `math.inf` or a very high number; if you want the default behavior, pass `openai.DEFAULT_MAX_RETRIES`" + ) + def _enforce_trailing_slash(self, url: URL) -> URL: if url.raw_path.endswith(b"/"): return url diff --git a/tests/test_client.py b/tests/test_client.py index dab1cb0ef..ba85fd9d5 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -646,6 +646,10 @@ class Model(BaseModel): assert isinstance(exc.value.__cause__, ValidationError) + def test_client_max_retries_validation(self) -> None: + with pytest.raises(TypeError, match=r"max_retries cannot be None"): + OpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True, max_retries=cast(Any, None)) + @pytest.mark.respx(base_url=base_url) def test_default_stream_cls(self, respx_mock: MockRouter) -> None: class Model(BaseModel): @@ -1368,6 +1372,12 @@ class Model(BaseModel): assert isinstance(exc.value.__cause__, ValidationError) + async def test_client_max_retries_validation(self) -> None: + with pytest.raises(TypeError, match=r"max_retries cannot be None"): + AsyncOpenAI( + base_url=base_url, api_key=api_key, _strict_response_validation=True, max_retries=cast(Any, None) + ) + @pytest.mark.respx(base_url=base_url) @pytest.mark.asyncio async def test_default_stream_cls(self, respx_mock: MockRouter) -> None: