diff --git a/src/mistralai/extra/mcp/base.py b/src/mistralai/extra/mcp/base.py index 8be5585c..20083b2a 100644 --- a/src/mistralai/extra/mcp/base.py +++ b/src/mistralai/extra/mcp/base.py @@ -29,27 +29,21 @@ class MCPClientProtocol(Protocol): _name: str - async def initialize(self, exit_stack: Optional[AsyncExitStack]) -> None: - ... + async def initialize(self, exit_stack: Optional[AsyncExitStack]) -> None: ... - async def aclose(self) -> None: - ... + async def aclose(self) -> None: ... - async def get_tools(self) -> list[FunctionTool]: - ... + async def get_tools(self) -> list[FunctionTool]: ... async def execute_tool( self, name: str, arguments: dict - ) -> list[TextChunkTypedDict]: - ... + ) -> list[TextChunkTypedDict]: ... async def get_system_prompt( self, name: str, arguments: dict[str, Any] - ) -> MCPSystemPrompt: - ... + ) -> MCPSystemPrompt: ... - async def list_system_prompts(self) -> ListPromptsResult: - ... + async def list_system_prompts(self) -> ListPromptsResult: ... class MCPClientBase(MCPClientProtocol): @@ -65,7 +59,7 @@ def __init__(self, name: Optional[str] = None): def _convert_content( self, mcp_content: Union[TextContent, ImageContent, EmbeddedResource] ) -> TextChunkTypedDict: - if not mcp_content.type == "text": + if mcp_content.type != "text": raise MCPException("Only supporting text tool responses for now.") return {"type": "text", "text": mcp_content.text} @@ -107,13 +101,10 @@ async def get_system_prompt( return { "description": prompt_result.description, "messages": [ - typing.cast( - Union[SystemMessageTypedDict, AssistantMessageTypedDict], - { - "role": message.role, - "content": self._convert_content(mcp_content=message.content), - }, - ) + { + "role": message.role, + "content": self._convert_content(mcp_content=message.content), + } for message in prompt_result.messages ], }