Skip to content

Commit 99779c9

Browse files
committed
update
1 parent c8e47cb commit 99779c9

File tree

3 files changed

+12
-11
lines changed

3 files changed

+12
-11
lines changed

agent/agent_utils.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -192,10 +192,11 @@ def _find_files_to_edit(base_dir: str, src_dir: str, test_dir: str) -> list[str]
192192
return files
193193

194194

195-
def ignore_cycles(graph: dict):
195+
def ignore_cycles(graph: dict) -> list[str]:
196+
"""Ignore the cycles in the graph."""
196197
ts = TopologicalSorter(graph)
197198
try:
198-
return list(set(ts.static_order()))
199+
return list(ts.static_order())
199200
except CycleError as e:
200201
# print(f"Cycle detected: {e.args[1]}")
201202
# You can either break the cycle by modifying the graph or handle it as needed.
@@ -231,7 +232,7 @@ def get_target_edit_files(
231232
reference_commit: str,
232233
) -> list[str]:
233234
"""Find the files with functions with the pass statement."""
234-
target_dir = local_repo.working_dir
235+
target_dir = str(local_repo.working_dir)
235236
files = _find_files_to_edit(target_dir, src_dir, test_dir)
236237
filtered_files = []
237238
for file_path in files:
@@ -241,10 +242,8 @@ def get_target_edit_files(
241242
continue
242243
if " pass" in content:
243244
filtered_files.append(file_path)
244-
245245
# Change to reference commit to get the correct dependencies
246246
local_repo.git.checkout(reference_commit)
247-
248247
topological_sort_files = topological_sort_based_on_dependencies(filtered_files)
249248
if len(topological_sort_files) != len(filtered_files):
250249
if len(topological_sort_files) < len(filtered_files):

agent/run_agent.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -87,8 +87,8 @@ def run_agent_for_repo(
8787

8888
# in cases where the latest commit of branch is not commit 0
8989
# set it back to commit 0
90-
latest_commit = local_repo.commit(branch)
91-
if latest_commit.hexsha != example["base_commit"] and override_previous_changes:
90+
latest_commit = str(local_repo.commit(branch))
91+
if latest_commit != example["base_commit"] and override_previous_changes:
9292
local_repo.git.reset("--hard", example["base_commit"])
9393

9494
target_edit_files = get_target_edit_files(

agent/run_agent_test.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,5 @@
11
import os
22
import yaml
3-
import multiprocessing
4-
from tqdm import tqdm
53
from datasets import load_dataset
64
from git import Repo
75
from agent.agent_utils import (
@@ -13,7 +11,7 @@
1311
read_yaml_config,
1412
)
1513
from agent.agents import AiderAgents
16-
from typing import Optional, Type, cast
14+
from typing import Optional, Type
1715
from types import TracebackType
1816
from agent.class_types import AgentConfig
1917
from commit0.harness.constants import SPLIT
@@ -89,7 +87,11 @@ def run_agent_for_repo(
8987

9088
# get target files to edit and test files to run
9189
target_edit_files = get_target_edit_files(
92-
local_repo, example["src_dir"], example["test"]["test_dir"], latest_commit, example["reference_commit"]
90+
local_repo,
91+
example["src_dir"],
92+
example["test"]["test_dir"],
93+
latest_commit,
94+
example["reference_commit"],
9395
)
9496
print(target_edit_files)
9597
return

0 commit comments

Comments
 (0)