diff --git a/bot.py b/bot.py index 61e938b5..a095f8d5 100644 --- a/bot.py +++ b/bot.py @@ -28,6 +28,7 @@ import chatbot from config import Config from utils.text_to_img import to_image +import conversation_manager from manager.ratelimit import RateLimitManager import time from revChatGPT.V1 import Error as V1Error @@ -71,11 +72,32 @@ async def create_timeout_task(target: Union[Friend, Group], source: Source): async def handle_message(target: Union[Friend, Group], session_id: str, message: str, source: Source) -> str: + number = session_id.split('-')[1] if not message.strip(): return config.response.placeholder timeout_task = None + # 如果消息包含help命令(config.trigger.help_command所定义的内容),则回滚会话 + if message.strip() in config.trigger.help_command: + return config.response.help_command.format(max_sessions=config.max_record.max_sessions) + + # 如果消息包含 会话列表 命令(config.trigger.talk_list_command所定义的内容),则输出会话列表 + if message.strip() in config.trigger.talk_list_command: + return "会话列表如下:\n" + '\n'.join(f"{i+1}. {x}" for i, x in enumerate(conversation_manager.get_user_sessions(number))) + + # 如果消息包含会话上限x命令,则进入会话x + max_record_search = re.search(config.trigger.max_talk_sessions, message) + if max_record_search: + conversation_manager.update_max_sessions(number, int(max_record_search.group(1))) + if int(max_record_search.group(1)) > 10: + max_record = 10 + elif int(max_record_search.group(1)) < 1: + max_record = 1 + else: + max_record = int(max_record_search.group(1)) + return f"会话上限已设置为{max_record}!" + session, is_new_session = chatbot.get_chat_session(session_id) # 回滚 @@ -170,14 +192,19 @@ async def friend_message_listener(app: Ariadne, friend: Friend, source: Source, if rate_usage >= 1: response = config.ratelimit.exceed else: - response = await handle_message(friend, f"friend-{friend.id}", chain.display, source) + response = await handle_message(friend, f"fd-{friend.id}-", chain.display, source) if rate_usage >= config.ratelimit.warning_rate: limit = rateLimitManager.get_limit('好友', friend.id) usage = rateLimitManager.get_usage('好友', friend.id) current_time = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(time.time())) response = response + '\n' + config.ratelimit.warning_msg.format(usage=usage['count'], limit=limit['rate'], current_time=current_time) - await respond(friend, source, response) + if len(response) <= 3000: + await respond(friend, source, response) + else: + chunks = [response[i:i+3000] for i in range(0, len(response), 3000)] + for chunk in chunks: + await respond(friend, source, chunk) GroupTrigger = Annotated[MessageChain, MentionMe(config.trigger.require_mention != "at"), DetectPrefix( @@ -193,7 +220,7 @@ async def group_message_listener(group: Group, source: Source, chain: GroupTrigg if rate_usage >= 1: return config.ratelimit.exceed else: - response = await handle_message(group, f"group-{group.id}", chain.display, source) + response = await handle_message(group, f"gp-{group.id}-", chain.display, source) if rate_usage >= config.ratelimit.warning_rate: limit = rateLimitManager.get_limit('群组', group.id) usage = rateLimitManager.get_usage('群组', group.id) diff --git a/chatbot.py b/chatbot.py index 2e5cbb2f..a5b085ee 100644 --- a/chatbot.py +++ b/chatbot.py @@ -4,6 +4,9 @@ from manager.bot import BotManager, BotInfo import atexit from loguru import logger +import conversation_manager +import re +from tinydb import TinyDB, Query import revChatGPT.V1 as V1 config = Config.load_config() @@ -24,10 +27,12 @@ def setup(): class ChatSession: chatbot: BotInfo = None session_id: str + number: str def __init__(self, session_id): self.session_id = session_id - self.prev_conversation_id = [] + self.number = session_id.split('-')[1] + self.prev_conversation_id = None self.prev_parent_id = [] self.parent_id = None self.conversation_id = None @@ -56,35 +61,143 @@ def reset_conversation(self): self.chatbot.bot.delete_conversation(self.conversation_id) self.conversation_id = None self.parent_id = None - self.prev_conversation_id = [] + self.prev_conversation_id = None self.prev_parent_id = [] self.chatbot = botManager.pick() def rollback_conversation(self) -> bool: if len(self.prev_parent_id) <= 0: return False - self.conversation_id = self.prev_conversation_id.pop() + # self.conversation_id = self.prev_conversation_id.pop() self.parent_id = self.prev_parent_id.pop() + # 回滚一次对话 + conversation_manager.rollback_last_parent_id(self.number, + self.conversation_id) return True + # 解析读取历史会话得到的字符串 + def extract_conversations(self, result): + + output = "" + message_mapping = result['mapping'] + current_node_id = result['current_node'] + current_node = message_mapping[current_node_id] + conversation = [] + + while current_node['parent'] is not None: + current_message = current_node['message'] + if current_message is not None: + author_role = current_message['author']['role'] + message_content = current_message['content']['parts'][0] + if author_role == 'user': + conversation.append("you:" + message_content) + elif author_role == 'assistant': + conversation.append("bot:" + message_content) + + current_node_id = current_node['parent'] + current_node = message_mapping[current_node_id] + + for line in reversed(conversation): + output += line + "\n" + + return output + async def get_chat_response(self, message) -> str: - self.prev_conversation_id.append(self.conversation_id) + + self.prev_conversation_id = self.conversation_id self.prev_parent_id.append(self.parent_id) + logger.info(f"当前id:{self.conversation_id}, 父节点id:{self.parent_id}") + + # 如果消息包含进入会话x命令,则进入会话x + into_talk_search = re.search(config.trigger.goto_talk, message) + if into_talk_search: + conversation_parents = conversation_manager.get_last_parent_id(self.number, int(into_talk_search.group(1))) + if conversation_parents: + self.conversation_id = conversation_parents[0] + self.parent_id = conversation_parents[1][-1] + self.prev_conversation_id = conversation_parents[0] + self.prev_parent_id = conversation_parents[1][0:-1] + self.chatbot = botManager.pick_id(conversation_parents[2]) + return f"进入会话{into_talk_search.group(1)}成功!" + else: + return f"进入会话{into_talk_search.group(1)}失败!停留在当前会话" + + # 如果消息包含删除会话x命令,则删除对应会话 + delete_talk_search = re.search(config.trigger.delete_talk, message) + if delete_talk_search: + result = conversation_manager.delete_session_record(self.number, int(delete_talk_search.group(1))) + if result == 2: # 如果删除的恰巧是最后一个,也就是当前会话 + self.reset_conversation() + if result: + return f"删除会话{delete_talk_search.group(1)}成功!" + else: + return f"删除会话{delete_talk_search.group(1)}失败!会话不存在" + + # 如果消息包含读取会话x命令,则读取对应会话 + read_talk_search = re.search(config.trigger.read_talk, message) + if read_talk_search: + conversation_parents = conversation_manager.get_last_parent_id(self.number, int(read_talk_search.group(1))) + if conversation_parents: + chatbot_save = self.chatbot + self.chatbot = botManager.pick_id(conversation_parents[2]) + result = self.chatbot.bot.get_msg_history(conversation_parents[0], encoding='utf-8') + result = self.extract_conversations(result) + self.chatbot = chatbot_save + logger.info( + f"会话{read_talk_search.group(1)}的聊天记录如下:\n{result}") + return f"读取会话成功!\n会话{read_talk_search.group(1)}的聊天记录如下:\n{result}" + else: + return f"读取会话{read_talk_search.group(1)}失败!停留在当前会话" + + # 如果消息包含清空会话命令(config.trigger.rollback_command所定义的内容),则清空会话 + if message.strip() in config.trigger.clear_talk_command: + self.reset_conversation() + if conversation_manager.clear_user_sessions(self.number): + return "清空会话成功!" + else: + return "清空会话失败!你是不是还没对话过。" + + # 如果消息包含 会话名:*** 命令,则删除对应会话 + rename_talk_search = re.search(config.trigger.rename_talk, message) + if rename_talk_search: + if self.conversation_id: + conversation_manager.update_session(self.number, self.conversation_id, rename_talk_search.group(1)) + self.chatbot.bot.change_title(self.conversation_id, ( + self.session_id.encode('unicode-escape') + rename_talk_search.group(1).encode( + 'unicode-escape')).decode('utf-8')) + return f"会话名已改为{rename_talk_search.group(1)}" + else: + return "会话还没开始,不能设置会话名!" bot = self.chatbot.bot + botManager.update_bot_time(self.chatbot.id) bot.conversation_id = self.conversation_id bot.parent_id = self.parent_id + logger.info( + f"当前id:{self.conversation_id}, 父节点id:{self.parent_id}") loop = asyncio.get_event_loop() resp = await loop.run_in_executor(None, self.chatbot.ask, message, self.conversation_id, self.parent_id) - if self.conversation_id is None and self.chatbot.account.title_pattern: - self.chatbot.bot.change_title(resp["conversation_id"], - self.chatbot.account.title_pattern.format(session_id=self.session_id)) + flag = False + if self.conversation_id is None: + flag = True self.conversation_id = resp["conversation_id"] self.parent_id = resp["parent_id"] + # 添加对话记录 + conversation_manager.add_session_record(self.number, self.conversation_id, message[:20], + self.chatbot.account_id, + self.parent_id) + + if flag: + self.chatbot.bot.change_title(self.conversation_id, f"{self.session_id}{message[:20].encode('utf-8')}") + # self.chatbot.bot.change_title(resp["conversation_id"],self.chatbot.account.title_pattern.format(session_id=f"{self.session_id}"+str(message[:20].encode("utf-8"))))########################## + + logger.info( + f"当前id:{self.conversation_id}, 父节点id:{self.parent_id}") + return resp["message"] @@ -92,11 +205,44 @@ async def get_chat_response(self, message) -> str: def get_chat_session(session_id: str) -> Tuple[ChatSession, bool]: + number = session_id.split('-')[1] new_session = False - if session_id not in __sessions: - __sessions[session_id] = ChatSession(session_id) - new_session = True - return __sessions[session_id], new_session + + if number not in __sessions: #有可能是重启容器了(读取旧聊天最后一个),也有可能是新的用户(创建新聊天) + + # 创建一个新的聊天会话 + __sessions[number] = ChatSession(session_id) + + # 创建一个新的聊天会话 + __sessions[number] = ChatSession(session_id) + + # 打开数据库 + db = TinyDB('data/session_records.json') + + # 获取数据表 + table = db.table('session_records') + + # 读取数据 + session_records = table.all() + if len(session_records) > 0: + session_records = session_records[0] + else: + session_records = {} + + # 读取session_records字典 + # with open('session_records.json', 'r') as f: + # session_records = json.loads(f.read()) + if number in session_records: #重启容器了 + if len(session_records[number]["sessions"]) > 0: + __sessions[number].conversation_id = session_records[number]["sessions"][-1]["conversation_id"] + __sessions[number].parent_id = session_records[number]["sessions"][-1]["parent_ids"][-1] + __sessions[number].prev_conversation_id = session_records[number]["sessions"][-1]["conversation_id"] + __sessions[number].prev_parent_id = session_records[number]["sessions"][-1]["parent_ids"][0:-1] + __sessions[number].chatbot = botManager.pick_id(session_records[number]["sessions"][-1]["account_id"]) + else: # 新用户开始聊天 + new_session = True + + return __sessions[number], new_session def conversation_remover(): diff --git a/config.py b/config.py index 15c1c846..5e0599be 100644 --- a/config.py +++ b/config.py @@ -47,13 +47,15 @@ class OpenAIAuthBase(BaseModel): auto_remove_old_conversations: bool = False """自动删除旧的对话""" + class Config(BaseConfig): extra = Extra.allow class OpenAIEmailAuth(OpenAIAuthBase): - email: str + email: str = '' """OpenAI 注册邮箱""" + password: str """OpenAI 密码""" isMicrosoftLogin: bool = False @@ -90,6 +92,9 @@ class TextToImage(BaseModel): """纵坐标""" wkhtmltoimage: Union[str, None] = None +class MaxRecord(BaseModel): + max_sessions: int = 5 + """会话数量上限""" class Trigger(BaseModel): prefix: List[str] = [""] @@ -100,6 +105,23 @@ class Trigger(BaseModel): """重置会话的命令""" rollback_command: List[str] = ["回滚会话"] """回滚会话的命令""" + help_command: List[str] = ["help"] + """请求帮助的命令""" + talk_list_command: List[str] = ["会话列表"] + """会话列表的命令""" + clear_talk_command: List[str] = ["清空会话"] + """清空会话的命令""" + max_talk_sessions: str = r"会话上限\s*(\d+)" + + goto_talk: str = r"进入会话\s*(\d+)" + + delete_talk: str = r"删除会话\s*(\d+)" + + read_talk: str = r"读取会话\s*(\d+)" + + rename_talk: str = r"会话名\s*[::]\s*([^::\s]+)" + + class Response(BaseModel): @@ -153,6 +175,21 @@ class Response(BaseModel): queued_notice: str = "消息已收到!当前我还有{queue_size}条消息要回复,请您稍等。" """新消息进入队列时,发送的通知。 queue_size 是当前排队的消息数""" + help_command: str = ( + "功能列表:\n" + "样例:指令 - 指令的功能\n" + "帮助 - 显示功能列表\n" + "重置会话 - 离开当前会话(保留),开启新的会话\n" + "回滚会话 - 回滚一次对话\n" + "会话名:*** - 设置当前会话名(开始会话之后才能设置)\n" + "会话列表 - 显示自己现有的会话列表\n" + "进入会话x - 进入会话列表中第x个会话\n" + "读取会话x - 读取会话列表中第x的会话的所有会话记录\n" + "删除会话x - 删除会话列表中第x个会话\n" + "清空会话 - 清空自己所有的会话\n" + "会话上限x - 设置自己会话数量上限x,最大为{max_sessions}" + ) + class System(BaseModel): accept_group_invite: bool = False @@ -186,6 +223,7 @@ class Config(BaseModel): response: Response = Response() system: System = System() presets: Preset = Preset() + max_record: MaxRecord = MaxRecord() ratelimit: Ratelimit = Ratelimit() def scan_presets(self): diff --git a/conversation_manager.py b/conversation_manager.py new file mode 100644 index 00000000..f784fa73 --- /dev/null +++ b/conversation_manager.py @@ -0,0 +1,463 @@ +"""本文件用于管理多账号下各用户的会话""" +"""需要注意的是,存储所有用户对话信息的session_records字典单独存储在一个文件,并映射到本地,这样即使容器重装、更新等也能读取之前的对话""" + +# session_records样例如下: +""" +session_records = { + "user1": { + "max_sessions": 5, # 用户1最多允许5个会话记录 + "sessions": [#记录user1的不同会话 + { + "conversation_id": "conv1", + "session_id": "fd-123-对话1" + "account_id": "account_1" + "parent_ids": ["parent1", "parent2"], + }, + { + "conversation_id": "conv2", + "session_id": "fd-123-对话2" + "account_id": "account_2" + "parent_ids": ["parent3"], + }, + # ... 更多的会话记录 + ], + }, + "user2": { + "max_sessions": 10, # 用户2最多允许10个会话记录 + "sessions": [#记录user1的不同会话 + { + "conversation_id": "conv1", + "session_id": "fd-123-对话1" + "account_id": "account_1" + "parent_ids": ["parent1", "parent2"], + }, + { + "conversation_id": "conv2", + "session_id": "fd-123-对话2" + "account_id": "account_2" + "parent_ids": ["parent3"], + }, + # ... 更多的会话记录 + ], + }, + # ... 更多的用户 +} +""" + +from config import Config, OpenAIAuths +from tinydb import TinyDB, Query + +config = Config.load_config() + + +# 添加对话记录 +def add_session_record(user_id, conversation_id, session_id, account_id, parent_id): + if conversation_id == None: + return + + # 打开数据库 + db = TinyDB('data/session_records.json') + + # 获取数据表 + table = db.table('session_records') + + # 读取数据 + session_records = table.all() + if len(session_records) > 0: + session_records = session_records[0] + else: + session_records = {} + # 读取session_records字典 + # with open('session_records.json', 'r') as f: + # session_records = json.loads(f.read()) + + # 检查用户是否存在 + if user_id not in session_records: + session_records[user_id] = { + "max_sessions": 5, + "sessions": [], + } + + # 检查新会话记录是否已经存在于会话记录列表中 + record_found = False + for record in session_records[user_id]["sessions"]: + if record["conversation_id"] == conversation_id: + if record["parent_ids"][-1] != parent_id: + record["parent_ids"].append(parent_id) + record_found = True + break + + # 如果新会话记录不存在,则添加新的会话记录 + if not record_found: + # 检查用户会话记录是否已经达到最大数量 + if len(session_records[user_id]["sessions"]) >= session_records[user_id]["max_sessions"]: + # 删除第一个会话记录,并将新记录插入到末尾 + session_records[user_id]["sessions"].pop(0) + + new_record = { + "conversation_id": conversation_id, + "session_id": session_id, + "account_id": account_id, + "parent_ids": [parent_id], + } + session_records[user_id]["sessions"].append(new_record) + + # 保存session_records字典 + # with open('session_records.json', 'w') as f: + # f.write(json.dumps(session_records)) + # 完全替代原有数据表的内容 + table.write_back = True + table.truncate() # 清空数据表 + table.insert(session_records) + # 关闭数据库 + db.close() + # return session_records # 返回更新后的会话记录字典 + + +# 返回指定用户的会话列表,返回样例:[session_id1, session_id2...] +def get_user_sessions(user_id): + # 打开数据库 + db = TinyDB('data/session_records.json') + + # 获取数据表 + table = db.table('session_records') + + # 读取数据 + session_records = table.all() + if len(session_records) > 0: + session_records = session_records[0] + else: + session_records = {} + # 读取session_records字典 + # with open('session_records.json', 'r') as f: + # session_records = json.loads(f.read()) + + # 检查用户是否存在 + if user_id not in session_records: + return [] # 用户不存在,返回空列表 + + # 获取用户的会话列表 + sessions = session_records[user_id]["sessions"] + + # 提取会话ID和会话记录ID,生成一个嵌套列表 + result = [session["session_id"] for session in sessions] + + return result # 返回用户的会话列表 + + +# 返回指定会话最后一个[conversation_id, parent_ids, account_id] +def get_last_parent_id(user_id, id): + # 打开数据库 + db = TinyDB('data/session_records.json') + + # 获取数据表 + table = db.table('session_records') + + # 读取数据 + session_records = table.all() + if len(session_records) > 0: + session_records = session_records[0] + else: + session_records = {} + # 读取session_records字典 + # with open('session_records.json', 'r') as f: + # session_records = json.loads(f.read()) + + # 检查用户是否存在 + if user_id not in session_records: + return None # 用户不存在,返回None + + # 获取用户的会话列表 + sessions = session_records[user_id]["sessions"] + + # 查找指定会话的会话记录 + record_found = False + if id > 0 and id <= len(sessions): + record_found = True + + # 如果找到了指定会话的会话记录,则返回最后一个parent_id + if record_found: + if len(sessions[id - 1]["parent_ids"]) > 0: + # session_records[user_id]["sessions"][id - 1]["parent_ids"].pop() + # 保存session_records字典 + # with open('session_records.json', 'w') as f: + # f.write(json.dumps(session_records)) + # 完全替代原有数据表的内容 + # table.write_back = True + # table.truncate() # 清空数据表 + # table.insert(session_records) + # 关闭数据库 + # db.close() + + return [sessions[id - 1]["conversation_id"], sessions[id - 1]["parent_ids"], sessions[id - 1]["account_id"]] + else: + return None # 如果parent_ids列表为空,则返回None + else: + return None # 如果没有找到指定会话的会话记录,则返回None + + +# 删除指定会话 +def delete_session_record(user_id, id): + # 打开数据库 + db = TinyDB('data/session_records.json') + + # 获取数据表 + table = db.table('session_records') + + # 读取数据 + session_records = table.all() + if len(session_records) > 0: + session_records = session_records[0] + else: + session_records = {} + # 读取session_records字典 + # with open('session_records.json', 'r') as f: + # session_records = json.loads(f.read()) + + # 检查用户是否存在 + if user_id not in session_records: + return None # 用户不存在,直接返回None + + # 获取用户的会话列表 + sessions = session_records[user_id]["sessions"] + + # 查找指定会话的会话记录 + record_found = False + last_record = False + if id > 0 and id <= len(sessions): + if id == len(sessions): + last_record = True + session_records[user_id]["sessions"].pop(id - 1) + record_found = True + + # 如果找到了指定会话的会话记录,则返回更新后的会话记录字典 + if record_found: + # 保存session_records字典 + # with open('session_records.json', 'w') as f: + # f.write(json.dumps(session_records)) + # 完全替代原有数据表的内容 + table.write_back = True + table.truncate() # 清空数据表 + table.insert(session_records) + # 关闭数据库 + db.close() + + if last_record: # 如果删除的恰巧是最后一个,也就是当前会话 + return 2 + return 1 + else: + return None # 如果没有找到指定会话的会话记录,则直接返回None + + +# 清空指定用户所有会话 +def clear_user_sessions(user_id): + # 打开数据库 + db = TinyDB('data/session_records.json') + + # 获取数据表 + table = db.table('session_records') + + # 读取数据 + session_records = table.all() + if len(session_records) > 0: + session_records = session_records[0] + else: + session_records = {} + # 读取session_records字典 + # with open('session_records.json', 'r') as f: + # session_records = json.loads(f.read()) + + # 检查用户是否存在 + if user_id not in session_records: + return None # 用户不存在,直接返回会话记录字典 + + # 删除指定用户的所有会话记录 + session_records[user_id]["sessions"] = [] + + # 保存session_records字典 + # with open('session_records.json', 'w') as f: + # f.write(json.dumps(session_records)) + # 完全替代原有数据表的内容 + table.write_back = True + table.truncate() # 清空数据表 + table.insert(session_records) + # 关闭数据库 + db.close() + + return True + + +# 回滚会话,返回parent_id,否则返回None +def rollback_last_parent_id(user_id, conversation_id): + # 打开数据库 + db = TinyDB('data/session_records.json') + + # 获取数据表 + table = db.table('session_records') + + # 读取数据 + session_records = table.all() + if len(session_records) > 0: + session_records = session_records[0] + else: + session_records = {} + # 读取session_records字典 + # with open('session_records.json', 'r') as f: + # session_records = json.loads(f.read()) + + # 检查用户是否存在 + if user_id not in session_records: + return None # 用户不存在,返回None + + # 获取用户的会话列表 + sessions = session_records[user_id]["sessions"] + + # 查找指定会话的会话记录 + record_found = False + for session in sessions: + if session["conversation_id"] == conversation_id: + record_found = True + break + + # 如果找到了指定会话的会话记录,则回滚最后一个parent_id + if record_found: + parent_ids = session["parent_ids"] + if len(parent_ids) > 0: + parent_ids.pop(-1) + + # 保存session_records字典 + # with open('session_records.json', 'w') as f: + # f.write(json.dumps(session_records)) + # 完全替代原有数据表的内容 + table.write_back = True + table.truncate() # 清空数据表 + table.insert(session_records) + # 关闭数据库 + db.close() + + if len(parent_ids) > 0: + return parent_ids[-1] # 如果还有剩余的parent_id,则返回新的最后一个parent_id + else: + return None # 如果parent_ids列表为空,则返回None + else: + return None # 如果parent_ids列表为空,则返回None + else: + return None # 如果没有找到指定会话的会话记录,则返回None + + +# 更新指定用户指定会话的会话名 +def update_session(user_id, conversation_id, session_id): + # 打开数据库 + db = TinyDB('data/session_records.json') + + # 获取数据表 + table = db.table('session_records') + + # 读取数据 + session_records = table.all() + if len(session_records) > 0: + session_records = session_records[0] + else: + session_records = {} + # 读取session_records字典 + # with open('session_records.json', 'r') as f: + # session_records = json.loads(f.read()) + + if user_id not in session_records: + # 如果用户不存在,则返回None + return None + + for i in range(len(session_records[user_id]["sessions"])): + if session_records[user_id]["sessions"][i]["conversation_id"] == conversation_id: + # 找到会话,更新session_id + session_records[user_id]["sessions"][i]["session_id"] = session_id + # 保存session_records字典 + # with open('session_records.json', 'w') as f: + # f.write(json.dumps(session_records)) + # 完全替代原有数据表的内容 + table.write_back = True + table.truncate() # 清空数据表 + table.insert(session_records) + # 关闭数据库 + db.close() + + return + + # 修改指定用户的会话上限 + + +def update_max_sessions(user_id, max_sessions): + # 打开数据库 + db = TinyDB('data/session_records.json') + + # 获取数据表 + table = db.table('session_records') + + # 读取数据 + session_records = table.all() + if len(session_records) > 0: + session_records = session_records[0] + else: + session_records = {} + # 读取session_records字典 + # with open('session_records.json', 'r') as f: + # session_records = json.loads(f.read()) + + if max_sessions > config.max_record.max_sessions: + max_sessions = config.max_record.max_sessions + elif max_sessions < 1: + max_sessions = 1 + + if user_id in session_records: + session_records[user_id]["max_sessions"] = max_sessions + else: + session_records[user_id] = {"max_sessions": max_sessions, "sessions": []} + + while max_sessions < len(session_records[user_id]["sessions"]): + session_records[user_id]["sessions"].pop(0) + + # 保存session_records字典 + # with open('session_records.json', 'w') as f: + # f.write(json.dumps(session_records)) + # 完全替代原有数据表的内容 + table.write_back = True + table.truncate() # 清空数据表 + table.insert(session_records) + # 关闭数据库 + db.close() + + +# 更新account_id(session_token登录的用户补上了email) +def update_account_id(account_id_before, account_id_after): + # 打开数据库 + db = TinyDB('data/session_records.json') + + # 获取数据表 + table = db.table('session_records') + + # 读取数据 + session_records = table.all() + if len(session_records) > 0: + session_records = session_records[0] + else: + session_records = {} + # 读取session_records字典 + # with open('session_records.json', 'r') as f: + # session_records = json.loads(f.read()) + + # 遍历所有用户的会话记录 + for user in session_records: + for session in session_records[user]["sessions"]: + # 如果会话记录中包含原始的account_id,则将其替换为新的account_id + if session["account_id"] == account_id_before: + session["account_id"] = account_id_after + + # 保存session_records字典 + # with open('session_records.json', 'w') as f: + # f.write(json.dumps(session_records)) + # 完全替代原有数据表的内容 + table.write_back = True + table.truncate() # 清空数据表 + table.insert(session_records) + # 关闭数据库 + db.close() diff --git a/manager/bot.py b/manager/bot.py index 0b4de42d..41607f2e 100644 --- a/manager/bot.py +++ b/manager/bot.py @@ -1,7 +1,11 @@ import datetime import os import sys + import time +import datetime +import conversation_manager + from requests.exceptions import SSLError @@ -21,17 +25,22 @@ import utils.network as network from tinydb import TinyDB, Query import hashlib - +import base64 +import json config = Config.load_config() class BotInfo(asyncio.Lock): id = 0 + account_id: str #记录当前账户的user_id,用于定位账户 + account: OpenAIAuthBase bot: Union[V1Chatbot, BrowserChatbot] + time: datetime.datetime + mode: str queue_size: int = 0 @@ -123,7 +132,7 @@ def __init__(self, accounts: List[Union[OpenAIEmailAuth, OpenAISessionTokenAuth] def login(self): for i, account in enumerate(self.accounts): - logger.info("正在登录第 {i} 个 OpenAI 账号", i=i + 1) + logger.info(f"正在登录第 {i+1} 个 OpenAI 账号\n") try: if account.mode == "proxy" or account.mode == "browserless": bot = self.__login_V1(account) @@ -133,6 +142,13 @@ def login(self): raise Exception("未定义的登录类型:" + account.mode) bot.id = i bot.account = account + bot.time = datetime.datetime.now() + bot_access_token = bot.bot.session.headers.get('Authorization').removeprefix('Bearer ') + bot_access_token = bot_access_token.split('.') + bot_access_token = (base64.urlsafe_b64decode(bot_access_token[1] + '=' * (4 - len(bot_access_token[1]) % 4))).decode('utf-8') + bot_access_token = json.loads(bot_access_token) + bot.account_id = bot_access_token["https://api.openai.com/auth"]["user_id"] #从bot_access_token中获取user_id + logger.success(f"该用户的user_id = {bot.account_id}") self.bots.append(bot) logger.success("登录成功!", i=i + 1) logger.debug("等待 8 秒……") @@ -153,7 +169,9 @@ def login(self): if len(self.bots) < 1: logger.error("所有账号均登录失败,无法继续启动!") exit(-2) - logger.success(f"成功登录 {len(self.bots)}/{len(self.accounts)} 个账号!") + logger.success(f"成功登录 {len(self.bots)}/{len(self.accounts)} 个账号!\n") + + def __login_browser(self, account) -> BotInfo: logger.info("模式:浏览器登录") @@ -240,7 +258,20 @@ def get_access_token(bot: V1Chatbot): self.__save_login_cache(account=account, cache={}) raise Exception("All login method failed") - def pick(self) -> BotInfo: - if self.roundrobin is None: - self.roundrobin = itertools.cycle(self.bots) - return next(self.roundrobin) + def pick(self) -> BotInfo: #选一个bot.time最小的,说明闲置时间最长 + # if self.roundrobin is None: + # self.roundrobin = itertools.cycle(self.bots) + # return next(self.roundrobin) + id = 0 + for i in range(1, len(self.bots)): + if self.bots[id].time > self.bots[i].time: + id = i + return self.bots[id] + + def pick_id(self, account_id) -> BotInfo: #返回指定id的账号 + for i in range(len(self.bots)): + if self.bots[i].account_id == account_id: + return self.bots[i] + + def update_bot_time(self, id): #更新指定bot的时间,在向机器人发消息的时候就更新bot的时间 + self.bots[id].time = datetime.datetime.now() \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index 6cbda67d..6ab91731 100644 --- a/requirements.txt +++ b/requirements.txt @@ -3,7 +3,6 @@ revChatGPT[unofficial]==2.3.11 toml~=0.10.2 Pillow>=9.3.0 tinydb~=4.7.1 - loguru~=0.6.0 asyncio~=3.4.3 pydantic~=1.10.5