Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support using chat template in the tokenizer and use the generation config in model for vllm and openai endpoint #3351

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 60 additions & 21 deletions fastchat/serve/openai_api_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,24 +284,47 @@ async def get_gen_params(
use_beam_search: Optional[bool] = None,
) -> Dict[str, Any]:
conv = await get_conv(model_name, worker_addr)
conv = Conversation(
name=conv["name"],
system_template=conv["system_template"],
system_message=conv["system_message"],
roles=conv["roles"],
messages=list(conv["messages"]), # prevent in-place modification
offset=conv["offset"],
sep_style=SeparatorStyle(conv["sep_style"]),
sep=conv["sep"],
sep2=conv["sep2"],
stop_str=conv["stop_str"],
stop_token_ids=conv["stop_token_ids"],
)

images = []
conv_stop_str = None
conv_stop_token_ids = None

if isinstance(messages, str):
prompt = messages
images = []
elif "chat_template" in conv and conv["chat_template"]["chat_template"]:
conv = conv["chat_template"]
prompt = await apply_chat_template(worker_addr, messages)
conv_stop_str = conv.get("eos_token")
generation_config = conv.get("generation_config")
if generation_config:
conv_stop_token_ids = conv.get("eos_token_id")
if isinstance(conv_stop_token_ids, int):
conv_stop_token_ids = [conv_stop_token_ids]

if temperature is None:
temperature = generation_config.get("temperature")
if top_p is None:
top_p = generation_config.get("top_p")
if top_k is None:
top_k = generation_config.get("top_k")
if max_tokens is None:
max_tokens = generation_config.get("max_tokens")
else:
conv = conv["conv"]
conv = Conversation(
name=conv["name"],
system_template=conv["system_template"],
system_message=conv["system_message"],
roles=conv["roles"],
messages=list(conv["messages"]), # prevent in-place modification
offset=conv["offset"],
sep_style=SeparatorStyle(conv["sep_style"]),
sep=conv["sep"],
sep2=conv["sep2"],
stop_str=conv["stop_str"],
stop_token_ids=conv["stop_token_ids"],
)

for message in messages:
msg_role = message["role"]
if msg_role == "system":
Expand Down Expand Up @@ -330,10 +353,12 @@ async def get_gen_params(
else:
raise ValueError(f"Unknown role: {msg_role}")

# Add a blank message for the assistant.
conv.append_message(conv.roles[1], None)
prompt = conv.get_prompt()
images = conv.get_images()
# Add a blank message for the assistant.
conv.append_message(conv.roles[1], None)
prompt = conv.get_prompt()
images = conv.get_images()
conv_stop_str = conv.stop_str
conv_stop_token_ids = conv.stop_token_ids

gen_params = {
"model": model_name,
Expand All @@ -346,7 +371,7 @@ async def get_gen_params(
"frequency_penalty": frequency_penalty,
"max_new_tokens": max_tokens,
"echo": echo,
"stop_token_ids": conv.stop_token_ids,
"stop_token_ids": conv_stop_token_ids,
}

if len(images) > 0:
Expand All @@ -359,7 +384,7 @@ async def get_gen_params(

new_stop = set()
_add_to_set(stop, new_stop)
_add_to_set(conv.stop_str, new_stop)
_add_to_set(conv_stop_str, new_stop)

gen_params["stop"] = list(new_stop)

Expand Down Expand Up @@ -391,12 +416,26 @@ async def get_conv(model_name: str, worker_addr: str):
conv_template = conv_template_map.get((worker_addr, model_name))
if conv_template is None:
conv_template = await fetch_remote(
worker_addr + "/worker_get_conv_template", {"model": model_name}, "conv"
worker_addr + "/worker_get_conv_template", {"model": model_name}
)
conv_template = json.loads(conv_template)
conv_template_map[(worker_addr, model_name)] = conv_template
return conv_template


async def apply_chat_template(worker_addr: str, conversation):
prompt = await fetch_remote(
worker_addr + "/apply_chat_template",
{
"conversation": conversation,
"tokenize": False,
"add_generation_prompt": True,
},
"prompt",
)
return prompt


@app.get("/v1/models", dependencies=[Depends(check_api_key)])
async def show_available_models():
controller_address = app_settings.controller_address
Expand Down
67 changes: 66 additions & 1 deletion fastchat/serve/vllm_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@
import argparse
import asyncio
import json
from typing import List
import codecs
from typing import List, Optional

from fastapi import FastAPI, Request, BackgroundTasks
from fastapi.responses import StreamingResponse, JSONResponse
Expand All @@ -16,6 +17,7 @@
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.sampling_params import SamplingParams
from vllm.utils import random_uuid
from transformers.generation import GenerationConfig

from fastchat.serve.base_model_worker import BaseModelWorker
from fastchat.serve.model_worker import (
Expand Down Expand Up @@ -59,11 +61,45 @@ def __init__(
# and llm_engine.engine.tokenizer was no longer a raw tokenizer
if hasattr(self.tokenizer, "tokenizer"):
self.tokenizer = llm_engine.engine.tokenizer.tokenizer
self._load_chat_template(chat_template=None)
try:
self.generation_config = GenerationConfig.from_pretrained(
model_path, trust_remote_code=True
)
except Exception:
self.generation_config = None
self.context_len = get_context_length(llm_engine.engine.model_config.hf_config)

if not no_register:
self.init_heart_beat()

def _load_chat_template(self, chat_template: Optional[str]):
tokenizer = self.tokenizer

if chat_template is not None:
try:
with open(chat_template, "r") as f:
tokenizer.chat_template = f.read()
except OSError as e:
JINJA_CHARS = "{}\n"
if not any(c in chat_template for c in JINJA_CHARS):
msg = (
f"The supplied chat template ({chat_template}) "
f"looks like a file path, but it failed to be "
f"opened. Reason: {e}"
)
raise ValueError(msg) from e

# If opening a file fails, set chat template to be args to
# ensure we decode so our escape are interpreted correctly
tokenizer.chat_template = codecs.decode(chat_template, "unicode_escape")

logger.info("Using supplied chat template:\n%s", tokenizer.chat_template)
elif tokenizer.chat_template is not None:
logger.info("Using default chat template:\n%s", tokenizer.chat_template)
else:
tokenizer.chat_template = ""

async def generate_stream(self, params):
self.call_ct += 1

Expand Down Expand Up @@ -169,6 +205,28 @@ async def generate_stream(self, params):
if aborted:
break

def get_conv_template(self):
if self.tokenizer.chat_template:
chat_template_kwargs = {
"chat_template": {
"chat_template": self.tokenizer.chat_template,
"eos_token": self.tokenizer.eos_token,
"generation_config": self.generation_config.to_diff_dict()
if self.generation_config
else None,
}
}
else:
chat_template_kwargs = {}

return {
"conv": self.conv,
**chat_template_kwargs,
}

def apply_chat_template(self, params):
return self.tokenizer.apply_chat_template(**params)

async def generate(self, params):
async for x in self.generate_stream(params):
pass
Expand All @@ -195,6 +253,13 @@ async def abort_request() -> None:
return background_tasks


@app.post("/apply_chat_template")
async def api_apply_chat_template(request: Request):
params = await request.json()
prompt = worker.apply_chat_template(params)
return JSONResponse({"prompt": prompt})


@app.post("/worker_generate_stream")
async def api_generate_stream(request: Request):
params = await request.json()
Expand Down
Loading