Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
101 changes: 95 additions & 6 deletions agent/agent_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
from pathlib import Path
from typing import List
import fitz
from import_deps import ModuleSet
from graphlib import TopologicalSorter, CycleError
import yaml

from agent.class_types import AgentConfig
Expand All @@ -16,6 +18,7 @@
UNIT_TESTS_INFO_HEADER = "\n\n>>> Here are the Unit Tests Information:\n"
LINT_INFO_HEADER = "\n\n>>> Here is the Lint Information:\n"
SPEC_INFO_HEADER = "\n\n>>> Here is the Specification Information:\n"
IMPORT_DEPENDENCIES_HEADER = "\n\n>>> Here are the Import Dependencies:\n"
# prefix components:
space = " "
branch = "│ "
Expand Down Expand Up @@ -190,25 +193,97 @@ def _find_files_to_edit(base_dir: str, src_dir: str, test_dir: str) -> list[str]
return files


def get_target_edit_files(target_dir: str, src_dir: str, test_dir: str) -> list[str]:
def ignore_cycles(graph: dict) -> list[str]:
"""Ignore the cycles in the graph."""
ts = TopologicalSorter(graph)
try:
return list(ts.static_order())
except CycleError as e:
# print(f"Cycle detected: {e.args[1]}")
# You can either break the cycle by modifying the graph or handle it as needed.
# For now, let's just remove the first node in the cycle and try again.
cycle_nodes = e.args[1]
node_to_remove = cycle_nodes[0]
# print(f"Removing node {node_to_remove} to resolve cycle.")
graph.pop(node_to_remove, None)
return ignore_cycles(graph)


def topological_sort_based_on_dependencies(
pkg_paths: list[str],
) -> tuple[list[str], dict]:
"""Topological sort based on dependencies."""
module_set = ModuleSet([str(p) for p in pkg_paths])

import_dependencies = {}
for path in sorted(module_set.by_path.keys()):
module_name = ".".join(module_set.by_path[path].fqn)
mod = module_set.by_name[module_name]
try:
imports = module_set.get_imports(mod)
import_dependencies[path] = set([str(x) for x in imports])
except Exception:
import_dependencies[path] = set()

import_dependencies_files = ignore_cycles(import_dependencies)

return import_dependencies_files, import_dependencies


def get_target_edit_files(
local_repo: git.Repo,
src_dir: str,
test_dir: str,
latest_commit: str,
reference_commit: str,
) -> tuple[list[str], dict]:
"""Find the files with functions with the pass statement."""
target_dir = str(local_repo.working_dir)
files = _find_files_to_edit(target_dir, src_dir, test_dir)
filtered_files = []
for file_path in files:
with open(file_path, "r", encoding="utf-8", errors="ignore") as file:
with open(file_path, "r", encoding="utf-8-sig", errors="ignore") as file:
content = file.read()
if len(content.splitlines()) > 1500:
continue
if " pass" in content:
filtered_files.append(file_path)
# Change to reference commit to get the correct dependencies
local_repo.git.checkout(reference_commit)

topological_sort_files, import_dependencies = (
topological_sort_based_on_dependencies(filtered_files)
)
if len(topological_sort_files) != len(filtered_files):
if len(topological_sort_files) < len(filtered_files):
# Find the missing elements
missing_files = set(filtered_files) - set(topological_sort_files)
# Add the missing files to the end of the list
topological_sort_files = topological_sort_files + list(missing_files)
else:
raise ValueError(
"topological_sort_files should not be longer than filtered_files"
)
assert len(topological_sort_files) == len(
filtered_files
), "all files should be included"

# change to latest commit
local_repo.git.checkout(latest_commit)

# Remove the base_dir prefix
filtered_files = [
file.replace(target_dir, "").lstrip("/") for file in filtered_files
topological_sort_files = [
file.replace(target_dir, "").lstrip("/") for file in topological_sort_files
]
# Only keep python files

return filtered_files
# Remove the base_dir prefix from import dependencies
import_dependencies_without_prefix = {}
for key, value in import_dependencies.items():
key_without_prefix = key.replace(target_dir, "").lstrip("/")
value_without_prefix = [v.replace(target_dir, "").lstrip("/") for v in value]
import_dependencies_without_prefix[key_without_prefix] = value_without_prefix

return topological_sort_files, import_dependencies_without_prefix


def get_message(
Expand Down Expand Up @@ -268,6 +343,20 @@ def get_message(
return message_to_agent


def update_message_with_dependencies(message: str, dependencies: list[str]) -> str:
"""Update the message with the dependencies."""
if len(dependencies) == 0:
return message
import_dependencies_info = f"\n{IMPORT_DEPENDENCIES_HEADER}"
for dependency in dependencies:
with open(dependency, "r") as file:
import_dependencies_info += (
f"\nHere is the content of the file {dependency}:\n{file.read()}"
)
message += import_dependencies_info
return message


def get_specification(specification_pdf_path: Path) -> str:
"""Get the reference for a given specification PDF path."""
# TODO: after pdf_to_text is available, use it to extract the text from the PDF
Expand Down
5 changes: 5 additions & 0 deletions agent/agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,11 @@ def run(
sys.stdout = open(log_file, "a")
sys.stderr = open(log_file, "a")

# Log the message
agent_message_log_file = log_dir / "agent_message.log"
with open(agent_message_log_file, "a") as f:
f.write(f"Message Sent: {message}\n\n")

# Configure httpx and backoff logging
handle_logging("httpx", log_file)
handle_logging("backoff", log_file)
Expand Down
6 changes: 6 additions & 0 deletions agent/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,10 @@ def run(
".agent.yaml",
help="Path to the agent config file",
),
commit0_config_file: str = typer.Option(
".commit0.yaml",
help="Path to the commit0 config file",
),
log_dir: str = typer.Option(
str(RUN_AGENT_LOG_DIR.resolve()),
help="Log directory to store the logs",
Expand All @@ -202,6 +206,7 @@ def run(
override_previous_changes,
backend,
agent_config_file,
commit0_config_file,
log_dir,
max_parallel_repos,
display_repo_progress_num,
Expand All @@ -212,6 +217,7 @@ def run(
override_previous_changes,
backend,
agent_config_file,
commit0_config_file,
log_dir,
max_parallel_repos,
)
28 changes: 28 additions & 0 deletions agent/display.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
from rich.align import Align
from collections import OrderedDict
from types import TracebackType
import json
from datetime import datetime


class RepoBox:
Expand Down Expand Up @@ -404,3 +406,29 @@ def __exit__(
f"{'Total':<30} {self.total_time_spent:>13.2f}s {total_files:>18} {total_money:>13.2f}$"
)
print("-" * 80)

# Write summary to JSON file

summary_data = {
"timestamp": datetime.now().isoformat(),
"total_time_spent": self.total_time_spent,
"total_files_processed": total_files,
"total_money_spent": total_money,
"repositories": [
{
"name": repo_name,
"time_spent": self.end_time_per_repo[repo_name]
- self.start_time_per_repo[repo_name],
"files_processed": self.total_files_per_repo[repo_name],
"money_spent": sum(
self.repo_money_spent.get(repo_name, {}).values()
),
}
for repo_name in self.end_time_per_repo
],
}

with open("processing_summary.json", "w") as json_file:
json.dump(summary_data, json_file, indent=4)

print("\nSummary has been written to processing_summary.json")
25 changes: 16 additions & 9 deletions agent/run_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
create_branch,
get_message,
get_target_edit_files,
update_message_with_dependencies,
get_lint_cmd,
read_yaml_config,
)
Expand Down Expand Up @@ -66,13 +67,6 @@ def run_agent_for_repo(
repo_path = os.path.join(repo_base_dir, repo_name)
repo_path = os.path.abspath(repo_path)

target_edit_files = get_target_edit_files(
repo_path, example["src_dir"], example["test"]["test_dir"]
)
# Call the commit0 get-tests command to retrieve test files
test_files_str = get_tests(repo_name, verbose=0)
test_files = sorted(list(set([i.split(":")[0] for i in test_files_str])))

try:
local_repo = Repo(repo_path)
except Exception:
Expand All @@ -90,7 +84,6 @@ def run_agent_for_repo(
# # if branch_name is not provided, create a new branch name based on agent_config
# if branch is None:
# branch = args2string(agent_config)

create_branch(local_repo, branch, example["base_commit"])

# in cases where the latest commit of branch is not commit 0
Expand All @@ -99,6 +92,17 @@ def run_agent_for_repo(
if latest_commit.hexsha != example["base_commit"] and override_previous_changes:
local_repo.git.reset("--hard", example["base_commit"])

target_edit_files, import_dependencies = get_target_edit_files(
local_repo,
example["src_dir"],
example["test"]["test_dir"],
str(latest_commit),
example["reference_commit"],
)
# Call the commit0 get-tests command to retrieve test files
test_files_str = get_tests(repo_name, verbose=0)
test_files = sorted(list(set([i.split(":")[0] for i in test_files_str])))

# prepare the log dir
experiment_log_dir = (
Path(log_dir)
Expand Down Expand Up @@ -158,6 +162,8 @@ def run_agent_for_repo(
)
for f in target_edit_files:
update_queue.put(("set_current_file", (repo_name, f)))
dependencies = import_dependencies[f]
message = update_message_with_dependencies(message, dependencies)
file_name = f.replace(".py", "").replace("/", "__")
file_log_dir = experiment_log_dir / file_name
lint_cmd = get_lint_cmd(repo_name, agent_config.use_lint_info)
Expand All @@ -176,6 +182,7 @@ def run_agent(
override_previous_changes: bool,
backend: str,
agent_config_file: str,
commit0_config_file: str,
log_dir: str,
max_parallel_repos: int,
display_repo_progress_num: int,
Expand All @@ -185,7 +192,7 @@ def run_agent(

agent_config = AgentConfig(**config)

commit0_config = read_commit0_dot_file(".commit0.yaml")
commit0_config = read_commit0_dot_file(commit0_config_file)

dataset = load_dataset(
commit0_config["dataset_name"], split=commit0_config["dataset_split"]
Expand Down
27 changes: 18 additions & 9 deletions agent/run_agent_no_rich.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
create_branch,
get_message,
get_target_edit_files,
update_message_with_dependencies,
get_lint_cmd,
read_yaml_config,
)
Expand Down Expand Up @@ -61,14 +62,6 @@ def run_agent_for_repo(
repo_path = os.path.join(repo_base_dir, repo_name)
repo_path = os.path.abspath(repo_path)

# get target files to edit and test files to run
target_edit_files = get_target_edit_files(
repo_path, example["src_dir"], example["test"]["test_dir"]
)
# Call the commit0 get-tests command to retrieve test files
test_files_str = get_tests(repo_name, verbose=0)
test_files = sorted(list(set([i.split(":")[0] for i in test_files_str])))

try:
local_repo = Repo(repo_path)
except Exception:
Expand All @@ -95,6 +88,19 @@ def run_agent_for_repo(
if latest_commit.hexsha != example["base_commit"] and override_previous_changes:
local_repo.git.reset("--hard", example["base_commit"])

# get target files to edit and test files to run
target_edit_files, import_dependencies = get_target_edit_files(
local_repo,
example["src_dir"],
example["test"]["test_dir"],
str(latest_commit),
str(example["reference_commit"]),
)

# Call the commit0 get-tests command to retrieve test files
test_files_str = get_tests(repo_name, verbose=0)
test_files = sorted(list(set([i.split(":")[0] for i in test_files_str])))

# prepare the log dir
experiment_log_dir = (
Path(log_dir)
Expand Down Expand Up @@ -139,6 +145,8 @@ def run_agent_for_repo(
agent_config, repo_path, test_dir=example["test"]["test_dir"]
)
for f in target_edit_files:
dependencies = import_dependencies[f]
message = update_message_with_dependencies(message, dependencies)
file_name = f.replace(".py", "").replace("/", "__")
file_log_dir = experiment_log_dir / file_name
lint_cmd = get_lint_cmd(repo_name, agent_config.use_lint_info)
Expand All @@ -151,6 +159,7 @@ def run_agent(
override_previous_changes: bool,
backend: str,
agent_config_file: str,
commit0_config_file: str,
log_dir: str,
max_parallel_repos: int,
) -> None:
Expand All @@ -162,7 +171,7 @@ def run_agent(

agent_config = AgentConfig(**config)

commit0_config = read_commit0_dot_file(".commit0.yaml")
commit0_config = read_commit0_dot_file(commit0_config_file)

dataset = load_dataset(
commit0_config["dataset_name"], split=commit0_config["dataset_split"]
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ requires-python = ">=3.11"
dependencies = [
"ruff>=0.6.4",
"pre-commit>=3.8.0",
"import-deps>=0.3.0",
"PyMuPDF>=1.24.5",
"modal==0.64.95",
"typer>=0.12.0",
Expand Down