Skip to content

Commit

Permalink
[HF][5/n] Image2Text: Allow base64 inputs for images
Browse files Browse the repository at this point in the history
Before we didn't allow base64, only URI (either local or http or https). This is good becuase our text2Image model parser outputs into a base64 format, so this will allow us to chain model prompts!

## Test Plan
  • Loading branch information
Rossdan Craig rossdan@lastmileai.dev committed Jan 10, 2024
1 parent a5a26aa commit bdfa093
Showing 1 changed file with 16 additions and 5 deletions.
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
import base64
import json
from io import BytesIO
from PIL import Image
from typing import Any, Dict, Optional, List, TYPE_CHECKING
from transformers import (
Pipeline,
Expand Down Expand Up @@ -218,7 +221,7 @@ def validate_attachment_type_is_image(attachment: Attachment):
raise ValueError(f"Invalid attachment mimetype {attachment.mime_type}. Expected image mimetype.")


def validate_and_retrieve_image_from_attachments(prompt: Prompt) -> list[str]:
def validate_and_retrieve_image_from_attachments(prompt: Prompt) -> list[str | Image]:
"""
Retrieves the image uri's from each attachment in the prompt input.
Expand All @@ -232,15 +235,23 @@ def validate_and_retrieve_image_from_attachments(prompt: Prompt) -> list[str]:
if not hasattr(prompt.input, "attachments") or len(prompt.input.attachments) == 0:
raise ValueError(f"No attachments found in input for prompt {prompt.name}. Please add an image attachment to the prompt input.")

image_uris: list[str] = []
images: list[str | Image] = []

for i, attachment in enumerate(prompt.input.attachments):
validate_attachment_type_is_image(attachment)

if not isinstance(attachment.data, str):
input_data = attachment.data
if not isinstance(input_data, str):
# See todo above, but for now only support uri's
raise ValueError(f"Attachment #{i} data is not a uri. Please specify a uri for the image attachment in prompt {prompt.name}.")

image_uris.append(attachment.data)
# Really basic heurestic to check if the data is a base64 encoded str
# vs. uri. This will be fixed once we have standardized inputs
# See https://github.com/lastmile-ai/aiconfig/issues/829
if len(input_data) > 10000:
pil_image : Image = Image.open(BytesIO(base64.b64decode(input_data)))
images.append(pil_image)
else:
images.append(input_data)

return image_uris
return images

0 comments on commit bdfa093

Please sign in to comment.