Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
minmingzhu committed May 6, 2024
1 parent a328494 commit a8e7b38
Show file tree
Hide file tree
Showing 6 changed files with 14 additions and 83 deletions.
7 changes: 6 additions & 1 deletion examples/inference/api_server_simple/query_single.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,12 @@
)

args = parser.parse_args()
prompt = "Once upon a time,"
# prompt = "Once upon a time,"
prompt = [
{"role": "user", "content": "Which is bigger, the moon or the sun?"},
]


config: Dict[str, Union[int, float]] = {}
if args.max_new_tokens:
config["max_new_tokens"] = int(args.max_new_tokens)
Expand Down
2 changes: 1 addition & 1 deletion llm_on_ray/finetune/finetune_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ class General(BaseModel):
enable_gradient_checkpointing: bool = False
chat_template: Optional[str] = None
default_chat_template: str = (
"{{ bos_token }}"
"Below is an instruction that describes a task. Write a response that appropriately completes the request."
"{% if messages[0]['role'] == 'system' %}"
"{{ raise_exception('System role not supported') }}"
"{% endif %}"
Expand Down
3 changes: 0 additions & 3 deletions llm_on_ray/inference/chat_template_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
# limitations under the License.
#
from typing import List

from llm_on_ray.inference.api_openai_backend.openai_protocol import ChatMessage


Expand Down Expand Up @@ -63,14 +62,12 @@ def _extract_messages(self, messages):
return texts, images

def _prepare_image(self, messages: list):
"""Prepare image from history messages."""
from PIL import Image
import requests
from io import BytesIO
import base64
import re

# prepare images
images: List = []
for msg in messages:
msg = dict(msg)
Expand Down
1 change: 0 additions & 1 deletion llm_on_ray/inference/models/gpt2.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -15,4 +15,3 @@ model_description:
tokenizer_name_or_path: gpt2
gpt_base_model: true
chat_template: "{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}{% set system_message = messages[0]['content'] %}{% else %}{% set loop_messages = messages %}{% set system_message = false %}{% endif %}{% for message in loop_messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ message['content'].strip() }}{% elif message['role'] == 'assistant' %}{{ message['content'].strip() }}{% endif %}{% endfor %}"

80 changes: 5 additions & 75 deletions llm_on_ray/inference/predictor_deployment.py
Original file line number Diff line number Diff line change
Expand Up @@ -340,31 +340,6 @@ def preprocess_prompts(self, input: Union[str, list], tools=None, tool_choice=No
else:
prompt = self.process_tool.get_prompt(input)
return prompt
else:
if isinstance(input, list) and input and isinstance(input[0], dict):
prompt = self.predictor.tokenizer.apply_chat_template(input, tokenize=False)
elif isinstance(input, list) and input and isinstance(input[0], list):
prompt = [
self.predictor.tokenizer.apply_chat_template(t, tokenize=False)
for t in input
]
elif isinstance(input, list) and input and isinstance(input[0], ChatMessage):
messages = []
for chat_message in input:
message = {"role": chat_message.role, "content": chat_message.content}
messages.append(message)
prompt = self.predictor.tokenizer.apply_chat_template(
messages, tokenize=False
)
elif isinstance(input, list) and input and isinstance(input[0], str):
prompt = input
elif isinstance(input, str):
prompt = input
else:
raise TypeError(
f"Unsupported type {type(input)} for text. Expected dict or list of dicts."
)
return prompt
elif prompt_format == PromptFormat.PROMPTS_FORMAT:
raise HTTPException(400, "Invalid prompt format.")
return input
Expand Down Expand Up @@ -411,63 +386,18 @@ async def openai_call(
tool_choice=None,
):
self.use_openai = True
print("openai_call")
print(input)
print(type(input))

# return prompt or list of prompts preprocessed
prompts = self.preprocess_prompts(input, tools, tool_choice)
print(prompts)
print(type(prompts))

# Handle streaming response
if streaming_response:
async for result in self.handle_streaming(prompts, config):
yield result
else:
yield await self.handle_non_streaming(prompts, config)

def _extract_messages(self, messages):
texts, images = [], []
for message in messages:
if message["role"] == "user" and isinstance(message["content"], list):
texts.append({"role": "user", "content": message["content"][0]["text"]})
images.append(
{"role": "user", "content": message["content"][1]["image_url"]["url"]}
)
else:
texts.append(message)
return texts, images

def _prepare_image(self, messages: list):
"""Prepare image from history messages."""
from PIL import Image
import requests
from io import BytesIO
import base64
import re

# prepare images
images: List = []
if isinstance(messages[0], List):
for i in range(len(messages)):
for msg in messages[i]:
msg = dict(msg)
content = msg["content"]
if "url" not in content:
continue
is_data = len(re.findall("^data:image/.+;base64,", content["url"])) > 0
if is_data:
encoded_str = re.sub("^data:image/.+;base64,", "", content["url"])
images[i].append(Image.open(BytesIO(base64.b64decode(encoded_str))))
else:
images[i].append(Image.open(requests.get(content["url"], stream=True).raw))
elif isinstance(messages[0], dict):
for msg in messages:
msg = dict(msg)
content = msg["content"]
if "url" not in content:
continue
is_data = len(re.findall("^data:image/.+;base64,", content["url"])) > 0
if is_data:
encoded_str = re.sub("^data:image/.+;base64,", "", content["url"])
images.append(Image.open(BytesIO(base64.b64decode(encoded_str))))
else:
images.append(Image.open(requests.get(content["url"], stream=True).raw))

return images
4 changes: 2 additions & 2 deletions llm_on_ray/inference/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,13 +162,13 @@ def is_cpu_without_ipex(infer_conf: InferenceConfig) -> bool:
return (not infer_conf.ipex.enabled) and infer_conf.device == DEVICE_CPU


def get_prompt_format(input: Union[List[str], List[dict], List[List[dict]], List[ChatMessage]]):
def get_prompt_format(input: Union[List[str], List[dict], List[ChatMessage]]):
chat_format = True
prompts_format = True
for item in input:
if isinstance(item, str):
chat_format = False
elif isinstance(item, dict) or isinstance(item, ChatMessage) or isinstance(item, list):
elif isinstance(item, dict) or isinstance(item, ChatMessage):
prompts_format = False
else:
chat_format = False
Expand Down

0 comments on commit a8e7b38

Please sign in to comment.