Skip to content

Commit

Permalink
Merge pull request #3203 from qqlww1987/master
Browse files Browse the repository at this point in the history
百川调用示例
  • Loading branch information
zRzRzRzRzRzRzR committed May 1, 2024
2 parents cbc28d7 + 7198654 commit 24faba5
Showing 1 changed file with 34 additions and 37 deletions.
71 changes: 34 additions & 37 deletions server/model_workers/baichuan.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import json
import time
import hashlib

import requests
from fastchat.conversation import Conversation
from server.model_workers.base import *
from server.utils import get_httpx_client
Expand Down Expand Up @@ -32,61 +32,58 @@ def __init__(
kwargs.setdefault("context_len", 32768)
super().__init__(**kwargs)
self.version = version

def do_chat(self, params: ApiChatParams) -> Dict:
params.load_config(self.model_names[0])

url = "https://api.baichuan-ai.com/v1/stream/chat"
url = "https://api.baichuan-ai.com/v1/chat/completions"
data = {
"model": params.version,
"messages": params.messages,
"parameters": {"temperature": params.temperature}
"stream": False,

}

json_data = json.dumps(data)
time_stamp = int(time.time())
signature = calculate_md5(params.secret_key + json_data + str(time_stamp))
headers = {
"Content-Type": "application/json",
"Authorization": "Bearer " + params.api_key,
"X-BC-Request-Id": "your requestId",
"X-BC-Timestamp": str(time_stamp),
"X-BC-Signature": signature,
"X-BC-Sign-Algo": "MD5",

}

text = ""
if log_verbose:
logger.info(f'{self.__class__.__name__}:json_data: {json_data}')
logger.info(f'{self.__class__.__name__}:url: {url}')
logger.info(f'{self.__class__.__name__}:headers: {headers}')

with get_httpx_client() as client:
with client.stream("POST", url, headers=headers, json=data) as response:
for line in response.iter_lines():
if not line.strip():
continue
resp = json.loads(line)
if resp["code"] == 0:
text += resp["data"]["messages"][-1]["content"]
yield {
"error_code": resp["code"],
"text": text
response = requests.post(url, headers=headers, json=data)
if response.status_code == 200:
print("请求成功!"+response.text)
result = json.loads(response.text)
textMsg=""
result["choices"][0]["delta"]=result["choices"][0]["message"]
if 'choices' in result:
textMsg += result["choices"][0]["message"]["content"]
data = {
"error_code": response.status_code,
"text": textMsg,
"choices":result["choices"],
"model":result["model"],
"object":result["object"],
"object":result["object"],
"created":result["created"],
"id":result["id"],
}
else:
data = {
"error_code": resp["code"],
"text": resp["msg"],

yield data

else:

data = {
"error_code": response.status_code,
"text":response.text,
"error": {
"message": resp["msg"],
"message": response.text,
"type": "invalid_request_error",
"param": None,
"code": None,
}
}
self.logger.error(f"请求百川 API 时发生错误:{data}")
yield data

}
self.logger.error(f"请求百川 API 时发生错误:{data}")
yield data
def get_embeddings(self, params):
print("embedding")
print(params)
Expand Down

0 comments on commit 24faba5

Please sign in to comment.