diff --git a/aider/tools/insert_block.py b/aider/tools/insert_block.py index 6fc72835eab..96f13262814 100644 --- a/aider/tools/insert_block.py +++ b/aider/tools/insert_block.py @@ -9,6 +9,7 @@ format_tool_result, generate_unified_diff_snippet, handle_tool_error, + is_provided, select_occurrence_index, validate_file_for_edit, ) @@ -32,7 +33,7 @@ class Tool(BaseTool): "occurrence": {"type": "integer", "default": 1}, "change_id": {"type": "string"}, "dry_run": {"type": "boolean", "default": False}, - "position": {"type": "string", "enum": ["top", "bottom"]}, + "position": {"type": "string", "enum": ["top", "bottom", ""]}, "auto_indent": {"type": "boolean", "default": True}, "use_regex": {"type": "boolean", "default": False}, }, @@ -68,14 +69,14 @@ def execute( occurrence: Which occurrence of the pattern to use (1-based, or -1 for last) change_id: Optional ID for tracking changes dry_run: If True, only simulate the change - position: Special position like "start_of_file" or "end_of_file" + position: Special position like "top" or "bottom" (mutually exclusive with before_pattern and after_pattern) auto_indent: If True, automatically adjust indentation of inserted content use_regex: If True, treat patterns as regular expressions """ tool_name = "InsertBlock" try: # 1. Validate parameters - if sum(x is not None for x in [after_pattern, before_pattern, position]) != 1: + if sum(is_provided(x) for x in [after_pattern, before_pattern, position]) != 1: raise ToolError( "Must specify exactly one of: after_pattern, before_pattern, or position" ) diff --git a/aider/tools/show_numbered_context.py b/aider/tools/show_numbered_context.py index c5376853851..45aff33b446 100644 --- a/aider/tools/show_numbered_context.py +++ b/aider/tools/show_numbered_context.py @@ -1,7 +1,12 @@ import os from aider.tools.utils.base_tool import BaseTool -from aider.tools.utils.helpers import ToolError, handle_tool_error, resolve_paths +from aider.tools.utils.helpers import ( + ToolError, + handle_tool_error, + is_provided, + resolve_paths, +) class Tool(BaseTool): @@ -34,9 +39,17 @@ def execute(cls, coder, file_path, pattern=None, line_number=None, context_lines tool_name = "ShowNumberedContext" try: # 1. Validate arguments - if not (pattern is None) ^ (line_number is None): + pattern_provided = is_provided(pattern) + line_number_provided = is_provided(line_number, treat_zero_as_missing=True) + + if sum([pattern_provided, line_number_provided]) != 1: raise ToolError("Provide exactly one of 'pattern' or 'line_number'.") + if not pattern_provided: + pattern = None + if not line_number_provided: + line_number = None + # 2. Resolve path abs_path, rel_path = resolve_paths(coder, file_path) if not os.path.exists(abs_path): diff --git a/aider/tools/utils/helpers.py b/aider/tools/utils/helpers.py index 63e068129a3..a0fbb871118 100644 --- a/aider/tools/utils/helpers.py +++ b/aider/tools/utils/helpers.py @@ -10,6 +10,21 @@ class ToolError(Exception): pass +def is_provided(value, *, treat_zero_as_missing=False): + """ + Normalizes parameter presence checks across tools. + + Returns True when the value should be considered user-provided. + """ + if value is None: + return False + if isinstance(value, str) and value == "": + return False + if treat_zero_as_missing and isinstance(value, (int, float)) and value == 0: + return False + return True + + def resolve_paths(coder, file_path): """Resolves absolute and relative paths for a given file path.""" try: @@ -105,10 +120,11 @@ def determine_line_range( Determines the end line index based on end_pattern or line_count. Raises ToolError if end_pattern is not found or line_count is invalid. """ + # Parameter validation: Ensure only one targeting method is used targeting_methods = [ - target_symbol is not None, - start_pattern_line_index is not None, + is_provided(target_symbol), + is_provided(start_pattern_line_index), # Note: line_count and end_pattern depend on start_pattern_line_index ] if sum(targeting_methods) > 1: diff --git a/tests/tools/test_insert_block.py b/tests/tools/test_insert_block.py new file mode 100644 index 00000000000..33955bb6509 --- /dev/null +++ b/tests/tools/test_insert_block.py @@ -0,0 +1,117 @@ +from pathlib import Path +from types import SimpleNamespace +from unittest.mock import Mock + +import pytest + +from aider.tools import insert_block + + +class DummyIO: + def __init__(self): + self.tool_error = Mock() + self.tool_warning = Mock() + self.tool_output = Mock() + + def read_text(self, path): + return Path(path).read_text() + + def write_text(self, path, content): + Path(path).write_text(content) + + +class DummyChangeTracker: + def __init__(self): + self.calls = [] + + def track_change( + self, file_path, change_type, original_content, new_content, metadata, change_id=None + ): + self.calls.append( + { + "file_path": file_path, + "change_type": change_type, + "original_content": original_content, + "new_content": new_content, + "metadata": metadata, + "change_id": change_id, + } + ) + return f"change-{len(self.calls)}" + + +class DummyCoder: + def __init__(self, root): + self.root = str(root) + self.repo = SimpleNamespace(root=str(root)) + self.io = DummyIO() + self.change_tracker = DummyChangeTracker() + self.aider_edited_files = set() + self.files_edited_by_tools = set() + self.abs_read_only_fnames = set() + self.abs_fnames = set() + + def abs_root_path(self, file_path): + path = Path(file_path) + if path.is_absolute(): + return str(path) + return str((Path(self.root) / path).resolve()) + + def get_rel_fname(self, abs_path): + return str(Path(abs_path).resolve().relative_to(self.root)) + + +@pytest.fixture +def coder_with_file(tmp_path): + file_path = tmp_path / "example.txt" + file_path.write_text("first line\nsecond line\n") + coder = DummyCoder(tmp_path) + coder.abs_fnames.add(str(file_path.resolve())) + return coder, file_path + + +def test_position_top_succeeds_with_no_patterns(coder_with_file): + coder, file_path = coder_with_file + + result = insert_block.Tool.execute( + coder, + file_path="example.txt", + content="inserted line", + position="top", + ) + + assert result.startswith("Successfully executed InsertBlock.") + assert file_path.read_text().splitlines()[0] == "inserted line" + coder.io.tool_error.assert_not_called() + + +def test_position_top_ignores_blank_patterns(coder_with_file): + coder, file_path = coder_with_file + + result = insert_block.Tool.execute( + coder, + file_path="example.txt", + content="inserted line", + position="top", + after_pattern="", + ) + + assert result.startswith("Successfully executed InsertBlock.") + assert file_path.read_text().splitlines()[0] == "inserted line" + coder.io.tool_error.assert_not_called() + + +def test_mutually_exclusive_parameters_raise(coder_with_file): + coder, file_path = coder_with_file + + result = insert_block.Tool.execute( + coder, + file_path="example.txt", + content="new line", + position="top", + after_pattern="first line", + ) + + assert result.startswith("Error: Must specify exactly one of") + assert file_path.read_text().startswith("first line") + coder.io.tool_error.assert_called() diff --git a/tests/tools/test_show_numbered_context.py b/tests/tools/test_show_numbered_context.py new file mode 100644 index 00000000000..a33019b5ec1 --- /dev/null +++ b/tests/tools/test_show_numbered_context.py @@ -0,0 +1,106 @@ +from pathlib import Path +from types import SimpleNamespace +from unittest.mock import Mock + +import pytest + +from aider.tools import show_numbered_context + + +class DummyIO: + def __init__(self): + self.tool_error = Mock() + self.tool_warning = Mock() + self.tool_output = Mock() + + def read_text(self, path): + return Path(path).read_text() + + def write_text(self, path, content): + Path(path).write_text(content) + + +class DummyCoder: + def __init__(self, root): + self.root = str(root) + self.repo = SimpleNamespace(root=str(root)) + self.io = DummyIO() + + def abs_root_path(self, file_path): + path = Path(file_path) + if path.is_absolute(): + return str(path) + return str((Path(self.root) / path).resolve()) + + def get_rel_fname(self, abs_path): + return str(Path(abs_path).resolve().relative_to(self.root)) + + +@pytest.fixture +def coder_with_file(tmp_path): + file_path = tmp_path / "example.txt" + file_path.write_text("alpha\nbeta\ngamma\n") + coder = DummyCoder(tmp_path) + return coder, file_path + + +def test_pattern_with_zero_line_number_is_allowed(coder_with_file): + coder, file_path = coder_with_file + + result = show_numbered_context.Tool.execute( + coder, + file_path="example.txt", + pattern="beta", + line_number=0, + context_lines=0, + ) + + assert "beta" in result + assert "line 2" in result or "2 | beta" in result + coder.io.tool_error.assert_not_called() + + +def test_empty_pattern_uses_line_number(coder_with_file): + coder, file_path = coder_with_file + + result = show_numbered_context.Tool.execute( + coder, + file_path="example.txt", + pattern="", + line_number=2, + context_lines=0, + ) + + assert "2 | beta" in result + coder.io.tool_error.assert_not_called() + + +def test_conflicting_pattern_and_line_number_raise(coder_with_file): + coder, file_path = coder_with_file + + result = show_numbered_context.Tool.execute( + coder, + file_path="example.txt", + pattern="beta", + line_number=2, + context_lines=0, + ) + + assert result.startswith("Error: Provide exactly one of") + coder.io.tool_error.assert_called() + + +def test_target_symbol_empty_string_treated_as_missing(): + from aider.tools.utils import helpers + from aider.tools.utils.helpers import ToolError + + with pytest.raises(ToolError, match="Must specify either target_symbol or start_pattern"): + helpers.determine_line_range( + coder=SimpleNamespace(repo_map=None), # repo_map not used in this path + file_path="dummy", + lines=["a", "b"], + target_symbol="", + start_pattern_line_index=None, + end_pattern=None, + line_count=1, + )