diff --git a/examples/inference/api_server_simple/query_single.py b/examples/inference/api_server_simple/query_single.py index 62bb4dc45..b6d935c9a 100644 --- a/examples/inference/api_server_simple/query_single.py +++ b/examples/inference/api_server_simple/query_single.py @@ -55,7 +55,12 @@ ) args = parser.parse_args() -prompt = "Once upon a time," +# prompt = "Once upon a time," +prompt = [ + {"role": "user", "content": "Which is bigger, the moon or the sun?"}, +] + + config: Dict[str, Union[int, float]] = {} if args.max_new_tokens: config["max_new_tokens"] = int(args.max_new_tokens) diff --git a/llm_on_ray/finetune/finetune_config.py b/llm_on_ray/finetune/finetune_config.py index 2cb9d0ce8..136b698eb 100644 --- a/llm_on_ray/finetune/finetune_config.py +++ b/llm_on_ray/finetune/finetune_config.py @@ -64,7 +64,7 @@ class General(BaseModel): enable_gradient_checkpointing: bool = False chat_template: Optional[str] = None default_chat_template: str = ( - "{{ bos_token }}" + "Below is an instruction that describes a task. Write a response that appropriately completes the request." "{% if messages[0]['role'] == 'system' %}" "{{ raise_exception('System role not supported') }}" "{% endif %}" diff --git a/llm_on_ray/inference/chat_template_process.py b/llm_on_ray/inference/chat_template_process.py index 2f7a64d27..851004b01 100644 --- a/llm_on_ray/inference/chat_template_process.py +++ b/llm_on_ray/inference/chat_template_process.py @@ -14,7 +14,6 @@ # limitations under the License. # from typing import List - from llm_on_ray.inference.api_openai_backend.openai_protocol import ChatMessage @@ -63,14 +62,12 @@ def _extract_messages(self, messages): return texts, images def _prepare_image(self, messages: list): - """Prepare image from history messages.""" from PIL import Image import requests from io import BytesIO import base64 import re - # prepare images images: List = [] for msg in messages: msg = dict(msg) diff --git a/llm_on_ray/inference/models/gpt2.yaml b/llm_on_ray/inference/models/gpt2.yaml index cddc45cd8..9ad098c24 100644 --- a/llm_on_ray/inference/models/gpt2.yaml +++ b/llm_on_ray/inference/models/gpt2.yaml @@ -15,4 +15,3 @@ model_description: tokenizer_name_or_path: gpt2 gpt_base_model: true chat_template: "{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}{% set system_message = messages[0]['content'] %}{% else %}{% set loop_messages = messages %}{% set system_message = false %}{% endif %}{% for message in loop_messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ message['content'].strip() }}{% elif message['role'] == 'assistant' %}{{ message['content'].strip() }}{% endif %}{% endfor %}" - diff --git a/llm_on_ray/inference/predictor_deployment.py b/llm_on_ray/inference/predictor_deployment.py index aab110727..502ad150f 100644 --- a/llm_on_ray/inference/predictor_deployment.py +++ b/llm_on_ray/inference/predictor_deployment.py @@ -340,31 +340,6 @@ def preprocess_prompts(self, input: Union[str, list], tools=None, tool_choice=No else: prompt = self.process_tool.get_prompt(input) return prompt - else: - if isinstance(input, list) and input and isinstance(input[0], dict): - prompt = self.predictor.tokenizer.apply_chat_template(input, tokenize=False) - elif isinstance(input, list) and input and isinstance(input[0], list): - prompt = [ - self.predictor.tokenizer.apply_chat_template(t, tokenize=False) - for t in input - ] - elif isinstance(input, list) and input and isinstance(input[0], ChatMessage): - messages = [] - for chat_message in input: - message = {"role": chat_message.role, "content": chat_message.content} - messages.append(message) - prompt = self.predictor.tokenizer.apply_chat_template( - messages, tokenize=False - ) - elif isinstance(input, list) and input and isinstance(input[0], str): - prompt = input - elif isinstance(input, str): - prompt = input - else: - raise TypeError( - f"Unsupported type {type(input)} for text. Expected dict or list of dicts." - ) - return prompt elif prompt_format == PromptFormat.PROMPTS_FORMAT: raise HTTPException(400, "Invalid prompt format.") return input @@ -411,9 +386,14 @@ async def openai_call( tool_choice=None, ): self.use_openai = True + print("openai_call") + print(input) + print(type(input)) # return prompt or list of prompts preprocessed prompts = self.preprocess_prompts(input, tools, tool_choice) + print(prompts) + print(type(prompts)) # Handle streaming response if streaming_response: @@ -421,53 +401,3 @@ async def openai_call( yield result else: yield await self.handle_non_streaming(prompts, config) - - def _extract_messages(self, messages): - texts, images = [], [] - for message in messages: - if message["role"] == "user" and isinstance(message["content"], list): - texts.append({"role": "user", "content": message["content"][0]["text"]}) - images.append( - {"role": "user", "content": message["content"][1]["image_url"]["url"]} - ) - else: - texts.append(message) - return texts, images - - def _prepare_image(self, messages: list): - """Prepare image from history messages.""" - from PIL import Image - import requests - from io import BytesIO - import base64 - import re - - # prepare images - images: List = [] - if isinstance(messages[0], List): - for i in range(len(messages)): - for msg in messages[i]: - msg = dict(msg) - content = msg["content"] - if "url" not in content: - continue - is_data = len(re.findall("^data:image/.+;base64,", content["url"])) > 0 - if is_data: - encoded_str = re.sub("^data:image/.+;base64,", "", content["url"]) - images[i].append(Image.open(BytesIO(base64.b64decode(encoded_str)))) - else: - images[i].append(Image.open(requests.get(content["url"], stream=True).raw)) - elif isinstance(messages[0], dict): - for msg in messages: - msg = dict(msg) - content = msg["content"] - if "url" not in content: - continue - is_data = len(re.findall("^data:image/.+;base64,", content["url"])) > 0 - if is_data: - encoded_str = re.sub("^data:image/.+;base64,", "", content["url"]) - images.append(Image.open(BytesIO(base64.b64decode(encoded_str)))) - else: - images.append(Image.open(requests.get(content["url"], stream=True).raw)) - - return images diff --git a/llm_on_ray/inference/utils.py b/llm_on_ray/inference/utils.py index 5a8db0401..56b9146e5 100644 --- a/llm_on_ray/inference/utils.py +++ b/llm_on_ray/inference/utils.py @@ -162,13 +162,13 @@ def is_cpu_without_ipex(infer_conf: InferenceConfig) -> bool: return (not infer_conf.ipex.enabled) and infer_conf.device == DEVICE_CPU -def get_prompt_format(input: Union[List[str], List[dict], List[List[dict]], List[ChatMessage]]): +def get_prompt_format(input: Union[List[str], List[dict], List[ChatMessage]]): chat_format = True prompts_format = True for item in input: if isinstance(item, str): chat_format = False - elif isinstance(item, dict) or isinstance(item, ChatMessage) or isinstance(item, list): + elif isinstance(item, dict) or isinstance(item, ChatMessage): prompts_format = False else: chat_format = False