diff --git a/configs/kb_config.py.example b/configs/kb_config.py.example index 04e09ecf3..9c727b775 100644 --- a/configs/kb_config.py.example +++ b/configs/kb_config.py.example @@ -100,7 +100,7 @@ kbs_config = { "index_name": "test_index", "user": "", "password": "" - } + } } # TextSplitter配置项,如果你不明白其中的含义,就不要修改。 diff --git a/server/model_workers/minimax.py b/server/model_workers/minimax.py index 47d6099dd..220ed5814 100644 --- a/server/model_workers/minimax.py +++ b/server/model_workers/minimax.py @@ -106,7 +106,7 @@ def do_embeddings(self, params: ApiEmbeddingsParams) -> Dict: data = { "model": params.embed_model or self.DEFAULT_EMBED_MODEL, - "texts": params.texts, + "texts": [], "type": "query" if params.to_query else "db", } if log_verbose: @@ -115,21 +115,26 @@ def do_embeddings(self, params: ApiEmbeddingsParams) -> Dict: logger.info(f'{self.__class__.__name__}:headers: {headers}') with get_httpx_client() as client: - r = client.post(url, headers=headers, json=data).json() - if embeddings := r.get("vectors"): - return {"code": 200, "data": embeddings} - elif error := r.get("base_resp"): - return { - "code": error["status_code"], - "msg": error["status_msg"], - - "error": { - "message": error["status_msg"], - "type": "invalid_request_error", - "param": None, - "code": None, + result = [] + i = 0 + for texts in params.texts[i:i+10]: + data["texts"] = texts + r = client.post(url, headers=headers, json=data).json() + if embeddings := r.get("vectors"): + result += embeddings + elif error := r.get("base_resp"): + return { + "code": error["status_code"], + "msg": error["status_msg"], + "error": { + "message": error["status_msg"], + "type": "invalid_request_error", + "param": None, + "code": None, + } } - } + i += 10 + return {"code": 200, "data": embeddings} def get_embeddings(self, params): # TODO: 支持embeddings diff --git a/server/model_workers/qianfan.py b/server/model_workers/qianfan.py index dacb8d1b6..2ab39d19e 100644 --- a/server/model_workers/qianfan.py +++ b/server/model_workers/qianfan.py @@ -184,21 +184,26 @@ def do_embeddings(self, params: ApiEmbeddingsParams) -> Dict: logger.info(f'{self.__class__.__name__}:url: {url}') with get_httpx_client() as client: - resp = client.post(url, json={"input": params.texts}).json() - if "error_cdoe" not in resp: - embeddings = [x["embedding"] for x in resp.get("data", [])] - return {"code": 200, "data": embeddings} - else: - return { - "code": resp["error_code"], - "msg": resp["error_msg"], - "error": { - "message": resp["error_msg"], - "type": "invalid_request_error", - "param": None, - "code": None, + result = [] + i = 0 + for texts in params.texts[i:i+10]: + resp = client.post(url, json={"input": texts}).json() + if "error_cdoe" in resp: + return { + "code": resp["error_code"], + "msg": resp["error_msg"], + "error": { + "message": resp["error_msg"], + "type": "invalid_request_error", + "param": None, + "code": None, + } } - } + else: + embeddings = [x["embedding"] for x in resp.get("data", [])] + result += embeddings + i += 10 + return {"code": 200, "data": result} # TODO: qianfan支持续写模型 def get_embeddings(self, params):