Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -14,21 +14,29 @@

from __future__ import annotations

from collections.abc import AsyncIterable
from typing import Any, Generic

from langchain_core.messages import AIMessage, BaseMessageChunk, HumanMessage, SystemMessage
from langchain_core.messages import (
AIMessage,
BaseMessage,
BaseMessageChunk,
HumanMessage,
SystemMessage,
)
from langchain_core.runnables import RunnableConfig
from langgraph.pregel.protocol import PregelProtocol
from langgraph.types import StreamMode
from langgraph.typing import ContextT

from livekit.agents import llm, utils
from livekit.agents.llm import ToolChoice
from livekit.agents.llm import ChatChunk, ToolChoice
from livekit.agents.llm.chat_context import ChatContext
from livekit.agents.types import (
DEFAULT_API_CONNECT_OPTIONS,
NOT_GIVEN,
APIConnectOptions,
FlushSentinel,
NotGivenOr,
)

Expand Down Expand Up @@ -116,9 +124,16 @@ def __init__(
self._subgraphs = subgraphs
self._stream_mode = stream_mode

async def _metrics_monitor_task(self, event_aiter: AsyncIterable[ChatChunk]) -> None:
async def _filtered(aiter: AsyncIterable) -> AsyncIterable[ChatChunk]:
async for ev in aiter:
if isinstance(ev, ChatChunk):
yield ev

await super()._metrics_monitor_task(_filtered(event_aiter))

async def _run(self) -> None:
state = self._chat_ctx_to_state()
is_multi_mode = isinstance(self._stream_mode, list)

# Some LangGraph versions don't accept the `subgraphs` or `context` kwargs yet.
# Try with them first; fall back gracefully if unsupported.
Expand All @@ -137,41 +152,65 @@ async def _run(self) -> None:
stream_mode=self._stream_mode,
)

multi_mode = isinstance(self._stream_mode, list)

async for item in aiter:
# Multi-mode: item is (mode, data) tuple wrapper
if is_multi_mode and isinstance(item, tuple) and len(item) == 2:
# Strip subgraph namespace prefix when present.
# With subgraphs=True, items are prefixed with a namespace tuple:
# single-mode: (ns, data) -> data
# multi-mode: (ns, mode, data) -> (mode, data)
if self._subgraphs and isinstance(item, tuple) and isinstance(item[0], tuple):
item = item[1:]
if len(item) == 1:
item = item[0]

# Extract mode tag in multi-mode; infer in single-mode.
if multi_mode and isinstance(item, tuple) and len(item) == 2:
mode, data = item
if isinstance(mode, str):
if mode == "custom":
# data = payload (str, dict, object)
chat_chunk = _to_chat_chunk(data)
if chat_chunk:
self._event_ch.send_nowait(chat_chunk)
continue
elif mode == "messages":
# data = (token, metadata)
token_like = _extract_message_chunk(data)
if token_like is None:
continue
chat_chunk = _to_chat_chunk(token_like)
if chat_chunk:
self._event_ch.send_nowait(chat_chunk)
continue

# Single-mode: item is data directly (no tuple wrapper)
if self._stream_mode == "custom":
# item = payload (str, dict, object)
chat_chunk = _to_chat_chunk(item)
if chat_chunk:
self._event_ch.send_nowait(chat_chunk)
elif self._stream_mode == "messages":
# item = (token, metadata)
token_like = _extract_message_chunk(item)
if token_like is None:
continue
chat_chunk = _to_chat_chunk(token_like)
if chat_chunk:
self._event_ch.send_nowait(chat_chunk)
else:
mode = self._stream_mode if isinstance(self._stream_mode, str) else None
data = item

if mode == "messages":
self._send_message(data)
elif mode == "custom":
self._send_custom(data)

def _send_custom(self, data: Any) -> None:
"""Handle custom stream mode items from StreamWriter.

Custom mode emits raw values written by StreamWriter nodes — strings,
dicts, BaseMessages, or arbitrary objects (e.g. FlushSentinel).
FlushSentinel is forwarded directly to trigger immediate TTS playback.
We extract text content where possible; non-text values are silently
skipped since ChatChunk only carries text.
"""
if isinstance(data, FlushSentinel):
self._event_ch.send_nowait(data) # type: ignore[arg-type]
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Preserve ChatChunk contract for emitted stream events

Sending FlushSentinel directly into _event_ch makes this LLMStream yield non-ChatChunk objects, but shared consumers like LLMStream.to_str_iterable()/collect() (livekit-agents/livekit/agents/llm/llm.py) and FallbackLLMStream._run() (livekit-agents/livekit/agents/llm/fallback_adapter.py) unconditionally access .delta, so a LangGraph custom stream that emits FlushSentinel will now raise AttributeError in those paths. This is a regression introduced by this change because previously all emitted events were ChatChunk instances.

Useful? React with 👍 / 👎.

return

content = _extract_custom_content(data)
if content:
chunk = _to_chat_chunk(content)
if chunk:
self._event_ch.send_nowait(chunk)

def _send_message(self, data: Any) -> None:
"""Handle messages stream mode items.

Messages mode yields (message, metadata) tuples where the message is
typically a BaseMessageChunk from LLM streaming.
Also tolerates bare token-like values for robustness.
"""
if isinstance(data, (BaseMessageChunk, str)):
token = data
elif isinstance(data, tuple) and len(data) == 2:
token, _meta = data
else:
return
chunk = _to_chat_chunk(token)
if chunk:
self._event_ch.send_nowait(chunk)

def _chat_ctx_to_state(self) -> dict[str, Any]:
"""Convert chat context to langgraph input"""
Expand All @@ -190,44 +229,20 @@ def _chat_ctx_to_state(self) -> dict[str, Any]:
return {"messages": messages}


def _extract_message_chunk(item: Any) -> BaseMessageChunk | str | None:
"""
Normalize outputs from graph.astream(..., stream_mode='messages', [subgraphs]).

Expected shapes:
- (token, meta)
- (namespace, (token, meta)) # with subgraphs=True
- (mode, (token, meta)) # future-friendly
- (namespace, mode, (token, meta)) # future-friendly
Also tolerate direct token-like values for robustness.
"""
# Already a token-like thing?
if isinstance(item, (BaseMessageChunk, str)):
return item

if not isinstance(item, tuple):
return None

# token is usually BaseMessageChunk, but could be a str
# (token, meta)
if len(item) == 2 and not isinstance(item[1], tuple):
token, _meta = item
return token # type: ignore

# (namespace, (token, meta)) OR (mode, (token, meta))
if len(item) == 2 and isinstance(item[1], tuple):
inner = item[1]
if len(inner) == 2:
token, _meta = inner
return token # type: ignore

# (namespace, mode, (token, meta))
if len(item) == 3 and isinstance(item[2], tuple):
inner = item[2]
if len(inner) == 2:
token, _meta = inner
return token # type: ignore
def _extract_custom_content(value: Any) -> str | None:
"""Extract text from a custom stream value.

StreamWriter can emit arbitrary types. We recognize common text-carrying
shapes (str, BaseMessage, dict with "content" key) and return the text.
Returns None for non-text values (e.g. FlushSentinel, control objects)
so the caller can skip them — ChatChunk only carries text content.
"""
if isinstance(value, str):
return value
if isinstance(value, BaseMessage) and isinstance(value.content, str):
return value.content
if isinstance(value, dict) and isinstance(value.get("content"), str):
return value["content"] # type: ignore[no-any-return]
return None


Expand All @@ -241,14 +256,6 @@ def _to_chat_chunk(msg: str | Any) -> llm.ChatChunk | None:
content = msg.text
if getattr(msg, "id", None):
message_id = msg.id # type: ignore
elif isinstance(msg, dict):
raw = msg.get("content")
if isinstance(raw, str):
content = raw
elif hasattr(msg, "content"):
raw = msg.content
if isinstance(raw, str):
content = raw

if not content:
return None
Expand Down
Loading