Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,12 @@ and this project attempts to adhere to [Semantic Versioning](https://semver.org/

### Added

- Pydantic is now a explicit dependency of the library. Previously, it was a transitive dependency via FastMCP.
- Pydantic is now an explicit dependency of the library. Previously, it was a transitive dependency via FastMCP.

### Changed

- The `code` argument to the `django_shell` MCP tool now has a helpful description.
- `django_shell` now returns structured output with execution status, error details, and filtered tracebacks instead of plain strings.

## [0.5.0]

Expand Down
125 changes: 125 additions & 0 deletions src/mcp_django_shell/output.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
from __future__ import annotations

import traceback
from enum import Enum
from types import TracebackType
from typing import Any

from pydantic import BaseModel
from pydantic import ConfigDict
from pydantic import field_serializer

from mcp_django_shell.shell import ErrorResult
from mcp_django_shell.shell import ExpressionResult
from mcp_django_shell.shell import Result
from mcp_django_shell.shell import StatementResult


class DjangoShellOutput(BaseModel):
status: ExecutionStatus
input: str
output: Output
stdout: str
stderr: str

@classmethod
def from_result(cls, result: Result) -> DjangoShellOutput:
output: Output

match result:
case ExpressionResult():
output = ExpressionOutput(
value=result.value,
value_type=type(result.value),
)
case StatementResult():
output = StatementOutput()
case ErrorResult():
exception = ExceptionOutput(
exc_type=type(result.exception),
message=str(result.exception),
traceback=result.exception.__traceback__,
)
output = ErrorOutput(exception=exception)

return cls(
status=ExecutionStatus.from_output(output),
input=result.code,
output=output,
stdout=result.stdout,
stderr=result.stderr,
)


class ExecutionStatus(str, Enum):
SUCCESS = "success"
ERROR = "error"

@classmethod
def from_output(cls, output: Output) -> ExecutionStatus:
match output:
case ExpressionOutput() | StatementOutput():
return cls.SUCCESS
case ErrorOutput():
return cls.ERROR


class ExpressionOutput(BaseModel):
model_config = ConfigDict(arbitrary_types_allowed=True)

value: Any
value_type: type

@field_serializer("value")
def serialize_value(self, val: Any) -> str:
if val is None:
return "None"
return repr(val)

@field_serializer("value_type")
def serialize_value_type(self, typ: type) -> str:
return typ.__name__


class StatementOutput(BaseModel):
"""Output from evaluating a Python statement.

Statements by definition do not return values, just side effects such as
setting variables or executing functions.

This is empty for now, but defined to gain type safety (see the `Output`
tagged union below) and as a holder for any potential future metadata.
"""


class ErrorOutput(BaseModel):
exception: ExceptionOutput


class ExceptionOutput(BaseModel):
model_config = ConfigDict(arbitrary_types_allowed=True)

exc_type: type[Exception]
message: str
traceback: TracebackType | None

@field_serializer("exc_type")
def serialize_exception_type(self, exc_type: type[Exception]) -> str:
return exc_type.__name__

@field_serializer("traceback")
def serialize_traceback(self, tb: TracebackType | None) -> list[str]:
if tb is None:
return []

tb_lines = traceback.format_tb(tb)
relevant_tb_lines = [
line.strip()
for line in tb_lines
if "mcp_django_shell" not in line and line.strip()
]

return relevant_tb_lines


Output = ExpressionOutput | StatementOutput | ErrorOutput
20 changes: 8 additions & 12 deletions src/mcp_django_shell/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,9 @@
from fastmcp import Context
from fastmcp import FastMCP

from .output import DjangoShellOutput
from .output import ErrorOutput
from .shell import DjangoShell
from .shell import ErrorResult

mcp = FastMCP(
name="Django Shell",
Expand All @@ -21,7 +22,7 @@
async def django_shell(
code: Annotated[str, "Python code to be executed inside the Django shell session"],
ctx: Context,
) -> str:
) -> DjangoShellOutput:
"""Execute Python code in a stateful Django shell session.

Django is pre-configured and ready to use with your project. You can import and use any Django
Expand All @@ -47,22 +48,17 @@ async def django_shell(

try:
result = await shell.execute(code)
output = DjangoShellOutput.from_result(result)

logger.debug(
"django_shell execution completed - request_id: %s, result type: %s",
ctx.request_id,
type(result).__name__,
)
if isinstance(result, ErrorResult):
await ctx.debug(
f"Execution failed: {type(result.exception).__name__}",
extra={
"error_type": type(result.exception).__name__,
"code_length": len(code),
},
)

return result.output
if isinstance(output.output, ErrorOutput):
await ctx.debug(f"Execution failed: {output.output.exception.message}")

return output

except Exception as e:
logger.error(
Expand Down
47 changes: 0 additions & 47 deletions src/mcp_django_shell/shell.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from __future__ import annotations

import logging
import traceback
from contextlib import redirect_stderr
from contextlib import redirect_stdout
from dataclasses import dataclass
Expand Down Expand Up @@ -196,17 +195,6 @@ def __post_init__(self):
logger.debug("%s.stdout: %s", self.__class__.__name__, self.stdout[:200])
if self.stderr:
logger.debug("%s.stderr: %s", self.__class__.__name__, self.stderr[:200])
logger.debug(
"%s output (for LLM): %s", self.__class__.__name__, self.output[:500]
)

@property
def output(self) -> str:
if self.stdout:
return self.stdout.rstrip()
elif self.value is not None:
return repr(self.value)
return ""


@dataclass
Expand All @@ -222,13 +210,6 @@ def __post_init__(self):
logger.debug("%s.stdout: %s", self.__class__.__name__, self.stdout[:200])
if self.stderr:
logger.debug("%s.stderr: %s", self.__class__.__name__, self.stderr[:200])
logger.debug(
"%s output (for LLM): %s", self.__class__.__name__, self.output[:500]
)

@property
def output(self) -> str:
return self.stdout or "OK"


@dataclass
Expand All @@ -250,34 +231,6 @@ def __post_init__(self):
logger.debug("%s.stdout: %s", self.__class__.__name__, self.stdout[:200])
if self.stderr:
logger.debug("%s.stderr: %s", self.__class__.__name__, self.stderr[:200])
logger.debug(
"%s output (filtered for LLM): %s",
self.__class__.__name__,
self.output[:500],
)

@property
def output(self) -> str:
error_type = self.exception.__class__.__name__

tb_str = "".join(
traceback.format_exception(
type(self.exception), self.exception, self.exception.__traceback__
)
)

# Filter out framework lines
tb_lines = tb_str.splitlines()
relevant_tb = "\n".join(
line for line in tb_lines if "mcp_django_shell" not in line
)

error_output = f"{error_type}: {self.exception}\n\nTraceback:\n{relevant_tb}"

if self.stdout:
return f"{self.stdout.rstrip()}\n{error_output}"

return error_output


Result = ExpressionResult | StatementResult | ErrorResult
128 changes: 128 additions & 0 deletions tests/test_output.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
from __future__ import annotations

import traceback

from mcp_django_shell.output import DjangoShellOutput
from mcp_django_shell.output import ErrorOutput
from mcp_django_shell.output import ExceptionOutput
from mcp_django_shell.output import ExecutionStatus
from mcp_django_shell.output import ExpressionOutput
from mcp_django_shell.shell import ErrorResult
from mcp_django_shell.shell import ExpressionResult


def test_django_shell_output_from_expression_result():
result = ExpressionResult(
code="2 + 2",
value=4,
stdout="",
stderr="",
)

output = DjangoShellOutput.from_result(result)

assert output.status == ExecutionStatus.SUCCESS
assert output.input == "2 + 2"
assert isinstance(output.output, ExpressionOutput)
assert output.output.value == 4
assert output.output.value_type is int


def test_expression_with_none_value():
result = ExpressionResult(
code="None",
value=None,
stdout="",
stderr="",
)

output = DjangoShellOutput.from_result(result)

assert output.status == ExecutionStatus.SUCCESS
assert isinstance(output.output, ExpressionOutput)

serialized = output.output.model_dump(mode="json")

assert serialized["value"] == "None"
assert serialized["value_type"] == "NoneType"


def test_django_shell_output_from_error_result():
exc = ZeroDivisionError("division by zero")

result = ErrorResult(
code="1 / 0",
exception=exc,
stdout="",
stderr="",
)

output = DjangoShellOutput.from_result(result)

assert output.status == ExecutionStatus.ERROR
assert isinstance(output.output, ErrorOutput)
assert output.output.exception.exc_type is ZeroDivisionError
assert "division by zero" in output.output.exception.message


def test_exception_output_serialization():
exc = ValueError("test error")

exc_output = ExceptionOutput(
exc_type=type(exc),
message=str(exc),
traceback=None, # needs to be None since we didn't actually raise it
)

serialized = exc_output.model_dump(mode="json")

assert serialized["exc_type"] == "ValueError"
assert serialized["message"] == "test error"
assert serialized["traceback"] == []


def test_exception_output_with_real_traceback():
# Do something that actually raises an error
try:
1 / 0
except ZeroDivisionError as e:
exc_output = ExceptionOutput(
exc_type=type(e),
message=str(e),
traceback=e.__traceback__,
)

serialized = exc_output.model_dump(mode="json")

assert serialized["exc_type"] == "ZeroDivisionError"
assert "division by zero" in serialized["message"]
assert isinstance(serialized["traceback"], list)
assert len(serialized["traceback"]) > 0
assert any("1 / 0" in line for line in serialized["traceback"])
assert not any("mcp_django_shell" in line for line in serialized["traceback"])


def test_traceback_filtering():
# Create a function that will appear in the traceback
def mcp_django_shell_function():
raise ValueError("test error")

try:
mcp_django_shell_function()
except ValueError as e:
exc_output = ExceptionOutput(
exc_type=type(e),
message=str(e),
traceback=e.__traceback__,
)

assert any(
"mcp_django_shell_function" in line
for line in traceback.format_tb(e.__traceback__)
)

serialized = exc_output.model_dump(mode="json")

assert len(serialized["traceback"]) == 0 or not any(
"mcp_django_shell" in line for line in serialized["traceback"]
)
Loading