Skip to content

Commit ca81e37

Browse files
committed
pre-commit
1 parent 4b95135 commit ca81e37

File tree

3 files changed

+18
-11
lines changed

3 files changed

+18
-11
lines changed

commit0/cli.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -249,7 +249,9 @@ def test(
249249
if reference:
250250
branch = "reference"
251251
if branch is None and not reference:
252-
git_path = os.path.join(commit0_config["base_dir"], repo_or_repo_path.split("/")[-1])
252+
git_path = os.path.join(
253+
commit0_config["base_dir"], repo_or_repo_path.split("/")[-1]
254+
)
253255
branch = get_active_branch(git_path)
254256

255257
if verbose == 2:
@@ -262,7 +264,7 @@ def test(
262264
commit0_config["dataset_split"],
263265
commit0_config["base_dir"],
264266
repo_or_repo_path,
265-
branch,
267+
branch, # type: ignore
266268
test_ids,
267269
backend,
268270
timeout,

commit0/harness/evaluate.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from concurrent.futures import ThreadPoolExecutor, as_completed
66
from datasets import load_dataset
77
from tqdm import tqdm
8-
from typing import Iterator
8+
from typing import Iterator, Union
99

1010
from commit0.harness.run_pytest_ids import main as run_tests
1111
from commit0.harness.get_pytest_ids import main as get_tests
@@ -23,7 +23,7 @@ def main(
2323
dataset_split: str,
2424
repo_split: str,
2525
base_dir: str,
26-
branch: str,
26+
branch: Union[str, None],
2727
backend: str,
2828
timeout: int,
2929
num_cpus: int,
@@ -32,19 +32,19 @@ def main(
3232
) -> None:
3333
dataset: Iterator[RepoInstance] = load_dataset(dataset_name, split=dataset_split) # type: ignore
3434
repos = SPLIT[repo_split]
35-
pairs = []
35+
triples = []
3636
log_dirs = []
3737
for example in dataset:
3838
repo_name = example["repo"].split("/")[-1]
3939
if repo_split != "all" and repo_name not in SPLIT[repo_split]:
4040
continue
41-
pairs.append((repo_name, example["test"]["test_dir"]))
4241
hashed_test_ids = get_hash_string(example["test"]["test_dir"])
4342
if branch is None:
4443
git_path = os.path.join(base_dir, repo_name)
4544
branch = get_active_branch(git_path)
4645
log_dir = RUN_PYTEST_LOG_DIR / repo_name / branch / hashed_test_ids
4746
log_dirs.append(str(log_dir))
47+
triples.append((repo_name, example["test"]["test_dir"], branch))
4848

4949
with tqdm(total=len(repos), smoothing=0, desc="Evaluating repos") as pbar:
5050
with ThreadPoolExecutor(max_workers=num_workers) as executor:
@@ -64,7 +64,7 @@ def main(
6464
rebuild_image=rebuild_image,
6565
verbose=0,
6666
): None
67-
for repo, test_dir in pairs
67+
for repo, test_dir, branch in triples
6868
}
6969
# Wait for each future to complete
7070
for future in as_completed(futures):

commit0/harness/utils.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -190,25 +190,30 @@ def generate_patch_between_commits(
190190

191191

192192
def get_active_branch(repo_path: Union[str, Path]) -> str:
193-
"""
194-
Retrieve the current active branch of a Git repository.
193+
"""Retrieve the current active branch of a Git repository.
195194
196195
Args:
196+
----
197197
repo_path (Path): The path to git repo.
198198
199199
Returns:
200+
-------
200201
str: The name of the active branch.
201202
202203
Raises:
204+
------
203205
Exception: If the repository is in a detached HEAD state.
206+
204207
"""
205208
repo = git.Repo(repo_path)
206209
try:
207210
# Get the current active branch
208211
branch = repo.active_branch.name
209212
except TypeError as e:
210-
raise Exception(f"{e}\nThis means the repository is in a detached HEAD state. "
211-
"To proceed, please specify a valid branch.")
213+
raise Exception(
214+
f"{e}\nThis means the repository is in a detached HEAD state. "
215+
"To proceed, please specify a valid branch by using --branch {branch}."
216+
)
212217

213218
return branch
214219

0 commit comments

Comments
 (0)