Skip to content

Commit

Permalink
feat: widen support of env vars in OpenAI components (#7653)
Browse files Browse the repository at this point in the history
* add enviroment variables to the _enviroment.py file

* add support for two of the three variables

* Add support for 'OPENAI_TIMEOUT' and 'OPENAI_MAX_RETRIES' on OpenAIDocument Ebedder.

* Replicate support for env vars in OpenAITextEmbedder.

* Add support for env vars in OpenAIGenerator..

* Add support for env vars in OpenAIChatGenerator.

* add docstrings and reno

* add params to __init__ in OpenAIDocumentEmbedder

* add params to __init__ in OpenAITextEmbedder

* make fully functional implementation of env vars and unit tests

* update reno

* Update haystack/components/embedders/openai_text_embedder.py

* reverse changes to telemetry/_enviroment.py

* Update haystack/components/embedders/openai_text_embedder.py

---------

Co-authored-by: Massimiliano Pippi <mpippi@gmail.com>
  • Loading branch information
CarlosFerLo and masci committed May 15, 2024
1 parent af53e84 commit 686a499
Show file tree
Hide file tree
Showing 10 changed files with 220 additions and 10 deletions.
23 changes: 22 additions & 1 deletion haystack/components/embedders/openai_document_embedder.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
#
# SPDX-License-Identifier: Apache-2.0

import os
from typing import Any, Dict, List, Optional, Tuple

from openai import OpenAI
Expand Down Expand Up @@ -45,10 +46,15 @@ def __init__(
progress_bar: bool = True,
meta_fields_to_embed: Optional[List[str]] = None,
embedding_separator: str = "\n",
timeout: Optional[float] = None,
max_retries: Optional[int] = None,
):
"""
Create a OpenAIDocumentEmbedder component.
By setting the 'OPENAI_TIMEOUT' and 'OPENAI_MAX_RETRIES' you can change the timeout and max_retries parameters in the OpenAI client.
:param api_key:
The OpenAI API key.
:param model:
Expand All @@ -73,6 +79,10 @@ def __init__(
List of meta fields that will be embedded along with the Document text.
:param embedding_separator:
Separator used to concatenate the meta fields to the Document text.
:param timeout:
Timeout for OpenAI Client calls, if not set it is inferred from the `OPENAI_TIMEOUT` environment variable or set to 30.
:param max_retries:
Maximum retries to stablish contact with OpenAI if it returns an internal error, if not set it is inferred from the `OPENAI_MAX_RETRIES` environment variable or set to 5.
"""
self.api_key = api_key
self.model = model
Expand All @@ -86,7 +96,18 @@ def __init__(
self.meta_fields_to_embed = meta_fields_to_embed or []
self.embedding_separator = embedding_separator

self.client = OpenAI(api_key=api_key.resolve_value(), organization=organization, base_url=api_base_url)
if timeout is None:
timeout = float(os.environ.get("OPENAI_TIMEOUT", 30.0))
if max_retries is None:
max_retries = int(os.environ.get("OPENAI_MAX_RETRIES", 5))

self.client = OpenAI(
api_key=api_key.resolve_value(),
organization=organization,
base_url=api_base_url,
timeout=timeout,
max_retries=max_retries,
)

def _get_telemetry_data(self) -> Dict[str, Any]:
"""
Expand Down
25 changes: 24 additions & 1 deletion haystack/components/embedders/openai_text_embedder.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,17 @@
#
# SPDX-License-Identifier: Apache-2.0

import os
from typing import Any, Dict, List, Optional

from openai import OpenAI

from haystack import component, default_from_dict, default_to_dict
from haystack.utils import Secret, deserialize_secrets_inplace

OPENAI_TIMEOUT = float(os.environ.get("OPENAI_TIMEOUT", 30))
OPENAI_MAX_RETRIES = int(os.environ.get("OPENAI_MAX_RETRIES", 5))


@component
class OpenAITextEmbedder:
Expand Down Expand Up @@ -40,10 +44,14 @@ def __init__(
organization: Optional[str] = None,
prefix: str = "",
suffix: str = "",
timeout: Optional[float] = None,
max_retries: Optional[int] = None,
):
"""
Create an OpenAITextEmbedder component.
By setting the 'OPENAI_TIMEOUT' and 'OPENAI_MAX_RETRIES' you can change the timeout and max_retries parameters in the OpenAI client.
:param api_key:
The OpenAI API key.
:param model:
Expand All @@ -60,6 +68,10 @@ def __init__(
A string to add at the beginning of each text.
:param suffix:
A string to add at the end of each text.
:param timeout:
Timeout for OpenAI Client calls, if not set it is inferred from the `OPENAI_TIMEOUT` environment variable or set to 30.
:param max_retries:
Maximum retries to stablish contact with OpenAI if it returns an internal error, if not set it is inferred from the `OPENAI_MAX_RETRIES` environment variable or set to 5.
"""
self.model = model
self.dimensions = dimensions
Expand All @@ -69,7 +81,18 @@ def __init__(
self.suffix = suffix
self.api_key = api_key

self.client = OpenAI(api_key=api_key.resolve_value(), organization=organization, base_url=api_base_url)
if timeout is None:
timeout = float(os.environ.get("OPENAI_TIMEOUT", 30.0))
if max_retries is None:
max_retries = int(os.environ.get("OPENAI_MAX_RETRIES", 5))

self.client = OpenAI(
api_key=api_key.resolve_value(),
organization=organization,
base_url=api_base_url,
timeout=timeout,
max_retries=max_retries,
)

def _get_telemetry_data(self) -> Dict[str, Any]:
"""
Expand Down
23 changes: 22 additions & 1 deletion haystack/components/generators/chat/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import copy
import json
import os
from typing import Any, Callable, Dict, List, Optional, Union

from openai import OpenAI, Stream
Expand Down Expand Up @@ -75,13 +76,17 @@ def __init__(
api_base_url: Optional[str] = None,
organization: Optional[str] = None,
generation_kwargs: Optional[Dict[str, Any]] = None,
timeout: Optional[float] = None,
max_retries: Optional[int] = None,
):
"""
Initializes the OpenAIChatGenerator component.
Creates an instance of OpenAIChatGenerator. Unless specified otherwise in the `model`, this is for OpenAI's
GPT-3.5 model.
By setting the 'OPENAI_TIMEOUT' and 'OPENAI_MAX_RETRIES' you can change the timeout and max_retries parameters in the OpenAI client.
:param api_key: The OpenAI API key.
:param model: The name of the model to use.
:param streaming_callback: A callback function that is called when a new token is received from the stream.
Expand All @@ -108,14 +113,30 @@ def __init__(
Bigger values mean the model will be less likely to repeat the same token in the text.
- `logit_bias`: Add a logit bias to specific tokens. The keys of the dictionary are tokens, and the
values are the bias to add to that token.
:param timeout:
Timeout for OpenAI Client calls, if not set it is inferred from the `OPENAI_TIMEOUT` environment variable or set to 30.
:param max_retries:
Maximum retries to stablish contact with OpenAI if it returns an internal error, if not set it is inferred from the `OPENAI_MAX_RETRIES` environment variable or set to 5.
"""
self.api_key = api_key
self.model = model
self.generation_kwargs = generation_kwargs or {}
self.streaming_callback = streaming_callback
self.api_base_url = api_base_url
self.organization = organization
self.client = OpenAI(api_key=api_key.resolve_value(), organization=organization, base_url=api_base_url)

if timeout is None:
timeout = float(os.environ.get("OPENAI_TIMEOUT", 30.0))
if max_retries is None:
max_retries = int(os.environ.get("OPENAI_MAX_RETRIES", 5))

self.client = OpenAI(
api_key=api_key.resolve_value(),
organization=organization,
base_url=api_base_url,
timeout=timeout,
max_retries=max_retries,
)

def _get_telemetry_data(self) -> Dict[str, Any]:
"""
Expand Down
25 changes: 24 additions & 1 deletion haystack/components/generators/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
#
# SPDX-License-Identifier: Apache-2.0

import os
from typing import Any, Callable, Dict, List, Optional, Union

from openai import OpenAI, Stream
Expand Down Expand Up @@ -60,10 +61,15 @@ def __init__(
organization: Optional[str] = None,
system_prompt: Optional[str] = None,
generation_kwargs: Optional[Dict[str, Any]] = None,
timeout: Optional[float] = None,
max_retries: Optional[int] = None,
):
"""
Creates an instance of OpenAIGenerator. Unless specified otherwise in the `model`, this is for OpenAI's GPT-3.5 model.
By setting the 'OPENAI_TIMEOUT' and 'OPENAI_MAX_RETRIES' you can change the timeout and max_retries parameters in the OpenAI client.
:param api_key: The OpenAI API key.
:param model: The name of the model to use.
:param streaming_callback: A callback function that is called when a new token is received from the stream.
Expand Down Expand Up @@ -92,6 +98,11 @@ def __init__(
Bigger values mean the model will be less likely to repeat the same token in the text.
- `logit_bias`: Add a logit bias to specific tokens. The keys of the dictionary are tokens, and the
values are the bias to add to that token.
:param timeout:
Timeout for OpenAI Client calls, if not set it is inferred from the `OPENAI_TIMEOUT` environment variable or set to 30.
:param max_retries:
Maximum retries to stablish contact with OpenAI if it returns an internal error, if not set it is inferred from the `OPENAI_MAX_RETRIES` environment variable or set to 5.
"""
self.api_key = api_key
self.model = model
Expand All @@ -101,7 +112,19 @@ def __init__(

self.api_base_url = api_base_url
self.organization = organization
self.client = OpenAI(api_key=api_key.resolve_value(), organization=organization, base_url=api_base_url)

if timeout is None:
timeout = float(os.environ.get("OPENAI_TIMEOUT", 30.0))
if max_retries is None:
max_retries = int(os.environ.get("OPENAI_MAX_RETRIES", 5))

self.client = OpenAI(
api_key=api_key.resolve_value(),
organization=organization,
base_url=api_base_url,
timeout=timeout,
max_retries=max_retries,
)

def _get_telemetry_data(self) -> Dict[str, Any]:
"""
Expand Down
1 change: 0 additions & 1 deletion haystack/telemetry/_environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@

logger = logging.getLogger(__name__)


# This value cannot change during the lifetime of the process
_IS_DOCKER_CACHE = None

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
---
highlights: >
Add the 'OPENAI_TIMEOUT' and 'OPENAI_MAX_RETRIES' to the OpenAI components.
enhancements:
- |
Now you can set the timeout and max_retries parameters on OpenAI components by setting the 'OPENAI_TIMEOUT' and 'OPENAI_MAX_RETRIES' environment vars or passing them at __init__.
issues:
- |
7610
40 changes: 38 additions & 2 deletions test/components/embedders/test_openai_document_embedder.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ class TestOpenAIDocumentEmbedder:
def test_init_default(self, monkeypatch):
monkeypatch.setenv("OPENAI_API_KEY", "fake-api-key")
embedder = OpenAIDocumentEmbedder()
assert embedder.api_key.resolve_value() == "fake-api-key"
assert embedder.model == "text-embedding-ada-002"
assert embedder.organization is None
assert embedder.prefix == ""
Expand All @@ -37,10 +38,14 @@ def test_init_default(self, monkeypatch):
assert embedder.progress_bar is True
assert embedder.meta_fields_to_embed == []
assert embedder.embedding_separator == "\n"
assert embedder.client.max_retries == 5
assert embedder.client.timeout == 30.0

def test_init_with_parameters(self):
def test_init_with_parameters(self, monkeypatch):
monkeypatch.setenv("OPENAI_TIMEOUT", "100")
monkeypatch.setenv("OPENAI_MAX_RETRIES", "10")
embedder = OpenAIDocumentEmbedder(
api_key=Secret.from_token("fake-api-key"),
api_key=Secret.from_token("fake-api-key-2"),
model="model",
organization="my-org",
prefix="prefix",
Expand All @@ -49,7 +54,10 @@ def test_init_with_parameters(self):
progress_bar=False,
meta_fields_to_embed=["test_field"],
embedding_separator=" | ",
timeout=40.0,
max_retries=1,
)
assert embedder.api_key.resolve_value() == "fake-api-key-2"
assert embedder.organization == "my-org"
assert embedder.model == "model"
assert embedder.prefix == "prefix"
Expand All @@ -58,6 +66,34 @@ def test_init_with_parameters(self):
assert embedder.progress_bar is False
assert embedder.meta_fields_to_embed == ["test_field"]
assert embedder.embedding_separator == " | "
assert embedder.client.max_retries == 1
assert embedder.client.timeout == 40.0

def test_init_with_parameters_and_env_vars(self, monkeypatch):
monkeypatch.setenv("OPENAI_TIMEOUT", "100")
monkeypatch.setenv("OPENAI_MAX_RETRIES", "10")
embedder = OpenAIDocumentEmbedder(
api_key=Secret.from_token("fake-api-key-2"),
model="model",
organization="my-org",
prefix="prefix",
suffix="suffix",
batch_size=64,
progress_bar=False,
meta_fields_to_embed=["test_field"],
embedding_separator=" | ",
)
assert embedder.api_key.resolve_value() == "fake-api-key-2"
assert embedder.organization == "my-org"
assert embedder.model == "model"
assert embedder.prefix == "prefix"
assert embedder.suffix == "suffix"
assert embedder.batch_size == 64
assert embedder.progress_bar is False
assert embedder.meta_fields_to_embed == ["test_field"]
assert embedder.embedding_separator == " | "
assert embedder.client.max_retries == 10
assert embedder.client.timeout == 100.0

def test_init_fail_wo_api_key(self, monkeypatch):
monkeypatch.delenv("OPENAI_API_KEY", raising=False)
Expand Down
30 changes: 29 additions & 1 deletion test/components/embedders/test_openai_text_embedder.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,22 +20,50 @@ def test_init_default(self, monkeypatch):
assert embedder.organization is None
assert embedder.prefix == ""
assert embedder.suffix == ""
assert embedder.client.timeout == 30
assert embedder.client.max_retries == 5

def test_init_with_parameters(self):
def test_init_with_parameters(self, monkeypatch):
monkeypatch.setenv("OPENAI_TIMEOUT", "100")
monkeypatch.setenv("OPENAI_MAX_RETRIES", "10")
embedder = OpenAITextEmbedder(
api_key=Secret.from_token("fake-api-key"),
model="model",
api_base_url="https://my-custom-base-url.com",
organization="fake-organization",
prefix="prefix",
suffix="suffix",
timeout=40.0,
max_retries=1,
)
assert embedder.client.api_key == "fake-api-key"
assert embedder.model == "model"
assert embedder.api_base_url == "https://my-custom-base-url.com"
assert embedder.organization == "fake-organization"
assert embedder.prefix == "prefix"
assert embedder.suffix == "suffix"
assert embedder.client.timeout == 40.0
assert embedder.client.max_retries == 1

def test_init_with_parameters_and_env_vars(self, monkeypatch):
monkeypatch.setenv("OPENAI_TIMEOUT", "100")
monkeypatch.setenv("OPENAI_MAX_RETRIES", "10")
embedder = OpenAITextEmbedder(
api_key=Secret.from_token("fake-api-key"),
model="model",
api_base_url="https://my-custom-base-url.com",
organization="fake-organization",
prefix="prefix",
suffix="suffix",
)
assert embedder.client.api_key == "fake-api-key"
assert embedder.model == "model"
assert embedder.api_base_url == "https://my-custom-base-url.com"
assert embedder.organization == "fake-organization"
assert embedder.prefix == "prefix"
assert embedder.suffix == "suffix"
assert embedder.client.timeout == 100.0
assert embedder.client.max_retries == 10

def test_init_fail_wo_api_key(self, monkeypatch):
monkeypatch.delenv("OPENAI_API_KEY", raising=False)
Expand Down

0 comments on commit 686a499

Please sign in to comment.