Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[AIC-py] hf image2text parser #821

Merged
merged 1 commit into from
Jan 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,19 +1,20 @@
from .local_inference.image_2_text import HuggingFaceImage2TextTransformer
from .local_inference.text_2_image import HuggingFaceText2ImageDiffusor
from .local_inference.text_2_speech import HuggingFaceText2SpeechTransformer
rossdanlm marked this conversation as resolved.
Show resolved Hide resolved
from .local_inference.text_generation import HuggingFaceTextGenerationTransformer
from .remote_inference_client.text_generation import HuggingFaceTextGenerationParser
from .local_inference.text_summarization import HuggingFaceTextSummarizationTransformer
from .local_inference.text_translation import HuggingFaceTextTranslationTransformer
from .local_inference.text_2_speech import HuggingFaceText2SpeechTransformer

from .remote_inference_client.text_generation import HuggingFaceTextGenerationParser

# from .remote_inference_client.text_generation import HuggingFaceTextGenerationClient

LOCAL_INFERENCE_CLASSES = [
"HuggingFaceText2ImageDiffusor",
"HuggingFaceTextGenerationTransformer",
"HuggingFaceTextSummarizationTransformer",
"HuggingFaceTextTranslationTransformer",
"HuggingFaceText2SpeechTransformer",
"HuggingFaceAutomaticSpeechRecognition",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cc @Ankush-lastmile you may have merge conflicts with your other PR?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: Can we also do these in alphabetical order?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed in #862

"HuggingFaceImage2TextTransformer",
]
REMOTE_INFERENCE_CLASSES = ["HuggingFaceTextGenerationParser"]
__ALL__ = LOCAL_INFERENCE_CLASSES + REMOTE_INFERENCE_CLASSES
Original file line number Diff line number Diff line change
@@ -0,0 +1,166 @@
from typing import Any, Dict, Optional, List, TYPE_CHECKING
from aiconfig import ParameterizedModelParser, InferenceOptions
from aiconfig.callback import CallbackEvent
import torch
from aiconfig.schema import Prompt, Output, ExecuteResult, Attachment

from transformers import pipeline, Pipeline

if TYPE_CHECKING:
from aiconfig import AIConfigRuntime


class HuggingFaceImage2TextTransformer(ParameterizedModelParser):
def __init__(self):
"""
Returns:
HuggingFaceImage2TextTransformer

Usage:
1. Create a new model parser object with the model ID of the model to use.
parser = HuggingFaceImage2TextTransformer()
2. Add the model parser to the registry.
config.register_model_parser(parser)
"""
super().__init__()
self.pipelines: dict[str, Pipeline] = {}

def id(self) -> str:
"""
Returns an identifier for the Model Parser
"""
return "HuggingFaceImage2TextTransformer"

async def serialize(
self,
prompt_name: str,
data: Any,
ai_config: "AIConfigRuntime",
parameters: Optional[Dict[str, Any]] = None,
) -> List[Prompt]:
"""
Defines how a prompt and model inference settings get serialized in the .aiconfig.
Assume input in the form of input(s) being passed into an already constructed pipeline.

Args:
prompt (str): The prompt to be serialized.
data (Any): Model-specific inference settings to be serialized.
ai_config (AIConfigRuntime): The AIConfig Runtime.
parameters (Dict[str, Any], optional): Model-specific parameters. Defaults to None.

Returns:
str: Serialized representation of the prompt and inference settings.
"""
await ai_config.callback_manager.run_callbacks(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add a TODO linking to #822 to fix later(and add automated testing. I'll do this later

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what's broken?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We're not using the correct model_id so I need to pass this in so we can re-create the prompt

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added TODO comment in code in #862

CallbackEvent(
"on_serialize_start",
__name__,
{
"prompt_name": prompt_name,
"data": data,
"parameters": parameters,
},
)
)

prompts = []

if not isinstance(data, dict):
raise ValueError("Invalid data type. Expected dict when serializing prompt data to aiconfig.")
if data.get("inputs", None) is None:
raise ValueError("Invalid data when serializing prompt to aiconfig. Input data must contain an inputs field.")

prompt = Prompt(
**{
"name": prompt_name,
"input": {"attachments": [{"data": data["inputs"]}]},
"metadata": None,
"outputs": None,
}
)

prompts.append(prompt)

await ai_config.callback_manager.run_callbacks(CallbackEvent("on_serialize_complete", __name__, {"result": prompts}))
return prompts

async def deserialize(
self,
prompt: Prompt,
aiconfig: "AIConfigRuntime",
params: Optional[Dict[str, Any]] = {},
) -> Dict[str, Any]:
await aiconfig.callback_manager.run_callbacks(CallbackEvent("on_deserialize_start", __name__, {"prompt": prompt, "params": params}))

# Build Completion data
completion_params = self.get_model_settings(prompt, aiconfig)

inputs = validate_and_retrieve_image_from_attachments(prompt)

completion_params["inputs"] = inputs
Comment on lines +98 to +100
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cc @Ankush-lastmile when this lands, can you link to the Attachment format/standardizing inputs issue we mentioned? Jonathan you don't need to do any work, just making sure Ankush is aware


await aiconfig.callback_manager.run_callbacks(CallbackEvent("on_deserialize_complete", __name__, {"output": completion_params}))
return completion_params

async def run_inference(self, prompt: Prompt, aiconfig: "AIConfigRuntime", options: InferenceOptions, parameters: Dict[str, Any]) -> list[Output]:
await aiconfig.callback_manager.run_callbacks(
CallbackEvent(
"on_run_start",
__name__,
{"prompt": prompt, "options": options, "parameters": parameters},
)
)
model_name = aiconfig.get_model_name(prompt)

self.pipelines[model_name] = pipeline(task="image-to-text", model=model_name)

captioner = self.pipelines[model_name]
completion_data = await self.deserialize(prompt, aiconfig, parameters)
inputs = completion_data.pop("inputs")
model = completion_data.pop("model")
Copy link
Contributor

@rossdanlm rossdanlm Jan 9, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is never used again, why would we have it in completion data in the first place?

If it's never used, pls prefix with _model

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it has to be removed from completion data. Something is definitely off here, I just don't know exactly what. cc @Ankush-lastmile

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed in #855 (I just deleted, we must've had model in the settings param). cc @saqadri who is working on adding this explicitly

response = captioner(inputs, **completion_data)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does pipeline only support inputs as URI, or does it also work with base64 encoded? If not, pls make task that we need to convert from base64 --> image URI first

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed in #856

rossdanlm marked this conversation as resolved.
Show resolved Hide resolved

output = ExecuteResult(output_type="execute_result", data=response, metadata={})
Copy link
Contributor

@rossdanlm rossdanlm Jan 9, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh sweet, so response is just purely text? nice! Also let's add "execution_count=0"

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed in #855


prompt.outputs = [output]
await aiconfig.callback_manager.run_callbacks(CallbackEvent("on_run_complete", __name__, {"result": prompt.outputs}))
return prompt.outputs

def get_output_text(self, response: dict[str, Any]) -> str:
raise NotImplementedError("get_output_text is not implemented for HuggingFaceImage2TextTransformer")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pls update to match others like the ones in text_generation.py

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed in #855



def validate_attachment_type_is_image(attachment: Attachment):
if not hasattr(attachment, "mime_type"):
raise ValueError(f"Attachment has no mime type. Specify the image mimetype in the aiconfig")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit; add the work "Please" before "Specify"

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated in #862


if not attachment.mime_type.startswith("image/"):
raise ValueError(f"Invalid attachment mimetype {attachment.mime_type}. Expected image mimetype.")


def validate_and_retrieve_image_from_attachments(prompt: Prompt) -> list[str]:
"""
Retrieves the image uri's from each attachment in the prompt input.

Throws an exception if
- attachment is not image
- attachment data is not a uri
- no attachments are found
- operation fails for any reason
"""

if not hasattr(prompt.input, "attachments") or len(prompt.input.attachments) == 0:
raise ValueError(f"No attachments found in input for prompt {prompt.name}. Please add an image attachment to the prompt input.")

image_uris: list[str] = []

for i, attachment in enumerate(prompt.input.attachments):
validate_attachment_type_is_image(attachment)

if not isinstance(attachment.data, str):
# See todo above, but for now only support uri's
raise ValueError(f"Attachment #{i} data is not a uri. Please specify a uri for the image attachment in prompt {prompt.name}.")
Comment on lines +161 to +162
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Doesn't have to be this diff, but please add support for base64 as well. This is important since if we want to chain prompts, some of our models output in base64 format (ex: text_2_image)

At very least, create an issue to track

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed in #856


image_uris.append(attachment.data)

return image_uris