Skip to content

Commit c8e47cb

Browse files
committed
tmp
1 parent e79cdf7 commit c8e47cb

File tree

4 files changed

+341
-13
lines changed

4 files changed

+341
-13
lines changed

agent/agent_utils.py

Lines changed: 65 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66
from pathlib import Path
77
from typing import List
88
import fitz
9+
from import_deps import ModuleSet
10+
from graphlib import TopologicalSorter, CycleError
911
import yaml
1012

1113
from agent.class_types import AgentConfig
@@ -190,8 +192,46 @@ def _find_files_to_edit(base_dir: str, src_dir: str, test_dir: str) -> list[str]
190192
return files
191193

192194

193-
def get_target_edit_files(target_dir: str, src_dir: str, test_dir: str) -> list[str]:
195+
def ignore_cycles(graph: dict):
196+
ts = TopologicalSorter(graph)
197+
try:
198+
return list(set(ts.static_order()))
199+
except CycleError as e:
200+
# print(f"Cycle detected: {e.args[1]}")
201+
# You can either break the cycle by modifying the graph or handle it as needed.
202+
# For now, let's just remove the first node in the cycle and try again.
203+
cycle_nodes = e.args[1]
204+
node_to_remove = cycle_nodes[0]
205+
# print(f"Removing node {node_to_remove} to resolve cycle.")
206+
graph.pop(node_to_remove, None)
207+
return ignore_cycles(graph)
208+
209+
210+
def topological_sort_based_on_dependencies(pkg_paths: list[str]) -> list[str]:
211+
"""Topological sort based on dependencies."""
212+
module_set = ModuleSet([str(p) for p in pkg_paths])
213+
214+
import_dependencies = {}
215+
for path in sorted(module_set.by_path.keys()):
216+
module_name = ".".join(module_set.by_path[path].fqn)
217+
mod = module_set.by_name[module_name]
218+
imports = module_set.get_imports(mod)
219+
import_dependencies[path] = set([str(x) for x in imports])
220+
221+
import_dependencies_files = ignore_cycles(import_dependencies)
222+
223+
return import_dependencies_files
224+
225+
226+
def get_target_edit_files(
227+
local_repo: git.Repo,
228+
src_dir: str,
229+
test_dir: str,
230+
latest_commit: str,
231+
reference_commit: str,
232+
) -> list[str]:
194233
"""Find the files with functions with the pass statement."""
234+
target_dir = local_repo.working_dir
195235
files = _find_files_to_edit(target_dir, src_dir, test_dir)
196236
filtered_files = []
197237
for file_path in files:
@@ -202,13 +242,33 @@ def get_target_edit_files(target_dir: str, src_dir: str, test_dir: str) -> list[
202242
if " pass" in content:
203243
filtered_files.append(file_path)
204244

245+
# Change to reference commit to get the correct dependencies
246+
local_repo.git.checkout(reference_commit)
247+
248+
topological_sort_files = topological_sort_based_on_dependencies(filtered_files)
249+
if len(topological_sort_files) != len(filtered_files):
250+
if len(topological_sort_files) < len(filtered_files):
251+
# Find the missing elements
252+
missing_files = set(filtered_files) - set(topological_sort_files)
253+
# Add the missing files to the end of the list
254+
topological_sort_files = topological_sort_files + list(missing_files)
255+
else:
256+
raise ValueError(
257+
"topological_sort_files should not be longer than filtered_files"
258+
)
259+
assert len(topological_sort_files) == len(
260+
filtered_files
261+
), "all files should be included"
262+
263+
# change to latest commit
264+
local_repo.git.checkout(latest_commit)
265+
205266
# Remove the base_dir prefix
206-
filtered_files = [
207-
file.replace(target_dir, "").lstrip("/") for file in filtered_files
267+
topological_sort_files = [
268+
file.replace(target_dir, "").lstrip("/") for file in topological_sort_files
208269
]
209-
# Only keep python files
210270

211-
return filtered_files
271+
return topological_sort_files
212272

213273

214274
def get_message(

agent/display.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717
from rich.align import Align
1818
from collections import OrderedDict
1919
from types import TracebackType
20+
import json
21+
from datetime import datetime
2022

2123

2224
class RepoBox:
@@ -404,3 +406,29 @@ def __exit__(
404406
f"{'Total':<30} {self.total_time_spent:>13.2f}s {total_files:>18} {total_money:>13.2f}$"
405407
)
406408
print("-" * 80)
409+
410+
# Write summary to JSON file
411+
412+
summary_data = {
413+
"timestamp": datetime.now().isoformat(),
414+
"total_time_spent": self.total_time_spent,
415+
"total_files_processed": total_files,
416+
"total_money_spent": total_money,
417+
"repositories": [
418+
{
419+
"name": repo_name,
420+
"time_spent": self.end_time_per_repo[repo_name]
421+
- self.start_time_per_repo[repo_name],
422+
"files_processed": self.total_files_per_repo[repo_name],
423+
"money_spent": sum(
424+
self.repo_money_spent.get(repo_name, {}).values()
425+
),
426+
}
427+
for repo_name in self.end_time_per_repo
428+
],
429+
}
430+
431+
with open("processing_summary.json", "w") as json_file:
432+
json.dump(summary_data, json_file, indent=4)
433+
434+
print("\nSummary has been written to processing_summary.json")

agent/run_agent.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -66,13 +66,6 @@ def run_agent_for_repo(
6666
repo_path = os.path.join(repo_base_dir, repo_name)
6767
repo_path = os.path.abspath(repo_path)
6868

69-
target_edit_files = get_target_edit_files(
70-
repo_path, example["src_dir"], example["test"]["test_dir"]
71-
)
72-
# Call the commit0 get-tests command to retrieve test files
73-
test_files_str = get_tests(repo_name, verbose=0)
74-
test_files = sorted(list(set([i.split(":")[0] for i in test_files_str])))
75-
7669
try:
7770
local_repo = Repo(repo_path)
7871
except Exception:
@@ -90,7 +83,6 @@ def run_agent_for_repo(
9083
# # if branch_name is not provided, create a new branch name based on agent_config
9184
# if branch is None:
9285
# branch = args2string(agent_config)
93-
9486
create_branch(local_repo, branch, example["base_commit"])
9587

9688
# in cases where the latest commit of branch is not commit 0
@@ -99,6 +91,17 @@ def run_agent_for_repo(
9991
if latest_commit.hexsha != example["base_commit"] and override_previous_changes:
10092
local_repo.git.reset("--hard", example["base_commit"])
10193

94+
target_edit_files = get_target_edit_files(
95+
local_repo,
96+
example["src_dir"],
97+
example["test"]["test_dir"],
98+
latest_commit,
99+
example["reference_commit"],
100+
)
101+
# Call the commit0 get-tests command to retrieve test files
102+
test_files_str = get_tests(repo_name, verbose=0)
103+
test_files = sorted(list(set([i.split(":")[0] for i in test_files_str])))
104+
102105
# prepare the log dir
103106
experiment_log_dir = (
104107
Path(log_dir)

0 commit comments

Comments
 (0)