From 45c5ec1dc99d01f7d68b4e612f901c1978770234 Mon Sep 17 00:00:00 2001 From: nanjiangwill Date: Thu, 19 Sep 2024 15:30:17 -0400 Subject: [PATCH] add specification for aider --- baselines/commit0_utils.py | 55 ++++++++++++++++++++++++++++---------- baselines/run_agent.py | 7 ++--- pyproject.toml | 3 ++- 3 files changed, 47 insertions(+), 18 deletions(-) diff --git a/baselines/commit0_utils.py b/baselines/commit0_utils.py index 42fc88e..1bcd553 100644 --- a/baselines/commit0_utils.py +++ b/baselines/commit0_utils.py @@ -1,10 +1,10 @@ import git import os import re -import subprocess from dataclasses import asdict from pathlib import Path from typing import List +import fitz from baselines.class_types import AgentConfig @@ -13,7 +13,7 @@ REPO_INFO_HEADER = "\n\n>>> Here is the Repository Information:\n" UNIT_TESTS_INFO_HEADER = "\n\n>>> Here are the Unit Tests Information:\n" LINT_INFO_HEADER = "\n\n>>> Here is the Lint Information:\n" - +SPEC_INFO_HEADER = "\n\n>>> Here is the Specification Information:\n" # prefix components: space = " " branch = "│ " @@ -122,14 +122,14 @@ def get_target_edit_files(target_dir: str) -> list[str]: """Find the files with the error 'NotImplementedError('IMPLEMENT ME HERE')'. """ - # The grep command - command = f"grep -R -l \"NotImplementedError('IMPLEMENT ME HERE')\" {target_dir}" - - # Run the command and capture the output - result = subprocess.run(command, shell=True, capture_output=True, text=True) - - # Split the output into lines and remove the base_dir prefix - files = result.stdout.strip().split("\n") + files = [] + for root, _, filenames in os.walk(target_dir): + for filename in filenames: + if filename.endswith(".py"): + file_path = os.path.join(root, filename) + with open(file_path, "r") as file: + if "NotImplementedError('IMPLEMENT ME HERE')" in file.read(): + files.append(file_path) # Remove the base_dir prefix files = [file.replace(target_dir, "").lstrip("/") for file in files] @@ -143,7 +143,8 @@ def get_target_edit_files(target_dir: str) -> list[str]: def get_message( agent_config: AgentConfig, repo_path: str, - test_dir: str, + test_dir: str | None = None, + test_file: str | None = None, ) -> str: """Get the message to Aider.""" prompt = f"{PROMPT_HEADER}" + agent_config.user_prompt @@ -157,6 +158,13 @@ def get_message( include_stubs=True, )[: agent_config.max_unit_tests_info_length] ) + elif agent_config.use_unit_tests_info and test_file: + unit_tests_info = ( + f"\n{UNIT_TESTS_INFO_HEADER} " + + get_file_info( + file_path=Path(os.path.join(repo_path, test_file)), prefix="" + )[: agent_config.max_unit_tests_info_length] + ) else: unit_tests_info = "" @@ -171,15 +179,34 @@ def get_message( else: repo_info = "" - message_to_agent = prompt + repo_info + unit_tests_info + if agent_config.use_spec_info: + spec_info = ( + f"\n{SPEC_INFO_HEADER} " + + get_specification(specification_pdf_path=Path(repo_path, "spec.pdf"))[ + : agent_config.max_spec_info_length + ] + ) + else: + spec_info = "" + + message_to_agent = prompt + repo_info + unit_tests_info + spec_info return message_to_agent -def get_reference(specification_pdf_path: str) -> str: +def get_specification(specification_pdf_path: Path) -> str: """Get the reference for a given specification PDF path.""" # TODO: after pdf_to_text is available, use it to extract the text from the PDF - return f"/pdf {specification_pdf_path}" + # Open the specified PDF file + document = fitz.open(specification_pdf_path) + text = "" + + # Iterate through the pages + for page_num in range(len(document)): + page = document.load_page(page_num) # loads the specified page + text += page.get_text() # type: ignore + + return text def create_branch(repo: git.Repo, branch: str, from_commit: str) -> None: diff --git a/baselines/run_agent.py b/baselines/run_agent.py index bb24d45..0535ba3 100644 --- a/baselines/run_agent.py +++ b/baselines/run_agent.py @@ -81,12 +81,10 @@ def run_agent_for_repo( if latest_commit.hexsha != example["base_commit"]: local_repo.git.reset("--hard", example["base_commit"]) target_edit_files = get_target_edit_files(repo_path) - with DirContext(repo_path): if commit0_config is None or agent_config is None: raise ValueError("Invalid input") - message = get_message(agent_config, repo_path, example["test"]["test_dir"]) if agent_config.run_tests: # when unit test feedback is available, iterate over test files for test_file in test_files: @@ -94,7 +92,7 @@ def run_agent_for_repo( test_file_name = test_file.replace(".py", "").replace("/", "__") log_dir = RUN_AIDER_LOG_DIR / "with_tests" / test_file_name lint_cmd = get_lint_cmd(local_repo, agent_config.use_lint_info) - + message = get_message(agent_config, repo_path, test_file=test_file) agent.run( message, test_cmd, @@ -104,6 +102,9 @@ def run_agent_for_repo( ) else: # when unit test feedback is not available, iterate over target files to edit + message = get_message( + agent_config, repo_path, test_dir=example["test"]["test_dir"] + ) for f in target_edit_files: file_name = f.replace(".py", "").replace("/", "__") log_dir = RUN_AIDER_LOG_DIR / "no_tests" / file_name diff --git a/pyproject.toml b/pyproject.toml index 8c7fb6f..c3f1d03 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -11,8 +11,9 @@ dependencies = [ "ruff>=0.6.4", "pre-commit>=3.8.0", "hydra-core>=1.3.2", + "PyMuPDF>=1.24.5", + "aider-chat>=0.56.0", "modal>=0.64.95", - "aider-chat", "datasets>=3.0.0", "docker>=7.1.0", "fastcore>=1.7.8",