diff --git a/README.md b/README.md index 57f68ce..15e932c 100644 --- a/README.md +++ b/README.md @@ -54,7 +54,7 @@ Available options include: | `--dataset-name` | str | Name of the Huggingface dataset | `wentingzhao/commit0_combined` | | `--dataset-split` | str | Split of the Huggingface dataset | `test` | | `--base-dir` | str | Base directory to clone repos to | `repos/` | -| `--commit0-dot-file-path` | str | Storing path for stateful commit0 configs | `.commit0.yaml` | +| `--commit0-config-file` | str | Storing path for stateful commit0 configs | `.commit0.yaml` | ### Build @@ -64,7 +64,7 @@ Available options include: | Argument | Type | Description | Default | |----------|------|-------------|---------| | `--num-workers` | int | Number of workers | `8` | -| `--commit0-dot-file-path` | str | Path to the commit0 dot file | `.commit0.yaml` | +| `--commit0-config-file` | str | Path to the commit0 dot file | `.commit0.yaml` | | `--verbose` | int | Verbosity level (1 or 2) | `1` | ### Get Tests @@ -91,7 +91,7 @@ Available options include: | `--reference` | bool | Test the reference commit | `False` | | `--coverage` | bool | Get coverage information | `False` | | `--rebuild` | bool | Rebuild an image | `False` | -| `--commit0-dot-file-path` | str | Path to the commit0 dot file | `.commit0.yaml` | +| `--commit0-config-file` | str | Path to the commit0 dot file | `.commit0.yaml` | | `--verbose` | int | Verbosity level (1 or 2) | `1` | | `--stdin` | bool | Read test names from stdin | `False` | @@ -109,7 +109,7 @@ Available options include: | `--num-workers` | int | Number of workers to use | `8` | | `--reference` | bool | Evaluate the reference commit | `False` | | `--coverage` | bool | Get coverage information | `False` | -| `--commit0-dot-file-path` | str | Path to the commit0 dot file | `.commit0.yaml` | +| `--commit0-config-file` | str | Path to the commit0 dot file | `.commit0.yaml` | | `--rebuild` | bool | Rebuild images | `False` | ### Lint @@ -121,7 +121,7 @@ Available options include: |----------|------|-------------|---------| | `repo_or_repo_dir` | str | Directory of the repository to test | | | `--files` | List[Path] | Files to lint (optional) | | -| `--commit0-dot-file-path` | str | Path to the commit0 dot file | `.commit0.yaml` | +| `--commit0-config-file` | str | Path to the commit0 dot file | `.commit0.yaml` | | `--verbose` | int | Verbosity level (1 or 2) | `1` | ### Save @@ -134,7 +134,7 @@ Available options include: | `owner` | str | Owner of the repository | | | `branch` | str | Branch to save | | | `--github-token` | str | GitHub token for authentication | | -| `--commit0-dot-file-path` | str | Path to the commit0 dot file | `.commit0.yaml` | +| `--commit0-config-file` | str | Path to the commit0 dot file | `.commit0.yaml` | ## Agent diff --git a/agent/agent_utils.py b/agent/agent_utils.py index 3c5cb60..4fdea82 100644 --- a/agent/agent_utils.py +++ b/agent/agent_utils.py @@ -234,8 +234,9 @@ def get_target_edit_files( local_repo: git.Repo, src_dir: str, test_dir: str, - latest_commit: str, + branch: str, reference_commit: str, + use_topo_sort_dependencies: bool = True, ) -> tuple[list[str], dict]: """Find the files with functions with the pass statement.""" target_dir = str(local_repo.working_dir) @@ -269,7 +270,7 @@ def get_target_edit_files( ), "all files should be included" # change to latest commit - local_repo.git.checkout(latest_commit) + local_repo.git.checkout(branch) # Remove the base_dir prefix topological_sort_files = [ @@ -282,35 +283,88 @@ def get_target_edit_files( key_without_prefix = key.replace(target_dir, "").lstrip("/") value_without_prefix = [v.replace(target_dir, "").lstrip("/") for v in value] import_dependencies_without_prefix[key_without_prefix] = value_without_prefix + if use_topo_sort_dependencies: + return topological_sort_files, import_dependencies_without_prefix + else: + filtered_files = [ + file.replace(target_dir, "").lstrip("/") for file in filtered_files + ] + return filtered_files, import_dependencies_without_prefix + + +def get_target_edit_files_from_patch( + local_repo: git.Repo, patch: str, use_topo_sort_dependencies: bool = True +) -> tuple[list[str], dict]: + """Get the target files from the patch.""" + working_dir = str(local_repo.working_dir) + target_files = set() + for line in patch.split("\n"): + if line.startswith("+++") or line.startswith("---"): + file_path = line.split()[1] + if file_path.startswith("a/"): + file_path = file_path[2:] + if file_path.startswith("b/"): + file_path = file_path[2:] + target_files.add(file_path) + + target_files_list = list(target_files) + target_files_list = [ + os.path.join(working_dir, file_path) for file_path in target_files_list + ] - return topological_sort_files, import_dependencies_without_prefix + if use_topo_sort_dependencies: + topological_sort_files, import_dependencies = ( + topological_sort_based_on_dependencies(target_files_list) + ) + if len(topological_sort_files) != len(target_files_list): + if len(topological_sort_files) < len(target_files_list): + missing_files = set(target_files_list) - set(topological_sort_files) + topological_sort_files = topological_sort_files + list(missing_files) + else: + raise ValueError( + "topological_sort_files should not be longer than target_files_list" + ) + assert len(topological_sort_files) == len( + target_files_list + ), "all files should be included" + + topological_sort_files = [ + file.replace(working_dir, "").lstrip("/") for file in topological_sort_files + ] + for key, value in import_dependencies.items(): + import_dependencies[key] = [ + v.replace(working_dir, "").lstrip("/") for v in value + ] + return topological_sort_files, import_dependencies + else: + target_files_list = [ + file.replace(working_dir, "").lstrip("/") for file in target_files_list + ] + return target_files_list, {} def get_message( agent_config: AgentConfig, repo_path: str, - test_dir: str | None = None, - test_file: str | None = None, + test_files: list[str] | None = None, ) -> str: """Get the message to Aider.""" prompt = f"{PROMPT_HEADER}" + agent_config.user_prompt - if agent_config.use_unit_tests_info and test_dir: - unit_tests_info = ( - f"\n{UNIT_TESTS_INFO_HEADER} " - + get_dir_info( - dir_path=Path(os.path.join(repo_path, test_dir)), - prefix="", - include_stubs=True, - )[: agent_config.max_unit_tests_info_length] - ) - elif agent_config.use_unit_tests_info and test_file: - unit_tests_info = ( - f"\n{UNIT_TESTS_INFO_HEADER} " - + get_file_info( + # if agent_config.use_unit_tests_info and test_file: + # unit_tests_info = ( + # f"\n{UNIT_TESTS_INFO_HEADER} " + # + get_file_info( + # file_path=Path(os.path.join(repo_path, test_file)), prefix="" + # )[: agent_config.max_unit_tests_info_length] + # ) + if agent_config.use_unit_tests_info and test_files: + unit_tests_info = f"\n{UNIT_TESTS_INFO_HEADER} " + for test_file in test_files: + unit_tests_info += get_file_info( file_path=Path(os.path.join(repo_path, test_file)), prefix="" - )[: agent_config.max_unit_tests_info_length] - ) + ) + unit_tests_info = unit_tests_info[: agent_config.max_unit_tests_info_length] else: unit_tests_info = "" @@ -405,6 +459,33 @@ def create_branch(repo: git.Repo, branch: str, from_commit: str) -> None: raise RuntimeError(f"Failed to create or switch to branch '{branch}': {e}") +def get_changed_files_from_commits( + repo: git.Repo, commit1: str, commit2: str +) -> list[str]: + """Get the changed files from two commits.""" + try: + # Get the commit objects + commit1_obj = repo.commit(commit1) + commit2_obj = repo.commit(commit2) + + # Get the diff between the two commits + diff = commit1_obj.diff(commit2_obj) + + # Extract the changed file paths + changed_files = [item.a_path for item in diff] + + # Check if each changed file is a Python file + python_files = [file for file in changed_files if file.endswith(".py")] + + # Update the changed_files list to only include Python files + changed_files = python_files + + return changed_files + except Exception as e: + print(f"An error occurred: {e}") + return [] + + def args2string(agent_config: AgentConfig) -> str: """Converts specific fields from an `AgentConfig` object into a formatted string. @@ -453,13 +534,14 @@ def get_changed_files(repo: git.Repo) -> list[str]: return files_changed -def get_lint_cmd(repo_name: str, use_lint_info: bool) -> str: +def get_lint_cmd(repo_name: str, use_lint_info: bool, commit0_config_file: str) -> str: """Generate a linting command based on whether to include files. Args: ---- repo_name (str): The name of the repository. use_lint_info (bool): A flag indicating whether to include changed files in the lint command. + commit0_config_file (str): The path to the commit0 dot file. Returns: ------- @@ -469,7 +551,9 @@ def get_lint_cmd(repo_name: str, use_lint_info: bool) -> str: """ lint_cmd = "python -m commit0 lint " if use_lint_info: - lint_cmd += repo_name + " --files " + lint_cmd += ( + repo_name + " --commit0-config-file " + commit0_config_file + " --files " + ) else: lint_cmd = "" return lint_cmd diff --git a/agent/agents.py b/agent/agents.py index 6e7d9d8..e908090 100644 --- a/agent/agents.py +++ b/agent/agents.py @@ -7,6 +7,7 @@ from aider.models import Model from aider.io import InputOutput import re +import os def handle_logging(logging_name: str, log_file: Path) -> None: @@ -24,6 +25,23 @@ def handle_logging(logging_name: str, log_file: Path) -> None: class AgentReturn(ABC): def __init__(self, log_file: Path): self.log_file = log_file + + self.last_cost = 0.0 + + +class Agents(ABC): + def __init__(self, max_iteration: int): + self.max_iteration = max_iteration + + @abstractmethod + def run(self) -> AgentReturn: + """Start agent""" + raise NotImplementedError + + +class AiderReturn(AgentReturn): + def __init__(self, log_file: Path): + super().__init__(log_file) self.last_cost = self.get_money_cost() def get_money_cost(self) -> float: @@ -40,20 +58,25 @@ def get_money_cost(self) -> float: return last_cost -class Agents(ABC): - def __init__(self, max_iteration: int): - self.max_iteration = max_iteration - - @abstractmethod - def run(self) -> AgentReturn: - """Start agent""" - raise NotImplementedError - - class AiderAgents(Agents): def __init__(self, max_iteration: int, model_name: str): super().__init__(max_iteration) self.model = Model(model_name) + # Check if API key is set for the model + if "gpt" in model_name: + api_key = os.environ.get("OPENAI_API_KEY", None) + elif "claude" in model_name: + api_key = os.environ.get("ANTHROPIC_API_KEY", None) + elif "gemini" in model_name: + api_key = os.environ.get("API_KEY", None) + else: + raise ValueError(f"Unsupported model: {model_name}") + + if not api_key: + raise ValueError( + "API Key Error: There is no API key associated with the model for this agent. " + "Edit model_name parameter in .agent.yaml, export API key for that model, and try again." + ) def run( self, @@ -63,6 +86,7 @@ def run( fnames: list[str], log_dir: Path, test_first: bool = False, + lint_first: bool = False, ) -> AgentReturn: """Start aider agent""" if test_cmd: @@ -90,11 +114,6 @@ def run( sys.stdout = open(log_file, "a") sys.stderr = open(log_file, "a") - # Log the message - agent_message_log_file = log_dir / "agent_message.log" - with open(agent_message_log_file, "a") as f: - f.write(f"Message Sent: {message}\n\n") - # Configure httpx and backoff logging handle_logging("httpx", log_file) handle_logging("backoff", log_file) @@ -113,7 +132,7 @@ def run( test_cmd=test_cmd, io=io, ) - coder.max_reflection = self.max_iteration + coder.max_reflections = self.max_iteration coder.stream = True # Run the agent @@ -121,23 +140,11 @@ def run( test_errors = coder.commands.cmd_test(test_cmd) if test_errors: coder.run(test_errors) + elif lint_first: + coder.commands.cmd_lint(fnames=fnames) else: coder.run(message) - # #### TMP - - # #### TMP - # import time - # import random - - # time.sleep(random.random() * 5) - # n = random.random() / 10 - # with open(log_file, "a") as f: - # f.write( - # f"> Tokens: 33k sent, 1.3k received. Cost: $0.12 message, ${n} session. \n" - # ) - # #### TMP - # Close redirected stdout and stderr sys.stdout.close() sys.stderr.close() @@ -145,4 +152,4 @@ def run( sys.stdout = sys.__stdout__ sys.stderr = sys.__stderr__ - return AgentReturn(log_file) + return AiderReturn(log_file) diff --git a/agent/class_types.py b/agent/class_types.py index 03debfa..c06e5f6 100644 --- a/agent/class_types.py +++ b/agent/class_types.py @@ -7,6 +7,8 @@ class AgentConfig: model_name: str use_user_prompt: bool user_prompt: str + use_topo_sort_dependencies: bool + add_import_module_to_context: bool use_repo_info: bool max_repo_info_length: int use_unit_tests_info: bool @@ -14,6 +16,7 @@ class AgentConfig: use_spec_info: bool max_spec_info_length: int use_lint_info: bool + run_entire_dir_lint: bool max_lint_info_length: int pre_commit_config_path: str run_tests: bool diff --git a/agent/cli.py b/agent/cli.py index 8d06891..1b1c371 100644 --- a/agent/cli.py +++ b/agent/cli.py @@ -83,6 +83,14 @@ def config( "Here is your task:\nYou need to complete the implementations for all functions (i.e., those with pass statements) and pass the unit tests.\nDo not change the names of existing functions or classes, as they may be referenced from other code like unit tests, etc.\nWhen you generate code, you must maintain the original formatting of the function stubs (such as whitespaces), otherwise we will not able to search/replace blocks for code modifications, and therefore you will receive a score of 0 for your generated code.", help="User prompt to use", ), + topo_sort_dependencies: bool = typer.Option( + True, + help="Topologically sort the dependencies of the repository", + ), + add_import_module_to_context: bool = typer.Option( + True, + help="Add the import module code to the context", + ), run_tests: bool = typer.Option( False, help="Run the tests after the agent is done", @@ -123,6 +131,10 @@ def config( 10000, help="Maximum length of the lint information to use", ), + run_entire_dir_lint: bool = typer.Option( + False, + help="Run the lint on the entire directory", + ), pre_commit_config_path: str = typer.Option( ".pre-commit-config.yaml", help="Path to the pre-commit config file", @@ -145,6 +157,8 @@ def config( "use_user_prompt": use_user_prompt, "user_prompt": user_prompt, "run_tests": run_tests, + "use_topo_sort_dependencies": topo_sort_dependencies, + "add_import_module_to_context": add_import_module_to_context, "max_iteration": max_iteration, "use_repo_info": use_repo_info, "max_repo_info_length": max_repo_info_length, @@ -154,6 +168,7 @@ def config( "max_spec_info_length": max_spec_info_length, "use_lint_info": use_lint_info, "max_lint_info_length": max_lint_info_length, + "run_entire_dir_lint": run_entire_dir_lint, "pre_commit_config_path": pre_commit_config_path, } diff --git a/agent/display.py b/agent/display.py index 53d01fe..c321908 100644 --- a/agent/display.py +++ b/agent/display.py @@ -96,6 +96,7 @@ def __init__(self, total_repos: int): self.start_time_per_repo = {} self.end_time_per_repo = {} self.total_time_spent = 0 + self.branch_name = "" self.overall_progress = Progress( SpinnerColumn(), @@ -144,6 +145,7 @@ def __init__(self, total_repos: int): Layout(name="agent_name", ratio=1), Layout(name="model_name", ratio=1), Layout(name="run_tests", ratio=1), + Layout(name="use_topo_sort_dependencies", ratio=1), Layout(name="use_repo_info", ratio=1), Layout(name="use_unit_tests_info", ratio=1), Layout(name="use_spec_info", ratio=1), @@ -192,6 +194,7 @@ def update_agent_display( agent_name: str, model_name: str, run_tests: bool, + use_topo_sort_dependencies: bool, use_repo_info: bool, use_unit_tests_info: bool, use_spec_info: bool, @@ -202,6 +205,11 @@ def update_agent_display( ("agent_name", "Agent", agent_name), ("model_name", "Model", model_name), ("run_tests", "Run Tests", run_tests), + ( + "use_topo_sort_dependencies", + "Topo Sort Dependencies", + use_topo_sort_dependencies, + ), ("use_repo_info", "Use Repo Info", use_repo_info), ("use_unit_tests_info", "Use Unit Tests", use_unit_tests_info), ("use_spec_info", "Use Spec", use_spec_info), @@ -236,6 +244,7 @@ def update_time_display(self, time_in_seconds: int) -> None: def update_branch_display(self, branch: str) -> None: """Update the branch display with the given branch.""" + self.branch_name = branch self.branch_display = Text(f"{branch}", justify="center") self.layout["info"]["other_info"]["branch"].update( Panel(self.branch_display, title="Branch", border_style="blue") @@ -428,7 +437,10 @@ def __exit__( ], } - with open("processing_summary.json", "w") as json_file: + with open( + f"processing_summary_{self.branch_name}.json", + "w", + ) as json_file: json.dump(summary_data, json_file, indent=4) print("\nSummary has been written to processing_summary.json") diff --git a/agent/run_agent.py b/agent/run_agent.py index 5315086..a140a5a 100644 --- a/agent/run_agent.py +++ b/agent/run_agent.py @@ -7,10 +7,12 @@ create_branch, get_message, get_target_edit_files, + get_changed_files_from_commits, update_message_with_dependencies, get_lint_cmd, read_yaml_config, ) +import subprocess from agent.agents import AiderAgents from typing import Optional, Type, cast from types import TracebackType @@ -47,19 +49,22 @@ def run_agent_for_repo( repo_base_dir: str, agent_config: AgentConfig, example: RepoInstance, - update_queue: multiprocessing.Queue, branch: str, + update_queue: multiprocessing.Queue, override_previous_changes: bool = False, backend: str = "modal", log_dir: str = str(RUN_AGENT_LOG_DIR.resolve()), + commit0_config_file: str = "", ) -> None: """Run Aider for a given repository.""" # get repo info + commit0_config = read_commit0_dot_file(commit0_config_file) + + assert "commit0" in commit0_config["dataset_name"] _, repo_name = example["repo"].split("/") # before starting, display all information to terminal - original_repo_name = repo_name - update_queue.put(("start_repo", (original_repo_name, 0))) + update_queue.put(("start_repo", (repo_name, 0))) # repo_name = repo_name.lower() # repo_name = repo_name.replace(".", "-") @@ -81,6 +86,13 @@ def run_agent_for_repo( f"{agent_config.agent_name} is not implemented; please add your implementations in baselines/agents.py." ) + # Check if there are changes in the current branch + if local_repo.is_dirty(): + # Stage all changes + local_repo.git.add(A=True) + # Commit changes with the message "left from last change" + local_repo.index.commit("left from last change") + # # if branch_name is not provided, create a new branch name based on agent_config # if branch is None: # branch = args2string(agent_config) @@ -92,12 +104,18 @@ def run_agent_for_repo( if latest_commit.hexsha != example["base_commit"] and override_previous_changes: local_repo.git.reset("--hard", example["base_commit"]) + # get target files to edit and test files to run target_edit_files, import_dependencies = get_target_edit_files( local_repo, example["src_dir"], example["test"]["test_dir"], - str(latest_commit), + branch, example["reference_commit"], + agent_config.use_topo_sort_dependencies, + ) + + lint_files = get_changed_files_from_commits( + local_repo, "HEAD", example["base_commit"] ) # Call the commit0 get-tests command to retrieve test files test_files_str = get_tests(repo_name, verbose=0) @@ -117,27 +135,26 @@ def run_agent_for_repo( with open(agent_config_log_file, "w") as agent_config_file: yaml.dump(agent_config, agent_config_file) - # TODO: make this path more general - commit0_dot_file_path = str(Path(repo_path).parent.parent / ".commit0.yaml") - with DirContext(repo_path): if agent_config is None: raise ValueError("Invalid input") if agent_config.run_tests: - update_queue.put(("start_repo", (original_repo_name, len(test_files)))) + update_queue.put(("start_repo", (repo_name, len(test_files)))) # when unit test feedback is available, iterate over test files for test_file in test_files: update_queue.put(("set_current_file", (repo_name, test_file))) - test_cmd = f"python -m commit0 test {repo_path} {test_file} --branch {branch} --backend {backend} --commit0-dot-file-path {commit0_dot_file_path}" + test_cmd = f"python -m commit0 test {repo_path} {test_file} --branch {branch} --backend {backend} --commit0-config-file {commit0_config_file} --timeout 100" test_file_name = test_file.replace(".py", "").replace("/", "__") test_log_dir = experiment_log_dir / test_file_name - lint_cmd = get_lint_cmd(repo_name, agent_config.use_lint_info) - message = get_message(agent_config, repo_path, test_file=test_file) + lint_cmd = get_lint_cmd( + repo_name, agent_config.use_lint_info, commit0_config_file + ) + message = get_message(agent_config, repo_path, test_files=[test_file]) # display the test file to terminal agent_return = agent.run( - message, + "", test_cmd, lint_cmd, target_edit_files, @@ -151,22 +168,48 @@ def run_agent_for_repo( (repo_name, test_file, agent_return.last_cost), ) ) + elif agent_config.run_entire_dir_lint: + update_queue.put(("start_repo", (repo_name, len(lint_files)))) + # when unit test feedback is available, iterate over test files + for lint_file in lint_files: + update_queue.put(("set_current_file", (repo_name, lint_file))) + lint_file_name = lint_file.replace(".py", "").replace("/", "__") + lint_log_dir = experiment_log_dir / lint_file_name + lint_cmd = get_lint_cmd( + repo_name, agent_config.use_lint_info, commit0_config_file + ) + + # display the test file to terminal + agent_return = agent.run( + "", + "", + lint_cmd, + [lint_file], + lint_log_dir, + lint_first=True, + ) + # after running the agent, update the money display + update_queue.put( + ( + "update_money_display", + (repo_name, lint_file, agent_return.last_cost), + ) + ) else: # when unit test feedback is not available, iterate over target files to edit - message = get_message( - agent_config, repo_path, test_dir=example["test"]["test_dir"] - ) + message = get_message(agent_config, repo_path, test_files=test_files) - update_queue.put( - ("start_repo", (original_repo_name, len(target_edit_files))) - ) + update_queue.put(("start_repo", (repo_name, len(target_edit_files)))) for f in target_edit_files: update_queue.put(("set_current_file", (repo_name, f))) - dependencies = import_dependencies[f] - message = update_message_with_dependencies(message, dependencies) + if agent_config.add_import_module_to_context: + dependencies = import_dependencies.get(f, []) + message = update_message_with_dependencies(message, dependencies) file_name = f.replace(".py", "").replace("/", "__") file_log_dir = experiment_log_dir / file_name - lint_cmd = get_lint_cmd(repo_name, agent_config.use_lint_info) + lint_cmd = get_lint_cmd( + repo_name, agent_config.use_lint_info, commit0_config_file + ) agent_return = agent.run(message, "", lint_cmd, [f], file_log_dir) update_queue.put( ( @@ -174,7 +217,7 @@ def run_agent_for_repo( (repo_name, file_name, agent_return.last_cost), ) ) - update_queue.put(("finish_repo", original_repo_name)) + update_queue.put(("finish_repo", repo_name)) def run_agent( @@ -192,6 +235,7 @@ def run_agent( agent_config = AgentConfig(**config) + commit0_config_file = os.path.abspath(commit0_config_file) commit0_config = read_commit0_dot_file(commit0_config_file) dataset = load_dataset( @@ -214,6 +258,16 @@ def run_agent( # if len(filtered_dataset) > 1: # sys.stdout = open(os.devnull, "w") + if agent_config.add_import_module_to_context: + # Install Chrome for Playwright for browser-based agents + try: + subprocess.run(["playwright", "install", "chromium"], check=True) + print("Chrome installed successfully for Playwright") + except subprocess.CalledProcessError as e: + print(f"Error installing Chrome for Playwright: {e}") + except FileNotFoundError: + print("Playwright not found. Make sure it's installed and in your PATH.") + with TerminalDisplay(len(filtered_dataset)) as display: not_started_repos = [ cast(RepoInstance, example)["repo"].split("/")[-1] @@ -232,6 +286,7 @@ def run_agent( agent_config.agent_name, agent_config.model_name, agent_config.run_tests, + agent_config.use_topo_sort_dependencies, agent_config.use_repo_info, agent_config.use_unit_tests_info, agent_config.use_spec_info, @@ -251,11 +306,12 @@ def run_agent( commit0_config["base_dir"], agent_config, cast(RepoInstance, example), - update_queue, branch, + update_queue, override_previous_changes, backend, log_dir, + commit0_config_file, ), ) results.append(result) diff --git a/agent/run_agent_no_rich.py b/agent/run_agent_no_rich.py index ec1334a..ceb4fb2 100644 --- a/agent/run_agent_no_rich.py +++ b/agent/run_agent_no_rich.py @@ -5,14 +5,15 @@ from datasets import load_dataset from git import Repo from agent.agent_utils import ( - args2string, create_branch, get_message, get_target_edit_files, + get_changed_files_from_commits, update_message_with_dependencies, get_lint_cmd, read_yaml_config, ) +import subprocess from agent.agents import AiderAgents from typing import Optional, Type, cast from types import TracebackType @@ -46,15 +47,18 @@ def run_agent_for_repo( repo_base_dir: str, agent_config: AgentConfig, example: RepoInstance, - branch: Optional[str] = None, + branch: str, override_previous_changes: bool = False, backend: str = "modal", log_dir: str = str(RUN_AGENT_LOG_DIR.resolve()), + commit0_config_file: str = "", ) -> None: """Run Aider for a given repository.""" # get repo info + commit0_config = read_commit0_dot_file(commit0_config_file) + + assert "commit0" in commit0_config["dataset_name"] _, repo_name = example["repo"].split("/") - print("Working on repo: ", repo_name) # repo_name = repo_name.lower() # repo_name = repo_name.replace(".", "-") @@ -76,10 +80,16 @@ def run_agent_for_repo( f"{agent_config.agent_name} is not implemented; please add your implementations in baselines/agents.py." ) - # if branch_name is not provided, create a new branch name based on agent_config - if branch is None: - branch = args2string(agent_config) + # Check if there are changes in the current branch + if local_repo.is_dirty(): + # Stage all changes + local_repo.git.add(A=True) + # Commit changes with the message "left from last change" + local_repo.index.commit("left from last change") + # # if branch_name is not provided, create a new branch name based on agent_config + # if branch is None: + # branch = args2string(agent_config) create_branch(local_repo, branch, example["base_commit"]) # in cases where the latest commit of branch is not commit 0 @@ -93,10 +103,14 @@ def run_agent_for_repo( local_repo, example["src_dir"], example["test"]["test_dir"], - str(latest_commit), - str(example["reference_commit"]), + branch, + example["reference_commit"], + agent_config.use_topo_sort_dependencies, ) + lint_files = get_changed_files_from_commits( + local_repo, "HEAD", example["base_commit"] + ) # Call the commit0 get-tests command to retrieve test files test_files_str = get_tests(repo_name, verbose=0) test_files = sorted(list(set([i.split(":")[0] for i in test_files_str]))) @@ -115,9 +129,6 @@ def run_agent_for_repo( with open(agent_config_log_file, "w") as agent_config_file: yaml.dump(agent_config, agent_config_file) - # TODO: make this path more general - commit0_dot_file_path = str(Path(repo_path).parent.parent / ".commit0.yaml") - with DirContext(repo_path): if agent_config is None: raise ValueError("Invalid input") @@ -125,33 +136,55 @@ def run_agent_for_repo( if agent_config.run_tests: # when unit test feedback is available, iterate over test files for test_file in test_files: - test_cmd = f"python -m commit0 test {repo_path} {test_file} --branch {branch} --backend {backend} --commit0-dot-file-path {commit0_dot_file_path}" + test_cmd = f"python -m commit0 test {repo_path} {test_file} --branch {branch} --backend {backend} --commit0-config-file {commit0_config_file} --timeout 100" test_file_name = test_file.replace(".py", "").replace("/", "__") test_log_dir = experiment_log_dir / test_file_name - lint_cmd = get_lint_cmd(repo_name, agent_config.use_lint_info) - message = get_message(agent_config, repo_path, test_file=test_file) + lint_cmd = get_lint_cmd( + repo_name, agent_config.use_lint_info, commit0_config_file + ) + message = get_message(agent_config, repo_path, test_files=[test_file]) + + # display the test file to terminal _ = agent.run( - message, + "", test_cmd, lint_cmd, target_edit_files, test_log_dir, test_first=True, ) - # cost = agent_return.last_cost + elif agent_config.run_entire_dir_lint: + # when unit test feedback is available, iterate over test files + for lint_file in lint_files: + lint_file_name = lint_file.replace(".py", "").replace("/", "__") + lint_log_dir = experiment_log_dir / lint_file_name + lint_cmd = get_lint_cmd( + repo_name, agent_config.use_lint_info, commit0_config_file + ) + + # display the test file to terminal + _ = agent.run( + "", + "", + lint_cmd, + [lint_file], + lint_log_dir, + lint_first=True, + ) else: # when unit test feedback is not available, iterate over target files to edit - message = get_message( - agent_config, repo_path, test_dir=example["test"]["test_dir"] - ) + message = get_message(agent_config, repo_path, test_files=test_files) + for f in target_edit_files: - dependencies = import_dependencies[f] - message = update_message_with_dependencies(message, dependencies) + if agent_config.add_import_module_to_context: + dependencies = import_dependencies.get(f, []) + message = update_message_with_dependencies(message, dependencies) file_name = f.replace(".py", "").replace("/", "__") file_log_dir = experiment_log_dir / file_name - lint_cmd = get_lint_cmd(repo_name, agent_config.use_lint_info) + lint_cmd = get_lint_cmd( + repo_name, agent_config.use_lint_info, commit0_config_file + ) _ = agent.run(message, "", lint_cmd, [f], file_log_dir) - # cost = agent_return.last_cost def run_agent( @@ -171,6 +204,7 @@ def run_agent( agent_config = AgentConfig(**config) + commit0_config_file = os.path.abspath(commit0_config_file) commit0_config = read_commit0_dot_file(commit0_config_file) dataset = load_dataset( @@ -192,6 +226,15 @@ def run_agent( # if len(filtered_dataset) > 1: # sys.stdout = open(os.devnull, "w") + if agent_config.add_import_module_to_context: + # Install Chrome for Playwright for browser-based agents + try: + subprocess.run(["playwright", "install", "chromium"], check=True) + print("Chrome installed successfully for Playwright") + except subprocess.CalledProcessError as e: + print(f"Error installing Chrome for Playwright: {e}") + except FileNotFoundError: + print("Playwright not found. Make sure it's installed and in your PATH.") with tqdm( total=len(filtered_dataset), smoothing=0, desc="Running Aider for repos" @@ -211,6 +254,7 @@ def run_agent( override_previous_changes, backend, log_dir, + commit0_config_file, ), callback=lambda _: pbar.update( 1 diff --git a/docs/api.md b/docs/api.md index 4069e03..76fb52d 100644 --- a/docs/api.md +++ b/docs/api.md @@ -13,7 +13,7 @@ Available options include: | `--dataset-name` | str | Name of the Huggingface dataset | `wentingzhao/commit0_combined` | | `--dataset-split` | str | Split of the Huggingface dataset | `test` | | `--base-dir` | str | Base directory to clone repos to | `repos/` | -| `--commit0-dot-file-path` | str | Storing path for stateful commit0 configs | `.commit0.yaml` | +| `--commit0-config-file` | str | Storing path for stateful commit0 configs | `.commit0.yaml` | ### Build @@ -23,7 +23,7 @@ Available options include: | Argument | Type | Description | Default | |----------|------|-------------|---------| | `--num-workers` | int | Number of workers | `8` | -| `--commit0-dot-file-path` | str | Path to the commit0 dot file | `.commit0.yaml` | +| `--commit0-config-file` | str | Path to the commit0 dot file | `.commit0.yaml` | | `--verbose` | int | Verbosity level (1 or 2) | `1` | ### Get Tests @@ -50,7 +50,7 @@ Available options include: | `--reference` | bool | Test the reference commit | `False` | | `--coverage` | bool | Get coverage information | `False` | | `--rebuild` | bool | Rebuild an image | `False` | -| `--commit0-dot-file-path` | str | Path to the commit0 dot file | `.commit0.yaml` | +| `--commit0-config-file` | str | Path to the commit0 dot file | `.commit0.yaml` | | `--verbose` | int | Verbosity level (1 or 2) | `1` | | `--stdin` | bool | Read test names from stdin | `False` | @@ -68,7 +68,7 @@ Available options include: | `--num-workers` | int | Number of workers to use | `8` | | `--reference` | bool | Evaluate the reference commit | `False` | | `--coverage` | bool | Get coverage information | `False` | -| `--commit0-dot-file-path` | str | Path to the commit0 dot file | `.commit0.yaml` | +| `--commit0-config-file` | str | Path to the commit0 dot file | `.commit0.yaml` | | `--rebuild` | bool | Rebuild images | `False` | ### Lint @@ -80,7 +80,7 @@ Available options include: |----------|------|-------------|---------| | `repo_or_repo_dir` | str | Directory of the repository to test | | | `--files` | List[Path] | Files to lint (optional) | | -| `--commit0-dot-file-path` | str | Path to the commit0 dot file | `.commit0.yaml` | +| `--commit0-config-file` | str | Path to the commit0 dot file | `.commit0.yaml` | | `--verbose` | int | Verbosity level (1 or 2) | `1` | ### Save @@ -93,7 +93,7 @@ Available options include: | `owner` | str | Owner of the repository | | | `branch` | str | Branch to save | | | `--github-token` | str | GitHub token for authentication | | -| `--commit0-dot-file-path` | str | Path to the commit0 dot file | `.commit0.yaml` | +| `--commit0-config-file` | str | Path to the commit0 dot file | `.commit0.yaml` | ## Agent diff --git a/docs/render_submissions.py b/docs/render_submissions.py index e2c95c2..3fe45a0 100644 --- a/docs/render_submissions.py +++ b/docs/render_submissions.py @@ -385,7 +385,7 @@ def main(args): if args.do_setup: os.system( f"commit0 setup {args.split} --base-dir {analysis_files_path}/repos " - f"--commit0-dot-file-path {analysis_files_path}/repos/.commit0.yaml" + f"--commit0-config-file {analysis_files_path}/repos/.commit0.yaml" ) branch_name = "blank" if args.overwrite_previous_eval: @@ -429,7 +429,7 @@ def main(args): if args.do_setup: os.system( f"commit0 setup {args.split} --base-dir {submission_repos_path} " - f"--commit0-dot-file-path {commit0_dot_file_path}" + f"--commit0-config-file {commit0_dot_file_path}" ) submission_metrics_output_file = os.path.join( analysis_files_path, org_name, f"{branch_name}.json" @@ -456,7 +456,7 @@ def main(args): if args.overwrite_previous_eval or need_re_eval: os.system( "commit0 evaluate --reference " - f"--commit0-dot-file-path {commit0_dot_file_path}" + f"--commit0-config-file {commit0_dot_file_path}" ) # get coverage and pytest info for each repo for example in dataset: @@ -531,7 +531,7 @@ def main(args): # run pytests os.system( f"commit0 evaluate --branch {branch_name} " - f"--commit0-dot-file-path {commit0_dot_file_path}" + f"--commit0-config-file {commit0_dot_file_path}" ) for example in dataset: repo_name = example["repo"].split("/")[-1]