diff --git a/haystack/components/generators/azure.py b/haystack/components/generators/azure.py index 388b9eb995..2c432d8231 100644 --- a/haystack/components/generators/azure.py +++ b/haystack/components/generators/azure.py @@ -8,7 +8,7 @@ # pylint: disable=import-error from openai.lib.azure import AzureOpenAI -from haystack import default_from_dict, default_to_dict, logging +from haystack import component, default_from_dict, default_to_dict, logging from haystack.components.generators import OpenAIGenerator from haystack.dataclasses import StreamingChunk from haystack.utils import Secret, deserialize_callable, deserialize_secrets_inplace, serialize_callable @@ -16,6 +16,7 @@ logger = logging.getLogger(__name__) +@component class AzureOpenAIGenerator(OpenAIGenerator): """ A Generator component that uses OpenAI's large language models (LLMs) on Azure to generate text. diff --git a/haystack/components/generators/chat/azure.py b/haystack/components/generators/chat/azure.py index 189b0c949a..b6cd2e153b 100644 --- a/haystack/components/generators/chat/azure.py +++ b/haystack/components/generators/chat/azure.py @@ -8,7 +8,7 @@ # pylint: disable=import-error from openai.lib.azure import AzureOpenAI -from haystack import default_from_dict, default_to_dict, logging +from haystack import component, default_from_dict, default_to_dict, logging from haystack.components.generators.chat import OpenAIChatGenerator from haystack.dataclasses import StreamingChunk from haystack.utils import Secret, deserialize_callable, deserialize_secrets_inplace, serialize_callable @@ -16,6 +16,7 @@ logger = logging.getLogger(__name__) +@component class AzureOpenAIChatGenerator(OpenAIChatGenerator): """ A Chat Generator component that uses the Azure OpenAI API to generate text. diff --git a/releasenotes/notes/fix-azure-generators-serialization-18fcdc9cbcb3732e.yaml b/releasenotes/notes/fix-azure-generators-serialization-18fcdc9cbcb3732e.yaml new file mode 100644 index 0000000000..071969e62e --- /dev/null +++ b/releasenotes/notes/fix-azure-generators-serialization-18fcdc9cbcb3732e.yaml @@ -0,0 +1,4 @@ +--- +fixes: + - | + Azure generators components fixed, they were missing the `@component` decorator. diff --git a/test/components/generators/chat/test_azure.py b/test/components/generators/chat/test_azure.py index 4b92fdee50..c9693caac8 100644 --- a/test/components/generators/chat/test_azure.py +++ b/test/components/generators/chat/test_azure.py @@ -6,6 +6,7 @@ import pytest from openai import OpenAIError +from haystack import Pipeline from haystack.components.generators.chat import AzureOpenAIChatGenerator from haystack.components.generators.utils import print_streaming_chunk from haystack.dataclasses import ChatMessage @@ -80,6 +81,15 @@ def test_to_dict_with_parameters(self, monkeypatch): }, } + def test_pipeline_serialization_deserialization(self, tmp_path, monkeypatch): + monkeypatch.setenv("AZURE_OPENAI_API_KEY", "test-api-key") + generator = AzureOpenAIChatGenerator(azure_endpoint="some-non-existing-endpoint") + p = Pipeline() + p.add_component(instance=generator, name="generator") + p_str = p.dumps() + q = Pipeline.loads(p_str) + assert p.to_dict() == q.to_dict(), "Pipeline serialization/deserialization w/ AzureOpenAIChatGenerator failed." + @pytest.mark.integration @pytest.mark.skipif( not os.environ.get("AZURE_OPENAI_API_KEY", None) and not os.environ.get("AZURE_OPENAI_ENDPOINT", None), diff --git a/test/components/generators/test_azure.py b/test/components/generators/test_azure.py index b2373eb68f..d5d52b6f26 100644 --- a/test/components/generators/test_azure.py +++ b/test/components/generators/test_azure.py @@ -2,6 +2,8 @@ # # SPDX-License-Identifier: Apache-2.0 import os + +from haystack import Pipeline from haystack.utils.auth import Secret import pytest @@ -83,6 +85,15 @@ def test_to_dict_with_parameters(self, monkeypatch): }, } + def test_pipeline_serialization_deserialization(self, tmp_path, monkeypatch): + monkeypatch.setenv("AZURE_OPENAI_API_KEY", "test-api-key") + generator = AzureOpenAIGenerator(azure_endpoint="some-non-existing-endpoint") + p = Pipeline() + p.add_component(instance=generator, name="generator") + p_str = p.dumps() + q = Pipeline.loads(p_str) + assert p.to_dict() == q.to_dict(), "Pipeline serialization/deserialization with AzureOpenAIGenerator failed." + @pytest.mark.integration @pytest.mark.skipif( not os.environ.get("AZURE_OPENAI_API_KEY", None) and not os.environ.get("AZURE_OPENAI_ENDPOINT", None),