diff --git a/cecli/coders/base_coder.py b/cecli/coders/base_coder.py index fd205357282..30e23aacd11 100755 --- a/cecli/coders/base_coder.py +++ b/cecli/coders/base_coder.py @@ -202,8 +202,8 @@ async def create( main_model = models.Model(models.DEFAULT_MODEL_NAME, io=io) if edit_format == "code": - edit_format = None - if edit_format is None: + edit_format = main_model.edit_format + elif edit_format is None: if from_coder: edit_format = from_coder.edit_format else: @@ -259,12 +259,14 @@ async def create( if res is not None: if from_coder: - if from_coder.mcp_manager: - res.mcp_manager = from_coder.mcp_manager - - # Transfer TUI app weak reference - res.tui = from_coder.tui - res.context_management_enabled = from_coder.context_management_enabled + if res.mcp_manager: + # When switching to a non-agent coder, disconnect the "Local" MCP server + # (which provides agent-only tools like tool calling and file editing) + # so it's not available in non-agent modes. + if not isinstance(res, coders.AgentCoder): + local_server = res.mcp_manager.get_server("Local") + if local_server and local_server.is_connected: + await res.mcp_manager.disconnect_server("Local") await res.initialize_mcp_tools() diff --git a/cecli/main.py b/cecli/main.py index 89f637e0910..eab1e8ccb2b 100644 --- a/cecli/main.py +++ b/cecli/main.py @@ -1239,17 +1239,16 @@ def get_io(pretty): kwargs["num_cache_warming_pings"] = 0 kwargs["args"] = coder.args - if kwargs["edit_format"] != AgentCoder.edit_format and ( - coder := kwargs.get("from_coder") - ): - if coder.mcp_manager.get_server("Local"): - await coder.mcp_manager.disconnect_server("Local") - for tag in [MessageTag.SYSTEM, MessageTag.EXAMPLES, MessageTag.STATIC]: ConversationService.get_manager(coder).clear_tag(tag) + old_coder = coder coder = await Coder.create(**kwargs) + if isinstance(old_coder, AgentCoder) and not isinstance(coder, AgentCoder): + if coder.mcp_manager and coder.mcp_manager.get_server("Local"): + await coder.mcp_manager.disconnect_server("Local") + if switch.kwargs.get("show_announcements") is False: coder.suppress_announcements_for_next_prompt = True diff --git a/tests/coders/test_coder_switching.py b/tests/coders/test_coder_switching.py new file mode 100644 index 00000000000..f00bc72b637 --- /dev/null +++ b/tests/coders/test_coder_switching.py @@ -0,0 +1,73 @@ +import asyncio +import unittest +from unittest.mock import MagicMock, patch + +from cecli.coders import Coder + + +class TestCoderSwitching(unittest.TestCase): + @patch("cecli.coders.agent_coder.ToolRegistry") + def test_switch_from_agent_to_non_agent(self, mock_tool_registry): + async def run_test(): + # Mock dependencies + io = MagicMock() + args = MagicMock() + args.agent_config = "{}" + args.verbose = False + args.tui = False + args.show_thinking = True + args.auto_save = False + args.file_diffs = True + args.max_reflections = 3 + main_model = MagicMock() + main_model.edit_format = "code" + main_model.agent_model = None + main_model.weak_model = None + main_model.editor_model = None + main_model.get_repo_map_tokens.return_value = 1024 + main_model.info = {} + main_model.name = "test-model" + main_model.reasoning_tag = "think" + main_model.get_active_model.return_value = main_model + + mock_tool_registry.get_registered_tools.return_value = ["edittext"] + mock_tool_registry.get_tool.return_value = MagicMock() + mock_tool_registry.build_registry.return_value = None + + # 1. Start with an AgentCoder + agent_coder = await Coder.create( + main_model=main_model, + edit_format="agent", + io=io, + args=args, + ) + from cecli.coders import AgentCoder + + self.assertIsInstance(agent_coder, AgentCoder) + self.assertTrue(agent_coder.mcp_manager.get_server("Local").is_connected) + + # 2. Switch to a non-agent coder + code_coder = await Coder.create( + from_coder=agent_coder, + edit_format="code", + ) + self.assertNotIsInstance(code_coder, AgentCoder) + + # 3. Check that "Local" server is disconnected + self.assertFalse(code_coder.mcp_manager.get_server("Local").is_connected) + + # 4. Switch back to agent coder + new_agent_coder = await Coder.create( + from_coder=code_coder, + edit_format="agent", + ) + self.assertIsInstance(new_agent_coder, AgentCoder) + + # 5. Check that "Local" server is re-connected + self.assertTrue(new_agent_coder.mcp_manager.get_server("Local").is_connected) + + asyncio.run(run_test()) + + +if __name__ == "__main__": + unittest.main()