Skip to content

Commit

Permalink
Fix OpenAI key validation (7042) (#7043)
Browse files Browse the repository at this point in the history
  • Loading branch information
ian-fox committed Jul 26, 2023
1 parent f4248af commit 5feec1a
Show file tree
Hide file tree
Showing 6 changed files with 39 additions and 6 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

### Bug Fixes / Nits
- tune prompt to get rid of KeyError in SubQ engine (#7039)
- Fix validation of Azure OpenAI keys (#7042)

## [0.7.12] - 2023-07-25

Expand Down
4 changes: 3 additions & 1 deletion llama_index/embeddings/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,7 +245,9 @@ def __init__(
# Validate that either the openai.api_key property
# or OPENAI_API_KEY env variable are set to a valid key
# Raises ValueError if missing or doesn't match valid format
validate_openai_api_key(kwargs.get("api_key", None))
validate_openai_api_key(
kwargs.get("api_key", None), kwargs.get("api_type", None)
)

"""Init params."""
super().__init__(embed_batch_size, tokenizer, callback_manager)
Expand Down
4 changes: 3 additions & 1 deletion llama_index/llms/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,9 @@ class OpenAI(LLM, BaseModel):
max_retries: int = 10

def __init__(self, *args: Any, **kwargs: Any) -> None:
validate_openai_api_key(kwargs.get("api_key", None))
validate_openai_api_key(
kwargs.get("api_key", None), kwargs.get("api_type", None)
)
super().__init__(*args, **kwargs)

@property
Expand Down
12 changes: 10 additions & 2 deletions llama_index/llms/openai_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,9 +256,17 @@ def to_openai_function(pydantic_class: Type[BaseModel]) -> Dict[str, Any]:
}


def validate_openai_api_key(api_key: Optional[str] = None) -> None:
def validate_openai_api_key(
api_key: Optional[str] = None, api_type: Optional[str] = None
) -> None:
openai_api_key = api_key or os.environ.get("OPENAI_API_KEY", "") or openai.api_key
openai_api_type = (
api_type or os.environ.get("OPENAI_API_TYPE", "") or openai.api_type
)

if not openai_api_key:
raise ValueError(MISSING_API_KEY_ERROR_MESSAGE)
elif not OPENAI_API_KEY_FORMAT.search(openai_api_key):
elif openai_api_type == "open_ai" and not OPENAI_API_KEY_FORMAT.search(
openai_api_key
):
raise ValueError(INVALID_API_KEY_ERROR_MESSAGE)
14 changes: 12 additions & 2 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,8 +93,8 @@ def mock_openai_credentials() -> None:

class CachedOpenAIApiKeys:
"""
Saves the users' OpenAI API key either in the environment variable
or set to the library itself.
Saves the users' OpenAI API key and OpenAI API type either in
the environment variable or set to the library itself.
This allows us to run tests by setting it without plowing over
the local environment.
"""
Expand All @@ -104,22 +104,32 @@ def __init__(
set_env_key_to: Optional[str] = "",
set_library_key_to: Optional[str] = None,
set_fake_key: bool = False,
set_env_type_to: Optional[str] = "",
set_library_type_to: str = "open_ai", # default value in openai package
):
self.set_env_key_to = set_env_key_to
self.set_library_key_to = set_library_key_to
self.set_fake_key = set_fake_key
self.set_env_type_to = set_env_type_to
self.set_library_type_to = set_library_type_to

def __enter__(self) -> None:
self.api_env_variable_was = os.environ.get("OPENAI_API_KEY", "")
self.api_env_type_was = os.environ.get("OPENAI_API_TYPE", "")
self.openai_api_key_was = openai.api_key
self.openai_api_type_was = openai.api_type

os.environ["OPENAI_API_KEY"] = str(self.set_env_key_to)
os.environ["OPENAI_API_TYPE"] = str(self.set_env_type_to)
openai.api_key = self.set_library_key_to
openai.api_type = self.set_library_type_to

if self.set_fake_key:
openai.api_key = "sk-" + "a" * 48

# No matter what, set the environment variable back to what it was
def __exit__(self, *exc: Any) -> None:
os.environ["OPENAI_API_KEY"] = str(self.api_env_variable_was)
os.environ["OPENAI_API_TYPE"] = str(self.api_env_type_was)
openai.api_key = self.openai_api_key_was
openai.api_type = self.openai_api_type_was
10 changes: 10 additions & 0 deletions tests/embeddings/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,8 +128,18 @@ def test_validates_api_key_format_from_env() -> None:
with pytest.raises(ValueError, match="Invalid OpenAI API key."):
OpenAIEmbedding()

with CachedOpenAIApiKeys(
set_env_key_to="api-hf47930g732gf372", set_env_type_to="azure"
):
assert OpenAIEmbedding()


def test_validates_api_key_format_in_library() -> None:
with CachedOpenAIApiKeys(set_library_key_to="api-hf47930g732gf372"):
with pytest.raises(ValueError, match="Invalid OpenAI API key."):
OpenAIEmbedding()

with CachedOpenAIApiKeys(
set_library_key_to="api-hf47930g732gf372", set_library_type_to="azure"
):
assert OpenAIEmbedding()

0 comments on commit 5feec1a

Please sign in to comment.