diff --git a/src/codegen/git/repo_operator/local_repo_operator.py b/src/codegen/git/repo_operator/local_repo_operator.py index 14f4c9dc9..7cc4e506e 100644 --- a/src/codegen/git/repo_operator/local_repo_operator.py +++ b/src/codegen/git/repo_operator/local_repo_operator.py @@ -4,11 +4,10 @@ from typing import Self, override from codeowners import CodeOwners as CodeOwnersParser -from git import Remote from git import Repo as GitCLI -from git.remote import PushInfoList from github import Github from github.PullRequest import PullRequest +from github.Repository import Repository from codegen.git.clients.git_repo_client import GitRepoClient from codegen.git.repo_operator.repo_operator import RepoOperator @@ -16,6 +15,7 @@ from codegen.git.schemas.repo_config import RepoConfig from codegen.git.utils.clone_url import url_to_github from codegen.git.utils.file_utils import create_files +from codegen.shared.configs.config import config logger = logging.getLogger(__name__) @@ -41,7 +41,7 @@ def __init__( github_api_key: str | None = None, bot_commit: bool = False, ) -> None: - self._github_api_key = github_api_key + self._github_api_key = github_api_key or config.secrets.github_token self._remote_git_repo = None super().__init__(repo_config, bot_commit) os.makedirs(self.repo_path, exist_ok=True) @@ -52,7 +52,7 @@ def __init__( #################################################################################################################### @property - def remote_git_repo(self) -> GitRepoClient: + def remote_git_repo(self) -> Repository: if self._remote_git_repo is None: if not self._github_api_key: return None @@ -173,10 +173,6 @@ def base_url(self) -> str | None: if remote := next(iter(self.git_cli.remotes), None): return url_to_github(remote.url, self.get_active_branch_or_commit()) - @override - def push_changes(self, remote: Remote | None = None, refspec: str | None = None, force: bool = False) -> PushInfoList: - raise OperatorIsLocal() - @override def pull_repo(self) -> None: """Pull the latest commit down to an existing local repo""" diff --git a/src/codegen/git/repo_operator/remote_repo_operator.py b/src/codegen/git/repo_operator/remote_repo_operator.py index 1b6141cf2..a9f711551 100644 --- a/src/codegen/git/repo_operator/remote_repo_operator.py +++ b/src/codegen/git/repo_operator/remote_repo_operator.py @@ -4,8 +4,7 @@ from typing import override from codeowners import CodeOwners as CodeOwnersParser -from git import GitCommandError, Remote -from git.remote import PushInfoList +from git import GitCommandError from codegen.git.clients.git_repo_client import GitRepoClient from codegen.git.repo_operator.repo_operator import RepoOperator @@ -165,43 +164,6 @@ def checkout_remote_branch(self, branch_name: str | None = None, remote_name: st """ return self.checkout_branch(branch_name, remote_name=remote_name, remote=True, create_if_missing=False) - @stopwatch - def push_changes(self, remote: Remote | None = None, refspec: str | None = None, force: bool = False) -> PushInfoList: - """Push the changes to the given refspec of the remote. - - Args: - refspec (str | None): refspec to push. If None, the current active branch is used. - remote (Remote | None): Remote to push too. Defaults to 'origin'. - force (bool): If True, force push the changes. Defaults to False. - """ - # Use default remote if not provided - if not remote: - remote = self.git_cli.remote(name="origin") - - # Use the current active branch if no branch is specified - if not refspec: - # TODO: doesn't work with detached HEAD state - refspec = self.git_cli.active_branch.name - - res = remote.push(refspec=refspec, force=force, progress=CustomRemoteProgress()) - for push_info in res: - if push_info.flags & push_info.ERROR: - # Handle the error case - logger.warning(f"Error pushing {refspec}: {push_info.summary}") - elif push_info.flags & push_info.FAST_FORWARD: - # Successful fast-forward push - logger.info(f"{refspec} pushed successfully (fast-forward).") - elif push_info.flags & push_info.NEW_HEAD: - # Successful push of a new branch - logger.info(f"{refspec} pushed successfully as a new branch.") - elif push_info.flags & push_info.NEW_TAG: - # Successful push of a new tag (if relevant) - logger.info("New tag pushed successfully.") - else: - # Successful push, general case - logger.info(f"{refspec} pushed successfully.") - return res - @cached_property def base_url(self) -> str | None: repo_config = self.repo_config diff --git a/src/codegen/git/repo_operator/repo_operator.py b/src/codegen/git/repo_operator/repo_operator.py index 9312b6135..692fbf28a 100644 --- a/src/codegen/git/repo_operator/repo_operator.py +++ b/src/codegen/git/repo_operator/repo_operator.py @@ -17,6 +17,7 @@ from codegen.git.configs.constants import CODEGEN_BOT_EMAIL, CODEGEN_BOT_NAME from codegen.git.schemas.enums import CheckoutResult, FetchResult from codegen.git.schemas.repo_config import RepoConfig +from codegen.git.utils.remote_progress import CustomRemoteProgress from codegen.shared.performance.stopwatch_utils import stopwatch from codegen.shared.performance.time_utils import humanize_duration @@ -137,7 +138,17 @@ def git_diff(self) -> str: @property def default_branch(self) -> str: - return self._default_branch or self.git_cli.active_branch.name + # Priority 1: If default branch has been set + if self._default_branch: + return self._default_branch + + # Priority 2: If origin/HEAD ref exists + origin_prefix = "origin" + if f"{origin_prefix}/HEAD" in self.git_cli.refs: + return self.git_cli.refs[f"{origin_prefix}/HEAD"].reference.name.removeprefix(f"{origin_prefix}/") + + # Priority 3: Fallback to the active branch + return self.git_cli.active_branch.name @abstractmethod def codeowners_parser(self) -> CodeOwnersParser | None: ... @@ -372,14 +383,42 @@ def commit_changes(self, message: str, verify: bool = False) -> bool: logger.info("No changes to commit. Do nothing.") return False - @abstractmethod + @stopwatch def push_changes(self, remote: Remote | None = None, refspec: str | None = None, force: bool = False) -> PushInfoList: - """Push the changes to the given refspec of the remote repository. + """Push the changes to the given refspec of the remote. Args: refspec (str | None): refspec to push. If None, the current active branch is used. remote (Remote | None): Remote to push too. Defaults to 'origin'. + force (bool): If True, force push the changes. Defaults to False. """ + # Use default remote if not provided + if not remote: + remote = self.git_cli.remote(name="origin") + + # Use the current active branch if no branch is specified + if not refspec: + # TODO: doesn't work with detached HEAD state + refspec = self.git_cli.active_branch.name + + res = remote.push(refspec=refspec, force=force, progress=CustomRemoteProgress()) + for push_info in res: + if push_info.flags & push_info.ERROR: + # Handle the error case + logger.warning(f"Error pushing {refspec}: {push_info.summary}") + elif push_info.flags & push_info.FAST_FORWARD: + # Successful fast-forward push + logger.info(f"{refspec} pushed successfully (fast-forward).") + elif push_info.flags & push_info.NEW_HEAD: + # Successful push of a new branch + logger.info(f"{refspec} pushed successfully as a new branch.") + elif push_info.flags & push_info.NEW_TAG: + # Successful push of a new tag (if relevant) + logger.info("New tag pushed successfully.") + else: + # Successful push, general case + logger.info(f"{refspec} pushed successfully.") + return res def relpath(self, abspath) -> str: # TODO: check if the path is an abspath (i.e. contains self.repo_path) diff --git a/src/codegen/sdk/core/codebase.py b/src/codegen/sdk/core/codebase.py index ca2a6087b..50b62b27b 100644 --- a/src/codegen/sdk/core/codebase.py +++ b/src/codegen/sdk/core/codebase.py @@ -15,6 +15,7 @@ from git import Commit as GitCommit from git import Diff from git.remote import PushInfoList +from github.PullRequest import PullRequest from networkx import Graph from rich.console import Console from typing_extensions import deprecated @@ -872,6 +873,19 @@ def restore_stashed_changes(self): """Restore the most recent stash in the codebase.""" self._op.stash_pop() + #################################################################################################################### + # GITHUB + #################################################################################################################### + + def create_pr(self, title: str, body: str) -> PullRequest: + """Creates a PR from the current branch.""" + if self._op.git_cli.head.is_detached: + msg = "Cannot make a PR from a detached HEAD" + raise ValueError(msg) + self._op.stage_and_commit_all_changes(message=title) + self._op.push_changes() + return self._op.remote_git_repo.create_pull(head=self._op.git_cli.active_branch.name, base=self._op.default_branch, title=title, body=body) + #################################################################################################################### # GRAPH VISUALIZATION ####################################################################################################################