Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 33 additions & 1 deletion instill/helpers/const.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from enum import Enum
from typing import Any, Dict
from typing import Any, Dict, Union

import numpy as np


class DataType(Enum):
Expand Down Expand Up @@ -29,10 +31,40 @@ class TextGenerationInput:


class TextToImageInput:
prompt_image: Union[np.ndarray, None] = None
prompt = ""
negative_prompt = ""
steps = 5
guidance_scale = 7.5
seed = 0
samples = 1
extra_params: Dict[str, str] = {}


class ImageToImageInput:
prompt_image: Union[np.ndarray, None] = None
prompt = ""
steps = 5
guidance_scale = 7.5
seed = 0
samples = 1
extra_params: Dict[str, str] = {}


class TextGenerationChatInput:
conversation = ""
max_new_tokens = 100
top_k = 1
temperature = 0.8
random_seed = 0
extra_params: Dict[str, str] = {}


class VisualQuestionAnsweringInput:
prompt_image: Union[np.ndarray, None] = None
prompt = ""
max_new_tokens = 100
top_k = 1
temperature = 0.8
random_seed = 0
extra_params: Dict[str, str] = {}
290 changes: 287 additions & 3 deletions instill/helpers/ray_io.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,18 @@
import io
import json
import struct
from typing import List

import numpy as np
from PIL import Image

from instill.helpers.const import TextGenerationInput, TextToImageInput
from instill.helpers.const import (
ImageToImageInput,
TextGenerationChatInput,
TextGenerationInput,
TextToImageInput,
VisualQuestionAnsweringInput,
)


def serialize_byte_tensor(input_tensor):
Expand Down Expand Up @@ -101,12 +109,12 @@ def parse_task_text_generation_input(request) -> TextGenerationInput:
)

if input_name == "max_new_tokens":
text_generation_inputmax_new_tokens = int.from_bytes(
text_generation_input.max_new_tokens = int.from_bytes(
b_input_tensor, "little"
)
print(
f"[DEBUG] input `max_new_tokens` type\
({type(text_generation_inputmax_new_tokens)}): {text_generation_inputmax_new_tokens}"
({type(text_generation_input.max_new_tokens)}): {text_generation_input.max_new_tokens}"
)

if input_name == "top_k":
Expand Down Expand Up @@ -260,6 +268,282 @@ def parse_task_text_to_image_input(request) -> TextToImageInput:
def parse_task_text_to_image_output(image):
return np.asarray(image).tobytes()

@staticmethod
def parse_task_image_to_image_input(request) -> ImageToImageInput:
image_to_image_input = ImageToImageInput()

for i, b_input_tensor in zip(request.inputs, request.raw_input_contents):
input_name = i.name

if input_name == "prompt_image":
input_tensors = deserialize_bytes_tensor(b_input_tensor)
images = []
for enc in input_tensors:
pil_img = Image.open(io.BytesIO(enc.astype(bytes))) # RGB
image = np.array(pil_img)
if len(image.shape) == 2: # gray image
raise ValueError(
f"The image shape with {image.shape} is "
f"not in acceptable"
)
images.append(image)
image_to_image_input.prompt_image = images[0]
print(
f"[DEBUG] input `prompt_image` type\
({type(image_to_image_input.prompt_image)}): {image_to_image_input.prompt_image}"
)

if input_name == "prompt":
input_tensor = deserialize_bytes_tensor(b_input_tensor)
image_to_image_input.prompt = str(input_tensor[0].decode("utf-8"))
print(
f"[DEBUG] input `prompt` type\
({type(image_to_image_input.prompt)}): {image_to_image_input.prompt}"
)

if input_name == "steps":
image_to_image_input.steps = int.from_bytes(b_input_tensor, "little")
print(
f"[DEBUG] input `steps` type\
({type(image_to_image_input.steps)}): {image_to_image_input.steps}"
)

if input_name == "seed":
image_to_image_input.seed = int.from_bytes(b_input_tensor, "little")
print(
f"[DEBUG] input `seed` type\
({type(image_to_image_input.seed)}): {image_to_image_input.seed}"
)

if input_name == "guidance_scale":
image_to_image_input.guidance_scale = struct.unpack(
"f", b_input_tensor
)[0]
print(
f"[DEBUG] input `guidance_scale` type\
({type(image_to_image_input.guidance_scale)}): {image_to_image_input.guidance_scale}"
)
image_to_image_input.guidance_scale = round(
image_to_image_input.guidance_scale, 2
)

if input_name == "samples":
image_to_image_input.samples = int.from_bytes(b_input_tensor, "little")
print(
f"[DEBUG] input `samples` type\
({type(image_to_image_input.samples)}): {image_to_image_input.samples}"
)

if input_name == "extra_params":
input_tensor = deserialize_bytes_tensor(b_input_tensor)
extra_params_str = str(input_tensor[0].decode("utf-8"))
print(
f"[DEBUG] input `extra_params` type\
({type(extra_params_str)}): {extra_params_str}"
)

try:
image_to_image_input.extra_params = json.loads(extra_params_str)
except json.decoder.JSONDecodeError:
print("[DEBUG] WARNING `extra_params` parsing faield!")
continue

return image_to_image_input

@staticmethod
def parse_task_image_to_image_output(image):
return np.asarray(image).tobytes()

@staticmethod
def parse_task_text_generation_chat_input(request) -> TextGenerationChatInput:
text_generation_chat_input = TextGenerationChatInput()

for i, b_input_tensor in zip(request.inputs, request.raw_input_contents):
input_name = i.name

if input_name == "conversation":
input_tensor = deserialize_bytes_tensor(b_input_tensor)
text_generation_chat_input.conversation = str(
input_tensor[0].decode("utf-8")
)
print(
f"[DEBUG] input `conversation` type\
({type(text_generation_chat_input.conversation)}): {text_generation_chat_input.conversation}"
)

if input_name == "max_new_tokens":
text_generation_chat_input.max_new_tokens = int.from_bytes(
b_input_tensor, "little"
)
print(
f"[DEBUG] input `max_new_tokens` type\
({type(text_generation_chat_input.max_new_tokens)}):\
{text_generation_chat_input.max_new_tokens}"
)

if input_name == "top_k":
text_generation_chat_input.top_k = int.from_bytes(
b_input_tensor, "little"
)
print(
f"[DEBUG] input `top_k` type\
({type(text_generation_chat_input.top_k)}):\
{text_generation_chat_input.top_k}"
)

if input_name == "temperature":
text_generation_chat_input.temperature = struct.unpack(
"f", b_input_tensor
)[0]
print(
f"[DEBUG] input `temperature` type\
({type(text_generation_chat_input.temperature)}):\
{text_generation_chat_input.temperature}"
)
text_generation_chat_input.temperature = round(
text_generation_chat_input.temperature, 2
)

if input_name == "random_seed":
text_generation_chat_input.random_seed = int.from_bytes(
b_input_tensor, "little"
)
print(
f"[DEBUG] input `random_seed` type\
({type(text_generation_chat_input.random_seed)}):\
{text_generation_chat_input.random_seed}"
)

if input_name == "extra_params":
input_tensor = deserialize_bytes_tensor(b_input_tensor)
extra_params_str = str(input_tensor[0].decode("utf-8"))
print(
f"[DEBUG] input `extra_params` type\
({type(extra_params_str)}): {extra_params_str}"
)

try:
text_generation_chat_input.extra_params = json.loads(
extra_params_str
)
except json.decoder.JSONDecodeError:
print("[DEBUG] WARNING `extra_params` parsing faield!")
continue

return text_generation_chat_input

@staticmethod
def parse_task_text_generation_chat_output(sequences: list):
text_outputs = [seq["generated_text"].encode("utf-8") for seq in sequences]

return serialize_byte_tensor(np.asarray(text_outputs))

@staticmethod
def parse_task_visual_question_answering_input(
request,
) -> VisualQuestionAnsweringInput:
text_visual_question_answering_input = VisualQuestionAnsweringInput()

for i, b_input_tensor in zip(request.inputs, request.raw_input_contents):
input_name = i.name

if input_name == "prompt_image":
input_tensors = deserialize_bytes_tensor(b_input_tensor)
images = []
for enc in input_tensors:
pil_img = Image.open(io.BytesIO(enc.astype(bytes))) # RGB
image = np.array(pil_img)
if len(image.shape) == 2: # gray image
raise ValueError(
f"The image shape with {image.shape} is "
f"not in acceptable"
)
images.append(image)
text_visual_question_answering_input.prompt_image = images[0]
print(
f"[DEBUG] input `prompt_image` type\
({type(text_visual_question_answering_input.prompt_image)}): \
{text_visual_question_answering_input.prompt_image}"
)

if input_name == "prompt":
input_tensor = deserialize_bytes_tensor(b_input_tensor)
text_visual_question_answering_input.prompt = str(
input_tensor[0].decode("utf-8")
)
print(
f"[DEBUG] input `prompt` type\
({type(text_visual_question_answering_input.prompt)}):\
{text_visual_question_answering_input.prompt}"
)

if input_name == "max_new_tokens":
text_visual_question_answering_input.max_new_tokens = int.from_bytes(
b_input_tensor, "little"
)
print(
f"[DEBUG] input `max_new_tokens` type\
({type(text_visual_question_answering_input.max_new_tokens)}):\
{text_visual_question_answering_input.max_new_tokens}"
)

if input_name == "top_k":
text_visual_question_answering_input.top_k = int.from_bytes(
b_input_tensor, "little"
)
print(
f"[DEBUG] input `top_k` type\
({type(text_visual_question_answering_input.top_k)}):\
{text_visual_question_answering_input.top_k}"
)

if input_name == "temperature":
text_visual_question_answering_input.temperature = struct.unpack(
"f", b_input_tensor
)[0]
print(
f"[DEBUG] input `temperature` type\
({type(text_visual_question_answering_input.temperature)}):\
{text_visual_question_answering_input.temperature}"
)
text_visual_question_answering_input.temperature = round(
text_visual_question_answering_input.temperature, 2
)

if input_name == "random_seed":
text_visual_question_answering_input.random_seed = int.from_bytes(
b_input_tensor, "little"
)
print(
f"[DEBUG] input `random_seed` type\
({type(text_visual_question_answering_input.random_seed)}):\
{text_visual_question_answering_input.random_seed}"
)

if input_name == "extra_params":
input_tensor = deserialize_bytes_tensor(b_input_tensor)
extra_params_str = str(input_tensor[0].decode("utf-8"))
print(
f"[DEBUG] input `extra_params` type\
({type(extra_params_str)}): {extra_params_str}"
)

try:
text_visual_question_answering_input.extra_params = json.loads(
extra_params_str
)
except json.decoder.JSONDecodeError:
print("[DEBUG] WARNING `extra_params` parsing faield!")
continue

return text_visual_question_answering_input

@staticmethod
def parse_task_visual_question_answering_output(sequences: list):
text_outputs = [seq["generated_text"].encode("utf-8") for seq in sequences]

return serialize_byte_tensor(np.asarray(text_outputs))


class RawIO:
@staticmethod
Expand Down
Loading