diff --git a/CHANGELOG.md b/CHANGELOG.md index 02e2cab..70193a3 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -20,11 +20,12 @@ and this project attempts to adhere to [Semantic Versioning](https://semver.org/ ### Added -- Pydantic is now a explicit dependency of the library. Previously, it was a transitive dependency via FastMCP. +- Pydantic is now an explicit dependency of the library. Previously, it was a transitive dependency via FastMCP. ### Changed - The `code` argument to the `django_shell` MCP tool now has a helpful description. +- `django_shell` now returns structured output with execution status, error details, and filtered tracebacks instead of plain strings. ## [0.5.0] diff --git a/src/mcp_django_shell/output.py b/src/mcp_django_shell/output.py new file mode 100644 index 0000000..172ddfc --- /dev/null +++ b/src/mcp_django_shell/output.py @@ -0,0 +1,125 @@ +from __future__ import annotations + +import traceback +from enum import Enum +from types import TracebackType +from typing import Any + +from pydantic import BaseModel +from pydantic import ConfigDict +from pydantic import field_serializer + +from mcp_django_shell.shell import ErrorResult +from mcp_django_shell.shell import ExpressionResult +from mcp_django_shell.shell import Result +from mcp_django_shell.shell import StatementResult + + +class DjangoShellOutput(BaseModel): + status: ExecutionStatus + input: str + output: Output + stdout: str + stderr: str + + @classmethod + def from_result(cls, result: Result) -> DjangoShellOutput: + output: Output + + match result: + case ExpressionResult(): + output = ExpressionOutput( + value=result.value, + value_type=type(result.value), + ) + case StatementResult(): + output = StatementOutput() + case ErrorResult(): + exception = ExceptionOutput( + exc_type=type(result.exception), + message=str(result.exception), + traceback=result.exception.__traceback__, + ) + output = ErrorOutput(exception=exception) + + return cls( + status=ExecutionStatus.from_output(output), + input=result.code, + output=output, + stdout=result.stdout, + stderr=result.stderr, + ) + + +class ExecutionStatus(str, Enum): + SUCCESS = "success" + ERROR = "error" + + @classmethod + def from_output(cls, output: Output) -> ExecutionStatus: + match output: + case ExpressionOutput() | StatementOutput(): + return cls.SUCCESS + case ErrorOutput(): + return cls.ERROR + + +class ExpressionOutput(BaseModel): + model_config = ConfigDict(arbitrary_types_allowed=True) + + value: Any + value_type: type + + @field_serializer("value") + def serialize_value(self, val: Any) -> str: + if val is None: + return "None" + return repr(val) + + @field_serializer("value_type") + def serialize_value_type(self, typ: type) -> str: + return typ.__name__ + + +class StatementOutput(BaseModel): + """Output from evaluating a Python statement. + + Statements by definition do not return values, just side effects such as + setting variables or executing functions. + + This is empty for now, but defined to gain type safety (see the `Output` + tagged union below) and as a holder for any potential future metadata. + """ + + +class ErrorOutput(BaseModel): + exception: ExceptionOutput + + +class ExceptionOutput(BaseModel): + model_config = ConfigDict(arbitrary_types_allowed=True) + + exc_type: type[Exception] + message: str + traceback: TracebackType | None + + @field_serializer("exc_type") + def serialize_exception_type(self, exc_type: type[Exception]) -> str: + return exc_type.__name__ + + @field_serializer("traceback") + def serialize_traceback(self, tb: TracebackType | None) -> list[str]: + if tb is None: + return [] + + tb_lines = traceback.format_tb(tb) + relevant_tb_lines = [ + line.strip() + for line in tb_lines + if "mcp_django_shell" not in line and line.strip() + ] + + return relevant_tb_lines + + +Output = ExpressionOutput | StatementOutput | ErrorOutput diff --git a/src/mcp_django_shell/server.py b/src/mcp_django_shell/server.py index 24303db..14a2c39 100644 --- a/src/mcp_django_shell/server.py +++ b/src/mcp_django_shell/server.py @@ -6,8 +6,9 @@ from fastmcp import Context from fastmcp import FastMCP +from .output import DjangoShellOutput +from .output import ErrorOutput from .shell import DjangoShell -from .shell import ErrorResult mcp = FastMCP( name="Django Shell", @@ -21,7 +22,7 @@ async def django_shell( code: Annotated[str, "Python code to be executed inside the Django shell session"], ctx: Context, -) -> str: +) -> DjangoShellOutput: """Execute Python code in a stateful Django shell session. Django is pre-configured and ready to use with your project. You can import and use any Django @@ -47,22 +48,17 @@ async def django_shell( try: result = await shell.execute(code) + output = DjangoShellOutput.from_result(result) logger.debug( "django_shell execution completed - request_id: %s, result type: %s", ctx.request_id, type(result).__name__, ) - if isinstance(result, ErrorResult): - await ctx.debug( - f"Execution failed: {type(result.exception).__name__}", - extra={ - "error_type": type(result.exception).__name__, - "code_length": len(code), - }, - ) - - return result.output + if isinstance(output.output, ErrorOutput): + await ctx.debug(f"Execution failed: {output.output.exception.message}") + + return output except Exception as e: logger.error( diff --git a/src/mcp_django_shell/shell.py b/src/mcp_django_shell/shell.py index 5eef2d9..68d8f32 100644 --- a/src/mcp_django_shell/shell.py +++ b/src/mcp_django_shell/shell.py @@ -1,7 +1,6 @@ from __future__ import annotations import logging -import traceback from contextlib import redirect_stderr from contextlib import redirect_stdout from dataclasses import dataclass @@ -196,17 +195,6 @@ def __post_init__(self): logger.debug("%s.stdout: %s", self.__class__.__name__, self.stdout[:200]) if self.stderr: logger.debug("%s.stderr: %s", self.__class__.__name__, self.stderr[:200]) - logger.debug( - "%s output (for LLM): %s", self.__class__.__name__, self.output[:500] - ) - - @property - def output(self) -> str: - if self.stdout: - return self.stdout.rstrip() - elif self.value is not None: - return repr(self.value) - return "" @dataclass @@ -222,13 +210,6 @@ def __post_init__(self): logger.debug("%s.stdout: %s", self.__class__.__name__, self.stdout[:200]) if self.stderr: logger.debug("%s.stderr: %s", self.__class__.__name__, self.stderr[:200]) - logger.debug( - "%s output (for LLM): %s", self.__class__.__name__, self.output[:500] - ) - - @property - def output(self) -> str: - return self.stdout or "OK" @dataclass @@ -250,34 +231,6 @@ def __post_init__(self): logger.debug("%s.stdout: %s", self.__class__.__name__, self.stdout[:200]) if self.stderr: logger.debug("%s.stderr: %s", self.__class__.__name__, self.stderr[:200]) - logger.debug( - "%s output (filtered for LLM): %s", - self.__class__.__name__, - self.output[:500], - ) - - @property - def output(self) -> str: - error_type = self.exception.__class__.__name__ - - tb_str = "".join( - traceback.format_exception( - type(self.exception), self.exception, self.exception.__traceback__ - ) - ) - - # Filter out framework lines - tb_lines = tb_str.splitlines() - relevant_tb = "\n".join( - line for line in tb_lines if "mcp_django_shell" not in line - ) - - error_output = f"{error_type}: {self.exception}\n\nTraceback:\n{relevant_tb}" - - if self.stdout: - return f"{self.stdout.rstrip()}\n{error_output}" - - return error_output Result = ExpressionResult | StatementResult | ErrorResult diff --git a/tests/test_output.py b/tests/test_output.py new file mode 100644 index 0000000..0cffca0 --- /dev/null +++ b/tests/test_output.py @@ -0,0 +1,128 @@ +from __future__ import annotations + +import traceback + +from mcp_django_shell.output import DjangoShellOutput +from mcp_django_shell.output import ErrorOutput +from mcp_django_shell.output import ExceptionOutput +from mcp_django_shell.output import ExecutionStatus +from mcp_django_shell.output import ExpressionOutput +from mcp_django_shell.shell import ErrorResult +from mcp_django_shell.shell import ExpressionResult + + +def test_django_shell_output_from_expression_result(): + result = ExpressionResult( + code="2 + 2", + value=4, + stdout="", + stderr="", + ) + + output = DjangoShellOutput.from_result(result) + + assert output.status == ExecutionStatus.SUCCESS + assert output.input == "2 + 2" + assert isinstance(output.output, ExpressionOutput) + assert output.output.value == 4 + assert output.output.value_type is int + + +def test_expression_with_none_value(): + result = ExpressionResult( + code="None", + value=None, + stdout="", + stderr="", + ) + + output = DjangoShellOutput.from_result(result) + + assert output.status == ExecutionStatus.SUCCESS + assert isinstance(output.output, ExpressionOutput) + + serialized = output.output.model_dump(mode="json") + + assert serialized["value"] == "None" + assert serialized["value_type"] == "NoneType" + + +def test_django_shell_output_from_error_result(): + exc = ZeroDivisionError("division by zero") + + result = ErrorResult( + code="1 / 0", + exception=exc, + stdout="", + stderr="", + ) + + output = DjangoShellOutput.from_result(result) + + assert output.status == ExecutionStatus.ERROR + assert isinstance(output.output, ErrorOutput) + assert output.output.exception.exc_type is ZeroDivisionError + assert "division by zero" in output.output.exception.message + + +def test_exception_output_serialization(): + exc = ValueError("test error") + + exc_output = ExceptionOutput( + exc_type=type(exc), + message=str(exc), + traceback=None, # needs to be None since we didn't actually raise it + ) + + serialized = exc_output.model_dump(mode="json") + + assert serialized["exc_type"] == "ValueError" + assert serialized["message"] == "test error" + assert serialized["traceback"] == [] + + +def test_exception_output_with_real_traceback(): + # Do something that actually raises an error + try: + 1 / 0 + except ZeroDivisionError as e: + exc_output = ExceptionOutput( + exc_type=type(e), + message=str(e), + traceback=e.__traceback__, + ) + + serialized = exc_output.model_dump(mode="json") + + assert serialized["exc_type"] == "ZeroDivisionError" + assert "division by zero" in serialized["message"] + assert isinstance(serialized["traceback"], list) + assert len(serialized["traceback"]) > 0 + assert any("1 / 0" in line for line in serialized["traceback"]) + assert not any("mcp_django_shell" in line for line in serialized["traceback"]) + + +def test_traceback_filtering(): + # Create a function that will appear in the traceback + def mcp_django_shell_function(): + raise ValueError("test error") + + try: + mcp_django_shell_function() + except ValueError as e: + exc_output = ExceptionOutput( + exc_type=type(e), + message=str(e), + traceback=e.__traceback__, + ) + + assert any( + "mcp_django_shell_function" in line + for line in traceback.format_tb(e.__traceback__) + ) + + serialized = exc_output.model_dump(mode="json") + + assert len(serialized["traceback"]) == 0 or not any( + "mcp_django_shell" in line for line in serialized["traceback"] + ) diff --git a/tests/test_server.py b/tests/test_server.py index ad498f7..d6f0451 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -4,9 +4,12 @@ import pytest import pytest_asyncio +from django.conf import settings +from django.test import override_settings from fastmcp import Client from fastmcp.exceptions import ToolError +from mcp_django_shell.output import ExecutionStatus from mcp_django_shell.server import mcp from mcp_django_shell.server import shell @@ -35,9 +38,17 @@ async def test_tool_listing(): async def test_django_shell_tool(): async with Client(mcp) as client: result = await client.call_tool("django_shell", {"code": "2 + 2"}) - assert "4" in result.data + assert result.data.status == ExecutionStatus.SUCCESS + assert result.data.output.value == "4" +@override_settings( + INSTALLED_APPS=settings.INSTALLED_APPS + + [ + "django.contrib.auth", + "django.contrib.contenttypes", + ] +) async def test_django_shell_tool_orm(): async with Client(mcp) as client: result = await client.call_tool( @@ -46,7 +57,21 @@ async def test_django_shell_tool_orm(): "code": "from django.contrib.auth import get_user_model; get_user_model().__name__" }, ) - assert result.data is not None + assert result.data.status == ExecutionStatus.SUCCESS + + +async def test_django_shell_error_output(): + async with Client(mcp) as client: + result = await client.call_tool("django_shell", {"code": "1 / 0"}) + + assert result.data.status == ExecutionStatus.ERROR.value + assert "ZeroDivisionError" in str(result.data.output.exception.exc_type) + assert "division by zero" in result.data.output.exception.message + assert len(result.data.output.exception.traceback) > 0 + assert not any( + "mcp_django_shell" in line + for line in result.data.output.exception.traceback + ) async def test_django_shell_tool_unexpected_error(monkeypatch): @@ -64,9 +89,10 @@ async def test_django_reset_session(): await client.call_tool("django_shell", {"code": "x = 42"}) result = await client.call_tool("django_reset", {}) - assert "reset" in result.data.lower() + assert "reset" in result.data.lower() # This one still returns a string result = await client.call_tool( "django_shell", {"code": "print('x' in globals())"} ) - assert "False" in result.data + # Check stdout contains "False" + assert "False" in result.data.stdout diff --git a/tests/test_shell.py b/tests/test_shell.py index 34d1d0c..1b2f31d 100644 --- a/tests/test_shell.py +++ b/tests/test_shell.py @@ -11,8 +11,6 @@ from mcp_django_shell.shell import StatementResult from mcp_django_shell.shell import parse_code -from .models import AModel - @pytest.fixture def shell(): @@ -116,13 +114,11 @@ def test_execute_expression_returns_value(self, shell): assert isinstance(result, ExpressionResult) assert result.value == 4 - assert "4" in result.output def test_execute_statement_returns_ok(self, shell): result = shell._execute("x = 5") assert isinstance(result, StatementResult) - assert result.output == "OK" def test_execute_multiline_expression_returns_last_value(self, shell): code = """\ @@ -134,7 +130,6 @@ def test_execute_multiline_expression_returns_last_value(self, shell): assert isinstance(result, ExpressionResult) assert result.value == 15 - assert "15" in result.output def test_execute_multiline_statements_returns_ok(self, shell): code = """\ @@ -145,14 +140,12 @@ def test_execute_multiline_statements_returns_ok(self, shell): result = shell._execute(code.strip()) assert isinstance(result, StatementResult) - assert result.output == "OK" def test_execute_print_captures_stdout(self, shell): result = shell._execute('print("Hello, World!")') assert isinstance(result, ExpressionResult) assert result.value is None - assert "Hello, World!" in result.output def test_multiline_ending_with_print_no_none(self, shell): code = """ @@ -164,20 +157,16 @@ def test_multiline_ending_with_print_no_none(self, shell): assert isinstance(result, ExpressionResult) assert result.value is None - assert "Sum: 15" in result.output - assert result.output == "Sum: 15" def test_execute_invalid_code_returns_error(self, shell): result = shell._execute("1 / 0") assert isinstance(result, ErrorResult) - assert "ZeroDivisionError" in result.output def test_execute_empty_string_returns_ok(self, shell): result = shell._execute("") assert isinstance(result, StatementResult) - assert result.output == "OK" def test_execute_whitespace_only_returns_ok(self, shell): result = shell._execute(" \n \t ") @@ -194,102 +183,6 @@ async def test_async_execute_returns_result(self): assert result.value == 4 -class TestResultOutput: - def test_expression_with_value_no_stdout(self, shell): - result = shell._execute("42") - assert isinstance(result, ExpressionResult) - assert result.output == "42" - - def test_expression_with_none_no_stdout(self, shell): - result = shell._execute("None") - assert isinstance(result, ExpressionResult) - assert result.output == "" - - def test_expression_with_stdout_and_value(self, shell): - code = """\ -print('hello') -42 -""" - result = shell._execute(code.strip()) - assert isinstance(result, ExpressionResult) - assert result.output == "hello" # No "42" shown - - def test_expression_with_stdout_and_none(self, shell): - result = shell._execute("print('hello')") - assert isinstance(result, ExpressionResult) - assert result.output == "hello" # No "None" shown - - def test_multiline_ending_with_print(self, shell): - code = """\ -x = 5 -y = 10 -print(f"Sum: {x + y}") -""" - result = shell._execute(code.strip()) - assert isinstance(result, ExpressionResult) - assert result.output == "Sum: 15" # No "None" appended - - def test_function_returning_none(self, shell): - code = """\ -def foo(): - x = 2 -foo() -""" - result = shell._execute(code.strip()) - assert isinstance(result, ExpressionResult) - assert result.output == "" - - @pytest.mark.django_db - def test_queryset_shows_standard_repr(self, shell): - for i in range(3): - AModel.objects.create(name=f"Item {i}", value=i) - - shell._execute("from tests.models import AModel") - result = shell._execute("AModel.objects.all()") - - assert isinstance(result, ExpressionResult) - assert result.output.startswith("