Skip to content

Commit f86c3d7

Browse files
authored
fix: CLI output mess with stream output (#184)
Signed-off-by: Frost Ming <me@frostming.com>
1 parent 774057d commit f86c3d7

3 files changed

Lines changed: 59 additions & 10 deletions

File tree

src/bub/channels/cli/__init__.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from hashlib import md5
66
from pathlib import Path
77

8+
from loguru import logger
89
from prompt_toolkit import PromptSession
910
from prompt_toolkit.completion import WordCompleter
1011
from prompt_toolkit.formatted_text import FormattedText
@@ -43,10 +44,16 @@ def __init__(self, on_receive: MessageHandler, agent: Agent) -> None:
4344
self._mode = "agent" # or "shell"
4445
self._main_task: asyncio.Task | None = None
4546
self._renderer = CliRenderer(get_console())
47+
self._log_handler_id = self._install_log_sink()
4648
self._last_tape_info: TapeInfo | None = None
4749
self._workspace = self._agent.framework.workspace
4850
self._prompt = self._build_prompt(self._workspace)
4951

52+
def _install_log_sink(self) -> int:
53+
with contextlib.suppress(ValueError):
54+
logger.remove(0)
55+
return logger.add(self._renderer.log, colorize=False, format="{level:<8} | {message}")
56+
5057
async def _refresh_tape_info(self) -> None:
5158
tape = self._agent.tapes.session_tape(self._message_template["session_id"], self._workspace)
5259
info = await self._agent.tapes.info(tape.name)
@@ -67,6 +74,8 @@ async def stop(self) -> None:
6774
self._main_task.cancel()
6875
with contextlib.suppress(asyncio.CancelledError):
6976
await self._main_task
77+
with contextlib.suppress(ValueError):
78+
logger.remove(self._log_handler_id)
7079

7180
async def send(self, message: ChannelMessage) -> None:
7281
if message.kind != "error":
@@ -140,10 +149,11 @@ async def stream_events(
140149
content = str(event.data.get("delta", ""))
141150
if not content.strip() and not text:
142151
continue # skip leading whitespace-only events
143-
if live is None:
144-
live = self._renderer.start_stream(message.kind)
145152
text += content
146-
self._renderer.update_stream(live, kind=message.kind, text=text)
153+
if live is None:
154+
live = self._renderer.start_stream(message.kind, text)
155+
else:
156+
self._renderer.update_stream(live, kind=message.kind, text=text)
147157
yield event
148158
finally:
149159
if live is not None:

src/bub/channels/cli/renderer.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -52,16 +52,20 @@ def error(self, text: str) -> None:
5252
return
5353
self.console.print(self.panel("error", text))
5454

55-
def start_stream(self, kind: MessageKind) -> Live:
55+
def log(self, message: object) -> None:
56+
text = str(message).rstrip()
57+
if text:
58+
self.console.print(text)
59+
60+
def start_stream(self, kind: MessageKind, text: str) -> Live:
5661
live = Live(
57-
self.panel(kind, ""),
62+
self.panel(kind, text),
5863
console=self.console,
5964
auto_refresh=False,
6065
transient=False,
6166
vertical_overflow="visible",
6267
)
63-
live.start()
64-
live.refresh()
68+
live.start(refresh=True)
6569
return live
6670

6771
def update_stream(self, live: Live, *, kind: MessageKind, text: str) -> None:

tests/test_channels.py

Lines changed: 38 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@
1010
from republic import StreamEvent
1111

1212
from bub.channels.cli import CliChannel
13+
from bub.channels.cli import renderer as cli_renderer
14+
from bub.channels.cli.renderer import CliRenderer
1315
from bub.channels.handler import BufferedMessageHandler
1416
from bub.channels.manager import ChannelManager
1517
from bub.channels.message import ChannelMessage
@@ -301,7 +303,7 @@ async def test_cli_channel_stream_events_renders_stream_and_yields_events() -> N
301303
events: list[tuple[str, str, str]] = []
302304
live_handle = object()
303305
channel._renderer = SimpleNamespace(
304-
start_stream=lambda kind: events.append(("start", kind, "")) or live_handle,
306+
start_stream=lambda kind, text: events.append(("start", kind, text)) or live_handle,
305307
update_stream=lambda live, *, kind, text: events.append(("update", kind, text)),
306308
finish_stream=lambda live, *, kind, text: events.append(("finish", kind, text)),
307309
error=lambda content: events.append(("error", "error", content)),
@@ -319,8 +321,7 @@ async def source() -> asyncio.AsyncIterator[StreamEvent]:
319321
yielded = [event async for event in channel.stream_events(message, source())]
320322

321323
assert events == [
322-
("start", "command", ""),
323-
("update", "command", "hel"),
324+
("start", "command", "hel"),
324325
("update", "command", "hello"),
325326
("finish", "command", "hello"),
326327
]
@@ -337,6 +338,40 @@ def test_cli_channel_history_file_uses_workspace_hash(tmp_path: Path) -> None:
337338
assert result.suffix == ".history"
338339

339340

341+
def test_cli_renderer_stream_uses_live_with_initial_text(monkeypatch: pytest.MonkeyPatch) -> None:
342+
live_calls: list[tuple[str, object]] = []
343+
344+
class FakeLive:
345+
def __init__(self, renderable, **kwargs) -> None:
346+
live_calls.append(("init", renderable))
347+
live_calls.append(("transient", kwargs["transient"]))
348+
self.renderable = renderable
349+
350+
def start(self, *, refresh: bool = False) -> None:
351+
live_calls.append(("start_refresh", refresh))
352+
353+
def update(self, renderable, *, refresh: bool = False) -> None:
354+
live_calls.append(("update_refresh", refresh))
355+
self.renderable = renderable
356+
357+
def stop(self) -> None:
358+
live_calls.append(("stop", self.renderable))
359+
360+
printed: list[str] = []
361+
console = SimpleNamespace(print=printed.append)
362+
monkeypatch.setattr(cli_renderer, "Live", FakeLive)
363+
364+
renderer = CliRenderer(console) # type: ignore[arg-type]
365+
live = renderer.start_stream("normal", "hel")
366+
renderer.update_stream(live, kind="normal", text="hello") # type: ignore[arg-type]
367+
renderer.finish_stream(live, kind="normal", text="hello") # type: ignore[arg-type]
368+
369+
assert ("transient", False) in live_calls
370+
assert ("start_refresh", True) in live_calls
371+
assert ("update_refresh", True) in live_calls
372+
assert not printed
373+
374+
340375
def test_bub_message_filter_accepts_private_messages() -> None:
341376
message = SimpleNamespace(chat=SimpleNamespace(type="private"), text="hello")
342377

0 commit comments

Comments
 (0)