Skip to content

Commit

Permalink
Fix type hint for token_provider, Add missing params in `AzureOpenAiC…
Browse files Browse the repository at this point in the history
…ompletionPromptDriver` openai client (#750)
  • Loading branch information
vachillo committed Apr 17, 2024
1 parent b944728 commit 77ba1a7
Show file tree
Hide file tree
Showing 5 changed files with 22 additions and 8 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- **BREADKING**: Removed `mime_type` field from `ImageArtifact`. `mime_type` is now a property constructed using the Artifact type and `format` field.
- Moved [Griptape Docs](https://github.com/griptape-ai/griptape-docs) to this repository.

### Fixed
- Type hint for parameter `azure_ad_token_provider` on Azure OpenAI drivers to `Optional[Callable[[], str]]`.
- Missing parameters `azure_ad_token` and `azure_ad_token_provider` on the default client for `AzureOpenAiCompletionPromptDriver`.

## [0.24.2] - 2024-04-04

- Fixed FileManager.load_files_from_disk schema.
Expand Down
6 changes: 4 additions & 2 deletions griptape/drivers/embedding/azure_openai_embedding_driver.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from __future__ import annotations

from typing import Optional
from typing import Callable, Optional
from attr import define, field, Factory
from griptape.drivers import OpenAiEmbeddingDriver
from griptape.tokenizers import OpenAiTokenizer
Expand All @@ -23,7 +23,9 @@ class AzureOpenAiEmbeddingDriver(OpenAiEmbeddingDriver):
azure_deployment: str = field(kw_only=True, metadata={"serializable": True})
azure_endpoint: str = field(kw_only=True, metadata={"serializable": True})
azure_ad_token: Optional[str] = field(kw_only=True, default=None, metadata={"serializable": False})
azure_ad_token_provider: Optional[str] = field(kw_only=True, default=None, metadata={"serializable": True})
azure_ad_token_provider: Optional[Callable[[], str]] = field(
kw_only=True, default=None, metadata={"serializable": False}
)
api_version: str = field(default="2023-05-15", kw_only=True, metadata={"serializable": True})
tokenizer: OpenAiTokenizer = field(
default=Factory(lambda self: OpenAiTokenizer(model=self.model), takes_self=True), kw_only=True
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import openai
from attr import field, Factory, define
from typing import Optional
from typing import Callable, Optional

from griptape.drivers import OpenAiImageGenerationDriver

Expand All @@ -23,7 +23,9 @@ class AzureOpenAiImageGenerationDriver(OpenAiImageGenerationDriver):
azure_deployment: str = field(kw_only=True, metadata={"serializable": True})
azure_endpoint: str = field(kw_only=True, metadata={"serializable": True})
azure_ad_token: Optional[str] = field(kw_only=True, default=None, metadata={"serializable": True})
azure_ad_token_provider: Optional[str] = field(kw_only=True, default=None, metadata={"serializable": True})
azure_ad_token_provider: Optional[Callable[[], str]] = field(
kw_only=True, default=None, metadata={"serializable": False}
)
api_version: str = field(default="2024-02-01", kw_only=True, metadata={"serializable": True})
client: openai.AzureOpenAI = field(
default=Factory(
Expand Down
6 changes: 4 additions & 2 deletions griptape/drivers/prompt/azure_openai_chat_prompt_driver.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from attr import define, field, Factory
from typing import Optional
from typing import Callable, Optional
from griptape.utils import PromptStack
from griptape.drivers import OpenAiChatPromptDriver
import openai
Expand All @@ -20,7 +20,9 @@ class AzureOpenAiChatPromptDriver(OpenAiChatPromptDriver):
azure_deployment: str = field(kw_only=True, metadata={"serializable": True})
azure_endpoint: str = field(kw_only=True, metadata={"serializable": True})
azure_ad_token: Optional[str] = field(kw_only=True, default=None, metadata={"serializable": True})
azure_ad_token_provider: Optional[str] = field(kw_only=True, default=None, metadata={"serializable": True})
azure_ad_token_provider: Optional[Callable[[], str]] = field(
kw_only=True, default=None, metadata={"serializable": False}
)
api_version: str = field(default="2023-05-15", kw_only=True, metadata={"serializable": True})
client: openai.AzureOpenAI = field(
default=Factory(
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Optional
from typing import Callable, Optional
from attr import define, field, Factory
from griptape.drivers import OpenAiCompletionPromptDriver
import openai
Expand All @@ -19,7 +19,9 @@ class AzureOpenAiCompletionPromptDriver(OpenAiCompletionPromptDriver):
azure_deployment: str = field(kw_only=True, metadata={"serializable": True})
azure_endpoint: str = field(kw_only=True, metadata={"serializable": True})
azure_ad_token: Optional[str] = field(kw_only=True, default=None, metadata={"serializable": True})
azure_ad_token_provider: Optional[str] = field(kw_only=True, default=None, metadata={"serializable": True})
azure_ad_token_provider: Optional[Callable[[], str]] = field(
kw_only=True, default=None, metadata={"serializable": False}
)
api_version: str = field(default="2023-05-15", kw_only=True, metadata={"serializable": True})
client: openai.AzureOpenAI = field(
default=Factory(
Expand All @@ -29,6 +31,8 @@ class AzureOpenAiCompletionPromptDriver(OpenAiCompletionPromptDriver):
api_version=self.api_version,
azure_endpoint=self.azure_endpoint,
azure_deployment=self.azure_deployment,
azure_ad_token=self.azure_ad_token,
azure_ad_token_provider=self.azure_ad_token_provider,
),
takes_self=True,
)
Expand Down

0 comments on commit 77ba1a7

Please sign in to comment.