diff --git a/src/codegen/extensions/langchain/agent.py b/src/codegen/extensions/langchain/agent.py index 8909e02d7..22964c5ac 100644 --- a/src/codegen/extensions/langchain/agent.py +++ b/src/codegen/extensions/langchain/agent.py @@ -13,12 +13,12 @@ from .tools import ( CommitTool, CreateFileTool, - CreatePRCommentTool, - CreatePRReviewCommentTool, - CreatePRTool, DeleteFileTool, EditFileTool, - GetPRcontentsTool, + GithubCreatePRCommentTool, + GithubCreatePRReviewCommentTool, + GithubCreatePRTool, + GithubViewPRTool, ListDirectoryTool, MoveSymbolTool, RenameFileTool, @@ -68,10 +68,10 @@ def create_codebase_agent( SemanticEditTool(codebase), SemanticSearchTool(codebase), CommitTool(codebase), - CreatePRTool(codebase), - GetPRcontentsTool(codebase), - CreatePRCommentTool(codebase), - CreatePRReviewCommentTool(codebase), + GithubCreatePRTool(codebase), + GithubViewPRTool(codebase), + GithubCreatePRCommentTool(codebase), + GithubCreatePRReviewCommentTool(codebase), ] # Get the prompt to use diff --git a/src/codegen/extensions/langchain/tools.py b/src/codegen/extensions/langchain/tools.py index 8ab951b2d..2aa76f25d 100644 --- a/src/codegen/extensions/langchain/tools.py +++ b/src/codegen/extensions/langchain/tools.py @@ -8,7 +8,7 @@ from codegen import Codebase from codegen.extensions.linear.linear_client import LinearClient -from codegen.extensions.tools.linear_tools import ( +from codegen.extensions.tools.linear.linear import ( linear_comment_on_issue_tool, linear_create_issue_tool, linear_get_issue_comments_tool, @@ -354,19 +354,24 @@ def _run(self, query: str, k: int = 5, preview_length: int = 200) -> str: return json.dumps(result, indent=2) -class CreatePRInput(BaseModel): +######################################################################################################################## +# GITHUB +######################################################################################################################## + + +class GithubCreatePRInput(BaseModel): """Input for creating a PR""" title: str = Field(..., description="The title of the PR") body: str = Field(..., description="The body of the PR") -class CreatePRTool(BaseTool): +class GithubCreatePRTool(BaseTool): """Tool for creating a PR.""" name: ClassVar[str] = "create_pr" description: ClassVar[str] = "Create a PR for the current branch" - args_schema: ClassVar[type[BaseModel]] = CreatePRInput + args_schema: ClassVar[type[BaseModel]] = GithubCreatePRInput codebase: Codebase = Field(exclude=True) def __init__(self, codebase: Codebase) -> None: @@ -377,18 +382,18 @@ def _run(self, title: str, body: str) -> str: return json.dumps(result, indent=2) -class GetPRContentsInput(BaseModel): +class GithubViewPRInput(BaseModel): """Input for getting PR contents.""" pr_id: int = Field(..., description="Number of the PR to get the contents for") -class GetPRcontentsTool(BaseTool): +class GithubViewPRTool(BaseTool): """Tool for getting PR data.""" - name: ClassVar[str] = "get_pr_contents" - description: ClassVar[str] = "Get the diff and modified symbols of a PR along with the dependencies of the modified symbols" - args_schema: ClassVar[type[BaseModel]] = GetPRContentsInput + name: ClassVar[str] = "view_pr" + description: ClassVar[str] = "View the diff and associated context for a pull request" + args_schema: ClassVar[type[BaseModel]] = GithubViewPRInput codebase: Codebase = Field(exclude=True) def __init__(self, codebase: Codebase) -> None: @@ -399,19 +404,19 @@ def _run(self, pr_id: int) -> str: return json.dumps(result, indent=2) -class CreatePRCommentInput(BaseModel): +class GithubCreatePRCommentInput(BaseModel): """Input for creating a PR comment""" pr_number: int = Field(..., description="The PR number to comment on") body: str = Field(..., description="The comment text") -class CreatePRCommentTool(BaseTool): +class GithubCreatePRCommentTool(BaseTool): """Tool for creating a general PR comment.""" name: ClassVar[str] = "create_pr_comment" description: ClassVar[str] = "Create a general comment on a pull request" - args_schema: ClassVar[type[BaseModel]] = CreatePRCommentInput + args_schema: ClassVar[type[BaseModel]] = GithubCreatePRCommentInput codebase: Codebase = Field(exclude=True) def __init__(self, codebase: Codebase) -> None: @@ -422,7 +427,7 @@ def _run(self, pr_number: int, body: str) -> str: return json.dumps(result, indent=2) -class CreatePRReviewCommentInput(BaseModel): +class GithubCreatePRReviewCommentInput(BaseModel): """Input for creating an inline PR review comment""" pr_number: int = Field(..., description="The PR number to comment on") @@ -434,12 +439,12 @@ class CreatePRReviewCommentInput(BaseModel): start_line: int | None = Field(None, description="For multi-line comments, the starting line") -class CreatePRReviewCommentTool(BaseTool): +class GithubCreatePRReviewCommentTool(BaseTool): """Tool for creating inline PR review comments.""" name: ClassVar[str] = "create_pr_review_comment" description: ClassVar[str] = "Create an inline review comment on a specific line in a pull request" - args_schema: ClassVar[type[BaseModel]] = CreatePRReviewCommentInput + args_schema: ClassVar[type[BaseModel]] = GithubCreatePRReviewCommentInput codebase: Codebase = Field(exclude=True) def __init__(self, codebase: Codebase) -> None: @@ -468,6 +473,11 @@ def _run( return json.dumps(result, indent=2) +######################################################################################################################## +# LINEAR +######################################################################################################################## + + class LinearGetIssueInput(BaseModel): """Input for getting a Linear issue.""" @@ -597,6 +607,11 @@ def _run(self) -> str: return json.dumps(result, indent=2) +######################################################################################################################## +# EXPORT +######################################################################################################################## + + def get_workspace_tools(codebase: Codebase) -> list["BaseTool"]: """Get all workspace tools initialized with a codebase. @@ -609,12 +624,9 @@ def get_workspace_tools(codebase: Codebase) -> list["BaseTool"]: return [ CommitTool(codebase), CreateFileTool(codebase), - CreatePRTool(codebase), - CreatePRCommentTool(codebase), - CreatePRReviewCommentTool(codebase), DeleteFileTool(codebase), EditFileTool(codebase), - GetPRcontentsTool(codebase), + GithubViewPRTool(codebase), ListDirectoryTool(codebase), MoveSymbolTool(codebase), RenameFileTool(codebase), @@ -623,6 +635,12 @@ def get_workspace_tools(codebase: Codebase) -> list["BaseTool"]: SemanticEditTool(codebase), SemanticSearchTool(codebase), ViewFileTool(codebase), + # Github + GithubCreatePRTool(codebase), + GithubCreatePRCommentTool(codebase), + GithubCreatePRReviewCommentTool(codebase), + GithubViewPRTool(codebase), + # Linear LinearGetIssueTool(codebase), LinearGetIssueCommentsTool(codebase), LinearCommentOnIssueTool(codebase), diff --git a/src/codegen/extensions/tools/__init__.py b/src/codegen/extensions/tools/__init__.py index 06ededaf8..70375f9b8 100644 --- a/src/codegen/extensions/tools/__init__.py +++ b/src/codegen/extensions/tools/__init__.py @@ -8,7 +8,7 @@ from .github.create_pr_comment import create_pr_comment from .github.create_pr_review_comment import create_pr_review_comment from .github.view_pr import view_pr -from .linear_tools import ( +from .linear import ( linear_comment_on_issue_tool, linear_get_issue_comments_tool, linear_get_issue_tool, diff --git a/src/codegen/extensions/tools/github/__init__.py b/src/codegen/extensions/tools/github/__init__.py new file mode 100644 index 000000000..a59669dd2 --- /dev/null +++ b/src/codegen/extensions/tools/github/__init__.py @@ -0,0 +1,11 @@ +from .create_pr import create_pr +from .create_pr_comment import create_pr_comment +from .create_pr_review_comment import create_pr_review_comment +from .view_pr import view_pr + +__all__ = [ + "create_pr", + "create_pr_comment", + "create_pr_review_comment", + "view_pr", +] diff --git a/src/codegen/extensions/tools/github/view_pr.py b/src/codegen/extensions/tools/github/view_pr.py index 13b90a0f5..c077b56a1 100644 --- a/src/codegen/extensions/tools/github/view_pr.py +++ b/src/codegen/extensions/tools/github/view_pr.py @@ -18,4 +18,4 @@ def view_pr(codebase: Codebase, pr_id: int) -> dict[str, Any]: modified_symbols, patch = codebase.get_modified_symbols_in_pr(pr_id) # Convert modified_symbols set to list for JSON serialization - return {"status": "success", "modified_symbols": list(modified_symbols), "patch": patch} + return {"status": "success", "patch": patch} diff --git a/src/codegen/extensions/tools/linear/__init__.py b/src/codegen/extensions/tools/linear/__init__.py new file mode 100644 index 000000000..8d15e1b72 --- /dev/null +++ b/src/codegen/extensions/tools/linear/__init__.py @@ -0,0 +1,19 @@ +from .linear import ( + linear_comment_on_issue_tool, + linear_create_issue_tool, + linear_get_issue_comments_tool, + linear_get_issue_tool, + linear_get_teams_tool, + linear_register_webhook_tool, + linear_search_issues_tool, +) + +__all__ = [ + "linear_comment_on_issue_tool", + "linear_create_issue_tool", + "linear_get_issue_comments_tool", + "linear_get_issue_tool", + "linear_get_teams_tool", + "linear_register_webhook_tool", + "linear_search_issues_tool", +] diff --git a/src/codegen/extensions/tools/linear_tools.py b/src/codegen/extensions/tools/linear/linear.py similarity index 100% rename from src/codegen/extensions/tools/linear_tools.py rename to src/codegen/extensions/tools/linear/linear.py diff --git a/tests/integration/__init__.py b/tests/integration/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/integration/extension/__init__.py b/tests/integration/extension/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/integration/extension/test_github.py b/tests/integration/extension/test_github.py new file mode 100644 index 000000000..a7ced0a3e --- /dev/null +++ b/tests/integration/extension/test_github.py @@ -0,0 +1,26 @@ +"""Tests for Linear tools.""" + +import os + +import pytest + +from codegen import Codebase +from codegen.extensions.linear.linear_client import LinearClient +from codegen.extensions.tools.github import view_pr + + +@pytest.fixture +def client() -> LinearClient: + """Create a Linear client for testing.""" + token = os.getenv("CODEGEN_SECRETS__GITHUB_TOKEN") + if not token: + pytest.skip("CODEGEN_SECRETS__GITHUB_TOKEN environment variable not set") + codebase = Codebase.from_repo("codegen-sh/Kevin-s-Adventure-Game") + return codebase + + +def test_github_view_pr(client: LinearClient) -> None: + """Test getting an issue from Linear.""" + # Link to PR: https://github.com/codegen-sh/Kevin-s-Adventure-Game/pull/419 + pr = view_pr(client, 419) + print(pr) diff --git a/tests/integration/extension/test_linear.py b/tests/integration/extension/test_linear.py index 5a758f2cb..4d39c9788 100644 --- a/tests/integration/extension/test_linear.py +++ b/tests/integration/extension/test_linear.py @@ -5,7 +5,7 @@ import pytest from codegen.extensions.linear.linear_client import LinearClient -from codegen.extensions.tools.linear_tools import ( +from codegen.extensions.tools.linear.linear import ( linear_comment_on_issue_tool, linear_create_issue_tool, linear_get_issue_comments_tool,