Skip to content

Commit

Permalink
updating structures and anthropic_version to static
Browse files Browse the repository at this point in the history
  • Loading branch information
emjay07 committed Mar 21, 2024
1 parent 978c8b9 commit 3ee66c1
Show file tree
Hide file tree
Showing 9 changed files with 34 additions and 31 deletions.
6 changes: 6 additions & 0 deletions griptape/config/amazon_bedrock_structure_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,11 @@
)
from griptape.drivers import (
AmazonBedrockImageGenerationDriver,
AmazonBedrockImageQueryDriver,
AmazonBedrockPromptDriver,
AmazonBedrockTitanEmbeddingDriver,
BedrockClaudePromptModelDriver,
BedrockClaudeImageQueryModelDriver,
BedrockTitanImageGenerationModelDriver,
LocalVectorStoreDriver,
)
Expand All @@ -34,6 +36,10 @@ class AmazonBedrockStructureConfig(BaseStructureConfig):
model="amazon.titan-image-generator-v1",
image_generation_model_driver=BedrockTitanImageGenerationModelDriver(),
),
image_query_driver=AmazonBedrockImageQueryDriver(
model="anthropic.claude-3-sonnet-20240229-v1:0",
image_query_model_driver=BedrockClaudeImageQueryModelDriver(),
),
embedding_driver=AmazonBedrockTitanEmbeddingDriver(model="amazon.titan-embed-text-v1"),
vector_store_driver=LocalVectorStoreDriver(
embedding_driver=AmazonBedrockTitanEmbeddingDriver(model="amazon.titan-embed-text-v1")
Expand Down
8 changes: 7 additions & 1 deletion griptape/config/anthropic_structure_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,12 @@
StructureTaskMemoryQueryEngineConfig,
StructureTaskMemorySummaryEngineConfig,
)
from griptape.drivers import LocalVectorStoreDriver, AnthropicPromptDriver, VoyageAiEmbeddingDriver
from griptape.drivers import (
LocalVectorStoreDriver,
AnthropicPromptDriver,
AnthropicImageQueryDriver,
VoyageAiEmbeddingDriver,
)


@define
Expand All @@ -23,6 +28,7 @@ class AnthropicStructureConfig(BaseStructureConfig):
vector_store_driver=LocalVectorStoreDriver(
embedding_driver=VoyageAiEmbeddingDriver(model="voyage-large-2")
),
image_query_driver=AnthropicImageQueryDriver(model="claude-3-opus-20240229"),
)
),
kw_only=True,
Expand Down
2 changes: 1 addition & 1 deletion griptape/config/openai_structure_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ class OpenAiStructureConfig(BaseStructureConfig):
lambda: StructureGlobalDriversConfig(
prompt_driver=OpenAiChatPromptDriver(model="gpt-4"),
image_generation_driver=OpenAiImageGenerationDriver(model="dall-e-2", image_size="512x512"),
image_query_driver=OpenAiVisionImageQueryDriver(model="gpt-4-vision-preview", max_tokens=300),
image_query_driver=OpenAiVisionImageQueryDriver(model="gpt-4-vision-preview"),
embedding_driver=OpenAiEmbeddingDriver(model="text-embedding-3-small"),
vector_store_driver=LocalVectorStoreDriver(
embedding_driver=OpenAiEmbeddingDriver(model="text-embedding-3-small")
Expand Down
7 changes: 2 additions & 5 deletions griptape/drivers/image_query/anthropic_image_query_driver.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,10 @@
from __future__ import annotations
from typing import TYPE_CHECKING, Optional
from typing import Optional, Any
from attr import define, field, Factory
from griptape.artifacts import ImageArtifact, TextArtifact
from griptape.drivers import BaseImageQueryDriver
from griptape.utils import import_optional_dependency

if TYPE_CHECKING:
from anthropic import Anthropic


@define
class AnthropicImageQueryDriver(BaseImageQueryDriver):
Expand All @@ -20,7 +17,7 @@ class AnthropicImageQueryDriver(BaseImageQueryDriver):

api_key: Optional[str] = field(default=None, kw_only=True, metadata={"serializable": True})
model: str = field(kw_only=True, metadata={"serializable": True})
client: Anthropic = field(
client: Any = field(
default=Factory(
lambda self: import_optional_dependency("anthropic").Anthropic(api_key=self.api_key), takes_self=True
),
Expand Down
1 change: 1 addition & 0 deletions griptape/drivers/image_query/dummy_image_query_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
@define
class DummyImageQueryDriver(BaseImageQueryDriver):
model: str = field(init=False)
max_output_tokens: int = field(init=False)

def try_query(self, query: str, images: list[ImageArtifact]) -> TextArtifact:
raise DummyException(__class__.__name__, "try_query")
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ class BedrockClaudePromptModelDriver(BasePromptModelDriver):
top_k: int = field(default=250, kw_only=True, metadata={"serializable": True})
_tokenizer: BedrockClaudeTokenizer = field(default=None, kw_only=True)
prompt_driver: Optional[AmazonBedrockPromptDriver] = field(default=None, kw_only=True)
anthropic_version: str = field(default="bedrock-2023-05-31", kw_only=True, metadata={"serializable": True})
anthropic_version: str = "bedrock-2023-05-31" # static string for AWS: https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-anthropic-claude-messages.html#api-inference-examples-claude-multimodal-code-example

@property
def tokenizer(self) -> BedrockClaudeTokenizer:
Expand Down
30 changes: 9 additions & 21 deletions tests/unit/config/test_amazon_bedrock_structure_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,16 +33,16 @@ def test_to_dict(self, config):
"seed": None,
"type": "AmazonBedrockImageGenerationDriver",
},
"image_query_driver": {"type": "DummyImageQueryDriver"},
"image_query_driver": {
"type": "AmazonBedrockImageQueryDriver",
"model": "anthropic.claude-3-sonnet-20240229-v1:0",
"max_output_tokens": 256,
"image_query_model_driver": {"type": "BedrockClaudeImageQueryModelDriver"},
},
"prompt_driver": {
"max_tokens": None,
"model": "anthropic.claude-3-sonnet-20240229-v1:0",
"prompt_model_driver": {
"type": "BedrockClaudePromptModelDriver",
"anthropic_version": "bedrock-2023-05-31",
"top_k": 250,
"top_p": 0.999,
},
"prompt_model_driver": {"type": "BedrockClaudePromptModelDriver", "top_k": 250, "top_p": 0.999},
"stream": False,
"temperature": 0.1,
"type": "AmazonBedrockPromptDriver",
Expand All @@ -66,12 +66,7 @@ def test_to_dict(self, config):
"temperature": 0.1,
"max_tokens": None,
"model": "anthropic.claude-3-sonnet-20240229-v1:0",
"prompt_model_driver": {
"type": "BedrockClaudePromptModelDriver",
"anthropic_version": "bedrock-2023-05-31",
"top_k": 250,
"top_p": 0.999,
},
"prompt_model_driver": {"type": "BedrockClaudePromptModelDriver", "top_k": 250, "top_p": 0.999},
"stream": False,
},
"vector_store_driver": {
Expand All @@ -93,7 +88,6 @@ def test_to_dict(self, config):
"model": "anthropic.claude-3-sonnet-20240229-v1:0",
"prompt_model_driver": {
"type": "BedrockClaudePromptModelDriver",
"anthropic_version": "bedrock-2023-05-31",
"top_k": 250,
"top_p": 0.999,
},
Expand All @@ -109,7 +103,6 @@ def test_to_dict(self, config):
"model": "anthropic.claude-3-sonnet-20240229-v1:0",
"prompt_model_driver": {
"type": "BedrockClaudePromptModelDriver",
"anthropic_version": "bedrock-2023-05-31",
"top_k": 250,
"top_p": 0.999,
},
Expand All @@ -124,12 +117,7 @@ def test_to_dict(self, config):
"temperature": 0.1,
"max_tokens": None,
"model": "anthropic.claude-3-sonnet-20240229-v1:0",
"prompt_model_driver": {
"type": "BedrockClaudePromptModelDriver",
"anthropic_version": "bedrock-2023-05-31",
"top_k": 250,
"top_p": 0.999,
},
"prompt_model_driver": {"type": "BedrockClaudePromptModelDriver", "top_k": 250, "top_p": 0.999},
"stream": False,
},
},
Expand Down
7 changes: 6 additions & 1 deletion tests/unit/config/test_anthropic_structure_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,12 @@ def test_to_dict(self, config):
"top_k": 250,
},
"image_generation_driver": {"type": "DummyImageGenerationDriver"},
"image_query_driver": {"type": "DummyImageQueryDriver"},
"image_query_driver": {
"type": "AnthropicImageQueryDriver",
"model": "claude-3-opus-20240229",
"api_key": None,
"max_output_tokens": 256,
},
"embedding_driver": {
"type": "VoyageAiEmbeddingDriver",
"model": "voyage-large-2",
Expand Down
2 changes: 1 addition & 1 deletion tests/unit/config/test_openai_structure_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def test_to_dict(self, config):
"api_version": None,
"base_url": None,
"image_quality": "auto",
"max_tokens": 300,
"max_output_tokens": 256,
"model": "gpt-4-vision-preview",
"organization": None,
"type": "OpenAiVisionImageQueryDriver",
Expand Down

0 comments on commit 3ee66c1

Please sign in to comment.