Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add MediaArtifact #520

Merged
merged 7 commits into from
Apr 12, 2024
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions griptape/artifacts/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from .blob_artifact import BlobArtifact
from .csv_row_artifact import CsvRowArtifact
from .list_artifact import ListArtifact
from .media_artifact import MediaArtifact
from .image_artifact import ImageArtifact


Expand All @@ -17,4 +18,5 @@
"CsvRowArtifact",
"ListArtifact",
"ImageArtifact",
"MediaArtifact",
]
46 changes: 11 additions & 35 deletions griptape/artifacts/image_artifact.py
Original file line number Diff line number Diff line change
@@ -1,47 +1,23 @@
from __future__ import annotations

import base64
import string
import time
import random
from typing import Optional
from attr import define, field, Factory
from griptape.artifacts import BlobArtifact
from attr import define, field

from griptape.artifacts import MediaArtifact


@define
class ImageArtifact(BlobArtifact):
"""ImageArtifact is a type of BlobArtifact that represents an image.
class ImageArtifact(MediaArtifact):
"""ImageArtifact is a type of MediaArtifact representing an image.

Attributes:
value: Raw bytes representing the image.
value: Raw bytes representing media data.
media_type: The type of media, defaults to "image".
format: The format of the media, like png, jpeg, or gif.
name: Artifact name, generated using creation time and a random string.
mime_type: The mime type of the image, like image/png or image/jpeg.
width: The width of the image in pixels.
height: The height of the image in pixels.
model: Optionally specify the model used to generate the image.
prompt: Optionally specify the prompt used to generate the image.
model: Optionally specify the model used to generate the media.
prompt: Optionally specify the prompt used to generate the media.
"""

mime_type: str = field(kw_only=True, default="image/png", metadata={"serializable": True})
media_type: str = "image"
width: int = field(kw_only=True, metadata={"serializable": True})
height: int = field(kw_only=True, metadata={"serializable": True})
model: Optional[str] = field(default=None, kw_only=True, metadata={"serializable": True})
prompt: Optional[str] = field(default=None, kw_only=True, metadata={"serializable": True})
name: str = field(
default=Factory(lambda self: self.make_name(), takes_self=True), kw_only=True, metadata={"serializable": True}
)

@property
def base64(self) -> str:
return base64.b64encode(self.value).decode("utf-8")

def make_name(self) -> str:
entropy = "".join(random.choices(string.ascii_lowercase + string.digits, k=4))
fmt_time = time.strftime("%y%m%d%H%M%S", time.localtime())
extension = self.mime_type.split("/")[1].split("+")[0]

return f"image_artifact_{fmt_time}_{entropy}.{extension}"

def to_text(self) -> str:
return f"Image, dimensions: {self.width}x{self.height}, type: {self.mime_type}, size: {len(self.value)} bytes"
53 changes: 53 additions & 0 deletions griptape/artifacts/media_artifact.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
from __future__ import annotations

import string
import time
import random
from typing import Optional

from attr import define, field
from griptape.artifacts import BlobArtifact
import base64


@define
class MediaArtifact(BlobArtifact):
"""MediaArtifact is a type of BlobArtifact that represents media (image, audio, video, etc.)
and can be extended to support a specific media type.

Attributes:
value: Raw bytes representing media data.
media_type: The type of media, like image, audio, or video.
format: The format of the media, like png, wav, or mp4.
name: Artifact name, generated using creation time and a random string.
model: Optionally specify the model used to generate the media.
prompt: Optionally specify the prompt used to generate the media.
"""

media_type: str = field(default="media", kw_only=True, metadata={"serializable": True})
format: str = field(kw_only=True, metadata={"serializable": True})
model: Optional[str] = field(default=None, kw_only=True, metadata={"serializable": True})
prompt: Optional[str] = field(default=None, kw_only=True, metadata={"serializable": True})

def __attrs_post_init__(self):
# Generating the name string requires attributes set by child classes.
# This waits until all attributes are available before generating a name.
if self.name == self.id:
self.name = self.make_name()

@property
def mime_type(self) -> str:
return f"{self.media_type}/{self.format}"

@property
def base64(self) -> str:
return base64.b64encode(self.value).decode("utf-8")

def to_text(self) -> str:
return f"Media, type: {self.mime_type}, size: {len(self.value)} bytes"

def make_name(self) -> str:
entropy = "".join(random.choices(string.ascii_lowercase + string.digits, k=4))
fmt_time = time.strftime("%y%m%d%H%M%S", time.localtime())

return f"{self.media_type}_artifact_{fmt_time}_{entropy}.{self.format}"
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def try_text_to_image(self, prompts: list[str], negative_prompts: Optional[list[
return ImageArtifact(
prompt=", ".join(prompts),
value=image_bytes,
mime_type="image/png",
format="png",
width=self.image_width,
height=self.image_height,
model=self.model,
Expand All @@ -62,7 +62,7 @@ def try_image_variation(
return ImageArtifact(
prompt=", ".join(prompts),
value=image_bytes,
mime_type="image/png",
format="png",
width=image.width,
height=image.height,
model=self.model,
Expand All @@ -84,7 +84,7 @@ def try_image_inpainting(
return ImageArtifact(
prompt=", ".join(prompts),
value=image_bytes,
mime_type="image/png",
format="png",
width=image.width,
height=image.height,
model=self.model,
Expand All @@ -106,7 +106,7 @@ def try_image_outpainting(
return ImageArtifact(
prompt=", ".join(prompts),
value=image_bytes,
mime_type="image/png",
format="png",
width=image.width,
height=image.height,
model=self.model,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def try_text_to_image(self, prompts: list[str], negative_prompts: Optional[list[

return ImageArtifact(
value=image_data,
mime_type="image/png",
format="png",
width=self.image_width,
height=self.image_height,
model=self.model,
Expand All @@ -75,7 +75,7 @@ def try_image_variation(

return ImageArtifact(
value=image_data,
mime_type="image/png",
format="png",
width=self.image_width,
height=self.image_height,
model=self.model,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ def _parse_image_response(self, response: ImagesResponse, prompt: str) -> ImageA

return ImageArtifact(
value=image_data,
mime_type="image/png",
format="png",
width=image_dimensions[0],
height=image_dimensions[1],
model=self.model,
Expand Down
4 changes: 1 addition & 3 deletions griptape/loaders/image_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,9 +45,7 @@ def load(self, source: bytes, *args, **kwargs) -> ImageArtifact:
image = Image.open(byte_stream)
source = byte_stream.getvalue()

image_artifact = ImageArtifact(
source, mime_type=self._get_mime_type(image.format), width=image.width, height=image.height
)
image_artifact = ImageArtifact(source, format=image.format.lower(), width=image.width, height=image.height)

return image_artifact

Expand Down
2 changes: 1 addition & 1 deletion tests/mocks/mock_image_generation_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,4 +17,4 @@ def input(self, value: str):
self._input = TextArtifact(value)

def run(self) -> ImageArtifact:
return ImageArtifact(value=b"image data", mime_type="image/png", width=512, height=512)
return ImageArtifact(value=b"image data", format="png", width=512, height=512)
4 changes: 2 additions & 2 deletions tests/unit/artifacts/test_base_artifact.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,8 @@ def test_image_artifact_from_dict(self):
dict_value = {
"type": "ImageArtifact",
"value": b"aW1hZ2UgZGF0YQ==",
"mime_type": "image/png",
"dir_name": "foo",
"format": "png",
"width": 256,
"height": 256,
"model": "test-model",
Expand All @@ -60,7 +60,7 @@ def test_image_artifact_from_dict(self):
artifact = BaseArtifact.from_dict(dict_value)

assert isinstance(artifact, ImageArtifact)
assert artifact.to_text() == "Image, dimensions: 256x256, type: image/png, size: 10 bytes"
assert artifact.to_text() == "Media, type: image/png, size: 10 bytes"
assert artifact.value == b"image data"

def test_unsupported_from_dict(self):
Expand Down
31 changes: 31 additions & 0 deletions tests/unit/artifacts/test_base_media_artifact.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
import pytest

from attr import define

from griptape.artifacts import MediaArtifact


class TestMediaArtifact:
@define
class ImaginaryMediaArtifact(MediaArtifact):
media_type: str = "imagination"

@pytest.fixture
def media_artifact(self):
return self.ImaginaryMediaArtifact(value=b"some binary dream data", format="dream")

def test_to_dict(self, media_artifact):
image_dict = media_artifact.to_dict()

assert image_dict["format"] == "dream"
assert image_dict["value"] == "c29tZSBiaW5hcnkgZHJlYW0gZGF0YQ=="

def test_name(self, media_artifact):
assert media_artifact.name.startswith("imagination_artifact")
assert media_artifact.name.endswith(".dream")

def test_mime_type(self, media_artifact):
assert media_artifact.mime_type == "imagination/dream"

def test_to_text(self, media_artifact):
assert media_artifact.to_text() == "Media, type: imagination/dream, size: 22 bytes"
11 changes: 6 additions & 5 deletions tests/unit/artifacts/test_image_artifact.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,20 +7,20 @@ class TestImageArtifact:
def image_artifact(self):
return ImageArtifact(
value=b"some binary png image data",
mime_type="image/png",
format="png",
width=512,
height=512,
model="openai/dalle2",
prompt="a cute cat",
)

def test_to_text(self, image_artifact):
assert image_artifact.to_text() == "Image, dimensions: 512x512, type: image/png, size: 26 bytes"
def test_to_text(self, image_artifact: ImageArtifact):
assert image_artifact.to_text() == "Media, type: image/png, size: 26 bytes"

def test_to_dict(self, image_artifact):
def test_to_dict(self, image_artifact: ImageArtifact):
image_dict = image_artifact.to_dict()

assert image_dict["mime_type"] == "image/png"
assert image_dict["format"] == "png"
assert image_dict["width"] == 512
assert image_dict["height"] == 512
assert image_dict["model"] == "openai/dalle2"
Expand All @@ -35,6 +35,7 @@ def test_deserialization(self, image_artifact):

assert deserialized_artifact.value == b"some binary png image data"
assert deserialized_artifact.mime_type == "image/png"
assert deserialized_artifact.format == "png"
assert deserialized_artifact.width == 512
assert deserialized_artifact.height == 512
assert deserialized_artifact.model == "openai/dalle2"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,9 @@ def test_init(self, driver):

def test_init_requires_image_generation_model_driver(self, session):
with pytest.raises(TypeError):
AmazonBedrockImageGenerationDriver(session=session, model="stability.stable-diffusion-xl-v1")
AmazonBedrockImageGenerationDriver(
session=session, model="stability.stable-diffusion-xl-v1"
) # pyright: ignore

def test_try_text_to_image(self, driver):
driver.bedrock_client.invoke_model.return_value = {
Expand All @@ -56,6 +58,7 @@ def test_try_text_to_image(self, driver):
image_artifact = driver.try_text_to_image(prompts=["test prompt"], negative_prompts=["test negative prompt"])

assert image_artifact.value == b"image data"
assert image_artifact.format == "png"
assert image_artifact.mime_type == "image/png"
assert image_artifact.width == 512
assert image_artifact.height == 512
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,22 +21,22 @@ def test_try_image_variation(self, image_generation_driver):
with pytest.raises(DummyException):
image_generation_driver.try_image_variation(
"prompt-stack",
ImageArtifact(value=b"", width=100, height=100),
ImageArtifact(value=b"", width=100, height=100),
ImageArtifact(value=b"", width=100, height=100, format="png"),
ImageArtifact(value=b"", width=100, height=100, format="png"),
)

def test_try_image_inpainting(self, image_generation_driver):
with pytest.raises(DummyException):
image_generation_driver.try_image_inpainting(
"prompt-stack",
ImageArtifact(value=b"", width=100, height=100),
ImageArtifact(value=b"", width=100, height=100),
ImageArtifact(value=b"", width=100, height=100, format="png"),
ImageArtifact(value=b"", width=100, height=100, format="png"),
)

def test_try_image_outpainting(self, image_generation_driver):
with pytest.raises(DummyException):
image_generation_driver.try_image_outpainting(
"prompt-stack",
ImageArtifact(value=b"", width=100, height=100),
ImageArtifact(value=b"", width=100, height=100),
ImageArtifact(value=b"", width=100, height=100, format="png"),
ImageArtifact(value=b"", width=100, height=100, format="png"),
)
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,11 @@ def model_driver(self):

@pytest.fixture
def image_artifact(self):
return ImageArtifact(b"image", mime_type="image/png", width=1024, height=1024)
return ImageArtifact(b"image", format="png", width=1024, height=1024)

@pytest.fixture
def mask_artifact(self):
return ImageArtifact(b"mask", mime_type="image/png", width=1024, height=1024)
return ImageArtifact(b"mask", format="png", width=1024, height=1024)

def test_init(self, model_driver):
assert model_driver
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,11 @@ def model_driver(self):

@pytest.fixture
def image_artifact(self):
return ImageArtifact(b"image", mime_type="image/png", width=1024, height=512)
return ImageArtifact(b"image", format="png", width=1024, height=512)

@pytest.fixture
def mask_artifact(self):
return ImageArtifact(b"mask", mime_type="image/png", width=1024, height=512)
return ImageArtifact(b"mask", format="png", width=1024, height=512)

def test_init(self, model_driver):
assert model_driver
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def test_try_query(self, image_query_driver):
image_query_driver.bedrock_client.invoke_model.return_value = {"body": io.BytesIO(b"""{"content": []}""")}

text_artifact = image_query_driver.try_query(
"Prompt String", [ImageArtifact(value=b"test-data", width=100, height=100)]
"Prompt String", [ImageArtifact(value=b"test-data", width=100, height=100, format="png")]
)

assert text_artifact.value == "content"
Expand All @@ -45,4 +45,6 @@ def test_try_query_no_body(self, image_query_driver):
image_query_driver.bedrock_client.invoke_model.return_value = {"body": io.BytesIO(b"")}

with pytest.raises(ValueError):
image_query_driver.try_query("Prompt String", [ImageArtifact(value=b"test-data", width=100, height=100)])
image_query_driver.try_query(
"Prompt String", [ImageArtifact(value=b"test-data", width=100, height=100, format="png")]
)
Loading
Loading