diff --git a/libs/community/langchain_community/chat_models/mlx.py b/libs/community/langchain_community/chat_models/mlx.py index 720dbd86..fdee8d65 100644 --- a/libs/community/langchain_community/chat_models/mlx.py +++ b/libs/community/langchain_community/chat_models/mlx.py @@ -1,6 +1,7 @@ """MLX Chat Wrapper.""" import json +import logging import re import uuid from typing import ( @@ -48,6 +49,8 @@ from langchain_community.llms.mlx_pipeline import MLXPipeline +logger = logging.getLogger(__name__) + DEFAULT_SYSTEM_PROMPT = """You are a helpful, respectful, and honest assistant.""" @@ -125,6 +128,25 @@ def __init__(self, **kwargs: Any): super().__init__(**kwargs) self.tokenizer = self.llm.tokenizer + def _parse_tool_args(self, arg_text: str) -> Dict[str, Any]: + """Parse the arguments for a tool call. + + Args: + arg_text: JSON string representation of the tool arguments. + + Returns: + Parsed arguments dictionary. If parsing fails, returns a dict with + the original text under the ``input`` key. + """ + try: + args = json.loads(arg_text) + except json.JSONDecodeError: + args = {"input": arg_text} + except Exception as e: # pragma: no cover - defensive + logger.warning("Unexpected error during tool argument parsing: %s", e) + args = {"input": arg_text} + return args + def _generate( self, messages: List[BaseMessage],