diff --git a/instill/helpers/protobufs/parse.py b/instill/helpers/protobufs/parse.py index fe9a1ef..07ff568 100644 --- a/instill/helpers/protobufs/parse.py +++ b/instill/helpers/protobufs/parse.py @@ -9,6 +9,8 @@ InferTensor, ) +from instill.helpers.const import DataType + @dataclass class Metadata: @@ -75,3 +77,403 @@ def construct_infer_response( resp.raw_output_contents.extend(raw_outputs) return resp + + +def construct_text_generation_metadata_response( + req: ModelMetadataRequest, +) -> ModelMetadataResponse: + return construct_metadata_response( + req=req, + inputs=[ + Metadata( + name="prompt", + datatype=str(DataType.TYPE_STRING.name), + shape=[1], + ), + Metadata( + name="prompt_images", + datatype=str(DataType.TYPE_STRING.name), + shape=[1], + ), + Metadata( + name="chat_history", + datatype=str(DataType.TYPE_STRING.name), + shape=[1], + ), + Metadata( + name="system_message", + datatype=str(DataType.TYPE_STRING.name), + shape=[1], + ), + Metadata( + name="max_new_tokens", + datatype=str(DataType.TYPE_UINT32.name), + shape=[1], + ), + Metadata( + name="temperature", + datatype=str(DataType.TYPE_FP32.name), + shape=[1], + ), + Metadata( + name="top_k", + datatype=str(DataType.TYPE_UINT32.name), + shape=[1], + ), + Metadata( + name="random_seed", + datatype=str(DataType.TYPE_UINT64.name), + shape=[1], + ), + Metadata( + name="extra_params", + datatype=str(DataType.TYPE_STRING.name), + shape=[1], + ), + ], + outputs=[ + Metadata( + name="text", + datatype=str(DataType.TYPE_STRING.name), + shape=[-1, -1], + ), + ], + ) + + +def construct_text_generation_infer_response( + req: RayServiceCallRequest, + shape: list, + raw_outputs: List[bytes], +): + return construct_infer_response( + req=req, + outputs=[ + Metadata( + name="text", + datatype=str(DataType.TYPE_STRING.name), + shape=shape, + ) + ], + raw_outputs=raw_outputs, + ) + + +def construct_text_generation_chat_metadata_response( + req: ModelMetadataRequest, +) -> ModelMetadataResponse: + return construct_metadata_response( + req=req, + inputs=[ + Metadata( + name="prompt", + datatype=str(DataType.TYPE_STRING.name), + shape=[1], + ), + Metadata( + name="prompt_images", + datatype=str(DataType.TYPE_STRING.name), + shape=[1], + ), + Metadata( + name="chat_history", + datatype=str(DataType.TYPE_STRING.name), + shape=[1], + ), + Metadata( + name="system_message", + datatype=str(DataType.TYPE_STRING.name), + shape=[1], + ), + Metadata( + name="max_new_tokens", + datatype=str(DataType.TYPE_UINT32.name), + shape=[1], + ), + Metadata( + name="temperature", + datatype=str(DataType.TYPE_FP32.name), + shape=[1], + ), + Metadata( + name="top_k", + datatype=str(DataType.TYPE_UINT32.name), + shape=[1], + ), + Metadata( + name="random_seed", + datatype=str(DataType.TYPE_UINT64.name), + shape=[1], + ), + Metadata( + name="extra_params", + datatype=str(DataType.TYPE_STRING.name), + shape=[1], + ), + ], + outputs=[ + Metadata( + name="text", + datatype=str(DataType.TYPE_STRING.name), + shape=[-1, -1], + ), + ], + ) + + +def construct_text_generation_chat_infer_response( + req: RayServiceCallRequest, + shape: list, + raw_outputs: List[bytes], +): + return construct_infer_response( + req=req, + outputs=[ + Metadata( + name="text", + datatype=str(DataType.TYPE_STRING.name), + shape=shape, + ) + ], + raw_outputs=raw_outputs, + ) + + +def construct_text_to_image_metadata_response( + req: ModelMetadataRequest, +): + return construct_metadata_response( + req=req, + inputs=[ + Metadata( + name="prompt", + datatype=str(DataType.TYPE_STRING.name), + shape=[1], + ), + Metadata( + name="negative_prompt", + datatype=str(DataType.TYPE_STRING.name), + shape=[1], + ), + Metadata( + name="prompt_image", + datatype=str(DataType.TYPE_STRING.name), + shape=[1], + ), + Metadata( + name="samples", + datatype=str(DataType.TYPE_INT32.name), + shape=[1], + ), + Metadata( + name="scheduler", + datatype=str(DataType.TYPE_STRING.name), + shape=[1], + ), + Metadata( + name="steps", + datatype=str(DataType.TYPE_INT32.name), + shape=[1], + ), + Metadata( + name="guidance_scale", + datatype=str(DataType.TYPE_FP32.name), + shape=[1], + ), + Metadata( + name="seed", + datatype=str(DataType.TYPE_INT64.name), + shape=[1], + ), + Metadata( + name="extra_params", + datatype=str(DataType.TYPE_STRING.name), + shape=[1], + ), + ], + outputs=[ + Metadata( + name="images", + datatype=str(DataType.TYPE_FP32.name), + shape=[-1, -1, -1, -1], + ), + ], + ) + + +def construct_text_to_image_infer_response( + req: RayServiceCallRequest, + shape: list, + raw_outputs: List[bytes], +): + return construct_infer_response( + req=req, + outputs=[ + Metadata( + name="images", + datatype=str(DataType.TYPE_FP32.name), + shape=shape, + ) + ], + raw_outputs=raw_outputs, + ) + + +def construct_image_to_image_metadata_response( + req: ModelMetadataRequest, +): + return construct_metadata_response( + req=req, + inputs=[ + Metadata( + name="prompt", + datatype=str(DataType.TYPE_STRING.name), + shape=[1], + ), + Metadata( + name="negative_prompt", + datatype=str(DataType.TYPE_STRING.name), + shape=[1], + ), + Metadata( + name="prompt_image", + datatype=str(DataType.TYPE_STRING.name), + shape=[1], + ), + Metadata( + name="samples", + datatype=str(DataType.TYPE_INT32.name), + shape=[1], + ), + Metadata( + name="scheduler", + datatype=str(DataType.TYPE_STRING.name), + shape=[1], + ), + Metadata( + name="steps", + datatype=str(DataType.TYPE_INT32.name), + shape=[1], + ), + Metadata( + name="guidance_scale", + datatype=str(DataType.TYPE_FP32.name), + shape=[1], + ), + Metadata( + name="seed", + datatype=str(DataType.TYPE_INT64.name), + shape=[1], + ), + Metadata( + name="extra_params", + datatype=str(DataType.TYPE_STRING.name), + shape=[1], + ), + ], + outputs=[ + Metadata( + name="images", + datatype=str(DataType.TYPE_FP32.name), + shape=[-1, -1, -1, -1], + ), + ], + ) + + +def construct_image_to_image_infer_response( + req: RayServiceCallRequest, + shape: list, + raw_outputs: List[bytes], +): + return construct_infer_response( + req=req, + outputs=[ + Metadata( + name="images", + datatype=str(DataType.TYPE_FP32.name), + shape=shape, + ) + ], + raw_outputs=raw_outputs, + ) + + +def construct_visual_question_answering_metadata_response( + req: ModelMetadataRequest, +): + return construct_metadata_response( + req=req, + inputs=[ + Metadata( + name="prompt", + datatype=str(DataType.TYPE_STRING.name), + shape=[1], + ), + Metadata( + name="prompt_images", + datatype=str(DataType.TYPE_STRING.name), + shape=[1], + ), + Metadata( + name="chat_history", + datatype=str(DataType.TYPE_STRING.name), + shape=[1], + ), + Metadata( + name="system_message", + datatype=str(DataType.TYPE_STRING.name), + shape=[1], + ), + Metadata( + name="max_new_tokens", + datatype=str(DataType.TYPE_UINT32.name), + shape=[1], + ), + Metadata( + name="temperature", + datatype=str(DataType.TYPE_FP32.name), + shape=[1], + ), + Metadata( + name="top_k", + datatype=str(DataType.TYPE_UINT32.name), + shape=[1], + ), + Metadata( + name="random_seed", + datatype=str(DataType.TYPE_UINT64.name), + shape=[1], + ), + Metadata( + name="extra_params", + datatype=str(DataType.TYPE_STRING.name), + shape=[1], + ), + ], + outputs=[ + Metadata( + name="text", + datatype=str(DataType.TYPE_STRING.name), + shape=[-1, -1], + ), + ], + ) + + +def construct_visual_question_answering_infer_response( + req: RayServiceCallRequest, + shape: list, + raw_outputs: List[bytes], +): + return construct_infer_response( + req=req, + outputs=[ + Metadata( + name="text", + datatype=str(DataType.TYPE_STRING.name), + shape=shape, + ) + ], + raw_outputs=raw_outputs, + )