-
Notifications
You must be signed in to change notification settings - Fork 153
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
25 changed files
with
488 additions
and
37 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -49,4 +49,7 @@ dist/** | |
# translation files | ||
*.xlf | ||
*.nls.*.json | ||
*.i18n.json | ||
*.i18n.json | ||
|
||
# asdf | ||
.tool-versions |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
35 changes: 35 additions & 0 deletions
35
griptape/drivers/image_query/amazon_bedrock_image_query_driver.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,35 @@ | ||
from __future__ import annotations | ||
from typing import TYPE_CHECKING, Any | ||
from attr import define, field, Factory | ||
from griptape.artifacts import ImageArtifact, TextArtifact | ||
from griptape.drivers import BaseMultiModelImageQueryDriver | ||
from griptape.utils import import_optional_dependency | ||
import json | ||
|
||
if TYPE_CHECKING: | ||
import boto3 | ||
|
||
|
||
@define | ||
class AmazonBedrockImageQueryDriver(BaseMultiModelImageQueryDriver): | ||
session: boto3.Session = field(default=Factory(lambda: import_optional_dependency("boto3").Session()), kw_only=True) | ||
bedrock_client: Any = field( | ||
default=Factory(lambda self: self.session.client("bedrock-runtime"), takes_self=True), kw_only=True | ||
) | ||
|
||
def try_query(self, query: str, images: list[ImageArtifact]) -> TextArtifact: | ||
payload = self.image_query_model_driver.image_query_request_parameters(query, images, self.max_tokens) | ||
|
||
response = self.bedrock_client.invoke_model( | ||
modelId=self.model, contentType="application/json", accept="application/json", body=json.dumps(payload) | ||
) | ||
|
||
response_body = json.loads(response.get("body").read()) | ||
|
||
if response_body is None: | ||
raise ValueError("Model response is empty") | ||
|
||
try: | ||
return self.image_query_model_driver.process_output(response_body) | ||
except Exception as e: | ||
raise ValueError(f"Output is unable to be processed as returned {e}") |
59 changes: 59 additions & 0 deletions
59
griptape/drivers/image_query/anthropic_image_query_driver.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,59 @@ | ||
from __future__ import annotations | ||
from typing import Optional, Any | ||
from attr import define, field, Factory | ||
from griptape.artifacts import ImageArtifact, TextArtifact | ||
from griptape.drivers import BaseImageQueryDriver | ||
from griptape.utils import import_optional_dependency | ||
|
||
|
||
@define | ||
class AnthropicImageQueryDriver(BaseImageQueryDriver): | ||
""" | ||
Attributes: | ||
api_key: Anthropic API key. | ||
model: Anthropic model name. | ||
client: Custom `Anthropic` client. | ||
""" | ||
|
||
api_key: Optional[str] = field(default=None, kw_only=True, metadata={"serializable": True}) | ||
model: str = field(kw_only=True, metadata={"serializable": True}) | ||
client: Any = field( | ||
default=Factory( | ||
lambda self: import_optional_dependency("anthropic").Anthropic(api_key=self.api_key), takes_self=True | ||
), | ||
kw_only=True, | ||
) | ||
|
||
def try_query(self, query: str, images: list[ImageArtifact]) -> TextArtifact: | ||
if self.max_tokens is None: | ||
raise TypeError("max_output_tokens can't be empty") | ||
|
||
response = self.client.messages.create(**self._base_params(query, images)) | ||
content_blocks = response.content | ||
|
||
if len(content_blocks) < 1: | ||
raise ValueError("Response content is empty") | ||
|
||
text_content = content_blocks[0].text | ||
|
||
return TextArtifact(text_content) | ||
|
||
def _base_params(self, text_query: str, images: list[ImageArtifact]): | ||
content = [self._construct_image_message(image) for image in images] | ||
content.append(self._construct_text_message(text_query)) | ||
messages = self._construct_messages(content) | ||
params = {"model": self.model, "messages": messages, "max_tokens": self.max_tokens} | ||
|
||
return params | ||
|
||
def _construct_image_message(self, image_data: ImageArtifact) -> dict: | ||
data = image_data.base64 | ||
type = image_data.mime_type | ||
|
||
return {"source": {"data": data, "media_type": type, "type": "base64"}, "type": "image"} | ||
|
||
def _construct_text_message(self, query: str) -> dict: | ||
return {"text": query, "type": "text"} | ||
|
||
def _construct_messages(self, content: list) -> list: | ||
return [{"content": content, "role": "user"}] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
22 changes: 22 additions & 0 deletions
22
griptape/drivers/image_query/base_multi_model_image_query_driver.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,22 @@ | ||
from __future__ import annotations | ||
from abc import ABC | ||
|
||
from attr import field, define | ||
|
||
from griptape.drivers import BaseImageQueryDriver, BaseImageQueryModelDriver | ||
|
||
|
||
@define | ||
class BaseMultiModelImageQueryDriver(BaseImageQueryDriver, ABC): | ||
"""Image Query Driver for platforms like Amazon Bedrock that host many LLM models. | ||
Instances of this Image Query Driver require a Image Query Model Driver which is used to structure the | ||
image generation request in the format required by the model and to process the output. | ||
Attributes: | ||
model: Model name to use | ||
image_query_model_driver: Image Model Driver to use. | ||
""" | ||
|
||
model: str = field(kw_only=True, metadata={"serializable": True}) | ||
image_query_model_driver: BaseImageQueryModelDriver = field(kw_only=True, metadata={"serializable": True}) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
16 changes: 16 additions & 0 deletions
16
griptape/drivers/image_query_model/base_image_query_model_driver.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,16 @@ | ||
from __future__ import annotations | ||
from abc import ABC, abstractmethod | ||
from attr import define | ||
from griptape.artifacts import TextArtifact, ImageArtifact | ||
from griptape.mixins import SerializableMixin | ||
|
||
|
||
@define | ||
class BaseImageQueryModelDriver(SerializableMixin, ABC): | ||
@abstractmethod | ||
def image_query_request_parameters(self, query: str, images: list[ImageArtifact], max_tokens: int) -> dict: | ||
... | ||
|
||
@abstractmethod | ||
def process_output(self, output: dict) -> TextArtifact: | ||
... |
39 changes: 39 additions & 0 deletions
39
griptape/drivers/image_query_model/bedrock_claude_image_query_model_driver.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,39 @@ | ||
from __future__ import annotations | ||
from typing import Optional | ||
from attr import field, define | ||
from griptape.artifacts import ImageArtifact, TextArtifact | ||
from griptape.drivers import BaseImageQueryModelDriver | ||
|
||
|
||
@define | ||
class BedrockClaudeImageQueryModelDriver(BaseImageQueryModelDriver): | ||
ANTHROPIC_VERSION = "bedrock-2023-05-31" # static string for AWS: https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-anthropic-claude-messages.html#api-inference-examples-claude-multimodal-code-example | ||
|
||
def image_query_request_parameters(self, query: str, images: list[ImageArtifact], max_tokens: int) -> dict: | ||
content = [self._construct_image_message(image) for image in images] | ||
content.append(self._construct_text_message(query)) | ||
messages = self._construct_messages(content) | ||
input_params = {"messages": messages, "anthropic_version": self.ANTHROPIC_VERSION, "max_tokens": max_tokens} | ||
|
||
return input_params | ||
|
||
def process_output(self, output: dict) -> TextArtifact: | ||
content_blocks = output["content"] | ||
if len(content_blocks) < 1: | ||
raise ValueError("Response content is empty") | ||
|
||
text_content = content_blocks[0]["text"] | ||
|
||
return TextArtifact(text_content) | ||
|
||
def _construct_image_message(self, image_data: ImageArtifact) -> dict: | ||
data = image_data.base64 | ||
type = image_data.mime_type | ||
|
||
return {"source": {"data": data, "media_type": type, "type": "base64"}, "type": "image"} | ||
|
||
def _construct_text_message(self, query: str) -> dict: | ||
return {"text": query, "type": "text"} | ||
|
||
def _construct_messages(self, content: list) -> list: | ||
return [{"content": content, "role": "user"}] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.