Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 10 additions & 8 deletions cecli/coders/base_coder.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,8 +202,8 @@ async def create(
main_model = models.Model(models.DEFAULT_MODEL_NAME, io=io)

if edit_format == "code":
edit_format = None
if edit_format is None:
edit_format = main_model.edit_format
elif edit_format is None:
if from_coder:
edit_format = from_coder.edit_format
else:
Expand Down Expand Up @@ -259,12 +259,14 @@ async def create(

if res is not None:
if from_coder:
if from_coder.mcp_manager:
res.mcp_manager = from_coder.mcp_manager

# Transfer TUI app weak reference
res.tui = from_coder.tui
res.context_management_enabled = from_coder.context_management_enabled
if res.mcp_manager:
# When switching to a non-agent coder, disconnect the "Local" MCP server
# (which provides agent-only tools like tool calling and file editing)
# so it's not available in non-agent modes.
if not isinstance(res, coders.AgentCoder):
local_server = res.mcp_manager.get_server("Local")
if local_server and local_server.is_connected:
await res.mcp_manager.disconnect_server("Local")

await res.initialize_mcp_tools()

Expand Down
11 changes: 5 additions & 6 deletions cecli/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -1239,17 +1239,16 @@ def get_io(pretty):
kwargs["num_cache_warming_pings"] = 0
kwargs["args"] = coder.args

if kwargs["edit_format"] != AgentCoder.edit_format and (
coder := kwargs.get("from_coder")
):
if coder.mcp_manager.get_server("Local"):
await coder.mcp_manager.disconnect_server("Local")

for tag in [MessageTag.SYSTEM, MessageTag.EXAMPLES, MessageTag.STATIC]:
ConversationService.get_manager(coder).clear_tag(tag)

old_coder = coder
coder = await Coder.create(**kwargs)

if isinstance(old_coder, AgentCoder) and not isinstance(coder, AgentCoder):
if coder.mcp_manager and coder.mcp_manager.get_server("Local"):
await coder.mcp_manager.disconnect_server("Local")

if switch.kwargs.get("show_announcements") is False:
coder.suppress_announcements_for_next_prompt = True

Expand Down
73 changes: 73 additions & 0 deletions tests/coders/test_coder_switching.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
import asyncio
import unittest
from unittest.mock import MagicMock, patch

from cecli.coders import Coder


class TestCoderSwitching(unittest.TestCase):
@patch("cecli.coders.agent_coder.ToolRegistry")
def test_switch_from_agent_to_non_agent(self, mock_tool_registry):
async def run_test():
# Mock dependencies
io = MagicMock()
args = MagicMock()
args.agent_config = "{}"
args.verbose = False
args.tui = False
args.show_thinking = True
args.auto_save = False
args.file_diffs = True
args.max_reflections = 3
main_model = MagicMock()
main_model.edit_format = "code"
main_model.agent_model = None
main_model.weak_model = None
main_model.editor_model = None
main_model.get_repo_map_tokens.return_value = 1024
main_model.info = {}
main_model.name = "test-model"
main_model.reasoning_tag = "think"
main_model.get_active_model.return_value = main_model

mock_tool_registry.get_registered_tools.return_value = ["edittext"]
mock_tool_registry.get_tool.return_value = MagicMock()
mock_tool_registry.build_registry.return_value = None

# 1. Start with an AgentCoder
agent_coder = await Coder.create(
main_model=main_model,
edit_format="agent",
io=io,
args=args,
)
from cecli.coders import AgentCoder

self.assertIsInstance(agent_coder, AgentCoder)
self.assertTrue(agent_coder.mcp_manager.get_server("Local").is_connected)

# 2. Switch to a non-agent coder
code_coder = await Coder.create(
from_coder=agent_coder,
edit_format="code",
)
self.assertNotIsInstance(code_coder, AgentCoder)

# 3. Check that "Local" server is disconnected
self.assertFalse(code_coder.mcp_manager.get_server("Local").is_connected)

# 4. Switch back to agent coder
new_agent_coder = await Coder.create(
from_coder=code_coder,
edit_format="agent",
)
self.assertIsInstance(new_agent_coder, AgentCoder)

# 5. Check that "Local" server is re-connected
self.assertTrue(new_agent_coder.mcp_manager.get_server("Local").is_connected)

asyncio.run(run_test())


if __name__ == "__main__":
unittest.main()
Loading