From 8e35481a6469a85240c503eb28d1e4e2b1982601 Mon Sep 17 00:00:00 2001 From: bodza Date: Tue, 17 Oct 2023 17:32:22 +0200 Subject: [PATCH 1/9] Added top_k & penaltys --- fastchat/protocol/api_protocol.py | 3 +++ fastchat/protocol/openai_api_protocol.py | 1 + fastchat/serve/openai_api_server.py | 24 +++++++++++++++++++++++- fastchat/serve/vllm_worker.py | 8 ++++++++ 4 files changed, 35 insertions(+), 1 deletion(-) diff --git a/fastchat/protocol/api_protocol.py b/fastchat/protocol/api_protocol.py index 7dc8fe1c30..cf7264feea 100644 --- a/fastchat/protocol/api_protocol.py +++ b/fastchat/protocol/api_protocol.py @@ -53,12 +53,15 @@ class APIChatCompletionRequest(BaseModel): messages: Union[str, List[Dict[str, str]]] temperature: Optional[float] = 0.7 top_p: Optional[float] = 1.0 + top_k: Optional[int] = 1.0 n: Optional[int] = 1 max_tokens: Optional[int] = None stop: Optional[Union[str, List[str]]] = None stream: Optional[bool] = False user: Optional[str] = None repetition_penalty: Optional[float] = 1.0 + frequency_penalty: Optional[float] = 1.0 + presence_penalty: Optional[float] = 1.0 class ChatMessage(BaseModel): diff --git a/fastchat/protocol/openai_api_protocol.py b/fastchat/protocol/openai_api_protocol.py index 6232e8b9b8..9eba65f2c1 100644 --- a/fastchat/protocol/openai_api_protocol.py +++ b/fastchat/protocol/openai_api_protocol.py @@ -146,6 +146,7 @@ class CompletionRequest(BaseModel): stop: Optional[Union[str, List[str]]] = None stream: Optional[bool] = False top_p: Optional[float] = 1.0 + top_k = Optional[int], logprobs: Optional[int] = None echo: Optional[bool] = False presence_penalty: Optional[float] = 0.0 diff --git a/fastchat/serve/openai_api_server.py b/fastchat/serve/openai_api_server.py index 9743bde263..b1b2c09087 100644 --- a/fastchat/serve/openai_api_server.py +++ b/fastchat/serve/openai_api_server.py @@ -192,6 +192,11 @@ def check_requests(request) -> Optional[JSONResponse]: ErrorCode.PARAM_OUT_OF_RANGE, f"{request.top_p} is greater than the maximum of 1 - 'temperature'", ) + if request.top_k is not None and (request.top_k > -1 and request.top_k < 1): + return create_error_response( + ErrorCode.PARAM_OUT_OF_RANGE, + f"{request.top_k} is out of Range. Either set top_k to -1 or >=1.", + ) if request.stop is not None and ( not isinstance(request.stop, str) and not isinstance(request.stop, list) ): @@ -233,6 +238,9 @@ async def get_gen_params( *, temperature: float, top_p: float, + top_k: Optional[int], + presence_penalty: Optional[float], + frequency_penalty: Optional[float], max_tokens: Optional[int], echo: Optional[bool], stop: Optional[Union[str, List[str]]], @@ -275,6 +283,9 @@ async def get_gen_params( "prompt": prompt, "temperature": temperature, "top_p": top_p, + "top_k": top_k, + "presence_penalty": presence_penalty, + "frequency_penalty": frequency_penalty, "max_new_tokens": max_tokens, "echo": echo, "stop_token_ids": conv.stop_token_ids, @@ -351,7 +362,9 @@ async def create_chat_completion(request: ChatCompletionRequest): worker_addr, request.messages, temperature=request.temperature, - top_p=request.top_p, + top_k=request.top_k, + presence_penalty=request.presence_penalty, + frequency_penalty=request.frequency_penalty, max_tokens=request.max_tokens, echo=False, stop=request.stop, @@ -484,6 +497,9 @@ async def create_completion(request: CompletionRequest): text, temperature=request.temperature, top_p=request.top_p, + top_k=request.top_k, + frequency_penalty=request.frequency_penalty, + presence_penalty=request.presence_penalty, max_tokens=request.max_tokens, echo=request.echo, stop=request.stop, @@ -536,6 +552,9 @@ async def generate_completion_stream_generator( text, temperature=request.temperature, top_p=request.top_p, + top_k = request.top_k, + presence_penalty = request.presence_penalty, + frequency_penalty = request.frequency_penalty, max_tokens=request.max_tokens, echo=request.echo, stop=request.stop, @@ -715,6 +734,9 @@ async def create_chat_completion(request: APIChatCompletionRequest): request.messages, temperature=request.temperature, top_p=request.top_p, + top_k = request.top_k, + presence_penalty = request.presence_penalty, + frequency_penalty = request.frequency_penalty, max_tokens=request.max_tokens, echo=False, stop=request.stop, diff --git a/fastchat/serve/vllm_worker.py b/fastchat/serve/vllm_worker.py index eb0bfe26a4..bf878de175 100644 --- a/fastchat/serve/vllm_worker.py +++ b/fastchat/serve/vllm_worker.py @@ -68,6 +68,9 @@ async def generate_stream(self, params): request_id = params.pop("request_id") temperature = float(params.get("temperature", 1.0)) top_p = float(params.get("top_p", 1.0)) + top_k = params.get("top_k", 1.0) + presence_penalty = float(params.get("presence_penalty", 0.0)) + frequency_penalty = float(params.get("frequency_penalty", 0.0)) max_new_tokens = params.get("max_new_tokens", 256) stop_str = params.get("stop", None) stop_token_ids = params.get("stop_token_ids", None) or [] @@ -90,6 +93,7 @@ async def generate_stream(self, params): top_p = max(top_p, 1e-5) if temperature <= 1e-5: top_p = 1.0 + sampling_params = SamplingParams( n=1, temperature=temperature, @@ -97,6 +101,10 @@ async def generate_stream(self, params): use_beam_search=False, stop=list(stop), max_tokens=max_new_tokens, + top_k=top_k, + presence_penalty=presence_penalty, + frequency_penalty=frequency_penalty + ) results_generator = engine.generate(context, sampling_params, request_id) From 3b9bbbbe3a8cebd5085b13d38a7156f27ef4b933 Mon Sep 17 00:00:00 2001 From: bodza Date: Tue, 17 Oct 2023 17:57:10 +0200 Subject: [PATCH 2/9] Fixed top_k definition --- fastchat/protocol/api_protocol.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/fastchat/protocol/api_protocol.py b/fastchat/protocol/api_protocol.py index cf7264feea..2381596506 100644 --- a/fastchat/protocol/api_protocol.py +++ b/fastchat/protocol/api_protocol.py @@ -53,15 +53,15 @@ class APIChatCompletionRequest(BaseModel): messages: Union[str, List[Dict[str, str]]] temperature: Optional[float] = 0.7 top_p: Optional[float] = 1.0 - top_k: Optional[int] = 1.0 + top_k: Optional[int] = 1 n: Optional[int] = 1 max_tokens: Optional[int] = None stop: Optional[Union[str, List[str]]] = None stream: Optional[bool] = False user: Optional[str] = None repetition_penalty: Optional[float] = 1.0 - frequency_penalty: Optional[float] = 1.0 - presence_penalty: Optional[float] = 1.0 + frequency_penalty: Optional[float] = 0.0 + presence_penalty: Optional[float] = 0.0 class ChatMessage(BaseModel): From 27fad0da0670cd025b80b82a1261ef2839f92e49 Mon Sep 17 00:00:00 2001 From: bodza Date: Tue, 17 Oct 2023 18:03:06 +0200 Subject: [PATCH 3/9] Added top_k to completion request --- fastchat/protocol/api_protocol.py | 1 + 1 file changed, 1 insertion(+) diff --git a/fastchat/protocol/api_protocol.py b/fastchat/protocol/api_protocol.py index 2381596506..d750cc1a15 100644 --- a/fastchat/protocol/api_protocol.py +++ b/fastchat/protocol/api_protocol.py @@ -133,6 +133,7 @@ class CompletionRequest(BaseModel): stop: Optional[Union[str, List[str]]] = None stream: Optional[bool] = False top_p: Optional[float] = 1.0 + top_k: Optional[int] = 1 logprobs: Optional[int] = None echo: Optional[bool] = False presence_penalty: Optional[float] = 0.0 From e119ca9f7b8deba5384347775e07a65ce70ca7f0 Mon Sep 17 00:00:00 2001 From: bodza Date: Tue, 17 Oct 2023 18:07:58 +0200 Subject: [PATCH 4/9] typo , --- fastchat/protocol/openai_api_protocol.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/fastchat/protocol/openai_api_protocol.py b/fastchat/protocol/openai_api_protocol.py index 9eba65f2c1..89ad006006 100644 --- a/fastchat/protocol/openai_api_protocol.py +++ b/fastchat/protocol/openai_api_protocol.py @@ -146,7 +146,7 @@ class CompletionRequest(BaseModel): stop: Optional[Union[str, List[str]]] = None stream: Optional[bool] = False top_p: Optional[float] = 1.0 - top_k = Optional[int], + top_k = Optional[int] = 1 logprobs: Optional[int] = None echo: Optional[bool] = False presence_penalty: Optional[float] = 0.0 From ae6eed74ecff5e9238f9edc4e8ca3cd31ae3a523 Mon Sep 17 00:00:00 2001 From: bodza Date: Tue, 17 Oct 2023 18:09:36 +0200 Subject: [PATCH 5/9] Fixed 2nd typo smh --- fastchat/protocol/openai_api_protocol.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/fastchat/protocol/openai_api_protocol.py b/fastchat/protocol/openai_api_protocol.py index 89ad006006..5e1062904c 100644 --- a/fastchat/protocol/openai_api_protocol.py +++ b/fastchat/protocol/openai_api_protocol.py @@ -146,7 +146,7 @@ class CompletionRequest(BaseModel): stop: Optional[Union[str, List[str]]] = None stream: Optional[bool] = False top_p: Optional[float] = 1.0 - top_k = Optional[int] = 1 + top_k: Optional[int] = 1 logprobs: Optional[int] = None echo: Optional[bool] = False presence_penalty: Optional[float] = 0.0 From 7d9c46b500bc9f658519d6fd9e7c2d9e62d550a8 Mon Sep 17 00:00:00 2001 From: bodza Date: Tue, 17 Oct 2023 18:13:10 +0200 Subject: [PATCH 6/9] added topk to ChatCompletionRequest --- fastchat/protocol/openai_api_protocol.py | 1 + 1 file changed, 1 insertion(+) diff --git a/fastchat/protocol/openai_api_protocol.py b/fastchat/protocol/openai_api_protocol.py index 5e1062904c..588c010bfc 100644 --- a/fastchat/protocol/openai_api_protocol.py +++ b/fastchat/protocol/openai_api_protocol.py @@ -53,6 +53,7 @@ class ChatCompletionRequest(BaseModel): messages: Union[str, List[Dict[str, str]]] temperature: Optional[float] = 0.7 top_p: Optional[float] = 1.0 + top_k: Optional[int] = 1 n: Optional[int] = 1 max_tokens: Optional[int] = None stop: Optional[Union[str, List[str]]] = None From 8bed52fc9c854d2539643e9a187b7cb811b187cf Mon Sep 17 00:00:00 2001 From: bodza Date: Tue, 17 Oct 2023 18:15:20 +0200 Subject: [PATCH 7/9] Added top_p to @app.post("/v1/chat/completions", dependencies=[Depends(check_api_key)]) --- fastchat/serve/openai_api_server.py | 1 + 1 file changed, 1 insertion(+) diff --git a/fastchat/serve/openai_api_server.py b/fastchat/serve/openai_api_server.py index b1b2c09087..7b72cc5c1d 100644 --- a/fastchat/serve/openai_api_server.py +++ b/fastchat/serve/openai_api_server.py @@ -362,6 +362,7 @@ async def create_chat_completion(request: ChatCompletionRequest): worker_addr, request.messages, temperature=request.temperature, + top_p=request.top_p, top_k=request.top_k, presence_penalty=request.presence_penalty, frequency_penalty=request.frequency_penalty, From 09fa8063d2ed233b50d5a1c1abb9c70bc562cb48 Mon Sep 17 00:00:00 2001 From: bodza Date: Tue, 24 Oct 2023 17:03:50 +0200 Subject: [PATCH 8/9] ran format.sh --- fastchat/serve/openai_api_server.py | 16 ++++++++-------- fastchat/serve/vllm_worker.py | 7 +++---- 2 files changed, 11 insertions(+), 12 deletions(-) diff --git a/fastchat/serve/openai_api_server.py b/fastchat/serve/openai_api_server.py index 7b72cc5c1d..84fda64e8f 100644 --- a/fastchat/serve/openai_api_server.py +++ b/fastchat/serve/openai_api_server.py @@ -365,7 +365,7 @@ async def create_chat_completion(request: ChatCompletionRequest): top_p=request.top_p, top_k=request.top_k, presence_penalty=request.presence_penalty, - frequency_penalty=request.frequency_penalty, + frequency_penalty=request.frequency_penalty, max_tokens=request.max_tokens, echo=False, stop=request.stop, @@ -500,7 +500,7 @@ async def create_completion(request: CompletionRequest): top_p=request.top_p, top_k=request.top_k, frequency_penalty=request.frequency_penalty, - presence_penalty=request.presence_penalty, + presence_penalty=request.presence_penalty, max_tokens=request.max_tokens, echo=request.echo, stop=request.stop, @@ -553,9 +553,9 @@ async def generate_completion_stream_generator( text, temperature=request.temperature, top_p=request.top_p, - top_k = request.top_k, - presence_penalty = request.presence_penalty, - frequency_penalty = request.frequency_penalty, + top_k=request.top_k, + presence_penalty=request.presence_penalty, + frequency_penalty=request.frequency_penalty, max_tokens=request.max_tokens, echo=request.echo, stop=request.stop, @@ -735,9 +735,9 @@ async def create_chat_completion(request: APIChatCompletionRequest): request.messages, temperature=request.temperature, top_p=request.top_p, - top_k = request.top_k, - presence_penalty = request.presence_penalty, - frequency_penalty = request.frequency_penalty, + top_k=request.top_k, + presence_penalty=request.presence_penalty, + frequency_penalty=request.frequency_penalty, max_tokens=request.max_tokens, echo=False, stop=request.stop, diff --git a/fastchat/serve/vllm_worker.py b/fastchat/serve/vllm_worker.py index bf878de175..ec5626f075 100644 --- a/fastchat/serve/vllm_worker.py +++ b/fastchat/serve/vllm_worker.py @@ -93,7 +93,7 @@ async def generate_stream(self, params): top_p = max(top_p, 1e-5) if temperature <= 1e-5: top_p = 1.0 - + sampling_params = SamplingParams( n=1, temperature=temperature, @@ -102,9 +102,8 @@ async def generate_stream(self, params): stop=list(stop), max_tokens=max_new_tokens, top_k=top_k, - presence_penalty=presence_penalty, - frequency_penalty=frequency_penalty - + presence_penalty=presence_penalty, + frequency_penalty=frequency_penalty, ) results_generator = engine.generate(context, sampling_params, request_id) From 43423ebb061092af282ad03a0bc9e6f1ce07e381 Mon Sep 17 00:00:00 2001 From: bodza Date: Tue, 24 Oct 2023 17:09:31 +0200 Subject: [PATCH 9/9] Defaults from vllm --- fastchat/protocol/api_protocol.py | 4 ++-- fastchat/protocol/openai_api_protocol.py | 4 ++-- fastchat/serve/vllm_worker.py | 2 +- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/fastchat/protocol/api_protocol.py b/fastchat/protocol/api_protocol.py index d750cc1a15..2dc99449dc 100644 --- a/fastchat/protocol/api_protocol.py +++ b/fastchat/protocol/api_protocol.py @@ -53,7 +53,7 @@ class APIChatCompletionRequest(BaseModel): messages: Union[str, List[Dict[str, str]]] temperature: Optional[float] = 0.7 top_p: Optional[float] = 1.0 - top_k: Optional[int] = 1 + top_k: Optional[int] = -1 n: Optional[int] = 1 max_tokens: Optional[int] = None stop: Optional[Union[str, List[str]]] = None @@ -133,7 +133,7 @@ class CompletionRequest(BaseModel): stop: Optional[Union[str, List[str]]] = None stream: Optional[bool] = False top_p: Optional[float] = 1.0 - top_k: Optional[int] = 1 + top_k: Optional[int] = -1 logprobs: Optional[int] = None echo: Optional[bool] = False presence_penalty: Optional[float] = 0.0 diff --git a/fastchat/protocol/openai_api_protocol.py b/fastchat/protocol/openai_api_protocol.py index 588c010bfc..ca42b6aa43 100644 --- a/fastchat/protocol/openai_api_protocol.py +++ b/fastchat/protocol/openai_api_protocol.py @@ -53,7 +53,7 @@ class ChatCompletionRequest(BaseModel): messages: Union[str, List[Dict[str, str]]] temperature: Optional[float] = 0.7 top_p: Optional[float] = 1.0 - top_k: Optional[int] = 1 + top_k: Optional[int] = -1 n: Optional[int] = 1 max_tokens: Optional[int] = None stop: Optional[Union[str, List[str]]] = None @@ -147,7 +147,7 @@ class CompletionRequest(BaseModel): stop: Optional[Union[str, List[str]]] = None stream: Optional[bool] = False top_p: Optional[float] = 1.0 - top_k: Optional[int] = 1 + top_k: Optional[int] = -1 logprobs: Optional[int] = None echo: Optional[bool] = False presence_penalty: Optional[float] = 0.0 diff --git a/fastchat/serve/vllm_worker.py b/fastchat/serve/vllm_worker.py index ec5626f075..824a9c8335 100644 --- a/fastchat/serve/vllm_worker.py +++ b/fastchat/serve/vllm_worker.py @@ -68,7 +68,7 @@ async def generate_stream(self, params): request_id = params.pop("request_id") temperature = float(params.get("temperature", 1.0)) top_p = float(params.get("top_p", 1.0)) - top_k = params.get("top_k", 1.0) + top_k = params.get("top_k", -1.0) presence_penalty = float(params.get("presence_penalty", 0.0)) frequency_penalty = float(params.get("frequency_penalty", 0.0)) max_new_tokens = params.get("max_new_tokens", 256)