From d450e1a45a0b9bf9ed36921ff3948ad505702678 Mon Sep 17 00:00:00 2001 From: Heiru Wu Date: Thu, 7 Dec 2023 16:37:59 +0800 Subject: [PATCH] feat(ray): add text to image io help --- instill/helpers/const.py | 10 +++++ instill/helpers/ray_io.py | 80 ++++++++++++++++++++++++++++++++++++++- 2 files changed, 89 insertions(+), 1 deletion(-) diff --git a/instill/helpers/const.py b/instill/helpers/const.py index 28fcec8..f89f4dc 100644 --- a/instill/helpers/const.py +++ b/instill/helpers/const.py @@ -26,3 +26,13 @@ class TextGenerationInput: random_seed = 0 stop_words: Any = "" extra_params: Dict[str, str] = {} + + +class TextToImageInput: + prompt = "" + negative_prompt = "" + steps = 5 + guidance_scale = 7.5 + seed = 0 + samples = 1 + extra_params: Dict[str, str] = {} diff --git a/instill/helpers/ray_io.py b/instill/helpers/ray_io.py index 7454234..bf074b6 100644 --- a/instill/helpers/ray_io.py +++ b/instill/helpers/ray_io.py @@ -4,7 +4,7 @@ import numpy as np -from instill.helpers.const import TextGenerationInput +from instill.helpers.const import TextGenerationInput, TextToImageInput def serialize_byte_tensor(input_tensor): @@ -182,6 +182,84 @@ def parse_task_text_generation_output(sequences: list): return serialize_byte_tensor(np.asarray(text_outputs)) + @staticmethod + def parse_task_text_to_image_input(request) -> TextToImageInput: + text_to_image_input = TextToImageInput() + + for i, b_input_tensor in zip(request.inputs, request.raw_input_contents): + input_name = i.name + + if input_name == "prompt": + input_tensor = deserialize_bytes_tensor(b_input_tensor) + text_to_image_input.prompt = str(input_tensor[0].decode("utf-8")) + print( + f"[DEBUG] input `prompt` type\ + ({type(text_to_image_input.prompt)}): {text_to_image_input.prompt}" + ) + + if input_name == "negative_prompt": + input_tensor = deserialize_bytes_tensor(b_input_tensor) + text_to_image_input.negative_prompt = str( + input_tensor[0].decode("utf-8") + ) + print( + f"[DEBUG] input `negative_prompt` type\ + ({type(text_to_image_input.negative_prompt)}): {text_to_image_input.negative_prompt}" + ) + + if input_name == "steps": + text_to_image_input.steps = int.from_bytes(b_input_tensor, "little") + print( + f"[DEBUG] input `steps` type\ + ({type(text_to_image_input.steps)}): {text_to_image_input.steps}" + ) + + if input_name == "seed": + text_to_image_input.seed = int.from_bytes(b_input_tensor, "little") + print( + f"[DEBUG] input `seed` type\ + ({type(text_to_image_input.seed)}): {text_to_image_input.seed}" + ) + + if input_name == "guidance_scale": + text_to_image_input.guidance_scale = struct.unpack("f", b_input_tensor)[ + 0 + ] + print( + f"[DEBUG] input `guidance_scale` type\ + ({type(text_to_image_input.guidance_scale)}): {text_to_image_input.guidance_scale}" + ) + text_to_image_input.guidance_scale = round( + text_to_image_input.guidance_scale, 2 + ) + + if input_name == "samples": + text_to_image_input.samples = int.from_bytes(b_input_tensor, "little") + print( + f"[DEBUG] input `samples` type\ + ({type(text_to_image_input.samples)}): {text_to_image_input.samples}" + ) + + if input_name == "extra_params": + input_tensor = deserialize_bytes_tensor(b_input_tensor) + extra_params_str = str(input_tensor[0].decode("utf-8")) + print( + f"[DEBUG] input `extra_params` type\ + ({type(extra_params_str)}): {extra_params_str}" + ) + + try: + text_to_image_input.extra_params = json.loads(extra_params_str) + except json.decoder.JSONDecodeError: + print("[DEBUG] WARNING `extra_params` parsing faield!") + continue + + return text_to_image_input + + @staticmethod + def parse_task_text_to_image_output(image): + return np.asarray(image).tobytes() + class RawIO: @staticmethod