Skip to content

Commit

Permalink
Merge pull request #3912 from pearjelly/master
Browse files Browse the repository at this point in the history
修复zhipu-api向量化失败问题
  • Loading branch information
zRzRzRzRzRzRzR committed May 1, 2024
2 parents f9beb14 + 160f0b8 commit cbc28d7
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 13 deletions.
4 changes: 2 additions & 2 deletions server/knowledge_base/kb_cache/faiss_cache.py
Expand Up @@ -57,7 +57,7 @@ def new_vector_store(
) -> FAISS:
embeddings = EmbeddingsFunAdapter(embed_model)
doc = Document(page_content="init", metadata={})
vector_store = FAISS.from_documents([doc], embeddings, normalize_L2=True,distance_strategy="METRIC_INNER_PRODUCT")
vector_store = FAISS.from_documents([doc], embeddings, distance_strategy="METRIC_INNER_PRODUCT")
ids = list(vector_store.docstore._dict.keys())
vector_store.delete(ids)
return vector_store
Expand Down Expand Up @@ -94,7 +94,7 @@ def load_vector_store(

if os.path.isfile(os.path.join(vs_path, "index.faiss")):
embeddings = self.load_kb_embeddings(kb_name=kb_name, embed_device=embed_device, default_embed_model=embed_model)
vector_store = FAISS.load_local(vs_path, embeddings, normalize_L2=True,distance_strategy="METRIC_INNER_PRODUCT")
vector_store = FAISS.load_local(vs_path, embeddings, distance_strategy="METRIC_INNER_PRODUCT")
elif create:
# create an empty vector store
if not os.path.exists(vs_path):
Expand Down
40 changes: 29 additions & 11 deletions server/model_workers/zhipu.py
@@ -1,6 +1,7 @@
from contextlib import contextmanager

import httpx
import requests
from fastchat.conversation import Conversation
from httpx_sse import EventSource

Expand Down Expand Up @@ -44,7 +45,7 @@ class ChatGLMWorker(ApiModelWorker):
def __init__(
self,
*,
model_names: List[str] = ["zhipu-api"],
model_names: List[str] = ("zhipu-api",),
controller_addr: str = None,
worker_addr: str = None,
version: Literal["glm-4"] = "glm-4",
Expand Down Expand Up @@ -87,28 +88,45 @@ def do_chat(self, params: ApiChatParams) -> Iterator[Dict]:


def do_embeddings(self, params: ApiEmbeddingsParams) -> Dict:
embed_model = params.embed_model or self.DEFAULT_EMBED_MODEL

params.load_config(self.model_names[0])
token = generate_token(params.api_key, 60)
headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {token}"
}
i = 0
batch_size = 1
result = []
while i < len(params.texts):
token = generate_token(params.api_key, 60)
headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {token}"
}
data = {
"model": params.embed_model or self.DEFAULT_EMBED_MODEL,
"model": embed_model,
"input": "".join(params.texts[i: i + batch_size])
}
embedding_data = self.request_embedding_api(headers, data, 1)
if embedding_data:
result.append(embedding_data)
i += batch_size
print(f"请求{embed_model}接口处理第{i}块文本,返回embeddings: \n{embedding_data}")

return {"code": 200, "data": result}

# 请求接口,支持重试
def request_embedding_api(self, headers, data, retry=0):
response = ''
try:
url = "https://open.bigmodel.cn/api/paas/v4/embeddings"
response = requests.post(url, headers=headers, json=data)
ans = response.json()
result.append(ans["data"][0]["embedding"])
i += batch_size
return ans["data"][0]["embedding"]
except Exception as e:
print(f"request_embedding_api error={e} \nresponse={response}")
if retry > 0:
return self.request_embedding_api(headers, data, retry - 1)
else:
return None

return {"code": 200, "data": result}

def get_embeddings(self, params):
print("embedding")
print(params)
Expand Down

0 comments on commit cbc28d7

Please sign in to comment.