diff --git a/src/codegen/extensions/tools/github/create_pr.py b/src/codegen/extensions/tools/github/create_pr.py index 42eab67ac..1f6a7e307 100644 --- a/src/codegen/extensions/tools/github/create_pr.py +++ b/src/codegen/extensions/tools/github/create_pr.py @@ -3,6 +3,8 @@ import uuid from typing import Any +from github import GithubException + from codegen import Codebase @@ -18,12 +20,25 @@ def create_pr(codebase: Codebase, title: str, body: str) -> dict[str, Any]: Dict containing PR info, or error information if operation fails """ try: + # Check for uncommitted changes and commit them + if len(codebase.get_diff()) == 0: + return {"error": "No changes to create a PR."} + + # TODO: this is very jank. We should ideally check out the branch before + # making the changes, but it looks like `codebase.checkout` blows away + # all of your changes + codebase.git_commit(".") + # If on default branch, create a new branch if codebase._op.git_cli.active_branch.name == codebase._op.default_branch: codebase.checkout(branch=f"{uuid.uuid4()}", create_if_missing=True) # Create the PR - pr = codebase.create_pr(title=title, body=body) + try: + pr = codebase.create_pr(title=title, body=body) + except GithubException as e: + print(e) + return {"error": "Failed to create PR. Check if the PR already exists."} return { "status": "success", "url": pr.html_url, diff --git a/src/codegen/sdk/core/codebase.py b/src/codegen/sdk/core/codebase.py index 8b57df722..0c31dd644 100644 --- a/src/codegen/sdk/core/codebase.py +++ b/src/codegen/sdk/core/codebase.py @@ -740,7 +740,7 @@ def get_relative_path(self, from_file: str, to_file: str) -> str: #################################################################################################################### def git_commit(self, message: str, *, verify: bool = False) -> GitCommit | None: - """Commits all staged changes to the codebase and git. + """Stages + commits all changes to the codebase and git. Args: message (str): The commit message @@ -753,6 +753,8 @@ def git_commit(self, message: str, *, verify: bool = False) -> GitCommit | None: if self._op.stage_and_commit_all_changes(message, verify): logger.info(f"Commited repository to {self._op.head_commit} on {self._op.get_active_branch_or_commit()}") return self._op.head_commit + else: + logger.info("No changes to commit") return None def commit(self, sync_graph: bool = True) -> None: