Skip to content

Commit

Permalink
add timeout to AzureOpenAIGenerator (#7724)
Browse files Browse the repository at this point in the history
* add timeout to AzureOpenAIGenerator

* add to chat also

* Update azure-openai-generator-timeout-c39ecd6d4b0cdb4b.yaml
  • Loading branch information
masci committed May 23, 2024
1 parent 83d3970 commit e3dccf4
Show file tree
Hide file tree
Showing 5 changed files with 19 additions and 0 deletions.
5 changes: 5 additions & 0 deletions haystack/components/generators/azure.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ def __init__(
organization: Optional[str] = None,
streaming_callback: Optional[Callable[[StreamingChunk], None]] = None,
system_prompt: Optional[str] = None,
timeout: Optional[float] = None,
generation_kwargs: Optional[Dict[str, Any]] = None,
):
"""
Expand All @@ -77,6 +78,7 @@ def __init__(
:param streaming_callback: A callback function that is called when a new token is received from the stream.
The callback function accepts StreamingChunk as an argument.
:param system_prompt: The prompt to use for the system. If not provided, the system prompt will be
:param timeout: The timeout to be passed to the underlying `AzureOpenAI` client.
:param generation_kwargs: Other parameters to use for the model. These parameters are all sent directly to
the OpenAI endpoint. See OpenAI [documentation](https://platform.openai.com/docs/api-reference/chat) for
more details.
Expand Down Expand Up @@ -123,6 +125,7 @@ def __init__(
self.azure_deployment = azure_deployment
self.organization = organization
self.model: str = azure_deployment or "gpt-35-turbo"
self.timeout = timeout

self.client = AzureOpenAI(
api_version=api_version,
Expand All @@ -131,6 +134,7 @@ def __init__(
api_key=api_key.resolve_value() if api_key is not None else None,
azure_ad_token=azure_ad_token.resolve_value() if azure_ad_token is not None else None,
organization=organization,
timeout=timeout,
)

def to_dict(self) -> Dict[str, Any]:
Expand All @@ -152,6 +156,7 @@ def to_dict(self) -> Dict[str, Any]:
system_prompt=self.system_prompt,
api_key=self.api_key.to_dict() if self.api_key is not None else None,
azure_ad_token=self.azure_ad_token.to_dict() if self.azure_ad_token is not None else None,
timeout=self.timeout,
)

@classmethod
Expand Down
3 changes: 3 additions & 0 deletions haystack/components/generators/chat/azure.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ def __init__(
azure_ad_token: Optional[Secret] = Secret.from_env_var("AZURE_OPENAI_AD_TOKEN", strict=False),
organization: Optional[str] = None,
streaming_callback: Optional[Callable[[StreamingChunk], None]] = None,
timeout: Optional[float] = None,
generation_kwargs: Optional[Dict[str, Any]] = None,
):
"""
Expand Down Expand Up @@ -139,6 +140,7 @@ def __init__(
self.azure_deployment = azure_deployment
self.organization = organization
self.model = azure_deployment or "gpt-35-turbo"
self.timeout = timeout

self.client = AzureOpenAI(
api_version=api_version,
Expand All @@ -165,6 +167,7 @@ def to_dict(self) -> Dict[str, Any]:
api_version=self.api_version,
streaming_callback=callback_name,
generation_kwargs=self.generation_kwargs,
timeout=self.timeout,
api_key=self.api_key.to_dict() if self.api_key is not None else None,
azure_ad_token=self.azure_ad_token.to_dict() if self.azure_ad_token is not None else None,
)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
---
features:
- |
`AzureOpenAIGenerator` and `AzureOpenAIChatGenerator` can now be configured passing a timeout for the underlying `AzureOpenAI` client.
3 changes: 3 additions & 0 deletions test/components/generators/chat/test_azure.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ def test_to_dict_default(self, monkeypatch):
"organization": None,
"streaming_callback": None,
"generation_kwargs": {},
"timeout": None,
},
}

Expand All @@ -64,6 +65,7 @@ def test_to_dict_with_parameters(self, monkeypatch):
api_key=Secret.from_env_var("ENV_VAR", strict=False),
azure_ad_token=Secret.from_env_var("ENV_VAR1", strict=False),
azure_endpoint="some-non-existing-endpoint",
timeout=2.5,
generation_kwargs={"max_tokens": 10, "some_test_param": "test-params"},
)
data = component.to_dict()
Expand All @@ -77,6 +79,7 @@ def test_to_dict_with_parameters(self, monkeypatch):
"azure_deployment": "gpt-35-turbo",
"organization": None,
"streaming_callback": None,
"timeout": 2.5,
"generation_kwargs": {"max_tokens": 10, "some_test_param": "test-params"},
},
}
Expand Down
4 changes: 4 additions & 0 deletions test/components/generators/test_azure.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ def test_init_with_parameters(self):
assert component.client.api_key == "fake-api-key"
assert component.azure_deployment == "gpt-35-turbo"
assert component.streaming_callback is print_streaming_chunk
assert component.timeout is None
assert component.generation_kwargs == {"max_tokens": 10, "some_test_param": "test-params"}

def test_to_dict_default(self, monkeypatch):
Expand All @@ -56,6 +57,7 @@ def test_to_dict_default(self, monkeypatch):
"azure_endpoint": "some-non-existing-endpoint",
"organization": None,
"system_prompt": None,
"timeout": None,
"generation_kwargs": {},
},
}
Expand All @@ -66,6 +68,7 @@ def test_to_dict_with_parameters(self, monkeypatch):
api_key=Secret.from_env_var("ENV_VAR", strict=False),
azure_ad_token=Secret.from_env_var("ENV_VAR1", strict=False),
azure_endpoint="some-non-existing-endpoint",
timeout=3.5,
generation_kwargs={"max_tokens": 10, "some_test_param": "test-params"},
)

Expand All @@ -81,6 +84,7 @@ def test_to_dict_with_parameters(self, monkeypatch):
"azure_endpoint": "some-non-existing-endpoint",
"organization": None,
"system_prompt": None,
"timeout": 3.5,
"generation_kwargs": {"max_tokens": 10, "some_test_param": "test-params"},
},
}
Expand Down

0 comments on commit e3dccf4

Please sign in to comment.