diff --git a/src/codegen/extensions/langchain/agent.py b/src/codegen/extensions/langchain/agent.py index 9931d830e..8909e02d7 100644 --- a/src/codegen/extensions/langchain/agent.py +++ b/src/codegen/extensions/langchain/agent.py @@ -13,8 +13,12 @@ from .tools import ( CommitTool, CreateFileTool, + CreatePRCommentTool, + CreatePRReviewCommentTool, + CreatePRTool, DeleteFileTool, EditFileTool, + GetPRcontentsTool, ListDirectoryTool, MoveSymbolTool, RenameFileTool, @@ -64,6 +68,10 @@ def create_codebase_agent( SemanticEditTool(codebase), SemanticSearchTool(codebase), CommitTool(codebase), + CreatePRTool(codebase), + GetPRcontentsTool(codebase), + CreatePRCommentTool(codebase), + CreatePRReviewCommentTool(codebase), ] # Get the prompt to use diff --git a/src/codegen/extensions/langchain/tools.py b/src/codegen/extensions/langchain/tools.py index 00a6365e9..6299cc556 100644 --- a/src/codegen/extensions/langchain/tools.py +++ b/src/codegen/extensions/langchain/tools.py @@ -1,7 +1,6 @@ """Langchain tools for workspace operations.""" import json -import uuid from typing import ClassVar, Literal, Optional from langchain.tools import BaseTool @@ -12,6 +11,9 @@ from ..tools import ( commit, create_file, + create_pr, + create_pr_comment, + create_pr_review_comment, delete_file, edit_file, list_directory, @@ -22,6 +24,7 @@ semantic_edit, semantic_search, view_file, + view_pr, ) @@ -205,12 +208,11 @@ def _run( collect_dependencies: bool = True, collect_usages: bool = True, ) -> str: - # Find the symbol first - found_symbol = self.codebase.get_symbol(symbol_name) result = reveal_symbol( - found_symbol, - degree, - max_tokens, + codebase=self.codebase, + symbol_name=symbol_name, + degree=degree, + max_tokens=max_tokens, collect_dependencies=collect_dependencies, collect_usages=collect_usages, ) @@ -356,11 +358,8 @@ def __init__(self, codebase: Codebase) -> None: super().__init__(codebase=codebase) def _run(self, title: str, body: str) -> str: - if self.codebase._op.git_cli.active_branch.name == self.codebase._op.default_branch: - # If the current checked out branch is the default branch, checkout onto a new branch - self.codebase.checkout(branch=f"{uuid.uuid4()}", create_if_missing=True) - pr = self.codebase.create_pr(title=title, body=body) - return pr.html_url + result = create_pr(self.codebase, title, body) + return json.dumps(result, indent=2) class GetPRContentsInput(BaseModel): @@ -381,11 +380,7 @@ def __init__(self, codebase: Codebase) -> None: super().__init__(codebase=codebase) def _run(self, pr_id: int) -> str: - modified_symbols, patch = self.codebase.get_modified_symbols_in_pr(pr_id) - - # Convert modified_symbols set to list for JSON serialization - result = {"modified_symbols": list(modified_symbols), "patch": patch} - + result = view_pr(self.codebase, pr_id) return json.dumps(result, indent=2) @@ -408,8 +403,8 @@ def __init__(self, codebase: Codebase) -> None: super().__init__(codebase=codebase) def _run(self, pr_number: int, body: str) -> str: - self.codebase.create_pr_comment(pr_number=pr_number, body=body) - return "Comment created successfully" + result = create_pr_comment(self.codebase, pr_number, body) + return json.dumps(result, indent=2) class CreatePRReviewCommentInput(BaseModel): @@ -445,7 +440,8 @@ def _run( side: str | None = None, start_line: int | None = None, ) -> str: - self.codebase.create_pr_review_comment( + result = create_pr_review_comment( + self.codebase, pr_number=pr_number, body=body, commit_sha=commit_sha, @@ -454,7 +450,7 @@ def _run( side=side, start_line=start_line, ) - return "Review comment created successfully" + return json.dumps(result, indent=2) def get_workspace_tools(codebase: Codebase) -> list["BaseTool"]: @@ -476,8 +472,11 @@ def get_workspace_tools(codebase: Codebase) -> list["BaseTool"]: EditFileTool(codebase), GetPRcontentsTool(codebase), ListDirectoryTool(codebase), + MoveSymbolTool(codebase), + RenameFileTool(codebase), RevealSymbolTool(codebase), SearchTool(codebase), SemanticEditTool(codebase), + SemanticSearchTool(codebase), ViewFileTool(codebase), ] diff --git a/src/codegen/extensions/mcp/codebase_tools.py b/src/codegen/extensions/mcp/codebase_tools.py index bf3f4d2ff..fa64d877f 100644 --- a/src/codegen/extensions/mcp/codebase_tools.py +++ b/src/codegen/extensions/mcp/codebase_tools.py @@ -10,7 +10,8 @@ mcp = FastMCP( "codebase-tools-mcp", - instructions="Use this server to access any information from your codebase. This tool can provide information ranging from AST Symbol details and information from across the codebase. Use this tool for all questions, queries regarding your codebase.", + instructions="""Use this server to access any information from your codebase. This tool can provide information ranging from AST Symbol details and information from across the codebase. + Use this tool for all questions, queries regarding your codebase.""", ) @@ -20,21 +21,16 @@ def reveal_symbol_tool( target_file: Annotated[Optional[str], "The file path of the file containing the symbol to inspect"], codebase_dir: Annotated[str, "The root directory of your codebase"], codebase_language: Annotated[ProgrammingLanguage, "The language the codebase is written in"], - degree: Annotated[Optional[int], "depth do which symbol information is retrieved"], + max_depth: Annotated[Optional[int], "depth up to which symbol information is retrieved"], collect_dependencies: Annotated[Optional[bool], "includes dependencies of symbol"], collect_usages: Annotated[Optional[bool], "includes usages of symbol"], ): codebase = Codebase(repo_path=codebase_dir, programming_language=codebase_language) - found_symbol = None - if target_file: - file = codebase.get_file(target_file) - found_symbol = file.get_symbol(symbol_name) - else: - found_symbol = codebase.get_symbol(symbol_name) - result = reveal_symbol( - found_symbol, - degree, + codebase=codebase, + symbol_name=symbol_name, + filepath=target_file, + max_depth=max_depth, collect_dependencies=collect_dependencies, collect_usages=collect_usages, ) diff --git a/src/codegen/extensions/tools/README.md b/src/codegen/extensions/tools/README.md new file mode 100644 index 000000000..f69e74e8a --- /dev/null +++ b/src/codegen/extensions/tools/README.md @@ -0,0 +1,4 @@ +# Tools + +- should take in a `codebase` and string args +- gets "wrapped" by extensions, e.g. MCP or Langchain diff --git a/src/codegen/extensions/tools/__init__.py b/src/codegen/extensions/tools/__init__.py index 9ce7b4f90..74f8ba83c 100644 --- a/src/codegen/extensions/tools/__init__.py +++ b/src/codegen/extensions/tools/__init__.py @@ -1,35 +1,42 @@ """Tools for workspace operations.""" -from .file_operations import ( - commit, - create_file, - delete_file, - edit_file, - list_directory, - move_symbol, - rename_file, - view_file, -) +from .commit import commit +from .create_file import create_file +from .delete_file import delete_file +from .edit_file import edit_file +from .github.create_pr import create_pr +from .github.create_pr_comment import create_pr_comment +from .github.create_pr_review_comment import create_pr_review_comment +from .github.view_pr import view_pr +from .list_directory import list_directory +from .move_symbol import move_symbol +from .rename_file import rename_file from .reveal_symbol import reveal_symbol from .search import search from .semantic_edit import semantic_edit from .semantic_search import semantic_search +from .view_file import view_file __all__ = [ + # Git operations "commit", + # File operations "create_file", + "create_pr", + "create_pr_comment", + "create_pr_review_comment", "delete_file", "edit_file", "list_directory", - # Symbol analysis + # Symbol operations "move_symbol", - # File operations "rename_file", "reveal_symbol", - # Search + # Search operations "search", - # Semantic edit + # Edit operations "semantic_edit", "semantic_search", "view_file", + "view_pr", ] diff --git a/src/codegen/extensions/tools/commit.py b/src/codegen/extensions/tools/commit.py new file mode 100644 index 000000000..eb47c17cf --- /dev/null +++ b/src/codegen/extensions/tools/commit.py @@ -0,0 +1,18 @@ +"""Tool for committing changes to disk.""" + +from typing import Any + +from codegen import Codebase + + +def commit(codebase: Codebase) -> dict[str, Any]: + """Commit any pending changes to disk. + + Args: + codebase: The codebase to operate on + + Returns: + Dict containing commit status + """ + codebase.commit() + return {"status": "success", "message": "Changes committed to disk"} diff --git a/src/codegen/extensions/tools/create_file.py b/src/codegen/extensions/tools/create_file.py new file mode 100644 index 000000000..d340e4b1c --- /dev/null +++ b/src/codegen/extensions/tools/create_file.py @@ -0,0 +1,25 @@ +"""Tool for creating new files.""" + +from typing import Any + +from codegen import Codebase + +from .view_file import view_file + + +def create_file(codebase: Codebase, filepath: str, content: str = "") -> dict[str, Any]: + """Create a new file. + + Args: + codebase: The codebase to operate on + filepath: Path where to create the file + content: Initial file content + + Returns: + Dict containing new file state, or error information if file already exists + """ + if codebase.has_file(filepath): + return {"error": f"File already exists: {filepath}"} + file = codebase.create_file(filepath, content=content) + codebase.commit() + return view_file(codebase, filepath) diff --git a/src/codegen/extensions/tools/delete_file.py b/src/codegen/extensions/tools/delete_file.py new file mode 100644 index 000000000..9703cd21f --- /dev/null +++ b/src/codegen/extensions/tools/delete_file.py @@ -0,0 +1,27 @@ +"""Tool for deleting files.""" + +from typing import Any + +from codegen import Codebase + + +def delete_file(codebase: Codebase, filepath: str) -> dict[str, Any]: + """Delete a file. + + Args: + codebase: The codebase to operate on + filepath: Path to the file to delete + + Returns: + Dict containing deletion status, or error information if file not found + """ + try: + file = codebase.get_file(filepath) + except ValueError: + return {"error": f"File not found: {filepath}"} + if file is None: + return {"error": f"File not found: {filepath}"} + + file.remove() + codebase.commit() + return {"status": "success", "deleted_file": filepath} diff --git a/src/codegen/extensions/tools/edit_file.py b/src/codegen/extensions/tools/edit_file.py new file mode 100644 index 000000000..a7f19e448 --- /dev/null +++ b/src/codegen/extensions/tools/edit_file.py @@ -0,0 +1,30 @@ +"""Tool for editing file contents.""" + +from typing import Any + +from codegen import Codebase + +from .view_file import view_file + + +def edit_file(codebase: Codebase, filepath: str, content: str) -> dict[str, Any]: + """Edit a file by replacing its entire content. + + Args: + codebase: The codebase to operate on + filepath: Path to the file to edit + content: New content for the file + + Returns: + Dict containing updated file state, or error information if file not found + """ + try: + file = codebase.get_file(filepath) + except ValueError: + return {"error": f"File not found: {filepath}"} + if file is None: + return {"error": f"File not found: {filepath}"} + + file.edit(content) + codebase.commit() + return view_file(codebase, filepath) diff --git a/src/codegen/extensions/tools/github/create_pr.py b/src/codegen/extensions/tools/github/create_pr.py new file mode 100644 index 000000000..42eab67ac --- /dev/null +++ b/src/codegen/extensions/tools/github/create_pr.py @@ -0,0 +1,34 @@ +"""Tool for creating pull requests.""" + +import uuid +from typing import Any + +from codegen import Codebase + + +def create_pr(codebase: Codebase, title: str, body: str) -> dict[str, Any]: + """Create a PR for the current branch. + + Args: + codebase: The codebase to operate on + title: The title of the PR + body: The body/description of the PR + + Returns: + Dict containing PR info, or error information if operation fails + """ + try: + # If on default branch, create a new branch + if codebase._op.git_cli.active_branch.name == codebase._op.default_branch: + codebase.checkout(branch=f"{uuid.uuid4()}", create_if_missing=True) + + # Create the PR + pr = codebase.create_pr(title=title, body=body) + return { + "status": "success", + "url": pr.html_url, + "number": pr.number, + "title": pr.title, + } + except Exception as e: + return {"error": f"Failed to create PR: {e!s}"} diff --git a/src/codegen/extensions/tools/github/create_pr_comment.py b/src/codegen/extensions/tools/github/create_pr_comment.py new file mode 100644 index 000000000..ae8bcae40 --- /dev/null +++ b/src/codegen/extensions/tools/github/create_pr_comment.py @@ -0,0 +1,27 @@ +"""Tool for creating PR comments.""" + +from typing import Any + +from codegen import Codebase + + +def create_pr_comment(codebase: Codebase, pr_number: int, body: str) -> dict[str, Any]: + """Create a general comment on a pull request. + + Args: + codebase: The codebase to operate on + pr_number: The PR number to comment on + body: The comment text + + Returns: + Dict containing comment status + """ + try: + codebase.create_pr_comment(pr_number=pr_number, body=body) + return { + "status": "success", + "message": "Comment created successfully", + "pr_number": pr_number, + } + except Exception as e: + return {"error": f"Failed to create PR comment: {e!s}"} diff --git a/src/codegen/extensions/tools/github/create_pr_review_comment.py b/src/codegen/extensions/tools/github/create_pr_review_comment.py new file mode 100644 index 000000000..11f6d04b4 --- /dev/null +++ b/src/codegen/extensions/tools/github/create_pr_review_comment.py @@ -0,0 +1,51 @@ +"""Tool for creating PR review comments.""" + +from typing import Any, Optional + +from codegen import Codebase + + +def create_pr_review_comment( + codebase: Codebase, + pr_number: int, + body: str, + commit_sha: str, + path: str, + line: Optional[int] = None, + side: Optional[str] = None, + start_line: Optional[int] = None, +) -> dict[str, Any]: + """Create an inline review comment on a specific line in a pull request. + + Args: + codebase: The codebase to operate on + pr_number: The PR number to comment on + body: The comment text + commit_sha: The commit SHA to attach the comment to + path: The file path to comment on + line: The line number to comment on + side: Which version of the file to comment on ('LEFT' or 'RIGHT') + start_line: For multi-line comments, the starting line + + Returns: + Dict containing comment status + """ + try: + codebase.create_pr_review_comment( + pr_number=pr_number, + body=body, + commit_sha=commit_sha, + path=path, + line=line, + side=side, + start_line=start_line, + ) + return { + "status": "success", + "message": "Review comment created successfully", + "pr_number": pr_number, + "path": path, + "line": line, + } + except Exception as e: + return {"error": f"Failed to create PR review comment: {e!s}"} diff --git a/src/codegen/extensions/tools/github/view_pr.py b/src/codegen/extensions/tools/github/view_pr.py new file mode 100644 index 000000000..13b90a0f5 --- /dev/null +++ b/src/codegen/extensions/tools/github/view_pr.py @@ -0,0 +1,21 @@ +"""Tool for viewing PR contents and modified symbols.""" + +from typing import Any + +from codegen import Codebase + + +def view_pr(codebase: Codebase, pr_id: int) -> dict[str, Any]: + """Get the diff and modified symbols of a PR. + + Args: + codebase: The codebase to operate on + pr_id: Number of the PR to get the contents for + + Returns: + Dict containing modified symbols and patch + """ + modified_symbols, patch = codebase.get_modified_symbols_in_pr(pr_id) + + # Convert modified_symbols set to list for JSON serialization + return {"status": "success", "modified_symbols": list(modified_symbols), "patch": patch} diff --git a/src/codegen/extensions/tools/list_directory.py b/src/codegen/extensions/tools/list_directory.py new file mode 100644 index 000000000..903983a45 --- /dev/null +++ b/src/codegen/extensions/tools/list_directory.py @@ -0,0 +1,69 @@ +"""Tool for listing directory contents.""" + +from typing import Any + +from codegen import Codebase +from codegen.sdk.core.directory import Directory + + +def list_directory(codebase: Codebase, dirpath: str = "./", depth: int = 1) -> dict[str, Any]: + """List contents of a directory. + + Args: + codebase: The codebase to operate on + dirpath: Path to directory relative to workspace root + depth: How deep to traverse the directory tree. Default is 1 (immediate children only). + Use -1 for unlimited depth. + + Returns: + Dict containing directory contents and metadata in a nested structure: + { + "path": str, + "name": str, + "files": list[str], + "subdirectories": [ + { + "path": str, + "name": str, + "files": list[str], + "subdirectories": [...], + }, + ... + ] + } + """ + try: + directory = codebase.get_directory(dirpath) + except ValueError: + return {"error": f"Directory not found: {dirpath}"} + + if not directory: + return {"error": f"Directory not found: {dirpath}"} + + def get_directory_info(dir_obj: Directory, current_depth: int) -> dict[str, Any]: + """Helper function to get directory info recursively.""" + # Get direct files + all_files = [] + for file in dir_obj.files: + if file.directory == dir_obj: + all_files.append(file.filepath.split("/")[-1]) + + # Get direct subdirectories + subdirs = [] + for subdir in dir_obj.subdirectories: + # Only include direct descendants + if subdir.parent == dir_obj: + if current_depth != 1: + new_depth = current_depth - 1 if current_depth > 1 else -1 + subdirs.append(get_directory_info(subdir, new_depth)) + else: + # At max depth, just include name + subdirs.append(subdir.name) + return { + "name": dir_obj.name, + "path": dir_obj.dirpath, + "files": all_files, + "subdirectories": subdirs, + } + + return get_directory_info(directory, depth) diff --git a/src/codegen/extensions/tools/move_symbol.py b/src/codegen/extensions/tools/move_symbol.py new file mode 100644 index 000000000..26058599a --- /dev/null +++ b/src/codegen/extensions/tools/move_symbol.py @@ -0,0 +1,61 @@ +"""Tool for moving symbols between files.""" + +from typing import Any, Literal + +from codegen import Codebase + +from .view_file import view_file + + +def move_symbol( + codebase: Codebase, + source_file: str, + symbol_name: str, + target_file: str, + strategy: Literal["update_all_imports", "add_back_edge"] = "update_all_imports", + include_dependencies: bool = True, +) -> dict[str, Any]: + """Move a symbol from one file to another. + + Args: + codebase: The codebase to operate on + source_file: Path to the file containing the symbol + symbol_name: Name of the symbol to move + target_file: Path to the destination file + strategy: Strategy for handling imports: + - "update_all_imports": Updates all import statements across the codebase (default) + - "add_back_edge": Adds import and re-export in the original file + include_dependencies: Whether to move dependencies along with the symbol + + Returns: + Dict containing move status and updated file info, or error information if operation fails + """ + try: + source = codebase.get_file(source_file) + except ValueError: + return {"error": f"Source file not found: {source_file}"} + if source is None: + return {"error": f"Source file not found: {source_file}"} + + try: + target = codebase.get_file(target_file) + except ValueError: + return {"error": f"Target file not found: {target_file}"} + + symbol = source.get_symbol(symbol_name) + if not symbol: + return {"error": f"Symbol '{symbol_name}' not found in {source_file}"} + + try: + symbol.move_to_file(target, include_dependencies=include_dependencies, strategy=strategy) + codebase.commit() + return { + "status": "success", + "symbol": symbol_name, + "source_file": source_file, + "target_file": target_file, + "source_file_info": view_file(codebase, source_file), + "target_file_info": view_file(codebase, target_file), + } + except Exception as e: + return {"error": f"Failed to move symbol: {e!s}"} diff --git a/src/codegen/extensions/tools/rename_file.py b/src/codegen/extensions/tools/rename_file.py new file mode 100644 index 000000000..2be9df52b --- /dev/null +++ b/src/codegen/extensions/tools/rename_file.py @@ -0,0 +1,36 @@ +"""Tool for renaming files and updating imports.""" + +from typing import Any + +from codegen import Codebase + +from .view_file import view_file + + +def rename_file(codebase: Codebase, filepath: str, new_filepath: str) -> dict[str, Any]: + """Rename a file and update all imports to point to the new location. + + Args: + codebase: The codebase to operate on + filepath: Current path of the file relative to workspace root + new_filepath: New path for the file relative to workspace root + + Returns: + Dict containing rename status and new file info, or error information if file not found + """ + try: + file = codebase.get_file(filepath) + except ValueError: + return {"error": f"File not found: {filepath}"} + if file is None: + return {"error": f"File not found: {filepath}"} + + if codebase.has_file(new_filepath): + return {"error": f"Destination file already exists: {new_filepath}"} + + try: + file.update_filepath(new_filepath) + codebase.commit() + return {"status": "success", "old_filepath": filepath, "new_filepath": new_filepath, "file_info": view_file(codebase, new_filepath)} + except Exception as e: + return {"error": f"Failed to rename file: {e!s}"} diff --git a/src/codegen/extensions/tools/reveal_symbol.py b/src/codegen/extensions/tools/reveal_symbol.py index 04eb01746..341ddad1d 100644 --- a/src/codegen/extensions/tools/reveal_symbol.py +++ b/src/codegen/extensions/tools/reveal_symbol.py @@ -2,6 +2,7 @@ import tiktoken +from codegen import Codebase from codegen.sdk.core.external_module import ExternalModule from codegen.sdk.core.import_resolution import Import from codegen.sdk.core.symbol import Symbol @@ -211,8 +212,10 @@ def under_token_limit() -> bool: def reveal_symbol( - symbol: Symbol, - degree: Optional[int] = 1, + codebase: Codebase, + symbol_name: str, + filepath: Optional[str] = None, + max_depth: Optional[int] = 1, max_tokens: Optional[int] = None, collect_dependencies: Optional[bool] = True, collect_usages: Optional[bool] = True, @@ -220,8 +223,10 @@ def reveal_symbol( """Reveal the dependencies and usages of a symbol up to N degrees. Args: - symbol: The symbol to analyze - degree: How many degrees of separation to traverse (default: 1) + codebase: The codebase to analyze + symbol_name: The name of the symbol to analyze + filepath: Optional filepath to the symbol to analyze + max_depth: How many degrees of separation to traverse (default: 1) max_tokens: Optional maximum number of tokens for all source code combined collect_dependencies: Whether to collect dependencies (default: True) collect_usages: Whether to collect usages (default: True) @@ -233,12 +238,18 @@ def reveal_symbol( - truncated: Whether the results were truncated due to max_tokens - error: Optional error message if the symbol was not found """ - # Check if we got a valid symbol - if symbol is None: - return {"error": "Symbol not found", "truncated": False, "dependencies": [], "usages": []} + symbols = codebase.get_symbols(symbol_name=symbol_name) + if len(symbols) == 0: + return {"error": f"{symbol_name} not found"} + if len(symbols) > 1: + return {"error": f"{symbol_name} is ambiguious", "valid_filepaths": [s.file.filepath for s in symbols]} + symbol = symbols[0] + if filepath: + if symbol.file.filepath != filepath: + return {"error": f"{symbol_name} not found at {filepath}", "valid_filepaths": [s.file.filepath for s in symbols]} # Get dependencies and usages up to specified degree - dependencies, usages, total_tokens = get_extended_context(symbol, degree, max_tokens, collect_dependencies=collect_dependencies, collect_usages=collect_usages) + dependencies, usages, total_tokens = get_extended_context(symbol, max_depth, max_tokens, collect_dependencies=collect_dependencies, collect_usages=collect_usages) was_truncated = max_tokens is not None and total_tokens >= max_tokens diff --git a/src/codegen/extensions/tools/view_file.py b/src/codegen/extensions/tools/view_file.py new file mode 100644 index 000000000..6817e0ad9 --- /dev/null +++ b/src/codegen/extensions/tools/view_file.py @@ -0,0 +1,31 @@ +"""Tool for viewing file contents and metadata.""" + +from typing import Any + +from codegen import Codebase + + +def view_file(codebase: Codebase, filepath: str) -> dict[str, Any]: + """View the contents and metadata of a file. + + Args: + codebase: The codebase to operate on + filepath: Path to the file relative to workspace root + + Returns: + Dict containing file contents and metadata, or error information if file not found + """ + file = None + + try: + file = codebase.get_file(filepath) + except ValueError: + pass + + if not file: + return {"error": f"File not found: {filepath}. Please use full filepath relative to workspace root."} + + return { + "filepath": file.filepath, + "content": file.content, + } diff --git a/tests/unit/codegen/extensions/test_tools.py b/tests/unit/codegen/extensions/test_tools.py new file mode 100644 index 000000000..e8dcd8a6a --- /dev/null +++ b/tests/unit/codegen/extensions/test_tools.py @@ -0,0 +1,181 @@ +"""Tests for codebase tools.""" + +import pytest + +from codegen.extensions.tools import ( + create_file, + create_pr, + create_pr_comment, + create_pr_review_comment, + delete_file, + edit_file, + list_directory, + move_symbol, + rename_file, + reveal_symbol, + search, + semantic_edit, + semantic_search, + view_file, + view_pr, +) +from codegen.sdk.codebase.factory.get_session import get_codebase_session + + +@pytest.fixture +def codebase(tmpdir): + """Create a simple codebase for testing.""" + # language=python + content = """ +def hello(): + print("Hello, world!") + +class Greeter: + def greet(self): + hello() +""" + with get_codebase_session(tmpdir=tmpdir, files={"src/main.py": content}) as codebase: + yield codebase + + +def test_view_file(codebase): + """Test viewing a file.""" + result = view_file(codebase, "src/main.py") + assert "error" not in result + assert result["filepath"] == "src/main.py" + assert "hello()" in result["content"] + + +def test_list_directory(codebase): + """Test listing directory contents.""" + result = list_directory(codebase, "./") + assert "error" not in result + assert "src" in result["subdirectories"] + + +def test_search(codebase): + """Test searching the codebase.""" + result = search(codebase, "hello") + assert "error" not in result + assert len(result["results"]) > 0 + + +def test_edit_file(codebase): + """Test editing a file.""" + result = edit_file(codebase, "src/main.py", "print('edited')") + assert "error" not in result + assert result["content"] == "print('edited')" + + +def test_create_file(codebase): + """Test creating a file.""" + result = create_file(codebase, "src/new.py", "print('new')") + assert "error" not in result + assert result["filepath"] == "src/new.py" + assert result["content"] == "print('new')" + + +def test_delete_file(codebase): + """Test deleting a file.""" + result = delete_file(codebase, "src/main.py") + assert "error" not in result + assert result["status"] == "success" + + +def test_rename_file(codebase): + """Test renaming a file.""" + result = rename_file(codebase, "src/main.py", "src/renamed.py") + assert "error" not in result + assert result["status"] == "success" + assert result["new_filepath"] == "src/renamed.py" + + +def test_move_symbol(codebase): + """Test moving a symbol between files.""" + # Create target file first + create_file(codebase, "src/target.py", "") + + result = move_symbol( + codebase, + source_file="src/main.py", + symbol_name="hello", + target_file="src/target.py", + ) + assert "error" not in result + assert result["status"] == "success" + + +def test_reveal_symbol(codebase): + """Test revealing symbol relationships.""" + result = reveal_symbol( + codebase, + symbol_name="hello", + max_depth=1, + ) + assert "error" not in result + assert not result["truncated"] + + +@pytest.mark.skip("TODO") +def test_semantic_edit(codebase): + """Test semantic editing.""" + edit_spec = """ +# ... existing code ... +def hello(): + print("Hello from semantic edit!") +# ... existing code ... +""" + result = semantic_edit(codebase, "src/main.py", edit_spec) + assert "error" not in result + assert result["status"] == "success" + + +@pytest.mark.skip("TODO") +def test_semantic_search(codebase): + """Test semantic search.""" + result = semantic_search(codebase, "function that prints hello") + assert "error" not in result + assert result["status"] == "success" + + +@pytest.mark.skip("TODO: Github tests") +def test_create_pr(codebase): + """Test creating a PR.""" + result = create_pr(codebase, "Test PR", "This is a test PR") + assert "error" not in result + assert result["status"] == "success" + + +@pytest.mark.skip("TODO: Github tests") +def test_view_pr(codebase): + """Test viewing a PR.""" + result = view_pr(codebase, 1) + assert "error" not in result + assert result["status"] == "success" + assert "modified_symbols" in result + assert "patch" in result + + +@pytest.mark.skip("TODO: Github tests") +def test_create_pr_comment(codebase): + """Test creating a PR comment.""" + result = create_pr_comment(codebase, 1, "Test comment") + assert "error" not in result + assert result["status"] == "success" + assert result["message"] == "Comment created successfully" + + +@pytest.mark.skip("TODO: Github tests") +def test_create_pr_review_comment(codebase): + """Test creating a PR review comment.""" + result = create_pr_review_comment( + codebase, + pr_number=1, + body="Test review comment", + commit_sha="abc123", + path="src/main.py", + line=1, + ) + assert "error" not in result + assert result["status"] == "success" + assert result["message"] == "Review comment created successfully"