Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
109 changes: 105 additions & 4 deletions libs/community/langchain_community/chat_models/mlx.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
"""MLX Chat Wrapper."""

import json
import re
import uuid
from typing import (
Any,
Callable,
Expand All @@ -9,6 +12,7 @@
Literal,
Optional,
Sequence,
Tuple,
Type,
Union,
)
Expand All @@ -24,7 +28,13 @@
AIMessageChunk,
BaseMessage,
HumanMessage,
InvalidToolCall,
SystemMessage,
ToolCall,
)
from langchain_core.output_parsers.openai_tools import (
make_invalid_tool_call,
parse_tool_call,
)
from langchain_core.outputs import (
ChatGeneration,
Expand All @@ -41,6 +51,52 @@
DEFAULT_SYSTEM_PROMPT = """You are a helpful, respectful, and honest assistant."""


def _parse_react_tool_calls(
text: str,
) -> Tuple[list[ToolCall] | None, list[InvalidToolCall]]:
"""Extract ReAct-style tool calls from plain text output.

Args:
text: Raw model generation text.

Returns:
A tuple containing a list of parsed ``ToolCall`` objects if any were
detected, otherwise ``None``, and a list of ``InvalidToolCall`` objects
for unparseable patterns.
"""

tool_calls: list[ToolCall] = []
invalid_tool_calls: list[InvalidToolCall] = []

bracket_pattern = r"Action:\s*(?P<name>[\w.-]+)\[(?P<input>[^\]]+)\]"
separate_pattern = (
r"Action:\s*(?P<name>[^\n]+)\nAction Input:\s*(?P<input>[^\n]+)"
)

matches = list(re.finditer(bracket_pattern, text))
if not matches:
matches = list(re.finditer(separate_pattern, text))

for match in matches:
name = match.group("name").strip()
arg_text = match.group("input").strip()
try:
args = json.loads(arg_text)
if not isinstance(args, dict):
args = {"input": args}
except Exception:
args = {"input": arg_text}
tool_calls.append(ToolCall(id=str(uuid.uuid4()), name=name, args=args))

if not tool_calls and "Action:" in text:
invalid_tool_calls.append(
make_invalid_tool_call(text, "Could not parse ReAct tool call")
)
return None, invalid_tool_calls

return tool_calls or None, invalid_tool_calls


class ChatMLX(BaseChatModel):
"""MLX chat models.

Expand Down Expand Up @@ -76,7 +132,8 @@ def _generate(
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
llm_input = self._to_chat_prompt(messages)
tools = kwargs.pop("tools", None)
llm_input = self._to_chat_prompt(messages, tools=tools)
llm_result = self.llm._generate(
prompts=[llm_input], stop=stop, run_manager=run_manager, **kwargs
)
Expand All @@ -89,7 +146,8 @@ async def _agenerate(
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
llm_input = self._to_chat_prompt(messages)
tools = kwargs.pop("tools", None)
llm_input = self._to_chat_prompt(messages, tools=tools)
llm_result = await self.llm._agenerate(
prompts=[llm_input], stop=stop, run_manager=run_manager, **kwargs
)
Expand All @@ -100,8 +158,17 @@ def _to_chat_prompt(
messages: List[BaseMessage],
tokenize: bool = False,
return_tensors: Optional[str] = None,
tools: Sequence[dict] | None = None,
) -> str:
"""Convert a list of messages into a prompt format expected by wrapped LLM."""
"""Convert messages to the prompt format expected by the wrapped LLM.

Args:
messages: Chat messages to include in the prompt.
tokenize: Whether to return token IDs instead of text.
return_tensors: Framework for returned tensors when ``tokenize`` is
True.
tools: Optional tool definitions to include in the prompt.
"""
if not messages:
raise ValueError("At least one HumanMessage must be provided!")

Expand All @@ -114,6 +181,7 @@ def _to_chat_prompt(
tokenize=tokenize,
add_generation_prompt=True,
return_tensors=return_tensors,
tools=tools,
)

def _to_chatml_format(self, message: BaseMessage) -> dict:
Expand All @@ -135,8 +203,41 @@ def _to_chat_result(llm_result: LLMResult) -> ChatResult:
chat_generations = []

for g in llm_result.generations[0]:
tool_calls: list[ToolCall] = []
invalid_tool_calls: list[InvalidToolCall] = []
additional_kwargs: Dict[str, Any] = {}

if isinstance(g.generation_info, dict):
raw_tool_calls = g.generation_info.get("tool_calls")
else:
raw_tool_calls = None

if raw_tool_calls:
additional_kwargs["tool_calls"] = raw_tool_calls
for raw_tool_call in raw_tool_calls:
try:
tc = parse_tool_call(raw_tool_call, return_id=True)
except Exception as e:
invalid_tool_calls.append(
make_invalid_tool_call(raw_tool_call, str(e))
)
else:
if tc:
tool_calls.append(tc)
else:
react_tool_calls, invalid_reacts = _parse_react_tool_calls(g.text)
if react_tool_calls is not None:
tool_calls.extend(react_tool_calls)
invalid_tool_calls.extend(invalid_reacts)

chat_generation = ChatGeneration(
message=AIMessage(content=g.text), generation_info=g.generation_info
message=AIMessage(
content=g.text,
additional_kwargs=additional_kwargs,
tool_calls=tool_calls,
invalid_tool_calls=invalid_tool_calls,
),
generation_info=g.generation_info,
)
chat_generations.append(chat_generation)

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
"""Tests ChatMLX tool calling."""

from typing import Dict

import pytest
from langchain_core.messages import AIMessage, HumanMessage, ToolMessage
from langchain_core.tools import tool

from langchain_community.chat_models.mlx import ChatMLX
from langchain_community.llms.mlx_pipeline import MLXPipeline

# Use a Phi-3 model for more reliable tool-calling behavior
MODEL_ID = "mlx-community/phi-3-mini-128k-instruct"


@tool
def multiply(a: int, b: int) -> int:
"""Multiply two integers."""
return a * b


@pytest.fixture(scope="module")
def chat() -> ChatMLX:
"""Return ChatMLX bound with the multiply tool or skip if unavailable."""
try:
llm = MLXPipeline.from_model_id(
model_id=MODEL_ID, pipeline_kwargs={"max_new_tokens": 150}
)
except Exception:
pytest.skip("Required MLX model isn't available.", allow_module_level=True)
chat_model = ChatMLX(llm=llm)
return chat_model.bind_tools(tools=[multiply], tool_choice=True) # type: ignore[return-value]


def _call_tool(tool_call: Dict) -> ToolMessage:
result = multiply.invoke(tool_call["args"])
return ToolMessage(content=str(result), tool_call_id=tool_call.get("id", ""))


def test_mlx_tool_calls_soft(chat: ChatMLX) -> None:
messages = [HumanMessage(content="Use the multiply tool to compute 2 * 3.")]
ai_msg = chat.invoke(messages)
tool_msg = _call_tool(ai_msg.tool_calls[0])
final = chat.invoke(messages + [ai_msg, tool_msg])
assert "6" in final.content


def test_mlx_tool_calls_hard(chat: ChatMLX) -> None:
messages = [HumanMessage(content="Use the multiply tool to compute 2 * 3.")]
ai_msg = chat.invoke(messages)
assert isinstance(ai_msg, AIMessage)
assert ai_msg.tool_calls
tool_call = ai_msg.tool_calls[0]
assert tool_call["name"] == "multiply"
assert tool_call["args"] == {"a": 2, "b": 3}
71 changes: 71 additions & 0 deletions libs/community/tests/unit_tests/chat_models/test_mlx.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,42 @@

from importlib import import_module

import pytest
from langchain_core.messages import HumanMessage

from langchain_community.chat_models.mlx import ChatMLX


class _FakeTokenizer:
def __init__(self) -> None:
self.tools = None

def apply_chat_template(
self,
messages,
tokenize=False,
add_generation_prompt=True,
return_tensors=None,
tools=None,
) -> str:
self.tools = tools
return "prompt"


class _FakeLLM:
def __init__(self) -> None:
self.tokenizer = _FakeTokenizer()

def _generate(self, prompts, stop=None, run_manager=None, **kwargs):
class _Res:
generations = [[type("G", (), {"text": "", "generation_info": {}})]]
llm_output = {}

return _Res()

async def _agenerate(self, prompts, stop=None, run_manager=None, **kwargs):
return self._generate(prompts, stop=stop, run_manager=run_manager, **kwargs)


def test_import_class() -> None:
"""Test that the class can be imported."""
Expand All @@ -10,3 +46,38 @@ def test_import_class() -> None:

module = import_module(module_name)
assert hasattr(module, class_name)


def test_generate_passes_tools_to_tokenizer() -> None:
llm = _FakeLLM()
chat = ChatMLX(llm=llm)
tools = [
{
"type": "function",
"function": {
"name": "foo",
"description": "",
"parameters": {"type": "object", "properties": {}},
},
}
]
chat._generate([HumanMessage(content="hi")], tools=tools)
assert llm.tokenizer.tools == tools


@pytest.mark.asyncio
async def test_agenerate_passes_tools_to_tokenizer() -> None:
llm = _FakeLLM()
chat = ChatMLX(llm=llm)
tools = [
{
"type": "function",
"function": {
"name": "foo",
"description": "",
"parameters": {"type": "object", "properties": {}},
},
}
]
await chat._agenerate([HumanMessage(content="hi")], tools=tools)
assert llm.tokenizer.tools == tools