Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 7 additions & 13 deletions commit0/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import commit0.harness.lint
import commit0.harness.save
from commit0.harness.constants import SPLIT, SPLIT_ALL
from commit0.harness.utils import get_active_branch
import subprocess
import yaml
import os
Expand Down Expand Up @@ -245,14 +246,13 @@ def test(

commit0_config = read_commit0_dot_file(commit0_dot_file_path)

if not branch and not reference:
raise typer.BadParameter(
f"Invalid {highlight('BRANCH', Colors.RED)}. Either --reference or provide a branch name.",
param_hint="BRANCH",
)
if reference:
branch = "reference"
assert branch is not None, "branch is not specified"
if branch is None and not reference:
git_path = os.path.join(
commit0_config["base_dir"], repo_or_repo_path.split("/")[-1]
)
branch = get_active_branch(git_path)

if verbose == 2:
typer.echo(f"Running tests for repository: {repo_or_repo_path}")
Expand All @@ -264,7 +264,7 @@ def test(
commit0_config["dataset_split"],
commit0_config["base_dir"],
repo_or_repo_path,
branch,
branch, # type: ignore
test_ids,
backend,
timeout,
Expand Down Expand Up @@ -294,14 +294,8 @@ def evaluate(
) -> None:
"""Evaluate Commit0 split you choose in Setup Stage."""
check_commit0_path()
if not branch and not reference:
raise typer.BadParameter(
f"Invalid {highlight('BRANCH', Colors.RED)}. Either --reference or provide a branch name",
param_hint="BRANCH",
)
if reference:
branch = "reference"
assert branch is not None, "branch is not specified"

commit0_config = read_commit0_dot_file(commit0_dot_file_path)
check_valid(commit0_config["repo_split"], SPLIT)
Expand Down
2 changes: 2 additions & 0 deletions commit0/harness/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ class Files(TypedDict):
patch: Dict[str, Path]


BASE_BRANCH = "commit0"

# Constants - Evaluation Log Directories
BASE_IMAGE_BUILD_DIR = Path("logs/build_images/base")
REPO_IMAGE_BUILD_DIR = Path("logs/build_images/repo")
Expand Down
15 changes: 9 additions & 6 deletions commit0/harness/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,12 @@
from concurrent.futures import ThreadPoolExecutor, as_completed
from datasets import load_dataset
from tqdm import tqdm
from typing import Iterator
from typing import Iterator, Union

from commit0.harness.run_pytest_ids import main as run_tests
from commit0.harness.get_pytest_ids import main as get_tests
from commit0.harness.constants import RepoInstance, SPLIT, RUN_PYTEST_LOG_DIR
from commit0.harness.utils import get_hash_string
from commit0.harness.utils import get_hash_string, get_active_branch

logging.basicConfig(
level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
Expand All @@ -23,7 +23,7 @@ def main(
dataset_split: str,
repo_split: str,
base_dir: str,
branch: str,
branch: Union[str, None],
backend: str,
timeout: int,
num_cpus: int,
Expand All @@ -32,16 +32,19 @@ def main(
) -> None:
dataset: Iterator[RepoInstance] = load_dataset(dataset_name, split=dataset_split) # type: ignore
repos = SPLIT[repo_split]
pairs = []
triples = []
log_dirs = []
for example in dataset:
repo_name = example["repo"].split("/")[-1]
if repo_split != "all" and repo_name not in SPLIT[repo_split]:
continue
pairs.append((repo_name, example["test"]["test_dir"]))
hashed_test_ids = get_hash_string(example["test"]["test_dir"])
if branch is None:
git_path = os.path.join(base_dir, repo_name)
branch = get_active_branch(git_path)
log_dir = RUN_PYTEST_LOG_DIR / repo_name / branch / hashed_test_ids
log_dirs.append(str(log_dir))
triples.append((repo_name, example["test"]["test_dir"], branch))

with tqdm(total=len(repos), smoothing=0, desc="Evaluating repos") as pbar:
with ThreadPoolExecutor(max_workers=num_workers) as executor:
Expand All @@ -61,7 +64,7 @@ def main(
rebuild_image=rebuild_image,
verbose=0,
): None
for repo, test_dir in pairs
for repo, test_dir, branch in triples
}
# Wait for each future to complete
for future in as_completed(futures):
Expand Down
27 changes: 21 additions & 6 deletions commit0/harness/run_pytest_ids.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,15 +82,30 @@ def main(
)
except Exception as e:
raise e
commit_id = ""
if branch == "reference":
commit_id = example["reference_commit"]
else:
try:
local_repo.git.checkout(branch)
local_branch = local_repo.branches[branch]
commit_id = local_branch.commit.hexsha
except Exception as e:
raise Exception(f"Problem checking out branch {branch}.\n{e}")
# Check if it's a local branch
if branch in local_repo.branches:
commit_id = local_repo.commit(branch).hexsha
else:
found_remote_branch = False
for remote in local_repo.remotes:
remote.fetch() # Fetch latest updates from each remote

# Check if the branch exists in this remote
for ref in remote.refs:
if (
ref.remote_head == branch
): # Compare branch name without remote prefix
commit_id = local_repo.commit(ref.name).hexsha
found_remote_branch = True
break # Branch found, no need to keep checking this remote
if found_remote_branch:
break # Stop checking other remotes if branch is found
if not found_remote_branch:
raise Exception(f"Branch {branch} does not exist locally or remotely.")
patch = generate_patch_between_commits(
local_repo, example["base_commit"], commit_id
)
Expand Down
9 changes: 7 additions & 2 deletions commit0/harness/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from commit0.harness.utils import (
clone_repo,
)
from commit0.harness.constants import RepoInstance, SPLIT
from commit0.harness.constants import BASE_BRANCH, RepoInstance, SPLIT


logging.basicConfig(
Expand All @@ -29,7 +29,12 @@ def main(
continue
clone_url = f"https://github.com/{example['repo']}.git"
clone_dir = os.path.abspath(os.path.join(base_dir, repo_name))
clone_repo(clone_url, clone_dir, example["base_commit"], logger)
branch = dataset_name.split("/")[-1]
repo = clone_repo(clone_url, clone_dir, branch, logger)
if BASE_BRANCH in repo.branches:
repo.git.branch("-d", BASE_BRANCH)
repo.git.checkout("-b", BASE_BRANCH)
logger.info("Checked out the base commit: commit 0")


__all__ = []
42 changes: 35 additions & 7 deletions commit0/harness/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import time
import sys
from pathlib import Path
from typing import Optional
from typing import Optional, Union

from fastcore.net import HTTP404NotFoundError, HTTP403ForbiddenError # type: ignore
from ghapi.core import GhApi
Expand Down Expand Up @@ -85,7 +85,7 @@ def extract_test_output(ss: str, pattern: str) -> str:


def clone_repo(
clone_url: str, clone_dir: str, commit: str, logger: logging.Logger
clone_url: str, clone_dir: str, branch: str, logger: logging.Logger
) -> git.Repo:
"""Clone repo into the specified directory if it does not already exist.

Expand All @@ -98,8 +98,8 @@ def clone_repo(
URL of the repository to clone.
clone_dir : str
Directory where the repository will be cloned.
commit : str
The commit hash or branch/tag name to checkout.
branch : str
The branch/tag name to checkout.
logger : logging.Logger
The logger object.

Expand Down Expand Up @@ -129,11 +129,10 @@ def clone_repo(
except git.exc.GitCommandError as e:
raise RuntimeError(f"Failed to clone repository: {e}")

logger.info(f"Checking out {commit}")
try:
repo.git.checkout(commit)
repo.git.checkout(branch)
except git.exc.GitCommandError as e:
raise RuntimeError(f"Failed to check out {commit}: {e}")
raise RuntimeError(f"Failed to check out {branch}: {e}")

return repo

Expand Down Expand Up @@ -190,4 +189,33 @@ def generate_patch_between_commits(
raise Exception(f"Error generating patch: {e}")


def get_active_branch(repo_path: Union[str, Path]) -> str:
"""Retrieve the current active branch of a Git repository.

Args:
----
repo_path (Path): The path to git repo.

Returns:
-------
str: The name of the active branch.

Raises:
------
Exception: If the repository is in a detached HEAD state.

"""
repo = git.Repo(repo_path)
try:
# Get the current active branch
branch = repo.active_branch.name
except TypeError as e:
raise Exception(
f"{e}\nThis means the repository is in a detached HEAD state. "
"To proceed, please specify a valid branch by using --branch {branch}."
)

return branch


__all__ = []