Skip to content

Commit

Permalink
add paramter max_tokens to 4 chat api with default value 1024 (#1744)
Browse files Browse the repository at this point in the history
  • Loading branch information
liunux4odoo committed Oct 12, 2023
1 parent 1ac1739 commit cd74812
Show file tree
Hide file tree
Showing 6 changed files with 18 additions and 2 deletions.
2 changes: 2 additions & 0 deletions server/chat/agent_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ async def agent_chat(query: str = Body(..., description="用户输入", examples
stream: bool = Body(False, description="流式输出"),
model_name: str = Body(LLM_MODEL, description="LLM 模型名称。"),
temperature: float = Body(TEMPERATURE, description="LLM 采样温度", ge=0.0, le=1.0),
max_tokens: int = Body(1024, description="限制LLM生成Token数量,当前默认为1024"), # TODO: fastchat更新后默认值设为None,自动使用LLM支持的最大值。
prompt_name: str = Body("agent_chat",
description="使用的prompt模板名称(在configs/prompt_config.py中配置)"),
# top_p: float = Body(TOP_P, description="LLM 核采样。勿与temperature同时设置", gt=0.0, lt=1.0),
Expand All @@ -41,6 +42,7 @@ async def agent_chat_iterator(
model = get_ChatOpenAI(
model_name=model_name,
temperature=temperature,
max_tokens=max_tokens,
)

prompt_template = CustomPromptTemplate(
Expand Down
2 changes: 2 additions & 0 deletions server/chat/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ async def chat(query: str = Body(..., description="用户输入", examples=["恼
stream: bool = Body(False, description="流式输出"),
model_name: str = Body(LLM_MODEL, description="LLM 模型名称。"),
temperature: float = Body(TEMPERATURE, description="LLM 采样温度", ge=0.0, le=1.0),
max_tokens: int = Body(1024, description="限制LLM生成Token数量,当前默认为1024"), # TODO: fastchat更新后默认值设为None,自动使用LLM支持的最大值。
# top_p: float = Body(TOP_P, description="LLM 核采样。勿与temperature同时设置", gt=0.0, lt=1.0),
prompt_name: str = Body("llm_chat", description="使用的prompt模板名称(在configs/prompt_config.py中配置)"),
):
Expand All @@ -36,6 +37,7 @@ async def chat_iterator(query: str,
model = get_ChatOpenAI(
model_name=model_name,
temperature=temperature,
max_tokens=max_tokens,
callbacks=[callback],
)

Expand Down
2 changes: 2 additions & 0 deletions server/chat/knowledge_base_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ async def knowledge_base_chat(query: str = Body(..., description="用户输入",
stream: bool = Body(False, description="流式输出"),
model_name: str = Body(LLM_MODEL, description="LLM 模型名称。"),
temperature: float = Body(TEMPERATURE, description="LLM 采样温度", ge=0.0, le=1.0),
max_tokens: int = Body(1024, description="限制LLM生成Token数量,当前默认为1024"), # TODO: fastchat更新后默认值设为None,自动使用LLM支持的最大值。
prompt_name: str = Body("knowledge_base_chat", description="使用的prompt模板名称(在configs/prompt_config.py中配置)"),
local_doc_url: bool = Body(False, description="知识文件返回本地路径(true)或URL(false)"),
request: Request = None,
Expand All @@ -51,6 +52,7 @@ async def knowledge_base_chat_iterator(query: str,
model = get_ChatOpenAI(
model_name=model_name,
temperature=temperature,
max_tokens=max_tokens,
callbacks=[callback],
)
docs = search_docs(query, knowledge_base_name, top_k, score_threshold)
Expand Down
2 changes: 2 additions & 0 deletions server/chat/search_engine_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ async def search_engine_chat(query: str = Body(..., description="用户输入",
stream: bool = Body(False, description="流式输出"),
model_name: str = Body(LLM_MODEL, description="LLM 模型名称。"),
temperature: float = Body(TEMPERATURE, description="LLM 采样温度", ge=0.0, le=1.0),
max_tokens: int = Body(1024, description="限制LLM生成Token数量,当前默认为1024"), # TODO: fastchat更新后默认值设为None,自动使用LLM支持的最大值。
prompt_name: str = Body("knowledge_base_chat", description="使用的prompt模板名称(在configs/prompt_config.py中配置)"),
):
if search_engine_name not in SEARCH_ENGINES.keys():
Expand All @@ -93,6 +94,7 @@ async def search_engine_chat_iterator(query: str,
model = get_ChatOpenAI(
model_name=model_name,
temperature=temperature,
max_tokens=max_tokens,
callbacks=[callback],
)

Expand Down
10 changes: 9 additions & 1 deletion server/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ async def wrap_done(fn: Awaitable, event: asyncio.Event):
def get_ChatOpenAI(
model_name: str,
temperature: float,
max_tokens: int = None,
streaming: bool = True,
callbacks: List[Callable] = [],
verbose: bool = True,
Expand All @@ -48,6 +49,7 @@ def get_ChatOpenAI(
openai_api_base=config.get("api_base_url", fschat_openai_api_address()),
model_name=model_name,
temperature=temperature,
max_tokens=max_tokens,
openai_proxy=config.get("openai_proxy"),
**kwargs
)
Expand Down Expand Up @@ -144,7 +146,7 @@ def run_async(cor):
return loop.run_until_complete(cor)


def iter_over_async(ait, loop):
def iter_over_async(ait, loop=None):
'''
将异步生成器封装成同步生成器.
'''
Expand All @@ -157,6 +159,12 @@ async def get_next():
except StopAsyncIteration:
return True, None

if loop is None:
try:
loop = asyncio.get_event_loop()
except:
loop = asyncio.new_event_loop()

while True:
done, obj = loop.run_until_complete(get_next())
if done:
Expand Down
2 changes: 1 addition & 1 deletion webui_pages/dialogue/dialogue.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ def on_kb_change():
ans = ""
support_agent = ["gpt", "Qwen", "qwen-api", "baichuan-api"] # 目前支持agent的模型
if not any(agent in llm_model for agent in support_agent):
ans += "正在思考... \n\n <span style='color:red'>改模型并没有进行Agent对齐,无法正常使用Agent功能!</span>\n\n\n<span style='color:red'>请更换 GPT4或Qwen-14B等支持Agent的模型获得更好的体验! </span> \n\n\n"
ans += "正在思考... \n\n <span style='color:red'>该模型并没有进行Agent对齐,无法正常使用Agent功能!</span>\n\n\n<span style='color:red'>请更换 GPT4或Qwen-14B等支持Agent的模型获得更好的体验! </span> \n\n\n"
chat_box.update_msg(ans, element_index=0, streaming=False)


Expand Down

0 comments on commit cd74812

Please sign in to comment.