Skip to content

Commit

Permalink
Refactoring api (#1001)
Browse files Browse the repository at this point in the history
* Use asynchrony to implement openai api access itself

* Update api.py

* Update api.py

* Update api.py

* Update api.py

* Update api.py

* Added api address verification
  • Loading branch information
Haibersut committed Jun 29, 2023
1 parent 71e9647 commit c4ba8e0
Show file tree
Hide file tree
Showing 5 changed files with 206 additions and 57 deletions.
208 changes: 168 additions & 40 deletions adapter/chatgpt/api.py
Original file line number Diff line number Diff line change
@@ -1,45 +1,118 @@
import ctypes
import os
import time
from typing import Generator
import openai
import json
import aiohttp
import async_timeout
from loguru import logger
from revChatGPT.V3 import Chatbot as OpenAIChatbot

from adapter.botservice import BotAdapter
from config import OpenAIAPIKey
from constants import botManager, config
import tiktoken

hashu = lambda word: ctypes.c_uint64(hash(word)).value

class OpenAIChatbot:
def __init__(self, api_info: OpenAIAPIKey):
self.api_key = api_info.api_key
self.proxy = api_info.proxy
self.presence_penalty = config.openai.gpt3_params.presence_penalty
self.frequency_penalty = config.openai.gpt3_params.frequency_penalty
self.top_p = config.openai.gpt3_params.top_p
self.temperature = config.openai.gpt3_params.temperature
self.max_tokens = config.openai.gpt3_params.max_tokens
self.engine = api_info.model or "gpt-3.5-turbo"
self.timeout = config.response.max_timeout
self.conversation: dict[str, list[dict]] = {
"default": [
{
"role": "system",
"content": "You are ChatGPT, a large language model trained by OpenAI. Knowledge cutoff: 2021-09 Current date:[current date]",
},
],
}

async def rollback(self, session_id: str = "default", n: int = 1) -> None:
try:
if session_id not in self.conversation:
raise ValueError(f"会话 ID {session_id} 不存在。")

if n > len(self.conversation[session_id]):
raise ValueError(f"回滚次数 {n} 超过了会话 {session_id} 的消息数量。")

for _ in range(n):
self.conversation[session_id].pop()

except ValueError as ve:
logger.error(ve)
raise
except Exception as e:
logger.error(f"未知错误: {e}")
raise

def add_to_conversation(
self,
message: str,
role: str,
session_id: str = "default",
) -> None:
self.conversation[session_id].append({"role": role, "content": message})

# https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb
def count_tokens(self, session_id: str = "default", model: str = "gpt-3.5-turbo"):
"""Return the number of tokens used by a list of messages."""
try:
encoding = tiktoken.encoding_for_model(model)
except KeyError:
encoding = tiktoken.get_encoding("cl100k_base")

if model in {
"gpt-3.5-turbo-0613",
"gpt-3.5-turbo-16k-0613",
"gpt-4-0314",
"gpt-4-32k-0314",
"gpt-4-0613",
"gpt-4-32k-0613",
"gpt-3.5-turbo",
"gpt-4"
}:
tokens_per_message = 3
tokens_per_name = 1
elif model == "gpt-3.5-turbo-0301":
tokens_per_message = 4 # every message follows {role/name}\n{content}\n
tokens_per_name = -1 # if there's a name, the role is omitted
else:
logger.warning("未找到相应模型计算方法,不进行计算")
return

num_tokens = 0
for message in self.conversation[session_id]:
num_tokens += tokens_per_message
for key, value in message.items():
num_tokens += len(encoding.encode(value))
if key == "name":
num_tokens += tokens_per_name
num_tokens += 3 # every reply is primed with assistant
return num_tokens

def get_max_tokens(self, session_id: str, model: str) -> int:
"""Get max tokens"""
return self.max_tokens - self.count_tokens(session_id, model)


class ChatGPTAPIAdapter(BotAdapter):
api_info: OpenAIAPIKey = None
"""API Key"""

bot: OpenAIChatbot = None
"""实例"""

hashed_user_id: str

def __init__(self, session_id: str = "unknown"):
self.__conversation_keep_from = 0
self.session_id = session_id
self.hashed_user_id = "user-" + hashu("session_id").to_bytes(8, "big").hex()
self.api_info = botManager.pick('openai-api')
self.bot = OpenAIChatbot(
api_key=self.api_info.api_key,
proxy=self.api_info.proxy,
presence_penalty=config.openai.gpt3_params.presence_penalty,
frequency_penalty=config.openai.gpt3_params.frequency_penalty,
top_p=config.openai.gpt3_params.top_p,
temperature=config.openai.gpt3_params.temperature,
max_tokens=config.openai.gpt3_params.max_tokens,
)
self.bot = OpenAIChatbot(self.api_info)
self.conversation_id = None
self.parent_id = None
super().__init__()
self.bot.conversation[self.session_id] = []
self.current_model = self.api_info.model or "gpt-3.5-turbo"
self.current_model = self.bot.engine
self.supported_models = [
"gpt-3.5-turbo",
"gpt-3.5-turbo-0301",
Expand All @@ -61,47 +134,102 @@ async def switch_model(self, model_name):
async def rollback(self):
if len(self.bot.conversation[self.session_id]) <= 0:
return False
self.bot.rollback(convo_id=self.session_id, n=2)
await self.bot.rollback(self.session_id, n=2)
return True

async def on_reset(self):
self.api_info = botManager.pick('openai-api')
self.bot.api_key = self.api_info.api_key
self.bot.proxy = self.api_info.proxy
self.bot.conversation[self.session_id] = []
self.bot.engine = self.api_info.model
self.__conversation_keep_from = 0

async def ask(self, prompt: str) -> Generator[str, None, None]:
self.api_info = botManager.pick('openai-api')
self.bot.api_key = self.api_info.api_key
self.bot.proxy = self.api_info.proxy
self.bot.session.proxies.update(
{
"http": self.bot.proxy,
"https": self.bot.proxy,
},
)
api_key = self.api_info.api_key
proxy = self.api_info.proxy
api_endpoint = config.openai.api_endpoint or "https://api.openai.com/v1"

if self.session_id not in self.bot.conversation:
self.bot.conversation[self.session_id] = [
{"role": "system", "content": self.bot.system_prompt}
{"role": "system", "content": prompt}
]
self.__conversation_keep_from = 1

while self.bot.max_tokens - self.bot.get_token_count(self.session_id) < config.openai.gpt3_params.min_tokens and \
while self.bot.max_tokens - self.bot.count_tokens(self.session_id) < config.openai.gpt3_params.min_tokens and \
len(self.bot.conversation[self.session_id]) > self.__conversation_keep_from:
self.bot.conversation[self.session_id].pop(self.__conversation_keep_from)
logger.debug(
f"清理 token,历史记录遗忘后使用 token 数:{str(self.bot.get_token_count(self.session_id))}"
f"清理 token,历史记录遗忘后使用 token 数:{str(self.bot.count_tokens(self.session_id))}"
)

os.environ['API_URL'] = f'{openai.api_base}/chat/completions'
full_response = ''
async for resp in self.bot.ask_stream_async(prompt=prompt, role=self.hashed_user_id, convo_id=self.session_id):
full_response += resp
yield full_response
logger.debug(f"[ChatGPT-API:{self.bot.engine}] 响应:{full_response}")
logger.debug(f"使用 token 数:{str(self.bot.get_token_count(self.session_id))}")
try:
logger.debug(f"[尝试使用ChatGPT-API:{self.bot.engine}] 请求:{prompt}")
self.bot.add_to_conversation(prompt, "user", session_id=self.session_id)
start_time = time.time()
headers = {
'Content-Type': 'application/json',
'Authorization': f'Bearer {api_key}'
}
data = {
'model': self.bot.engine,
'messages': self.bot.conversation[self.session_id],
'stream': True,
'temperature': self.bot.temperature,
'top_p': self.bot.top_p,
'presence_penalty': self.bot.presence_penalty,
'frequency_penalty': self.bot.frequency_penalty,
"user": 'user',
'max_tokens': self.bot.get_max_tokens(self.session_id, self.bot.engine),
}
async with aiohttp.ClientSession() as session:
with async_timeout.timeout(self.bot.timeout):
async with session.post(api_endpoint + '/chat/completions', headers=headers,
data=json.dumps(data), proxy=proxy) as resp:
if resp.status != 200:
response_text = await resp.text()
raise Exception(
f"{resp.status} {resp.reason} {response_text}",
)

response_role: str = ''
completion_text: str = ''

async for line in resp.content:
line = line.decode('utf-8').strip()
if not line.startswith("data: "):
continue
line = line[len("data: "):]
if line == "[DONE]":
break
if not line:
continue
try:
event = json.loads(line)
except json.JSONDecodeError:
raise Exception(f"JSON解码错误: {line}") from None
event_time = time.time() - start_time
if 'error' in event:
raise Exception(f"响应错误: {event['error']}")
if 'choices' in event and len(event['choices']) > 0 and 'delta' in event['choices'][0]:
delta = event['choices'][0]['delta']
if 'role' in delta:
response_role = delta['role']
if 'content' in delta:
event_text = delta['content']
completion_text += event_text
yield completion_text

self.bot.add_to_conversation(completion_text, response_role, session_id=self.session_id)
token_count = self.bot.count_tokens(self.session_id, self.bot.engine)
logger.debug(f"[ChatGPT-API:{self.bot.engine}] 响应:{completion_text}")
logger.debug(f"[ChatGPT-API:{self.bot.engine}] 使用 token 数:{token_count}")
logger.debug(f"[ChatGPT-API:{self.bot.engine}] 接收到全部消息花费了{event_time:.2f}秒")

except Exception as e:
logger.error(f"[ChatGPT-API:{self.bot.engine}] 请求失败:\n{e}")
yield f"发生错误: \n{e}"

async def preset_ask(self, role: str, text: str):
if role.endswith('bot') or role in {'assistant', 'chatgpt'}:
Expand Down
17 changes: 14 additions & 3 deletions manager/bot.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import urllib.request
from typing import List, Dict
from urllib.parse import urlparse

import re
import base64
import json
import time
Expand Down Expand Up @@ -120,8 +120,19 @@ async def handle_openai(self):
openai.api_base = self.config.openai.api_endpoint or openai.api_base
if openai.api_base.endswith("/"):
openai.api_base.removesuffix("/")
logger.info(f"当前的 api_endpoint 为:{openai.api_base}")
await self.login_openai()

pattern = r'^https://[^/]+/v1$'
match = re.match(pattern, openai.api_base)

if match:
logger.info(f"当前的 api_endpoint 为:{openai.api_base}")
await self.login_openai()
else:
logger.error("API反代地址填写错误,正确格式应为 'https://<网址>/v1'")
raise ValueError("API反代地址填写错误,正确格式应为 'https://<网址>/v1'")




async def login(self):
self.bots = {
Expand Down
10 changes: 7 additions & 3 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
graia-ariadne==0.11.5
graiax-silkcoder
revChatGPT==6.3.3
revChatGPT~=6.5.0
toml~=0.10.2
Pillow>=9.3.0
tinydb~=4.8.0
tinydb~=4.7.1

loguru~=0.7.0
asyncio~=3.4.3
Expand Down Expand Up @@ -33,9 +33,13 @@ azure-cognitiveservices-speech
poe-api~=0.4.6

regex~=2023.6.3
httpx
httpx~=0.24.1
Quart==0.17.0

edge-tts
wechatpy~=2.0.0a26
pydub~=0.25.1

creart~=0.2.2
tiktoken~=0.4.0
httpcore~=0.17.2
4 changes: 3 additions & 1 deletion universal.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,8 +201,10 @@ async def request(_session_id, prompt: str, conversation_context, _respond):
await conversation_context.switch_model(model_name)
await respond(f"已切换至 {model_name} 模型,让我们聊天吧!")
else:
logger.warning(f"模型 {model_name} 不在支持列表中,下次将尝试使用此模型创建对话。")
await conversation_context.switch_model(model_name)
await respond(
f"当前的 AI 不支持切换至 {model_name} 模型,目前仅支持{conversation_context.supported_models}!")
f"模型 {model_name} 不在支持列表中,下次将尝试使用此模型创建对话,目前AI仅支持{conversation_context.supported_models}!")
return

# 加载预设
Expand Down
24 changes: 14 additions & 10 deletions utils/text_to_img.py
Original file line number Diff line number Diff line change
Expand Up @@ -333,16 +333,19 @@ async def text_to_image(text):
with StringIO(html) as input_file:
ok = False
try:
# 调用imgkit将html转为图片
ok = await asyncio.get_event_loop().run_in_executor(None, imgkit.from_file, input_file,
temp_jpg_filename, {
"enable-local-file-access": "",
"allow": asset_folder,
"width": config.text_to_image.width, # 图片宽度
}, None, None, None, imgkit_config)
# 调用PIL将图片读取为 JPEG,RGB 格式
image = Image.open(temp_jpg_filename, formats=['PNG']).convert('RGB')
ok = True
if config.text_to_image.wkhtmltoimage:
# 调用imgkit将html转为图片
ok = await asyncio.get_event_loop().run_in_executor(None, imgkit.from_file, input_file,
temp_jpg_filename, {
"enable-local-file-access": "",
"allow": asset_folder,
"width": config.text_to_image.width, # 图片宽度
}, None, None, None, imgkit_config)
# 调用PIL将图片读取为 JPEG,RGB 格式
image = Image.open(temp_jpg_filename, formats=['PNG']).convert('RGB')
ok = True
else:
ok = False
except Exception as e:
logger.exception(e)
finally:
Expand All @@ -357,6 +360,7 @@ async def text_to_image(text):

return image


async def to_image(text) -> GraiaImage:
img = await text_to_image(text=str(text))
b = BytesIO()
Expand Down

0 comments on commit c4ba8e0

Please sign in to comment.