55from concurrent .futures import ThreadPoolExecutor , as_completed
66from datasets import load_dataset
77from tqdm import tqdm
8- from typing import Iterator
8+ from typing import Iterator , Union
99
1010from commit0 .harness .run_pytest_ids import main as run_tests
1111from commit0 .harness .get_pytest_ids import main as get_tests
1212from commit0 .harness .constants import RepoInstance , SPLIT , RUN_PYTEST_LOG_DIR
13- from commit0 .harness .utils import get_hash_string
13+ from commit0 .harness .utils import get_hash_string , get_active_branch
1414
1515logging .basicConfig (
1616 level = logging .INFO , format = "%(asctime)s - %(name)s - %(levelname)s - %(message)s"
@@ -23,7 +23,7 @@ def main(
2323 dataset_split : str ,
2424 repo_split : str ,
2525 base_dir : str ,
26- branch : str ,
26+ branch : Union [ str , None ] ,
2727 backend : str ,
2828 timeout : int ,
2929 num_cpus : int ,
@@ -32,16 +32,19 @@ def main(
3232) -> None :
3333 dataset : Iterator [RepoInstance ] = load_dataset (dataset_name , split = dataset_split ) # type: ignore
3434 repos = SPLIT [repo_split ]
35- pairs = []
35+ triples = []
3636 log_dirs = []
3737 for example in dataset :
3838 repo_name = example ["repo" ].split ("/" )[- 1 ]
3939 if repo_split != "all" and repo_name not in SPLIT [repo_split ]:
4040 continue
41- pairs .append ((repo_name , example ["test" ]["test_dir" ]))
4241 hashed_test_ids = get_hash_string (example ["test" ]["test_dir" ])
42+ if branch is None :
43+ git_path = os .path .join (base_dir , repo_name )
44+ branch = get_active_branch (git_path )
4345 log_dir = RUN_PYTEST_LOG_DIR / repo_name / branch / hashed_test_ids
4446 log_dirs .append (str (log_dir ))
47+ triples .append ((repo_name , example ["test" ]["test_dir" ], branch ))
4548
4649 with tqdm (total = len (repos ), smoothing = 0 , desc = "Evaluating repos" ) as pbar :
4750 with ThreadPoolExecutor (max_workers = num_workers ) as executor :
@@ -61,7 +64,7 @@ def main(
6164 rebuild_image = rebuild_image ,
6265 verbose = 0 ,
6366 ): None
64- for repo , test_dir in pairs
67+ for repo , test_dir , branch in triples
6568 }
6669 # Wait for each future to complete
6770 for future in as_completed (futures ):
0 commit comments