From e6f1d0923ce8b3515843720d86f299e7fc381b48 Mon Sep 17 00:00:00 2001 From: jayhack Date: Wed, 19 Feb 2025 15:10:56 -0800 Subject: [PATCH 1/9] . --- src/codegen/extensions/tools/bash.py | 84 ++++++++----- src/codegen/extensions/tools/commit.py | 34 +++++- src/codegen/extensions/tools/create_file.py | 67 +++++++++-- src/codegen/extensions/tools/delete_file.py | 49 ++++++-- src/codegen/extensions/tools/edit_file.py | 80 +++++++++++-- .../extensions/tools/list_directory.py | 110 +++++++++++++----- src/codegen/extensions/tools/observation.py | 51 ++++++++ 7 files changed, 391 insertions(+), 84 deletions(-) create mode 100644 src/codegen/extensions/tools/observation.py diff --git a/src/codegen/extensions/tools/bash.py b/src/codegen/extensions/tools/bash.py index 0ebcaf855..dd9da037d 100644 --- a/src/codegen/extensions/tools/bash.py +++ b/src/codegen/extensions/tools/bash.py @@ -3,7 +3,11 @@ import re import shlex import subprocess -from typing import Any +from typing import ClassVar, Optional + +from pydantic import Field + +from .observation import Observation # Whitelist of allowed commands and their flags ALLOWED_COMMANDS = { @@ -22,6 +26,28 @@ } +class RunBashCommandObservation(Observation): + """Response from running a bash command.""" + + stdout: Optional[str] = Field( + default=None, + description="Standard output from the command", + ) + stderr: Optional[str] = Field( + default=None, + description="Standard error from the command", + ) + command: str = Field( + description="The command that was executed", + ) + pid: Optional[int] = Field( + default=None, + description="Process ID for background commands", + ) + + str_template: ClassVar[str] = "Command '{command}' completed" + + def validate_command(command: str) -> tuple[bool, str]: """Validate if a command is safe to execute. @@ -90,7 +116,7 @@ def validate_command(command: str) -> tuple[bool, str]: return False, f"Failed to validate command: {e!s}" -def run_bash_command(command: str, is_background: bool = False) -> dict[str, Any]: +def run_bash_command(command: str, is_background: bool = False) -> RunBashCommandObservation: """Run a bash command and return its output. Args: @@ -98,15 +124,16 @@ def run_bash_command(command: str, is_background: bool = False) -> dict[str, Any is_background: Whether to run the command in the background Returns: - Dictionary containing the command output or error + RunBashCommandObservation containing the command output or error """ # First validate the command is_valid, error_message = validate_command(command) if not is_valid: - return { - "status": "error", - "error": f"Invalid command: {error_message}", - } + return RunBashCommandObservation( + status="error", + error=f"Invalid command: {error_message}", + command=command, + ) try: if is_background: @@ -118,10 +145,11 @@ def run_bash_command(command: str, is_background: bool = False) -> dict[str, Any stderr=subprocess.PIPE, text=True, ) - return { - "status": "success", - "message": f"Command '{command}' started in background with PID {process.pid}", - } + return RunBashCommandObservation( + status="success", + command=command, + pid=process.pid, + ) # For foreground processes, we wait for completion result = subprocess.run( @@ -132,20 +160,24 @@ def run_bash_command(command: str, is_background: bool = False) -> dict[str, Any check=True, # This will raise CalledProcessError if command fails ) - return { - "status": "success", - "stdout": result.stdout, - "stderr": result.stderr, - } + return RunBashCommandObservation( + status="success", + command=command, + stdout=result.stdout, + stderr=result.stderr, + ) + except subprocess.CalledProcessError as e: - return { - "status": "error", - "error": f"Command failed with exit code {e.returncode}", - "stdout": e.stdout, - "stderr": e.stderr, - } + return RunBashCommandObservation( + status="error", + error=f"Command failed with exit code {e.returncode}", + command=command, + stdout=e.stdout, + stderr=e.stderr, + ) except Exception as e: - return { - "status": "error", - "error": f"Failed to run command: {e!s}", - } + return RunBashCommandObservation( + status="error", + error=f"Failed to run command: {e!s}", + command=command, + ) diff --git a/src/codegen/extensions/tools/commit.py b/src/codegen/extensions/tools/commit.py index eb47c17cf..3bd931756 100644 --- a/src/codegen/extensions/tools/commit.py +++ b/src/codegen/extensions/tools/commit.py @@ -1,18 +1,42 @@ """Tool for committing changes to disk.""" -from typing import Any +from typing import ClassVar + +from pydantic import Field from codegen import Codebase +from .observation import Observation + + +class CommitObservation(Observation): + """Response from committing changes to disk.""" + + message: str = Field( + description="Message describing the commit result", + ) + + str_template: ClassVar[str] = "{message}" + -def commit(codebase: Codebase) -> dict[str, Any]: +def commit(codebase: Codebase) -> CommitObservation: """Commit any pending changes to disk. Args: codebase: The codebase to operate on Returns: - Dict containing commit status + CommitObservation containing commit status """ - codebase.commit() - return {"status": "success", "message": "Changes committed to disk"} + try: + codebase.commit() + return CommitObservation( + status="success", + message="Changes committed to disk", + ) + except Exception as e: + return CommitObservation( + status="error", + error=f"Failed to commit changes: {e!s}", + message="Failed to commit changes", + ) diff --git a/src/codegen/extensions/tools/create_file.py b/src/codegen/extensions/tools/create_file.py index d340e4b1c..903cd11bf 100644 --- a/src/codegen/extensions/tools/create_file.py +++ b/src/codegen/extensions/tools/create_file.py @@ -1,13 +1,29 @@ """Tool for creating new files.""" -from typing import Any +from typing import ClassVar + +from pydantic import Field from codegen import Codebase -from .view_file import view_file +from .observation import Observation +from .view_file import ViewFileObservation, view_file + + +class CreateFileObservation(Observation): + """Response from creating a new file.""" + + filepath: str = Field( + description="Path to the created file", + ) + file_info: ViewFileObservation = Field( + description="Information about the created file", + ) + + str_template: ClassVar[str] = "Created file {filepath}" -def create_file(codebase: Codebase, filepath: str, content: str = "") -> dict[str, Any]: +def create_file(codebase: Codebase, filepath: str, content: str = "") -> CreateFileObservation: """Create a new file. Args: @@ -16,10 +32,45 @@ def create_file(codebase: Codebase, filepath: str, content: str = "") -> dict[st content: Initial file content Returns: - Dict containing new file state, or error information if file already exists + CreateFileObservation containing new file state, or error if file exists """ if codebase.has_file(filepath): - return {"error": f"File already exists: {filepath}"} - file = codebase.create_file(filepath, content=content) - codebase.commit() - return view_file(codebase, filepath) + return CreateFileObservation( + status="error", + error=f"File already exists: {filepath}", + filepath=filepath, + file_info=ViewFileObservation( + status="error", + error=f"File already exists: {filepath}", + filepath=filepath, + content="", + line_count=0, + ), + ) + + try: + file = codebase.create_file(filepath, content=content) + codebase.commit() + + # Get file info using view_file + file_info = view_file(codebase, filepath) + + return CreateFileObservation( + status="success", + filepath=filepath, + file_info=file_info, + ) + + except Exception as e: + return CreateFileObservation( + status="error", + error=f"Failed to create file: {e!s}", + filepath=filepath, + file_info=ViewFileObservation( + status="error", + error=f"Failed to create file: {e!s}", + filepath=filepath, + content="", + line_count=0, + ), + ) diff --git a/src/codegen/extensions/tools/delete_file.py b/src/codegen/extensions/tools/delete_file.py index 9703cd21f..1f23ef265 100644 --- a/src/codegen/extensions/tools/delete_file.py +++ b/src/codegen/extensions/tools/delete_file.py @@ -1,11 +1,25 @@ """Tool for deleting files.""" -from typing import Any +from typing import ClassVar + +from pydantic import Field from codegen import Codebase +from .observation import Observation + + +class DeleteFileObservation(Observation): + """Response from deleting a file.""" + + filepath: str = Field( + description="Path to the deleted file", + ) -def delete_file(codebase: Codebase, filepath: str) -> dict[str, Any]: + str_template: ClassVar[str] = "Deleted file {filepath}" + + +def delete_file(codebase: Codebase, filepath: str) -> DeleteFileObservation: """Delete a file. Args: @@ -13,15 +27,34 @@ def delete_file(codebase: Codebase, filepath: str) -> dict[str, Any]: filepath: Path to the file to delete Returns: - Dict containing deletion status, or error information if file not found + DeleteFileObservation containing deletion status, or error if file not found """ try: file = codebase.get_file(filepath) except ValueError: - return {"error": f"File not found: {filepath}"} + return DeleteFileObservation( + status="error", + error=f"File not found: {filepath}", + filepath=filepath, + ) + if file is None: - return {"error": f"File not found: {filepath}"} + return DeleteFileObservation( + status="error", + error=f"File not found: {filepath}", + filepath=filepath, + ) - file.remove() - codebase.commit() - return {"status": "success", "deleted_file": filepath} + try: + file.remove() + codebase.commit() + return DeleteFileObservation( + status="success", + filepath=filepath, + ) + except Exception as e: + return DeleteFileObservation( + status="error", + error=f"Failed to delete file: {e!s}", + filepath=filepath, + ) diff --git a/src/codegen/extensions/tools/edit_file.py b/src/codegen/extensions/tools/edit_file.py index a7f19e448..50e85b73d 100644 --- a/src/codegen/extensions/tools/edit_file.py +++ b/src/codegen/extensions/tools/edit_file.py @@ -1,13 +1,29 @@ """Tool for editing file contents.""" -from typing import Any +from typing import ClassVar + +from pydantic import Field from codegen import Codebase -from .view_file import view_file +from .observation import Observation +from .view_file import ViewFileObservation, view_file + + +class EditFileObservation(Observation): + """Response from editing a file.""" + + filepath: str = Field( + description="Path to the edited file", + ) + file_info: ViewFileObservation = Field( + description="Information about the edited file", + ) + + str_template: ClassVar[str] = "Edited file {filepath}" -def edit_file(codebase: Codebase, filepath: str, content: str) -> dict[str, Any]: +def edit_file(codebase: Codebase, filepath: str, content: str) -> EditFileObservation: """Edit a file by replacing its entire content. Args: @@ -16,15 +32,61 @@ def edit_file(codebase: Codebase, filepath: str, content: str) -> dict[str, Any] content: New content for the file Returns: - Dict containing updated file state, or error information if file not found + EditFileObservation containing updated file state, or error if file not found """ try: file = codebase.get_file(filepath) except ValueError: - return {"error": f"File not found: {filepath}"} + return EditFileObservation( + status="error", + error=f"File not found: {filepath}", + filepath=filepath, + file_info=ViewFileObservation( + status="error", + error=f"File not found: {filepath}", + filepath=filepath, + content="", + line_count=0, + ), + ) + if file is None: - return {"error": f"File not found: {filepath}"} + return EditFileObservation( + status="error", + error=f"File not found: {filepath}", + filepath=filepath, + file_info=ViewFileObservation( + status="error", + error=f"File not found: {filepath}", + filepath=filepath, + content="", + line_count=0, + ), + ) + + try: + file.edit(content) + codebase.commit() + + # Get updated file info using view_file + file_info = view_file(codebase, filepath) + + return EditFileObservation( + status="success", + filepath=filepath, + file_info=file_info, + ) - file.edit(content) - codebase.commit() - return view_file(codebase, filepath) + except Exception as e: + return EditFileObservation( + status="error", + error=f"Failed to edit file: {e!s}", + filepath=filepath, + file_info=ViewFileObservation( + status="error", + error=f"Failed to edit file: {e!s}", + filepath=filepath, + content="", + line_count=0, + ), + ) diff --git a/src/codegen/extensions/tools/list_directory.py b/src/codegen/extensions/tools/list_directory.py index 903983a45..ed76a953b 100644 --- a/src/codegen/extensions/tools/list_directory.py +++ b/src/codegen/extensions/tools/list_directory.py @@ -1,12 +1,37 @@ """Tool for listing directory contents.""" -from typing import Any +from typing import ClassVar, Union + +from pydantic import BaseModel, Field from codegen import Codebase from codegen.sdk.core.directory import Directory +from .observation import Observation + + +class DirectoryInfo(BaseModel): + """Information about a directory.""" + + name: str = Field(description="Name of the directory") + path: str = Field(description="Full path to the directory") + files: list[str] = Field(description="List of files in this directory") + subdirectories: list[Union[str, "DirectoryInfo"]] = Field( + description="List of subdirectories (either names or full DirectoryInfo objects depending on depth)", + ) + + +class ListDirectoryObservation(Observation): + """Response from listing directory contents.""" -def list_directory(codebase: Codebase, dirpath: str = "./", depth: int = 1) -> dict[str, Any]: + path: str = Field(description="Path to the listed directory") + directory_info: DirectoryInfo = Field(description="Information about the directory and its contents") + depth: int = Field(description="How deep the directory traversal went") + + str_template: ClassVar[str] = "Listed contents of {path} (depth={depth})" + + +def list_directory(codebase: Codebase, dirpath: str = "./", depth: int = 1) -> ListDirectoryObservation: """List contents of a directory. Args: @@ -16,31 +41,39 @@ def list_directory(codebase: Codebase, dirpath: str = "./", depth: int = 1) -> d Use -1 for unlimited depth. Returns: - Dict containing directory contents and metadata in a nested structure: - { - "path": str, - "name": str, - "files": list[str], - "subdirectories": [ - { - "path": str, - "name": str, - "files": list[str], - "subdirectories": [...], - }, - ... - ] - } + ListDirectoryObservation containing directory contents and metadata """ try: directory = codebase.get_directory(dirpath) except ValueError: - return {"error": f"Directory not found: {dirpath}"} + return ListDirectoryObservation( + status="error", + error=f"Directory not found: {dirpath}", + path=dirpath, + directory_info=DirectoryInfo( + name="", + path=dirpath, + files=[], + subdirectories=[], + ), + depth=depth, + ) if not directory: - return {"error": f"Directory not found: {dirpath}"} + return ListDirectoryObservation( + status="error", + error=f"Directory not found: {dirpath}", + path=dirpath, + directory_info=DirectoryInfo( + name="", + path=dirpath, + files=[], + subdirectories=[], + ), + depth=depth, + ) - def get_directory_info(dir_obj: Directory, current_depth: int) -> dict[str, Any]: + def get_directory_info(dir_obj: Directory, current_depth: int) -> DirectoryInfo: """Helper function to get directory info recursively.""" # Get direct files all_files = [] @@ -59,11 +92,32 @@ def get_directory_info(dir_obj: Directory, current_depth: int) -> dict[str, Any] else: # At max depth, just include name subdirs.append(subdir.name) - return { - "name": dir_obj.name, - "path": dir_obj.dirpath, - "files": all_files, - "subdirectories": subdirs, - } - - return get_directory_info(directory, depth) + + return DirectoryInfo( + name=dir_obj.name, + path=dir_obj.dirpath, + files=all_files, + subdirectories=subdirs, + ) + + try: + directory_info = get_directory_info(directory, depth) + return ListDirectoryObservation( + status="success", + path=dirpath, + directory_info=directory_info, + depth=depth, + ) + except Exception as e: + return ListDirectoryObservation( + status="error", + error=f"Failed to list directory: {e!s}", + path=dirpath, + directory_info=DirectoryInfo( + name="", + path=dirpath, + files=[], + subdirectories=[], + ), + depth=depth, + ) diff --git a/src/codegen/extensions/tools/observation.py b/src/codegen/extensions/tools/observation.py new file mode 100644 index 000000000..9f9e23335 --- /dev/null +++ b/src/codegen/extensions/tools/observation.py @@ -0,0 +1,51 @@ +"""Base class for tool observations/responses.""" + +from typing import Any, ClassVar, Optional + +from pydantic import BaseModel, Field + + +class Observation(BaseModel): + """Base class for all tool observations. + + All tool responses should inherit from this class to ensure consistent + handling and string representations. + """ + + status: str = Field( + default="success", + description="Status of the operation - 'success' or 'error'", + ) + error: Optional[str] = Field( + default=None, + description="Error message if status is 'error'", + ) + + # Class variable to store a template for string representation + str_template: ClassVar[str] = "{status}: {details}" + + def _get_details(self) -> dict[str, Any]: + """Get the details to include in string representation. + + Override this in subclasses to customize string output. + By default, includes all fields except status and error. + """ + return {k: v for k, v in self.model_dump().items() if k not in {"status", "error"} and v is not None} + + def __str__(self) -> str: + """Get string representation of the observation.""" + if self.status == "error": + return f"Error: {self.error}" + + details = self._get_details() + if not details: + return self.status + + return self.str_template.format( + status=self.status, + details=", ".join(f"{k}={v}" for k, v in details.items()), + ) + + def __repr__(self) -> str: + """Get detailed string representation of the observation.""" + return f"{self.__class__.__name__}({self.model_dump_json()})" From 6b21d844a7448072968d5512498f321ceb9fee35 Mon Sep 17 00:00:00 2001 From: jayhack Date: Wed, 19 Feb 2025 15:20:26 -0800 Subject: [PATCH 2/9] . --- src/codegen/extensions/tools/move_symbol.py | 116 ++++++++++++--- src/codegen/extensions/tools/rename_file.py | 79 ++++++++-- .../extensions/tools/replacement_edit.py | 49 ++++-- src/codegen/extensions/tools/reveal_symbol.py | 94 +++++++++--- src/codegen/extensions/tools/search.py | 140 ++++++++++++------ src/codegen/extensions/tools/semantic_edit.py | 84 +++++++++-- .../extensions/tools/semantic_search.py | 84 ++++++++--- src/codegen/extensions/tools/view_file.py | 51 +++++-- 8 files changed, 544 insertions(+), 153 deletions(-) diff --git a/src/codegen/extensions/tools/move_symbol.py b/src/codegen/extensions/tools/move_symbol.py index 26058599a..6a86be4e4 100644 --- a/src/codegen/extensions/tools/move_symbol.py +++ b/src/codegen/extensions/tools/move_symbol.py @@ -1,10 +1,35 @@ """Tool for moving symbols between files.""" -from typing import Any, Literal +from typing import ClassVar, Literal + +from pydantic import Field from codegen import Codebase -from .view_file import view_file +from .observation import Observation +from .view_file import ViewFileObservation, view_file + + +class MoveSymbolObservation(Observation): + """Response from moving a symbol between files.""" + + symbol_name: str = Field( + description="Name of the symbol that was moved", + ) + source_file: str = Field( + description="Path to the source file", + ) + target_file: str = Field( + description="Path to the target file", + ) + source_file_info: ViewFileObservation = Field( + description="Information about the source file after move", + ) + target_file_info: ViewFileObservation = Field( + description="Information about the target file after move", + ) + + str_template: ClassVar[str] = "Moved symbol {symbol_name} from {source_file} to {target_file}" def move_symbol( @@ -14,7 +39,7 @@ def move_symbol( target_file: str, strategy: Literal["update_all_imports", "add_back_edge"] = "update_all_imports", include_dependencies: bool = True, -) -> dict[str, Any]: +) -> MoveSymbolObservation: """Move a symbol from one file to another. Args: @@ -28,34 +53,89 @@ def move_symbol( include_dependencies: Whether to move dependencies along with the symbol Returns: - Dict containing move status and updated file info, or error information if operation fails + MoveSymbolObservation containing move status and updated file info """ try: source = codebase.get_file(source_file) except ValueError: - return {"error": f"Source file not found: {source_file}"} - if source is None: - return {"error": f"Source file not found: {source_file}"} + return MoveSymbolObservation( + status="error", + error=f"Source file not found: {source_file}", + symbol_name=symbol_name, + source_file=source_file, + target_file=target_file, + source_file_info=ViewFileObservation( + status="error", + error=f"Source file not found: {source_file}", + filepath=source_file, + content="", + line_count=0, + ), + target_file_info=ViewFileObservation( + status="error", + error=f"Source file not found: {source_file}", + filepath=target_file, + content="", + line_count=0, + ), + ) try: target = codebase.get_file(target_file) except ValueError: - return {"error": f"Target file not found: {target_file}"} + return MoveSymbolObservation( + status="error", + error=f"Target file not found: {target_file}", + symbol_name=symbol_name, + source_file=source_file, + target_file=target_file, + source_file_info=ViewFileObservation( + status="error", + error=f"Target file not found: {target_file}", + filepath=source_file, + content="", + line_count=0, + ), + target_file_info=ViewFileObservation( + status="error", + error=f"Target file not found: {target_file}", + filepath=target_file, + content="", + line_count=0, + ), + ) symbol = source.get_symbol(symbol_name) if not symbol: - return {"error": f"Symbol '{symbol_name}' not found in {source_file}"} + return MoveSymbolObservation( + status="error", + error=f"Symbol '{symbol_name}' not found in {source_file}", + symbol_name=symbol_name, + source_file=source_file, + target_file=target_file, + source_file_info=view_file(codebase, source_file), + target_file_info=view_file(codebase, target_file), + ) try: symbol.move_to_file(target, include_dependencies=include_dependencies, strategy=strategy) codebase.commit() - return { - "status": "success", - "symbol": symbol_name, - "source_file": source_file, - "target_file": target_file, - "source_file_info": view_file(codebase, source_file), - "target_file_info": view_file(codebase, target_file), - } + + return MoveSymbolObservation( + status="success", + symbol_name=symbol_name, + source_file=source_file, + target_file=target_file, + source_file_info=view_file(codebase, source_file), + target_file_info=view_file(codebase, target_file), + ) except Exception as e: - return {"error": f"Failed to move symbol: {e!s}"} + return MoveSymbolObservation( + status="error", + error=f"Failed to move symbol: {e!s}", + symbol_name=symbol_name, + source_file=source_file, + target_file=target_file, + source_file_info=view_file(codebase, source_file), + target_file_info=view_file(codebase, target_file), + ) diff --git a/src/codegen/extensions/tools/rename_file.py b/src/codegen/extensions/tools/rename_file.py index 2be9df52b..ce4111865 100644 --- a/src/codegen/extensions/tools/rename_file.py +++ b/src/codegen/extensions/tools/rename_file.py @@ -1,13 +1,32 @@ """Tool for renaming files and updating imports.""" -from typing import Any +from typing import ClassVar + +from pydantic import Field from codegen import Codebase -from .view_file import view_file +from .observation import Observation +from .view_file import ViewFileObservation, view_file + + +class RenameFileObservation(Observation): + """Response from renaming a file.""" + old_filepath: str = Field( + description="Original path of the file", + ) + new_filepath: str = Field( + description="New path of the file", + ) + file_info: ViewFileObservation = Field( + description="Information about the renamed file", + ) -def rename_file(codebase: Codebase, filepath: str, new_filepath: str) -> dict[str, Any]: + str_template: ClassVar[str] = "Renamed file from {old_filepath} to {new_filepath}" + + +def rename_file(codebase: Codebase, filepath: str, new_filepath: str) -> RenameFileObservation: """Rename a file and update all imports to point to the new location. Args: @@ -16,21 +35,61 @@ def rename_file(codebase: Codebase, filepath: str, new_filepath: str) -> dict[st new_filepath: New path for the file relative to workspace root Returns: - Dict containing rename status and new file info, or error information if file not found + RenameFileObservation containing rename status and new file info """ try: file = codebase.get_file(filepath) except ValueError: - return {"error": f"File not found: {filepath}"} - if file is None: - return {"error": f"File not found: {filepath}"} + return RenameFileObservation( + status="error", + error=f"File not found: {filepath}", + old_filepath=filepath, + new_filepath=new_filepath, + file_info=ViewFileObservation( + status="error", + error=f"File not found: {filepath}", + filepath=filepath, + content="", + line_count=0, + ), + ) if codebase.has_file(new_filepath): - return {"error": f"Destination file already exists: {new_filepath}"} + return RenameFileObservation( + status="error", + error=f"Destination file already exists: {new_filepath}", + old_filepath=filepath, + new_filepath=new_filepath, + file_info=ViewFileObservation( + status="error", + error=f"Destination file already exists: {new_filepath}", + filepath=new_filepath, + content="", + line_count=0, + ), + ) try: file.update_filepath(new_filepath) codebase.commit() - return {"status": "success", "old_filepath": filepath, "new_filepath": new_filepath, "file_info": view_file(codebase, new_filepath)} + + return RenameFileObservation( + status="success", + old_filepath=filepath, + new_filepath=new_filepath, + file_info=view_file(codebase, new_filepath), + ) except Exception as e: - return {"error": f"Failed to rename file: {e!s}"} + return RenameFileObservation( + status="error", + error=f"Failed to rename file: {e!s}", + old_filepath=filepath, + new_filepath=new_filepath, + file_info=ViewFileObservation( + status="error", + error=f"Failed to rename file: {e!s}", + filepath=filepath, + content="", + line_count=0, + ), + ) diff --git a/src/codegen/extensions/tools/replacement_edit.py b/src/codegen/extensions/tools/replacement_edit.py index de3868ac7..7f04b94c0 100644 --- a/src/codegen/extensions/tools/replacement_edit.py +++ b/src/codegen/extensions/tools/replacement_edit.py @@ -2,13 +2,34 @@ import difflib import re -from typing import Optional +from typing import ClassVar, Optional + +from pydantic import Field from codegen import Codebase +from .observation import Observation from .view_file import add_line_numbers +class ReplacementEditObservation(Observation): + """Response from making regex-based replacements in a file.""" + + filepath: str = Field( + description="Path to the edited file", + ) + diff: Optional[str] = Field( + default=None, + description="Unified diff showing the changes made", + ) + new_content: Optional[str] = Field( + default=None, + description="New content with line numbers", + ) + + str_template: ClassVar[str] = "Edited file {filepath}" + + def generate_diff(original: str, modified: str) -> str: """Generate a unified diff between two strings. @@ -70,7 +91,7 @@ def replacement_edit( end: int = -1, count: Optional[int] = None, flags: re.RegexFlag = re.MULTILINE, -) -> dict[str, str]: +) -> ReplacementEditObservation: """Replace text in a file using regex pattern matching. Args: @@ -84,7 +105,7 @@ def replacement_edit( flags: Regex flags (default: re.MULTILINE) Returns: - Dict containing edit results and status + ReplacementEditObservation containing edit results and status Raises: FileNotFoundError: If file not found @@ -124,11 +145,11 @@ def replacement_edit( # If no changes were made, return early if new_section == section_content: - return { - "filepath": filepath, - "status": "unchanged", - "message": "No matches found for the given pattern", - } + return ReplacementEditObservation( + status="unchanged", + message="No matches found for the given pattern", + filepath=filepath, + ) # Merge the edited content with the original new_content = _merge_content(original_content, new_section, start, end) @@ -140,9 +161,9 @@ def replacement_edit( file.edit(new_content) codebase.commit() - return { - "filepath": filepath, - "diff": diff, - "status": "success", - "new_content": add_line_numbers(new_content), - } + return ReplacementEditObservation( + status="success", + filepath=filepath, + diff=diff, + new_content=add_line_numbers(new_content), + ) diff --git a/src/codegen/extensions/tools/reveal_symbol.py b/src/codegen/extensions/tools/reveal_symbol.py index 06897d88e..6279eaeba 100644 --- a/src/codegen/extensions/tools/reveal_symbol.py +++ b/src/codegen/extensions/tools/reveal_symbol.py @@ -1,6 +1,9 @@ -from typing import Any, Optional +"""Tool for revealing symbol dependencies and usages.""" + +from typing import Any, ClassVar, Optional import tiktoken +from pydantic import Field from codegen import Codebase from codegen.sdk.ai.utils import count_tokens @@ -8,6 +11,48 @@ from codegen.sdk.core.import_resolution import Import from codegen.sdk.core.symbol import Symbol +from .observation import Observation + + +class SymbolInfo(Observation): + """Information about a symbol.""" + + name: str = Field(description="Name of the symbol") + filepath: Optional[str] = Field(description="Path to the file containing the symbol") + source: str = Field(description="Source code of the symbol") + + str_template: ClassVar[str] = "{name} in {filepath}" + + +class RevealSymbolObservation(Observation): + """Response from revealing symbol dependencies and usages.""" + + dependencies: Optional[list[SymbolInfo]] = Field( + default=None, + description="List of symbols this symbol depends on", + ) + usages: Optional[list[SymbolInfo]] = Field( + default=None, + description="List of symbols that use this symbol", + ) + truncated: bool = Field( + default=False, + description="Whether results were truncated due to token limit", + ) + valid_filepaths: Optional[list[str]] = Field( + default=None, + description="List of valid filepaths when symbol is ambiguous", + ) + + str_template: ClassVar[str] = "Symbol info: {dependencies_count} dependencies, {usages_count} usages" + + def _get_details(self) -> dict[str, Any]: + """Get details for string representation.""" + return { + "dependencies_count": len(self.dependencies or []), + "usages_count": len(self.usages or []), + } + def truncate_source(source: str, max_tokens: int) -> str: """Truncate source code to fit within max_tokens while preserving meaning. @@ -70,7 +115,7 @@ def truncate_source(source: str, max_tokens: int) -> str: return "".join(result) -def get_symbol_info(symbol: Symbol, max_tokens: Optional[int] = None) -> dict[str, Any]: +def get_symbol_info(symbol: Symbol, max_tokens: Optional[int] = None) -> SymbolInfo: """Get relevant information about a symbol. Args: @@ -84,11 +129,12 @@ def get_symbol_info(symbol: Symbol, max_tokens: Optional[int] = None) -> dict[st if max_tokens: source = truncate_source(source, max_tokens) - return { - "name": symbol.name, - "filepath": symbol.file.filepath if symbol.file else None, - "source": source, - } + return SymbolInfo( + status="success", + name=symbol.name, + filepath=symbol.file.filepath if symbol.file else None, + source=source, + ) def hop_through_imports(symbol: Symbol, seen_imports: Optional[set[str]] = None) -> Symbol: @@ -122,7 +168,7 @@ def get_extended_context( total_tokens: int = 0, collect_dependencies: bool = True, collect_usages: bool = True, -) -> tuple[list[dict[str, Any]], list[dict[str, Any]], int]: +) -> tuple[list[SymbolInfo], list[SymbolInfo], int]: """Recursively collect dependencies and usages up to specified degree. Args: @@ -164,7 +210,7 @@ def under_token_limit() -> bool: if dep not in seen_symbols: # Calculate tokens for this symbol info = get_symbol_info(dep, max_tokens=max_tokens) - symbol_tokens = count_tokens(info["source"]) if info["source"] else 0 + symbol_tokens = count_tokens(info.source) if info.source else 0 if max_tokens and total_tokens + symbol_tokens > max_tokens: continue @@ -189,7 +235,7 @@ def under_token_limit() -> bool: if usage not in seen_symbols: # Calculate tokens for this symbol info = get_symbol_info(usage, max_tokens=max_tokens) - symbol_tokens = count_tokens(info["source"]) if info["source"] else 0 + symbol_tokens = count_tokens(info.source) if info.source else 0 if max_tokens and total_tokens + symbol_tokens > max_tokens: continue @@ -214,7 +260,7 @@ def reveal_symbol( max_tokens: Optional[int] = None, collect_dependencies: Optional[bool] = True, collect_usages: Optional[bool] = True, -) -> dict[str, Any]: +) -> RevealSymbolObservation: """Reveal the dependencies and usages of a symbol up to N degrees. Args: @@ -235,22 +281,36 @@ def reveal_symbol( """ symbols = codebase.get_symbols(symbol_name=symbol_name) if len(symbols) == 0: - return {"error": f"{symbol_name} not found"} + return RevealSymbolObservation( + status="error", + error=f"{symbol_name} not found", + ) if len(symbols) > 1: - return {"error": f"{symbol_name} is ambiguious", "valid_filepaths": [s.file.filepath for s in symbols]} + return RevealSymbolObservation( + status="error", + error=f"{symbol_name} is ambiguous", + valid_filepaths=[s.file.filepath for s in symbols], + ) symbol = symbols[0] if filepath: if symbol.file.filepath != filepath: - return {"error": f"{symbol_name} not found at {filepath}", "valid_filepaths": [s.file.filepath for s in symbols]} + return RevealSymbolObservation( + status="error", + error=f"{symbol_name} not found at {filepath}", + valid_filepaths=[s.file.filepath for s in symbols], + ) # Get dependencies and usages up to specified degree dependencies, usages, total_tokens = get_extended_context(symbol, max_depth, max_tokens, collect_dependencies=collect_dependencies, collect_usages=collect_usages) was_truncated = max_tokens is not None and total_tokens >= max_tokens - result = {"truncated": was_truncated} + result = RevealSymbolObservation( + status="success", + truncated=was_truncated, + ) if collect_dependencies: - result["dependencies"] = dependencies + result.dependencies = dependencies if collect_usages: - result["usages"] = usages + result.usages = usages return result diff --git a/src/codegen/extensions/tools/search.py b/src/codegen/extensions/tools/search.py index 7d81e412a..0923f6837 100644 --- a/src/codegen/extensions/tools/search.py +++ b/src/codegen/extensions/tools/search.py @@ -6,10 +6,72 @@ """ import re -from typing import Any, Optional +from typing import ClassVar, Optional + +from pydantic import Field from codegen import Codebase +from .observation import Observation + + +class SearchMatch(Observation): + """Information about a single line match.""" + + line_number: int = Field( + description="1-based line number of the match", + ) + line: str = Field( + description="The full line containing the match", + ) + match: str = Field( + description="The specific text that matched", + ) + + str_template: ClassVar[str] = "Line {line_number}: {match}" + + +class SearchFileResult(Observation): + """Search results for a single file.""" + + filepath: str = Field( + description="Path to the file containing matches", + ) + matches: list[SearchMatch] = Field( + description="List of matches found in this file", + ) + + str_template: ClassVar[str] = "{filepath}: {match_count} matches" + + def _get_details(self) -> dict[str, str | int]: + """Get details for string representation.""" + return {"match_count": len(self.matches)} + + +class SearchObservation(Observation): + """Response from searching the codebase.""" + + query: str = Field( + description="The search query that was used", + ) + page: int = Field( + description="Current page number (1-based)", + ) + total_pages: int = Field( + description="Total number of pages available", + ) + total_files: int = Field( + description="Total number of files with matches", + ) + files_per_page: int = Field( + description="Number of files shown per page", + ) + results: list[SearchFileResult] = Field( + description="Search results for this page", + ) + + str_template: ClassVar[str] = "Found {total_files} files with matches for '{query}' (page {page}/{total_pages})" + def search( codebase: Codebase, @@ -19,7 +81,7 @@ def search( page: int = 1, files_per_page: int = 10, use_regex: bool = False, -) -> dict[str, Any]: +) -> SearchObservation: """Search the codebase using text search or regex pattern matching. If use_regex is True, performs a regex pattern match on each line. @@ -38,29 +100,7 @@ def search( use_regex: Whether to treat query as a regex pattern (default: False) Returns: - Dict containing search results with matches and their sources, grouped by file: - { - "query": str, - "page": int, - "total_pages": int, - "total_files": int, - "files_per_page": int, - "results": [ - { - "filepath": str, - "matches": [ - { - "line_number": int, # 1-based line number - "line": str, # The full line containing the match - "match": str, # The specific text that matched - } - ] - } - ] - } - - Raises: - re.error: If use_regex is True and the regex pattern is invalid + SearchObservation containing search results with matches and their sources """ # Validate pagination parameters if page < 1: @@ -73,8 +113,16 @@ def search( try: pattern = re.compile(query) except re.error as e: - msg = f"Invalid regex pattern: {e!s}" - raise re.error(msg) from e + return SearchObservation( + status="error", + error=f"Invalid regex pattern: {e!s}", + query=query, + page=page, + total_pages=0, + total_files=0, + files_per_page=files_per_page, + results=[], + ) else: # For non-regex searches, escape special characters and make case-insensitive pattern = re.compile(re.escape(query), re.IGNORECASE) @@ -103,18 +151,25 @@ def search( match = pattern.search(line) if match: file_matches.append( - { - "line_number": line_number, - "line": line.strip(), - "match": match.group(0), # The full matched text - } + SearchMatch( + status="success", + line_number=line_number, + line=line.strip(), + match=match.group(0), + ) ) if file_matches: - all_results.append({"filepath": file.filepath, "matches": sorted(file_matches, key=lambda x: x["line_number"])}) + all_results.append( + SearchFileResult( + status="success", + filepath=file.filepath, + matches=sorted(file_matches, key=lambda x: x.line_number), + ) + ) # Sort all results by filepath - all_results.sort(key=lambda x: x["filepath"]) + all_results.sort(key=lambda x: x.filepath) # Calculate pagination total_files = len(all_results) @@ -125,11 +180,12 @@ def search( # Get the current page of results paginated_results = all_results[start_idx:end_idx] - return { - "query": query, - "page": page, - "total_pages": total_pages, - "total_files": total_files, - "files_per_page": files_per_page, - "results": paginated_results, - } + return SearchObservation( + status="success", + query=query, + page=page, + total_pages=total_pages, + total_files=total_files, + files_per_page=files_per_page, + results=paginated_results, + ) diff --git a/src/codegen/extensions/tools/semantic_edit.py b/src/codegen/extensions/tools/semantic_edit.py index 81cd98188..642a245f4 100644 --- a/src/codegen/extensions/tools/semantic_edit.py +++ b/src/codegen/extensions/tools/semantic_edit.py @@ -2,16 +2,41 @@ import difflib import re +from typing import ClassVar, Optional from langchain_anthropic import ChatAnthropic from langchain_core.prompts import ChatPromptTemplate +from pydantic import Field from codegen import Codebase +from .observation import Observation from .semantic_edit_prompts import _HUMAN_PROMPT_DRAFT_EDITOR, COMMANDER_SYSTEM_PROMPT from .view_file import add_line_numbers +class SemanticEditObservation(Observation): + """Response from making semantic edits to a file.""" + + filepath: str = Field( + description="Path to the edited file", + ) + diff: Optional[str] = Field( + default=None, + description="Unified diff showing the changes made", + ) + new_content: Optional[str] = Field( + default=None, + description="New content with line numbers", + ) + line_count: Optional[int] = Field( + default=None, + description="Total number of lines in file", + ) + + str_template: ClassVar[str] = "Edited file {filepath}" + + def generate_diff(original: str, modified: str) -> str: """Generate a unified diff between two strings. @@ -53,7 +78,8 @@ def _extract_code_block(llm_response: str) -> str: matches = re.findall(pattern, llm_response.strip(), re.DOTALL) if not matches: - raise ValueError("LLM response must contain code wrapped in ``` blocks. Got response: " + llm_response[:200] + "...") + msg = "LLM response must contain code wrapped in ``` blocks. Got response: " + llm_response[:200] + "..." + raise ValueError(msg) # Return the last code block exactly as is return matches[-1] @@ -106,7 +132,7 @@ def _validate_edit_boundaries(original_lines: list[str], modified_lines: list[st raise ValueError(msg) -def semantic_edit(codebase: Codebase, filepath: str, edit_content: str, start: int = 1, end: int = -1) -> dict[str, str]: +def semantic_edit(codebase: Codebase, filepath: str, edit_content: str, start: int = 1, end: int = -1) -> SemanticEditObservation: """Edit a file using semantic editing with line range support. This is an internal api and should not be called by the LLM.""" try: file = codebase.get_file(filepath) @@ -121,15 +147,16 @@ def semantic_edit(codebase: Codebase, filepath: str, edit_content: str, start: i # Check if file is too large for full edit MAX_LINES = 300 if len(original_lines) > MAX_LINES and start == 1 and end == -1: - return { - "error": ( + return SemanticEditObservation( + status="error", + error=( f"File is {len(original_lines)} lines long. For files longer than {MAX_LINES} lines, " "please specify a line range using start and end parameters. " "You may need to make multiple targeted edits." ), - "status": "error", - "line_count": len(original_lines), - } + filepath=filepath, + line_count=len(original_lines), + ) # Handle append mode if start == -1 and end == -1: @@ -137,7 +164,12 @@ def semantic_edit(codebase: Codebase, filepath: str, edit_content: str, start: i file.add_symbol_from_source(edit_content) codebase.commit() - return {"filepath": filepath, "content": file.content, "diff": generate_diff(original_content, file.content), "status": "success"} + return SemanticEditObservation( + status="success", + filepath=filepath, + new_content=file.content, + diff=generate_diff(original_content, file.content), + ) except Exception as e: msg = f"Failed to append content: {e!s}" raise ValueError(msg) @@ -167,21 +199,43 @@ def semantic_edit(codebase: Codebase, filepath: str, edit_content: str, start: i try: modified_segment = _extract_code_block(response.content) except ValueError as e: - msg = f"Failed to parse LLM response: {e!s}" - raise ValueError(msg) + return SemanticEditObservation( + status="error", + error=f"Failed to parse LLM response: {e!s}", + filepath=filepath, + ) # Merge the edited content with the original new_content = _merge_content(original_content, modified_segment, start, end) new_lines = new_content.splitlines() # Validate that no changes were made before the start line - _validate_edit_boundaries(original_lines, new_lines, start_idx, end_idx) + try: + _validate_edit_boundaries(original_lines, new_lines, start_idx, end_idx) + except ValueError as e: + return SemanticEditObservation( + status="error", + error=str(e), + filepath=filepath, + ) # Generate diff diff = generate_diff(original_content, new_content) # Apply the edit - file.edit(new_content) - codebase.commit() - - return {"filepath": filepath, "diff": diff, "status": "success", "new_content": add_line_numbers(new_content)} + try: + file.edit(new_content) + codebase.commit() + except Exception as e: + return SemanticEditObservation( + status="error", + error=f"Failed to apply edit: {e!s}", + filepath=filepath, + ) + + return SemanticEditObservation( + status="success", + filepath=filepath, + diff=diff, + new_content=add_line_numbers(new_content), + ) diff --git a/src/codegen/extensions/tools/semantic_search.py b/src/codegen/extensions/tools/semantic_search.py index 7acc071e9..f2876ca3e 100644 --- a/src/codegen/extensions/tools/semantic_search.py +++ b/src/codegen/extensions/tools/semantic_search.py @@ -1,10 +1,50 @@ """Semantic search over codebase files.""" -from typing import Any, Optional +from typing import ClassVar, Optional + +from pydantic import Field from codegen import Codebase from codegen.extensions.index.file_index import FileIndex +from .observation import Observation + + +class SearchResult(Observation): + """Information about a single search result.""" + + filepath: str = Field( + description="Path to the matching file", + ) + score: float = Field( + description="Similarity score of the match", + ) + preview: str = Field( + description="Preview of the file content", + ) + + str_template: ClassVar[str] = "{filepath} (score: {score})" + + +class SemanticSearchObservation(Observation): + """Response from semantic search over codebase.""" + + query: str = Field( + description="The search query that was used", + ) + results: list[SearchResult] = Field( + description="List of search results", + ) + + str_template: ClassVar[str] = "Found {result_count} results for '{query}'" + + def _get_details(self) -> dict[str, str | int]: + """Get details for string representation.""" + return { + "result_count": len(self.results), + "query": self.query, + } + def semantic_search( codebase: Codebase, @@ -12,7 +52,7 @@ def semantic_search( k: int = 5, preview_length: int = 200, index_path: Optional[str] = None, -) -> dict[str, Any]: +) -> SemanticSearchObservation: """Search the codebase using semantic similarity. This function provides semantic search over a codebase by using OpenAI's embeddings. @@ -31,23 +71,7 @@ def semantic_search( index_path: Optional path to a saved vector index Returns: - Dict containing search results or error information. Format: - { - "status": "success", - "query": str, - "results": [ - { - "filepath": str, - "score": float, - "preview": str - }, - ... - ] - } - Or on error: - { - "error": str - } + SemanticSearchObservation containing search results or error information. """ try: # Initialize vector index @@ -74,9 +98,25 @@ def semantic_search( if len(file.content) > preview_length: preview += "..." - formatted_results.append({"filepath": file.filepath, "score": float(score), "preview": preview}) + formatted_results.append( + SearchResult( + status="success", + filepath=file.filepath, + score=float(score), + preview=preview, + ) + ) - return {"status": "success", "query": query, "results": formatted_results} + return SemanticSearchObservation( + status="success", + query=query, + results=formatted_results, + ) except Exception as e: - return {"error": f"Failed to perform semantic search: {e!s}"} + return SemanticSearchObservation( + status="error", + error=f"Failed to perform semantic search: {e!s}", + query=query, + results=[], + ) diff --git a/src/codegen/extensions/tools/view_file.py b/src/codegen/extensions/tools/view_file.py index 8b7e70d7d..41d1276c5 100644 --- a/src/codegen/extensions/tools/view_file.py +++ b/src/codegen/extensions/tools/view_file.py @@ -1,9 +1,30 @@ """Tool for viewing file contents and metadata.""" -from typing import Any +from typing import ClassVar, Optional + +from pydantic import Field from codegen import Codebase +from .observation import Observation + + +class ViewFileObservation(Observation): + """Response from viewing a file.""" + + filepath: str = Field( + description="Path to the file", + ) + content: str = Field( + description="Content of the file", + ) + line_count: Optional[int] = Field( + default=None, + description="Number of lines in the file", + ) + + str_template: ClassVar[str] = "File {filepath} ({line_count} lines)" + def add_line_numbers(content: str) -> str: """Add line numbers to content. @@ -19,32 +40,32 @@ def add_line_numbers(content: str) -> str: return "\n".join(f"{i + 1:>{width}}|{line}" for i, line in enumerate(lines)) -def view_file(codebase: Codebase, filepath: str, line_numbers: bool = True) -> dict[str, Any]: +def view_file(codebase: Codebase, filepath: str, line_numbers: bool = True) -> ViewFileObservation: """View the contents and metadata of a file. Args: codebase: The codebase to operate on filepath: Path to the file relative to workspace root line_numbers: If True, add line numbers to the content (1-indexed) - - Returns: - Dict containing file contents and metadata, or error information if file not found """ - file = None - try: file = codebase.get_file(filepath) except ValueError: - pass - - if not file: - return {"error": f"File not found: {filepath}. Please use full filepath relative to workspace root."} + return ViewFileObservation( + status="error", + error=f"File not found: {filepath}. Please use full filepath relative to workspace root.", + filepath=filepath, + content="", + line_count=0, + ) content = file.content if line_numbers: content = add_line_numbers(content) - return { - "filepath": file.filepath, - "content": content, - } + return ViewFileObservation( + status="success", + filepath=file.filepath, + content=content, + line_count=len(content.splitlines()), + ) From ba400d29fe6dcde35ecace1c41675e68c9fd9bb3 Mon Sep 17 00:00:00 2001 From: jayhack Date: Wed, 19 Feb 2025 15:26:54 -0800 Subject: [PATCH 3/9] . --- src/codegen/extensions/langchain/tools.py | 46 +++++++++++---------- src/codegen/extensions/tools/observation.py | 5 +++ 2 files changed, 29 insertions(+), 22 deletions(-) diff --git a/src/codegen/extensions/langchain/tools.py b/src/codegen/extensions/langchain/tools.py index d2783b868..529a171e2 100644 --- a/src/codegen/extensions/langchain/tools.py +++ b/src/codegen/extensions/langchain/tools.py @@ -60,7 +60,7 @@ def __init__(self, codebase: Codebase) -> None: def _run(self, filepath: str) -> str: result = view_file(self.codebase, filepath) - return json.dumps(result, indent=2) + return result.render() class ListDirectoryInput(BaseModel): @@ -83,7 +83,7 @@ def __init__(self, codebase: Codebase) -> None: def _run(self, dirpath: str = "./", depth: int = 1) -> str: result = list_directory(self.codebase, dirpath, depth) - return json.dumps(result, indent=2) + return result.render() class SearchInput(BaseModel): @@ -106,7 +106,7 @@ def __init__(self, codebase: Codebase) -> None: def _run(self, query: str, target_directories: Optional[list[str]] = None) -> str: result = search(self.codebase, query, target_directories) - return json.dumps(result, indent=2) + return result.render() class EditFileInput(BaseModel): @@ -129,7 +129,7 @@ def __init__(self, codebase: Codebase) -> None: def _run(self, filepath: str, content: str) -> str: result = edit_file(self.codebase, filepath, content) - return json.dumps(result, indent=2) + return result.render() class CreateFileInput(BaseModel): @@ -152,7 +152,7 @@ def __init__(self, codebase: Codebase) -> None: def _run(self, filepath: str, content: str = "") -> str: result = create_file(self.codebase, filepath, content) - return json.dumps(result, indent=2) + return result.render() class DeleteFileInput(BaseModel): @@ -174,7 +174,7 @@ def __init__(self, codebase: Codebase) -> None: def _run(self, filepath: str) -> str: result = delete_file(self.codebase, filepath) - return json.dumps(result, indent=2) + return result.render() class CommitTool(BaseTool): @@ -189,7 +189,7 @@ def __init__(self, codebase: Codebase) -> None: def _run(self) -> str: result = commit(self.codebase) - return json.dumps(result, indent=2) + return result.render() class RevealSymbolInput(BaseModel): @@ -232,7 +232,7 @@ def _run( collect_dependencies=collect_dependencies, collect_usages=collect_usages, ) - return json.dumps(result, indent=2) + return result.render() _SEMANTIC_EDIT_BRIEF = """Tool for file editing via an LLM delegate. Describe the changes you want to make and an expert will apply them to the file. @@ -277,7 +277,7 @@ def __init__(self, codebase: Codebase) -> None: def _run(self, filepath: str, edit_content: str, start: int = 1, end: int = -1) -> str: # Create the the draft editor mini llm result = semantic_edit(self.codebase, filepath, edit_content, start=start, end=end) - return json.dumps(result, indent=2) + return result.render() class RenameFileInput(BaseModel): @@ -300,7 +300,7 @@ def __init__(self, codebase: Codebase) -> None: def _run(self, filepath: str, new_filepath: str) -> str: result = rename_file(self.codebase, filepath, new_filepath) - return json.dumps(result, indent=2) + return result.render() class MoveSymbolInput(BaseModel): @@ -343,7 +343,7 @@ def _run( strategy=strategy, include_dependencies=include_dependencies, ) - return json.dumps(result, indent=2) + return result.render() class SemanticSearchInput(BaseModel): @@ -367,7 +367,7 @@ def __init__(self, codebase: Codebase) -> None: def _run(self, query: str, k: int = 5, preview_length: int = 200) -> str: result = semantic_search(self.codebase, query, k=k, preview_length=preview_length) - return json.dumps(result, indent=2) + return result.render() ######################################################################################################################## @@ -391,7 +391,7 @@ class RunBashCommandTool(BaseTool): def _run(self, command: str, is_background: bool = False) -> str: result = run_bash_command(command, is_background) - return json.dumps(result, indent=2) + return result.render() ######################################################################################################################## @@ -419,7 +419,7 @@ def __init__(self, codebase: Codebase) -> None: def _run(self, title: str, body: str) -> str: result = create_pr(self.codebase, title, body) - return json.dumps(result, indent=2) + return result.render() class GithubViewPRInput(BaseModel): @@ -441,6 +441,7 @@ def __init__(self, codebase: Codebase) -> None: def _run(self, pr_id: int) -> str: result = view_pr(self.codebase, pr_id) + return result.render() return json.dumps(result, indent=2) @@ -464,7 +465,7 @@ def __init__(self, codebase: Codebase) -> None: def _run(self, pr_number: int, body: str) -> str: result = create_pr_comment(self.codebase, pr_number, body) - return json.dumps(result, indent=2) + return result.render() class GithubCreatePRReviewCommentInput(BaseModel): @@ -510,7 +511,7 @@ def _run( side=side, start_line=start_line, ) - return json.dumps(result, indent=2) + return result.render() ######################################################################################################################## @@ -537,7 +538,7 @@ def __init__(self, client: LinearClient) -> None: def _run(self, issue_id: str) -> str: result = linear_get_issue_tool(self.client, issue_id) - return json.dumps(result, indent=2) + return result.render() class LinearGetIssueCommentsInput(BaseModel): @@ -559,7 +560,7 @@ def __init__(self, client: LinearClient) -> None: def _run(self, issue_id: str) -> str: result = linear_get_issue_comments_tool(self.client, issue_id) - return json.dumps(result, indent=2) + return result.render() class LinearCommentOnIssueInput(BaseModel): @@ -582,7 +583,7 @@ def __init__(self, client: LinearClient) -> None: def _run(self, issue_id: str, body: str) -> str: result = linear_comment_on_issue_tool(self.client, issue_id, body) - return json.dumps(result, indent=2) + return result.render() class LinearSearchIssuesInput(BaseModel): @@ -605,7 +606,7 @@ def __init__(self, client: LinearClient) -> None: def _run(self, query: str, limit: int = 10) -> str: result = linear_search_issues_tool(self.client, query, limit) - return json.dumps(result, indent=2) + return result.render() class LinearCreateIssueInput(BaseModel): @@ -629,7 +630,7 @@ def __init__(self, client: LinearClient) -> None: def _run(self, title: str, description: str | None = None, team_id: str | None = None) -> str: result = linear_create_issue_tool(self.client, title, description, team_id) - return json.dumps(result, indent=2) + return result.render() class LinearGetTeamsTool(BaseTool): @@ -644,7 +645,7 @@ def __init__(self, client: LinearClient) -> None: def _run(self) -> str: result = linear_get_teams_tool(self.client) - return json.dumps(result, indent=2) + return result.render() ######################################################################################################################## @@ -677,6 +678,7 @@ def __init__(self, codebase: Codebase, say: Callable[[str], None]) -> None: self.codebase = codebase def _run(self, content: str) -> str: + # TODO - pull this out into a separate function print("> Adding links to message") content_formatted = add_links_to_message(content, self.codebase) print("> Sending message to Slack") diff --git a/src/codegen/extensions/tools/observation.py b/src/codegen/extensions/tools/observation.py index 9f9e23335..ac76c2b91 100644 --- a/src/codegen/extensions/tools/observation.py +++ b/src/codegen/extensions/tools/observation.py @@ -1,5 +1,6 @@ """Base class for tool observations/responses.""" +import json from typing import Any, ClassVar, Optional from pydantic import BaseModel, Field @@ -49,3 +50,7 @@ def __str__(self) -> str: def __repr__(self) -> str: """Get detailed string representation of the observation.""" return f"{self.__class__.__name__}({self.model_dump_json()})" + + def render(self) -> str: + """Render the observation as a string.""" + return json.dumps(self.model_dump(), indent=2) From 0aefbf4a42c359fb551b0d31b9f09dea1b4d9e96 Mon Sep 17 00:00:00 2001 From: jayhack Date: Wed, 19 Feb 2025 15:32:07 -0800 Subject: [PATCH 4/9] . --- .../extensions/tools/github/create_pr.py | 65 ++++-- .../tools/github/create_pr_comment.py | 41 +++- .../tools/github/create_pr_review_comment.py | 54 +++-- .../extensions/tools/github/view_pr.py | 42 +++- src/codegen/extensions/tools/linear/linear.py | 201 ++++++++++++++++-- 5 files changed, 334 insertions(+), 69 deletions(-) diff --git a/src/codegen/extensions/tools/github/create_pr.py b/src/codegen/extensions/tools/github/create_pr.py index 1f6a7e307..0e10b850b 100644 --- a/src/codegen/extensions/tools/github/create_pr.py +++ b/src/codegen/extensions/tools/github/create_pr.py @@ -1,28 +1,50 @@ """Tool for creating pull requests.""" import uuid -from typing import Any +from typing import ClassVar from github import GithubException +from pydantic import Field from codegen import Codebase +from ..observation import Observation -def create_pr(codebase: Codebase, title: str, body: str) -> dict[str, Any]: + +class CreatePRObservation(Observation): + """Response from creating a pull request.""" + + url: str = Field( + description="URL of the created PR", + ) + number: int = Field( + description="PR number", + ) + title: str = Field( + description="Title of the PR", + ) + + str_template: ClassVar[str] = "Created PR #{number}: {title}" + + +def create_pr(codebase: Codebase, title: str, body: str) -> CreatePRObservation: """Create a PR for the current branch. Args: codebase: The codebase to operate on title: The title of the PR body: The body/description of the PR - - Returns: - Dict containing PR info, or error information if operation fails """ try: # Check for uncommitted changes and commit them if len(codebase.get_diff()) == 0: - return {"error": "No changes to create a PR."} + return CreatePRObservation( + status="error", + error="No changes to create a PR.", + url="", + number=0, + title=title, + ) # TODO: this is very jank. We should ideally check out the branch before # making the changes, but it looks like `codebase.checkout` blows away @@ -37,13 +59,26 @@ def create_pr(codebase: Codebase, title: str, body: str) -> dict[str, Any]: try: pr = codebase.create_pr(title=title, body=body) except GithubException as e: - print(e) - return {"error": "Failed to create PR. Check if the PR already exists."} - return { - "status": "success", - "url": pr.html_url, - "number": pr.number, - "title": pr.title, - } + return CreatePRObservation( + status="error", + error="Failed to create PR. Check if the PR already exists.", + url="", + number=0, + title=title, + ) + + return CreatePRObservation( + status="success", + url=pr.html_url, + number=pr.number, + title=pr.title, + ) + except Exception as e: - return {"error": f"Failed to create PR: {e!s}"} + return CreatePRObservation( + status="error", + error=f"Failed to create PR: {e!s}", + url="", + number=0, + title=title, + ) diff --git a/src/codegen/extensions/tools/github/create_pr_comment.py b/src/codegen/extensions/tools/github/create_pr_comment.py index ae8bcae40..3f538413c 100644 --- a/src/codegen/extensions/tools/github/create_pr_comment.py +++ b/src/codegen/extensions/tools/github/create_pr_comment.py @@ -1,27 +1,46 @@ """Tool for creating PR comments.""" -from typing import Any +from typing import ClassVar + +from pydantic import Field from codegen import Codebase +from ..observation import Observation + + +class PRCommentObservation(Observation): + """Response from creating a PR comment.""" + + pr_number: int = Field( + description="PR number the comment was added to", + ) + body: str = Field( + description="Content of the comment", + ) -def create_pr_comment(codebase: Codebase, pr_number: int, body: str) -> dict[str, Any]: + str_template: ClassVar[str] = "Added comment to PR #{pr_number}" + + +def create_pr_comment(codebase: Codebase, pr_number: int, body: str) -> PRCommentObservation: """Create a general comment on a pull request. Args: codebase: The codebase to operate on pr_number: The PR number to comment on body: The comment text - - Returns: - Dict containing comment status """ try: codebase.create_pr_comment(pr_number=pr_number, body=body) - return { - "status": "success", - "message": "Comment created successfully", - "pr_number": pr_number, - } + return PRCommentObservation( + status="success", + pr_number=pr_number, + body=body, + ) except Exception as e: - return {"error": f"Failed to create PR comment: {e!s}"} + return PRCommentObservation( + status="error", + error=f"Failed to create PR comment: {e!s}", + pr_number=pr_number, + body=body, + ) diff --git a/src/codegen/extensions/tools/github/create_pr_review_comment.py b/src/codegen/extensions/tools/github/create_pr_review_comment.py index 11f6d04b4..ac6ea6b86 100644 --- a/src/codegen/extensions/tools/github/create_pr_review_comment.py +++ b/src/codegen/extensions/tools/github/create_pr_review_comment.py @@ -1,9 +1,33 @@ """Tool for creating PR review comments.""" -from typing import Any, Optional +from typing import ClassVar, Optional + +from pydantic import Field from codegen import Codebase +from ..observation import Observation + + +class PRReviewCommentObservation(Observation): + """Response from creating a PR review comment.""" + + pr_number: int = Field( + description="PR number the comment was added to", + ) + path: str = Field( + description="File path the comment was added to", + ) + line: Optional[int] = Field( + default=None, + description="Line number the comment was added to", + ) + body: str = Field( + description="Content of the comment", + ) + + str_template: ClassVar[str] = "Added review comment to PR #{pr_number} at {path}:{line}" + def create_pr_review_comment( codebase: Codebase, @@ -14,7 +38,7 @@ def create_pr_review_comment( line: Optional[int] = None, side: Optional[str] = None, start_line: Optional[int] = None, -) -> dict[str, Any]: +) -> PRReviewCommentObservation: """Create an inline review comment on a specific line in a pull request. Args: @@ -26,9 +50,6 @@ def create_pr_review_comment( line: The line number to comment on side: Which version of the file to comment on ('LEFT' or 'RIGHT') start_line: For multi-line comments, the starting line - - Returns: - Dict containing comment status """ try: codebase.create_pr_review_comment( @@ -40,12 +61,19 @@ def create_pr_review_comment( side=side, start_line=start_line, ) - return { - "status": "success", - "message": "Review comment created successfully", - "pr_number": pr_number, - "path": path, - "line": line, - } + return PRReviewCommentObservation( + status="success", + pr_number=pr_number, + path=path, + line=line, + body=body, + ) except Exception as e: - return {"error": f"Failed to create PR review comment: {e!s}"} + return PRReviewCommentObservation( + status="error", + error=f"Failed to create PR review comment: {e!s}", + pr_number=pr_number, + path=path, + line=line, + body=body, + ) diff --git a/src/codegen/extensions/tools/github/view_pr.py b/src/codegen/extensions/tools/github/view_pr.py index c077b56a1..a698d80d2 100644 --- a/src/codegen/extensions/tools/github/view_pr.py +++ b/src/codegen/extensions/tools/github/view_pr.py @@ -1,21 +1,47 @@ """Tool for viewing PR contents and modified symbols.""" -from typing import Any +from typing import ClassVar + +from pydantic import Field from codegen import Codebase +from ..observation import Observation + + +class ViewPRObservation(Observation): + """Response from viewing a PR.""" + + pr_id: int = Field( + description="ID of the PR", + ) + patch: str = Field( + description="The PR's patch/diff content", + ) -def view_pr(codebase: Codebase, pr_id: int) -> dict[str, Any]: + str_template: ClassVar[str] = "PR #{pr_id}" + + +def view_pr(codebase: Codebase, pr_id: int) -> ViewPRObservation: """Get the diff and modified symbols of a PR. Args: codebase: The codebase to operate on pr_id: Number of the PR to get the contents for - - Returns: - Dict containing modified symbols and patch """ - modified_symbols, patch = codebase.get_modified_symbols_in_pr(pr_id) + try: + modified_symbols, patch = codebase.get_modified_symbols_in_pr(pr_id) + + return ViewPRObservation( + status="success", + pr_id=pr_id, + patch=patch, + ) - # Convert modified_symbols set to list for JSON serialization - return {"status": "success", "patch": patch} + except Exception as e: + return ViewPRObservation( + status="error", + error=f"Failed to view PR: {e!s}", + pr_id=pr_id, + patch="", + ) diff --git a/src/codegen/extensions/tools/linear/linear.py b/src/codegen/extensions/tools/linear/linear.py index c15cb5d10..d079f45d7 100644 --- a/src/codegen/extensions/tools/linear/linear.py +++ b/src/codegen/extensions/tools/linear/linear.py @@ -1,66 +1,223 @@ -from typing import Any +"""Tools for interacting with Linear.""" + +from typing import ClassVar + +from pydantic import Field from codegen.extensions.linear.linear_client import LinearClient +from ..observation import Observation + + +class LinearIssueObservation(Observation): + """Response from getting a Linear issue.""" + + issue_id: str = Field(description="ID of the issue") + issue_data: dict = Field(description="Full issue data") + + str_template: ClassVar[str] = "Issue {issue_id}" + + +class LinearCommentsObservation(Observation): + """Response from getting Linear issue comments.""" + + issue_id: str = Field(description="ID of the issue") + comments: list[dict] = Field(description="List of comments") + + str_template: ClassVar[str] = "{comment_count} comments on issue {issue_id}" + + def _get_details(self) -> dict[str, int]: + """Get details for string representation.""" + return {"comment_count": len(self.comments)} + + +class LinearCommentObservation(Observation): + """Response from commenting on a Linear issue.""" + + issue_id: str = Field(description="ID of the issue") + comment: dict = Field(description="Created comment data") + + str_template: ClassVar[str] = "Added comment to issue {issue_id}" + + +class LinearWebhookObservation(Observation): + """Response from registering a Linear webhook.""" + + webhook_url: str = Field(description="URL of the registered webhook") + team_id: str = Field(description="ID of the team") + response: dict = Field(description="Full webhook registration response") + + str_template: ClassVar[str] = "Registered webhook for team {team_id}" + + +class LinearSearchObservation(Observation): + """Response from searching Linear issues.""" + + query: str = Field(description="Search query used") + issues: list[dict] = Field(description="List of matching issues") + + str_template: ClassVar[str] = "Found {issue_count} issues matching '{query}'" + + def _get_details(self) -> dict[str, str | int]: + """Get details for string representation.""" + return { + "issue_count": len(self.issues), + "query": self.query, + } + + +class LinearCreateIssueObservation(Observation): + """Response from creating a Linear issue.""" + + title: str = Field(description="Title of the created issue") + team_id: str | None = Field(description="Team ID if specified") + issue_data: dict = Field(description="Created issue data") + + str_template: ClassVar[str] = "Created issue '{title}'" + + +class LinearTeamsObservation(Observation): + """Response from getting Linear teams.""" + + teams: list[dict] = Field(description="List of teams") + + str_template: ClassVar[str] = "Found {team_count} teams" + + def _get_details(self) -> dict[str, int]: + """Get details for string representation.""" + return {"team_count": len(self.teams)} + -def linear_get_issue_tool(client: LinearClient, issue_id: str) -> dict[str, Any]: +def linear_get_issue_tool(client: LinearClient, issue_id: str) -> LinearIssueObservation: """Get an issue by its ID.""" try: issue = client.get_issue(issue_id) - return {"status": "success", "issue": issue.dict()} + return LinearIssueObservation( + status="success", + issue_id=issue_id, + issue_data=issue.dict(), + ) except Exception as e: - return {"error": f"Failed to get issue: {e!s}"} + return LinearIssueObservation( + status="error", + error=f"Failed to get issue: {e!s}", + issue_id=issue_id, + issue_data={}, + ) -def linear_get_issue_comments_tool(client: LinearClient, issue_id: str) -> dict[str, Any]: +def linear_get_issue_comments_tool(client: LinearClient, issue_id: str) -> LinearCommentsObservation: """Get comments for a specific issue.""" try: comments = client.get_issue_comments(issue_id) - return {"status": "success", "comments": [comment.dict() for comment in comments]} + return LinearCommentsObservation( + status="success", + issue_id=issue_id, + comments=[comment.dict() for comment in comments], + ) except Exception as e: - return {"error": f"Failed to get issue comments: {e!s}"} + return LinearCommentsObservation( + status="error", + error=f"Failed to get issue comments: {e!s}", + issue_id=issue_id, + comments=[], + ) -def linear_comment_on_issue_tool(client: LinearClient, issue_id: str, body: str) -> dict[str, Any]: +def linear_comment_on_issue_tool(client: LinearClient, issue_id: str, body: str) -> LinearCommentObservation: """Add a comment to an issue.""" try: comment = client.comment_on_issue(issue_id, body) - return {"status": "success", "comment": comment} + return LinearCommentObservation( + status="success", + issue_id=issue_id, + comment=comment, + ) except Exception as e: - return {"error": f"Failed to comment on issue: {e!s}"} + return LinearCommentObservation( + status="error", + error=f"Failed to comment on issue: {e!s}", + issue_id=issue_id, + comment={}, + ) -def linear_register_webhook_tool(client: LinearClient, webhook_url: str, team_id: str, secret: str, enabled: bool, resource_types: list[str]) -> dict[str, Any]: +def linear_register_webhook_tool( + client: LinearClient, + webhook_url: str, + team_id: str, + secret: str, + enabled: bool, + resource_types: list[str], +) -> LinearWebhookObservation: """Register a webhook with Linear.""" try: response = client.register_webhook(webhook_url, team_id, secret, enabled, resource_types) - return {"status": "success", "response": response} + return LinearWebhookObservation( + status="success", + webhook_url=webhook_url, + team_id=team_id, + response=response, + ) except Exception as e: - return {"error": f"Failed to register webhook: {e!s}"} + return LinearWebhookObservation( + status="error", + error=f"Failed to register webhook: {e!s}", + webhook_url=webhook_url, + team_id=team_id, + response={}, + ) -def linear_search_issues_tool(client: LinearClient, query: str, limit: int = 10) -> dict[str, Any]: +def linear_search_issues_tool(client: LinearClient, query: str, limit: int = 10) -> LinearSearchObservation: """Search for issues using a query string.""" try: issues = client.search_issues(query, limit) - return {"status": "success", "issues": [issue.dict() for issue in issues]} + return LinearSearchObservation( + status="success", + query=query, + issues=[issue.dict() for issue in issues], + ) except Exception as e: - return {"error": f"Failed to search issues: {e!s}"} + return LinearSearchObservation( + status="error", + error=f"Failed to search issues: {e!s}", + query=query, + issues=[], + ) -def linear_create_issue_tool(client: LinearClient, title: str, description: str | None = None, team_id: str | None = None) -> dict[str, Any]: +def linear_create_issue_tool(client: LinearClient, title: str, description: str | None = None, team_id: str | None = None) -> LinearCreateIssueObservation: """Create a new issue.""" try: issue = client.create_issue(title, description, team_id) - return {"status": "success", "issue": issue.dict()} + return LinearCreateIssueObservation( + status="success", + title=title, + team_id=team_id, + issue_data=issue.dict(), + ) except Exception as e: - return {"error": f"Failed to create issue: {e!s}"} + return LinearCreateIssueObservation( + status="error", + error=f"Failed to create issue: {e!s}", + title=title, + team_id=team_id, + issue_data={}, + ) -def linear_get_teams_tool(client: LinearClient) -> dict[str, Any]: +def linear_get_teams_tool(client: LinearClient) -> LinearTeamsObservation: """Get all teams the authenticated user has access to.""" try: teams = client.get_teams() - return {"status": "success", "teams": [team.dict() for team in teams]} + return LinearTeamsObservation( + status="success", + teams=[team.dict() for team in teams], + ) except Exception as e: - return {"error": f"Failed to get teams: {e!s}"} + return LinearTeamsObservation( + status="error", + error=f"Failed to get teams: {e!s}", + teams=[], + ) From 85241f46b2364117b2a0f746f6a9dad5a6a72c12 Mon Sep 17 00:00:00 2001 From: jayhack Date: Wed, 19 Feb 2025 15:36:14 -0800 Subject: [PATCH 5/9] . --- tests/integration/extension/test_linear.py | 39 ++++---- tests/unit/codegen/extensions/test_tools.py | 101 ++++++++++---------- 2 files changed, 72 insertions(+), 68 deletions(-) diff --git a/tests/integration/extension/test_linear.py b/tests/integration/extension/test_linear.py index 4d39c9788..ece4ab677 100644 --- a/tests/integration/extension/test_linear.py +++ b/tests/integration/extension/test_linear.py @@ -30,30 +30,35 @@ def client() -> LinearClient: def test_linear_get_issue(client: LinearClient) -> None: """Test getting an issue from Linear.""" # Link to issue: https://linear.app/codegen-sh/issue/CG-10775/read-file-and-reveal-symbol-tool-size-limits - issue = linear_get_issue_tool(client, "CG-10775") - assert issue["status"] == "success" - assert issue["issue"]["id"] == "d5a7d6db-e20d-4d67-98f8-acedef6d3536" + result = linear_get_issue_tool(client, "CG-10775") + assert result.status == "success" + assert result.issue_id == "CG-10775" + assert result.issue_data["id"] == "d5a7d6db-e20d-4d67-98f8-acedef6d3536" def test_linear_get_issue_comments(client: LinearClient) -> None: """Test getting comments for an issue from Linear.""" - comments = linear_get_issue_comments_tool(client, "CG-10775") - assert comments["status"] == "success" - assert len(comments["comments"]) > 1 + result = linear_get_issue_comments_tool(client, "CG-10775") + assert result.status == "success" + assert result.issue_id == "CG-10775" + assert len(result.comments) > 1 def test_linear_comment_on_issue(client: LinearClient) -> None: """Test commenting on a Linear issue.""" test_comment = "Test comment from automated testing" result = linear_comment_on_issue_tool(client, "CG-10775", test_comment) - assert result["status"] == "success" + assert result.status == "success" + assert result.issue_id == "CG-10775" + assert result.comment["body"] == test_comment def test_search_issues(client: LinearClient) -> None: """Test searching for issues in Linear.""" - issues = linear_search_issues_tool(client, "REVEAL_SYMBOL") - assert issues["status"] == "success" - assert len(issues["issues"]) > 0 + result = linear_search_issues_tool(client, "REVEAL_SYMBOL") + assert result.status == "success" + assert result.query == "REVEAL_SYMBOL" + assert len(result.issues) > 0 def test_create_issue(client: LinearClient) -> None: @@ -76,20 +81,20 @@ def test_create_issue(client: LinearClient) -> None: # Test the tool wrapper with default team_id result = linear_create_issue_tool(client, "Test Tool Issue", "Test description from tool") - assert result["status"] == "success" - assert result["issue"]["title"] == "Test Tool Issue" - assert result["issue"]["description"] == "Test description from tool" + assert result.status == "success" + assert result.title == "Test Tool Issue" + assert result.issue_data["title"] == "Test Tool Issue" + assert result.issue_data["description"] == "Test description from tool" def test_get_teams(client: LinearClient) -> None: """Test getting teams from Linear.""" result = linear_get_teams_tool(client) - assert result["status"] == "success" - assert len(result["teams"]) > 0 + assert result.status == "success" + assert len(result.teams) > 0 # Verify team structure - team = result["teams"][0] - print(result) + team = result.teams[0] assert "id" in team assert "name" in team assert "key" in team diff --git a/tests/unit/codegen/extensions/test_tools.py b/tests/unit/codegen/extensions/test_tools.py index 458159268..eeef231c3 100644 --- a/tests/unit/codegen/extensions/test_tools.py +++ b/tests/unit/codegen/extensions/test_tools.py @@ -42,53 +42,53 @@ def greet(self): def test_view_file(codebase): """Test viewing a file.""" result = view_file(codebase, "src/main.py") - assert "error" not in result - assert result["filepath"] == "src/main.py" - assert "hello()" in result["content"] + assert result.status == "success" + assert result.filepath == "src/main.py" + assert "hello()" in result.content def test_list_directory(codebase): """Test listing directory contents.""" result = list_directory(codebase, "./") - assert "error" not in result - assert "src" in result["subdirectories"] + assert result.status == "success" + assert "src" in result.directory_info.subdirectories def test_search(codebase): """Test searching the codebase.""" result = search(codebase, "hello") - assert "error" not in result - assert len(result["results"]) > 0 + assert result.status == "success" + assert len(result.results) > 0 def test_edit_file(codebase): """Test editing a file.""" result = edit_file(codebase, "src/main.py", "print('edited')") - assert "error" not in result - assert result["content"] == "1|print('edited')" + assert result.status == "success" + assert result.file_info.content == "1|print('edited')" def test_create_file(codebase): """Test creating a file.""" result = create_file(codebase, "src/new.py", "print('new')") - assert "error" not in result - assert result["filepath"] == "src/new.py" - assert result["content"] == "1|print('new')" + assert result.status == "success" + assert result.filepath == "src/new.py" + assert result.file_info.content == "1|print('new')" def test_delete_file(codebase): """Test deleting a file.""" result = delete_file(codebase, "src/main.py") - assert "error" not in result - assert result["status"] == "success" + assert result.status == "success" + assert result.filepath == "src/main.py" def test_rename_file(codebase): """Test renaming a file.""" result = rename_file(codebase, "src/main.py", "src/renamed.py") - assert "error" not in result - assert result["status"] == "success" - assert result["new_filepath"] == "src/renamed.py" + assert result.status == "success" + assert result.old_filepath == "src/main.py" + assert result.new_filepath == "src/renamed.py" def test_move_symbol(codebase): @@ -102,8 +102,10 @@ def test_move_symbol(codebase): symbol_name="hello", target_file="src/target.py", ) - assert "error" not in result - assert result["status"] == "success" + assert result.status == "success" + assert result.symbol_name == "hello" + assert result.source_file == "src/main.py" + assert result.target_file == "src/target.py" def test_reveal_symbol(codebase): @@ -113,8 +115,8 @@ def test_reveal_symbol(codebase): symbol_name="hello", max_depth=1, ) - assert "error" not in result - assert not result["truncated"] + assert result.status == "success" + assert not result.truncated @pytest.mark.skip("TODO") @@ -127,43 +129,42 @@ def hello(): # ... existing code ... """ result = semantic_edit(codebase, "src/main.py", edit_spec) - assert "error" not in result - assert result["status"] == "success" + assert result.status == "success" + assert "Hello from semantic edit!" in result.new_content @pytest.mark.skip("TODO") def test_semantic_search(codebase): """Test semantic search.""" result = semantic_search(codebase, "function that prints hello") - assert "error" not in result - assert result["status"] == "success" + assert result.status == "success" + assert len(result.results) > 0 @pytest.mark.skip("TODO: Github tests") def test_create_pr(codebase): """Test creating a PR.""" result = create_pr(codebase, "Test PR", "This is a test PR") - assert "error" not in result - assert result["status"] == "success" + assert result.status == "success" + assert result.title == "Test PR" @pytest.mark.skip("TODO: Github tests") def test_view_pr(codebase): """Test viewing a PR.""" result = view_pr(codebase, 1) - assert "error" not in result - assert result["status"] == "success" - assert "modified_symbols" in result - assert "patch" in result + assert result.status == "success" + assert result.pr_id == 1 + assert result.patch != "" @pytest.mark.skip("TODO: Github tests") def test_create_pr_comment(codebase): """Test creating a PR comment.""" result = create_pr_comment(codebase, 1, "Test comment") - assert "error" not in result - assert result["status"] == "success" - assert result["message"] == "Comment created successfully" + assert result.status == "success" + assert result.pr_number == 1 + assert result.body == "Test comment" @pytest.mark.skip("TODO: Github tests") @@ -177,9 +178,11 @@ def test_create_pr_review_comment(codebase): path="src/main.py", line=1, ) - assert "error" not in result - assert result["status"] == "success" - assert result["message"] == "Review comment created successfully" + assert result.status == "success" + assert result.pr_number == 1 + assert result.path == "src/main.py" + assert result.line == 1 + assert result.body == "Test review comment" def test_replacement_edit(codebase): @@ -191,9 +194,8 @@ def test_replacement_edit(codebase): pattern=r'print\("Hello, world!"\)', replacement='print("Goodbye, world!")', ) - assert "error" not in result - assert result["status"] == "success" - assert 'print("Goodbye, world!")' in result["new_content"] + assert result.status == "success" + assert 'print("Goodbye, world!")' in result.new_content # Test with line range result = replacement_edit( @@ -204,9 +206,8 @@ def test_replacement_edit(codebase): start=5, # Class definition line end=7, ) - assert "error" not in result - assert result["status"] == "success" - assert "class Welcomer" in result["new_content"] + assert result.status == "success" + assert "class Welcomer" in result.new_content # Test with regex groups result = replacement_edit( @@ -215,9 +216,8 @@ def test_replacement_edit(codebase): pattern=r"def (\w+)\(\):", replacement=r"def \1_function():", ) - assert "error" not in result - assert result["status"] == "success" - assert "def hello_function():" in result["new_content"] + assert result.status == "success" + assert "def hello_function():" in result.new_content # Test with count limit result = replacement_edit( @@ -227,9 +227,8 @@ def test_replacement_edit(codebase): replacement="async def", count=1, # Only replace first occurrence ) - assert "error" not in result - assert result["status"] == "success" - assert result["new_content"].count("async def") == 1 + assert result.status == "success" + assert result.new_content.count("async def") == 1 # Test no matches result = replacement_edit( @@ -238,5 +237,5 @@ def test_replacement_edit(codebase): pattern=r"nonexistent_pattern", replacement="replacement", ) - assert result["status"] == "unchanged" - assert "No matches found" in result["message"] + assert result.status == "unchanged" + assert "No matches found" in str(result) From 809458eec11906735f197920020feb0f5f481789 Mon Sep 17 00:00:00 2001 From: jayhack Date: Wed, 19 Feb 2025 15:48:51 -0800 Subject: [PATCH 6/9] . --- src/codegen/extensions/tools/observation.py | 11 ++--------- src/codegen/extensions/tools/replacement_edit.py | 6 +++++- 2 files changed, 7 insertions(+), 10 deletions(-) diff --git a/src/codegen/extensions/tools/observation.py b/src/codegen/extensions/tools/observation.py index ac76c2b91..512b10117 100644 --- a/src/codegen/extensions/tools/observation.py +++ b/src/codegen/extensions/tools/observation.py @@ -31,21 +31,14 @@ def _get_details(self) -> dict[str, Any]: Override this in subclasses to customize string output. By default, includes all fields except status and error. """ - return {k: v for k, v in self.model_dump().items() if k not in {"status", "error"} and v is not None} + return self.model_dump() def __str__(self) -> str: """Get string representation of the observation.""" if self.status == "error": return f"Error: {self.error}" - details = self._get_details() - if not details: - return self.status - - return self.str_template.format( - status=self.status, - details=", ".join(f"{k}={v}" for k, v in details.items()), - ) + return self.render() def __repr__(self) -> str: """Get detailed string representation of the observation.""" diff --git a/src/codegen/extensions/tools/replacement_edit.py b/src/codegen/extensions/tools/replacement_edit.py index 7f04b94c0..3bac90e17 100644 --- a/src/codegen/extensions/tools/replacement_edit.py +++ b/src/codegen/extensions/tools/replacement_edit.py @@ -26,8 +26,12 @@ class ReplacementEditObservation(Observation): default=None, description="New content with line numbers", ) + message: Optional[str] = Field( + default=None, + description="Message describing the result", + ) - str_template: ClassVar[str] = "Edited file {filepath}" + str_template: ClassVar[str] = "{message}" if "{message}" else "Edited file {filepath}" def generate_diff(original: str, modified: str) -> str: From a70df48d7535204f6f40dfb1fdb22a21789f954e Mon Sep 17 00:00:00 2001 From: jayhack Date: Wed, 19 Feb 2025 16:07:35 -0800 Subject: [PATCH 7/9] . --- src/codegen/extensions/langchain/agent.py | 3 +- .../extensions/tools/list_directory.py | 165 ++++++++---------- src/codegen/extensions/tools/view_file.py | 5 + 3 files changed, 82 insertions(+), 91 deletions(-) diff --git a/src/codegen/extensions/langchain/agent.py b/src/codegen/extensions/langchain/agent.py index 9d778072b..faff0a8b4 100644 --- a/src/codegen/extensions/langchain/agent.py +++ b/src/codegen/extensions/langchain/agent.py @@ -24,7 +24,6 @@ RevealSymbolTool, SearchTool, SemanticEditTool, - SemanticSearchTool, ViewFileTool, ) @@ -70,8 +69,8 @@ def create_codebase_agent( MoveSymbolTool(codebase), RevealSymbolTool(codebase), SemanticEditTool(codebase), - SemanticSearchTool(codebase), ReplacementEditTool(codebase), + # SemanticSearchTool(codebase), # =====[ Github Integration ]===== # Enable Github integration # GithubCreatePRTool(codebase), diff --git a/src/codegen/extensions/tools/list_directory.py b/src/codegen/extensions/tools/list_directory.py index ed76a953b..e08889010 100644 --- a/src/codegen/extensions/tools/list_directory.py +++ b/src/codegen/extensions/tools/list_directory.py @@ -1,123 +1,110 @@ """Tool for listing directory contents.""" -from typing import ClassVar, Union +from typing import ClassVar -from pydantic import BaseModel, Field +from pydantic import Field from codegen import Codebase -from codegen.sdk.core.directory import Directory from .observation import Observation -class DirectoryInfo(BaseModel): +class DirectoryInfo(Observation): """Information about a directory.""" - name: str = Field(description="Name of the directory") - path: str = Field(description="Full path to the directory") - files: list[str] = Field(description="List of files in this directory") - subdirectories: list[Union[str, "DirectoryInfo"]] = Field( - description="List of subdirectories (either names or full DirectoryInfo objects depending on depth)", + path: str = Field( + description="Path to the directory", ) + files: list[str] = Field( + description="List of files in the directory", + ) + subdirectories: list[str] = Field( + description="List of subdirectories", + ) + + str_template: ClassVar[str] = "Directory {path} ({file_count} files, {dir_count} subdirs)" + + def _get_details(self) -> dict[str, int]: + """Get details for string representation.""" + return { + "file_count": len(self.files), + "dir_count": len(self.subdirectories), + } + + def render(self) -> str: + """Render directory listing as a file tree.""" + lines = [ + f"[LIST DIRECTORY]: {self.path}", + "", + ] + + def add_tree_item(name: str, prefix: str = "", is_last: bool = False) -> str: + """Helper to format a tree item with proper prefix.""" + marker = "└── " if is_last else "├── " + return prefix + marker + name + + # Sort files and directories + items = [] + for f in sorted(self.files): + items.append((f, False)) # False = not a directory + for d in sorted(self.subdirectories): + items.append((d + "/", True)) # True = is a directory + + if not items: + lines.append("(empty directory)") + return "\n".join(lines) + + # Generate tree + for i, (name, is_dir) in enumerate(items): + is_last = i == len(items) - 1 + lines.append(add_tree_item(name, is_last=is_last)) + + return "\n".join(lines) class ListDirectoryObservation(Observation): """Response from listing directory contents.""" - path: str = Field(description="Path to the listed directory") - directory_info: DirectoryInfo = Field(description="Information about the directory and its contents") - depth: int = Field(description="How deep the directory traversal went") + directory_info: DirectoryInfo = Field( + description="Information about the directory", + ) + + str_template: ClassVar[str] = "{directory_info}" - str_template: ClassVar[str] = "Listed contents of {path} (depth={depth})" + def render(self) -> str: + """Render directory listing.""" + return self.directory_info.render() -def list_directory(codebase: Codebase, dirpath: str = "./", depth: int = 1) -> ListDirectoryObservation: - """List contents of a directory. +def list_directory(codebase: Codebase, path: str) -> ListDirectoryObservation: + """List the contents of a directory. Args: codebase: The codebase to operate on - dirpath: Path to directory relative to workspace root - depth: How deep to traverse the directory tree. Default is 1 (immediate children only). - Use -1 for unlimited depth. - - Returns: - ListDirectoryObservation containing directory contents and metadata + path: Path to the directory relative to workspace root """ try: - directory = codebase.get_directory(dirpath) - except ValueError: - return ListDirectoryObservation( - status="error", - error=f"Directory not found: {dirpath}", - path=dirpath, - directory_info=DirectoryInfo( - name="", - path=dirpath, - files=[], - subdirectories=[], - ), - depth=depth, - ) - - if not directory: - return ListDirectoryObservation( - status="error", - error=f"Directory not found: {dirpath}", - path=dirpath, - directory_info=DirectoryInfo( - name="", - path=dirpath, - files=[], - subdirectories=[], - ), - depth=depth, - ) - - def get_directory_info(dir_obj: Directory, current_depth: int) -> DirectoryInfo: - """Helper function to get directory info recursively.""" - # Get direct files - all_files = [] - for file in dir_obj.files: - if file.directory == dir_obj: - all_files.append(file.filepath.split("/")[-1]) - - # Get direct subdirectories - subdirs = [] - for subdir in dir_obj.subdirectories: - # Only include direct descendants - if subdir.parent == dir_obj: - if current_depth != 1: - new_depth = current_depth - 1 if current_depth > 1 else -1 - subdirs.append(get_directory_info(subdir, new_depth)) - else: - # At max depth, just include name - subdirs.append(subdir.name) - - return DirectoryInfo( - name=dir_obj.name, - path=dir_obj.dirpath, - files=all_files, + files, subdirs = codebase.list_directory(path) + dir_info = DirectoryInfo( + status="success", + path=path, + files=files, subdirectories=subdirs, ) - - try: - directory_info = get_directory_info(directory, depth) return ListDirectoryObservation( status="success", - path=dirpath, - directory_info=directory_info, - depth=depth, + directory_info=dir_info, ) except Exception as e: + dir_info = DirectoryInfo( + status="error", + error=str(e), + path=path, + files=[], + subdirectories=[], + ) return ListDirectoryObservation( status="error", - error=f"Failed to list directory: {e!s}", - path=dirpath, - directory_info=DirectoryInfo( - name="", - path=dirpath, - files=[], - subdirectories=[], - ), - depth=depth, + error=str(e), + directory_info=dir_info, ) diff --git a/src/codegen/extensions/tools/view_file.py b/src/codegen/extensions/tools/view_file.py index 41d1276c5..2cb29e652 100644 --- a/src/codegen/extensions/tools/view_file.py +++ b/src/codegen/extensions/tools/view_file.py @@ -25,6 +25,11 @@ class ViewFileObservation(Observation): str_template: ClassVar[str] = "File {filepath} ({line_count} lines)" + def render(self) -> str: + return f"""[VIEW FILE]: {self.filepath} ({self.line_count} lines) +{self.content} +""" + def add_line_numbers(content: str) -> str: """Add line numbers to content. From 3af5e1220dedb5faeebfba1a6d34eb650358bb11 Mon Sep 17 00:00:00 2001 From: jayhack Date: Wed, 19 Feb 2025 16:26:04 -0800 Subject: [PATCH 8/9] . --- src/codegen/extensions/tools/edit_file.py | 79 ++++------- .../extensions/tools/list_directory.py | 129 +++++++++++++----- src/codegen/extensions/tools/search.py | 38 +++++- 3 files changed, 155 insertions(+), 91 deletions(-) diff --git a/src/codegen/extensions/tools/edit_file.py b/src/codegen/extensions/tools/edit_file.py index 50e85b73d..d89818e6a 100644 --- a/src/codegen/extensions/tools/edit_file.py +++ b/src/codegen/extensions/tools/edit_file.py @@ -7,7 +7,7 @@ from codegen import Codebase from .observation import Observation -from .view_file import ViewFileObservation, view_file +from .replacement_edit import generate_diff class EditFileObservation(Observation): @@ -16,23 +16,26 @@ class EditFileObservation(Observation): filepath: str = Field( description="Path to the edited file", ) - file_info: ViewFileObservation = Field( - description="Information about the edited file", + diff: str = Field( + description="Unified diff showing the changes made", ) str_template: ClassVar[str] = "Edited file {filepath}" + def render(self) -> str: + """Render edit results in a clean format.""" + return f"""[EDIT FILE]: {self.filepath} -def edit_file(codebase: Codebase, filepath: str, content: str) -> EditFileObservation: - """Edit a file by replacing its entire content. +{self.diff}""" + + +def edit_file(codebase: Codebase, filepath: str, new_content: str) -> EditFileObservation: + """Edit the contents of a file. Args: codebase: The codebase to operate on - filepath: Path to the file to edit - content: New content for the file - - Returns: - EditFileObservation containing updated file state, or error if file not found + filepath: Path to the file relative to workspace root + new_content: New content for the file """ try: file = codebase.get_file(filepath) @@ -41,52 +44,18 @@ def edit_file(codebase: Codebase, filepath: str, content: str) -> EditFileObserv status="error", error=f"File not found: {filepath}", filepath=filepath, - file_info=ViewFileObservation( - status="error", - error=f"File not found: {filepath}", - filepath=filepath, - content="", - line_count=0, - ), - ) - - if file is None: - return EditFileObservation( - status="error", - error=f"File not found: {filepath}", - filepath=filepath, - file_info=ViewFileObservation( - status="error", - error=f"File not found: {filepath}", - filepath=filepath, - content="", - line_count=0, - ), + diff="", ) - try: - file.edit(content) - codebase.commit() + # Generate diff before making changes + diff = generate_diff(file.content, new_content) - # Get updated file info using view_file - file_info = view_file(codebase, filepath) + # Apply the edit + file.edit(new_content) + codebase.commit() - return EditFileObservation( - status="success", - filepath=filepath, - file_info=file_info, - ) - - except Exception as e: - return EditFileObservation( - status="error", - error=f"Failed to edit file: {e!s}", - filepath=filepath, - file_info=ViewFileObservation( - status="error", - error=f"Failed to edit file: {e!s}", - filepath=filepath, - content="", - line_count=0, - ), - ) + return EditFileObservation( + status="success", + filepath=filepath, + diff=diff, + ) diff --git a/src/codegen/extensions/tools/list_directory.py b/src/codegen/extensions/tools/list_directory.py index e08889010..b029f848a 100644 --- a/src/codegen/extensions/tools/list_directory.py +++ b/src/codegen/extensions/tools/list_directory.py @@ -5,6 +5,7 @@ from pydantic import Field from codegen import Codebase +from codegen.sdk.core.directory import Directory from .observation import Observation @@ -12,14 +13,17 @@ class DirectoryInfo(Observation): """Information about a directory.""" + name: str = Field( + description="Name of the directory", + ) path: str = Field( - description="Path to the directory", + description="Full path to the directory", ) files: list[str] = Field( - description="List of files in the directory", + description="List of files in this directory", ) - subdirectories: list[str] = Field( - description="List of subdirectories", + subdirectories: list["DirectoryInfo | str"] = Field( + description="List of subdirectories (full info or just names at max depth)", ) str_template: ClassVar[str] = "Directory {path} ({file_count} files, {dir_count} subdirs)" @@ -38,26 +42,56 @@ def render(self) -> str: "", ] - def add_tree_item(name: str, prefix: str = "", is_last: bool = False) -> str: + def add_tree_item(name: str, prefix: str = "", is_last: bool = False) -> tuple[str, str]: """Helper to format a tree item with proper prefix.""" marker = "└── " if is_last else "├── " - return prefix + marker + name + indent = " " if is_last else "│ " + return prefix + marker + name, prefix + indent + + def build_tree(items: list[tuple[str, bool, "DirectoryInfo | None"]], prefix: str = "") -> list[str]: + """Recursively build tree with proper indentation.""" + if not items: + return [] + + result = [] + for i, (name, is_dir, dir_info) in enumerate(items): + is_last = i == len(items) - 1 + line, new_prefix = add_tree_item(name, prefix, is_last) + result.append(line) + + # If this is a directory with full info, recursively add its contents + if dir_info and isinstance(dir_info, DirectoryInfo): + subitems = [] + # Add files first + for f in sorted(dir_info.files): + subitems.append((f, False, None)) + # Then add subdirectories + for d in dir_info.subdirectories: + if isinstance(d, DirectoryInfo): + subitems.append((d.name + "/", True, d)) + else: + subitems.append((d + "/", True, None)) + + result.extend(build_tree(subitems, new_prefix)) + + return result # Sort files and directories items = [] for f in sorted(self.files): - items.append((f, False)) # False = not a directory - for d in sorted(self.subdirectories): - items.append((d + "/", True)) # True = is a directory + items.append((f, False, None)) # (name, is_dir, dir_info) + for d in self.subdirectories: + if isinstance(d, DirectoryInfo): + items.append((d.name + "/", True, d)) + else: + items.append((d + "/", True, None)) if not items: lines.append("(empty directory)") return "\n".join(lines) # Generate tree - for i, (name, is_dir) in enumerate(items): - is_last = i == len(items) - 1 - lines.append(add_tree_item(name, is_last=is_last)) + lines.extend(build_tree(items)) return "\n".join(lines) @@ -76,35 +110,60 @@ def render(self) -> str: return self.directory_info.render() -def list_directory(codebase: Codebase, path: str) -> ListDirectoryObservation: - """List the contents of a directory. +def list_directory(codebase: Codebase, path: str = "./", depth: int = 2) -> ListDirectoryObservation: + """List contents of a directory. Args: codebase: The codebase to operate on - path: Path to the directory relative to workspace root + path: Path to directory relative to workspace root + depth: How deep to traverse the directory tree. Default is 1 (immediate children only). + Use -1 for unlimited depth. """ try: - files, subdirs = codebase.list_directory(path) - dir_info = DirectoryInfo( - status="success", - path=path, - files=files, - subdirectories=subdirs, - ) + directory = codebase.get_directory(path) + except ValueError: return ListDirectoryObservation( - status="success", - directory_info=dir_info, - ) - except Exception as e: - dir_info = DirectoryInfo( status="error", - error=str(e), - path=path, - files=[], - subdirectories=[], + error=f"Directory not found: {path}", + directory_info=DirectoryInfo( + status="error", + name=path.split("/")[-1], + path=path, + files=[], + subdirectories=[], + ), ) - return ListDirectoryObservation( - status="error", - error=str(e), - directory_info=dir_info, + + def get_directory_info(dir_obj: Directory, current_depth: int) -> DirectoryInfo: + """Helper function to get directory info recursively.""" + # Get direct files + all_files = [] + for file in dir_obj.files: + if file.directory == dir_obj: + all_files.append(file.filepath.split("/")[-1]) + + # Get direct subdirectories + subdirs = [] + for subdir in dir_obj.subdirectories: + # Only include direct descendants + if subdir.parent == dir_obj: + if current_depth != 1: + new_depth = current_depth - 1 if current_depth > 1 else -1 + subdirs.append(get_directory_info(subdir, new_depth)) + else: + # At max depth, just include name + subdirs.append(subdir.name) + + return DirectoryInfo( + status="success", + name=dir_obj.name, + path=dir_obj.dirpath, + files=sorted(all_files), + subdirectories=subdirs, ) + + dir_info = get_directory_info(directory, depth) + return ListDirectoryObservation( + status="success", + directory_info=dir_info, + ) diff --git a/src/codegen/extensions/tools/search.py b/src/codegen/extensions/tools/search.py index 0923f6837..9be02f039 100644 --- a/src/codegen/extensions/tools/search.py +++ b/src/codegen/extensions/tools/search.py @@ -27,9 +27,12 @@ class SearchMatch(Observation): match: str = Field( description="The specific text that matched", ) - str_template: ClassVar[str] = "Line {line_number}: {match}" + def render(self) -> str: + """Render match in a VSCode-like format.""" + return f"{self.line_number:>4}: {self.line}" + class SearchFileResult(Observation): """Search results for a single file.""" @@ -43,6 +46,15 @@ class SearchFileResult(Observation): str_template: ClassVar[str] = "{filepath}: {match_count} matches" + def render(self) -> str: + """Render file results in a VSCode-like format.""" + lines = [ + f"📄 {self.filepath}", + ] + for match in self.matches: + lines.append(match.render()) + return "\n".join(lines) + def _get_details(self) -> dict[str, str | int]: """Get details for string representation.""" return {"match_count": len(self.matches)} @@ -72,6 +84,30 @@ class SearchObservation(Observation): str_template: ClassVar[str] = "Found {total_files} files with matches for '{query}' (page {page}/{total_pages})" + def render(self) -> str: + """Render search results in a VSCode-like format.""" + if self.status == "error": + return f"[SEARCH ERROR]: {self.error}" + + lines = [ + f"[SEARCH RESULTS]: {self.query}", + f"Found {self.total_files} files with matches (showing page {self.page} of {self.total_pages})", + "", + ] + + if not self.results: + lines.append("No matches found") + return "\n".join(lines) + + for result in self.results: + lines.append(result.render()) + lines.append("") # Add blank line between files + + if self.total_pages > 1: + lines.append(f"Page {self.page}/{self.total_pages} (use page parameter to see more results)") + + return "\n".join(lines) + def search( codebase: Codebase, From 25fa49b43ba7146fa4fa9b68e72978ac8c020060 Mon Sep 17 00:00:00 2001 From: jayhack Date: Wed, 19 Feb 2025 17:03:49 -0800 Subject: [PATCH 9/9] . --- .../extensions/tools/list_directory.py | 57 +++++++++++-------- tests/unit/codegen/extensions/test_tools.py | 35 +++++++++++- 2 files changed, 66 insertions(+), 26 deletions(-) diff --git a/src/codegen/extensions/tools/list_directory.py b/src/codegen/extensions/tools/list_directory.py index b029f848a..24023612b 100644 --- a/src/codegen/extensions/tools/list_directory.py +++ b/src/codegen/extensions/tools/list_directory.py @@ -19,11 +19,17 @@ class DirectoryInfo(Observation): path: str = Field( description="Full path to the directory", ) - files: list[str] = Field( - description="List of files in this directory", + files: list[str] | None = Field( + default=None, + description="List of files in this directory (None if at max depth)", ) - subdirectories: list["DirectoryInfo | str"] = Field( - description="List of subdirectories (full info or just names at max depth)", + subdirectories: list["DirectoryInfo"] = Field( + default_factory=list, + description="List of subdirectories", + ) + is_leaf: bool = Field( + default=False, + description="Whether this is a leaf node (at max depth)", ) str_template: ClassVar[str] = "Directory {path} ({file_count} files, {dir_count} subdirs)" @@ -31,7 +37,7 @@ class DirectoryInfo(Observation): def _get_details(self) -> dict[str, int]: """Get details for string representation.""" return { - "file_count": len(self.files), + "file_count": len(self.files or []), "dir_count": len(self.subdirectories), } @@ -59,18 +65,16 @@ def build_tree(items: list[tuple[str, bool, "DirectoryInfo | None"]], prefix: st line, new_prefix = add_tree_item(name, prefix, is_last) result.append(line) - # If this is a directory with full info, recursively add its contents - if dir_info and isinstance(dir_info, DirectoryInfo): + # If this is a directory and not a leaf node, show its contents + if dir_info and not dir_info.is_leaf: subitems = [] # Add files first - for f in sorted(dir_info.files): - subitems.append((f, False, None)) + if dir_info.files: + for f in sorted(dir_info.files): + subitems.append((f, False, None)) # Then add subdirectories for d in dir_info.subdirectories: - if isinstance(d, DirectoryInfo): - subitems.append((d.name + "/", True, d)) - else: - subitems.append((d + "/", True, None)) + subitems.append((d.name + "/", True, d)) result.extend(build_tree(subitems, new_prefix)) @@ -78,13 +82,11 @@ def build_tree(items: list[tuple[str, bool, "DirectoryInfo | None"]], prefix: st # Sort files and directories items = [] - for f in sorted(self.files): - items.append((f, False, None)) # (name, is_dir, dir_info) + if self.files: + for f in sorted(self.files): + items.append((f, False, None)) for d in self.subdirectories: - if isinstance(d, DirectoryInfo): - items.append((d.name + "/", True, d)) - else: - items.append((d + "/", True, None)) + items.append((d.name + "/", True, d)) if not items: lines.append("(empty directory)") @@ -136,7 +138,7 @@ def list_directory(codebase: Codebase, path: str = "./", depth: int = 2) -> List def get_directory_info(dir_obj: Directory, current_depth: int) -> DirectoryInfo: """Helper function to get directory info recursively.""" - # Get direct files + # Get direct files (always include files unless at max depth) all_files = [] for file in dir_obj.files: if file.directory == dir_obj: @@ -147,12 +149,21 @@ def get_directory_info(dir_obj: Directory, current_depth: int) -> DirectoryInfo: for subdir in dir_obj.subdirectories: # Only include direct descendants if subdir.parent == dir_obj: - if current_depth != 1: + if current_depth > 1 or current_depth == -1: + # For deeper traversal, get full directory info new_depth = current_depth - 1 if current_depth > 1 else -1 subdirs.append(get_directory_info(subdir, new_depth)) else: - # At max depth, just include name - subdirs.append(subdir.name) + # At max depth, return a leaf node + subdirs.append( + DirectoryInfo( + status="success", + name=subdir.name, + path=subdir.dirpath, + files=None, # Don't include files at max depth + is_leaf=True, + ) + ) return DirectoryInfo( status="success", diff --git a/tests/unit/codegen/extensions/test_tools.py b/tests/unit/codegen/extensions/test_tools.py index 51538a830..fd0fafb31 100644 --- a/tests/unit/codegen/extensions/test_tools.py +++ b/tests/unit/codegen/extensions/test_tools.py @@ -50,9 +50,36 @@ def test_view_file(codebase): def test_list_directory(codebase): """Test listing directory contents.""" - result = list_directory(codebase, "./") + # Create a nested directory structure for testing + create_file(codebase, "src/core/__init__.py", "") + create_file(codebase, "src/core/models.py", "") + create_file(codebase, "src/utils.py", "") + + result = list_directory(codebase, "./", depth=2) # Ensure we get nested structure assert result.status == "success" - assert "src" in result.directory_info.subdirectories + + # Check directory structure + dir_info = result.directory_info + + # Check that src exists and has proper structure + src_dir = next(d for d in dir_info.subdirectories) + assert src_dir.name == "src" + assert "main.py" in src_dir.files + assert "utils.py" in src_dir.files + + # Check nested core directory exists in subdirectories + assert any(d.name == "core" for d in src_dir.subdirectories) + core_dir = next(d for d in src_dir.subdirectories if d.name == "core") + + # Verify rendered output has proper tree structure + rendered = result.render() + print(rendered) + expected_tree = """ +└── src/ + ├── main.py + ├── utils.py + └── core/""" + assert expected_tree in rendered.strip() def test_search(codebase): @@ -66,7 +93,9 @@ def test_edit_file(codebase): """Test editing a file.""" result = edit_file(codebase, "src/main.py", "print('edited')") assert result.status == "success" - assert result.file_info.content == "1|print('edited')" + assert result.filepath == "src/main.py" + assert "+print('edited')" in result.diff + assert "-def hello():" in result.diff # Check that old content is shown in diff def test_create_file(codebase):