Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
165 changes: 142 additions & 23 deletions codegen-examples/examples/swebench_agent_run/run_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import uuid
import modal
import click
from datetime import datetime
import time
from codegen.extensions.swebench.harness import run_agent_on_entry
from codegen.extensions.swebench.utils import SWEBenchDataset, SweBenchExample, get_swe_bench_examples
from codegen.extensions.swebench.report import generate_report
Expand All @@ -17,28 +17,112 @@
run_agent_modal = modal.Function.from_name(app_name="swebench-agent-run", name="run_agent_modal")


async def process_batch_modal(examples: list[SweBenchExample], num_workers=10, max_retries=3):
"""Process a batch of examples concurrently using a queue system.
async def process_batch_modal(examples: list[SweBenchExample], num_workers=5, min_workers=1, max_retries=3):
"""Process a batch of examples concurrently using a queue system with incremental worker scaling.

Args:
examples: List of SweBenchExample objects to process
num_workers: Number of examples to process concurrently
num_workers: Initial number of examples to process concurrently
min_workers: Minimum number of concurrent workers to maintain
max_retries: Maximum number of retries for failed requests
"""
results = {}
queue = asyncio.Queue()

# Shared state for worker management
state = {
"active_workers": num_workers,
"success_streak": 0,
"last_scaling_time": time.time(),
"scaling_cooldown": 0, # seconds between scaling operations
"worker_tasks": [],
"running": True,
}

# Use a lock to protect shared state during adjustments
state_lock = asyncio.Lock()

# Initialize the queue with (example, attempt) tuples
for example in examples:
await queue.put((example, 0)) # 0 represents first attempt

async def process_example(example, attempt):
async def scale_down_worker(task_to_cancel=None):
"""Remove a single worker when rate limiting is detected"""
async with state_lock:
# Only scale if cooldown period has passed and we're above min_workers
current_time = time.time()
if current_time - state["last_scaling_time"] < state["scaling_cooldown"] or state["active_workers"] <= min_workers:
return False

# Reset success streak when scaling down
state["success_streak"] = 0
state["last_scaling_time"] = current_time

# If a specific task was provided, cancel it
if task_to_cancel and task_to_cancel in state["worker_tasks"]:
print(f"Rate limiting detected! Removing 1 worker, going from {state['active_workers']} to {state['active_workers'] - 1}")
state["worker_tasks"].remove(task_to_cancel)
task_to_cancel.cancel()
state["active_workers"] -= 1
return True

# Otherwise, cancel the most recently added worker
elif state["worker_tasks"]:
print(f"Rate limiting detected! Removing 1 worker, going from {state['active_workers']} to {state['active_workers'] - 1}")
task = state["worker_tasks"].pop()
task.cancel()
state["active_workers"] -= 1
return True

return False

async def scale_up_worker():
"""Add a single worker when operations have been consistently successful"""
async with state_lock:
# Only scale if cooldown period has passed and we're below num_workers
current_time = time.time()
if current_time - state["last_scaling_time"] < state["scaling_cooldown"] or state["active_workers"] >= num_workers:
return False

# Add a worker after a streak of successful operations
if state["success_streak"] >= 5:
print(f"Operations succeeding! Adding 1 worker, going from {state['active_workers']} to {state['active_workers'] + 1}")

# Create new worker
if state["running"]:
new_task = asyncio.create_task(worker())
state["worker_tasks"].append(new_task)
state["active_workers"] += 1
state["success_streak"] = 0
state["last_scaling_time"] = current_time
return True

return False

async def is_rate_limit_error(error):
"""Determine if an error is due to rate limiting"""
# Check for common rate limit error patterns
if isinstance(error, modal.exception.Error):
error_msg = str(error).lower()
rate_limit_indicators = ["rate limit", "too many requests", "429", "throttle", "quota exceeded", "capacity", "limit exceeded"]
return any(indicator in error_msg for indicator in rate_limit_indicators)
return False

async def process_example(example, attempt, current_task):
try:
result = await run_agent_modal.remote.aio(example)

if result is None:
print(f"Warning: Null result for {example.instance_id}")
return {"instance_id": example.instance_id, "status": "error", "error_info": {"error_type": "NullResult", "error_message": "Process returned None"}}
return {"status": "error", "instance_id": example.instance_id, "error_info": {"error_type": "NullResult", "error_message": "Process returned None"}}

# Increment success streak and potentially scale up
async with state_lock:
state["success_streak"] += 1

if state["success_streak"] % 5 == 0: # Check after every 5 successes
await scale_up_worker()

return result

except Exception as e:
Expand All @@ -56,51 +140,86 @@ async def process_example(example, attempt):
print(f"Error processing {example.instance_id} (attempt {attempt + 1}):")
print(f"Type: {error_type}")
print(f"Message: {str(e)}")
print("Traceback:")
print("".join(error_info["traceback"]))

# Check if this is a rate limit error
if await is_rate_limit_error(e):
print(f"Rate limit detected on task for {example.instance_id}")

# Scale down by removing this specific worker
scaled_down = await scale_down_worker(current_task)

# If we're removing this worker, we need to requeue the task for another worker
if scaled_down:
# Requeue this example with the same attempt count (not incrementing)
await queue.put((example, attempt))
return None

# Otherwise add a small delay before retrying
await asyncio.sleep(2 * (attempt + 1)) # Exponential backoff

if attempt < max_retries:
await queue.put((example, attempt + 1))
return None

return {"instance_id": example.instance_id, "status": "error", "error_info": error_info}
return {"status": "error", "instance_id": example.instance_id, "error_info": error_info}

async def worker():
while True:
# Store this task reference to allow targeted cancellation
current_task = asyncio.current_task()

while state["running"]:
try:
example, attempt = await queue.get()
# Use a timeout to allow worker to check if it should exit
try:
example, attempt = await asyncio.wait_for(queue.get(), timeout=1.0)
except asyncio.TimeoutError:
continue

if example.instance_id in results:
queue.task_done()
continue
print(f"Processing example {example.instance_id}")
process_result = await process_example(example, attempt, current_task)

result = await process_example(example, attempt)

if result is not None:
results[example.instance_id] = result
# If we're still processing this task (not requeued due to rate limiting)
if process_result is not None:
results[example.instance_id] = {"instance_id": example.instance_id, **process_result}
print(f"Processed example {example.instance_id}")
queue.task_done()

queue.task_done()
# If None is returned, the task was requeued due to rate limiting
# and this worker is being shut down, so exit the loop
else:
print(f"Task for {example.instance_id} has been requeued")
queue.task_done()
if current_task not in state["worker_tasks"]:
break

except asyncio.CancelledError:
# Handle graceful cancellation
print("Worker task cancelled")
break
except Exception as e:
print(f"Worker error: {str(e)}")
traceback.print_exc()
queue.task_done()

# Start workers
workers = [asyncio.create_task(worker()) for _ in range(num_workers)]
# Start initial workers
state["worker_tasks"] = [asyncio.create_task(worker()) for _ in range(num_workers)]

# Wait for queue to be fully processed
await queue.join()

# Cancel workers
for w in workers:
# Mark as not running and cancel remaining workers
state["running"] = False
for w in state["worker_tasks"]:
w.cancel()

# Wait for all workers to be cancelled
await asyncio.gather(*workers, return_exceptions=True)
await asyncio.gather(*state["worker_tasks"], return_exceptions=True)

# Return results in the same order as input examples
return [results[example.instance_id] for example in examples]
return [results.get(example.instance_id, {"instance_id": example.instance_id, "status": "missing"}) for example in examples]


def process_batch_local(examples: list[SweBenchExample], num_workers=5, codebases: dict[str, Codebase] = {}):
Expand Down Expand Up @@ -171,7 +290,7 @@ async def run_eval(
predictions_dir.mkdir(exist_ok=True, parents=True)

# Create a timestamp for this run
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
timestamp = time.strftime("%Y-%m-%d %H:%M %Z", time.localtime(time.time()))

# Process all examples in parallel batches
if local:
Expand Down
Loading