Skip to content

Commit 9212dcb

Browse files
committed
fix small error
1 parent 910e5c5 commit 9212dcb

File tree

4 files changed

+273
-5
lines changed

4 files changed

+273
-5
lines changed

agent/agent_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -126,7 +126,7 @@ def get_target_edit_files(target_dir: str, src_prefix: str) -> list[str]:
126126
for filename in filenames:
127127
if filename.endswith(".py"):
128128
file_path = os.path.join(root, filename)
129-
with open(file_path, "r") as file:
129+
with open(file_path, "r", encoding="utf-8", errors="ignore") as file:
130130
if " pass" in file.read():
131131
files.append(file_path)
132132

agent/cli.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import typer
22
from agent.run_agent_no_rich import run_agent as run_agent_no_rich
33
from agent.run_agent import run_agent
4+
from agent.run_agent_joblit import run_agent as run_agent_joblit
45
from commit0.harness.constants import RUN_AIDER_LOG_DIR
56
import subprocess
67
from agent.agent_utils import write_agent_config
@@ -243,3 +244,41 @@ def run_test_no_rich(
243244
log_dir,
244245
max_parallel_repos,
245246
)
247+
248+
249+
@agent_app.command()
250+
def run_test_joblit(
251+
branch: str = typer.Argument(
252+
...,
253+
help="Branch name of current run",
254+
),
255+
override_previous_changes: bool = typer.Option(
256+
False,
257+
help="If override the previous agent changes on `branch` or run the agent continuously on the new changes",
258+
),
259+
backend: str = typer.Option(
260+
"modal",
261+
help="Test backend to run the agent on, ignore this option if you are not adding `test` option to agent",
262+
),
263+
agent_config_file: str = typer.Option(
264+
".agent.yaml",
265+
help="Path to the agent config file",
266+
),
267+
log_dir: str = typer.Option(
268+
str(RUN_AIDER_LOG_DIR.resolve()),
269+
help="Log directory to store the logs",
270+
),
271+
max_parallel_repos: int = typer.Option(
272+
1,
273+
help="Maximum number of repositories for agent to run in parallel",
274+
),
275+
) -> None:
276+
"""Run the agent on the repository."""
277+
run_agent_joblit(
278+
branch,
279+
override_previous_changes,
280+
backend,
281+
agent_config_file,
282+
log_dir,
283+
max_parallel_repos,
284+
)

agent/run_agent.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -57,11 +57,13 @@ def run_agent_for_repo(
5757
# get repo info
5858
_, repo_name = example["repo"].split("/")
5959

60+
original_repo_name = repo_name
61+
6062
repo_name = repo_name.lower()
6163
repo_name = repo_name.replace(".", "-")
6264

6365
# before starting, display all information to terminal
64-
update_queue.put(("start_repo", (repo_name, 0)))
66+
update_queue.put(("start_repo", (original_repo_name, 0)))
6567

6668
repo_path = os.path.join(repo_base_dir, repo_name)
6769
repo_path = os.path.abspath(repo_path)
@@ -128,7 +130,7 @@ def run_agent_for_repo(
128130
test_files_str = get_tests(repo_name, verbose=0)
129131
test_files = sorted(list(set([i.split(":")[0] for i in test_files_str])))
130132

131-
update_queue.put(("start_repo", (repo_name, len(test_files))))
133+
update_queue.put(("start_repo", (original_repo_name, len(test_files))))
132134
# when unit test feedback is available, iterate over test files
133135
for test_file in test_files:
134136
update_queue.put(("set_current_file", (repo_name, test_file)))
@@ -159,7 +161,9 @@ def run_agent_for_repo(
159161
agent_config, repo_path, test_dir=example["test"]["test_dir"]
160162
)
161163

162-
update_queue.put(("start_repo", (repo_name, len(target_edit_files))))
164+
update_queue.put(
165+
("start_repo", (original_repo_name, len(target_edit_files)))
166+
)
163167
for f in target_edit_files:
164168
update_queue.put(("set_current_file", (repo_name, f)))
165169
file_name = f.replace(".py", "").replace("/", "__")
@@ -172,7 +176,7 @@ def run_agent_for_repo(
172176
(repo_name, file_name, agent_return.last_cost),
173177
)
174178
)
175-
update_queue.put(("finish_repo", repo_name))
179+
update_queue.put(("finish_repo", original_repo_name))
176180

177181

178182
def run_agent(

agent/run_agent_joblit.py

Lines changed: 225 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,225 @@
1+
import os
2+
import yaml
3+
import multiprocessing
4+
from tqdm import tqdm
5+
from datasets import load_dataset
6+
from git import Repo
7+
from agent.agent_utils import (
8+
args2string,
9+
create_branch,
10+
get_message,
11+
get_target_edit_files,
12+
get_lint_cmd,
13+
read_yaml_config,
14+
)
15+
from agent.agents import AiderAgents
16+
from typing import Optional, Type, cast
17+
from types import TracebackType
18+
from agent.class_types import AgentConfig
19+
from commit0.harness.constants import SPLIT
20+
from commit0.harness.get_pytest_ids import main as get_tests
21+
from commit0.harness.constants import RUN_AIDER_LOG_DIR, RepoInstance
22+
from commit0.cli import read_commit0_dot_file
23+
from pathlib import Path
24+
from datetime import datetime
25+
26+
27+
class DirContext:
28+
def __init__(self, d: str):
29+
self.dir = d
30+
self.cwd = os.getcwd()
31+
32+
def __enter__(self):
33+
os.chdir(self.dir)
34+
35+
def __exit__(
36+
self,
37+
exctype: Optional[Type[BaseException]],
38+
excinst: Optional[BaseException],
39+
exctb: Optional[TracebackType],
40+
) -> None:
41+
os.chdir(self.cwd)
42+
43+
44+
def run_agent_for_repo(
45+
repo_base_dir: str,
46+
agent_config: AgentConfig,
47+
example: RepoInstance,
48+
branch: Optional[str] = None,
49+
override_previous_changes: bool = False,
50+
backend: str = "modal",
51+
log_dir: str = str(RUN_AIDER_LOG_DIR.resolve()),
52+
) -> None:
53+
"""Run Aider for a given repository."""
54+
# get repo info
55+
_, repo_name = example["repo"].split("/")
56+
57+
repo_name = repo_name.lower()
58+
repo_name = repo_name.replace(".", "-")
59+
60+
repo_path = os.path.join(repo_base_dir, repo_name)
61+
repo_path = os.path.abspath(repo_path)
62+
63+
src_dir = os.path.join(repo_path, example["src_dir"])
64+
65+
try:
66+
local_repo = Repo(repo_path)
67+
except Exception:
68+
raise Exception(
69+
f"{repo_path} is not a git repo. Check if base_dir is correctly specified."
70+
)
71+
72+
if agent_config.agent_name == "aider":
73+
agent = AiderAgents(agent_config.max_iteration, agent_config.model_name)
74+
else:
75+
raise NotImplementedError(
76+
f"{agent_config.agent_name} is not implemented; please add your implementations in baselines/agents.py."
77+
)
78+
79+
# if branch_name is not provided, create a new branch name based on agent_config
80+
if branch is None:
81+
branch = args2string(agent_config)
82+
83+
create_branch(local_repo, branch, example["base_commit"])
84+
85+
# in cases where the latest commit of branch is not commit 0
86+
# set it back to commit 0
87+
latest_commit = local_repo.commit(branch)
88+
if latest_commit.hexsha != example["base_commit"] and override_previous_changes:
89+
local_repo.git.reset("--hard", example["base_commit"])
90+
91+
# prepare the log dir
92+
experiment_log_dir = (
93+
Path(log_dir)
94+
/ repo_name
95+
/ branch
96+
/ datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
97+
)
98+
experiment_log_dir.mkdir(parents=True, exist_ok=True)
99+
100+
# write agent_config to .agent.yaml in the log_dir for record
101+
agent_config_log_file = experiment_log_dir / ".agent.yaml"
102+
with open(agent_config_log_file, "w") as agent_config_file:
103+
yaml.dump(agent_config, agent_config_file)
104+
105+
# TODO: make this path more general
106+
commit0_dot_file_path = str(Path(repo_path).parent.parent / ".commit0.yaml")
107+
with DirContext(repo_path):
108+
if agent_config is None:
109+
raise ValueError("Invalid input")
110+
111+
target_edit_files = get_target_edit_files(
112+
src_dir, src_prefix=example["src_dir"]
113+
)
114+
115+
if agent_config.run_tests:
116+
# Call the commit0 get-tests command to retrieve test files
117+
test_files_str = get_tests(repo_name, verbose=0)
118+
test_files = sorted(list(set([i.split(":")[0] for i in test_files_str])))
119+
120+
# when unit test feedback is available, iterate over test files
121+
for test_file in test_files:
122+
test_cmd = f"python -m commit0 test {repo_path} {test_file} --branch {branch} --backend {backend} --commit0_dot_file_path {commit0_dot_file_path}"
123+
test_file_name = test_file.replace(".py", "").replace("/", "__")
124+
test_log_dir = experiment_log_dir / test_file_name
125+
lint_cmd = get_lint_cmd(repo_name, agent_config.use_lint_info)
126+
message = get_message(agent_config, repo_path, test_file=test_file)
127+
_ = agent.run(
128+
message,
129+
test_cmd,
130+
lint_cmd,
131+
target_edit_files,
132+
test_log_dir,
133+
)
134+
# cost = agent_return.last_cost
135+
else:
136+
# when unit test feedback is not available, iterate over target files to edit
137+
message = get_message(
138+
agent_config, repo_path, test_dir=example["test"]["test_dir"]
139+
)
140+
for f in target_edit_files:
141+
file_name = f.replace(".py", "").replace("/", "__")
142+
file_log_dir = experiment_log_dir / file_name
143+
lint_cmd = get_lint_cmd(repo_name, agent_config.use_lint_info)
144+
_ = agent.run(message, "", lint_cmd, [f], file_log_dir)
145+
# cost = agent_return.last_cost
146+
147+
148+
def run_agent(
149+
branch: str,
150+
override_previous_changes: bool,
151+
backend: str,
152+
agent_config_file: str,
153+
log_dir: str,
154+
max_parallel_repos: int,
155+
) -> None:
156+
"""Main function to run Aider for a given repository.
157+
158+
Will run in parallel for each repo.
159+
"""
160+
config = read_yaml_config(agent_config_file)
161+
162+
agent_config = AgentConfig(**config)
163+
164+
commit0_config = read_commit0_dot_file(".commit0.yaml")
165+
166+
dataset = load_dataset(
167+
commit0_config["dataset_name"], split=commit0_config["dataset_split"]
168+
)
169+
filtered_dataset = [
170+
example
171+
for example in dataset
172+
if commit0_config["repo_split"] == "all"
173+
or (
174+
isinstance(example, dict)
175+
and "repo" in example
176+
and isinstance(example["repo"], str)
177+
and example["repo"].split("/")[-1]
178+
in SPLIT.get(commit0_config["repo_split"], [])
179+
)
180+
]
181+
assert len(filtered_dataset) > 0, "No examples available"
182+
183+
# if len(filtered_dataset) > 1:
184+
# sys.stdout = open(os.devnull, "w")
185+
print("jere")
186+
print(filtered_dataset[0])
187+
for example in filtered_dataset:
188+
if "joblib" in example["repo"]:
189+
print(example)
190+
run_agent_for_repo(
191+
commit0_config["base_dir"],
192+
agent_config,
193+
cast(RepoInstance, example),
194+
branch,
195+
override_previous_changes,
196+
backend,
197+
log_dir,
198+
)
199+
# with tqdm(
200+
# total=len(filtered_dataset), smoothing=0, desc="Running Aider for repos"
201+
# ) as pbar:
202+
# with multiprocessing.Pool(processes=max_parallel_repos) as pool:
203+
# results = []
204+
205+
# # Use apply_async to submit jobs and add progress bar updates
206+
# for example in filtered_dataset:
207+
# result = pool.apply_async(
208+
# run_agent_for_repo,
209+
# args=(
210+
# commit0_config["base_dir"],
211+
# agent_config,
212+
# cast(RepoInstance, example),
213+
# branch,
214+
# override_previous_changes,
215+
# backend,
216+
# log_dir,
217+
# ),
218+
# callback=lambda _: pbar.update(
219+
# 1
220+
# ), # Update progress bar on task completion
221+
# )
222+
# results.append(result)
223+
224+
# for result in results:
225+
# result.wait()

0 commit comments

Comments
 (0)