Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

添加MistralAI的Adapter #1329

Open
wants to merge 1 commit into
base: browser-version
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
326 changes: 326 additions & 0 deletions adapter/mistral/mistral.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,326 @@
import json
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

issue (llm): I noticed that the new adapter for MistralAI has been added, but there are no unit tests accompanying this addition. It's crucial to have unit tests to verify the behavior of the new adapter, especially for methods like rollback, add_to_conversation, count_tokens, request, and ask. Could you please add unit tests to ensure these methods work as expected under various conditions?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

suggestion (llm): The addition of the MistralAI adapter introduces significant new functionality. To ensure the adapter's resilience and maintainability, consider adding end-to-end tests that cover the full lifecycle of a conversation through the adapter, including error handling and retries for API failures. This will help ensure the adapter behaves correctly in the context of the larger application.

import time
import aiohttp
import async_timeout
import tiktoken
from loguru import logger
from typing import AsyncGenerator

from adapter.botservice import BotAdapter
from config import MistralAIAPIKey
from constants import botManager, config

DEFAULT_ENGINE: str = "mistral-large-latest"


class MistralAIChatbot:
def __init__(self, api_info: MistralAIAPIKey):
self.api_key = api_info.api_key
self.proxy = api_info.proxy
self.top_p = config.mistral.mistral_params.top_p
self.temperature = config.mistral.mistral_params.temperature
self.max_tokens = config.mistral.mistral_params.max_tokens
self.engine = api_info.model or DEFAULT_ENGINE
self.timeout = config.response.max_timeout
self.conversation: dict[str, list[dict]] = {
"default": [
{
"role": "system",
"content": "你是 MistralAI,现在需要用中文进行交流。",
},
],
}

async def rollback(self, session_id: str = "default", n: int = 1) -> None:
try:
if session_id not in self.conversation:
raise ValueError(f"会话 ID {session_id} 不存在。")

if n > len(self.conversation[session_id]):
raise ValueError(f"回滚次数 {n} 超过了会话 {session_id} 的消息数量。")

for _ in range(n):
self.conversation[session_id].pop()

except ValueError as ve:
logger.error(ve)
raise
except Exception as e:
logger.error(f"未知错误: {e}")
raise

def add_to_conversation(self, message: str, role: str, session_id: str = "default") -> None:
if role and message is not None:
self.conversation[session_id].append({"role": role, "content": message})
else:
logger.warning("出现错误!返回消息为空,不添加到会话。")
raise ValueError("出现错误!返回消息为空,不添加到会话。")

# https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb
def count_tokens(self, session_id: str = "default", model: str = DEFAULT_ENGINE):
"""Return the number of tokens used by a list of messages."""
if model is None:
model = DEFAULT_ENGINE
try:
encoding = tiktoken.encoding_for_model(model)
except KeyError:
encoding = tiktoken.get_encoding("cl100k_base")

tokens_per_message = 4
tokens_per_name = 1

num_tokens = 0
for message in self.conversation[session_id]:
num_tokens += tokens_per_message
for key, value in message.items():
if value is not None:
num_tokens += len(encoding.encode(value))
if key == "name":
num_tokens += tokens_per_name
num_tokens += 3 # every reply is primed with assistant
return num_tokens

def get_max_tokens(self, session_id: str, model: str) -> int:
"""Get max tokens"""
return self.max_tokens - self.count_tokens(session_id, model)


class MistralAIAPIAdapter(BotAdapter):
api_info: MistralAIAPIKey = None
"""API Key"""

def __init__(self, session_id: str = "unknown"):
self.latest_role = None
self.__conversation_keep_from = 0
self.session_id = session_id
self.api_info = botManager.pick('mistral')
self.bot = MistralAIChatbot(self.api_info)
self.conversation_id = None
self.parent_id = None
super().__init__()
self.bot.conversation[self.session_id] = []
self.current_model = self.bot.engine
self.supported_models = [
"mistral-large-latest",
"mistral-medium-latest",
"mistral-small-latest",
"open-mixtral-8x7b",
"open-mistral-7b",
]

def manage_conversation(self, session_id: str, prompt: str):
if session_id not in self.bot.conversation:
self.bot.conversation[session_id] = [
{"role": "system", "content": prompt}
]
self.__conversation_keep_from = 1

while self.bot.max_tokens - self.bot.count_tokens(session_id) < config.mistral.mistral_params.min_tokens and \
len(self.bot.conversation[session_id]) > self.__conversation_keep_from:
self.bot.conversation[session_id].pop(self.__conversation_keep_from)
logger.debug(
f"清理 token,历史记录遗忘后使用 token 数:{str(self.bot.count_tokens(session_id))}"
)

async def switch_model(self, model_name):
self.current_model = model_name
self.bot.engine = self.current_model

async def rollback(self):
if len(self.bot.conversation[self.session_id]) <= 0:
return False
await self.bot.rollback(self.session_id, n=2)
return True

async def on_reset(self):
self.api_info = botManager.pick('mistral')
self.bot.api_key = self.api_info.api_key
self.bot.proxy = self.api_info.proxy
self.bot.conversation[self.session_id] = []
self.bot.engine = self.current_model
self.__conversation_keep_from = 0

def construct_data(self, messages: list = None, api_key: str = None, stream: bool = True):
headers = {
'Content-Type': 'application/json',
'Authorization': f'Bearer {api_key}'
}
data = {
'model': self.bot.engine,
'messages': messages,
'stream': stream,
'temperature': self.bot.temperature,
'top_p': self.bot.top_p,
'max_tokens': self.bot.get_max_tokens(self.session_id, self.bot.engine),
}
return headers, data

def _prepare_request(self, session_id: str = None, messages: list = None, stream: bool = False):
self.api_info = botManager.pick('mistral')
api_key = self.api_info.api_key
proxy = self.api_info.proxy
api_endpoint = config.mistral.api_endpoint or "https://api.mistral.ai/v1"

if not messages:
messages = self.bot.conversation[session_id]

headers, data = self.construct_data(messages, api_key, stream)

return proxy, api_endpoint, headers, data

async def _process_response(self, resp, session_id: str = None):

result = await resp.json()

total_tokens = result.get('usage', {}).get('total_tokens', None)
logger.debug(f"[MistralAI-API:{self.bot.engine}] 使用 token 数:{total_tokens}")
if total_tokens is None:
raise Exception("Response does not contain 'total_tokens'")

content = result.get('choices', [{}])[0].get('message', {}).get('content', None)
logger.debug(f"[MistralAI-API:{self.bot.engine}] 响应:{content}")
if content is None:
raise Exception("Response does not contain 'content'")

response_role = result.get('choices', [{}])[0].get('message', {}).get('role', None)
if response_role is None:
raise Exception("Response does not contain 'role'")

self.bot.add_to_conversation(content, response_role, session_id)

return content

async def request(self, session_id: str = None, messages: list = None) -> str:
proxy, api_endpoint, headers, data = self._prepare_request(session_id, messages, stream=False)

async with aiohttp.ClientSession() as session:
with async_timeout.timeout(self.bot.timeout):
async with session.post(f'{api_endpoint}/chat/completions', headers=headers,
data=json.dumps(data), proxy=proxy) as resp:
if resp.status != 200:
response_text = await resp.text()
raise Exception(
f"{resp.status} {resp.reason} {response_text}",
)
return await self._process_response(resp, session_id)

async def request_with_stream(self, session_id: str = None, messages: list = None) -> AsyncGenerator[str, None]:
proxy, api_endpoint, headers, data = self._prepare_request(session_id, messages, stream=True)

async with aiohttp.ClientSession() as session:
with async_timeout.timeout(self.bot.timeout):
async with session.post(f'{api_endpoint}/chat/completions', headers=headers, data=json.dumps(data),
proxy=proxy) as resp:
if resp.status != 200:
response_text = await resp.text()
raise Exception(
f"{resp.status} {resp.reason} {response_text}",
)

response_role: str = ''
completion_text: str = ''

async for line in resp.content:
try:
line = line.decode('utf-8').strip()
if not line.startswith("data: "):
continue
line = line[len("data: "):]
if line == "[DONE]":
break
if not line:
continue
event = json.loads(line)
except json.JSONDecodeError:
raise Exception(f"JSON解码错误: {line}") from None
except Exception as e:
logger.error(f"未知错误: {e}\n响应内容: {resp.content}")
logger.error("请将该段日记提交到项目issue中,以便修复该问题。")
raise Exception(f"未知错误: {e}") from None
if 'error' in event:
raise Exception(f"响应错误: {event['error']}")
if 'choices' in event and len(event['choices']) > 0 and 'delta' in event['choices'][0]:
delta = event['choices'][0]['delta']
if 'role' in delta:
if delta['role'] is not None:
response_role = delta['role']
if 'content' in delta:
event_text = delta['content']
if event_text is not None:
completion_text += event_text
self.latest_role = response_role
yield event_text
self.bot.add_to_conversation(completion_text, response_role, session_id)

async def compressed_session(self, session_id: str):
if session_id not in self.bot.conversation or not self.bot.conversation[session_id]:
logger.debug(f"不存在该会话,不进行压缩: {session_id}")
return

if self.bot.count_tokens(session_id) > config.mistral.mistral_params.compressed_tokens:
logger.debug('开始进行会话压缩')

filtered_data = [entry for entry in self.bot.conversation[session_id] if entry['role'] != 'system']
self.bot.conversation[session_id] = [entry for entry in self.bot.conversation[session_id] if
entry['role'] not in ['assistant', 'user']]

filtered_data.append(({"role": "system",
"content": "Summarize the discussion briefly in 200 words or less to use as a prompt for future context."}))

async for text in self.request_with_stream(session_id=session_id, messages=filtered_data):
pass

token_count = self.bot.count_tokens(self.session_id, self.bot.engine)
logger.debug(f"压缩会话后使用 token 数:{token_count}")

async def ask(self, prompt: str) -> AsyncGenerator[str, None]:
"""Send a message to api and return the response with stream."""

self.manage_conversation(self.session_id, prompt)

if config.mistral.mistral_params.compressed_session:
await self.compressed_session(self.session_id)

event_time = None

try:
if self.bot.engine not in self.supported_models:
logger.warning(f"当前模型非官方支持的模型,请注意控制台输出,当前使用的模型为 {self.bot.engine}")
logger.debug(f"[尝试使用MistralAI-API:{self.bot.engine}] 请求:{prompt}")
self.bot.add_to_conversation(prompt, "user", session_id=self.session_id)
start_time = time.time()

full_response = ''

if config.mistral.mistral_params.stream:
async for resp in self.request_with_stream(session_id=self.session_id):
full_response += resp
yield full_response

token_count = self.bot.count_tokens(self.session_id, self.bot.engine)
logger.debug(f"[MistralAI-API:{self.bot.engine}] 响应:{full_response}")
logger.debug(f"[MistralAI-API:{self.bot.engine}] 使用 token 数:{token_count}")
else:
yield await self.request(session_id=self.session_id)
event_time = time.time() - start_time
if event_time is not None:
logger.debug(f"[MistralAI-API:{self.bot.engine}] 接收到全部消息花费了{event_time:.2f}秒")

except Exception as e:
logger.error(f"[MistralAI-API:{self.bot.engine}] 请求失败:\n{e}")
yield f"发生错误: \n{e}"
raise

async def preset_ask(self, role: str, text: str):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

suggestion (llm): It's great to see the implementation of preset_ask for handling predefined interactions. However, to ensure its reliability and correctness, could you add integration tests that simulate real-world scenarios where preset_ask is used? This would help in catching any potential issues with the interaction between this method and the MistralAI API.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

suggestion (llm): Given the complexity of managing conversations with different roles (assistant, user, system), it would be beneficial to have tests that specifically verify the conversation state is managed correctly over multiple interactions. This includes testing the conversation history is accurately maintained and that role-based behaviors are correctly implemented.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

praise (llm): Praise for implementing the preset_ask method to handle predefined interactions. This is a valuable feature for simulating specific scenarios or testing. However, to fully leverage its potential, consider documenting example use cases or scenarios where preset_ask could be particularly useful. This documentation could be included in the method comments or a separate documentation file.

self.bot.engine = self.current_model
if role.endswith('bot') or role in {'assistant', 'mistral'}:
logger.debug(f"[预设] 响应:{text}")
yield text
role = 'assistant'
if role not in ['assistant', 'user', 'system']:
raise ValueError(f"预设文本有误!仅支持设定 assistant、user 或 system 的预设文本,但你写了{role}。")
if self.session_id not in self.bot.conversation:
self.bot.conversation[self.session_id] = []
self.__conversation_keep_from = 0
self.bot.conversation[self.session_id].append({"role": role, "content": text})
self.__conversation_keep_from = len(self.bot.conversation[self.session_id])
10 changes: 10 additions & 0 deletions config.example.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,16 @@ alias = 'g4f-chatgpt'
# ping bot时针对此AI的描述
description = 'gpt4free的gpt-3.5-turbo'

[mistral]
api_endpoint = "https://api.mistral.ai/v1"
safe_prompt = true
temperature = 0.7
top_p = 1.0

[[mistral.accounts]]
api_key = ""
# proxy="http://127.0.0.1:7890"

[presets]
# 切换预设的命令: 加载预设 猫娘
command = "加载预设 (\\w+)"
Expand Down
32 changes: 32 additions & 0 deletions config.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,6 +279,37 @@ class G4fAuths(BaseModel):
"""支持的模型"""


class MistralAIParams(BaseModel):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

issue (llm): The introduction of new classes (MistralAIParams, MistralAIAPIKey, MistralAuths) and the associated configuration options significantly increases the complexity of the codebase. While these changes add structure and potentially enhance flexibility, they also introduce a higher cognitive load for understanding and managing configurations. Here are a few suggestions to consider for reducing complexity:

  1. Consolidate Configuration Classes: If MistralAIParams and MistralAIAPIKey share common parameters or can logically be grouped, consider merging them or creating a hierarchical structure. This could reduce the need to navigate between multiple classes to understand related configurations.

  2. Utilize Inheritance: If the new classes share attributes or methods with existing ones, leveraging inheritance or mixins could help in sharing code and reducing redundancy, making the overall structure cleaner.

  3. Simplify Default Values Management: Consider using a configuration file (JSON, YAML, etc.) for managing default values. This approach can make the code cleaner and the configuration process more flexible, as it externalizes configuration from the code.

  4. Review the Necessity of Nested Classes: The use of nested classes, such as in MistralAuths, adds another layer of complexity. Evaluate if this nesting is essential or if there's a simpler way to achieve the same functionality.

By addressing these points, we can aim for a balance between flexibility and maintainability, ensuring the code remains accessible and manageable as it evolves.

temperature: float = 0.7
max_tokens: int = 4000
top_p: float = 1.0
min_tokens: int = 1000
compressed_session: bool = False
compressed_tokens: int = 1000
stream: bool = True


class MistralAIAPIKey(BaseModel):
api_key: str
"""自定义 Mistral API 的Key"""
model: Optional[str] = "mistral-large-latest"
"""使用的默认模型,此选项优先级最高"""
proxy: Optional[str] = None
"""可选的代理地址,留空则检测系统代理"""


class MistralAuths(BaseModel):
api_endpoint: Optional[str] = None
"""自定义 Mistral API 的接入点"""
temperature: float = 0.7
top_p: float = 1.0

mistral_params: MistralAIParams = MistralAIParams()

accounts: List[MistralAIAPIKey] = []
"""MistralAI的账号列表"""


class SlackAppAccessToken(BaseModel):
channel_id: str
"""负责与机器人交互的 Channel ID"""
Expand Down Expand Up @@ -563,6 +594,7 @@ class Config(BaseModel):
slack: SlackAuths = SlackAuths()
xinghuo: XinghuoAuths = XinghuoAuths()
gpt4free: G4fAuths = G4fAuths()
mistral: MistralAuths = MistralAuths()

# === Response Settings ===
text_to_image: TextToImage = TextToImage()
Expand Down
3 changes: 3 additions & 0 deletions constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,9 @@ class LlmName(Enum):
YiYan = "yiyan"
ChatGLM = "chatglm-api"
XunfeiXinghuo = "xinghuo"
MistralSmall = "mistral-small-latest"
MistralMedium = "mistral-medium-latest"
MistralLarge = "mistral-large-latest"


class BotPlatform(Enum):
Expand Down
Loading