diff --git a/src/agents/_run_impl.py b/src/agents/_run_impl.py index 3f3f2b916..48e8eebdf 100644 --- a/src/agents/_run_impl.py +++ b/src/agents/_run_impl.py @@ -1826,7 +1826,10 @@ async def execute( output_text = "" try: - operation = _coerce_apply_patch_operation(call.tool_call) + operation = _coerce_apply_patch_operation( + call.tool_call, + context_wrapper=context_wrapper, + ) editor = apply_patch_tool.editor if operation.type == "create_file": result = editor.create_file(operation) @@ -2093,7 +2096,9 @@ def _extract_apply_patch_call_id(tool_call: Any) -> str: return str(value) -def _coerce_apply_patch_operation(tool_call: Any) -> ApplyPatchOperation: +def _coerce_apply_patch_operation( + tool_call: Any, *, context_wrapper: RunContextWrapper[Any] +) -> ApplyPatchOperation: raw_operation = _get_mapping_or_attr(tool_call, "operation") if raw_operation is None: raise ModelBehaviorError("Apply patch call is missing an operation payload.") @@ -2117,7 +2122,12 @@ def _coerce_apply_patch_operation(tool_call: Any) -> ApplyPatchOperation: else: diff = None - return ApplyPatchOperation(type=op_type_literal, path=str(path), diff=diff) + return ApplyPatchOperation( + type=op_type_literal, + path=str(path), + diff=diff, + ctx_wrapper=context_wrapper, + ) def _normalize_apply_patch_result( diff --git a/src/agents/editor.py b/src/agents/editor.py index 38dd616b3..40a1374b4 100644 --- a/src/agents/editor.py +++ b/src/agents/editor.py @@ -4,6 +4,7 @@ from dataclasses import dataclass from typing import Literal, Protocol, runtime_checkable +from .run_context import RunContextWrapper from .util._types import MaybeAwaitable ApplyPatchOperationType = Literal["create_file", "update_file", "delete_file"] @@ -18,6 +19,7 @@ class ApplyPatchOperation: type: ApplyPatchOperationType path: str diff: str | None = None + ctx_wrapper: RunContextWrapper | None = None @dataclass(**_DATACLASS_KWARGS) diff --git a/tests/test_apply_patch_tool.py b/tests/test_apply_patch_tool.py index 197a7550f..a067a9d8a 100644 --- a/tests/test_apply_patch_tool.py +++ b/tests/test_apply_patch_tool.py @@ -63,6 +63,7 @@ async def test_apply_patch_tool_success() -> None: assert raw_item["status"] == "completed" assert raw_item["call_id"] == "call_apply" assert editor.operations[0].type == "update_file" + assert editor.operations[0].ctx_wrapper is context_wrapper assert isinstance(raw_item["output"], str) assert raw_item["output"].startswith("Updated tasks.md") input_payload = result.to_input_item() @@ -137,3 +138,4 @@ async def test_apply_patch_tool_accepts_mapping_call() -> None: raw_item = cast(dict[str, Any], result.raw_item) assert raw_item["call_id"] == "call_mapping" assert editor.operations[0].path == "notes.md" + assert editor.operations[0].ctx_wrapper is context_wrapper