diff --git a/agent/agent_utils.py b/agent/agent_utils.py index c6ec4d5..3c5cb60 100644 --- a/agent/agent_utils.py +++ b/agent/agent_utils.py @@ -6,6 +6,8 @@ from pathlib import Path from typing import List import fitz +from import_deps import ModuleSet +from graphlib import TopologicalSorter, CycleError import yaml from agent.class_types import AgentConfig @@ -16,6 +18,7 @@ UNIT_TESTS_INFO_HEADER = "\n\n>>> Here are the Unit Tests Information:\n" LINT_INFO_HEADER = "\n\n>>> Here is the Lint Information:\n" SPEC_INFO_HEADER = "\n\n>>> Here is the Specification Information:\n" +IMPORT_DEPENDENCIES_HEADER = "\n\n>>> Here are the Import Dependencies:\n" # prefix components: space = " " branch = "│ " @@ -190,25 +193,97 @@ def _find_files_to_edit(base_dir: str, src_dir: str, test_dir: str) -> list[str] return files -def get_target_edit_files(target_dir: str, src_dir: str, test_dir: str) -> list[str]: +def ignore_cycles(graph: dict) -> list[str]: + """Ignore the cycles in the graph.""" + ts = TopologicalSorter(graph) + try: + return list(ts.static_order()) + except CycleError as e: + # print(f"Cycle detected: {e.args[1]}") + # You can either break the cycle by modifying the graph or handle it as needed. + # For now, let's just remove the first node in the cycle and try again. + cycle_nodes = e.args[1] + node_to_remove = cycle_nodes[0] + # print(f"Removing node {node_to_remove} to resolve cycle.") + graph.pop(node_to_remove, None) + return ignore_cycles(graph) + + +def topological_sort_based_on_dependencies( + pkg_paths: list[str], +) -> tuple[list[str], dict]: + """Topological sort based on dependencies.""" + module_set = ModuleSet([str(p) for p in pkg_paths]) + + import_dependencies = {} + for path in sorted(module_set.by_path.keys()): + module_name = ".".join(module_set.by_path[path].fqn) + mod = module_set.by_name[module_name] + try: + imports = module_set.get_imports(mod) + import_dependencies[path] = set([str(x) for x in imports]) + except Exception: + import_dependencies[path] = set() + + import_dependencies_files = ignore_cycles(import_dependencies) + + return import_dependencies_files, import_dependencies + + +def get_target_edit_files( + local_repo: git.Repo, + src_dir: str, + test_dir: str, + latest_commit: str, + reference_commit: str, +) -> tuple[list[str], dict]: """Find the files with functions with the pass statement.""" + target_dir = str(local_repo.working_dir) files = _find_files_to_edit(target_dir, src_dir, test_dir) filtered_files = [] for file_path in files: - with open(file_path, "r", encoding="utf-8", errors="ignore") as file: + with open(file_path, "r", encoding="utf-8-sig", errors="ignore") as file: content = file.read() if len(content.splitlines()) > 1500: continue if " pass" in content: filtered_files.append(file_path) + # Change to reference commit to get the correct dependencies + local_repo.git.checkout(reference_commit) + + topological_sort_files, import_dependencies = ( + topological_sort_based_on_dependencies(filtered_files) + ) + if len(topological_sort_files) != len(filtered_files): + if len(topological_sort_files) < len(filtered_files): + # Find the missing elements + missing_files = set(filtered_files) - set(topological_sort_files) + # Add the missing files to the end of the list + topological_sort_files = topological_sort_files + list(missing_files) + else: + raise ValueError( + "topological_sort_files should not be longer than filtered_files" + ) + assert len(topological_sort_files) == len( + filtered_files + ), "all files should be included" + + # change to latest commit + local_repo.git.checkout(latest_commit) # Remove the base_dir prefix - filtered_files = [ - file.replace(target_dir, "").lstrip("/") for file in filtered_files + topological_sort_files = [ + file.replace(target_dir, "").lstrip("/") for file in topological_sort_files ] - # Only keep python files - return filtered_files + # Remove the base_dir prefix from import dependencies + import_dependencies_without_prefix = {} + for key, value in import_dependencies.items(): + key_without_prefix = key.replace(target_dir, "").lstrip("/") + value_without_prefix = [v.replace(target_dir, "").lstrip("/") for v in value] + import_dependencies_without_prefix[key_without_prefix] = value_without_prefix + + return topological_sort_files, import_dependencies_without_prefix def get_message( @@ -268,6 +343,20 @@ def get_message( return message_to_agent +def update_message_with_dependencies(message: str, dependencies: list[str]) -> str: + """Update the message with the dependencies.""" + if len(dependencies) == 0: + return message + import_dependencies_info = f"\n{IMPORT_DEPENDENCIES_HEADER}" + for dependency in dependencies: + with open(dependency, "r") as file: + import_dependencies_info += ( + f"\nHere is the content of the file {dependency}:\n{file.read()}" + ) + message += import_dependencies_info + return message + + def get_specification(specification_pdf_path: Path) -> str: """Get the reference for a given specification PDF path.""" # TODO: after pdf_to_text is available, use it to extract the text from the PDF diff --git a/agent/agents.py b/agent/agents.py index 9255a9f..6e7d9d8 100644 --- a/agent/agents.py +++ b/agent/agents.py @@ -90,6 +90,11 @@ def run( sys.stdout = open(log_file, "a") sys.stderr = open(log_file, "a") + # Log the message + agent_message_log_file = log_dir / "agent_message.log" + with open(agent_message_log_file, "a") as f: + f.write(f"Message Sent: {message}\n\n") + # Configure httpx and backoff logging handle_logging("httpx", log_file) handle_logging("backoff", log_file) diff --git a/agent/cli.py b/agent/cli.py index 905191b..8d06891 100644 --- a/agent/cli.py +++ b/agent/cli.py @@ -178,6 +178,10 @@ def run( ".agent.yaml", help="Path to the agent config file", ), + commit0_config_file: str = typer.Option( + ".commit0.yaml", + help="Path to the commit0 config file", + ), log_dir: str = typer.Option( str(RUN_AGENT_LOG_DIR.resolve()), help="Log directory to store the logs", @@ -202,6 +206,7 @@ def run( override_previous_changes, backend, agent_config_file, + commit0_config_file, log_dir, max_parallel_repos, display_repo_progress_num, @@ -212,6 +217,7 @@ def run( override_previous_changes, backend, agent_config_file, + commit0_config_file, log_dir, max_parallel_repos, ) diff --git a/agent/display.py b/agent/display.py index a5f389c..53d01fe 100644 --- a/agent/display.py +++ b/agent/display.py @@ -17,6 +17,8 @@ from rich.align import Align from collections import OrderedDict from types import TracebackType +import json +from datetime import datetime class RepoBox: @@ -404,3 +406,29 @@ def __exit__( f"{'Total':<30} {self.total_time_spent:>13.2f}s {total_files:>18} {total_money:>13.2f}$" ) print("-" * 80) + + # Write summary to JSON file + + summary_data = { + "timestamp": datetime.now().isoformat(), + "total_time_spent": self.total_time_spent, + "total_files_processed": total_files, + "total_money_spent": total_money, + "repositories": [ + { + "name": repo_name, + "time_spent": self.end_time_per_repo[repo_name] + - self.start_time_per_repo[repo_name], + "files_processed": self.total_files_per_repo[repo_name], + "money_spent": sum( + self.repo_money_spent.get(repo_name, {}).values() + ), + } + for repo_name in self.end_time_per_repo + ], + } + + with open("processing_summary.json", "w") as json_file: + json.dump(summary_data, json_file, indent=4) + + print("\nSummary has been written to processing_summary.json") diff --git a/agent/run_agent.py b/agent/run_agent.py index 3ef2a08..5315086 100644 --- a/agent/run_agent.py +++ b/agent/run_agent.py @@ -7,6 +7,7 @@ create_branch, get_message, get_target_edit_files, + update_message_with_dependencies, get_lint_cmd, read_yaml_config, ) @@ -66,13 +67,6 @@ def run_agent_for_repo( repo_path = os.path.join(repo_base_dir, repo_name) repo_path = os.path.abspath(repo_path) - target_edit_files = get_target_edit_files( - repo_path, example["src_dir"], example["test"]["test_dir"] - ) - # Call the commit0 get-tests command to retrieve test files - test_files_str = get_tests(repo_name, verbose=0) - test_files = sorted(list(set([i.split(":")[0] for i in test_files_str]))) - try: local_repo = Repo(repo_path) except Exception: @@ -90,7 +84,6 @@ def run_agent_for_repo( # # if branch_name is not provided, create a new branch name based on agent_config # if branch is None: # branch = args2string(agent_config) - create_branch(local_repo, branch, example["base_commit"]) # in cases where the latest commit of branch is not commit 0 @@ -99,6 +92,17 @@ def run_agent_for_repo( if latest_commit.hexsha != example["base_commit"] and override_previous_changes: local_repo.git.reset("--hard", example["base_commit"]) + target_edit_files, import_dependencies = get_target_edit_files( + local_repo, + example["src_dir"], + example["test"]["test_dir"], + str(latest_commit), + example["reference_commit"], + ) + # Call the commit0 get-tests command to retrieve test files + test_files_str = get_tests(repo_name, verbose=0) + test_files = sorted(list(set([i.split(":")[0] for i in test_files_str]))) + # prepare the log dir experiment_log_dir = ( Path(log_dir) @@ -158,6 +162,8 @@ def run_agent_for_repo( ) for f in target_edit_files: update_queue.put(("set_current_file", (repo_name, f))) + dependencies = import_dependencies[f] + message = update_message_with_dependencies(message, dependencies) file_name = f.replace(".py", "").replace("/", "__") file_log_dir = experiment_log_dir / file_name lint_cmd = get_lint_cmd(repo_name, agent_config.use_lint_info) @@ -176,6 +182,7 @@ def run_agent( override_previous_changes: bool, backend: str, agent_config_file: str, + commit0_config_file: str, log_dir: str, max_parallel_repos: int, display_repo_progress_num: int, @@ -185,7 +192,7 @@ def run_agent( agent_config = AgentConfig(**config) - commit0_config = read_commit0_dot_file(".commit0.yaml") + commit0_config = read_commit0_dot_file(commit0_config_file) dataset = load_dataset( commit0_config["dataset_name"], split=commit0_config["dataset_split"] diff --git a/agent/run_agent_no_rich.py b/agent/run_agent_no_rich.py index c46ae2f..ec1334a 100644 --- a/agent/run_agent_no_rich.py +++ b/agent/run_agent_no_rich.py @@ -9,6 +9,7 @@ create_branch, get_message, get_target_edit_files, + update_message_with_dependencies, get_lint_cmd, read_yaml_config, ) @@ -61,14 +62,6 @@ def run_agent_for_repo( repo_path = os.path.join(repo_base_dir, repo_name) repo_path = os.path.abspath(repo_path) - # get target files to edit and test files to run - target_edit_files = get_target_edit_files( - repo_path, example["src_dir"], example["test"]["test_dir"] - ) - # Call the commit0 get-tests command to retrieve test files - test_files_str = get_tests(repo_name, verbose=0) - test_files = sorted(list(set([i.split(":")[0] for i in test_files_str]))) - try: local_repo = Repo(repo_path) except Exception: @@ -95,6 +88,19 @@ def run_agent_for_repo( if latest_commit.hexsha != example["base_commit"] and override_previous_changes: local_repo.git.reset("--hard", example["base_commit"]) + # get target files to edit and test files to run + target_edit_files, import_dependencies = get_target_edit_files( + local_repo, + example["src_dir"], + example["test"]["test_dir"], + str(latest_commit), + str(example["reference_commit"]), + ) + + # Call the commit0 get-tests command to retrieve test files + test_files_str = get_tests(repo_name, verbose=0) + test_files = sorted(list(set([i.split(":")[0] for i in test_files_str]))) + # prepare the log dir experiment_log_dir = ( Path(log_dir) @@ -139,6 +145,8 @@ def run_agent_for_repo( agent_config, repo_path, test_dir=example["test"]["test_dir"] ) for f in target_edit_files: + dependencies = import_dependencies[f] + message = update_message_with_dependencies(message, dependencies) file_name = f.replace(".py", "").replace("/", "__") file_log_dir = experiment_log_dir / file_name lint_cmd = get_lint_cmd(repo_name, agent_config.use_lint_info) @@ -151,6 +159,7 @@ def run_agent( override_previous_changes: bool, backend: str, agent_config_file: str, + commit0_config_file: str, log_dir: str, max_parallel_repos: int, ) -> None: @@ -162,7 +171,7 @@ def run_agent( agent_config = AgentConfig(**config) - commit0_config = read_commit0_dot_file(".commit0.yaml") + commit0_config = read_commit0_dot_file(commit0_config_file) dataset = load_dataset( commit0_config["dataset_name"], split=commit0_config["dataset_split"] diff --git a/pyproject.toml b/pyproject.toml index 8befc62..7666711 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -11,6 +11,7 @@ requires-python = ">=3.11" dependencies = [ "ruff>=0.6.4", "pre-commit>=3.8.0", + "import-deps>=0.3.0", "PyMuPDF>=1.24.5", "modal==0.64.95", "typer>=0.12.0",