Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions src/codegen/extensions/langchain/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,12 @@
from .tools import (
CommitTool,
CreateFileTool,
CreatePRCommentTool,
CreatePRReviewCommentTool,
CreatePRTool,
DeleteFileTool,
EditFileTool,
GetPRcontentsTool,
ListDirectoryTool,
MoveSymbolTool,
RenameFileTool,
Expand Down Expand Up @@ -64,6 +68,10 @@ def create_codebase_agent(
SemanticEditTool(codebase),
SemanticSearchTool(codebase),
CommitTool(codebase),
CreatePRTool(codebase),
GetPRcontentsTool(codebase),
CreatePRCommentTool(codebase),
CreatePRReviewCommentTool(codebase),
]

# Get the prompt to use
Expand Down
39 changes: 19 additions & 20 deletions src/codegen/extensions/langchain/tools.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
"""Langchain tools for workspace operations."""

import json
import uuid
from typing import ClassVar, Literal, Optional

from langchain.tools import BaseTool
Expand All @@ -12,6 +11,9 @@
from ..tools import (
commit,
create_file,
create_pr,
create_pr_comment,
create_pr_review_comment,
delete_file,
edit_file,
list_directory,
Expand All @@ -22,6 +24,7 @@
semantic_edit,
semantic_search,
view_file,
view_pr,
)


Expand Down Expand Up @@ -205,12 +208,11 @@ def _run(
collect_dependencies: bool = True,
collect_usages: bool = True,
) -> str:
# Find the symbol first
found_symbol = self.codebase.get_symbol(symbol_name)
result = reveal_symbol(
found_symbol,
degree,
max_tokens,
codebase=self.codebase,
symbol_name=symbol_name,
degree=degree,
max_tokens=max_tokens,
collect_dependencies=collect_dependencies,
collect_usages=collect_usages,
)
Expand Down Expand Up @@ -356,11 +358,8 @@ def __init__(self, codebase: Codebase) -> None:
super().__init__(codebase=codebase)

def _run(self, title: str, body: str) -> str:
if self.codebase._op.git_cli.active_branch.name == self.codebase._op.default_branch:
# If the current checked out branch is the default branch, checkout onto a new branch
self.codebase.checkout(branch=f"{uuid.uuid4()}", create_if_missing=True)
pr = self.codebase.create_pr(title=title, body=body)
return pr.html_url
result = create_pr(self.codebase, title, body)
return json.dumps(result, indent=2)


class GetPRContentsInput(BaseModel):
Expand All @@ -381,11 +380,7 @@ def __init__(self, codebase: Codebase) -> None:
super().__init__(codebase=codebase)

def _run(self, pr_id: int) -> str:
modified_symbols, patch = self.codebase.get_modified_symbols_in_pr(pr_id)

# Convert modified_symbols set to list for JSON serialization
result = {"modified_symbols": list(modified_symbols), "patch": patch}

result = view_pr(self.codebase, pr_id)
return json.dumps(result, indent=2)


Expand All @@ -408,8 +403,8 @@ def __init__(self, codebase: Codebase) -> None:
super().__init__(codebase=codebase)

def _run(self, pr_number: int, body: str) -> str:
self.codebase.create_pr_comment(pr_number=pr_number, body=body)
return "Comment created successfully"
result = create_pr_comment(self.codebase, pr_number, body)
return json.dumps(result, indent=2)


class CreatePRReviewCommentInput(BaseModel):
Expand Down Expand Up @@ -445,7 +440,8 @@ def _run(
side: str | None = None,
start_line: int | None = None,
) -> str:
self.codebase.create_pr_review_comment(
result = create_pr_review_comment(
self.codebase,
pr_number=pr_number,
body=body,
commit_sha=commit_sha,
Expand All @@ -454,7 +450,7 @@ def _run(
side=side,
start_line=start_line,
)
return "Review comment created successfully"
return json.dumps(result, indent=2)


def get_workspace_tools(codebase: Codebase) -> list["BaseTool"]:
Expand All @@ -476,8 +472,11 @@ def get_workspace_tools(codebase: Codebase) -> list["BaseTool"]:
EditFileTool(codebase),
GetPRcontentsTool(codebase),
ListDirectoryTool(codebase),
MoveSymbolTool(codebase),
RenameFileTool(codebase),
RevealSymbolTool(codebase),
SearchTool(codebase),
SemanticEditTool(codebase),
SemanticSearchTool(codebase),
ViewFileTool(codebase),
]
18 changes: 7 additions & 11 deletions src/codegen/extensions/mcp/codebase_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@

mcp = FastMCP(
"codebase-tools-mcp",
instructions="Use this server to access any information from your codebase. This tool can provide information ranging from AST Symbol details and information from across the codebase. Use this tool for all questions, queries regarding your codebase.",
instructions="""Use this server to access any information from your codebase. This tool can provide information ranging from AST Symbol details and information from across the codebase.
Use this tool for all questions, queries regarding your codebase.""",
)


Expand All @@ -20,21 +21,16 @@
target_file: Annotated[Optional[str], "The file path of the file containing the symbol to inspect"],
codebase_dir: Annotated[str, "The root directory of your codebase"],
codebase_language: Annotated[ProgrammingLanguage, "The language the codebase is written in"],
degree: Annotated[Optional[int], "depth do which symbol information is retrieved"],
max_depth: Annotated[Optional[int], "depth up to which symbol information is retrieved"],
collect_dependencies: Annotated[Optional[bool], "includes dependencies of symbol"],
collect_usages: Annotated[Optional[bool], "includes usages of symbol"],
):
codebase = Codebase(repo_path=codebase_dir, programming_language=codebase_language)

Check failure on line 28 in src/codegen/extensions/mcp/codebase_tools.py

View workflow job for this annotation

GitHub Actions / mypy

error: Need type annotation for "codebase" [var-annotated]
found_symbol = None
if target_file:
file = codebase.get_file(target_file)
found_symbol = file.get_symbol(symbol_name)
else:
found_symbol = codebase.get_symbol(symbol_name)

result = reveal_symbol(
found_symbol,
degree,
codebase=codebase,
symbol_name=symbol_name,
filepath=target_file,
max_depth=max_depth,
collect_dependencies=collect_dependencies,
collect_usages=collect_usages,
)
Expand All @@ -49,7 +45,7 @@
codebase_language: Annotated[ProgrammingLanguage, "The language the codebase is written in"],
use_regex: Annotated[bool, "use regex for the search query"],
):
codebase = Codebase(repo_path=codebase_dir, programming_language=codebase_language)

Check failure on line 48 in src/codegen/extensions/mcp/codebase_tools.py

View workflow job for this annotation

GitHub Actions / mypy

error: Need type annotation for "codebase" [var-annotated]
result = search(codebase, query, target_directories, use_regex=use_regex)
return json.dumps(result, indent=2)

Expand Down
4 changes: 4 additions & 0 deletions src/codegen/extensions/tools/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
# Tools

- should take in a `codebase` and string args
- gets "wrapped" by extensions, e.g. MCP or Langchain
35 changes: 21 additions & 14 deletions src/codegen/extensions/tools/__init__.py
Original file line number Diff line number Diff line change
@@ -1,35 +1,42 @@
"""Tools for workspace operations."""

from .file_operations import (
commit,
create_file,
delete_file,
edit_file,
list_directory,
move_symbol,
rename_file,
view_file,
)
from .commit import commit
from .create_file import create_file
from .delete_file import delete_file
from .edit_file import edit_file
from .github.create_pr import create_pr
from .github.create_pr_comment import create_pr_comment
from .github.create_pr_review_comment import create_pr_review_comment
from .github.view_pr import view_pr
from .list_directory import list_directory
from .move_symbol import move_symbol
from .rename_file import rename_file
from .reveal_symbol import reveal_symbol
from .search import search
from .semantic_edit import semantic_edit
from .semantic_search import semantic_search
from .view_file import view_file

__all__ = [
# Git operations
"commit",
# File operations
"create_file",
"create_pr",
"create_pr_comment",
"create_pr_review_comment",
"delete_file",
"edit_file",
"list_directory",
# Symbol analysis
# Symbol operations
"move_symbol",
# File operations
"rename_file",
"reveal_symbol",
# Search
# Search operations
"search",
# Semantic edit
# Edit operations
"semantic_edit",
"semantic_search",
"view_file",
"view_pr",
]
18 changes: 18 additions & 0 deletions src/codegen/extensions/tools/commit.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
"""Tool for committing changes to disk."""

from typing import Any

from codegen import Codebase


def commit(codebase: Codebase) -> dict[str, Any]:
"""Commit any pending changes to disk.

Args:
codebase: The codebase to operate on

Returns:
Dict containing commit status
"""
codebase.commit()
return {"status": "success", "message": "Changes committed to disk"}
25 changes: 25 additions & 0 deletions src/codegen/extensions/tools/create_file.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
"""Tool for creating new files."""

from typing import Any

from codegen import Codebase

from .view_file import view_file


def create_file(codebase: Codebase, filepath: str, content: str = "") -> dict[str, Any]:
"""Create a new file.

Args:
codebase: The codebase to operate on
filepath: Path where to create the file
content: Initial file content

Returns:
Dict containing new file state, or error information if file already exists
"""
if codebase.has_file(filepath):
return {"error": f"File already exists: {filepath}"}
file = codebase.create_file(filepath, content=content)
codebase.commit()
return view_file(codebase, filepath)
27 changes: 27 additions & 0 deletions src/codegen/extensions/tools/delete_file.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
"""Tool for deleting files."""

from typing import Any

from codegen import Codebase


def delete_file(codebase: Codebase, filepath: str) -> dict[str, Any]:
"""Delete a file.

Args:
codebase: The codebase to operate on
filepath: Path to the file to delete

Returns:
Dict containing deletion status, or error information if file not found
"""
try:
file = codebase.get_file(filepath)
except ValueError:
return {"error": f"File not found: {filepath}"}
if file is None:
return {"error": f"File not found: {filepath}"}

file.remove()
codebase.commit()
return {"status": "success", "deleted_file": filepath}
30 changes: 30 additions & 0 deletions src/codegen/extensions/tools/edit_file.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
"""Tool for editing file contents."""

from typing import Any

from codegen import Codebase

from .view_file import view_file


def edit_file(codebase: Codebase, filepath: str, content: str) -> dict[str, Any]:
"""Edit a file by replacing its entire content.

Args:
codebase: The codebase to operate on
filepath: Path to the file to edit
content: New content for the file

Returns:
Dict containing updated file state, or error information if file not found
"""
try:
file = codebase.get_file(filepath)
except ValueError:
return {"error": f"File not found: {filepath}"}
if file is None:
return {"error": f"File not found: {filepath}"}

file.edit(content)
codebase.commit()
return view_file(codebase, filepath)
34 changes: 34 additions & 0 deletions src/codegen/extensions/tools/github/create_pr.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
"""Tool for creating pull requests."""

import uuid
from typing import Any

from codegen import Codebase


def create_pr(codebase: Codebase, title: str, body: str) -> dict[str, Any]:
"""Create a PR for the current branch.

Args:
codebase: The codebase to operate on
title: The title of the PR
body: The body/description of the PR

Returns:
Dict containing PR info, or error information if operation fails
"""
try:
# If on default branch, create a new branch
if codebase._op.git_cli.active_branch.name == codebase._op.default_branch:
codebase.checkout(branch=f"{uuid.uuid4()}", create_if_missing=True)

# Create the PR
pr = codebase.create_pr(title=title, body=body)
return {
"status": "success",
"url": pr.html_url,
"number": pr.number,
"title": pr.title,
}
except Exception as e:
return {"error": f"Failed to create PR: {e!s}"}
Loading