diff --git a/docs/examples/tools/interpreter_example.py b/docs/examples/tools/interpreter_example.py new file mode 100644 index 00000000..ccde0ada --- /dev/null +++ b/docs/examples/tools/interpreter_example.py @@ -0,0 +1,77 @@ +from mellea.stdlib.tools import code_interpreter, local_code_interpreter +from mellea import start_session, MelleaSession +from mellea.backends.types import ModelOption +from mellea.backends.model_ids import OPENAI_GPT_OSS_20B +from mellea.stdlib.reqlib.tools import uses_tool, tool_arg_validator + + +def example_1(m: MelleaSession): + # First, let's see how the code interpreter function works without an LLM in the loop: + result = code_interpreter("print(1+1)") + print(result) + + +# Now let's ask the LLM to make a plot. + + +def example_2(m: MelleaSession): + plot_output = m.instruct( + description="Make a plot of y=x^2", + model_options={ModelOption.TOOLS: [local_code_interpreter]}, + ) + print(plot_output) + + +# Notice that the model did not actually generate a plot. Let's force tool use: + + +def example_3(m: MelleaSession): + plot_output = m.instruct( + description="Use the code interpreter tool to make a plot of y=x^2.", + requirements=[uses_tool(local_code_interpreter)], + model_options={ModelOption.TOOLS: [local_code_interpreter]}, + tool_calls=True, + ) + + code = plot_output.tool_calls["local_code_interpreter"].args["code"] + print(f"Going to execute the following code:\n```python\n{code}\n```") + + # Call the tool. + exec_result = plot_output.tool_calls["local_code_interpreter"].call_func() + + print(exec_result) + + +# Notice that the model did make a plot, but it just "showed" the plot. +# We would actually like this to be written out to a file. + + +def example_4(m: MelleaSession): + plot_output = m.instruct( + description="Use the code interpreter tool to make a plot of y=x^2.", + requirements=[ + uses_tool(local_code_interpreter), + tool_arg_validator( + "The plot should be written to /tmp/output.png", + tool_name=local_code_interpreter, + arg_name="code", + validation_fn=lambda code_snippet: "/tmp/output.png" in code_snippet + and "plt.show()" not in code_snippet, + ), + ], + model_options={ModelOption.TOOLS: [local_code_interpreter]}, + tool_calls=True, + ) + + code = plot_output.tool_calls["local_code_interpreter"].args["code"] + print(f"Going to execute the following code:\n```python\n{code}\n```") + + # Call the tool. + exec_result = plot_output.tool_calls["local_code_interpreter"].call_func() + + print(exec_result) + + +# m = start_session(backend_name="ollama", model_id=OPENAI_GPT_OSS_20B) +m = start_session() +example_4(m) diff --git a/mellea/stdlib/reqlib/python.py b/mellea/stdlib/reqlib/python.py index 2bcab2c5..1c83a330 100644 --- a/mellea/stdlib/reqlib/python.py +++ b/mellea/stdlib/reqlib/python.py @@ -12,204 +12,15 @@ from mellea.helpers.fancy_logger import FancyLogger from mellea.stdlib.base import Context from mellea.stdlib.requirement import Requirement, ValidationResult +from mellea.stdlib.tools.interpreter import ( + ExecutionEnvironment, + LLMSandboxEnvironment, + StaticAnalysisEnvironment, + UnsafeEnvironment, +) logger = FancyLogger.get_logger() -# region execution backends - - -@dataclass -class ExecutionResult: - """Result of code execution.""" - - success: bool - message: str | None = None - error: str | None = None - skipped: bool = False - - -class ExecutionEnvironment(ABC): - """Abstract environment for executing Python code.""" - - def __init__(self, allowed_imports: list[str] | None = None): - """Initialize with optional import restrictions. - - Args: - allowed_imports: List of allowed import modules. None means any import is allowed. - """ - self.allowed_imports = allowed_imports - - @abstractmethod - def execute(self, code: str, timeout: int) -> ExecutionResult: - """Execute code and return result.""" - - -class SafeEnvironment(ExecutionEnvironment): - """Safe environment that validates but does not execute code.""" - - def execute(self, code: str, timeout: int) -> ExecutionResult: - """Validate code syntax and imports without executing.""" - try: - ast.parse(code) - except SyntaxError as e: - return ExecutionResult(success=False, error=str(e)) - - if self.allowed_imports: - unauthorized = _get_unauthorized_imports(code, self.allowed_imports) - if unauthorized: - return ExecutionResult( - success=False, - error=f"Unauthorized imports detected: {', '.join(unauthorized)}", - ) - - return ExecutionResult( - success=True, - skipped=True, - message="Code validated but not executed (safe mode)", - ) - - -class UnsafeEnvironment(ExecutionEnvironment): - """Unsafe environment that executes code directly with subprocess.""" - - def execute(self, code: str, timeout: int) -> ExecutionResult: - """Execute code with subprocess after checking imports.""" - if self.allowed_imports: - unauthorized = _get_unauthorized_imports(code, self.allowed_imports) - if unauthorized: - return ExecutionResult( - success=False, - error=f"Unauthorized imports detected: {', '.join(unauthorized)}", - ) - - return self._execute_subprocess(code, timeout) - - def _execute_subprocess(self, code: str, timeout: int) -> ExecutionResult: - with tempfile.NamedTemporaryFile(mode="w", suffix=".py", delete=False) as f: - f.write(code) - temp_file = f.name - - try: - # Execute code using the same Python interpreter and environment as the current process - # This ensures the code has access to all installed packages and dependencies - result = subprocess.run( - [sys.executable, temp_file], - capture_output=True, - text=True, - timeout=timeout, - ) - - if result.returncode == 0: - message = "Code executed successfully" - if result.stdout.strip(): - message += f"\nOutput: {result.stdout.strip()}" - return ExecutionResult(success=True, message=message) - else: - return ExecutionResult( - success=False, - error=f"Execution failed with error: {result.stderr[:200]}", - ) - except subprocess.TimeoutExpired: - return ExecutionResult( - success=False, error=f"Execution timed out after {timeout} seconds" - ) - except Exception as e: - return ExecutionResult(success=False, error=f"Execution error: {e!s}") - finally: - try: - Path(temp_file).unlink() - except Exception: - pass - - -class LLMSandboxEnvironment(ExecutionEnvironment): - """Environment using llm-sandbox for secure Docker-based execution.""" - - def execute(self, code: str, timeout: int) -> ExecutionResult: - """Execute code using llm-sandbox.""" - if self.allowed_imports: - unauthorized = _get_unauthorized_imports(code, self.allowed_imports) - if unauthorized: - return ExecutionResult( - success=False, - error=f"Unauthorized imports detected: {', '.join(unauthorized)}", - ) - - try: - from llm_sandbox import SandboxSession - except ImportError: - return ExecutionResult( - success=False, - error="llm-sandbox not installed. Install with: uv add 'llm-sandbox[docker]'", - ) - - try: - with SandboxSession( - lang="python", verbose=False, keep_template=False - ) as session: - result = session.run(code, timeout=timeout) - - if result.exit_code == 0: - message = "Code executed successfully in sandbox" - if ( - hasattr(result, "stdout") - and result.stdout - and result.stdout.strip() - ): - message += f"\nOutput: {result.stdout.strip()}" - return ExecutionResult(success=True, message=message) - else: - if result.stderr: - error_msg = f"Sandbox execution failed: {result.stderr[:200]}" - else: - # Log unknown error details for debugging - logger.warning( - f"Sandbox execution failed without stderr. Exit code: {result.exit_code}, " - f"Available attributes: {[attr for attr in dir(result) if not attr.startswith('_')]}" - ) - error_msg = f"Sandbox execution failed with exit code {result.exit_code} (no error details available)" - return ExecutionResult(success=False, error=error_msg) - - except Exception as e: - return ExecutionResult( - success=False, error=f"Sandbox execution error: {e!s}" - ) - - -def _get_unauthorized_imports(code: str, allowed_imports: list[str]) -> list[str]: - """Get list of unauthorized imports used in code.""" - unauthorized: list[str] = [] - try: - tree = ast.parse(code) - except SyntaxError: - return unauthorized - - for node in ast.walk(tree): - if isinstance(node, ast.Import): - for alias in node.names: - base_module = alias.name.split(".")[0] - if ( - base_module not in allowed_imports - and base_module not in unauthorized - ): - unauthorized.append(base_module) - elif isinstance(node, ast.ImportFrom): - if node.module: - base_module = node.module.split(".")[0] - if ( - base_module not in allowed_imports - and base_module not in unauthorized - ): - unauthorized.append(base_module) - return unauthorized - - -def _check_allowed_imports(code: str, allowed_imports: list[str]) -> bool: - """Check if code only uses allowed imports.""" - return len(_get_unauthorized_imports(code, allowed_imports)) == 0 - - -# endregion # region code extraction @@ -328,11 +139,11 @@ def _python_executes_without_error( elif allow_unsafe: environment = UnsafeEnvironment(allowed_imports=allowed_imports) else: - environment = SafeEnvironment(allowed_imports=allowed_imports) + environment = StaticAnalysisEnvironment(allowed_imports=allowed_imports) result = environment.execute(code, timeout) return ValidationResult( - result=result.success, reason=result.message or result.error + result=result.success, reason=result.to_validationresult_reason() ) diff --git a/mellea/stdlib/reqlib/tools.py b/mellea/stdlib/reqlib/tools.py new file mode 100644 index 00000000..6b64f18a --- /dev/null +++ b/mellea/stdlib/reqlib/tools.py @@ -0,0 +1,113 @@ +"""Requirements for tool-use workflows.""" + +from collections.abc import Callable +from typing import Optional + +from mellea.stdlib.base import Context +from mellea.stdlib.requirement import Requirement, ValidationResult + + +def _name2str(tool_name: str | Callable) -> str: + match tool_name: + case tool_name if callable(tool_name): + return tool_name.__name__ + case str(): + return tool_name + case _: + raise TypeError(f"Expected Callable or str but found: {type(tool_name)}") + + +def uses_tool(tool_name: str | Callable, check_only=False): + """Forces the model to call a given tool. + + Args: + tool_name: The tool that must be called; this can be either the name of the tool or the Callable for the tool. + check_only: Propagates to the Requirement. + + Use `tool_choice` if the OpenAI `tool_choice` model option is supported by your model and inference engine. + """ + tool_name = _name2str(tool_name) + + def _validate(ctx: Context): + output = ctx.last_output() + assert output is not None + if output.tool_calls is None: + return ValidationResult(result=False, reason="There were no tool calls.") + return ValidationResult(result=tool_name in output.tool_calls) + + return Requirement( + description=f"Use the {tool_name} tool.", + validation_fn=_validate, + check_only=check_only, + ) + + +def tool_arg_validator( + description: str, + tool_name: str | Callable | None, + arg_name: str, + validation_fn: Callable, + check_only: bool = False, +) -> Requirement: + """A requirement that passes only if `validation_fn` returns a True value for the *value* of the `arg_name` argument to `tool_name`. + + If `tool_name` is not specified, then this requirement is enforced for *every* tool that + + Args: + description: The Requirement description. + tool_name: The (optional) tool name for . + arg_name: The argument to check. + validation_fn: A validation function for validating the value of the `arg_name` argument. + check_only: propagates the `check_only` flag to the requirement. + + Todo: + 1. should this be a requirement? + 2. should this be done automatically when the user provides asserts in their function body? + """ + if tool_name: + tool_name = _name2str(tool_name) + + def _validate(ctx: Context): + output = ctx.last_output() + assert output is not None + + if output.tool_calls is None: + return ValidationResult( + result=False, + reason=f"Expected {tool_name} to be called but no tools were called.", + ) + + if tool_name: + if tool_name not in output.tool_calls: + return ValidationResult( + result=False, reason=f"Tool {tool_name} was not called." + ) + if arg_name not in output.tool_calls[tool_name].args: + return ValidationResult( + result=False, + reason=f"Tool {tool_name} did not call argument {arg_name}", + ) + arg_value = output.tool_calls[tool_name].args[arg_name] + validate_result = validation_fn(arg_value) + if validate_result: + return ValidationResult(result=True) + else: + return ValidationResult( + result=False, + reason=f"Valiudation did not pass for {tool_name}.{arg_name}. Arg value: {arg_value}. Argument validation result: {validate_result}", + ) + else: + for tool in output.tool_calls.keys(): + if arg_name in output.tool_calls[tool].args: + arg_value = output.tool_calls[tool].args[arg_name] + validate_result = validation_fn(arg_value) + if not validate_result: + return ValidationResult( + result=False, + reason=f"Valiudation did not pass for {tool_name}.{arg_name}. Arg value: {arg_value}. Argument validation result: {validate_result}", + ) + return ValidationResult(result=True) + + return Requirement( + description=description, validation_fn=_validate, check_only=check_only + ) diff --git a/mellea/stdlib/tools/__init__.py b/mellea/stdlib/tools/__init__.py new file mode 100644 index 00000000..24ca99aa --- /dev/null +++ b/mellea/stdlib/tools/__init__.py @@ -0,0 +1,3 @@ +"""Implementations of tools.""" + +from mellea.stdlib.tools.interpreter import code_interpreter, local_code_interpreter diff --git a/mellea/stdlib/tools/interpreter.py b/mellea/stdlib/tools/interpreter.py new file mode 100644 index 00000000..9f144057 --- /dev/null +++ b/mellea/stdlib/tools/interpreter.py @@ -0,0 +1,282 @@ +"""Code interpreter tool.""" + +import ast +import subprocess +import sys +import tempfile +from abc import ABC, abstractmethod +from dataclasses import dataclass +from pathlib import Path +from typing import Any + +from mellea.helpers.fancy_logger import FancyLogger +from mellea.stdlib.base import Context +from mellea.stdlib.requirement import Requirement, ValidationResult + +logger = FancyLogger.get_logger() + + +@dataclass +class ExecutionResult: + """Result of code execution. + + Code execution can be aborted prior to spinning up an interpreter (e.g., if prohibited imports are used). + In these cases, the `success` flag is set to False and the `skipped` flag is set to True. + + If code is executed, then `success` is set to true iff the exit code is 0, and the `stdout` and `stderr` outputs + are set to non-None values. + + We also use the `ExecutionResult` object to communicate the result of static and dynamic analyses. Those are passed back + using the `analysis_result` field. + + TODO: should we also be trying to pass back the value of the final expression evaluated, or the value of locals() and globals()? + """ + + success: bool + + stdout: str | None + + stderr: str | None + + """ Indicates whether execution was skipped. """ + skipped: bool = False + + """ If execution is skipped, this message indicates why. """ + skip_message: str | None = None + + """ Used for returning results from static analyses. """ + analysis_result: Any | None = None + + def to_validationresult_reason(self): + """Maps an ExecutionResult to a ValidationResult reason. + + TODO: Downstream use of this method is really hacky. A far better solution is for `ExecutionResult` to implement the `ValidationResult` interface. + """ + assert self.skip_message is not None or ( + self.stderr is not None and self.stdout is not None + ), ( + "Every ExecutionResult should have either a skip_message or a stdout/stderr stream." + ) + if self.skip_message: + reason = self.skip_message + else: + if self.success: + reason = self.stdout + else: + reason = self.stderr + return reason + + +class ExecutionEnvironment(ABC): + """Abstract environment for executing Python code.""" + + def __init__(self, allowed_imports: list[str] | None = None): + """Initialize with optional import restrictions. + + Args: + allowed_imports: List of allowed import modules. None means any import is allowed. + """ + self.allowed_imports = allowed_imports + + @abstractmethod + def execute(self, code: str, timeout: int) -> ExecutionResult: + """Execute code and return result.""" + + +class StaticAnalysisEnvironment(ExecutionEnvironment): + """Safe environment that validates but does not execute code.""" + + def execute(self, code: str, timeout: int) -> ExecutionResult: + """Validate code syntax and imports without executing.""" + try: + parse_tree = ast.parse(code) + except SyntaxError as e: + return ExecutionResult( + success=False, + stdout=None, + stderr=None, + skipped=True, + skip_message="Parse failed.", + analysis_result=e, + ) + + if self.allowed_imports: + unauthorized = _get_unauthorized_imports(code, self.allowed_imports) + if unauthorized: + return ExecutionResult( + success=False, + stdout=None, + stderr=None, + skipped=True, + skip_message=f"Unauthorized imports detected: {', '.join(unauthorized)}", + ) + + return ExecutionResult( + success=True, + stdout=None, + stderr=None, + skipped=True, + skip_message="Code parses successful; the parse result is in the analysis_result field of the ExecutionResult object. The static analysis execution environment does not execute code. To execute code, use one of the other execution environments.", + analysis_result=parse_tree, + ) + + +class UnsafeEnvironment(ExecutionEnvironment): + """Unsafe environment that executes code directly with subprocess.""" + + def execute(self, code: str, timeout: int) -> ExecutionResult: + """Execute code with subprocess after checking imports.""" + if self.allowed_imports: + unauthorized = _get_unauthorized_imports(code, self.allowed_imports) + if unauthorized: + return ExecutionResult( + success=False, + stdout=None, + stderr=None, + skipped=True, + skip_message=f"Unauthorized imports detected: {', '.join(unauthorized)}", + ) + + return self._execute_subprocess(code, timeout) + + def _execute_subprocess(self, code: str, timeout: int) -> ExecutionResult: + with tempfile.NamedTemporaryFile(mode="w", suffix=".py", delete=False) as f: + f.write(code) + temp_file = f.name + + try: + # Execute code using the same Python interpreter and environment as the current process + # This ensures the code has access to all installed packages and dependencies + result = subprocess.run( + [sys.executable, temp_file], + capture_output=True, + text=True, + timeout=timeout, + ) + return ExecutionResult( + success=result.returncode == 0, + stdout=result.stdout.strip(), + stderr=result.stderr.strip(), + ) + except subprocess.TimeoutExpired: + return ExecutionResult( + success=False, + stdout=None, + stderr=None, + skipped=True, + skip_message="Execution timed out.", + ) + except Exception as e: + return ExecutionResult( + success=False, + stdout=None, + stderr=None, + skipped=True, + skip_message=f"Exception encountered in Mellea process (*not* the code interpreter process) when trying to run code_interpreter: {e!s}", + ) + finally: + try: + Path(temp_file).unlink() + except Exception: + pass + + +class LLMSandboxEnvironment(ExecutionEnvironment): + """Environment using llm-sandbox for secure Docker-based execution.""" + + def execute(self, code: str, timeout: int) -> ExecutionResult: + """Execute code using llm-sandbox.""" + if self.allowed_imports: + unauthorized = _get_unauthorized_imports(code, self.allowed_imports) + if unauthorized: + return ExecutionResult( + success=False, + stdout=None, + stderr=None, + skipped=True, + skip_message=f"Unauthorized imports detected: {', '.join(unauthorized)}", + ) + + try: + from llm_sandbox import SandboxSession + except ImportError: + return ExecutionResult( + success=False, + stdout=None, + stderr=None, + skipped=True, + skip_message="llm-sandbox not installed. Install with: uv add 'llm-sandbox[docker]'", + ) + + try: + with SandboxSession( + lang="python", verbose=False, keep_template=False + ) as session: + result = session.run(code, timeout=timeout) + + return ExecutionResult( + success=result.exit_code == 0, + stdout=result.stdout.strip(), + stderr=result.stderr.strip(), + ) + except Exception as e: + return ExecutionResult( + success=False, + stdout=None, + stderr=None, + skipped=True, + skip_message=f"Sandbox execution error: {e!s}", + ) + + +def _get_unauthorized_imports(code: str, allowed_imports: list[str]) -> list[str]: + """Get list of unauthorized imports used in code.""" + unauthorized: list[str] = [] + try: + tree = ast.parse(code) + except SyntaxError: + return unauthorized + + for node in ast.walk(tree): + if isinstance(node, ast.Import): + for alias in node.names: + base_module = alias.name.split(".")[0] + if ( + base_module not in allowed_imports + and base_module not in unauthorized + ): + unauthorized.append(base_module) + elif isinstance(node, ast.ImportFrom): + if node.module: + base_module = node.module.split(".")[0] + if ( + base_module not in allowed_imports + and base_module not in unauthorized + ): + unauthorized.append(base_module) + return unauthorized + + +def _check_allowed_imports(code: str, allowed_imports: list[str]) -> bool: + """Check if code only uses allowed imports.""" + return len(_get_unauthorized_imports(code, allowed_imports)) == 0 + + +def code_interpreter(code: str) -> ExecutionResult: + """Executes python code. + + Args: + code: The Python code to execute. + """ + exec_env = LLMSandboxEnvironment(allowed_imports=None) + return exec_env.execute(code, 60) + + +def local_code_interpreter(code: str) -> ExecutionResult: + """Executes python code in the cwd. + + Args: + code: The Python code to execute. + """ + exec_env = UnsafeEnvironment(allowed_imports=None) + return exec_env.execute(code, 60) diff --git a/test/stdlib_basics/test_reqlib_python.py b/test/stdlib_basics/test_reqlib_python.py index 153a981f..b3f9211a 100644 --- a/test/stdlib_basics/test_reqlib_python.py +++ b/test/stdlib_basics/test_reqlib_python.py @@ -128,7 +128,6 @@ def test_safe_mode_default(): req = PythonExecutionReq() result = req.validation_fn(VALID_PYTHON_CTX) assert result.as_bool() is True - assert "safe mode" in result.reason def test_safe_mode_syntax_error(): @@ -143,7 +142,6 @@ def test_safe_mode_no_execution(): req = PythonExecutionReq(timeout=1) result = req.validation_fn(PYTHON_INFINITE_LOOP_CTX) assert result.as_bool() is True # Should pass because it's not actually executed - assert "safe mode" in result.reason # endregion @@ -225,7 +223,6 @@ def test_sandbox_execution_valid(): req = PythonExecutionReq(use_sandbox=True, timeout=10) result = req.validation_fn(VALID_PYTHON_CTX) assert result.as_bool() is True - assert "sandbox" in result.reason.lower() @pytest.mark.skipif( @@ -313,7 +310,6 @@ def test_direct_validation_function(): VALID_PYTHON_CTX, timeout=5, allow_unsafe=False, use_sandbox=False ) assert result.as_bool() is True - assert "safe mode" in result.reason result = _python_executes_without_error( SYNTAX_ERROR_CTX, timeout=5, allow_unsafe=False, use_sandbox=False diff --git a/test/stdlib_basics/test_reqlib_tools.py b/test/stdlib_basics/test_reqlib_tools.py new file mode 100644 index 00000000..b7f771f2 --- /dev/null +++ b/test/stdlib_basics/test_reqlib_tools.py @@ -0,0 +1,10 @@ +import pytest +from mellea.stdlib.reqlib.tools import _name2str + +def test_name2str(): + """Test handling when no Python code is present.""" + def test123(): + pass + assert _name2str(test123) == "test123" + assert _name2str("test1234") == "test1234" +