From 9ad3c5d92eab8d3bfa5ef87f55be71270fb95496 Mon Sep 17 00:00:00 2001 From: KopekC Date: Fri, 21 Feb 2025 12:39:22 -0500 Subject: [PATCH 1/2] feat: final set of upgrades for tools --- src/codegen/extensions/tools/github/view_pr.py | 9 +++++++-- src/codegen/git/utils/pr_review.py | 5 +++-- src/codegen/sdk/core/codebase.py | 4 ++-- 3 files changed, 12 insertions(+), 6 deletions(-) diff --git a/src/codegen/extensions/tools/github/view_pr.py b/src/codegen/extensions/tools/github/view_pr.py index 65537d672..d30edc8e6 100644 --- a/src/codegen/extensions/tools/github/view_pr.py +++ b/src/codegen/extensions/tools/github/view_pr.py @@ -5,7 +5,7 @@ from pydantic import Field from codegen.sdk.core.codebase import Codebase - +from codegen.sdk.core.symbol import Symbol from ..observation import Observation @@ -21,6 +21,9 @@ class ViewPRObservation(Observation): file_commit_sha: dict[str, str] = Field( description="Commit SHAs for each file in the PR", ) + modified_symbols: list[str] = Field( + description="Names of modified symbols in the PR", + ) str_template: ClassVar[str] = "PR #{pr_id}" @@ -33,13 +36,14 @@ def view_pr(codebase: Codebase, pr_id: int) -> ViewPRObservation: pr_id: Number of the PR to get the contents for """ try: - patch, file_commit_sha = codebase.get_modified_symbols_in_pr(pr_id) + patch, file_commit_sha, moddified_symbols = codebase.get_modified_symbols_in_pr(pr_id) return ViewPRObservation( status="success", pr_id=pr_id, patch=patch, file_commit_sha=file_commit_sha, + modified_symbols=moddified_symbols, ) except Exception as e: @@ -49,4 +53,5 @@ def view_pr(codebase: Codebase, pr_id: int) -> ViewPRObservation: pr_id=pr_id, patch="", file_commit_sha={}, + modified_symbols=[], ) diff --git a/src/codegen/git/utils/pr_review.py b/src/codegen/git/utils/pr_review.py index a042ce709..ad0c637ac 100644 --- a/src/codegen/git/utils/pr_review.py +++ b/src/codegen/git/utils/pr_review.py @@ -150,7 +150,7 @@ def is_modified(self, editable: "Editable") -> bool: return False @property - def modified_symbols(self) -> list["Symbol"]: + def modified_symbols(self) -> list[str]: # Import SourceFile locally to avoid circular dependencies from codegen.sdk.core.file import SourceFile @@ -163,7 +163,8 @@ def modified_symbols(self) -> list["Symbol"]: continue for symbol in file.symbols: if self.is_modified(symbol): - all_modified.append(symbol) + all_modified.append(symbol.name) + return all_modified def get_pr_diff(self) -> str: diff --git a/src/codegen/sdk/core/codebase.py b/src/codegen/sdk/core/codebase.py index 90023da3d..6858462f5 100644 --- a/src/codegen/sdk/core/codebase.py +++ b/src/codegen/sdk/core/codebase.py @@ -1311,13 +1311,13 @@ def from_repo( logger.exception(f"Failed to initialize codebase: {e}") raise - def get_modified_symbols_in_pr(self, pr_id: int) -> tuple[str, dict[str, str]]: + def get_modified_symbols_in_pr(self, pr_id: int) -> tuple[str, dict[str, str], list[str]]: """Get all modified symbols in a pull request""" pr = self._op.get_pull_request(pr_id) cg_pr = CodegenPR(self._op, self, pr) patch = cg_pr.get_pr_diff() commit_sha = cg_pr.get_file_commit_shas() - return patch, commit_sha + return patch, commit_sha, cg_pr.modified_symbols def create_pr_comment(self, pr_number: int, body: str) -> None: """Create a comment on a pull request""" From 6295093f7db03c7c3463a93a7f5089764bf53b8d Mon Sep 17 00:00:00 2001 From: kopekC <28070492+kopekC@users.noreply.github.com> Date: Fri, 21 Feb 2025 17:40:27 +0000 Subject: [PATCH 2/2] Automated pre-commit update --- src/codegen/extensions/tools/github/view_pr.py | 2 +- src/codegen/git/utils/pr_review.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/codegen/extensions/tools/github/view_pr.py b/src/codegen/extensions/tools/github/view_pr.py index d30edc8e6..00c20f7bb 100644 --- a/src/codegen/extensions/tools/github/view_pr.py +++ b/src/codegen/extensions/tools/github/view_pr.py @@ -5,7 +5,7 @@ from pydantic import Field from codegen.sdk.core.codebase import Codebase -from codegen.sdk.core.symbol import Symbol + from ..observation import Observation diff --git a/src/codegen/git/utils/pr_review.py b/src/codegen/git/utils/pr_review.py index ad0c637ac..4ebdc204a 100644 --- a/src/codegen/git/utils/pr_review.py +++ b/src/codegen/git/utils/pr_review.py @@ -9,7 +9,7 @@ from codegen.git.repo_operator.repo_operator import RepoOperator if TYPE_CHECKING: - from codegen.sdk.core.codebase import Codebase, Editable, File, Symbol + from codegen.sdk.core.codebase import Codebase, Editable, File def get_merge_base(git_repo_client: Repository, pull: PullRequest | PullRequestContext) -> str: @@ -164,7 +164,7 @@ def modified_symbols(self) -> list[str]: for symbol in file.symbols: if self.is_modified(symbol): all_modified.append(symbol.name) - + return all_modified def get_pr_diff(self) -> str: