From 6607ee9799127ac1672ebc912a26836d259f9f40 Mon Sep 17 00:00:00 2001 From: diego-coder <39010417+diego-coder@users.noreply.github.com> Date: Thu, 11 Sep 2025 17:17:48 -0700 Subject: [PATCH 1/3] test(mlx): add async tool test --- .../langchain_community/chat_models/mlx.py | 18 ++++- .../tests/unit_tests/chat_models/test_mlx.py | 71 +++++++++++++++++++ 2 files changed, 86 insertions(+), 3 deletions(-) diff --git a/libs/community/langchain_community/chat_models/mlx.py b/libs/community/langchain_community/chat_models/mlx.py index 44e76087..5cbec756 100644 --- a/libs/community/langchain_community/chat_models/mlx.py +++ b/libs/community/langchain_community/chat_models/mlx.py @@ -76,7 +76,8 @@ def _generate( run_manager: Optional[CallbackManagerForLLMRun] = None, **kwargs: Any, ) -> ChatResult: - llm_input = self._to_chat_prompt(messages) + tools = kwargs.pop("tools", None) + llm_input = self._to_chat_prompt(messages, tools=tools) llm_result = self.llm._generate( prompts=[llm_input], stop=stop, run_manager=run_manager, **kwargs ) @@ -89,7 +90,8 @@ async def _agenerate( run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, **kwargs: Any, ) -> ChatResult: - llm_input = self._to_chat_prompt(messages) + tools = kwargs.pop("tools", None) + llm_input = self._to_chat_prompt(messages, tools=tools) llm_result = await self.llm._agenerate( prompts=[llm_input], stop=stop, run_manager=run_manager, **kwargs ) @@ -100,8 +102,17 @@ def _to_chat_prompt( messages: List[BaseMessage], tokenize: bool = False, return_tensors: Optional[str] = None, + tools: Sequence[dict] | None = None, ) -> str: - """Convert a list of messages into a prompt format expected by wrapped LLM.""" + """Convert messages to the prompt format expected by the wrapped LLM. + + Args: + messages: Chat messages to include in the prompt. + tokenize: Whether to return token IDs instead of text. + return_tensors: Framework for returned tensors when ``tokenize`` is + True. + tools: Optional tool definitions to include in the prompt. + """ if not messages: raise ValueError("At least one HumanMessage must be provided!") @@ -114,6 +125,7 @@ def _to_chat_prompt( tokenize=tokenize, add_generation_prompt=True, return_tensors=return_tensors, + tools=tools, ) def _to_chatml_format(self, message: BaseMessage) -> dict: diff --git a/libs/community/tests/unit_tests/chat_models/test_mlx.py b/libs/community/tests/unit_tests/chat_models/test_mlx.py index 6f27116d..ee9a1510 100644 --- a/libs/community/tests/unit_tests/chat_models/test_mlx.py +++ b/libs/community/tests/unit_tests/chat_models/test_mlx.py @@ -2,6 +2,42 @@ from importlib import import_module +import pytest +from langchain_core.messages import HumanMessage + +from langchain_community.chat_models.mlx import ChatMLX + + +class _FakeTokenizer: + def __init__(self) -> None: + self.tools = None + + def apply_chat_template( + self, + messages, + tokenize=False, + add_generation_prompt=True, + return_tensors=None, + tools=None, + ) -> str: + self.tools = tools + return "prompt" + + +class _FakeLLM: + def __init__(self) -> None: + self.tokenizer = _FakeTokenizer() + + def _generate(self, prompts, stop=None, run_manager=None, **kwargs): + class _Res: + generations = [[type("G", (), {"text": "", "generation_info": {}})]] + llm_output = {} + + return _Res() + + async def _agenerate(self, prompts, stop=None, run_manager=None, **kwargs): + return self._generate(prompts, stop=stop, run_manager=run_manager, **kwargs) + def test_import_class() -> None: """Test that the class can be imported.""" @@ -10,3 +46,38 @@ def test_import_class() -> None: module = import_module(module_name) assert hasattr(module, class_name) + + +def test_generate_passes_tools_to_tokenizer() -> None: + llm = _FakeLLM() + chat = ChatMLX(llm=llm) + tools = [ + { + "type": "function", + "function": { + "name": "foo", + "description": "", + "parameters": {"type": "object", "properties": {}}, + }, + } + ] + chat._generate([HumanMessage(content="hi")], tools=tools) + assert llm.tokenizer.tools == tools + + +@pytest.mark.asyncio +async def test_agenerate_passes_tools_to_tokenizer() -> None: + llm = _FakeLLM() + chat = ChatMLX(llm=llm) + tools = [ + { + "type": "function", + "function": { + "name": "foo", + "description": "", + "parameters": {"type": "object", "properties": {}}, + }, + } + ] + await chat._agenerate([HumanMessage(content="hi")], tools=tools) + assert llm.tokenizer.tools == tools From 264a02e558b6f4eb4473456776e33ad4958e1274 Mon Sep 17 00:00:00 2001 From: diego-coder <39010417+diego-coder@users.noreply.github.com> Date: Thu, 11 Sep 2025 18:26:28 -0700 Subject: [PATCH 2/3] Refine MLX ReAct tool-call parsing --- .../langchain_community/chat_models/mlx.py | 91 ++++++++++++++++++- 1 file changed, 90 insertions(+), 1 deletion(-) diff --git a/libs/community/langchain_community/chat_models/mlx.py b/libs/community/langchain_community/chat_models/mlx.py index 44e76087..18149421 100644 --- a/libs/community/langchain_community/chat_models/mlx.py +++ b/libs/community/langchain_community/chat_models/mlx.py @@ -1,5 +1,8 @@ """MLX Chat Wrapper.""" +import json +import re +import uuid from typing import ( Any, Callable, @@ -9,6 +12,7 @@ Literal, Optional, Sequence, + Tuple, Type, Union, ) @@ -24,7 +28,13 @@ AIMessageChunk, BaseMessage, HumanMessage, + InvalidToolCall, SystemMessage, + ToolCall, +) +from langchain_core.output_parsers.openai_tools import ( + make_invalid_tool_call, + parse_tool_call, ) from langchain_core.outputs import ( ChatGeneration, @@ -41,6 +51,52 @@ DEFAULT_SYSTEM_PROMPT = """You are a helpful, respectful, and honest assistant.""" +def _parse_react_tool_calls( + text: str, +) -> Tuple[list[ToolCall] | None, list[InvalidToolCall]]: + """Extract ReAct-style tool calls from plain text output. + + Args: + text: Raw model generation text. + + Returns: + A tuple containing a list of parsed ``ToolCall`` objects if any were + detected, otherwise ``None``, and a list of ``InvalidToolCall`` objects + for unparseable patterns. + """ + + tool_calls: list[ToolCall] = [] + invalid_tool_calls: list[InvalidToolCall] = [] + + bracket_pattern = r"Action:\s*(?P[\w.-]+)\[(?P[^\]]+)\]" + separate_pattern = ( + r"Action:\s*(?P[^\n]+)\nAction Input:\s*(?P[^\n]+)" + ) + + matches = list(re.finditer(bracket_pattern, text)) + if not matches: + matches = list(re.finditer(separate_pattern, text)) + + for match in matches: + name = match.group("name").strip() + arg_text = match.group("input").strip() + try: + args = json.loads(arg_text) + if not isinstance(args, dict): + args = {"input": args} + except Exception: + args = {"input": arg_text} + tool_calls.append(ToolCall(id=str(uuid.uuid4()), name=name, args=args)) + + if not tool_calls and "Action:" in text: + invalid_tool_calls.append( + make_invalid_tool_call(text, "Could not parse ReAct tool call") + ) + return None, invalid_tool_calls + + return tool_calls or None, invalid_tool_calls + + class ChatMLX(BaseChatModel): """MLX chat models. @@ -135,8 +191,41 @@ def _to_chat_result(llm_result: LLMResult) -> ChatResult: chat_generations = [] for g in llm_result.generations[0]: + tool_calls: list[ToolCall] = [] + invalid_tool_calls: list[InvalidToolCall] = [] + additional_kwargs: Dict[str, Any] = {} + + if isinstance(g.generation_info, dict): + raw_tool_calls = g.generation_info.get("tool_calls") + else: + raw_tool_calls = None + + if raw_tool_calls: + additional_kwargs["tool_calls"] = raw_tool_calls + for raw_tool_call in raw_tool_calls: + try: + tc = parse_tool_call(raw_tool_call, return_id=True) + except Exception as e: + invalid_tool_calls.append( + make_invalid_tool_call(raw_tool_call, str(e)) + ) + else: + if tc: + tool_calls.append(tc) + else: + react_tool_calls, invalid_reacts = _parse_react_tool_calls(g.text) + if react_tool_calls is not None: + tool_calls.extend(react_tool_calls) + invalid_tool_calls.extend(invalid_reacts) + chat_generation = ChatGeneration( - message=AIMessage(content=g.text), generation_info=g.generation_info + message=AIMessage( + content=g.text, + additional_kwargs=additional_kwargs, + tool_calls=tool_calls, + invalid_tool_calls=invalid_tool_calls, + ), + generation_info=g.generation_info, ) chat_generations.append(chat_generation) From 87e5484d0a98e234b9d06168b1960f8aa77a9108 Mon Sep 17 00:00:00 2001 From: diego-coder <39010417+diego-coder@users.noreply.github.com> Date: Thu, 11 Sep 2025 18:58:58 -0700 Subject: [PATCH 3/3] test: use phi-3 model for ChatMLX tools --- .../chat_models/test_mlx_tool_calls.py | 55 +++++++++++++++++++ 1 file changed, 55 insertions(+) create mode 100644 libs/community/tests/integration_tests/chat_models/test_mlx_tool_calls.py diff --git a/libs/community/tests/integration_tests/chat_models/test_mlx_tool_calls.py b/libs/community/tests/integration_tests/chat_models/test_mlx_tool_calls.py new file mode 100644 index 00000000..396c0e1a --- /dev/null +++ b/libs/community/tests/integration_tests/chat_models/test_mlx_tool_calls.py @@ -0,0 +1,55 @@ +"""Tests ChatMLX tool calling.""" + +from typing import Dict + +import pytest +from langchain_core.messages import AIMessage, HumanMessage, ToolMessage +from langchain_core.tools import tool + +from langchain_community.chat_models.mlx import ChatMLX +from langchain_community.llms.mlx_pipeline import MLXPipeline + +# Use a Phi-3 model for more reliable tool-calling behavior +MODEL_ID = "mlx-community/phi-3-mini-128k-instruct" + + +@tool +def multiply(a: int, b: int) -> int: + """Multiply two integers.""" + return a * b + + +@pytest.fixture(scope="module") +def chat() -> ChatMLX: + """Return ChatMLX bound with the multiply tool or skip if unavailable.""" + try: + llm = MLXPipeline.from_model_id( + model_id=MODEL_ID, pipeline_kwargs={"max_new_tokens": 150} + ) + except Exception: + pytest.skip("Required MLX model isn't available.", allow_module_level=True) + chat_model = ChatMLX(llm=llm) + return chat_model.bind_tools(tools=[multiply], tool_choice=True) # type: ignore[return-value] + + +def _call_tool(tool_call: Dict) -> ToolMessage: + result = multiply.invoke(tool_call["args"]) + return ToolMessage(content=str(result), tool_call_id=tool_call.get("id", "")) + + +def test_mlx_tool_calls_soft(chat: ChatMLX) -> None: + messages = [HumanMessage(content="Use the multiply tool to compute 2 * 3.")] + ai_msg = chat.invoke(messages) + tool_msg = _call_tool(ai_msg.tool_calls[0]) + final = chat.invoke(messages + [ai_msg, tool_msg]) + assert "6" in final.content + + +def test_mlx_tool_calls_hard(chat: ChatMLX) -> None: + messages = [HumanMessage(content="Use the multiply tool to compute 2 * 3.")] + ai_msg = chat.invoke(messages) + assert isinstance(ai_msg, AIMessage) + assert ai_msg.tool_calls + tool_call = ai_msg.tool_calls[0] + assert tool_call["name"] == "multiply" + assert tool_call["args"] == {"a": 2, "b": 3}