-
Notifications
You must be signed in to change notification settings - Fork 79
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[HF][streaming][4/n] Image2Text (no streaming, but lots of fixing) #855
Changes from 1 commit
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,11 +1,21 @@ | ||
import json | ||
from typing import Any, Dict, Optional, List, TYPE_CHECKING | ||
from transformers import ( | ||
Pipeline, | ||
pipeline, | ||
) | ||
|
||
from aiconfig import ParameterizedModelParser, InferenceOptions | ||
from aiconfig.callback import CallbackEvent | ||
import torch | ||
from aiconfig.schema import Prompt, Output, ExecuteResult, Attachment | ||
|
||
from transformers import pipeline, Pipeline | ||
|
||
from aiconfig.schema import ( | ||
Attachment, | ||
ExecuteResult, | ||
Output, | ||
OutputDataWithValue, | ||
Prompt, | ||
) | ||
|
||
# Circular Dependency Type Hints | ||
if TYPE_CHECKING: | ||
from aiconfig import AIConfigRuntime | ||
|
||
|
@@ -93,10 +103,11 @@ async def deserialize( | |
await aiconfig.callback_manager.run_callbacks(CallbackEvent("on_deserialize_start", __name__, {"prompt": prompt, "params": params})) | ||
|
||
# Build Completion data | ||
completion_params = self.get_model_settings(prompt, aiconfig) | ||
model_settings = self.get_model_settings(prompt, aiconfig) | ||
completion_params = refine_completion_params(model_settings) | ||
|
||
#Add image inputs | ||
inputs = validate_and_retrieve_image_from_attachments(prompt) | ||
|
||
completion_params["inputs"] = inputs | ||
|
||
await aiconfig.callback_manager.run_callbacks(CallbackEvent("on_deserialize_complete", __name__, {"output": completion_params})) | ||
|
@@ -110,24 +121,93 @@ async def run_inference(self, prompt: Prompt, aiconfig: "AIConfigRuntime", optio | |
{"prompt": prompt, "options": options, "parameters": parameters}, | ||
) | ||
) | ||
model_name = aiconfig.get_model_name(prompt) | ||
|
||
self.pipelines[model_name] = pipeline(task="image-to-text", model=model_name) | ||
|
||
captioner = self.pipelines[model_name] | ||
completion_data = await self.deserialize(prompt, aiconfig, parameters) | ||
inputs = completion_data.pop("inputs") | ||
model = completion_data.pop("model") | ||
response = captioner(inputs, **completion_data) | ||
|
||
output = ExecuteResult(output_type="execute_result", data=response, metadata={}) | ||
model_name: str | None = aiconfig.get_model_name(prompt) | ||
if isinstance(model_name, str) and model_name not in self.pipelines: | ||
self.pipelines[model_name] = pipeline(task="image-to-text", model=model_name) | ||
captioner = self.pipelines[model_name] | ||
|
||
outputs: List[Output] = [] | ||
response: List[Any] = captioner(inputs, **completion_data) | ||
for count, result in enumerate(response): | ||
output: Output = construct_regular_output(result, count) | ||
outputs.append(output) | ||
|
||
prompt.outputs = [output] | ||
await aiconfig.callback_manager.run_callbacks(CallbackEvent("on_run_complete", __name__, {"result": prompt.outputs})) | ||
prompt.outputs = outputs | ||
print(f"{prompt.outputs=}") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Remove print? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Fixed in #862 |
||
await aiconfig.callback_manager.run_callbacks( | ||
CallbackEvent( | ||
"on_run_complete", | ||
__name__, | ||
{"result": prompt.outputs}, | ||
) | ||
) | ||
return prompt.outputs | ||
|
||
def get_output_text(self, response: dict[str, Any]) -> str: | ||
raise NotImplementedError("get_output_text is not implemented for HuggingFaceImage2TextTransformer") | ||
def get_output_text( | ||
self, | ||
prompt: Prompt, | ||
aiconfig: "AIConfigRuntime", | ||
output: Optional[Output] = None, | ||
) -> str: | ||
if output is None: | ||
output = aiconfig.get_latest_output(prompt) | ||
|
||
if output is None: | ||
return "" | ||
|
||
# TODO (rossdanlm): Handle multiple outputs in list | ||
# https://github.com/lastmile-ai/aiconfig/issues/467 | ||
if output.output_type == "execute_result": | ||
output_data = output.data | ||
if isinstance(output_data, str): | ||
return output_data | ||
if isinstance(output_data, OutputDataWithValue): | ||
if isinstance(output_data.value, str): | ||
return output_data.value | ||
# HuggingFace Text summarization does not support function | ||
# calls so shouldn't get here, but just being safe | ||
Comment on lines
+171
to
+172
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Fixed in #862 |
||
return json.dumps(output_data.value, indent=2) | ||
return "" | ||
|
||
|
||
def refine_completion_params(model_settings: Dict[str, Any]) -> Dict[str, Any]: | ||
""" | ||
Refines the completion params for the HF image to text api. Removes any unsupported params. | ||
The supported keys were found by looking at the HF ImageToTextPipeline.__call__ method | ||
""" | ||
supported_keys = { | ||
"max_new_tokens", | ||
"timeout", | ||
} | ||
|
||
completion_data = {} | ||
for key in model_settings: | ||
if key.lower() in supported_keys: | ||
completion_data[key.lower()] = model_settings[key] | ||
|
||
return completion_data | ||
|
||
# Helper methods | ||
def construct_regular_output(result: Dict[str, str], execution_count: int) -> Output: | ||
""" | ||
Construct regular output per response result, without streaming enabled | ||
""" | ||
output = ExecuteResult( | ||
**{ | ||
"output_type": "execute_result", | ||
# For some reason result is always in list format we haven't found | ||
# a way of being able to return multiple sequences from the image | ||
# to text pipeline | ||
"data": result[0]["generated_text"], | ||
"execution_count": execution_count, | ||
"metadata": {}, | ||
} | ||
) | ||
return output | ||
|
||
|
||
def validate_attachment_type_is_image(attachment: Attachment): | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good catch!