diff --git a/cecli/coders/agent_coder.py b/cecli/coders/agent_coder.py index 8524d707185..2050c2ddbcf 100644 --- a/cecli/coders/agent_coder.py +++ b/cecli/coders/agent_coder.py @@ -745,7 +745,13 @@ async def gather_and_await(): if self.auto_lint and used_write_tool: edited = list(self.files_edited_by_tools) - lint_errors = self.lint_edited(edited, show_output=False) + lint_coro = self.lint_edited(edited, show_output=False) + lint_errors, interrupted = await self.coroutines.interruptible( + lint_coro, self.interrupt_event + ) + if interrupted: + raise KeyboardInterrupt("Interrupted during linting") + self.lint_outcome = not lint_errors if lint_errors: @@ -847,7 +853,12 @@ async def reply_completed(self): " its outputs are no longer necessary" ) self.io.tool_output(waiting_msg) - await asyncio.sleep(command_timeout / 2) + sleep_coro = asyncio.sleep(command_timeout / 2) + _res, interrupted = await self.coroutines.interruptible( + sleep_coro, self.interrupt_event + ) + if interrupted: + raise KeyboardInterrupt("Interrupted while waiting for background commands") return True # Check for recently finished commands that need reflection diff --git a/cecli/coders/base_coder.py b/cecli/coders/base_coder.py index fd205357282..8102c211d71 100755 --- a/cecli/coders/base_coder.py +++ b/cecli/coders/base_coder.py @@ -60,7 +60,7 @@ from cecli.repo import ANY_GIT_ERROR, GitRepo from cecli.repomap import RepoMap from cecli.report import update_error_prefix -from cecli.run_cmd import run_cmd +from cecli.run_cmd import run_cmd_async from cecli.sessions import SessionManager from cecli.tools.utils.output import print_tool_response from cecli.tools.utils.registry import ToolRegistry @@ -579,7 +579,9 @@ def __init__( self.files_edited_by_tools = set() # Linting and testing - self.linter = Linter(root=self.root, encoding=io.encoding) + self.linter = Linter( + root=self.root, encoding=io.encoding, interrupt_event=self.interrupt_event + ) self.auto_lint = auto_lint self.setup_lint_cmds(lint_cmds) self.lint_cmds = lint_cmds @@ -2200,7 +2202,15 @@ async def send_message(self, inp): import asyncio loop = asyncio.get_running_loop() - result = await loop.run_in_executor(None, self.format_messages) + + async def format_in_executor(): + return await loop.run_in_executor(None, self.format_messages) + + result, interrupted = await self.coroutines.interruptible( + format_in_executor(), self.interrupt_event + ) + if interrupted: + raise KeyboardInterrupt("Interrupted during message formatting") messages = result if not await self.check_tokens(messages): @@ -2409,7 +2419,10 @@ async def send_message(self, inp): return if edited and self.auto_lint: - lint_errors = self.lint_edited(edited) + lint_errors = await self.lint_edited(edited) + if lint_errors is None: # Interrupted + return + await self.auto_commit(edited, context="Ran the linter") self.lint_outcome = not lint_errors if lint_errors: @@ -2887,12 +2900,16 @@ async def show_exhausted_error(self): self.io.tool_error(res) await self.io.offer_url(urls.token_limits) - def lint_edited(self, fnames, show_output=True): + async def lint_edited(self, fnames, show_output=True): res = "" for fname in fnames: if not fname: continue - errors = self.linter.lint(self.abs_root_path(fname)) + try: + errors = await self.linter.lint(self.abs_root_path(fname)) + except asyncio.CancelledError: + self.io.tool_warning("Linting interrupted.") + return None if errors: res += "\n" @@ -3201,122 +3218,127 @@ async def show_send_output_stream(self, completion): received_content = False chunk_index = 0 - async for chunk in completion: - if self.args.debug: - with open(".cecli/logs/chunks.log", "a") as f: - print(chunk, file=f) + try: + async for chunk in coroutines.interruptible_async_generator( + completion, self.interrupt_event + ): + if self.args.debug: + with open(".cecli/logs/chunks.log", "a") as f: + print(chunk, file=f) - # Check if confirmation is in progress and wait if needed - if not self.io.confirmation_in_progress_event.is_set(): - await self.io.confirmation_in_progress_event.wait() + # Check if confirmation is in progress and wait if needed + if not self.io.confirmation_in_progress_event.is_set(): + await self.io.confirmation_in_progress_event.wait() - if isinstance(chunk, str): - self.io.tool_error(chunk) - continue - else: - if len(chunk.choices) == 0: + if isinstance(chunk, str): + self.io.tool_error(chunk) continue + else: + if len(chunk.choices) == 0: + continue - if ( - hasattr(chunk.choices[0], "finish_reason") - and chunk.choices[0].finish_reason == "length" - ): - raise FinishReasonLength() - - try: - if chunk.choices[0].delta.tool_calls: - received_content = True - self.token_profiler.on_token() - for tool_call_chunk in chunk.choices[0].delta.tool_calls: - self.tool_reflection = True - - if tool_call_chunk.type: - self.io.update_spinner_suffix(tool_call_chunk.type) + if ( + hasattr(chunk.choices[0], "finish_reason") + and chunk.choices[0].finish_reason == "length" + ): + raise FinishReasonLength() - if tool_call_chunk.function: - if tool_call_chunk.function.name: - self.io.update_spinner_suffix(tool_call_chunk.function.name) + try: + if chunk.choices[0].delta.tool_calls: + received_content = True + self.token_profiler.on_token() + for tool_call_chunk in chunk.choices[0].delta.tool_calls: + self.tool_reflection = True - if tool_call_chunk.function.arguments: - self.io.update_spinner_suffix( - tool_call_chunk.function.arguments - ) + if tool_call_chunk.type: + self.io.update_spinner_suffix(tool_call_chunk.type) - except (AttributeError, IndexError): - # Handle cases where the response structure doesn't match expectations - pass + if tool_call_chunk.function: + if tool_call_chunk.function.name: + self.io.update_spinner_suffix(tool_call_chunk.function.name) - try: - func = chunk.choices[0].delta.function_call - # dump(func) - if func: - for k, v in func.items(): - self.tool_reflection = True - self.io.update_spinner_suffix(v) - - received_content = True - self.token_profiler.on_token() - except AttributeError: - pass + if tool_call_chunk.function.arguments: + self.io.update_spinner_suffix( + tool_call_chunk.function.arguments + ) - text = "" + except (AttributeError, IndexError): + # Handle cases where the response structure doesn't match expectations + pass - try: - reasoning_content = chunk.choices[0].delta.reasoning_content - except AttributeError: try: - reasoning_content = chunk.choices[0].delta.reasoning + func = chunk.choices[0].delta.function_call + # dump(func) + if func: + for k, v in func.items(): + self.tool_reflection = True + self.io.update_spinner_suffix(v) + + received_content = True + self.token_profiler.on_token() except AttributeError: - reasoning_content = None + pass - if reasoning_content: - if nested.getter(self.args, "show_thinking"): - if not self.got_reasoning_content: - text += f"<{REASONING_TAG}>\n\n" - text += reasoning_content - self.got_reasoning_content = True - received_content = True - self.token_profiler.on_token() - self.io.update_spinner_suffix(reasoning_content) - self.partial_response_reasoning_content += reasoning_content + text = "" - try: - content = chunk.choices[0].delta.content - if content: - if self.got_reasoning_content and not self.ended_reasoning_content: - text += f"\n\n\n\n" - self.ended_reasoning_content = True - - text += content - received_content = True + try: + reasoning_content = chunk.choices[0].delta.reasoning_content + except AttributeError: + try: + reasoning_content = chunk.choices[0].delta.reasoning + except AttributeError: + reasoning_content = None + + if reasoning_content: + if nested.getter(self.args, "show_thinking"): + if not self.got_reasoning_content: + text += f"<{REASONING_TAG}>\n\n" + text += reasoning_content + self.got_reasoning_content = True + received_content = True self.token_profiler.on_token() - self.io.update_spinner_suffix(content) - except AttributeError: - pass + self.io.update_spinner_suffix(reasoning_content) + self.partial_response_reasoning_content += reasoning_content + + try: + content = chunk.choices[0].delta.content + if content: + if self.got_reasoning_content and not self.ended_reasoning_content: + text += f"\n\n\n\n" + self.ended_reasoning_content = True + + text += content + received_content = True + self.token_profiler.on_token() + self.io.update_spinner_suffix(content) + except AttributeError: + pass - self.partial_response_content += text + self.partial_response_content += text - chunk_index += 1 - chunk._hidden_params["created_at"] = chunk_index - self.partial_response_chunks.append(chunk) + chunk_index += 1 + chunk._hidden_params["created_at"] = chunk_index + self.partial_response_chunks.append(chunk) - if self.show_pretty(): - # Use simplified streaming - just call the method with full content - content_to_show = self.live_incremental_response(False) - self.stream_wrapper(content_to_show, final=False) - elif text: - # Apply reasoning tag formatting for non-pretty output - if nested.getter(self.args, "show_thinking"): - text = replace_reasoning_tags(text, self.reasoning_tag_name) - try: - self.stream_wrapper(text, final=False) - except UnicodeEncodeError: - # Safely encode and decode the text - safe_text = text.encode(sys.stdout.encoding, errors="backslashreplace").decode( - sys.stdout.encoding - ) - self.stream_wrapper(safe_text, final=False) - yield text + if self.show_pretty(): + # Use simplified streaming - just call the method with full content + content_to_show = self.live_incremental_response(False) + self.stream_wrapper(content_to_show, final=False) + elif text: + # Apply reasoning tag formatting for non-pretty output + if nested.getter(self.args, "show_thinking"): + text = replace_reasoning_tags(text, self.reasoning_tag_name) + try: + self.stream_wrapper(text, final=False) + except UnicodeEncodeError: + # Safely encode and decode the text + safe_text = text.encode( + sys.stdout.encoding, errors="backslashreplace" + ).decode(sys.stdout.encoding) + self.stream_wrapper(safe_text, final=False) + yield text + except (asyncio.CancelledError, KeyboardInterrupt): + raise KeyboardInterrupt # The Part Doing the Heavy Lifting Now self.consolidate_chunks() @@ -4147,8 +4169,10 @@ async def handle_shell_commands(self, commands_str, group): self.io.tool_output(f"Running {command}") # Add the command to input history # self.io.add_to_input_history(f"/run {command.strip()}") - exit_status, output = await asyncio.to_thread( - run_cmd, command, error_print=self.io.tool_error, cwd=self.root + exit_status, output = await run_cmd_async( + command, + self.interrupt_event, + cwd=self.root, ) if output: diff --git a/cecli/commands/run.py b/cecli/commands/run.py index a23d5326c6b..13f1e028cc5 100644 --- a/cecli/commands/run.py +++ b/cecli/commands/run.py @@ -1,11 +1,10 @@ -import asyncio from typing import List import cecli.prompts.utils.system as prompts from cecli.commands.utils.base_command import BaseCommand from cecli.commands.utils.helpers import format_command_result from cecli.helpers.conversation import ConversationService, MessageTag -from cecli.run_cmd import run_cmd +from cecli.run_cmd import run_cmd_async class RunCommand(BaseCommand): @@ -22,11 +21,10 @@ async def execute(cls, io, coder, args, **kwargs): if coder.args.tui: should_print = False - exit_status, combined_output = await asyncio.to_thread( - run_cmd, + exit_status, combined_output = await run_cmd_async( args, + coder.interrupt_event, verbose=coder.args.verbose if hasattr(coder.args, "verbose") else False, - error_print=io.tool_error, cwd=coder.root, should_print=should_print, ) diff --git a/cecli/helpers/coroutines.py b/cecli/helpers/coroutines.py index 07f1a669d5a..ccddf957cf7 100644 --- a/cecli/helpers/coroutines.py +++ b/cecli/helpers/coroutines.py @@ -1,6 +1,40 @@ import asyncio +async def interruptible_async_generator(async_generator, interrupt_event): + """ + Wraps an async generator to make it interruptible. + """ + gen = async_generator.__aiter__() + interrupt_task = asyncio.create_task(interrupt_event.wait()) + + while True: + next_task = asyncio.create_task(gen.__anext__()) + done, pending = await asyncio.wait( + {next_task, interrupt_task}, return_when=asyncio.FIRST_COMPLETED + ) + + if interrupt_task in done: + next_task.cancel() + try: + await next_task + except asyncio.CancelledError: + pass + break + + if next_task in done: + try: + yield next_task.result() + except StopAsyncIteration: + break + + interrupt_task.cancel() + try: + await interrupt_task + except asyncio.CancelledError: + pass + + def is_active(task): if not task or task.done() or task.cancelled(): return False diff --git a/cecli/linter.py b/cecli/linter.py index 6b04ab546ff..9e91d826fd8 100644 --- a/cecli/linter.py +++ b/cecli/linter.py @@ -1,6 +1,6 @@ +import asyncio import os import re -import subprocess import sys import traceback import warnings @@ -12,16 +12,17 @@ from cecli.dump import dump # noqa: F401 from cecli.helpers.grep_ast import TreeContext, filename_to_lang from cecli.helpers.grep_ast.tsl import get_parser # noqa: E402 -from cecli.run_cmd import run_cmd_subprocess # noqa: F401 +from cecli.run_cmd import run_cmd_async, run_cmd_subprocess # noqa: F401 # tree_sitter is throwing a FutureWarning warnings.simplefilter("ignore", category=FutureWarning) class Linter: - def __init__(self, encoding="utf-8", root=None): + def __init__(self, encoding="utf-8", root=None, interrupt_event=None): self.encoding = encoding self.root = root + self.interrupt_event = interrupt_event or asyncio.Event() self.languages = dict( python=self.py_lint, @@ -44,20 +45,18 @@ def get_rel_fname(self, fname): else: return fname - def run_cmd(self, cmd, rel_fname, code): + async def run_cmd(self, cmd, rel_fname, code): cmd += " " + oslex.quote(rel_fname) - returncode = 0 - stdout = "" - try: - returncode, stdout = run_cmd_subprocess( - cmd, - cwd=self.root, - encoding=self.encoding, - ) - except OSError as err: - print(f"Unable to execute lint command: {err}") + returncode, stdout = await run_cmd_async( + cmd, + self.interrupt_event, + cwd=self.root, + encoding=self.encoding, + ) + if stdout == "Interrupted": return + errors = stdout if returncode == 0: return # zero exit status @@ -79,7 +78,7 @@ def errors_to_lint_result(self, rel_fname, errors): return LintResult(text=errors, lines=linenums) - def lint(self, fname, cmd=None): + async def lint(self, fname, cmd=None): rel_fname = self.get_rel_fname(fname) try: code = Path(fname).read_text(encoding=self.encoding, errors="replace") @@ -99,9 +98,13 @@ def lint(self, fname, cmd=None): cmd = self.languages.get(lang) if callable(cmd): - lintres = cmd(fname, rel_fname, code) + # Check if the callable is a coroutine function + if asyncio.iscoroutinefunction(cmd): + lintres = await cmd(fname, rel_fname, code) + else: + lintres = cmd(fname, rel_fname, code) elif cmd: - lintres = self.run_cmd(cmd, rel_fname, code) + lintres = await self.run_cmd(cmd, rel_fname, code) else: lintres = basic_lint(rel_fname, code) @@ -115,10 +118,10 @@ def lint(self, fname, cmd=None): return res - def py_lint(self, fname, rel_fname, code): + async def py_lint(self, fname, rel_fname, code): basic_res = basic_lint(rel_fname, code) compile_res = lint_python_compile(fname, code) - flake_res = self.flake8_lint(rel_fname) + flake_res = await self.flake8_lint(rel_fname) text = "" lines = set() @@ -133,9 +136,9 @@ def py_lint(self, fname, rel_fname, code): if text or lines: return LintResult(text, lines) - def flake8_lint(self, rel_fname): + async def flake8_lint(self, rel_fname): fatal = "E9,F821,F823,F831,F406,F407,F701,F702,F704,F706" - flake8_cmd = [ + flake8_cmd_list = [ sys.executable, "-m", "flake8", @@ -144,24 +147,21 @@ def flake8_lint(self, rel_fname): "--isolated", rel_fname, ] + flake8_cmd = " ".join(flake8_cmd_list) - text = f"## Running: {' '.join(flake8_cmd)}\n\n" + text = f"## Running: {flake8_cmd}\n\n" - try: - result = subprocess.run( - flake8_cmd, - capture_output=True, - text=True, - check=False, - encoding=self.encoding, - errors="replace", - cwd=self.root, - ) - errors = result.stdout + result.stderr - except Exception as e: - errors = f"Error running flake8: {str(e)}" + returncode, stdout = await run_cmd_async( + flake8_cmd, + self.interrupt_event, + cwd=self.root, + encoding=self.encoding, + ) + if stdout == "Interrupted": + return - if not errors: + errors = stdout + if returncode == 0 or not errors: return text += errors diff --git a/cecli/models.py b/cecli/models.py index 19a6f8cff35..3caebebe6bf 100644 --- a/cecli/models.py +++ b/cecli/models.py @@ -1278,7 +1278,11 @@ async def send_completion( if override_kwargs: kwargs = deep_merge(kwargs, override_kwargs) - res = await litellm.acompletion(**kwargs) + completion_coro = litellm.acompletion(**kwargs) + res, interrupted = await coroutines.interruptible(completion_coro, interrupt_event) + if interrupted: + raise KeyboardInterrupt("Interrupted during acompletion") + return hash_object, res except litellm.ContextWindowExceededError as err: raise err diff --git a/cecli/run_cmd.py b/cecli/run_cmd.py index 2de892f51a6..5cbb13d6601 100644 --- a/cecli/run_cmd.py +++ b/cecli/run_cmd.py @@ -1,3 +1,4 @@ +import asyncio import os import platform import subprocess @@ -97,6 +98,86 @@ def run_cmd_subprocess( return 1, str(e) +async def run_cmd_async( + command, + interrupt_event, + verbose=False, + cwd=None, + encoding=sys.stdout.encoding, + should_print=True, +): + if verbose: + print("Using run_cmd_async:", command) + + shell = os.environ.get("SHELL", "/bin/sh") + parent_process = None + + # Determine the appropriate shell + if platform.system() == "Windows": + parent_process = get_windows_parent_process_name() + if parent_process == "powershell.exe": + command = f"powershell -Command {command}" + + if verbose: + print("Running command:", command) + print("SHELL:", shell) + if platform.system() == "Windows": + print("Parent process:", parent_process) + + try: + process = await asyncio.create_subprocess_shell( + command, + stdout=asyncio.subprocess.PIPE, + stderr=asyncio.subprocess.STDOUT, + cwd=cwd, + ) + except FileNotFoundError: + return 1, f"Command not found: {command}" + + output = [] + + async def read_stream(stream): + while True: + try: + line_bytes = await stream.readline() + except (IOError, OSError): + # Stream closed + break + if not line_bytes: + break + line = line_bytes.decode(encoding, errors="replace") + output.append(line) + if should_print: + print(line, end="", flush=True) + + reader_task = asyncio.create_task(read_stream(process.stdout)) + interrupt_task = asyncio.create_task(interrupt_event.wait()) + + done, pending = await asyncio.wait( + {reader_task, interrupt_task}, + return_when=asyncio.FIRST_COMPLETED, + ) + + if interrupt_task in done: + # Interrupted + for task in pending: + task.cancel() + try: + process.terminate() + except ProcessLookupError: + pass # process already finished + await process.wait() + return 1, "Interrupted" + + # Not interrupted, wait for process to finish + await process.wait() + # wait for reader to finish + if not reader_task.done(): + await reader_task + + return process.returncode, "".join(output) + + def run_cmd_pexpect(command, verbose=False, cwd=None, should_print=True): """ Run a shell command interactively using pexpect, capturing all output.