Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
91 changes: 90 additions & 1 deletion libs/community/langchain_community/chat_models/mlx.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
"""MLX Chat Wrapper."""

import json
import re
import uuid
from typing import (
Any,
Callable,
Expand All @@ -9,6 +12,7 @@
Literal,
Optional,
Sequence,
Tuple,
Type,
Union,
)
Expand All @@ -24,7 +28,13 @@
AIMessageChunk,
BaseMessage,
HumanMessage,
InvalidToolCall,
SystemMessage,
ToolCall,
)
from langchain_core.output_parsers.openai_tools import (
make_invalid_tool_call,
parse_tool_call,
)
from langchain_core.outputs import (
ChatGeneration,
Expand All @@ -41,6 +51,52 @@
DEFAULT_SYSTEM_PROMPT = """You are a helpful, respectful, and honest assistant."""


def _parse_react_tool_calls(
text: str,
) -> Tuple[list[ToolCall] | None, list[InvalidToolCall]]:
"""Extract ReAct-style tool calls from plain text output.

Args:
text: Raw model generation text.

Returns:
A tuple containing a list of parsed ``ToolCall`` objects if any were
detected, otherwise ``None``, and a list of ``InvalidToolCall`` objects
for unparseable patterns.
"""

tool_calls: list[ToolCall] = []
invalid_tool_calls: list[InvalidToolCall] = []

bracket_pattern = r"Action:\s*(?P<name>[\w.-]+)\[(?P<input>[^\]]+)\]"
separate_pattern = (
r"Action:\s*(?P<name>[^\n]+)\nAction Input:\s*(?P<input>[^\n]+)"
)

matches = list(re.finditer(bracket_pattern, text))
if not matches:
matches = list(re.finditer(separate_pattern, text))

for match in matches:
name = match.group("name").strip()
arg_text = match.group("input").strip()
try:
args = json.loads(arg_text)
if not isinstance(args, dict):
args = {"input": args}
except Exception:
args = {"input": arg_text}
tool_calls.append(ToolCall(id=str(uuid.uuid4()), name=name, args=args))

if not tool_calls and "Action:" in text:
invalid_tool_calls.append(
make_invalid_tool_call(text, "Could not parse ReAct tool call")
)
return None, invalid_tool_calls

return tool_calls or None, invalid_tool_calls


class ChatMLX(BaseChatModel):
"""MLX chat models.

Expand Down Expand Up @@ -135,8 +191,41 @@ def _to_chat_result(llm_result: LLMResult) -> ChatResult:
chat_generations = []

for g in llm_result.generations[0]:
tool_calls: list[ToolCall] = []
invalid_tool_calls: list[InvalidToolCall] = []
additional_kwargs: Dict[str, Any] = {}

if isinstance(g.generation_info, dict):
raw_tool_calls = g.generation_info.get("tool_calls")
else:
raw_tool_calls = None

if raw_tool_calls:
additional_kwargs["tool_calls"] = raw_tool_calls
for raw_tool_call in raw_tool_calls:
try:
tc = parse_tool_call(raw_tool_call, return_id=True)
except Exception as e:
invalid_tool_calls.append(
make_invalid_tool_call(raw_tool_call, str(e))
)
else:
if tc:
tool_calls.append(tc)
else:
react_tool_calls, invalid_reacts = _parse_react_tool_calls(g.text)
if react_tool_calls is not None:
tool_calls.extend(react_tool_calls)
invalid_tool_calls.extend(invalid_reacts)

chat_generation = ChatGeneration(
message=AIMessage(content=g.text), generation_info=g.generation_info
message=AIMessage(
content=g.text,
additional_kwargs=additional_kwargs,
tool_calls=tool_calls,
invalid_tool_calls=invalid_tool_calls,
),
generation_info=g.generation_info,
)
chat_generations.append(chat_generation)

Expand Down