diff --git a/.gitignore b/.gitignore index 895ca52ac..f7478c108 100644 --- a/.gitignore +++ b/.gitignore @@ -49,4 +49,7 @@ dist/** # translation files *.xlf *.nls.*.json -*.i18n.json \ No newline at end of file +*.i18n.json + +# asdf +.tool-versions diff --git a/CHANGELOG.md b/CHANGELOG.md index 4b2d5ecab..a3846afe3 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -16,6 +16,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Support for `claude-3-opus`, `claude-3-sonnet`, and `claude-3-haiku` in `AnthropicPromptDriver`. - Support for `anthropic.claude-3-sonnet-20240229-v1:0` and `anthropic.claude-3-haiku-20240307-v1:0` in `BedrockClaudePromptModelDriver`. - `top_k` and `top_p` parameters in `AnthropicPromptDriver`. +- Added `AnthropicImageQueryDriver` for Claude-3 multi-modal models +- Added `AmazonBedrockImageQueryDriver` along with `BedrockClaudeImageQueryDriverModel` for Claude-3 in Bedrock support - `BaseWebScraperDriver` allowing multiple web scraping implementations. - `TrafilaturaWebScraperDriver` for scraping text from web pages using trafilatura. - `MarkdownifyWebScraperDriver` for scraping text from web pages using playwright and converting to markdown using markdownify. @@ -28,12 +30,13 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Changed - **BREAKING**: `ActionSubtask` was renamed to `ActionsSubtask`. - **BREAKING**: Removed `subtask_action_name`, `subtask_action_path`, and `subtask_action_input` in `BaseActionSubtaskEvent`. +- **BREAKING**: `OpenAiVisionImageQueryDriver` field `model` no longer defaults to `gpt-4-vision-preview` and must be specified - Default model of `OpenAiEmbeddingDriver` to `text-embedding-3-small`. - Default model of `OpenAiStructureConfig` to `text-embedding-3-small`. - `BaseTextLoader` to accept a `BaseChunker`. - Default model of `AmazonBedrockStructureConfig` to `anthropic.claude-3-sonnet-20240229-v1:0`. - `AnthropicPromptDriver` and `BedrockClaudePromptModelDriver` to use Anthropic's Messages API. - +- `OpenAiVisionImageQueryDriver` now has a required field `max_tokens` that defaults to 256 ## [0.23.2] - 2024-03-15 diff --git a/griptape/config/amazon_bedrock_structure_config.py b/griptape/config/amazon_bedrock_structure_config.py index 8bd57db7d..54b8d91c7 100644 --- a/griptape/config/amazon_bedrock_structure_config.py +++ b/griptape/config/amazon_bedrock_structure_config.py @@ -12,9 +12,11 @@ ) from griptape.drivers import ( AmazonBedrockImageGenerationDriver, + AmazonBedrockImageQueryDriver, AmazonBedrockPromptDriver, AmazonBedrockTitanEmbeddingDriver, BedrockClaudePromptModelDriver, + BedrockClaudeImageQueryModelDriver, BedrockTitanImageGenerationModelDriver, LocalVectorStoreDriver, ) @@ -34,6 +36,10 @@ class AmazonBedrockStructureConfig(BaseStructureConfig): model="amazon.titan-image-generator-v1", image_generation_model_driver=BedrockTitanImageGenerationModelDriver(), ), + image_query_driver=AmazonBedrockImageQueryDriver( + model="anthropic.claude-3-sonnet-20240229-v1:0", + image_query_model_driver=BedrockClaudeImageQueryModelDriver(), + ), embedding_driver=AmazonBedrockTitanEmbeddingDriver(model="amazon.titan-embed-text-v1"), vector_store_driver=LocalVectorStoreDriver( embedding_driver=AmazonBedrockTitanEmbeddingDriver(model="amazon.titan-embed-text-v1") diff --git a/griptape/config/anthropic_structure_config.py b/griptape/config/anthropic_structure_config.py index 956df85ba..06978a5c2 100644 --- a/griptape/config/anthropic_structure_config.py +++ b/griptape/config/anthropic_structure_config.py @@ -10,7 +10,12 @@ StructureTaskMemoryQueryEngineConfig, StructureTaskMemorySummaryEngineConfig, ) -from griptape.drivers import LocalVectorStoreDriver, AnthropicPromptDriver, VoyageAiEmbeddingDriver +from griptape.drivers import ( + LocalVectorStoreDriver, + AnthropicPromptDriver, + AnthropicImageQueryDriver, + VoyageAiEmbeddingDriver, +) @define @@ -23,6 +28,7 @@ class AnthropicStructureConfig(BaseStructureConfig): vector_store_driver=LocalVectorStoreDriver( embedding_driver=VoyageAiEmbeddingDriver(model="voyage-large-2") ), + image_query_driver=AnthropicImageQueryDriver(model="claude-3-opus-20240229"), ) ), kw_only=True, diff --git a/griptape/config/openai_structure_config.py b/griptape/config/openai_structure_config.py index 31a445d07..283fca2d1 100644 --- a/griptape/config/openai_structure_config.py +++ b/griptape/config/openai_structure_config.py @@ -26,7 +26,7 @@ class OpenAiStructureConfig(BaseStructureConfig): lambda: StructureGlobalDriversConfig( prompt_driver=OpenAiChatPromptDriver(model="gpt-4"), image_generation_driver=OpenAiImageGenerationDriver(model="dall-e-2", image_size="512x512"), - image_query_driver=OpenAiVisionImageQueryDriver(model="gpt-4-vision-preview", max_tokens=300), + image_query_driver=OpenAiVisionImageQueryDriver(model="gpt-4-vision-preview"), embedding_driver=OpenAiEmbeddingDriver(model="text-embedding-3-small"), vector_store_driver=LocalVectorStoreDriver( embedding_driver=OpenAiEmbeddingDriver(model="text-embedding-3-small") diff --git a/griptape/drivers/__init__.py b/griptape/drivers/__init__.py index 5a3180807..a431dcd26 100644 --- a/griptape/drivers/__init__.py +++ b/griptape/drivers/__init__.py @@ -72,9 +72,15 @@ from .image_generation.azure_openai_image_generation_driver import AzureOpenAiImageGenerationDriver from .image_generation.dummy_image_generation_driver import DummyImageGenerationDriver +from .image_query_model.base_image_query_model_driver import BaseImageQueryModelDriver +from .image_query_model.bedrock_claude_image_query_model_driver import BedrockClaudeImageQueryModelDriver + from .image_query.base_image_query_driver import BaseImageQueryDriver -from .image_query.openai_vision_image_query_driver import OpenAiVisionImageQueryDriver +from .image_query.base_multi_model_image_query_driver import BaseMultiModelImageQueryDriver from .image_query.dummy_image_query_driver import DummyImageQueryDriver +from .image_query.openai_vision_image_query_driver import OpenAiVisionImageQueryDriver +from .image_query.anthropic_image_query_driver import AnthropicImageQueryDriver +from .image_query.amazon_bedrock_image_query_driver import AmazonBedrockImageQueryDriver from .web_scraper.base_web_scraper_driver import BaseWebScraperDriver from .web_scraper.trafilatura_web_scraper_driver import TrafilaturaWebScraperDriver @@ -144,9 +150,14 @@ "AmazonBedrockImageGenerationDriver", "AzureOpenAiImageGenerationDriver", "DummyImageGenerationDriver", + BaseImageQueryModelDriver, + BedrockClaudeImageQueryModelDriver, "BaseImageQueryDriver", "OpenAiVisionImageQueryDriver", "DummyImageQueryDriver", + AnthropicImageQueryDriver, + BaseMultiModelImageQueryDriver, + AmazonBedrockImageQueryDriver, "BaseWebScraperDriver", "TrafilaturaWebScraperDriver", "MarkdownifyWebScraperDriver", diff --git a/griptape/drivers/image_query/amazon_bedrock_image_query_driver.py b/griptape/drivers/image_query/amazon_bedrock_image_query_driver.py new file mode 100644 index 000000000..a5a7c6f15 --- /dev/null +++ b/griptape/drivers/image_query/amazon_bedrock_image_query_driver.py @@ -0,0 +1,35 @@ +from __future__ import annotations +from typing import TYPE_CHECKING, Any +from attr import define, field, Factory +from griptape.artifacts import ImageArtifact, TextArtifact +from griptape.drivers import BaseMultiModelImageQueryDriver +from griptape.utils import import_optional_dependency +import json + +if TYPE_CHECKING: + import boto3 + + +@define +class AmazonBedrockImageQueryDriver(BaseMultiModelImageQueryDriver): + session: boto3.Session = field(default=Factory(lambda: import_optional_dependency("boto3").Session()), kw_only=True) + bedrock_client: Any = field( + default=Factory(lambda self: self.session.client("bedrock-runtime"), takes_self=True), kw_only=True + ) + + def try_query(self, query: str, images: list[ImageArtifact]) -> TextArtifact: + payload = self.image_query_model_driver.image_query_request_parameters(query, images, self.max_tokens) + + response = self.bedrock_client.invoke_model( + modelId=self.model, contentType="application/json", accept="application/json", body=json.dumps(payload) + ) + + response_body = json.loads(response.get("body").read()) + + if response_body is None: + raise ValueError("Model response is empty") + + try: + return self.image_query_model_driver.process_output(response_body) + except Exception as e: + raise ValueError(f"Output is unable to be processed as returned {e}") diff --git a/griptape/drivers/image_query/anthropic_image_query_driver.py b/griptape/drivers/image_query/anthropic_image_query_driver.py new file mode 100644 index 000000000..1ffc24d95 --- /dev/null +++ b/griptape/drivers/image_query/anthropic_image_query_driver.py @@ -0,0 +1,59 @@ +from __future__ import annotations +from typing import Optional, Any +from attr import define, field, Factory +from griptape.artifacts import ImageArtifact, TextArtifact +from griptape.drivers import BaseImageQueryDriver +from griptape.utils import import_optional_dependency + + +@define +class AnthropicImageQueryDriver(BaseImageQueryDriver): + """ + Attributes: + api_key: Anthropic API key. + model: Anthropic model name. + client: Custom `Anthropic` client. + """ + + api_key: Optional[str] = field(default=None, kw_only=True, metadata={"serializable": True}) + model: str = field(kw_only=True, metadata={"serializable": True}) + client: Any = field( + default=Factory( + lambda self: import_optional_dependency("anthropic").Anthropic(api_key=self.api_key), takes_self=True + ), + kw_only=True, + ) + + def try_query(self, query: str, images: list[ImageArtifact]) -> TextArtifact: + if self.max_tokens is None: + raise TypeError("max_output_tokens can't be empty") + + response = self.client.messages.create(**self._base_params(query, images)) + content_blocks = response.content + + if len(content_blocks) < 1: + raise ValueError("Response content is empty") + + text_content = content_blocks[0].text + + return TextArtifact(text_content) + + def _base_params(self, text_query: str, images: list[ImageArtifact]): + content = [self._construct_image_message(image) for image in images] + content.append(self._construct_text_message(text_query)) + messages = self._construct_messages(content) + params = {"model": self.model, "messages": messages, "max_tokens": self.max_tokens} + + return params + + def _construct_image_message(self, image_data: ImageArtifact) -> dict: + data = image_data.base64 + type = image_data.mime_type + + return {"source": {"data": data, "media_type": type, "type": "base64"}, "type": "image"} + + def _construct_text_message(self, query: str) -> dict: + return {"text": query, "type": "text"} + + def _construct_messages(self, content: list) -> list: + return [{"content": content, "role": "user"}] diff --git a/griptape/drivers/image_query/base_image_query_driver.py b/griptape/drivers/image_query/base_image_query_driver.py index 05fbb945e..369cb13c6 100644 --- a/griptape/drivers/image_query/base_image_query_driver.py +++ b/griptape/drivers/image_query/base_image_query_driver.py @@ -16,6 +16,7 @@ @define class BaseImageQueryDriver(SerializableMixin, ExponentialBackoffMixin, ABC): structure: Optional[Structure] = field(default=None, kw_only=True) + max_tokens: int = field(default=256, kw_only=True, metadata={"serializable": True}) def before_run(self, query: str, images: list[ImageArtifact]) -> None: if self.structure: diff --git a/griptape/drivers/image_query/base_multi_model_image_query_driver.py b/griptape/drivers/image_query/base_multi_model_image_query_driver.py new file mode 100644 index 000000000..07d5b7c27 --- /dev/null +++ b/griptape/drivers/image_query/base_multi_model_image_query_driver.py @@ -0,0 +1,22 @@ +from __future__ import annotations +from abc import ABC + +from attr import field, define + +from griptape.drivers import BaseImageQueryDriver, BaseImageQueryModelDriver + + +@define +class BaseMultiModelImageQueryDriver(BaseImageQueryDriver, ABC): + """Image Query Driver for platforms like Amazon Bedrock that host many LLM models. + + Instances of this Image Query Driver require a Image Query Model Driver which is used to structure the + image generation request in the format required by the model and to process the output. + + Attributes: + model: Model name to use + image_query_model_driver: Image Model Driver to use. + """ + + model: str = field(kw_only=True, metadata={"serializable": True}) + image_query_model_driver: BaseImageQueryModelDriver = field(kw_only=True, metadata={"serializable": True}) diff --git a/griptape/drivers/image_query/dummy_image_query_driver.py b/griptape/drivers/image_query/dummy_image_query_driver.py index 327f0fe14..b8f677507 100644 --- a/griptape/drivers/image_query/dummy_image_query_driver.py +++ b/griptape/drivers/image_query/dummy_image_query_driver.py @@ -8,6 +8,7 @@ @define class DummyImageQueryDriver(BaseImageQueryDriver): model: str = field(init=False) + max_tokens: int = field(init=False) def try_query(self, query: str, images: list[ImageArtifact]) -> TextArtifact: raise DummyException(__class__.__name__, "try_query") diff --git a/griptape/drivers/image_query/openai_vision_image_query_driver.py b/griptape/drivers/image_query/openai_vision_image_query_driver.py index 4b9efc874..c37cebba9 100644 --- a/griptape/drivers/image_query/openai_vision_image_query_driver.py +++ b/griptape/drivers/image_query/openai_vision_image_query_driver.py @@ -17,14 +17,13 @@ @define class OpenAiVisionImageQueryDriver(BaseImageQueryDriver): - model: str = field(default="gpt-4-vision-preview", kw_only=True, metadata={"serializable": True}) + model: str = field(kw_only=True, metadata={"serializable": True}) api_type: str = field(default=openai.api_type, kw_only=True) api_version: Optional[str] = field(default=openai.api_version, kw_only=True, metadata={"serializable": True}) base_url: Optional[str] = field(default=None, kw_only=True, metadata={"serializable": True}) api_key: Optional[str] = field(default=None, kw_only=True) organization: Optional[str] = field(default=openai.organization, kw_only=True, metadata={"serializable": True}) image_quality: Literal["auto", "low", "high"] = field(default="auto", kw_only=True, metadata={"serializable": True}) - max_tokens: Optional[int] = field(default=None, kw_only=True, metadata={"serializable": True}) client: openai.OpenAI = field( default=Factory( lambda self: openai.OpenAI(api_key=self.api_key, base_url=self.base_url, organization=self.organization), @@ -46,10 +45,7 @@ def try_query(self, query: str, images: list[ImageArtifact]) -> TextArtifact: ) messages = ChatCompletionUserMessageParam(content=message_parts, role="user") - params = {"model": self.model, "messages": [messages]} - - if self.max_tokens is not None: - params["max_tokens"] = self.max_tokens + params = {"model": self.model, "messages": [messages], "max_tokens": self.max_tokens} response = self.client.chat.completions.create(**params) diff --git a/griptape/drivers/image_query_model/__init__.py b/griptape/drivers/image_query_model/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/griptape/drivers/image_query_model/base_image_query_model_driver.py b/griptape/drivers/image_query_model/base_image_query_model_driver.py new file mode 100644 index 000000000..46320f945 --- /dev/null +++ b/griptape/drivers/image_query_model/base_image_query_model_driver.py @@ -0,0 +1,16 @@ +from __future__ import annotations +from abc import ABC, abstractmethod +from attr import define +from griptape.artifacts import TextArtifact, ImageArtifact +from griptape.mixins import SerializableMixin + + +@define +class BaseImageQueryModelDriver(SerializableMixin, ABC): + @abstractmethod + def image_query_request_parameters(self, query: str, images: list[ImageArtifact], max_tokens: int) -> dict: + ... + + @abstractmethod + def process_output(self, output: dict) -> TextArtifact: + ... diff --git a/griptape/drivers/image_query_model/bedrock_claude_image_query_model_driver.py b/griptape/drivers/image_query_model/bedrock_claude_image_query_model_driver.py new file mode 100644 index 000000000..caa646ba7 --- /dev/null +++ b/griptape/drivers/image_query_model/bedrock_claude_image_query_model_driver.py @@ -0,0 +1,39 @@ +from __future__ import annotations +from typing import Optional +from attr import field, define +from griptape.artifacts import ImageArtifact, TextArtifact +from griptape.drivers import BaseImageQueryModelDriver + + +@define +class BedrockClaudeImageQueryModelDriver(BaseImageQueryModelDriver): + ANTHROPIC_VERSION = "bedrock-2023-05-31" # static string for AWS: https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-anthropic-claude-messages.html#api-inference-examples-claude-multimodal-code-example + + def image_query_request_parameters(self, query: str, images: list[ImageArtifact], max_tokens: int) -> dict: + content = [self._construct_image_message(image) for image in images] + content.append(self._construct_text_message(query)) + messages = self._construct_messages(content) + input_params = {"messages": messages, "anthropic_version": self.ANTHROPIC_VERSION, "max_tokens": max_tokens} + + return input_params + + def process_output(self, output: dict) -> TextArtifact: + content_blocks = output["content"] + if len(content_blocks) < 1: + raise ValueError("Response content is empty") + + text_content = content_blocks[0]["text"] + + return TextArtifact(text_content) + + def _construct_image_message(self, image_data: ImageArtifact) -> dict: + data = image_data.base64 + type = image_data.mime_type + + return {"source": {"data": data, "media_type": type, "type": "base64"}, "type": "image"} + + def _construct_text_message(self, query: str) -> dict: + return {"text": query, "type": "text"} + + def _construct_messages(self, content: list) -> list: + return [{"content": content, "role": "user"}] diff --git a/griptape/drivers/prompt_model/bedrock_claude_prompt_model_driver.py b/griptape/drivers/prompt_model/bedrock_claude_prompt_model_driver.py index 4a0d655be..1233a8f04 100644 --- a/griptape/drivers/prompt_model/bedrock_claude_prompt_model_driver.py +++ b/griptape/drivers/prompt_model/bedrock_claude_prompt_model_driver.py @@ -10,11 +10,12 @@ @define class BedrockClaudePromptModelDriver(BasePromptModelDriver): + ANTHROPIC_VERSION = "bedrock-2023-05-31" # static string for AWS: https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-anthropic-claude-messages.html#api-inference-examples-claude-multimodal-code-example + top_p: float = field(default=0.999, kw_only=True, metadata={"serializable": True}) top_k: int = field(default=250, kw_only=True, metadata={"serializable": True}) _tokenizer: BedrockClaudeTokenizer = field(default=None, kw_only=True) prompt_driver: Optional[AmazonBedrockPromptDriver] = field(default=None, kw_only=True) - anthropic_version: str = field(default="bedrock-2023-05-31", kw_only=True, metadata={"serializable": True}) @property def tokenizer(self) -> BedrockClaudeTokenizer: @@ -60,7 +61,7 @@ def prompt_stack_to_model_params(self, prompt_stack: PromptStack) -> dict: "top_p": self.top_p, "top_k": self.top_k, "max_tokens": self.prompt_driver.max_output_tokens(self.prompt_driver.prompt_stack_to_string(prompt_stack)), - "anthropic_version": self.anthropic_version, + "anthropic_version": self.ANTHROPIC_VERSION, **input, } diff --git a/tests/unit/config/test_amazon_bedrock_structure_config.py b/tests/unit/config/test_amazon_bedrock_structure_config.py index 29c27a4e1..04fe4d7a1 100644 --- a/tests/unit/config/test_amazon_bedrock_structure_config.py +++ b/tests/unit/config/test_amazon_bedrock_structure_config.py @@ -33,16 +33,16 @@ def test_to_dict(self, config): "seed": None, "type": "AmazonBedrockImageGenerationDriver", }, - "image_query_driver": {"type": "DummyImageQueryDriver"}, + "image_query_driver": { + "type": "AmazonBedrockImageQueryDriver", + "model": "anthropic.claude-3-sonnet-20240229-v1:0", + "max_tokens": 256, + "image_query_model_driver": {"type": "BedrockClaudeImageQueryModelDriver"}, + }, "prompt_driver": { "max_tokens": None, "model": "anthropic.claude-3-sonnet-20240229-v1:0", - "prompt_model_driver": { - "type": "BedrockClaudePromptModelDriver", - "anthropic_version": "bedrock-2023-05-31", - "top_k": 250, - "top_p": 0.999, - }, + "prompt_model_driver": {"type": "BedrockClaudePromptModelDriver", "top_k": 250, "top_p": 0.999}, "stream": False, "temperature": 0.1, "type": "AmazonBedrockPromptDriver", @@ -66,12 +66,7 @@ def test_to_dict(self, config): "temperature": 0.1, "max_tokens": None, "model": "anthropic.claude-3-sonnet-20240229-v1:0", - "prompt_model_driver": { - "type": "BedrockClaudePromptModelDriver", - "anthropic_version": "bedrock-2023-05-31", - "top_k": 250, - "top_p": 0.999, - }, + "prompt_model_driver": {"type": "BedrockClaudePromptModelDriver", "top_k": 250, "top_p": 0.999}, "stream": False, }, "vector_store_driver": { @@ -93,7 +88,6 @@ def test_to_dict(self, config): "model": "anthropic.claude-3-sonnet-20240229-v1:0", "prompt_model_driver": { "type": "BedrockClaudePromptModelDriver", - "anthropic_version": "bedrock-2023-05-31", "top_k": 250, "top_p": 0.999, }, @@ -109,7 +103,6 @@ def test_to_dict(self, config): "model": "anthropic.claude-3-sonnet-20240229-v1:0", "prompt_model_driver": { "type": "BedrockClaudePromptModelDriver", - "anthropic_version": "bedrock-2023-05-31", "top_k": 250, "top_p": 0.999, }, @@ -124,12 +117,7 @@ def test_to_dict(self, config): "temperature": 0.1, "max_tokens": None, "model": "anthropic.claude-3-sonnet-20240229-v1:0", - "prompt_model_driver": { - "type": "BedrockClaudePromptModelDriver", - "anthropic_version": "bedrock-2023-05-31", - "top_k": 250, - "top_p": 0.999, - }, + "prompt_model_driver": {"type": "BedrockClaudePromptModelDriver", "top_k": 250, "top_p": 0.999}, "stream": False, }, }, diff --git a/tests/unit/config/test_anthropic_structure_config.py b/tests/unit/config/test_anthropic_structure_config.py index 292129e06..e05ff0d50 100644 --- a/tests/unit/config/test_anthropic_structure_config.py +++ b/tests/unit/config/test_anthropic_structure_config.py @@ -28,7 +28,12 @@ def test_to_dict(self, config): "top_k": 250, }, "image_generation_driver": {"type": "DummyImageGenerationDriver"}, - "image_query_driver": {"type": "DummyImageQueryDriver"}, + "image_query_driver": { + "type": "AnthropicImageQueryDriver", + "model": "claude-3-opus-20240229", + "api_key": None, + "max_tokens": 256, + }, "embedding_driver": { "type": "VoyageAiEmbeddingDriver", "model": "voyage-large-2", diff --git a/tests/unit/config/test_openai_structure_config.py b/tests/unit/config/test_openai_structure_config.py index b87e059a3..8eaabb27c 100644 --- a/tests/unit/config/test_openai_structure_config.py +++ b/tests/unit/config/test_openai_structure_config.py @@ -53,7 +53,7 @@ def test_to_dict(self, config): "api_version": None, "base_url": None, "image_quality": "auto", - "max_tokens": 300, + "max_tokens": 256, "model": "gpt-4-vision-preview", "organization": None, "type": "OpenAiVisionImageQueryDriver", diff --git a/tests/unit/drivers/image_query/test_amazon_bedrock_image_query_driver.py b/tests/unit/drivers/image_query/test_amazon_bedrock_image_query_driver.py new file mode 100644 index 000000000..7f5b44aa4 --- /dev/null +++ b/tests/unit/drivers/image_query/test_amazon_bedrock_image_query_driver.py @@ -0,0 +1,48 @@ +import pytest +import io +from unittest.mock import Mock +from griptape.drivers import AmazonBedrockImageQueryDriver +from griptape.artifacts import ImageArtifact, TextArtifact + + +class TestAmazonBedrockImageQueryDriver: + @pytest.fixture + def bedrock_client(self, mocker): + return Mock() + + @pytest.fixture + def session(self, bedrock_client): + session = Mock() + session.client.return_value = bedrock_client + + return session + + @pytest.fixture + def model_driver(self): + model_driver = Mock() + model_driver.image_query_request_parameters.return_value = {} + model_driver.process_output.return_value = TextArtifact("content") + + return model_driver + + @pytest.fixture + def image_query_driver(self, session, model_driver): + return AmazonBedrockImageQueryDriver(session=session, model="model", image_query_model_driver=model_driver) + + def test_init(self, image_query_driver): + assert image_query_driver + + def test_try_query(self, image_query_driver): + image_query_driver.bedrock_client.invoke_model.return_value = {"body": io.BytesIO(b"""{"content": []}""")} + + text_artifact = image_query_driver.try_query( + "Prompt String", [ImageArtifact(value=b"test-data", width=100, height=100)] + ) + + assert text_artifact.value == "content" + + def test_try_query_no_body(self, image_query_driver): + image_query_driver.bedrock_client.invoke_model.return_value = {"body": io.BytesIO(b"")} + + with pytest.raises(ValueError): + image_query_driver.try_query("Prompt String", [ImageArtifact(value=b"test-data", width=100, height=100)]) diff --git a/tests/unit/drivers/image_query/test_anthropic_image_query_driver.py b/tests/unit/drivers/image_query/test_anthropic_image_query_driver.py new file mode 100644 index 000000000..ff53375b0 --- /dev/null +++ b/tests/unit/drivers/image_query/test_anthropic_image_query_driver.py @@ -0,0 +1,93 @@ +import pytest +import base64 +from unittest.mock import Mock +from griptape.drivers import AnthropicImageQueryDriver +from griptape.artifacts import ImageArtifact + + +class TestAnthropicImageQueryDriver: + @pytest.fixture + def mock_client(self, mocker): + mock_client = mocker.patch("anthropic.Anthropic") + return_value = Mock(text="Content") + mock_client.return_value.messages.create.return_value.content = [return_value] + + return mock_client + + @pytest.mark.parametrize( + "model", [("claude-3-haiku-20240307"), ("claude-3-sonnet-20240229"), ("claude-3-opus-20240229")] + ) + def test_init(self, model): + assert AnthropicImageQueryDriver(model=model) + + def test_try_query(self, mock_client): + driver = AnthropicImageQueryDriver(model="test-model") + test_prompt_string = "Prompt String" + test_binary_data = b"test-data" + + text_artifact = driver.try_query( + test_prompt_string, [ImageArtifact(value=test_binary_data, width=100, height=100)] + ) + + expected_message = self._expected_message(test_binary_data, "image/png", test_prompt_string) + + mock_client.return_value.messages.create.assert_called_once_with( + model=driver.model, max_tokens=256, messages=[expected_message] + ) + + assert text_artifact.value == "Content" + + def test_try_query_max_tokens_value(self, mock_client): + driver = AnthropicImageQueryDriver(model="test-model", max_tokens=1024) + test_prompt_string = "Prompt String" + test_binary_data = b"test-data" + + text_artifact = driver.try_query( + test_prompt_string, [ImageArtifact(value=test_binary_data, width=100, height=100)] + ) + + expected_message = self._expected_message(test_binary_data, "image/png", test_prompt_string) + + mock_client.return_value.messages.create.assert_called_once_with( + model=driver.model, max_tokens=1024, messages=[expected_message] + ) + + assert text_artifact.value == "Content" + + def test_try_query_max_tokens_none(self, mock_client): + driver = AnthropicImageQueryDriver(model="test-model", max_tokens=None) # pyright: ignore + test_prompt_string = "Prompt String" + test_binary_data = b"test-data" + with pytest.raises(TypeError): + driver.try_query(test_prompt_string, [ImageArtifact(value=test_binary_data, width=100, height=100)]) + + def test_try_query_wrong_media_type(self, mock_client): + driver = AnthropicImageQueryDriver(model="test-model") + test_prompt_string = "Prompt String" + test_binary_data = b"test-data" + + # we expect this to pass Griptape code as the model will error approriately + text_artifact = driver.try_query( + test_prompt_string, [ImageArtifact(value=test_binary_data, mime_type="image/exr", width=100, height=100)] + ) + + expected_message = self._expected_message(test_binary_data, "image/exr", test_prompt_string) + + mock_client.return_value.messages.create.assert_called_once_with( + model=driver.model, messages=[expected_message], max_tokens=256 + ) + + assert text_artifact.value == "Content" + + def _expected_message(self, expected_data, expected_media_type, expected_prompt_string): + encoded_data = base64.b64encode(expected_data).decode("utf-8") + return { + "content": [ + { + "source": {"data": encoded_data, "media_type": expected_media_type, "type": "base64"}, + "type": "image", + }, + {"text": expected_prompt_string, "type": "text"}, + ], + "role": "user", + } diff --git a/tests/unit/drivers/image_query/test_dummy_image_query_driver.py b/tests/unit/drivers/image_query/test_dummy_image_query_driver.py new file mode 100644 index 000000000..6341717b8 --- /dev/null +++ b/tests/unit/drivers/image_query/test_dummy_image_query_driver.py @@ -0,0 +1,18 @@ +from griptape.drivers import DummyImageQueryDriver +from griptape.artifacts import ImageArtifact +import pytest + +from griptape.exceptions import DummyException + + +class TestDummyImageQueryDriver: + @pytest.fixture + def image_query_driver(self): + return DummyImageQueryDriver() + + def test_init(self, image_query_driver): + assert image_query_driver + + def test_try_query(self, image_query_driver): + with pytest.raises(DummyException): + image_query_driver.try_query("Prompt", [ImageArtifact(value=b"", width=100, height=100)]) diff --git a/tests/unit/drivers/image_query/test_openai_image_query_driver.py b/tests/unit/drivers/image_query/test_openai_image_query_driver.py new file mode 100644 index 000000000..f811f87cf --- /dev/null +++ b/tests/unit/drivers/image_query/test_openai_image_query_driver.py @@ -0,0 +1,59 @@ +import pytest +from unittest.mock import Mock +from griptape.drivers import OpenAiVisionImageQueryDriver +from griptape.artifacts import ImageArtifact + + +class TestOpenAiVisionImageQueryDriver: + @pytest.fixture + def mock_completion_create(self, mocker): + mock_chat_create = mocker.patch("openai.OpenAI").return_value.chat.completions.create + mock_choice = Mock(message=Mock(content="expected_output_text")) + mock_chat_create.return_value.choices = [mock_choice] + return mock_chat_create + + def test_init(self): + assert OpenAiVisionImageQueryDriver(model="gpt-4-vision-preview") + + def test_try_query_defaults(self, mock_completion_create): + driver = OpenAiVisionImageQueryDriver(model="gpt-4-vision-preview") + test_prompt_string = "Prompt String" + test_binary_data = b"test-data" + test_image = ImageArtifact(value=test_binary_data, width=100, height=100) + text_artifact = driver.try_query(test_prompt_string, [test_image]) + + messages = self._expected_messages(test_prompt_string, test_image.base64) + + mock_completion_create.assert_called_once_with(model=driver.model, messages=[messages], max_tokens=256) + + assert text_artifact.value == "expected_output_text" + + def test_try_query_max_tokens(self, mock_completion_create): + driver = OpenAiVisionImageQueryDriver(model="gpt-4-vision-preview", max_tokens=1024) + test_prompt_string = "Prompt String" + test_binary_data = b"test-data" + test_image = ImageArtifact(value=test_binary_data, width=100, height=100) + driver.try_query(test_prompt_string, [test_image]) + + messages = self._expected_messages(test_prompt_string, test_image.base64) + + mock_completion_create.assert_called_once_with(model=driver.model, messages=[messages], max_tokens=1024) + + def test_try_query_multiple_choices(self, mock_completion_create): + mock_completion_create.return_value.choices.append(Mock(message=Mock(content="expected_output_text2"))) + driver = OpenAiVisionImageQueryDriver(model="gpt-4-vision-preview") + + with pytest.raises(Exception): + driver.try_query("Prompt String", [ImageArtifact(value=b"test-data", width=100, height=100)]) + + def _expected_messages(self, expected_prompt_string, expected_binary_data): + return { + "content": [ + {"type": "text", "text": expected_prompt_string}, + { + "type": "image_url", + "image_url": {"url": f"data:image/png;base64,{expected_binary_data}", "detail": "auto"}, + }, + ], + "role": "user", + } diff --git a/tests/unit/drivers/image_query_models/test_bedrock_claude_image_query_model_driver.py b/tests/unit/drivers/image_query_models/test_bedrock_claude_image_query_model_driver.py new file mode 100644 index 000000000..a4c57917b --- /dev/null +++ b/tests/unit/drivers/image_query_models/test_bedrock_claude_image_query_model_driver.py @@ -0,0 +1,41 @@ +import pytest +from griptape.drivers import BedrockClaudeImageQueryModelDriver +from griptape.artifacts import ImageArtifact, TextArtifact + + +class TestBedrockClaudeImageQueryModelDriver: + def test_init(self): + assert BedrockClaudeImageQueryModelDriver() + + def test_image_query_request_parameters(self): + model_driver = BedrockClaudeImageQueryModelDriver() + params = model_driver.image_query_request_parameters( + "Prompt String", [ImageArtifact(value=b"test-data", width=100, height=100)], 256 + ) + + assert isinstance(params, dict) + assert "anthropic_version" in params + assert params["anthropic_version"] == "bedrock-2023-05-31" + assert "messages" in params + assert len(params["messages"]) == 1 + assert "max_tokens" in params + assert params["max_tokens"] == 256 + + def test_process_output(self): + model_driver = BedrockClaudeImageQueryModelDriver() + output = model_driver.process_output({"content": [{"text": "Content"}]}) + + assert isinstance(output, TextArtifact) + assert output.value == "Content" + + def test_process_output_no_content_key(self): + with pytest.raises(KeyError): + BedrockClaudeImageQueryModelDriver().process_output({"explicitly-not-content": ["ContentBlock"]}) + + def test_process_output_bad_length(self): + with pytest.raises(ValueError): + BedrockClaudeImageQueryModelDriver().process_output({"content": []}) + + def test_process_output_no_text_key(self): + with pytest.raises(KeyError): + BedrockClaudeImageQueryModelDriver().process_output({"content": [{"not-text": "Content"}]}) diff --git a/tests/unit/drivers/prompt_models/test_bedrock_claude_prompt_model_driver.py b/tests/unit/drivers/prompt_models/test_bedrock_claude_prompt_model_driver.py index 98538f1bf..94bc97122 100644 --- a/tests/unit/drivers/prompt_models/test_bedrock_claude_prompt_model_driver.py +++ b/tests/unit/drivers/prompt_models/test_bedrock_claude_prompt_model_driver.py @@ -92,7 +92,7 @@ def test_prompt_stack_to_model_params(self, driver, system_enabled): expected = { "temperature": 0.12345, "max_tokens": max_tokens, - "anthropic_version": driver.anthropic_version, + "anthropic_version": driver.ANTHROPIC_VERSION, "messages": [ {"role": "user", "content": "bar"}, {"role": "assistant", "content": "baz"},