Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Claude 3 Image Query Support #700

Merged
merged 26 commits into from
Mar 21, 2024
Merged
Show file tree
Hide file tree
Changes from 21 commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
5b56352
Adding anthropic image query driver
emjay07 Mar 18, 2024
c7496c0
adding base files for bedrock-claude support
emjay07 Mar 18, 2024
afb28a5
adding bedrock driver and claude model driver
emjay07 Mar 19, 2024
0d8107c
reformatting
emjay07 Mar 19, 2024
da3e5e8
adding tests for anthropic driver
emjay07 Mar 19, 2024
f46c9ca
adding models tests
emjay07 Mar 19, 2024
e5e1221
adding max_tokens defaults back in
emjay07 Mar 19, 2024
059a4e2
resolving merge conflicts
emjay07 Mar 19, 2024
8af40e5
updating lock file for new dependencies
emjay07 Mar 19, 2024
9c4b5f2
stripping the extra structures around the output
emjay07 Mar 19, 2024
685fba3
adding OpenAIVisionImageQueryDriver testing
emjay07 Mar 20, 2024
ca814c4
updating to use ImageArtifact's base64 method
emjay07 Mar 20, 2024
3565689
updating whitespace and setting defaults
emjay07 Mar 20, 2024
7c38119
addressing pr comments for renames and minor refactors
emjay07 Mar 20, 2024
41157b4
updating toml after merge
emjay07 Mar 21, 2024
7bbdbd6
updating lock file after merge
emjay07 Mar 21, 2024
294dac1
updating changelog after merge
emjay07 Mar 21, 2024
fbee2d8
bringing lock file back up to dev
emjay07 Mar 21, 2024
978c8b9
updating drivers to make max_output_tokens be required
emjay07 Mar 21, 2024
3ee66c1
updating structures and anthropic_version to static
emjay07 Mar 21, 2024
8a2e640
updating changelog
emjay07 Mar 21, 2024
5528638
nit in changelog and surfacing better errors in bedrock image query d…
emjay07 Mar 21, 2024
ec6e7ea
updating name for static field
emjay07 Mar 21, 2024
91e3e0e
fixing naming in tests
emjay07 Mar 21, 2024
b70e373
renaming to max_tokens
emjay07 Mar 21, 2024
f8fc43a
removing default model from openai's image query driver
emjay07 Mar 21, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -49,4 +49,7 @@ dist/**
# translation files
*.xlf
*.nls.*.json
*.i18n.json
*.i18n.json

# asdf
.tool-versions
4 changes: 3 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Support for `claude-3-opus`, `claude-3-sonnet`, and `claude-3-haiku` in `AnthropicPromptDriver`.
- Support for `anthropic.claude-3-sonnet-20240229-v1:0` and `anthropic.claude-3-haiku-20240307-v1:0` in `BedrockClaudePromptModelDriver`.
- `top_k` and `top_p` parameters in `AnthropicPromptDriver`.
- Added `AnthropicImageQueryDriver` for Claude-3 multi-modal models
- Added `AmazonBedrockImageQueryDriver` along with `BedrockClaudeImageQueryDriverModel` for Claude-3 in Bedrock support
- `BaseWebScraperDriver` allowing multiple web scraping implementations.
- `TrafilaturaWebScraperDriver` for scraping text from web pages using trafilatura.
- `MarkdownifyWebScraperDriver` for scraping text from web pages using playwright and converting to markdown using markdownify.
Expand All @@ -33,7 +35,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- `BaseTextLoader` to accept a `BaseChunker`.
- Default model of `AmazonBedrockStructureConfig` to `anthropic.claude-3-sonnet-20240229-v1:0`.
- `AnthropicPromptDriver` and `BedrockClaudePromptModelDriver` to use Anthropic's Messages API.

- `OpenAIVisionImageQueryDriver` now has a required field `max_output_tokens` that defaults to 256
emjay07 marked this conversation as resolved.
Show resolved Hide resolved

## [0.23.2] - 2024-03-15

Expand Down
6 changes: 6 additions & 0 deletions griptape/config/amazon_bedrock_structure_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,11 @@
)
from griptape.drivers import (
AmazonBedrockImageGenerationDriver,
AmazonBedrockImageQueryDriver,
AmazonBedrockPromptDriver,
AmazonBedrockTitanEmbeddingDriver,
BedrockClaudePromptModelDriver,
BedrockClaudeImageQueryModelDriver,
BedrockTitanImageGenerationModelDriver,
LocalVectorStoreDriver,
)
Expand All @@ -34,6 +36,10 @@ class AmazonBedrockStructureConfig(BaseStructureConfig):
model="amazon.titan-image-generator-v1",
image_generation_model_driver=BedrockTitanImageGenerationModelDriver(),
),
image_query_driver=AmazonBedrockImageQueryDriver(
model="anthropic.claude-3-sonnet-20240229-v1:0",
image_query_model_driver=BedrockClaudeImageQueryModelDriver(),
),
embedding_driver=AmazonBedrockTitanEmbeddingDriver(model="amazon.titan-embed-text-v1"),
vector_store_driver=LocalVectorStoreDriver(
embedding_driver=AmazonBedrockTitanEmbeddingDriver(model="amazon.titan-embed-text-v1")
Expand Down
8 changes: 7 additions & 1 deletion griptape/config/anthropic_structure_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,12 @@
StructureTaskMemoryQueryEngineConfig,
StructureTaskMemorySummaryEngineConfig,
)
from griptape.drivers import LocalVectorStoreDriver, AnthropicPromptDriver, VoyageAiEmbeddingDriver
from griptape.drivers import (
LocalVectorStoreDriver,
AnthropicPromptDriver,
AnthropicImageQueryDriver,
VoyageAiEmbeddingDriver,
)


@define
Expand All @@ -23,6 +28,7 @@ class AnthropicStructureConfig(BaseStructureConfig):
vector_store_driver=LocalVectorStoreDriver(
embedding_driver=VoyageAiEmbeddingDriver(model="voyage-large-2")
),
image_query_driver=AnthropicImageQueryDriver(model="claude-3-opus-20240229"),
)
),
kw_only=True,
Expand Down
2 changes: 1 addition & 1 deletion griptape/config/openai_structure_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ class OpenAiStructureConfig(BaseStructureConfig):
lambda: StructureGlobalDriversConfig(
prompt_driver=OpenAiChatPromptDriver(model="gpt-4"),
image_generation_driver=OpenAiImageGenerationDriver(model="dall-e-2", image_size="512x512"),
image_query_driver=OpenAiVisionImageQueryDriver(model="gpt-4-vision-preview", max_tokens=300),
image_query_driver=OpenAiVisionImageQueryDriver(model="gpt-4-vision-preview"),
embedding_driver=OpenAiEmbeddingDriver(model="text-embedding-3-small"),
vector_store_driver=LocalVectorStoreDriver(
embedding_driver=OpenAiEmbeddingDriver(model="text-embedding-3-small")
Expand Down
13 changes: 12 additions & 1 deletion griptape/drivers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,9 +72,15 @@
from .image_generation.azure_openai_image_generation_driver import AzureOpenAiImageGenerationDriver
from .image_generation.dummy_image_generation_driver import DummyImageGenerationDriver

from .image_query_model.base_image_query_model_driver import BaseImageQueryModelDriver
from .image_query_model.bedrock_claude_image_query_model_driver import BedrockClaudeImageQueryModelDriver

from .image_query.base_image_query_driver import BaseImageQueryDriver
from .image_query.openai_vision_image_query_driver import OpenAiVisionImageQueryDriver
from .image_query.base_multi_model_image_query_driver import BaseMultiModelImageQueryDriver
from .image_query.dummy_image_query_driver import DummyImageQueryDriver
from .image_query.openai_vision_image_query_driver import OpenAiVisionImageQueryDriver
from .image_query.anthropic_image_query_driver import AnthropicImageQueryDriver
from .image_query.amazon_bedrock_image_query_driver import AmazonBedrockImageQueryDriver

from .web_scraper.base_web_scraper_driver import BaseWebScraperDriver
from .web_scraper.trafilatura_web_scraper_driver import TrafilaturaWebScraperDriver
Expand Down Expand Up @@ -144,9 +150,14 @@
"AmazonBedrockImageGenerationDriver",
"AzureOpenAiImageGenerationDriver",
"DummyImageGenerationDriver",
BaseImageQueryModelDriver,
BedrockClaudeImageQueryModelDriver,
"BaseImageQueryDriver",
"OpenAiVisionImageQueryDriver",
"DummyImageQueryDriver",
AnthropicImageQueryDriver,
BaseMultiModelImageQueryDriver,
AmazonBedrockImageQueryDriver,
"BaseWebScraperDriver",
"TrafilaturaWebScraperDriver",
"MarkdownifyWebScraperDriver",
Expand Down
35 changes: 35 additions & 0 deletions griptape/drivers/image_query/amazon_bedrock_image_query_driver.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
from __future__ import annotations
from typing import TYPE_CHECKING, Any
from attr import define, field, Factory
from griptape.artifacts import ImageArtifact, TextArtifact
from griptape.drivers import BaseMultiModelImageQueryDriver
from griptape.utils import import_optional_dependency
import json

if TYPE_CHECKING:
import boto3


@define
class AmazonBedrockImageQueryDriver(BaseMultiModelImageQueryDriver):
session: boto3.Session = field(default=Factory(lambda: import_optional_dependency("boto3").Session()), kw_only=True)
bedrock_client: Any = field(
default=Factory(lambda self: self.session.client("bedrock-runtime"), takes_self=True), kw_only=True
)

def try_query(self, query: str, images: list[ImageArtifact]) -> TextArtifact:
payload = self.image_query_model_driver.image_query_request_parameters(query, images, self.max_output_tokens)

response = self.bedrock_client.invoke_model(
modelId=self.model, contentType="application/json", accept="application/json", body=json.dumps(payload)
)

response_body = json.loads(response.get("body").read())

if response_body is None:
raise ValueError("Model response is empty")

try:
return self.image_query_model_driver.process_output(response_body)
except Exception as e:
raise e
collindutter marked this conversation as resolved.
Show resolved Hide resolved
59 changes: 59 additions & 0 deletions griptape/drivers/image_query/anthropic_image_query_driver.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
from __future__ import annotations
from typing import Optional, Any
from attr import define, field, Factory
from griptape.artifacts import ImageArtifact, TextArtifact
from griptape.drivers import BaseImageQueryDriver
from griptape.utils import import_optional_dependency


@define
class AnthropicImageQueryDriver(BaseImageQueryDriver):
"""
Attributes:
api_key: Anthropic API key.
model: Anthropic model name.
client: Custom `Anthropic` client.
"""

api_key: Optional[str] = field(default=None, kw_only=True, metadata={"serializable": True})
model: str = field(kw_only=True, metadata={"serializable": True})
client: Any = field(
default=Factory(
lambda self: import_optional_dependency("anthropic").Anthropic(api_key=self.api_key), takes_self=True
),
kw_only=True,
)

def try_query(self, query: str, images: list[ImageArtifact]) -> TextArtifact:
collindutter marked this conversation as resolved.
Show resolved Hide resolved
if self.max_output_tokens is None:
raise TypeError("max_output_tokens can't be empty")

response = self.client.messages.create(**self._base_params(query, images))
content_blocks = response.content

if len(content_blocks) < 1:
raise ValueError("Response content is empty")

text_content = content_blocks[0].text

return TextArtifact(text_content)

def _base_params(self, text_query: str, images: list[ImageArtifact]):
content = [self._construct_image_message(image) for image in images]
content.append(self._construct_text_message(text_query))
messages = self._construct_messages(content)
params = {"model": self.model, "messages": messages, "max_tokens": self.max_output_tokens}

return params

def _construct_image_message(self, image_data: ImageArtifact) -> dict:
data = image_data.base64
type = image_data.mime_type

return {"source": {"data": data, "media_type": type, "type": "base64"}, "type": "image"}

def _construct_text_message(self, query: str) -> dict:
return {"text": query, "type": "text"}

def _construct_messages(self, content: list) -> list:
return [{"content": content, "role": "user"}]
1 change: 1 addition & 0 deletions griptape/drivers/image_query/base_image_query_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
@define
class BaseImageQueryDriver(SerializableMixin, ExponentialBackoffMixin, ABC):
structure: Optional[Structure] = field(default=None, kw_only=True)
max_output_tokens: int = field(default=256, kw_only=True, metadata={"serializable": True})
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Lets keep it named as max_tokens for consistency with other Drivers.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought I saw a PR that updated max_tokens to max_input_tokens and max_output_tokens? Was that just when you are providing a tokenizer?

I am fine with updating this to match the rest of the drivers, but I would say max_output_tokens is better because that is explicitly what it is. Is that the same for the other drivers as well?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed that max_output_tokens is probably a better field name, but the other PR was only in the context of the Tokenizers, not the Drivers.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

All the other drivers have inputs as text, not images. so to clarify my question: are the other drivers that have max_tokens intended to be for the output tokens or the input tokens?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated back to max_tokens


def before_run(self, query: str, images: list[ImageArtifact]) -> None:
if self.structure:
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
from __future__ import annotations
from abc import ABC

from attr import field, define

from griptape.drivers import BaseImageQueryDriver, BaseImageQueryModelDriver


@define
class BaseMultiModelImageQueryDriver(BaseImageQueryDriver, ABC):
"""Image Query Driver for platforms like Amazon Bedrock that host many LLM models.

Instances of this Image Query Driver require a Image Query Model Driver which is used to structure the
image generation request in the format required by the model and to process the output.

Attributes:
model: Model name to use
image_query_model_driver: Image Model Driver to use.
"""

model: str = field(kw_only=True, metadata={"serializable": True})
image_query_model_driver: BaseImageQueryModelDriver = field(kw_only=True, metadata={"serializable": True})
1 change: 1 addition & 0 deletions griptape/drivers/image_query/dummy_image_query_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
@define
class DummyImageQueryDriver(BaseImageQueryDriver):
model: str = field(init=False)
max_output_tokens: int = field(init=False)

def try_query(self, query: str, images: list[ImageArtifact]) -> TextArtifact:
raise DummyException(__class__.__name__, "try_query")
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@ class OpenAiVisionImageQueryDriver(BaseImageQueryDriver):
api_key: Optional[str] = field(default=None, kw_only=True)
organization: Optional[str] = field(default=openai.organization, kw_only=True, metadata={"serializable": True})
image_quality: Literal["auto", "low", "high"] = field(default="auto", kw_only=True, metadata={"serializable": True})
max_tokens: Optional[int] = field(default=None, kw_only=True, metadata={"serializable": True})
client: openai.OpenAI = field(
default=Factory(
lambda self: openai.OpenAI(api_key=self.api_key, base_url=self.base_url, organization=self.organization),
Expand All @@ -46,10 +45,7 @@ def try_query(self, query: str, images: list[ImageArtifact]) -> TextArtifact:
)

messages = ChatCompletionUserMessageParam(content=message_parts, role="user")
params = {"model": self.model, "messages": [messages]}

if self.max_tokens is not None:
params["max_tokens"] = self.max_tokens
params = {"model": self.model, "messages": [messages], "max_tokens": self.max_output_tokens}

response = self.client.chat.completions.create(**params)

Expand Down
Empty file.
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
from __future__ import annotations
from abc import ABC, abstractmethod
from attr import define
from griptape.artifacts import TextArtifact, ImageArtifact
from griptape.mixins import SerializableMixin


@define
class BaseImageQueryModelDriver(SerializableMixin, ABC):
@abstractmethod
def image_query_request_parameters(self, query: str, images: list[ImageArtifact], max_output_tokens: int) -> dict:
...

@abstractmethod
def process_output(self, output: dict) -> TextArtifact:
...
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
from __future__ import annotations
from typing import Optional
from attr import field, define
from griptape.artifacts import ImageArtifact, TextArtifact
from griptape.drivers import BaseImageQueryModelDriver


@define
class BedrockClaudeImageQueryModelDriver(BaseImageQueryModelDriver):
anthropic_version: str = "bedrock-2023-05-31" # static string for AWS: https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-anthropic-claude-messages.html#api-inference-examples-claude-multimodal-code-example
collindutter marked this conversation as resolved.
Show resolved Hide resolved

def image_query_request_parameters(self, query: str, images: list[ImageArtifact], max_output_tokens: int) -> dict:
content = [self._construct_image_message(image) for image in images]
content.append(self._construct_text_message(query))
messages = self._construct_messages(content)
input_params = {
"messages": messages,
"anthropic_version": self.anthropic_version,
"max_tokens": max_output_tokens,
}

return input_params

def process_output(self, output: dict) -> TextArtifact:
content_blocks = output["content"]
if len(content_blocks) < 1:
raise ValueError("Response content is empty")

text_content = content_blocks[0]["text"]

return TextArtifact(text_content)

def _construct_image_message(self, image_data: ImageArtifact) -> dict:
data = image_data.base64
type = image_data.mime_type

return {"source": {"data": data, "media_type": type, "type": "base64"}, "type": "image"}

def _construct_text_message(self, query: str) -> dict:
return {"text": query, "type": "text"}

def _construct_messages(self, content: list) -> list:
return [{"content": content, "role": "user"}]
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ class BedrockClaudePromptModelDriver(BasePromptModelDriver):
top_k: int = field(default=250, kw_only=True, metadata={"serializable": True})
_tokenizer: BedrockClaudeTokenizer = field(default=None, kw_only=True)
prompt_driver: Optional[AmazonBedrockPromptDriver] = field(default=None, kw_only=True)
anthropic_version: str = field(default="bedrock-2023-05-31", kw_only=True, metadata={"serializable": True})
anthropic_version: str = "bedrock-2023-05-31" # static string for AWS: https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-anthropic-claude-messages.html#api-inference-examples-claude-multimodal-code-example
collindutter marked this conversation as resolved.
Show resolved Hide resolved

@property
def tokenizer(self) -> BedrockClaudeTokenizer:
Expand Down
Loading
Loading