Skip to content

Commit

Permalink
fix: Adding missing component decorator to AzureOpenAIGenerator (#7698
Browse files Browse the repository at this point in the history
)

* initial import

* adding release notes

* tests avoiding I/O operations

* Update fix-azure-generators-serialization-18fcdc9cbcb3732e.yaml
  • Loading branch information
davidsbatista committed May 15, 2024
1 parent cc1d4b1 commit 96b9d3e
Show file tree
Hide file tree
Showing 5 changed files with 29 additions and 2 deletions.
3 changes: 2 additions & 1 deletion haystack/components/generators/azure.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,15 @@
# pylint: disable=import-error
from openai.lib.azure import AzureOpenAI

from haystack import default_from_dict, default_to_dict, logging
from haystack import component, default_from_dict, default_to_dict, logging
from haystack.components.generators import OpenAIGenerator
from haystack.dataclasses import StreamingChunk
from haystack.utils import Secret, deserialize_callable, deserialize_secrets_inplace, serialize_callable

logger = logging.getLogger(__name__)


@component
class AzureOpenAIGenerator(OpenAIGenerator):
"""
A Generator component that uses OpenAI's large language models (LLMs) on Azure to generate text.
Expand Down
3 changes: 2 additions & 1 deletion haystack/components/generators/chat/azure.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,15 @@
# pylint: disable=import-error
from openai.lib.azure import AzureOpenAI

from haystack import default_from_dict, default_to_dict, logging
from haystack import component, default_from_dict, default_to_dict, logging
from haystack.components.generators.chat import OpenAIChatGenerator
from haystack.dataclasses import StreamingChunk
from haystack.utils import Secret, deserialize_callable, deserialize_secrets_inplace, serialize_callable

logger = logging.getLogger(__name__)


@component
class AzureOpenAIChatGenerator(OpenAIChatGenerator):
"""
A Chat Generator component that uses the Azure OpenAI API to generate text.
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
---
fixes:
- |
Azure generators components fixed, they were missing the `@component` decorator.
10 changes: 10 additions & 0 deletions test/components/generators/chat/test_azure.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import pytest
from openai import OpenAIError

from haystack import Pipeline
from haystack.components.generators.chat import AzureOpenAIChatGenerator
from haystack.components.generators.utils import print_streaming_chunk
from haystack.dataclasses import ChatMessage
Expand Down Expand Up @@ -80,6 +81,15 @@ def test_to_dict_with_parameters(self, monkeypatch):
},
}

def test_pipeline_serialization_deserialization(self, tmp_path, monkeypatch):
monkeypatch.setenv("AZURE_OPENAI_API_KEY", "test-api-key")
generator = AzureOpenAIChatGenerator(azure_endpoint="some-non-existing-endpoint")
p = Pipeline()
p.add_component(instance=generator, name="generator")
p_str = p.dumps()
q = Pipeline.loads(p_str)
assert p.to_dict() == q.to_dict(), "Pipeline serialization/deserialization w/ AzureOpenAIChatGenerator failed."

@pytest.mark.integration
@pytest.mark.skipif(
not os.environ.get("AZURE_OPENAI_API_KEY", None) and not os.environ.get("AZURE_OPENAI_ENDPOINT", None),
Expand Down
11 changes: 11 additions & 0 deletions test/components/generators/test_azure.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
#
# SPDX-License-Identifier: Apache-2.0
import os

from haystack import Pipeline
from haystack.utils.auth import Secret

import pytest
Expand Down Expand Up @@ -83,6 +85,15 @@ def test_to_dict_with_parameters(self, monkeypatch):
},
}

def test_pipeline_serialization_deserialization(self, tmp_path, monkeypatch):
monkeypatch.setenv("AZURE_OPENAI_API_KEY", "test-api-key")
generator = AzureOpenAIGenerator(azure_endpoint="some-non-existing-endpoint")
p = Pipeline()
p.add_component(instance=generator, name="generator")
p_str = p.dumps()
q = Pipeline.loads(p_str)
assert p.to_dict() == q.to_dict(), "Pipeline serialization/deserialization with AzureOpenAIGenerator failed."

@pytest.mark.integration
@pytest.mark.skipif(
not os.environ.get("AZURE_OPENAI_API_KEY", None) and not os.environ.get("AZURE_OPENAI_ENDPOINT", None),
Expand Down

0 comments on commit 96b9d3e

Please sign in to comment.