Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 13 additions & 2 deletions cecli/coders/agent_coder.py
Original file line number Diff line number Diff line change
Expand Up @@ -745,7 +745,13 @@ async def gather_and_await():

if self.auto_lint and used_write_tool:
edited = list(self.files_edited_by_tools)
lint_errors = self.lint_edited(edited, show_output=False)
lint_coro = self.lint_edited(edited, show_output=False)
lint_errors, interrupted = await self.coroutines.interruptible(
lint_coro, self.interrupt_event
)
if interrupted:
raise KeyboardInterrupt("Interrupted during linting")

self.lint_outcome = not lint_errors

if lint_errors:
Expand Down Expand Up @@ -847,7 +853,12 @@ async def reply_completed(self):
" its outputs are no longer necessary"
)
self.io.tool_output(waiting_msg)
await asyncio.sleep(command_timeout / 2)
sleep_coro = asyncio.sleep(command_timeout / 2)
_res, interrupted = await self.coroutines.interruptible(
sleep_coro, self.interrupt_event
)
if interrupted:
raise KeyboardInterrupt("Interrupted while waiting for background commands")
return True

# Check for recently finished commands that need reflection
Expand Down
236 changes: 130 additions & 106 deletions cecli/coders/base_coder.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@
from cecli.repo import ANY_GIT_ERROR, GitRepo
from cecli.repomap import RepoMap
from cecli.report import update_error_prefix
from cecli.run_cmd import run_cmd
from cecli.run_cmd import run_cmd_async
from cecli.sessions import SessionManager
from cecli.tools.utils.output import print_tool_response
from cecli.tools.utils.registry import ToolRegistry
Expand Down Expand Up @@ -579,7 +579,9 @@ def __init__(
self.files_edited_by_tools = set()

# Linting and testing
self.linter = Linter(root=self.root, encoding=io.encoding)
self.linter = Linter(
root=self.root, encoding=io.encoding, interrupt_event=self.interrupt_event
)
self.auto_lint = auto_lint
self.setup_lint_cmds(lint_cmds)
self.lint_cmds = lint_cmds
Expand Down Expand Up @@ -2200,7 +2202,15 @@ async def send_message(self, inp):
import asyncio

loop = asyncio.get_running_loop()
result = await loop.run_in_executor(None, self.format_messages)

async def format_in_executor():
return await loop.run_in_executor(None, self.format_messages)

result, interrupted = await self.coroutines.interruptible(
format_in_executor(), self.interrupt_event
)
if interrupted:
raise KeyboardInterrupt("Interrupted during message formatting")
messages = result

if not await self.check_tokens(messages):
Expand Down Expand Up @@ -2409,7 +2419,10 @@ async def send_message(self, inp):
return

if edited and self.auto_lint:
lint_errors = self.lint_edited(edited)
lint_errors = await self.lint_edited(edited)
if lint_errors is None: # Interrupted
return

await self.auto_commit(edited, context="Ran the linter")
self.lint_outcome = not lint_errors
if lint_errors:
Expand Down Expand Up @@ -2887,12 +2900,16 @@ async def show_exhausted_error(self):
self.io.tool_error(res)
await self.io.offer_url(urls.token_limits)

def lint_edited(self, fnames, show_output=True):
async def lint_edited(self, fnames, show_output=True):
res = ""
for fname in fnames:
if not fname:
continue
errors = self.linter.lint(self.abs_root_path(fname))
try:
errors = await self.linter.lint(self.abs_root_path(fname))
except asyncio.CancelledError:
self.io.tool_warning("Linting interrupted.")
return None

if errors:
res += "\n"
Expand Down Expand Up @@ -3201,122 +3218,127 @@ async def show_send_output_stream(self, completion):
received_content = False
chunk_index = 0

async for chunk in completion:
if self.args.debug:
with open(".cecli/logs/chunks.log", "a") as f:
print(chunk, file=f)
try:
async for chunk in coroutines.interruptible_async_generator(
completion, self.interrupt_event
):
if self.args.debug:
with open(".cecli/logs/chunks.log", "a") as f:
print(chunk, file=f)

# Check if confirmation is in progress and wait if needed
if not self.io.confirmation_in_progress_event.is_set():
await self.io.confirmation_in_progress_event.wait()
# Check if confirmation is in progress and wait if needed
if not self.io.confirmation_in_progress_event.is_set():
await self.io.confirmation_in_progress_event.wait()

if isinstance(chunk, str):
self.io.tool_error(chunk)
continue
else:
if len(chunk.choices) == 0:
if isinstance(chunk, str):
self.io.tool_error(chunk)
continue
else:
if len(chunk.choices) == 0:
continue

if (
hasattr(chunk.choices[0], "finish_reason")
and chunk.choices[0].finish_reason == "length"
):
raise FinishReasonLength()

try:
if chunk.choices[0].delta.tool_calls:
received_content = True
self.token_profiler.on_token()
for tool_call_chunk in chunk.choices[0].delta.tool_calls:
self.tool_reflection = True

if tool_call_chunk.type:
self.io.update_spinner_suffix(tool_call_chunk.type)
if (
hasattr(chunk.choices[0], "finish_reason")
and chunk.choices[0].finish_reason == "length"
):
raise FinishReasonLength()

if tool_call_chunk.function:
if tool_call_chunk.function.name:
self.io.update_spinner_suffix(tool_call_chunk.function.name)
try:
if chunk.choices[0].delta.tool_calls:
received_content = True
self.token_profiler.on_token()
for tool_call_chunk in chunk.choices[0].delta.tool_calls:
self.tool_reflection = True

if tool_call_chunk.function.arguments:
self.io.update_spinner_suffix(
tool_call_chunk.function.arguments
)
if tool_call_chunk.type:
self.io.update_spinner_suffix(tool_call_chunk.type)

except (AttributeError, IndexError):
# Handle cases where the response structure doesn't match expectations
pass
if tool_call_chunk.function:
if tool_call_chunk.function.name:
self.io.update_spinner_suffix(tool_call_chunk.function.name)

try:
func = chunk.choices[0].delta.function_call
# dump(func)
if func:
for k, v in func.items():
self.tool_reflection = True
self.io.update_spinner_suffix(v)

received_content = True
self.token_profiler.on_token()
except AttributeError:
pass
if tool_call_chunk.function.arguments:
self.io.update_spinner_suffix(
tool_call_chunk.function.arguments
)

text = ""
except (AttributeError, IndexError):
# Handle cases where the response structure doesn't match expectations
pass

try:
reasoning_content = chunk.choices[0].delta.reasoning_content
except AttributeError:
try:
reasoning_content = chunk.choices[0].delta.reasoning
func = chunk.choices[0].delta.function_call
# dump(func)
if func:
for k, v in func.items():
self.tool_reflection = True
self.io.update_spinner_suffix(v)

received_content = True
self.token_profiler.on_token()
except AttributeError:
reasoning_content = None
pass

if reasoning_content:
if nested.getter(self.args, "show_thinking"):
if not self.got_reasoning_content:
text += f"<{REASONING_TAG}>\n\n"
text += reasoning_content
self.got_reasoning_content = True
received_content = True
self.token_profiler.on_token()
self.io.update_spinner_suffix(reasoning_content)
self.partial_response_reasoning_content += reasoning_content
text = ""

try:
content = chunk.choices[0].delta.content
if content:
if self.got_reasoning_content and not self.ended_reasoning_content:
text += f"\n\n</{self.reasoning_tag_name}>\n\n"
self.ended_reasoning_content = True

text += content
received_content = True
try:
reasoning_content = chunk.choices[0].delta.reasoning_content
except AttributeError:
try:
reasoning_content = chunk.choices[0].delta.reasoning
except AttributeError:
reasoning_content = None

if reasoning_content:
if nested.getter(self.args, "show_thinking"):
if not self.got_reasoning_content:
text += f"<{REASONING_TAG}>\n\n"
text += reasoning_content
self.got_reasoning_content = True
received_content = True
self.token_profiler.on_token()
self.io.update_spinner_suffix(content)
except AttributeError:
pass
self.io.update_spinner_suffix(reasoning_content)
self.partial_response_reasoning_content += reasoning_content

try:
content = chunk.choices[0].delta.content
if content:
if self.got_reasoning_content and not self.ended_reasoning_content:
text += f"\n\n</{self.reasoning_tag_name}>\n\n"
self.ended_reasoning_content = True

text += content
received_content = True
self.token_profiler.on_token()
self.io.update_spinner_suffix(content)
except AttributeError:
pass

self.partial_response_content += text
self.partial_response_content += text

chunk_index += 1
chunk._hidden_params["created_at"] = chunk_index
self.partial_response_chunks.append(chunk)
chunk_index += 1
chunk._hidden_params["created_at"] = chunk_index
self.partial_response_chunks.append(chunk)

if self.show_pretty():
# Use simplified streaming - just call the method with full content
content_to_show = self.live_incremental_response(False)
self.stream_wrapper(content_to_show, final=False)
elif text:
# Apply reasoning tag formatting for non-pretty output
if nested.getter(self.args, "show_thinking"):
text = replace_reasoning_tags(text, self.reasoning_tag_name)
try:
self.stream_wrapper(text, final=False)
except UnicodeEncodeError:
# Safely encode and decode the text
safe_text = text.encode(sys.stdout.encoding, errors="backslashreplace").decode(
sys.stdout.encoding
)
self.stream_wrapper(safe_text, final=False)
yield text
if self.show_pretty():
# Use simplified streaming - just call the method with full content
content_to_show = self.live_incremental_response(False)
self.stream_wrapper(content_to_show, final=False)
elif text:
# Apply reasoning tag formatting for non-pretty output
if nested.getter(self.args, "show_thinking"):
text = replace_reasoning_tags(text, self.reasoning_tag_name)
try:
self.stream_wrapper(text, final=False)
except UnicodeEncodeError:
# Safely encode and decode the text
safe_text = text.encode(
sys.stdout.encoding, errors="backslashreplace"
).decode(sys.stdout.encoding)
self.stream_wrapper(safe_text, final=False)
yield text
except (asyncio.CancelledError, KeyboardInterrupt):
raise KeyboardInterrupt

# The Part Doing the Heavy Lifting Now
self.consolidate_chunks()
Expand Down Expand Up @@ -4147,8 +4169,10 @@ async def handle_shell_commands(self, commands_str, group):
self.io.tool_output(f"Running {command}")
# Add the command to input history
# self.io.add_to_input_history(f"/run {command.strip()}")
exit_status, output = await asyncio.to_thread(
run_cmd, command, error_print=self.io.tool_error, cwd=self.root
exit_status, output = await run_cmd_async(
command,
self.interrupt_event,
cwd=self.root,
)

if output:
Expand Down
8 changes: 3 additions & 5 deletions cecli/commands/run.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
import asyncio
from typing import List

import cecli.prompts.utils.system as prompts
from cecli.commands.utils.base_command import BaseCommand
from cecli.commands.utils.helpers import format_command_result
from cecli.helpers.conversation import ConversationService, MessageTag
from cecli.run_cmd import run_cmd
from cecli.run_cmd import run_cmd_async


class RunCommand(BaseCommand):
Expand All @@ -22,11 +21,10 @@ async def execute(cls, io, coder, args, **kwargs):
if coder.args.tui:
should_print = False

exit_status, combined_output = await asyncio.to_thread(
run_cmd,
exit_status, combined_output = await run_cmd_async(
args,
coder.interrupt_event,
verbose=coder.args.verbose if hasattr(coder.args, "verbose") else False,
error_print=io.tool_error,
cwd=coder.root,
should_print=should_print,
)
Expand Down
Loading
Loading