diff --git a/src/agents/retry.py b/src/agents/retry.py index f240a2d923..4b122cbfa7 100644 --- a/src/agents/retry.py +++ b/src/agents/retry.py @@ -16,13 +16,13 @@ class ModelRetryBackoffSettings: """Backoff configuration for runner-managed model retries.""" - initial_delay: float | None = None + initial_delay: float | None = Field(default=None, ge=0) """Delay in seconds before the first retry attempt.""" - max_delay: float | None = None + max_delay: float | None = Field(default=None, ge=0) """Maximum delay in seconds between retry attempts.""" - multiplier: float | None = None + multiplier: float | None = Field(default=None, ge=0) """Multiplier applied after each retry attempt.""" jitter: bool | None = None diff --git a/tests/models/test_model_retry.py b/tests/models/test_model_retry.py index 5a99efd282..04b9b1dc9c 100644 --- a/tests/models/test_model_retry.py +++ b/tests/models/test_model_retry.py @@ -7,6 +7,7 @@ import httpx import pytest from openai import APIConnectionError, APIStatusError, BadRequestError +from pydantic import ValidationError from agents.items import ModelResponse, TResponseStreamEvent from agents.models._openai_retry import get_openai_retry_advice @@ -28,6 +29,32 @@ from tests.test_responses import get_text_message +@pytest.mark.parametrize( + "make_backoff", + [ + lambda: ModelRetryBackoffSettings(initial_delay=-0.1), + lambda: ModelRetryBackoffSettings(max_delay=-0.1), + lambda: ModelRetryBackoffSettings(multiplier=-0.1), + ], +) +def test_model_retry_backoff_settings_reject_negative_values(make_backoff: Any) -> None: + with pytest.raises(ValidationError, match="greater than or equal to 0"): + make_backoff() + + +def test_model_retry_settings_rejects_negative_backoff_dict() -> None: + with pytest.raises(ValidationError, match="greater than or equal to 0"): + ModelRetrySettings(backoff={"initial_delay": -0.1}) + + +def test_model_retry_backoff_settings_allow_zero_values() -> None: + backoff = ModelRetryBackoffSettings(initial_delay=0, max_delay=0, multiplier=0) + + assert backoff.initial_delay == 0 + assert backoff.max_delay == 0 + assert backoff.multiplier == 0 + + def _connection_error(message: str = "connection error") -> APIConnectionError: return APIConnectionError( message=message,