|
| 1 | +import os |
| 2 | +import yaml |
| 3 | +import multiprocessing |
| 4 | +from tqdm import tqdm |
| 5 | +from datasets import load_dataset |
| 6 | +from git import Repo |
| 7 | +from agent.agent_utils import ( |
| 8 | + args2string, |
| 9 | + create_branch, |
| 10 | + get_message, |
| 11 | + get_target_edit_files, |
| 12 | + get_lint_cmd, |
| 13 | + read_yaml_config, |
| 14 | +) |
| 15 | +from agent.agents import AiderAgents |
| 16 | +from typing import Optional, Type, cast |
| 17 | +from types import TracebackType |
| 18 | +from agent.class_types import AgentConfig |
| 19 | +from commit0.harness.constants import SPLIT |
| 20 | +from commit0.harness.get_pytest_ids import main as get_tests |
| 21 | +from commit0.harness.constants import RUN_AIDER_LOG_DIR, RepoInstance |
| 22 | +from commit0.cli import read_commit0_dot_file |
| 23 | +from pathlib import Path |
| 24 | +from datetime import datetime |
| 25 | + |
| 26 | + |
| 27 | +class DirContext: |
| 28 | + def __init__(self, d: str): |
| 29 | + self.dir = d |
| 30 | + self.cwd = os.getcwd() |
| 31 | + |
| 32 | + def __enter__(self): |
| 33 | + os.chdir(self.dir) |
| 34 | + |
| 35 | + def __exit__( |
| 36 | + self, |
| 37 | + exctype: Optional[Type[BaseException]], |
| 38 | + excinst: Optional[BaseException], |
| 39 | + exctb: Optional[TracebackType], |
| 40 | + ) -> None: |
| 41 | + os.chdir(self.cwd) |
| 42 | + |
| 43 | + |
| 44 | +def run_agent_for_repo( |
| 45 | + repo_base_dir: str, |
| 46 | + agent_config: AgentConfig, |
| 47 | + example: RepoInstance, |
| 48 | + branch: Optional[str] = None, |
| 49 | + override_previous_changes: bool = False, |
| 50 | + backend: str = "modal", |
| 51 | + log_dir: str = str(RUN_AIDER_LOG_DIR.resolve()), |
| 52 | +) -> None: |
| 53 | + """Run Aider for a given repository.""" |
| 54 | + # get repo info |
| 55 | + _, repo_name = example["repo"].split("/") |
| 56 | + |
| 57 | + repo_name = repo_name.lower() |
| 58 | + repo_name = repo_name.replace(".", "-") |
| 59 | + |
| 60 | + repo_path = os.path.join(repo_base_dir, repo_name) |
| 61 | + repo_path = os.path.abspath(repo_path) |
| 62 | + |
| 63 | + src_dir = os.path.join(repo_path, example["src_dir"]) |
| 64 | + |
| 65 | + try: |
| 66 | + local_repo = Repo(repo_path) |
| 67 | + except Exception: |
| 68 | + raise Exception( |
| 69 | + f"{repo_path} is not a git repo. Check if base_dir is correctly specified." |
| 70 | + ) |
| 71 | + |
| 72 | + if agent_config.agent_name == "aider": |
| 73 | + agent = AiderAgents(agent_config.max_iteration, agent_config.model_name) |
| 74 | + else: |
| 75 | + raise NotImplementedError( |
| 76 | + f"{agent_config.agent_name} is not implemented; please add your implementations in baselines/agents.py." |
| 77 | + ) |
| 78 | + |
| 79 | + # if branch_name is not provided, create a new branch name based on agent_config |
| 80 | + if branch is None: |
| 81 | + branch = args2string(agent_config) |
| 82 | + |
| 83 | + create_branch(local_repo, branch, example["base_commit"]) |
| 84 | + |
| 85 | + # in cases where the latest commit of branch is not commit 0 |
| 86 | + # set it back to commit 0 |
| 87 | + latest_commit = local_repo.commit(branch) |
| 88 | + if latest_commit.hexsha != example["base_commit"] and override_previous_changes: |
| 89 | + local_repo.git.reset("--hard", example["base_commit"]) |
| 90 | + |
| 91 | + # prepare the log dir |
| 92 | + experiment_log_dir = ( |
| 93 | + Path(log_dir) |
| 94 | + / repo_name |
| 95 | + / branch |
| 96 | + / datetime.now().strftime("%Y-%m-%d_%H-%M-%S") |
| 97 | + ) |
| 98 | + experiment_log_dir.mkdir(parents=True, exist_ok=True) |
| 99 | + |
| 100 | + # write agent_config to .agent.yaml in the log_dir for record |
| 101 | + agent_config_log_file = experiment_log_dir / ".agent.yaml" |
| 102 | + with open(agent_config_log_file, "w") as agent_config_file: |
| 103 | + yaml.dump(agent_config, agent_config_file) |
| 104 | + |
| 105 | + # TODO: make this path more general |
| 106 | + commit0_dot_file_path = str(Path(repo_path).parent.parent / ".commit0.yaml") |
| 107 | + with DirContext(repo_path): |
| 108 | + if agent_config is None: |
| 109 | + raise ValueError("Invalid input") |
| 110 | + |
| 111 | + target_edit_files = get_target_edit_files( |
| 112 | + src_dir, src_prefix=example["src_dir"] |
| 113 | + ) |
| 114 | + |
| 115 | + if agent_config.run_tests: |
| 116 | + # Call the commit0 get-tests command to retrieve test files |
| 117 | + test_files_str = get_tests(repo_name, verbose=0) |
| 118 | + test_files = sorted(list(set([i.split(":")[0] for i in test_files_str]))) |
| 119 | + |
| 120 | + # when unit test feedback is available, iterate over test files |
| 121 | + for test_file in test_files: |
| 122 | + test_cmd = f"python -m commit0 test {repo_path} {test_file} --branch {branch} --backend {backend} --commit0_dot_file_path {commit0_dot_file_path}" |
| 123 | + test_file_name = test_file.replace(".py", "").replace("/", "__") |
| 124 | + test_log_dir = experiment_log_dir / test_file_name |
| 125 | + lint_cmd = get_lint_cmd(repo_name, agent_config.use_lint_info) |
| 126 | + message = get_message(agent_config, repo_path, test_file=test_file) |
| 127 | + _ = agent.run( |
| 128 | + message, |
| 129 | + test_cmd, |
| 130 | + lint_cmd, |
| 131 | + target_edit_files, |
| 132 | + test_log_dir, |
| 133 | + ) |
| 134 | + # cost = agent_return.last_cost |
| 135 | + else: |
| 136 | + # when unit test feedback is not available, iterate over target files to edit |
| 137 | + message = get_message( |
| 138 | + agent_config, repo_path, test_dir=example["test"]["test_dir"] |
| 139 | + ) |
| 140 | + for f in target_edit_files: |
| 141 | + file_name = f.replace(".py", "").replace("/", "__") |
| 142 | + file_log_dir = experiment_log_dir / file_name |
| 143 | + lint_cmd = get_lint_cmd(repo_name, agent_config.use_lint_info) |
| 144 | + _ = agent.run(message, "", lint_cmd, [f], file_log_dir) |
| 145 | + # cost = agent_return.last_cost |
| 146 | + |
| 147 | + |
| 148 | +def run_agent( |
| 149 | + branch: str, |
| 150 | + override_previous_changes: bool, |
| 151 | + backend: str, |
| 152 | + agent_config_file: str, |
| 153 | + log_dir: str, |
| 154 | + max_parallel_repos: int, |
| 155 | +) -> None: |
| 156 | + """Main function to run Aider for a given repository. |
| 157 | +
|
| 158 | + Will run in parallel for each repo. |
| 159 | + """ |
| 160 | + config = read_yaml_config(agent_config_file) |
| 161 | + |
| 162 | + agent_config = AgentConfig(**config) |
| 163 | + |
| 164 | + commit0_config = read_commit0_dot_file(".commit0.yaml") |
| 165 | + |
| 166 | + dataset = load_dataset( |
| 167 | + commit0_config["dataset_name"], split=commit0_config["dataset_split"] |
| 168 | + ) |
| 169 | + filtered_dataset = [ |
| 170 | + example |
| 171 | + for example in dataset |
| 172 | + if commit0_config["repo_split"] == "all" |
| 173 | + or ( |
| 174 | + isinstance(example, dict) |
| 175 | + and "repo" in example |
| 176 | + and isinstance(example["repo"], str) |
| 177 | + and example["repo"].split("/")[-1] |
| 178 | + in SPLIT.get(commit0_config["repo_split"], []) |
| 179 | + ) |
| 180 | + ] |
| 181 | + assert len(filtered_dataset) > 0, "No examples available" |
| 182 | + |
| 183 | + # if len(filtered_dataset) > 1: |
| 184 | + # sys.stdout = open(os.devnull, "w") |
| 185 | + print("jere") |
| 186 | + print(filtered_dataset[0]) |
| 187 | + for example in filtered_dataset: |
| 188 | + if "joblib" in example["repo"]: |
| 189 | + print(example) |
| 190 | + run_agent_for_repo( |
| 191 | + commit0_config["base_dir"], |
| 192 | + agent_config, |
| 193 | + cast(RepoInstance, example), |
| 194 | + branch, |
| 195 | + override_previous_changes, |
| 196 | + backend, |
| 197 | + log_dir, |
| 198 | + ) |
| 199 | + # with tqdm( |
| 200 | + # total=len(filtered_dataset), smoothing=0, desc="Running Aider for repos" |
| 201 | + # ) as pbar: |
| 202 | + # with multiprocessing.Pool(processes=max_parallel_repos) as pool: |
| 203 | + # results = [] |
| 204 | + |
| 205 | + # # Use apply_async to submit jobs and add progress bar updates |
| 206 | + # for example in filtered_dataset: |
| 207 | + # result = pool.apply_async( |
| 208 | + # run_agent_for_repo, |
| 209 | + # args=( |
| 210 | + # commit0_config["base_dir"], |
| 211 | + # agent_config, |
| 212 | + # cast(RepoInstance, example), |
| 213 | + # branch, |
| 214 | + # override_previous_changes, |
| 215 | + # backend, |
| 216 | + # log_dir, |
| 217 | + # ), |
| 218 | + # callback=lambda _: pbar.update( |
| 219 | + # 1 |
| 220 | + # ), # Update progress bar on task completion |
| 221 | + # ) |
| 222 | + # results.append(result) |
| 223 | + |
| 224 | + # for result in results: |
| 225 | + # result.wait() |
0 commit comments