Skip to content

Commit

Permalink
Added settings vllm (#2599)
Browse files Browse the repository at this point in the history
Co-authored-by: bodza <bodza@qnovi.de>
Co-authored-by: bodza <sebastian.bodza@qnovi.de>
  • Loading branch information
3 people committed Nov 1, 2023
1 parent 7a31d3b commit d5e4b27
Show file tree
Hide file tree
Showing 4 changed files with 36 additions and 0 deletions.
4 changes: 4 additions & 0 deletions fastchat/protocol/api_protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,12 +53,15 @@ class APIChatCompletionRequest(BaseModel):
messages: Union[str, List[Dict[str, str]]]
temperature: Optional[float] = 0.7
top_p: Optional[float] = 1.0
top_k: Optional[int] = -1
n: Optional[int] = 1
max_tokens: Optional[int] = None
stop: Optional[Union[str, List[str]]] = None
stream: Optional[bool] = False
user: Optional[str] = None
repetition_penalty: Optional[float] = 1.0
frequency_penalty: Optional[float] = 0.0
presence_penalty: Optional[float] = 0.0


class ChatMessage(BaseModel):
Expand Down Expand Up @@ -130,6 +133,7 @@ class CompletionRequest(BaseModel):
stop: Optional[Union[str, List[str]]] = None
stream: Optional[bool] = False
top_p: Optional[float] = 1.0
top_k: Optional[int] = -1
logprobs: Optional[int] = None
echo: Optional[bool] = False
presence_penalty: Optional[float] = 0.0
Expand Down
2 changes: 2 additions & 0 deletions fastchat/protocol/openai_api_protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ class ChatCompletionRequest(BaseModel):
messages: Union[str, List[Dict[str, str]]]
temperature: Optional[float] = 0.7
top_p: Optional[float] = 1.0
top_k: Optional[int] = -1
n: Optional[int] = 1
max_tokens: Optional[int] = None
stop: Optional[Union[str, List[str]]] = None
Expand Down Expand Up @@ -146,6 +147,7 @@ class CompletionRequest(BaseModel):
stop: Optional[Union[str, List[str]]] = None
stream: Optional[bool] = False
top_p: Optional[float] = 1.0
top_k: Optional[int] = -1
logprobs: Optional[int] = None
echo: Optional[bool] = False
presence_penalty: Optional[float] = 0.0
Expand Down
23 changes: 23 additions & 0 deletions fastchat/serve/openai_api_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,11 @@ def check_requests(request) -> Optional[JSONResponse]:
ErrorCode.PARAM_OUT_OF_RANGE,
f"{request.top_p} is greater than the maximum of 1 - 'temperature'",
)
if request.top_k is not None and (request.top_k > -1 and request.top_k < 1):
return create_error_response(
ErrorCode.PARAM_OUT_OF_RANGE,
f"{request.top_k} is out of Range. Either set top_k to -1 or >=1.",
)
if request.stop is not None and (
not isinstance(request.stop, str) and not isinstance(request.stop, list)
):
Expand Down Expand Up @@ -240,6 +245,9 @@ async def get_gen_params(
*,
temperature: float,
top_p: float,
top_k: Optional[int],
presence_penalty: Optional[float],
frequency_penalty: Optional[float],
max_tokens: Optional[int],
echo: Optional[bool],
stop: Optional[Union[str, List[str]]],
Expand Down Expand Up @@ -284,6 +292,9 @@ async def get_gen_params(
"prompt": prompt,
"temperature": temperature,
"top_p": top_p,
"top_k": top_k,
"presence_penalty": presence_penalty,
"frequency_penalty": frequency_penalty,
"max_new_tokens": max_tokens,
"echo": echo,
"stop_token_ids": conv.stop_token_ids,
Expand Down Expand Up @@ -366,6 +377,9 @@ async def create_chat_completion(request: ChatCompletionRequest):
request.messages,
temperature=request.temperature,
top_p=request.top_p,
top_k=request.top_k,
presence_penalty=request.presence_penalty,
frequency_penalty=request.frequency_penalty,
max_tokens=request.max_tokens,
echo=False,
stop=request.stop,
Expand Down Expand Up @@ -498,6 +512,9 @@ async def create_completion(request: CompletionRequest):
text,
temperature=request.temperature,
top_p=request.top_p,
top_k=request.top_k,
frequency_penalty=request.frequency_penalty,
presence_penalty=request.presence_penalty,
max_tokens=request.max_tokens,
echo=request.echo,
stop=request.stop,
Expand Down Expand Up @@ -552,6 +569,9 @@ async def generate_completion_stream_generator(
text,
temperature=request.temperature,
top_p=request.top_p,
top_k=request.top_k,
presence_penalty=request.presence_penalty,
frequency_penalty=request.frequency_penalty,
max_tokens=request.max_tokens,
echo=request.echo,
stop=request.stop,
Expand Down Expand Up @@ -731,6 +751,9 @@ async def create_chat_completion(request: APIChatCompletionRequest):
request.messages,
temperature=request.temperature,
top_p=request.top_p,
top_k=request.top_k,
presence_penalty=request.presence_penalty,
frequency_penalty=request.frequency_penalty,
max_tokens=request.max_tokens,
echo=False,
stop=request.stop,
Expand Down
7 changes: 7 additions & 0 deletions fastchat/serve/vllm_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,9 @@ async def generate_stream(self, params):
request_id = params.pop("request_id")
temperature = float(params.get("temperature", 1.0))
top_p = float(params.get("top_p", 1.0))
top_k = params.get("top_k", -1.0)
presence_penalty = float(params.get("presence_penalty", 0.0))
frequency_penalty = float(params.get("frequency_penalty", 0.0))
max_new_tokens = params.get("max_new_tokens", 256)
stop_str = params.get("stop", None)
stop_token_ids = params.get("stop_token_ids", None) or []
Expand All @@ -92,13 +95,17 @@ async def generate_stream(self, params):
top_p = max(top_p, 1e-5)
if temperature <= 1e-5:
top_p = 1.0

sampling_params = SamplingParams(
n=1,
temperature=temperature,
top_p=top_p,
use_beam_search=use_beam_search,
stop=list(stop),
max_tokens=max_new_tokens,
top_k=top_k,
presence_penalty=presence_penalty,
frequency_penalty=frequency_penalty,
best_of=best_of,
)
results_generator = engine.generate(context, sampling_params, request_id)
Expand Down

0 comments on commit d5e4b27

Please sign in to comment.