Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: widen support of env vars in OpenAI components #7653

Merged
merged 20 commits into from
May 15, 2024
Merged
Show file tree
Hide file tree
Changes from 19 commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
aa0e7ed
add enviroment variables to the _enviroment.py file
CarlosFerLo Apr 30, 2024
395e012
add support for two of the three variables
CarlosFerLo Apr 30, 2024
737a5b5
Add support for 'OPENAI_TIMEOUT' and 'OPENAI_MAX_RETRIES' on OpenAIDo…
CarlosFerLo May 6, 2024
6503565
Replicate support for env vars in OpenAITextEmbedder.
CarlosFerLo May 6, 2024
704e38a
Add support for env vars in OpenAIGenerator..
CarlosFerLo May 6, 2024
8a754a1
Add support for env vars in OpenAIChatGenerator.
CarlosFerLo May 6, 2024
7d031d5
Merge branch 'deepset-ai:main' into issue-7610
CarlosFerLo May 6, 2024
6e55270
add docstrings and reno
CarlosFerLo May 7, 2024
3a321c7
Merge branch 'main' into issue-7610
CarlosFerLo May 8, 2024
06a3c3a
Merge branch 'main' into issue-7610
masci May 9, 2024
03234df
add params to __init__ in OpenAIDocumentEmbedder
CarlosFerLo May 10, 2024
a9f8248
add params to __init__ in OpenAITextEmbedder
CarlosFerLo May 10, 2024
f3a3a4b
make fully functional implementation of env vars and unit tests
CarlosFerLo May 12, 2024
21c3d5e
Merge branch 'main' into issue-7610
CarlosFerLo May 12, 2024
7327677
update reno
CarlosFerLo May 12, 2024
70d15ca
Pull from main.
CarlosFerLo May 12, 2024
3ff3c20
Update haystack/components/embedders/openai_text_embedder.py
masci May 15, 2024
10bc4bf
Merge branch 'main' into issue-7610
masci May 15, 2024
ce19f65
reverse changes to telemetry/_enviroment.py
CarlosFerLo May 15, 2024
411733c
Update haystack/components/embedders/openai_text_embedder.py
masci May 15, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
23 changes: 22 additions & 1 deletion haystack/components/embedders/openai_document_embedder.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
#
# SPDX-License-Identifier: Apache-2.0

import os
from typing import Any, Dict, List, Optional, Tuple

from openai import OpenAI
Expand Down Expand Up @@ -45,10 +46,15 @@ def __init__(
progress_bar: bool = True,
meta_fields_to_embed: Optional[List[str]] = None,
embedding_separator: str = "\n",
timeout: Optional[float] = None,
max_retries: Optional[int] = None,
):
"""
Create a OpenAIDocumentEmbedder component.

By setting the 'OPENAI_TIMEOUT' and 'OPENAI_MAX_RETRIES' you can change the timeout and max_retries parameters in the OpenAI client.


:param api_key:
The OpenAI API key.
:param model:
Expand All @@ -73,6 +79,10 @@ def __init__(
List of meta fields that will be embedded along with the Document text.
:param embedding_separator:
Separator used to concatenate the meta fields to the Document text.
:param timeout:
Timeout for OpenAI Client calls, if not set it is inferred from the `OPENAI_TIMEOUT` environment variable or set to 30.
:param max_retries:
Maximum retries to stablish contact with OpenAI if it returns an internal error, if not set it is inferred from the `OPENAI_MAX_RETRIES` environment variable or set to 5.
"""
self.api_key = api_key
self.model = model
Expand All @@ -86,7 +96,18 @@ def __init__(
self.meta_fields_to_embed = meta_fields_to_embed or []
self.embedding_separator = embedding_separator

self.client = OpenAI(api_key=api_key.resolve_value(), organization=organization, base_url=api_base_url)
if timeout is None:
timeout = float(os.environ.get("OPENAI_TIMEOUT", 30.0))
if max_retries is None:
max_retries = int(os.environ.get("OPENAI_MAX_RETRIES", 5))

self.client = OpenAI(
api_key=api_key.resolve_value(),
organization=organization,
base_url=api_base_url,
timeout=timeout,
max_retries=max_retries,
)

def _get_telemetry_data(self) -> Dict[str, Any]:
"""
Expand Down
25 changes: 24 additions & 1 deletion haystack/components/embedders/openai_text_embedder.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,17 @@
#
# SPDX-License-Identifier: Apache-2.0

import os
from typing import Any, Dict, List, Optional

from openai import OpenAI

from haystack import component, default_from_dict, default_to_dict
from haystack.utils import Secret, deserialize_secrets_inplace

OPENAI_TIMEOUT = float(os.environ.get("OPENAI_API_KEY", 30))
masci marked this conversation as resolved.
Show resolved Hide resolved
OPENAI_MAX_RETRIES = int(os.environ.get("OPENAI_MAX_RETRIES", 5))


@component
class OpenAITextEmbedder:
Expand Down Expand Up @@ -40,10 +44,14 @@ def __init__(
organization: Optional[str] = None,
prefix: str = "",
suffix: str = "",
timeout: Optional[float] = None,
max_retries: Optional[int] = None,
):
"""
Create an OpenAITextEmbedder component.

By setting the 'OPENAI_TIMEOUT' and 'OPENAI_MAX_RETRIES' you can change the timeout and max_retries parameters in the OpenAI client.

:param api_key:
The OpenAI API key.
:param model:
Expand All @@ -60,6 +68,10 @@ def __init__(
A string to add at the beginning of each text.
:param suffix:
A string to add at the end of each text.
:param timeout:
Timeout for OpenAI Client calls, if not set it is inferred from the `OPENAI_TIMEOUT` environment variable or set to 30.
:param max_retries:
Maximum retries to stablish contact with OpenAI if it returns an internal error, if not set it is inferred from the `OPENAI_MAX_RETRIES` environment variable or set to 5.
"""
self.model = model
self.dimensions = dimensions
Expand All @@ -69,7 +81,18 @@ def __init__(
self.suffix = suffix
self.api_key = api_key

self.client = OpenAI(api_key=api_key.resolve_value(), organization=organization, base_url=api_base_url)
if timeout is None:
timeout = float(os.environ.get("OPENAI_TIMEOUT", 30.0))
if max_retries is None:
max_retries = int(os.environ.get("OPENAI_MAX_RETRIES", 5))

self.client = OpenAI(
api_key=api_key.resolve_value(),
organization=organization,
base_url=api_base_url,
timeout=timeout,
max_retries=max_retries,
)

def _get_telemetry_data(self) -> Dict[str, Any]:
"""
Expand Down
23 changes: 22 additions & 1 deletion haystack/components/generators/chat/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import copy
import json
import os
from typing import Any, Callable, Dict, List, Optional, Union

from openai import OpenAI, Stream
Expand Down Expand Up @@ -75,13 +76,17 @@ def __init__(
api_base_url: Optional[str] = None,
organization: Optional[str] = None,
generation_kwargs: Optional[Dict[str, Any]] = None,
timeout: Optional[float] = None,
max_retries: Optional[int] = None,
):
"""
Initializes the OpenAIChatGenerator component.

Creates an instance of OpenAIChatGenerator. Unless specified otherwise in the `model`, this is for OpenAI's
GPT-3.5 model.

By setting the 'OPENAI_TIMEOUT' and 'OPENAI_MAX_RETRIES' you can change the timeout and max_retries parameters in the OpenAI client.

:param api_key: The OpenAI API key.
:param model: The name of the model to use.
:param streaming_callback: A callback function that is called when a new token is received from the stream.
Expand All @@ -108,14 +113,30 @@ def __init__(
Bigger values mean the model will be less likely to repeat the same token in the text.
- `logit_bias`: Add a logit bias to specific tokens. The keys of the dictionary are tokens, and the
values are the bias to add to that token.
:param timeout:
Timeout for OpenAI Client calls, if not set it is inferred from the `OPENAI_TIMEOUT` environment variable or set to 30.
:param max_retries:
Maximum retries to stablish contact with OpenAI if it returns an internal error, if not set it is inferred from the `OPENAI_MAX_RETRIES` environment variable or set to 5.
"""
self.api_key = api_key
self.model = model
self.generation_kwargs = generation_kwargs or {}
self.streaming_callback = streaming_callback
self.api_base_url = api_base_url
self.organization = organization
self.client = OpenAI(api_key=api_key.resolve_value(), organization=organization, base_url=api_base_url)

if timeout is None:
timeout = float(os.environ.get("OPENAI_TIMEOUT", 30.0))
if max_retries is None:
max_retries = int(os.environ.get("OPENAI_MAX_RETRIES", 5))

self.client = OpenAI(
api_key=api_key.resolve_value(),
organization=organization,
base_url=api_base_url,
timeout=timeout,
max_retries=max_retries,
)

def _get_telemetry_data(self) -> Dict[str, Any]:
"""
Expand Down
25 changes: 24 additions & 1 deletion haystack/components/generators/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
#
# SPDX-License-Identifier: Apache-2.0

import os
from typing import Any, Callable, Dict, List, Optional, Union

from openai import OpenAI, Stream
Expand Down Expand Up @@ -60,10 +61,15 @@ def __init__(
organization: Optional[str] = None,
system_prompt: Optional[str] = None,
generation_kwargs: Optional[Dict[str, Any]] = None,
timeout: Optional[float] = None,
max_retries: Optional[int] = None,
):
"""
Creates an instance of OpenAIGenerator. Unless specified otherwise in the `model`, this is for OpenAI's GPT-3.5 model.

By setting the 'OPENAI_TIMEOUT' and 'OPENAI_MAX_RETRIES' you can change the timeout and max_retries parameters in the OpenAI client.


:param api_key: The OpenAI API key.
:param model: The name of the model to use.
:param streaming_callback: A callback function that is called when a new token is received from the stream.
Expand Down Expand Up @@ -92,6 +98,11 @@ def __init__(
Bigger values mean the model will be less likely to repeat the same token in the text.
- `logit_bias`: Add a logit bias to specific tokens. The keys of the dictionary are tokens, and the
values are the bias to add to that token.
:param timeout:
Timeout for OpenAI Client calls, if not set it is inferred from the `OPENAI_TIMEOUT` environment variable or set to 30.
:param max_retries:
Maximum retries to stablish contact with OpenAI if it returns an internal error, if not set it is inferred from the `OPENAI_MAX_RETRIES` environment variable or set to 5.

"""
self.api_key = api_key
self.model = model
Expand All @@ -101,7 +112,19 @@ def __init__(

self.api_base_url = api_base_url
self.organization = organization
self.client = OpenAI(api_key=api_key.resolve_value(), organization=organization, base_url=api_base_url)

if timeout is None:
timeout = float(os.environ.get("OPENAI_TIMEOUT", 30.0))
if max_retries is None:
max_retries = int(os.environ.get("OPENAI_MAX_RETRIES", 5))

self.client = OpenAI(
api_key=api_key.resolve_value(),
organization=organization,
base_url=api_base_url,
timeout=timeout,
max_retries=max_retries,
)

def _get_telemetry_data(self) -> Dict[str, Any]:
"""
Expand Down
1 change: 0 additions & 1 deletion haystack/telemetry/_environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@

logger = logging.getLogger(__name__)


# This value cannot change during the lifetime of the process
_IS_DOCKER_CACHE = None

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
---
highlights: >
Add the 'OPENAI_TIMEOUT' and 'OPENAI_MAX_RETRIES' to the OpenAI components.
enhancements:
- |
Now you can set the timeout and max_retries parameters on OpenAI components by setting the 'OPENAI_TIMEOUT' and 'OPENAI_MAX_RETRIES' environment vars or passing them at __init__.
issues:
- |
7610
40 changes: 38 additions & 2 deletions test/components/embedders/test_openai_document_embedder.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ class TestOpenAIDocumentEmbedder:
def test_init_default(self, monkeypatch):
monkeypatch.setenv("OPENAI_API_KEY", "fake-api-key")
embedder = OpenAIDocumentEmbedder()
assert embedder.api_key.resolve_value() == "fake-api-key"
assert embedder.model == "text-embedding-ada-002"
assert embedder.organization is None
assert embedder.prefix == ""
Expand All @@ -37,10 +38,14 @@ def test_init_default(self, monkeypatch):
assert embedder.progress_bar is True
assert embedder.meta_fields_to_embed == []
assert embedder.embedding_separator == "\n"
assert embedder.client.max_retries == 5
assert embedder.client.timeout == 30.0

def test_init_with_parameters(self):
def test_init_with_parameters(self, monkeypatch):
monkeypatch.setenv("OPENAI_TIMEOUT", "100")
monkeypatch.setenv("OPENAI_MAX_RETRIES", "10")
embedder = OpenAIDocumentEmbedder(
api_key=Secret.from_token("fake-api-key"),
api_key=Secret.from_token("fake-api-key-2"),
model="model",
organization="my-org",
prefix="prefix",
Expand All @@ -49,7 +54,10 @@ def test_init_with_parameters(self):
progress_bar=False,
meta_fields_to_embed=["test_field"],
embedding_separator=" | ",
timeout=40.0,
max_retries=1,
)
assert embedder.api_key.resolve_value() == "fake-api-key-2"
assert embedder.organization == "my-org"
assert embedder.model == "model"
assert embedder.prefix == "prefix"
Expand All @@ -58,6 +66,34 @@ def test_init_with_parameters(self):
assert embedder.progress_bar is False
assert embedder.meta_fields_to_embed == ["test_field"]
assert embedder.embedding_separator == " | "
assert embedder.client.max_retries == 1
assert embedder.client.timeout == 40.0

def test_init_with_parameters_and_env_vars(self, monkeypatch):
monkeypatch.setenv("OPENAI_TIMEOUT", "100")
monkeypatch.setenv("OPENAI_MAX_RETRIES", "10")
embedder = OpenAIDocumentEmbedder(
api_key=Secret.from_token("fake-api-key-2"),
model="model",
organization="my-org",
prefix="prefix",
suffix="suffix",
batch_size=64,
progress_bar=False,
meta_fields_to_embed=["test_field"],
embedding_separator=" | ",
)
assert embedder.api_key.resolve_value() == "fake-api-key-2"
assert embedder.organization == "my-org"
assert embedder.model == "model"
assert embedder.prefix == "prefix"
assert embedder.suffix == "suffix"
assert embedder.batch_size == 64
assert embedder.progress_bar is False
assert embedder.meta_fields_to_embed == ["test_field"]
assert embedder.embedding_separator == " | "
assert embedder.client.max_retries == 10
assert embedder.client.timeout == 100.0

def test_init_fail_wo_api_key(self, monkeypatch):
monkeypatch.delenv("OPENAI_API_KEY", raising=False)
Expand Down
30 changes: 29 additions & 1 deletion test/components/embedders/test_openai_text_embedder.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,22 +20,50 @@ def test_init_default(self, monkeypatch):
assert embedder.organization is None
assert embedder.prefix == ""
assert embedder.suffix == ""
assert embedder.client.timeout == 30
assert embedder.client.max_retries == 5

def test_init_with_parameters(self):
def test_init_with_parameters(self, monkeypatch):
monkeypatch.setenv("OPENAI_TIMEOUT", "100")
monkeypatch.setenv("OPENAI_MAX_RETRIES", "10")
embedder = OpenAITextEmbedder(
api_key=Secret.from_token("fake-api-key"),
model="model",
api_base_url="https://my-custom-base-url.com",
organization="fake-organization",
prefix="prefix",
suffix="suffix",
timeout=40.0,
max_retries=1,
)
assert embedder.client.api_key == "fake-api-key"
assert embedder.model == "model"
assert embedder.api_base_url == "https://my-custom-base-url.com"
assert embedder.organization == "fake-organization"
assert embedder.prefix == "prefix"
assert embedder.suffix == "suffix"
assert embedder.client.timeout == 40.0
assert embedder.client.max_retries == 1

def test_init_with_parameters_and_env_vars(self, monkeypatch):
monkeypatch.setenv("OPENAI_TIMEOUT", "100")
monkeypatch.setenv("OPENAI_MAX_RETRIES", "10")
embedder = OpenAITextEmbedder(
api_key=Secret.from_token("fake-api-key"),
model="model",
api_base_url="https://my-custom-base-url.com",
organization="fake-organization",
prefix="prefix",
suffix="suffix",
)
assert embedder.client.api_key == "fake-api-key"
assert embedder.model == "model"
assert embedder.api_base_url == "https://my-custom-base-url.com"
assert embedder.organization == "fake-organization"
assert embedder.prefix == "prefix"
assert embedder.suffix == "suffix"
assert embedder.client.timeout == 100.0
assert embedder.client.max_retries == 10

def test_init_fail_wo_api_key(self, monkeypatch):
monkeypatch.delenv("OPENAI_API_KEY", raising=False)
Expand Down