diff --git a/commit0/cli.py b/commit0/cli.py index 8c82b08..c0d8596 100644 --- a/commit0/cli.py +++ b/commit0/cli.py @@ -10,6 +10,7 @@ import commit0.harness.lint import commit0.harness.save from commit0.harness.constants import SPLIT, SPLIT_ALL +from commit0.harness.utils import get_active_branch import subprocess import yaml import os @@ -245,14 +246,13 @@ def test( commit0_config = read_commit0_dot_file(commit0_dot_file_path) - if not branch and not reference: - raise typer.BadParameter( - f"Invalid {highlight('BRANCH', Colors.RED)}. Either --reference or provide a branch name.", - param_hint="BRANCH", - ) if reference: branch = "reference" - assert branch is not None, "branch is not specified" + if branch is None and not reference: + git_path = os.path.join( + commit0_config["base_dir"], repo_or_repo_path.split("/")[-1] + ) + branch = get_active_branch(git_path) if verbose == 2: typer.echo(f"Running tests for repository: {repo_or_repo_path}") @@ -264,7 +264,7 @@ def test( commit0_config["dataset_split"], commit0_config["base_dir"], repo_or_repo_path, - branch, + branch, # type: ignore test_ids, backend, timeout, @@ -294,14 +294,8 @@ def evaluate( ) -> None: """Evaluate Commit0 split you choose in Setup Stage.""" check_commit0_path() - if not branch and not reference: - raise typer.BadParameter( - f"Invalid {highlight('BRANCH', Colors.RED)}. Either --reference or provide a branch name", - param_hint="BRANCH", - ) if reference: branch = "reference" - assert branch is not None, "branch is not specified" commit0_config = read_commit0_dot_file(commit0_dot_file_path) check_valid(commit0_config["repo_split"], SPLIT) diff --git a/commit0/harness/constants.py b/commit0/harness/constants.py index f2957fa..9a92337 100644 --- a/commit0/harness/constants.py +++ b/commit0/harness/constants.py @@ -16,6 +16,8 @@ class Files(TypedDict): patch: Dict[str, Path] +BASE_BRANCH = "commit0" + # Constants - Evaluation Log Directories BASE_IMAGE_BUILD_DIR = Path("logs/build_images/base") REPO_IMAGE_BUILD_DIR = Path("logs/build_images/repo") diff --git a/commit0/harness/evaluate.py b/commit0/harness/evaluate.py index 11dbc98..ddc4b15 100644 --- a/commit0/harness/evaluate.py +++ b/commit0/harness/evaluate.py @@ -5,12 +5,12 @@ from concurrent.futures import ThreadPoolExecutor, as_completed from datasets import load_dataset from tqdm import tqdm -from typing import Iterator +from typing import Iterator, Union from commit0.harness.run_pytest_ids import main as run_tests from commit0.harness.get_pytest_ids import main as get_tests from commit0.harness.constants import RepoInstance, SPLIT, RUN_PYTEST_LOG_DIR -from commit0.harness.utils import get_hash_string +from commit0.harness.utils import get_hash_string, get_active_branch logging.basicConfig( level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s" @@ -23,7 +23,7 @@ def main( dataset_split: str, repo_split: str, base_dir: str, - branch: str, + branch: Union[str, None], backend: str, timeout: int, num_cpus: int, @@ -32,16 +32,19 @@ def main( ) -> None: dataset: Iterator[RepoInstance] = load_dataset(dataset_name, split=dataset_split) # type: ignore repos = SPLIT[repo_split] - pairs = [] + triples = [] log_dirs = [] for example in dataset: repo_name = example["repo"].split("/")[-1] if repo_split != "all" and repo_name not in SPLIT[repo_split]: continue - pairs.append((repo_name, example["test"]["test_dir"])) hashed_test_ids = get_hash_string(example["test"]["test_dir"]) + if branch is None: + git_path = os.path.join(base_dir, repo_name) + branch = get_active_branch(git_path) log_dir = RUN_PYTEST_LOG_DIR / repo_name / branch / hashed_test_ids log_dirs.append(str(log_dir)) + triples.append((repo_name, example["test"]["test_dir"], branch)) with tqdm(total=len(repos), smoothing=0, desc="Evaluating repos") as pbar: with ThreadPoolExecutor(max_workers=num_workers) as executor: @@ -61,7 +64,7 @@ def main( rebuild_image=rebuild_image, verbose=0, ): None - for repo, test_dir in pairs + for repo, test_dir, branch in triples } # Wait for each future to complete for future in as_completed(futures): diff --git a/commit0/harness/run_pytest_ids.py b/commit0/harness/run_pytest_ids.py index 09fbe45..fd45080 100644 --- a/commit0/harness/run_pytest_ids.py +++ b/commit0/harness/run_pytest_ids.py @@ -82,15 +82,30 @@ def main( ) except Exception as e: raise e + commit_id = "" if branch == "reference": commit_id = example["reference_commit"] else: - try: - local_repo.git.checkout(branch) - local_branch = local_repo.branches[branch] - commit_id = local_branch.commit.hexsha - except Exception as e: - raise Exception(f"Problem checking out branch {branch}.\n{e}") + # Check if it's a local branch + if branch in local_repo.branches: + commit_id = local_repo.commit(branch).hexsha + else: + found_remote_branch = False + for remote in local_repo.remotes: + remote.fetch() # Fetch latest updates from each remote + + # Check if the branch exists in this remote + for ref in remote.refs: + if ( + ref.remote_head == branch + ): # Compare branch name without remote prefix + commit_id = local_repo.commit(ref.name).hexsha + found_remote_branch = True + break # Branch found, no need to keep checking this remote + if found_remote_branch: + break # Stop checking other remotes if branch is found + if not found_remote_branch: + raise Exception(f"Branch {branch} does not exist locally or remotely.") patch = generate_patch_between_commits( local_repo, example["base_commit"], commit_id ) diff --git a/commit0/harness/setup.py b/commit0/harness/setup.py index edaadcd..b355572 100644 --- a/commit0/harness/setup.py +++ b/commit0/harness/setup.py @@ -7,7 +7,7 @@ from commit0.harness.utils import ( clone_repo, ) -from commit0.harness.constants import RepoInstance, SPLIT +from commit0.harness.constants import BASE_BRANCH, RepoInstance, SPLIT logging.basicConfig( @@ -29,7 +29,12 @@ def main( continue clone_url = f"https://github.com/{example['repo']}.git" clone_dir = os.path.abspath(os.path.join(base_dir, repo_name)) - clone_repo(clone_url, clone_dir, example["base_commit"], logger) + branch = dataset_name.split("/")[-1] + repo = clone_repo(clone_url, clone_dir, branch, logger) + if BASE_BRANCH in repo.branches: + repo.git.branch("-d", BASE_BRANCH) + repo.git.checkout("-b", BASE_BRANCH) + logger.info("Checked out the base commit: commit 0") __all__ = [] diff --git a/commit0/harness/utils.py b/commit0/harness/utils.py index 8836415..b006035 100644 --- a/commit0/harness/utils.py +++ b/commit0/harness/utils.py @@ -6,7 +6,7 @@ import time import sys from pathlib import Path -from typing import Optional +from typing import Optional, Union from fastcore.net import HTTP404NotFoundError, HTTP403ForbiddenError # type: ignore from ghapi.core import GhApi @@ -85,7 +85,7 @@ def extract_test_output(ss: str, pattern: str) -> str: def clone_repo( - clone_url: str, clone_dir: str, commit: str, logger: logging.Logger + clone_url: str, clone_dir: str, branch: str, logger: logging.Logger ) -> git.Repo: """Clone repo into the specified directory if it does not already exist. @@ -98,8 +98,8 @@ def clone_repo( URL of the repository to clone. clone_dir : str Directory where the repository will be cloned. - commit : str - The commit hash or branch/tag name to checkout. + branch : str + The branch/tag name to checkout. logger : logging.Logger The logger object. @@ -129,11 +129,10 @@ def clone_repo( except git.exc.GitCommandError as e: raise RuntimeError(f"Failed to clone repository: {e}") - logger.info(f"Checking out {commit}") try: - repo.git.checkout(commit) + repo.git.checkout(branch) except git.exc.GitCommandError as e: - raise RuntimeError(f"Failed to check out {commit}: {e}") + raise RuntimeError(f"Failed to check out {branch}: {e}") return repo @@ -190,4 +189,33 @@ def generate_patch_between_commits( raise Exception(f"Error generating patch: {e}") +def get_active_branch(repo_path: Union[str, Path]) -> str: + """Retrieve the current active branch of a Git repository. + + Args: + ---- + repo_path (Path): The path to git repo. + + Returns: + ------- + str: The name of the active branch. + + Raises: + ------ + Exception: If the repository is in a detached HEAD state. + + """ + repo = git.Repo(repo_path) + try: + # Get the current active branch + branch = repo.active_branch.name + except TypeError as e: + raise Exception( + f"{e}\nThis means the repository is in a detached HEAD state. " + "To proceed, please specify a valid branch by using --branch {branch}." + ) + + return branch + + __all__ = []