From 17803cb7c1e5954d5a6f80cc588f26845fcfb9d8 Mon Sep 17 00:00:00 2001 From: zR <2448370773@qq.com> Date: Mon, 22 Jan 2024 13:14:13 +0800 Subject: [PATCH] =?UTF-8?q?gemini=20api=20=E4=BF=AE=E5=A4=8D=E8=B0=83?= =?UTF-8?q?=E7=94=A8?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- configs/server_config.py.example | 11 ++--- server/knowledge_base/kb_api.py | 10 ++-- server/model_workers/gemini.py | 18 ++++---- webui_pages/dialogue/dialogue.py | 79 ++++++++++++++++---------------- 4 files changed, 58 insertions(+), 60 deletions(-) diff --git a/configs/server_config.py.example b/configs/server_config.py.example index 2f51c3ad8..a812376b0 100644 --- a/configs/server_config.py.example +++ b/configs/server_config.py.example @@ -92,11 +92,10 @@ FSCHAT_MODEL_WORKERS = { # 'disable_log_requests': False }, - # 可以如下示例方式更改默认配置 - # "Qwen-1_8B-Chat": { # 使用default中的IP和端口 - # "device": "cpu", - # }, - "chatglm3-6b": { # 使用default中的IP和端口 + "Qwen-1_8B-Chat": { + "device": "cpu", + }, + "chatglm3-6b": { "device": "cuda", }, @@ -129,7 +128,7 @@ FSCHAT_MODEL_WORKERS = { "port": 21009, }, "gemini-api": { - "port": 21012, + "port": 21010, }, } diff --git a/server/knowledge_base/kb_api.py b/server/knowledge_base/kb_api.py index f50d8a732..0d2fbce9d 100644 --- a/server/knowledge_base/kb_api.py +++ b/server/knowledge_base/kb_api.py @@ -13,9 +13,9 @@ def list_kbs(): def create_kb(knowledge_base_name: str = Body(..., examples=["samples"]), - vector_store_type: str = Body("faiss"), - embed_model: str = Body(EMBEDDING_MODEL), - ) -> BaseResponse: + vector_store_type: str = Body("faiss"), + embed_model: str = Body(EMBEDDING_MODEL), + ) -> BaseResponse: # Create selected knowledge base if not validate_kb_name(knowledge_base_name): return BaseResponse(code=403, msg="Don't attack me") @@ -39,8 +39,8 @@ def create_kb(knowledge_base_name: str = Body(..., examples=["samples"]), def delete_kb( - knowledge_base_name: str = Body(..., examples=["samples"]) - ) -> BaseResponse: + knowledge_base_name: str = Body(..., examples=["samples"]) +) -> BaseResponse: # Delete selected knowledge base if not validate_kb_name(knowledge_base_name): return BaseResponse(code=403, msg="Don't attack me") diff --git a/server/model_workers/gemini.py b/server/model_workers/gemini.py index 0cd8e159b..db41029b6 100644 --- a/server/model_workers/gemini.py +++ b/server/model_workers/gemini.py @@ -3,7 +3,7 @@ from server.model_workers.base import * from server.utils import get_httpx_client from fastchat import conversation as conv -import json,httpx +import json, httpx from typing import List, Dict from configs import logger, log_verbose @@ -14,14 +14,14 @@ def __init__( *, controller_addr: str = None, worker_addr: str = None, - model_names: List[str] = ["Gemini-api"], + model_names: List[str] = ["gemini-api"], **kwargs, ): kwargs.update(model_names=model_names, controller_addr=controller_addr, worker_addr=worker_addr) kwargs.setdefault("context_len", 4096) super().__init__(**kwargs) - def create_gemini_messages(self,messages) -> json: + def create_gemini_messages(self, messages) -> json: has_history = any(msg['role'] == 'assistant' for msg in messages) gemini_msg = [] @@ -42,11 +42,11 @@ def create_gemini_messages(self,messages) -> json: msg = dict(contents=gemini_msg) return msg - + def do_chat(self, params: ApiChatParams) -> Dict: params.load_config(self.model_names[0]) data = self.create_gemini_messages(messages=params.messages) - generationConfig=dict( + generationConfig = dict( temperature=params.temperature, topK=1, topP=1, @@ -54,8 +54,8 @@ def do_chat(self, params: ApiChatParams) -> Dict: stopSequences=[] ) - data['generationConfig'] = generationConfig - url = "https://generativelanguage.googleapis.com/v1beta/models/gemini-pro:generateContent"+ '?key=' + params.api_key + data['generationConfig'] = generationConfig + url = "https://generativelanguage.googleapis.com/v1beta/models/gemini-pro:generateContent" + '?key=' + params.api_key headers = { 'Content-Type': 'application/json', } @@ -67,7 +67,7 @@ def do_chat(self, params: ApiChatParams) -> Dict: text = "" json_string = "" timeout = httpx.Timeout(60.0) - client=get_httpx_client(timeout=timeout) + client = get_httpx_client(timeout=timeout) with client.stream("POST", url, headers=headers, json=data) as response: for line in response.iter_lines(): line = line.strip() @@ -89,7 +89,7 @@ def do_chat(self, params: ApiChatParams) -> Dict: "error_code": 0, "text": text } - print(text) + print(text) except json.JSONDecodeError as e: print("Failed to decode JSON:", e) print("Invalid JSON string:", json_string) diff --git a/webui_pages/dialogue/dialogue.py b/webui_pages/dialogue/dialogue.py index 325cd5d14..b5691ffd0 100644 --- a/webui_pages/dialogue/dialogue.py +++ b/webui_pages/dialogue/dialogue.py @@ -12,7 +12,6 @@ import uuid from typing import List, Dict - chat_box = ChatBox( assistant_avatar=os.path.join( "img", @@ -138,11 +137,11 @@ def on_mode_change(): st.toast(text) dialogue_modes = ["LLM 对话", - "知识库问答", - "文件对话", - "搜索引擎问答", - "自定义Agent问答", - ] + "知识库问答", + "文件对话", + "搜索引擎问答", + "自定义Agent问答", + ] dialogue_mode = st.selectbox("请选择对话模式:", dialogue_modes, index=0, @@ -166,9 +165,9 @@ def llm_model_format_func(x): available_models = [] config_models = api.list_config_models() if not is_lite: - for k, v in config_models.get("local", {}).items(): # 列出配置了有效本地路径的模型 + for k, v in config_models.get("local", {}).items(): # 列出配置了有效本地路径的模型 if (v.get("model_path_exists") - and k not in running_models): + and k not in running_models): available_models.append(k) for k, v in config_models.get("online", {}).items(): # 列出ONLINE_MODELS中直接访问的模型 if not v.get("provider") and k not in running_models: @@ -250,14 +249,14 @@ def on_kb_change(): elif dialogue_mode == "文件对话": with st.expander("文件对话配置", True): files = st.file_uploader("上传知识文件:", - [i for ls in LOADER_DICT.values() for i in ls], - accept_multiple_files=True, - ) + [i for ls in LOADER_DICT.values() for i in ls], + accept_multiple_files=True, + ) kb_top_k = st.number_input("匹配知识条数:", 1, 20, VECTOR_SEARCH_TOP_K) ## Bge 模型会超过1 score_threshold = st.slider("知识匹配分数阈值:", 0.0, 2.0, float(SCORE_THRESHOLD), 0.01) - if st.button("开始上传", disabled=len(files)==0): + if st.button("开始上传", disabled=len(files) == 0): st.session_state["file_chat_id"] = upload_temp_docs(files, api) elif dialogue_mode == "搜索引擎问答": search_engine_list = api.list_search_engines() @@ -279,9 +278,9 @@ def on_kb_change(): chat_input_placeholder = "请输入对话内容,换行请使用Shift+Enter。输入/help查看自定义命令 " def on_feedback( - feedback, - message_id: str = "", - history_index: int = -1, + feedback, + message_id: str = "", + history_index: int = -1, ): reason = feedback["text"] score_int = chat_box.set_feedback(feedback=feedback, history_index=history_index) @@ -296,7 +295,7 @@ def on_feedback( } if prompt := st.chat_input(chat_input_placeholder, key="prompt"): - if parse_command(text=prompt, modal=modal): # 用户输入自定义命令 + if parse_command(text=prompt, modal=modal): # 用户输入自定义命令 st.rerun() else: history = get_messages_history(history_len) @@ -306,11 +305,11 @@ def on_feedback( text = "" message_id = "" r = api.chat_chat(prompt, - history=history, - conversation_id=conversation_id, - model=llm_model, - prompt_name=prompt_template_name, - temperature=temperature) + history=history, + conversation_id=conversation_id, + model=llm_model, + prompt_name=prompt_template_name, + temperature=temperature) for t in r: if error_msg := check_error_msg(t): # check whether error occured st.error(error_msg) @@ -321,12 +320,12 @@ def on_feedback( metadata = { "message_id": message_id, - } + } chat_box.update_msg(text, streaming=False, metadata=metadata) # 更新最终的字符串,去除光标 chat_box.show_feedback(**feedback_kwargs, - key=message_id, - on_submit=on_feedback, - kwargs={"message_id": message_id, "history_index": len(chat_box.history) - 1}) + key=message_id, + on_submit=on_feedback, + kwargs={"message_id": message_id, "history_index": len(chat_box.history) - 1}) elif dialogue_mode == "自定义Agent问答": if not any(agent in llm_model for agent in SUPPORT_AGENT_MODEL): @@ -373,13 +372,13 @@ def on_feedback( ]) text = "" for d in api.knowledge_base_chat(prompt, - knowledge_base_name=selected_kb, - top_k=kb_top_k, - score_threshold=score_threshold, - history=history, - model=llm_model, - prompt_name=prompt_template_name, - temperature=temperature): + knowledge_base_name=selected_kb, + top_k=kb_top_k, + score_threshold=score_threshold, + history=history, + model=llm_model, + prompt_name=prompt_template_name, + temperature=temperature): if error_msg := check_error_msg(d): # check whether error occured st.error(error_msg) elif chunk := d.get("answer"): @@ -397,13 +396,13 @@ def on_feedback( ]) text = "" for d in api.file_chat(prompt, - knowledge_id=st.session_state["file_chat_id"], - top_k=kb_top_k, - score_threshold=score_threshold, - history=history, - model=llm_model, - prompt_name=prompt_template_name, - temperature=temperature): + knowledge_id=st.session_state["file_chat_id"], + top_k=kb_top_k, + score_threshold=score_threshold, + history=history, + model=llm_model, + prompt_name=prompt_template_name, + temperature=temperature): if error_msg := check_error_msg(d): # check whether error occured st.error(error_msg) elif chunk := d.get("answer"): @@ -455,4 +454,4 @@ def on_feedback( file_name=f"{now:%Y-%m-%d %H.%M}_对话记录.md", mime="text/markdown", use_container_width=True, - ) \ No newline at end of file + )