Skip to content

Commit

Permalink
Claude 3 Image Query Support (#700)
Browse files Browse the repository at this point in the history
  • Loading branch information
emjay07 committed Mar 21, 2024
1 parent 7ab53c2 commit f96af2e
Show file tree
Hide file tree
Showing 25 changed files with 488 additions and 37 deletions.
5 changes: 4 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -49,4 +49,7 @@ dist/**
# translation files
*.xlf
*.nls.*.json
*.i18n.json
*.i18n.json

# asdf
.tool-versions
5 changes: 4 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Support for `claude-3-opus`, `claude-3-sonnet`, and `claude-3-haiku` in `AnthropicPromptDriver`.
- Support for `anthropic.claude-3-sonnet-20240229-v1:0` and `anthropic.claude-3-haiku-20240307-v1:0` in `BedrockClaudePromptModelDriver`.
- `top_k` and `top_p` parameters in `AnthropicPromptDriver`.
- Added `AnthropicImageQueryDriver` for Claude-3 multi-modal models
- Added `AmazonBedrockImageQueryDriver` along with `BedrockClaudeImageQueryDriverModel` for Claude-3 in Bedrock support
- `BaseWebScraperDriver` allowing multiple web scraping implementations.
- `TrafilaturaWebScraperDriver` for scraping text from web pages using trafilatura.
- `MarkdownifyWebScraperDriver` for scraping text from web pages using playwright and converting to markdown using markdownify.
Expand All @@ -28,12 +30,13 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
### Changed
- **BREAKING**: `ActionSubtask` was renamed to `ActionsSubtask`.
- **BREAKING**: Removed `subtask_action_name`, `subtask_action_path`, and `subtask_action_input` in `BaseActionSubtaskEvent`.
- **BREAKING**: `OpenAiVisionImageQueryDriver` field `model` no longer defaults to `gpt-4-vision-preview` and must be specified
- Default model of `OpenAiEmbeddingDriver` to `text-embedding-3-small`.
- Default model of `OpenAiStructureConfig` to `text-embedding-3-small`.
- `BaseTextLoader` to accept a `BaseChunker`.
- Default model of `AmazonBedrockStructureConfig` to `anthropic.claude-3-sonnet-20240229-v1:0`.
- `AnthropicPromptDriver` and `BedrockClaudePromptModelDriver` to use Anthropic's Messages API.

- `OpenAiVisionImageQueryDriver` now has a required field `max_tokens` that defaults to 256

## [0.23.2] - 2024-03-15

Expand Down
6 changes: 6 additions & 0 deletions griptape/config/amazon_bedrock_structure_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,11 @@
)
from griptape.drivers import (
AmazonBedrockImageGenerationDriver,
AmazonBedrockImageQueryDriver,
AmazonBedrockPromptDriver,
AmazonBedrockTitanEmbeddingDriver,
BedrockClaudePromptModelDriver,
BedrockClaudeImageQueryModelDriver,
BedrockTitanImageGenerationModelDriver,
LocalVectorStoreDriver,
)
Expand All @@ -34,6 +36,10 @@ class AmazonBedrockStructureConfig(BaseStructureConfig):
model="amazon.titan-image-generator-v1",
image_generation_model_driver=BedrockTitanImageGenerationModelDriver(),
),
image_query_driver=AmazonBedrockImageQueryDriver(
model="anthropic.claude-3-sonnet-20240229-v1:0",
image_query_model_driver=BedrockClaudeImageQueryModelDriver(),
),
embedding_driver=AmazonBedrockTitanEmbeddingDriver(model="amazon.titan-embed-text-v1"),
vector_store_driver=LocalVectorStoreDriver(
embedding_driver=AmazonBedrockTitanEmbeddingDriver(model="amazon.titan-embed-text-v1")
Expand Down
8 changes: 7 additions & 1 deletion griptape/config/anthropic_structure_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,12 @@
StructureTaskMemoryQueryEngineConfig,
StructureTaskMemorySummaryEngineConfig,
)
from griptape.drivers import LocalVectorStoreDriver, AnthropicPromptDriver, VoyageAiEmbeddingDriver
from griptape.drivers import (
LocalVectorStoreDriver,
AnthropicPromptDriver,
AnthropicImageQueryDriver,
VoyageAiEmbeddingDriver,
)


@define
Expand All @@ -23,6 +28,7 @@ class AnthropicStructureConfig(BaseStructureConfig):
vector_store_driver=LocalVectorStoreDriver(
embedding_driver=VoyageAiEmbeddingDriver(model="voyage-large-2")
),
image_query_driver=AnthropicImageQueryDriver(model="claude-3-opus-20240229"),
)
),
kw_only=True,
Expand Down
2 changes: 1 addition & 1 deletion griptape/config/openai_structure_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ class OpenAiStructureConfig(BaseStructureConfig):
lambda: StructureGlobalDriversConfig(
prompt_driver=OpenAiChatPromptDriver(model="gpt-4"),
image_generation_driver=OpenAiImageGenerationDriver(model="dall-e-2", image_size="512x512"),
image_query_driver=OpenAiVisionImageQueryDriver(model="gpt-4-vision-preview", max_tokens=300),
image_query_driver=OpenAiVisionImageQueryDriver(model="gpt-4-vision-preview"),
embedding_driver=OpenAiEmbeddingDriver(model="text-embedding-3-small"),
vector_store_driver=LocalVectorStoreDriver(
embedding_driver=OpenAiEmbeddingDriver(model="text-embedding-3-small")
Expand Down
13 changes: 12 additions & 1 deletion griptape/drivers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,9 +72,15 @@
from .image_generation.azure_openai_image_generation_driver import AzureOpenAiImageGenerationDriver
from .image_generation.dummy_image_generation_driver import DummyImageGenerationDriver

from .image_query_model.base_image_query_model_driver import BaseImageQueryModelDriver
from .image_query_model.bedrock_claude_image_query_model_driver import BedrockClaudeImageQueryModelDriver

from .image_query.base_image_query_driver import BaseImageQueryDriver
from .image_query.openai_vision_image_query_driver import OpenAiVisionImageQueryDriver
from .image_query.base_multi_model_image_query_driver import BaseMultiModelImageQueryDriver
from .image_query.dummy_image_query_driver import DummyImageQueryDriver
from .image_query.openai_vision_image_query_driver import OpenAiVisionImageQueryDriver
from .image_query.anthropic_image_query_driver import AnthropicImageQueryDriver
from .image_query.amazon_bedrock_image_query_driver import AmazonBedrockImageQueryDriver

from .web_scraper.base_web_scraper_driver import BaseWebScraperDriver
from .web_scraper.trafilatura_web_scraper_driver import TrafilaturaWebScraperDriver
Expand Down Expand Up @@ -144,9 +150,14 @@
"AmazonBedrockImageGenerationDriver",
"AzureOpenAiImageGenerationDriver",
"DummyImageGenerationDriver",
BaseImageQueryModelDriver,
BedrockClaudeImageQueryModelDriver,
"BaseImageQueryDriver",
"OpenAiVisionImageQueryDriver",
"DummyImageQueryDriver",
AnthropicImageQueryDriver,
BaseMultiModelImageQueryDriver,
AmazonBedrockImageQueryDriver,
"BaseWebScraperDriver",
"TrafilaturaWebScraperDriver",
"MarkdownifyWebScraperDriver",
Expand Down
35 changes: 35 additions & 0 deletions griptape/drivers/image_query/amazon_bedrock_image_query_driver.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
from __future__ import annotations
from typing import TYPE_CHECKING, Any
from attr import define, field, Factory
from griptape.artifacts import ImageArtifact, TextArtifact
from griptape.drivers import BaseMultiModelImageQueryDriver
from griptape.utils import import_optional_dependency
import json

if TYPE_CHECKING:
import boto3


@define
class AmazonBedrockImageQueryDriver(BaseMultiModelImageQueryDriver):
session: boto3.Session = field(default=Factory(lambda: import_optional_dependency("boto3").Session()), kw_only=True)
bedrock_client: Any = field(
default=Factory(lambda self: self.session.client("bedrock-runtime"), takes_self=True), kw_only=True
)

def try_query(self, query: str, images: list[ImageArtifact]) -> TextArtifact:
payload = self.image_query_model_driver.image_query_request_parameters(query, images, self.max_tokens)

response = self.bedrock_client.invoke_model(
modelId=self.model, contentType="application/json", accept="application/json", body=json.dumps(payload)
)

response_body = json.loads(response.get("body").read())

if response_body is None:
raise ValueError("Model response is empty")

try:
return self.image_query_model_driver.process_output(response_body)
except Exception as e:
raise ValueError(f"Output is unable to be processed as returned {e}")
59 changes: 59 additions & 0 deletions griptape/drivers/image_query/anthropic_image_query_driver.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
from __future__ import annotations
from typing import Optional, Any
from attr import define, field, Factory
from griptape.artifacts import ImageArtifact, TextArtifact
from griptape.drivers import BaseImageQueryDriver
from griptape.utils import import_optional_dependency


@define
class AnthropicImageQueryDriver(BaseImageQueryDriver):
"""
Attributes:
api_key: Anthropic API key.
model: Anthropic model name.
client: Custom `Anthropic` client.
"""

api_key: Optional[str] = field(default=None, kw_only=True, metadata={"serializable": True})
model: str = field(kw_only=True, metadata={"serializable": True})
client: Any = field(
default=Factory(
lambda self: import_optional_dependency("anthropic").Anthropic(api_key=self.api_key), takes_self=True
),
kw_only=True,
)

def try_query(self, query: str, images: list[ImageArtifact]) -> TextArtifact:
if self.max_tokens is None:
raise TypeError("max_output_tokens can't be empty")

response = self.client.messages.create(**self._base_params(query, images))
content_blocks = response.content

if len(content_blocks) < 1:
raise ValueError("Response content is empty")

text_content = content_blocks[0].text

return TextArtifact(text_content)

def _base_params(self, text_query: str, images: list[ImageArtifact]):
content = [self._construct_image_message(image) for image in images]
content.append(self._construct_text_message(text_query))
messages = self._construct_messages(content)
params = {"model": self.model, "messages": messages, "max_tokens": self.max_tokens}

return params

def _construct_image_message(self, image_data: ImageArtifact) -> dict:
data = image_data.base64
type = image_data.mime_type

return {"source": {"data": data, "media_type": type, "type": "base64"}, "type": "image"}

def _construct_text_message(self, query: str) -> dict:
return {"text": query, "type": "text"}

def _construct_messages(self, content: list) -> list:
return [{"content": content, "role": "user"}]
1 change: 1 addition & 0 deletions griptape/drivers/image_query/base_image_query_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
@define
class BaseImageQueryDriver(SerializableMixin, ExponentialBackoffMixin, ABC):
structure: Optional[Structure] = field(default=None, kw_only=True)
max_tokens: int = field(default=256, kw_only=True, metadata={"serializable": True})

def before_run(self, query: str, images: list[ImageArtifact]) -> None:
if self.structure:
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
from __future__ import annotations
from abc import ABC

from attr import field, define

from griptape.drivers import BaseImageQueryDriver, BaseImageQueryModelDriver


@define
class BaseMultiModelImageQueryDriver(BaseImageQueryDriver, ABC):
"""Image Query Driver for platforms like Amazon Bedrock that host many LLM models.
Instances of this Image Query Driver require a Image Query Model Driver which is used to structure the
image generation request in the format required by the model and to process the output.
Attributes:
model: Model name to use
image_query_model_driver: Image Model Driver to use.
"""

model: str = field(kw_only=True, metadata={"serializable": True})
image_query_model_driver: BaseImageQueryModelDriver = field(kw_only=True, metadata={"serializable": True})
1 change: 1 addition & 0 deletions griptape/drivers/image_query/dummy_image_query_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
@define
class DummyImageQueryDriver(BaseImageQueryDriver):
model: str = field(init=False)
max_tokens: int = field(init=False)

def try_query(self, query: str, images: list[ImageArtifact]) -> TextArtifact:
raise DummyException(__class__.__name__, "try_query")
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,13 @@

@define
class OpenAiVisionImageQueryDriver(BaseImageQueryDriver):
model: str = field(default="gpt-4-vision-preview", kw_only=True, metadata={"serializable": True})
model: str = field(kw_only=True, metadata={"serializable": True})
api_type: str = field(default=openai.api_type, kw_only=True)
api_version: Optional[str] = field(default=openai.api_version, kw_only=True, metadata={"serializable": True})
base_url: Optional[str] = field(default=None, kw_only=True, metadata={"serializable": True})
api_key: Optional[str] = field(default=None, kw_only=True)
organization: Optional[str] = field(default=openai.organization, kw_only=True, metadata={"serializable": True})
image_quality: Literal["auto", "low", "high"] = field(default="auto", kw_only=True, metadata={"serializable": True})
max_tokens: Optional[int] = field(default=None, kw_only=True, metadata={"serializable": True})
client: openai.OpenAI = field(
default=Factory(
lambda self: openai.OpenAI(api_key=self.api_key, base_url=self.base_url, organization=self.organization),
Expand All @@ -46,10 +45,7 @@ def try_query(self, query: str, images: list[ImageArtifact]) -> TextArtifact:
)

messages = ChatCompletionUserMessageParam(content=message_parts, role="user")
params = {"model": self.model, "messages": [messages]}

if self.max_tokens is not None:
params["max_tokens"] = self.max_tokens
params = {"model": self.model, "messages": [messages], "max_tokens": self.max_tokens}

response = self.client.chat.completions.create(**params)

Expand Down
Empty file.
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
from __future__ import annotations
from abc import ABC, abstractmethod
from attr import define
from griptape.artifacts import TextArtifact, ImageArtifact
from griptape.mixins import SerializableMixin


@define
class BaseImageQueryModelDriver(SerializableMixin, ABC):
@abstractmethod
def image_query_request_parameters(self, query: str, images: list[ImageArtifact], max_tokens: int) -> dict:
...

@abstractmethod
def process_output(self, output: dict) -> TextArtifact:
...
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
from __future__ import annotations
from typing import Optional
from attr import field, define
from griptape.artifacts import ImageArtifact, TextArtifact
from griptape.drivers import BaseImageQueryModelDriver


@define
class BedrockClaudeImageQueryModelDriver(BaseImageQueryModelDriver):
ANTHROPIC_VERSION = "bedrock-2023-05-31" # static string for AWS: https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-anthropic-claude-messages.html#api-inference-examples-claude-multimodal-code-example

def image_query_request_parameters(self, query: str, images: list[ImageArtifact], max_tokens: int) -> dict:
content = [self._construct_image_message(image) for image in images]
content.append(self._construct_text_message(query))
messages = self._construct_messages(content)
input_params = {"messages": messages, "anthropic_version": self.ANTHROPIC_VERSION, "max_tokens": max_tokens}

return input_params

def process_output(self, output: dict) -> TextArtifact:
content_blocks = output["content"]
if len(content_blocks) < 1:
raise ValueError("Response content is empty")

text_content = content_blocks[0]["text"]

return TextArtifact(text_content)

def _construct_image_message(self, image_data: ImageArtifact) -> dict:
data = image_data.base64
type = image_data.mime_type

return {"source": {"data": data, "media_type": type, "type": "base64"}, "type": "image"}

def _construct_text_message(self, query: str) -> dict:
return {"text": query, "type": "text"}

def _construct_messages(self, content: list) -> list:
return [{"content": content, "role": "user"}]
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,12 @@

@define
class BedrockClaudePromptModelDriver(BasePromptModelDriver):
ANTHROPIC_VERSION = "bedrock-2023-05-31" # static string for AWS: https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-anthropic-claude-messages.html#api-inference-examples-claude-multimodal-code-example

top_p: float = field(default=0.999, kw_only=True, metadata={"serializable": True})
top_k: int = field(default=250, kw_only=True, metadata={"serializable": True})
_tokenizer: BedrockClaudeTokenizer = field(default=None, kw_only=True)
prompt_driver: Optional[AmazonBedrockPromptDriver] = field(default=None, kw_only=True)
anthropic_version: str = field(default="bedrock-2023-05-31", kw_only=True, metadata={"serializable": True})

@property
def tokenizer(self) -> BedrockClaudeTokenizer:
Expand Down Expand Up @@ -60,7 +61,7 @@ def prompt_stack_to_model_params(self, prompt_stack: PromptStack) -> dict:
"top_p": self.top_p,
"top_k": self.top_k,
"max_tokens": self.prompt_driver.max_output_tokens(self.prompt_driver.prompt_stack_to_string(prompt_stack)),
"anthropic_version": self.anthropic_version,
"anthropic_version": self.ANTHROPIC_VERSION,
**input,
}

Expand Down

0 comments on commit f96af2e

Please sign in to comment.