diff --git a/libs/community/langchain_community/chat_models/mlx.py b/libs/community/langchain_community/chat_models/mlx.py index 44e76087..18149421 100644 --- a/libs/community/langchain_community/chat_models/mlx.py +++ b/libs/community/langchain_community/chat_models/mlx.py @@ -1,5 +1,8 @@ """MLX Chat Wrapper.""" +import json +import re +import uuid from typing import ( Any, Callable, @@ -9,6 +12,7 @@ Literal, Optional, Sequence, + Tuple, Type, Union, ) @@ -24,7 +28,13 @@ AIMessageChunk, BaseMessage, HumanMessage, + InvalidToolCall, SystemMessage, + ToolCall, +) +from langchain_core.output_parsers.openai_tools import ( + make_invalid_tool_call, + parse_tool_call, ) from langchain_core.outputs import ( ChatGeneration, @@ -41,6 +51,52 @@ DEFAULT_SYSTEM_PROMPT = """You are a helpful, respectful, and honest assistant.""" +def _parse_react_tool_calls( + text: str, +) -> Tuple[list[ToolCall] | None, list[InvalidToolCall]]: + """Extract ReAct-style tool calls from plain text output. + + Args: + text: Raw model generation text. + + Returns: + A tuple containing a list of parsed ``ToolCall`` objects if any were + detected, otherwise ``None``, and a list of ``InvalidToolCall`` objects + for unparseable patterns. + """ + + tool_calls: list[ToolCall] = [] + invalid_tool_calls: list[InvalidToolCall] = [] + + bracket_pattern = r"Action:\s*(?P[\w.-]+)\[(?P[^\]]+)\]" + separate_pattern = ( + r"Action:\s*(?P[^\n]+)\nAction Input:\s*(?P[^\n]+)" + ) + + matches = list(re.finditer(bracket_pattern, text)) + if not matches: + matches = list(re.finditer(separate_pattern, text)) + + for match in matches: + name = match.group("name").strip() + arg_text = match.group("input").strip() + try: + args = json.loads(arg_text) + if not isinstance(args, dict): + args = {"input": args} + except Exception: + args = {"input": arg_text} + tool_calls.append(ToolCall(id=str(uuid.uuid4()), name=name, args=args)) + + if not tool_calls and "Action:" in text: + invalid_tool_calls.append( + make_invalid_tool_call(text, "Could not parse ReAct tool call") + ) + return None, invalid_tool_calls + + return tool_calls or None, invalid_tool_calls + + class ChatMLX(BaseChatModel): """MLX chat models. @@ -135,8 +191,41 @@ def _to_chat_result(llm_result: LLMResult) -> ChatResult: chat_generations = [] for g in llm_result.generations[0]: + tool_calls: list[ToolCall] = [] + invalid_tool_calls: list[InvalidToolCall] = [] + additional_kwargs: Dict[str, Any] = {} + + if isinstance(g.generation_info, dict): + raw_tool_calls = g.generation_info.get("tool_calls") + else: + raw_tool_calls = None + + if raw_tool_calls: + additional_kwargs["tool_calls"] = raw_tool_calls + for raw_tool_call in raw_tool_calls: + try: + tc = parse_tool_call(raw_tool_call, return_id=True) + except Exception as e: + invalid_tool_calls.append( + make_invalid_tool_call(raw_tool_call, str(e)) + ) + else: + if tc: + tool_calls.append(tc) + else: + react_tool_calls, invalid_reacts = _parse_react_tool_calls(g.text) + if react_tool_calls is not None: + tool_calls.extend(react_tool_calls) + invalid_tool_calls.extend(invalid_reacts) + chat_generation = ChatGeneration( - message=AIMessage(content=g.text), generation_info=g.generation_info + message=AIMessage( + content=g.text, + additional_kwargs=additional_kwargs, + tool_calls=tool_calls, + invalid_tool_calls=invalid_tool_calls, + ), + generation_info=g.generation_info, ) chat_generations.append(chat_generation)