Skip to content

Commit

Permalink
Add MediaArtifact (#520)
Browse files Browse the repository at this point in the history
  • Loading branch information
andrewfrench committed Apr 12, 2024
1 parent 2cc293c commit 54ef8fb
Show file tree
Hide file tree
Showing 31 changed files with 164 additions and 94 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- `input_task_input` and `input_task_output` fields to `StartStructureRunEvent`.
- `output_task_input` and `output_task_output` fields to `FinishStructureRunEvent`.
- `AmazonS3FileManagerDriver` for managing files on Amazon S3.
- `MediaArtifact` as a base class for `ImageArtifact` and future media Artifacts.

### Changed
- **BREAKING**: Secret fields (ex: api_key) removed from serialized Drivers.
Expand All @@ -33,6 +34,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- **BREAKING**: Replaced `EventListener.handler` with `EventListener.driver` and `LocalEventListenerDriver`.
- Improved RAG performance in `VectorQueryEngine`.
- **BREAKING**: Removed `workdir`, `loaders`, `default_loader`, and `save_file_encoding` fields from `FileManager` and added `file_manager_driver`.
- **BREADKING**: Removed `mime_type` field from `ImageArtifact`. `mime_type` is now a property constructed using the Artifact type and `format` field.

## [0.24.2] - 2024-04-04

Expand Down
2 changes: 2 additions & 0 deletions griptape/artifacts/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from .blob_artifact import BlobArtifact
from .csv_row_artifact import CsvRowArtifact
from .list_artifact import ListArtifact
from .media_artifact import MediaArtifact
from .image_artifact import ImageArtifact


Expand All @@ -17,4 +18,5 @@
"CsvRowArtifact",
"ListArtifact",
"ImageArtifact",
"MediaArtifact",
]
46 changes: 11 additions & 35 deletions griptape/artifacts/image_artifact.py
Original file line number Diff line number Diff line change
@@ -1,47 +1,23 @@
from __future__ import annotations

import base64
import string
import time
import random
from typing import Optional
from attr import define, field, Factory
from griptape.artifacts import BlobArtifact
from attr import define, field

from griptape.artifacts import MediaArtifact


@define
class ImageArtifact(BlobArtifact):
"""ImageArtifact is a type of BlobArtifact that represents an image.
class ImageArtifact(MediaArtifact):
"""ImageArtifact is a type of MediaArtifact representing an image.
Attributes:
value: Raw bytes representing the image.
value: Raw bytes representing media data.
media_type: The type of media, defaults to "image".
format: The format of the media, like png, jpeg, or gif.
name: Artifact name, generated using creation time and a random string.
mime_type: The mime type of the image, like image/png or image/jpeg.
width: The width of the image in pixels.
height: The height of the image in pixels.
model: Optionally specify the model used to generate the image.
prompt: Optionally specify the prompt used to generate the image.
model: Optionally specify the model used to generate the media.
prompt: Optionally specify the prompt used to generate the media.
"""

mime_type: str = field(kw_only=True, default="image/png", metadata={"serializable": True})
media_type: str = "image"
width: int = field(kw_only=True, metadata={"serializable": True})
height: int = field(kw_only=True, metadata={"serializable": True})
model: Optional[str] = field(default=None, kw_only=True, metadata={"serializable": True})
prompt: Optional[str] = field(default=None, kw_only=True, metadata={"serializable": True})
name: str = field(
default=Factory(lambda self: self.make_name(), takes_self=True), kw_only=True, metadata={"serializable": True}
)

@property
def base64(self) -> str:
return base64.b64encode(self.value).decode("utf-8")

def make_name(self) -> str:
entropy = "".join(random.choices(string.ascii_lowercase + string.digits, k=4))
fmt_time = time.strftime("%y%m%d%H%M%S", time.localtime())
extension = self.mime_type.split("/")[1].split("+")[0]

return f"image_artifact_{fmt_time}_{entropy}.{extension}"

def to_text(self) -> str:
return f"Image, dimensions: {self.width}x{self.height}, type: {self.mime_type}, size: {len(self.value)} bytes"
53 changes: 53 additions & 0 deletions griptape/artifacts/media_artifact.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
from __future__ import annotations

import string
import time
import random
from typing import Optional

from attr import define, field
from griptape.artifacts import BlobArtifact
import base64


@define
class MediaArtifact(BlobArtifact):
"""MediaArtifact is a type of BlobArtifact that represents media (image, audio, video, etc.)
and can be extended to support a specific media type.
Attributes:
value: Raw bytes representing media data.
media_type: The type of media, like image, audio, or video.
format: The format of the media, like png, wav, or mp4.
name: Artifact name, generated using creation time and a random string.
model: Optionally specify the model used to generate the media.
prompt: Optionally specify the prompt used to generate the media.
"""

media_type: str = field(default="media", kw_only=True, metadata={"serializable": True})
format: str = field(kw_only=True, metadata={"serializable": True})
model: Optional[str] = field(default=None, kw_only=True, metadata={"serializable": True})
prompt: Optional[str] = field(default=None, kw_only=True, metadata={"serializable": True})

def __attrs_post_init__(self):
# Generating the name string requires attributes set by child classes.
# This waits until all attributes are available before generating a name.
if self.name == self.id:
self.name = self.make_name()

@property
def mime_type(self) -> str:
return f"{self.media_type}/{self.format}"

@property
def base64(self) -> str:
return base64.b64encode(self.value).decode("utf-8")

def to_text(self) -> str:
return f"Media, type: {self.mime_type}, size: {len(self.value)} bytes"

def make_name(self) -> str:
entropy = "".join(random.choices(string.ascii_lowercase + string.digits, k=4))
fmt_time = time.strftime("%y%m%d%H%M%S", time.localtime())

return f"{self.media_type}_artifact_{fmt_time}_{entropy}.{self.format}"
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def try_text_to_image(self, prompts: list[str], negative_prompts: Optional[list[
return ImageArtifact(
prompt=", ".join(prompts),
value=image_bytes,
mime_type="image/png",
format="png",
width=self.image_width,
height=self.image_height,
model=self.model,
Expand All @@ -62,7 +62,7 @@ def try_image_variation(
return ImageArtifact(
prompt=", ".join(prompts),
value=image_bytes,
mime_type="image/png",
format="png",
width=image.width,
height=image.height,
model=self.model,
Expand All @@ -84,7 +84,7 @@ def try_image_inpainting(
return ImageArtifact(
prompt=", ".join(prompts),
value=image_bytes,
mime_type="image/png",
format="png",
width=image.width,
height=image.height,
model=self.model,
Expand All @@ -106,7 +106,7 @@ def try_image_outpainting(
return ImageArtifact(
prompt=", ".join(prompts),
value=image_bytes,
mime_type="image/png",
format="png",
width=image.width,
height=image.height,
model=self.model,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def try_text_to_image(self, prompts: list[str], negative_prompts: Optional[list[

return ImageArtifact(
value=image_data,
mime_type="image/png",
format="png",
width=self.image_width,
height=self.image_height,
model=self.model,
Expand All @@ -75,7 +75,7 @@ def try_image_variation(

return ImageArtifact(
value=image_data,
mime_type="image/png",
format="png",
width=self.image_width,
height=self.image_height,
model=self.model,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ def _parse_image_response(self, response: ImagesResponse, prompt: str) -> ImageA

return ImageArtifact(
value=image_data,
mime_type="image/png",
format="png",
width=image_dimensions[0],
height=image_dimensions[1],
model=self.model,
Expand Down
4 changes: 1 addition & 3 deletions griptape/loaders/image_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,9 +45,7 @@ def load(self, source: bytes, *args, **kwargs) -> ImageArtifact:
image = Image.open(byte_stream)
source = byte_stream.getvalue()

image_artifact = ImageArtifact(
source, mime_type=self._get_mime_type(image.format), width=image.width, height=image.height
)
image_artifact = ImageArtifact(source, format=image.format.lower(), width=image.width, height=image.height)

return image_artifact

Expand Down
2 changes: 1 addition & 1 deletion tests/mocks/mock_image_generation_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,4 +17,4 @@ def input(self, value: str):
self._input = TextArtifact(value)

def run(self) -> ImageArtifact:
return ImageArtifact(value=b"image data", mime_type="image/png", width=512, height=512)
return ImageArtifact(value=b"image data", format="png", width=512, height=512)
4 changes: 2 additions & 2 deletions tests/unit/artifacts/test_base_artifact.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,8 @@ def test_image_artifact_from_dict(self):
dict_value = {
"type": "ImageArtifact",
"value": b"aW1hZ2UgZGF0YQ==",
"mime_type": "image/png",
"dir_name": "foo",
"format": "png",
"width": 256,
"height": 256,
"model": "test-model",
Expand All @@ -60,7 +60,7 @@ def test_image_artifact_from_dict(self):
artifact = BaseArtifact.from_dict(dict_value)

assert isinstance(artifact, ImageArtifact)
assert artifact.to_text() == "Image, dimensions: 256x256, type: image/png, size: 10 bytes"
assert artifact.to_text() == "Media, type: image/png, size: 10 bytes"
assert artifact.value == b"image data"

def test_unsupported_from_dict(self):
Expand Down
31 changes: 31 additions & 0 deletions tests/unit/artifacts/test_base_media_artifact.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
import pytest

from attr import define

from griptape.artifacts import MediaArtifact


class TestMediaArtifact:
@define
class ImaginaryMediaArtifact(MediaArtifact):
media_type: str = "imagination"

@pytest.fixture
def media_artifact(self):
return self.ImaginaryMediaArtifact(value=b"some binary dream data", format="dream")

def test_to_dict(self, media_artifact):
image_dict = media_artifact.to_dict()

assert image_dict["format"] == "dream"
assert image_dict["value"] == "c29tZSBiaW5hcnkgZHJlYW0gZGF0YQ=="

def test_name(self, media_artifact):
assert media_artifact.name.startswith("imagination_artifact")
assert media_artifact.name.endswith(".dream")

def test_mime_type(self, media_artifact):
assert media_artifact.mime_type == "imagination/dream"

def test_to_text(self, media_artifact):
assert media_artifact.to_text() == "Media, type: imagination/dream, size: 22 bytes"
11 changes: 6 additions & 5 deletions tests/unit/artifacts/test_image_artifact.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,20 +7,20 @@ class TestImageArtifact:
def image_artifact(self):
return ImageArtifact(
value=b"some binary png image data",
mime_type="image/png",
format="png",
width=512,
height=512,
model="openai/dalle2",
prompt="a cute cat",
)

def test_to_text(self, image_artifact):
assert image_artifact.to_text() == "Image, dimensions: 512x512, type: image/png, size: 26 bytes"
def test_to_text(self, image_artifact: ImageArtifact):
assert image_artifact.to_text() == "Media, type: image/png, size: 26 bytes"

def test_to_dict(self, image_artifact):
def test_to_dict(self, image_artifact: ImageArtifact):
image_dict = image_artifact.to_dict()

assert image_dict["mime_type"] == "image/png"
assert image_dict["format"] == "png"
assert image_dict["width"] == 512
assert image_dict["height"] == 512
assert image_dict["model"] == "openai/dalle2"
Expand All @@ -35,6 +35,7 @@ def test_deserialization(self, image_artifact):

assert deserialized_artifact.value == b"some binary png image data"
assert deserialized_artifact.mime_type == "image/png"
assert deserialized_artifact.format == "png"
assert deserialized_artifact.width == 512
assert deserialized_artifact.height == 512
assert deserialized_artifact.model == "openai/dalle2"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,9 @@ def test_init(self, driver):

def test_init_requires_image_generation_model_driver(self, session):
with pytest.raises(TypeError):
AmazonBedrockImageGenerationDriver(session=session, model="stability.stable-diffusion-xl-v1")
AmazonBedrockImageGenerationDriver(
session=session, model="stability.stable-diffusion-xl-v1"
) # pyright: ignore

def test_try_text_to_image(self, driver):
driver.bedrock_client.invoke_model.return_value = {
Expand All @@ -56,6 +58,7 @@ def test_try_text_to_image(self, driver):
image_artifact = driver.try_text_to_image(prompts=["test prompt"], negative_prompts=["test negative prompt"])

assert image_artifact.value == b"image data"
assert image_artifact.format == "png"
assert image_artifact.mime_type == "image/png"
assert image_artifact.width == 512
assert image_artifact.height == 512
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,22 +21,22 @@ def test_try_image_variation(self, image_generation_driver):
with pytest.raises(DummyException):
image_generation_driver.try_image_variation(
"prompt-stack",
ImageArtifact(value=b"", width=100, height=100),
ImageArtifact(value=b"", width=100, height=100),
ImageArtifact(value=b"", width=100, height=100, format="png"),
ImageArtifact(value=b"", width=100, height=100, format="png"),
)

def test_try_image_inpainting(self, image_generation_driver):
with pytest.raises(DummyException):
image_generation_driver.try_image_inpainting(
"prompt-stack",
ImageArtifact(value=b"", width=100, height=100),
ImageArtifact(value=b"", width=100, height=100),
ImageArtifact(value=b"", width=100, height=100, format="png"),
ImageArtifact(value=b"", width=100, height=100, format="png"),
)

def test_try_image_outpainting(self, image_generation_driver):
with pytest.raises(DummyException):
image_generation_driver.try_image_outpainting(
"prompt-stack",
ImageArtifact(value=b"", width=100, height=100),
ImageArtifact(value=b"", width=100, height=100),
ImageArtifact(value=b"", width=100, height=100, format="png"),
ImageArtifact(value=b"", width=100, height=100, format="png"),
)
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,11 @@ def model_driver(self):

@pytest.fixture
def image_artifact(self):
return ImageArtifact(b"image", mime_type="image/png", width=1024, height=1024)
return ImageArtifact(b"image", format="png", width=1024, height=1024)

@pytest.fixture
def mask_artifact(self):
return ImageArtifact(b"mask", mime_type="image/png", width=1024, height=1024)
return ImageArtifact(b"mask", format="png", width=1024, height=1024)

def test_init(self, model_driver):
assert model_driver
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,11 @@ def model_driver(self):

@pytest.fixture
def image_artifact(self):
return ImageArtifact(b"image", mime_type="image/png", width=1024, height=512)
return ImageArtifact(b"image", format="png", width=1024, height=512)

@pytest.fixture
def mask_artifact(self):
return ImageArtifact(b"mask", mime_type="image/png", width=1024, height=512)
return ImageArtifact(b"mask", format="png", width=1024, height=512)

def test_init(self, model_driver):
assert model_driver
Expand Down
Loading

0 comments on commit 54ef8fb

Please sign in to comment.