55from concurrent .futures import ThreadPoolExecutor , as_completed
66from datasets import load_dataset
77from tqdm import tqdm
8- from typing import Iterator
8+ from typing import Iterator , Union
99
1010from commit0 .harness .run_pytest_ids import main as run_tests
1111from commit0 .harness .get_pytest_ids import main as get_tests
@@ -23,7 +23,7 @@ def main(
2323 dataset_split : str ,
2424 repo_split : str ,
2525 base_dir : str ,
26- branch : str ,
26+ branch : Union [ str , None ] ,
2727 backend : str ,
2828 timeout : int ,
2929 num_cpus : int ,
@@ -32,19 +32,19 @@ def main(
3232) -> None :
3333 dataset : Iterator [RepoInstance ] = load_dataset (dataset_name , split = dataset_split ) # type: ignore
3434 repos = SPLIT [repo_split ]
35- pairs = []
35+ triples = []
3636 log_dirs = []
3737 for example in dataset :
3838 repo_name = example ["repo" ].split ("/" )[- 1 ]
3939 if repo_split != "all" and repo_name not in SPLIT [repo_split ]:
4040 continue
41- pairs .append ((repo_name , example ["test" ]["test_dir" ]))
4241 hashed_test_ids = get_hash_string (example ["test" ]["test_dir" ])
4342 if branch is None :
4443 git_path = os .path .join (base_dir , repo_name )
4544 branch = get_active_branch (git_path )
4645 log_dir = RUN_PYTEST_LOG_DIR / repo_name / branch / hashed_test_ids
4746 log_dirs .append (str (log_dir ))
47+ triples .append ((repo_name , example ["test" ]["test_dir" ], branch ))
4848
4949 with tqdm (total = len (repos ), smoothing = 0 , desc = "Evaluating repos" ) as pbar :
5050 with ThreadPoolExecutor (max_workers = num_workers ) as executor :
@@ -64,7 +64,7 @@ def main(
6464 rebuild_image = rebuild_image ,
6565 verbose = 0 ,
6666 ): None
67- for repo , test_dir in pairs
67+ for repo , test_dir , branch in triples
6868 }
6969 # Wait for each future to complete
7070 for future in as_completed (futures ):
0 commit comments