Skip to content

Commit

Permalink
Merge pull request #2752 from zRzRzRzRzRzRzR/dev
Browse files Browse the repository at this point in the history
gemini API 修复
  • Loading branch information
zRzRzRzRzRzRzR committed Jan 22, 2024
2 parents 6437883 + 17803cb commit 54e5b41
Show file tree
Hide file tree
Showing 4 changed files with 58 additions and 60 deletions.
11 changes: 5 additions & 6 deletions configs/server_config.py.example
Original file line number Diff line number Diff line change
Expand Up @@ -92,11 +92,10 @@ FSCHAT_MODEL_WORKERS = {
# 'disable_log_requests': False

},
# 可以如下示例方式更改默认配置
# "Qwen-1_8B-Chat": { # 使用default中的IP和端口
# "device": "cpu",
# },
"chatglm3-6b": { # 使用default中的IP和端口
"Qwen-1_8B-Chat": {
"device": "cpu",
},
"chatglm3-6b": {
"device": "cuda",
},

Expand Down Expand Up @@ -129,7 +128,7 @@ FSCHAT_MODEL_WORKERS = {
"port": 21009,
},
"gemini-api": {
"port": 21012,
"port": 21010,
},
}

Expand Down
10 changes: 5 additions & 5 deletions server/knowledge_base/kb_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,9 @@ def list_kbs():


def create_kb(knowledge_base_name: str = Body(..., examples=["samples"]),
vector_store_type: str = Body("faiss"),
embed_model: str = Body(EMBEDDING_MODEL),
) -> BaseResponse:
vector_store_type: str = Body("faiss"),
embed_model: str = Body(EMBEDDING_MODEL),
) -> BaseResponse:
# Create selected knowledge base
if not validate_kb_name(knowledge_base_name):
return BaseResponse(code=403, msg="Don't attack me")
Expand All @@ -39,8 +39,8 @@ def create_kb(knowledge_base_name: str = Body(..., examples=["samples"]),


def delete_kb(
knowledge_base_name: str = Body(..., examples=["samples"])
) -> BaseResponse:
knowledge_base_name: str = Body(..., examples=["samples"])
) -> BaseResponse:
# Delete selected knowledge base
if not validate_kb_name(knowledge_base_name):
return BaseResponse(code=403, msg="Don't attack me")
Expand Down
18 changes: 9 additions & 9 deletions server/model_workers/gemini.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from server.model_workers.base import *
from server.utils import get_httpx_client
from fastchat import conversation as conv
import json,httpx
import json, httpx
from typing import List, Dict
from configs import logger, log_verbose

Expand All @@ -14,14 +14,14 @@ def __init__(
*,
controller_addr: str = None,
worker_addr: str = None,
model_names: List[str] = ["Gemini-api"],
model_names: List[str] = ["gemini-api"],
**kwargs,
):
kwargs.update(model_names=model_names, controller_addr=controller_addr, worker_addr=worker_addr)
kwargs.setdefault("context_len", 4096)
super().__init__(**kwargs)

def create_gemini_messages(self,messages) -> json:
def create_gemini_messages(self, messages) -> json:
has_history = any(msg['role'] == 'assistant' for msg in messages)
gemini_msg = []

Expand All @@ -42,20 +42,20 @@ def create_gemini_messages(self,messages) -> json:

msg = dict(contents=gemini_msg)
return msg

def do_chat(self, params: ApiChatParams) -> Dict:
params.load_config(self.model_names[0])
data = self.create_gemini_messages(messages=params.messages)
generationConfig=dict(
generationConfig = dict(
temperature=params.temperature,
topK=1,
topP=1,
maxOutputTokens=4096,
stopSequences=[]
)

data['generationConfig'] = generationConfig
url = "https://generativelanguage.googleapis.com/v1beta/models/gemini-pro:generateContent"+ '?key=' + params.api_key
data['generationConfig'] = generationConfig
url = "https://generativelanguage.googleapis.com/v1beta/models/gemini-pro:generateContent" + '?key=' + params.api_key
headers = {
'Content-Type': 'application/json',
}
Expand All @@ -67,7 +67,7 @@ def do_chat(self, params: ApiChatParams) -> Dict:
text = ""
json_string = ""
timeout = httpx.Timeout(60.0)
client=get_httpx_client(timeout=timeout)
client = get_httpx_client(timeout=timeout)
with client.stream("POST", url, headers=headers, json=data) as response:
for line in response.iter_lines():
line = line.strip()
Expand All @@ -89,7 +89,7 @@ def do_chat(self, params: ApiChatParams) -> Dict:
"error_code": 0,
"text": text
}
print(text)
print(text)
except json.JSONDecodeError as e:
print("Failed to decode JSON:", e)
print("Invalid JSON string:", json_string)
Expand Down
79 changes: 39 additions & 40 deletions webui_pages/dialogue/dialogue.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
import uuid
from typing import List, Dict


chat_box = ChatBox(
assistant_avatar=os.path.join(
"img",
Expand Down Expand Up @@ -138,11 +137,11 @@ def on_mode_change():
st.toast(text)

dialogue_modes = ["LLM 对话",
"知识库问答",
"文件对话",
"搜索引擎问答",
"自定义Agent问答",
]
"知识库问答",
"文件对话",
"搜索引擎问答",
"自定义Agent问答",
]
dialogue_mode = st.selectbox("请选择对话模式:",
dialogue_modes,
index=0,
Expand All @@ -166,9 +165,9 @@ def llm_model_format_func(x):
available_models = []
config_models = api.list_config_models()
if not is_lite:
for k, v in config_models.get("local", {}).items(): # 列出配置了有效本地路径的模型
for k, v in config_models.get("local", {}).items(): # 列出配置了有效本地路径的模型
if (v.get("model_path_exists")
and k not in running_models):
and k not in running_models):
available_models.append(k)
for k, v in config_models.get("online", {}).items(): # 列出ONLINE_MODELS中直接访问的模型
if not v.get("provider") and k not in running_models:
Expand Down Expand Up @@ -250,14 +249,14 @@ def on_kb_change():
elif dialogue_mode == "文件对话":
with st.expander("文件对话配置", True):
files = st.file_uploader("上传知识文件:",
[i for ls in LOADER_DICT.values() for i in ls],
accept_multiple_files=True,
)
[i for ls in LOADER_DICT.values() for i in ls],
accept_multiple_files=True,
)
kb_top_k = st.number_input("匹配知识条数:", 1, 20, VECTOR_SEARCH_TOP_K)

## Bge 模型会超过1
score_threshold = st.slider("知识匹配分数阈值:", 0.0, 2.0, float(SCORE_THRESHOLD), 0.01)
if st.button("开始上传", disabled=len(files)==0):
if st.button("开始上传", disabled=len(files) == 0):
st.session_state["file_chat_id"] = upload_temp_docs(files, api)
elif dialogue_mode == "搜索引擎问答":
search_engine_list = api.list_search_engines()
Expand All @@ -279,9 +278,9 @@ def on_kb_change():
chat_input_placeholder = "请输入对话内容,换行请使用Shift+Enter。输入/help查看自定义命令 "

def on_feedback(
feedback,
message_id: str = "",
history_index: int = -1,
feedback,
message_id: str = "",
history_index: int = -1,
):
reason = feedback["text"]
score_int = chat_box.set_feedback(feedback=feedback, history_index=history_index)
Expand All @@ -296,7 +295,7 @@ def on_feedback(
}

if prompt := st.chat_input(chat_input_placeholder, key="prompt"):
if parse_command(text=prompt, modal=modal): # 用户输入自定义命令
if parse_command(text=prompt, modal=modal): # 用户输入自定义命令
st.rerun()
else:
history = get_messages_history(history_len)
Expand All @@ -306,11 +305,11 @@ def on_feedback(
text = ""
message_id = ""
r = api.chat_chat(prompt,
history=history,
conversation_id=conversation_id,
model=llm_model,
prompt_name=prompt_template_name,
temperature=temperature)
history=history,
conversation_id=conversation_id,
model=llm_model,
prompt_name=prompt_template_name,
temperature=temperature)
for t in r:
if error_msg := check_error_msg(t): # check whether error occured
st.error(error_msg)
Expand All @@ -321,12 +320,12 @@ def on_feedback(

metadata = {
"message_id": message_id,
}
}
chat_box.update_msg(text, streaming=False, metadata=metadata) # 更新最终的字符串,去除光标
chat_box.show_feedback(**feedback_kwargs,
key=message_id,
on_submit=on_feedback,
kwargs={"message_id": message_id, "history_index": len(chat_box.history) - 1})
key=message_id,
on_submit=on_feedback,
kwargs={"message_id": message_id, "history_index": len(chat_box.history) - 1})

elif dialogue_mode == "自定义Agent问答":
if not any(agent in llm_model for agent in SUPPORT_AGENT_MODEL):
Expand Down Expand Up @@ -373,13 +372,13 @@ def on_feedback(
])
text = ""
for d in api.knowledge_base_chat(prompt,
knowledge_base_name=selected_kb,
top_k=kb_top_k,
score_threshold=score_threshold,
history=history,
model=llm_model,
prompt_name=prompt_template_name,
temperature=temperature):
knowledge_base_name=selected_kb,
top_k=kb_top_k,
score_threshold=score_threshold,
history=history,
model=llm_model,
prompt_name=prompt_template_name,
temperature=temperature):
if error_msg := check_error_msg(d): # check whether error occured
st.error(error_msg)
elif chunk := d.get("answer"):
Expand All @@ -397,13 +396,13 @@ def on_feedback(
])
text = ""
for d in api.file_chat(prompt,
knowledge_id=st.session_state["file_chat_id"],
top_k=kb_top_k,
score_threshold=score_threshold,
history=history,
model=llm_model,
prompt_name=prompt_template_name,
temperature=temperature):
knowledge_id=st.session_state["file_chat_id"],
top_k=kb_top_k,
score_threshold=score_threshold,
history=history,
model=llm_model,
prompt_name=prompt_template_name,
temperature=temperature):
if error_msg := check_error_msg(d): # check whether error occured
st.error(error_msg)
elif chunk := d.get("answer"):
Expand Down Expand Up @@ -455,4 +454,4 @@ def on_feedback(
file_name=f"{now:%Y-%m-%d %H.%M}_对话记录.md",
mime="text/markdown",
use_container_width=True,
)
)

0 comments on commit 54e5b41

Please sign in to comment.