Skip to content

Commit 9cba73f

Browse files
committed
Merge branch 'main' into aider
2 parents da0338c + 2f7a044 commit 9cba73f

File tree

10 files changed

+128
-53
lines changed

10 files changed

+128
-53
lines changed

.github/workflows/system.yml

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,9 +25,19 @@ jobs:
2525
- name: Get tests
2626
run: uv run commit0 get-tests simpy
2727
- name: Test
28-
run: uv run commit0 test simpy tests/test_event.py::test_succeed --reference
28+
env:
29+
MODAL_TOKEN_ID: ${{secrets.MODAL_TOKEN_ID}}
30+
MODAL_TOKEN_SECRET: ${{secrets.MODAL_TOKEN_SECRET}}
31+
run: |
32+
uv run commit0 test simpy tests/test_event.py::test_succeed --reference --rebuild
33+
uv run commit0 test simpy tests/test_event.py::test_succeed --reference
2934
- name: Evaluate
30-
run: uv run commit0 evaluate --reference
35+
env:
36+
MODAL_TOKEN_ID: ${{secrets.MODAL_TOKEN_ID}}
37+
MODAL_TOKEN_SECRET: ${{secrets.MODAL_TOKEN_SECRET}}
38+
run: |
39+
uv run commit0 evaluate --reference --rebuild
40+
uv run commit0 evaluate --reference
3141
- name: Lint
3242
run: uv run commit0 lint commit0/harness/lint.py
3343
- name: Save

agent/cli.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ def config(
8383
help="Use the user prompt instead of the default prompt",
8484
),
8585
user_prompt: str = typer.Option(
86-
"Here is your task:\nYou need to implement all functions with ' pass' and pass the unit tests.\nDo not change the names of existing functions or classes, as they may be referenced from other code like unit tests, etc.\nWhen you generate code, you must maintain the original formatting of the function stubs (such as whitespaces), otherwise we will not able to search/replace blocks for code modifications, and therefore you will receive a score of 0 for your generated code.",
86+
"Here is your task:\nYou need to complete the implementations for all functions (i.e., those with pass statements) and pass the unit tests.\nDo not change the names of existing functions or classes, as they may be referenced from other code like unit tests, etc.\nWhen you generate code, you must maintain the original formatting of the function stubs (such as whitespaces), otherwise we will not able to search/replace blocks for code modifications, and therefore you will receive a score of 0 for your generated code.",
8787
help="User prompt to use",
8888
),
8989
run_tests: bool = typer.Option(

agent/commit0_utils.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -119,9 +119,7 @@ def get_file_info(file_path: Path, prefix: str = "") -> str:
119119

120120

121121
def get_target_edit_files(target_dir: str) -> list[str]:
122-
"""Find the files with the error 'NotImplementedError('IMPLEMENT ME
123-
HERE')'.
124-
"""
122+
"""Find the files with functions with the pass statement."""
125123
files = []
126124
for root, _, filenames in os.walk(target_dir):
127125
for filename in filenames:

commit0/cli.py

Lines changed: 19 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
import commit0.harness.lint
1111
import commit0.harness.save
1212
from commit0.harness.constants import SPLIT, SPLIT_ALL
13+
from commit0.harness.utils import get_active_branch
1314
import subprocess
1415
import yaml
1516
import os
@@ -216,12 +217,15 @@ def test(
216217
branch: Union[str, None] = typer.Option(
217218
None, help="Branch to test (branch MUST be provided or use --reference)"
218219
),
219-
backend: str = typer.Option("local", help="Backend to use for testing"),
220+
backend: str = typer.Option("modal", help="Backend to use for testing"),
220221
timeout: int = typer.Option(1800, help="Timeout for tests in seconds"),
221222
num_cpus: int = typer.Option(1, help="Number of CPUs to use"),
222223
reference: Annotated[
223224
bool, typer.Option("--reference", help="Test the reference commit.")
224225
] = False,
226+
rebuild: bool = typer.Option(
227+
False, "--rebuild", help="Whether to rebuild an image"
228+
),
225229
commit0_dot_file_path: str = typer.Option(
226230
".commit0.yaml",
227231
help="Path to the commit0 dot file, where the setup config is stored",
@@ -242,29 +246,30 @@ def test(
242246

243247
commit0_config = read_commit0_dot_file(commit0_dot_file_path)
244248

245-
if not branch and not reference:
246-
raise typer.BadParameter(
247-
f"Invalid {highlight('BRANCH', Colors.RED)}. Either --reference or provide a branch name.",
248-
param_hint="BRANCH",
249-
)
250249
if reference:
251250
branch = "reference"
252-
assert branch is not None, "branch is not specified"
251+
if branch is None and not reference:
252+
git_path = os.path.join(
253+
commit0_config["base_dir"], repo_or_repo_path.split("/")[-1]
254+
)
255+
branch = get_active_branch(git_path)
253256

254-
typer.echo(f"Running tests for repository: {repo_or_repo_path}")
255-
typer.echo(f"Branch: {branch}")
256-
typer.echo(f"Test IDs: {test_ids}")
257+
if verbose == 2:
258+
typer.echo(f"Running tests for repository: {repo_or_repo_path}")
259+
typer.echo(f"Branch: {branch}")
260+
typer.echo(f"Test IDs: {test_ids}")
257261

258262
commit0.harness.run_pytest_ids.main(
259263
commit0_config["dataset_name"],
260264
commit0_config["dataset_split"],
261265
commit0_config["base_dir"],
262266
repo_or_repo_path,
263-
branch,
267+
branch, # type: ignore
264268
test_ids,
265269
backend,
266270
timeout,
267271
num_cpus,
272+
rebuild,
268273
verbose,
269274
)
270275

@@ -274,7 +279,7 @@ def evaluate(
274279
branch: Union[str, None] = typer.Option(
275280
None, help="Branch to evaluate (branch MUST be provided or use --reference)"
276281
),
277-
backend: str = typer.Option("local", help="Backend to use for evaluation"),
282+
backend: str = typer.Option("modal", help="Backend to use for evaluation"),
278283
timeout: int = typer.Option(1800, help="Timeout for evaluation in seconds"),
279284
num_cpus: int = typer.Option(1, help="Number of CPUs to use"),
280285
num_workers: int = typer.Option(8, help="Number of workers to use"),
@@ -285,17 +290,12 @@ def evaluate(
285290
".commit0.yaml",
286291
help="Path to the commit0 dot file, where the setup config is stored",
287292
),
293+
rebuild: bool = typer.Option(False, "--rebuild", help="Whether to rebuild images"),
288294
) -> None:
289295
"""Evaluate Commit0 split you choose in Setup Stage."""
290296
check_commit0_path()
291-
if not branch and not reference:
292-
raise typer.BadParameter(
293-
f"Invalid {highlight('BRANCH', Colors.RED)}. Either --reference or provide a branch name",
294-
param_hint="BRANCH",
295-
)
296297
if reference:
297298
branch = "reference"
298-
assert branch is not None, "branch is not specified"
299299

300300
commit0_config = read_commit0_dot_file(commit0_dot_file_path)
301301
check_valid(commit0_config["repo_split"], SPLIT)
@@ -313,6 +313,7 @@ def evaluate(
313313
timeout,
314314
num_cpus,
315315
num_workers,
316+
rebuild,
316317
)
317318

318319

commit0/harness/constants.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@ class Files(TypedDict):
1616
patch: Dict[str, Path]
1717

1818

19+
BASE_BRANCH = "commit0"
20+
1921
# Constants - Evaluation Log Directories
2022
BASE_IMAGE_BUILD_DIR = Path("logs/build_images/base")
2123
REPO_IMAGE_BUILD_DIR = Path("logs/build_images/repo")

commit0/harness/evaluate.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,12 @@
55
from concurrent.futures import ThreadPoolExecutor, as_completed
66
from datasets import load_dataset
77
from tqdm import tqdm
8-
from typing import Iterator
8+
from typing import Iterator, Union
99

1010
from commit0.harness.run_pytest_ids import main as run_tests
1111
from commit0.harness.get_pytest_ids import main as get_tests
1212
from commit0.harness.constants import RepoInstance, SPLIT, RUN_PYTEST_LOG_DIR
13-
from commit0.harness.utils import get_hash_string
13+
from commit0.harness.utils import get_hash_string, get_active_branch
1414

1515
logging.basicConfig(
1616
level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
@@ -23,24 +23,28 @@ def main(
2323
dataset_split: str,
2424
repo_split: str,
2525
base_dir: str,
26-
branch: str,
26+
branch: Union[str, None],
2727
backend: str,
2828
timeout: int,
2929
num_cpus: int,
3030
num_workers: int,
31+
rebuild_image: bool,
3132
) -> None:
3233
dataset: Iterator[RepoInstance] = load_dataset(dataset_name, split=dataset_split) # type: ignore
3334
repos = SPLIT[repo_split]
34-
pairs = []
35+
triples = []
3536
log_dirs = []
3637
for example in dataset:
3738
repo_name = example["repo"].split("/")[-1]
3839
if repo_split != "all" and repo_name not in SPLIT[repo_split]:
3940
continue
40-
pairs.append((repo_name, example["test"]["test_dir"]))
4141
hashed_test_ids = get_hash_string(example["test"]["test_dir"])
42+
if branch is None:
43+
git_path = os.path.join(base_dir, repo_name)
44+
branch = get_active_branch(git_path)
4245
log_dir = RUN_PYTEST_LOG_DIR / repo_name / branch / hashed_test_ids
4346
log_dirs.append(str(log_dir))
47+
triples.append((repo_name, example["test"]["test_dir"], branch))
4448

4549
with tqdm(total=len(repos), smoothing=0, desc="Evaluating repos") as pbar:
4650
with ThreadPoolExecutor(max_workers=num_workers) as executor:
@@ -57,9 +61,10 @@ def main(
5761
backend,
5862
timeout,
5963
num_cpus,
64+
rebuild_image=rebuild_image,
6065
verbose=0,
6166
): None
62-
for repo, test_dir in pairs
67+
for repo, test_dir, branch in triples
6368
}
6469
# Wait for each future to complete
6570
for future in as_completed(futures):

commit0/harness/execution_context.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ def __init__(
4444
log_dir: Path,
4545
files_to_copy: Optional[Files] = None,
4646
files_to_collect: Optional[list[str]] = None,
47+
rebuild_image: bool = False,
4748
):
4849
"""Create the remote execution context
4950
@@ -85,6 +86,7 @@ def __init__(
8586
log_dir: Path,
8687
files_to_copy: Optional[Files] = None,
8788
files_to_collect: Optional[list[str]] = None,
89+
rebuild_image: bool = False,
8890
):
8991
super().__init__(
9092
spec,
@@ -145,6 +147,7 @@ def __init__(
145147
log_dir: Path,
146148
files_to_copy: Optional[Files] = None,
147149
files_to_collect: Optional[list[str]] = None,
150+
rebuild_image: bool = False,
148151
):
149152
super().__init__(
150153
spec,
@@ -161,7 +164,7 @@ def __init__(
161164
# the image must exist on dockerhub
162165
reponame = spec.repo.split("/")[-1]
163166
image_name = f"wentingzhao/{reponame}:latest".lower()
164-
image = modal.Image.from_registry(image_name)
167+
image = modal.Image.from_registry(image_name, force_build=rebuild_image)
165168
if files_to_copy:
166169
for _, f in files_to_copy.items():
167170
image = image.copy_local_file(f["src"], f["dest"]) # type: ignore
@@ -171,14 +174,12 @@ def exec_run_with_timeout(self, command: str) -> tuple[str, bool, float]:
171174
"""Execute command on modal sandbox"""
172175
start_time = time.time()
173176
with modal.Volume.ephemeral() as vol:
174-
cp_cmd = ""
175177
if self.files_to_collect:
178+
command += " && "
176179
for fname in self.files_to_collect:
177180
remote_file = Path(self.spec.repo_directory) / fname
178-
curr_cp_cmd = f" && cp {str(remote_file)} /vol/{fname} 2>/dev/null"
179-
cp_cmd += curr_cp_cmd
180-
181-
command += cp_cmd
181+
cp_cmd = f"test -e {str(remote_file)} && cp {str(remote_file)} /vol/{fname}; "
182+
command += cp_cmd
182183
self.sandbox = modal.Sandbox.create(
183184
"bash",
184185
"-c",
@@ -199,7 +200,9 @@ def exec_run_with_timeout(self, command: str) -> tuple[str, bool, float]:
199200
timed_out = False
200201

201202
if self.files_to_collect:
202-
for fname in self.files_to_collect:
203+
fnames = vol.listdir("")
204+
for fname in fnames:
205+
fname = fname.path
203206
with (self.log_dir / fname).open("wb") as f:
204207
for data in vol.read_file(fname):
205208
f.write(data)

commit0/harness/run_pytest_ids.py

Lines changed: 30 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ def main(
3737
backend: str,
3838
timeout: int,
3939
num_cpus: int,
40+
rebuild_image: bool,
4041
verbose: int,
4142
) -> None:
4243
"""Runs the pytests for repos in a dataset.
@@ -81,15 +82,30 @@ def main(
8182
)
8283
except Exception as e:
8384
raise e
85+
commit_id = ""
8486
if branch == "reference":
8587
commit_id = example["reference_commit"]
8688
else:
87-
try:
88-
local_repo.git.checkout(branch)
89-
local_branch = local_repo.branches[branch]
90-
commit_id = local_branch.commit.hexsha
91-
except Exception as e:
92-
raise Exception(f"Problem checking out branch {branch}.\n{e}")
89+
# Check if it's a local branch
90+
if branch in local_repo.branches:
91+
commit_id = local_repo.commit(branch).hexsha
92+
else:
93+
found_remote_branch = False
94+
for remote in local_repo.remotes:
95+
remote.fetch() # Fetch latest updates from each remote
96+
97+
# Check if the branch exists in this remote
98+
for ref in remote.refs:
99+
if (
100+
ref.remote_head == branch
101+
): # Compare branch name without remote prefix
102+
commit_id = local_repo.commit(ref.name).hexsha
103+
found_remote_branch = True
104+
break # Branch found, no need to keep checking this remote
105+
if found_remote_branch:
106+
break # Stop checking other remotes if branch is found
107+
if not found_remote_branch:
108+
raise Exception(f"Branch {branch} does not exist locally or remotely.")
93109
patch = generate_patch_between_commits(
94110
local_repo, example["base_commit"], commit_id
95111
)
@@ -125,7 +141,14 @@ def main(
125141

126142
try:
127143
with execution_context(
128-
spec, logger, timeout, num_cpus, log_dir, files_to_copy, files_to_collect
144+
spec,
145+
logger,
146+
timeout,
147+
num_cpus,
148+
log_dir,
149+
files_to_copy,
150+
files_to_collect,
151+
rebuild_image,
129152
) as context:
130153
output, timed_out, total_runtime = context.exec_run_with_timeout(
131154
"/bin/bash /eval.sh"

commit0/harness/setup.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from commit0.harness.utils import (
88
clone_repo,
99
)
10-
from commit0.harness.constants import RepoInstance, SPLIT
10+
from commit0.harness.constants import BASE_BRANCH, RepoInstance, SPLIT
1111

1212

1313
logging.basicConfig(
@@ -29,7 +29,12 @@ def main(
2929
continue
3030
clone_url = f"https://github.com/{example['repo']}.git"
3131
clone_dir = os.path.abspath(os.path.join(base_dir, repo_name))
32-
clone_repo(clone_url, clone_dir, example["base_commit"], logger)
32+
branch = dataset_name.split("/")[-1]
33+
repo = clone_repo(clone_url, clone_dir, branch, logger)
34+
if BASE_BRANCH in repo.branches:
35+
repo.git.branch("-d", BASE_BRANCH)
36+
repo.git.checkout("-b", BASE_BRANCH)
37+
logger.info(f"Checked out the base branch: {BASE_BRANCH}")
3338

3439

3540
__all__ = []

0 commit comments

Comments
 (0)