diff --git a/codegen-examples/examples/swebench_agent_run/run_eval.py b/codegen-examples/examples/swebench_agent_run/run_eval.py index ff349266c..65d4b1518 100644 --- a/codegen-examples/examples/swebench_agent_run/run_eval.py +++ b/codegen-examples/examples/swebench_agent_run/run_eval.py @@ -5,7 +5,7 @@ import uuid import modal import click -from datetime import datetime +import time from codegen.extensions.swebench.harness import run_agent_on_entry from codegen.extensions.swebench.utils import SWEBenchDataset, SweBenchExample, get_swe_bench_examples from codegen.extensions.swebench.report import generate_report @@ -17,28 +17,112 @@ run_agent_modal = modal.Function.from_name(app_name="swebench-agent-run", name="run_agent_modal") -async def process_batch_modal(examples: list[SweBenchExample], num_workers=10, max_retries=3): - """Process a batch of examples concurrently using a queue system. +async def process_batch_modal(examples: list[SweBenchExample], num_workers=5, min_workers=1, max_retries=3): + """Process a batch of examples concurrently using a queue system with incremental worker scaling. Args: examples: List of SweBenchExample objects to process - num_workers: Number of examples to process concurrently + num_workers: Initial number of examples to process concurrently + min_workers: Minimum number of concurrent workers to maintain max_retries: Maximum number of retries for failed requests """ results = {} queue = asyncio.Queue() + # Shared state for worker management + state = { + "active_workers": num_workers, + "success_streak": 0, + "last_scaling_time": time.time(), + "scaling_cooldown": 0, # seconds between scaling operations + "worker_tasks": [], + "running": True, + } + + # Use a lock to protect shared state during adjustments + state_lock = asyncio.Lock() + # Initialize the queue with (example, attempt) tuples for example in examples: await queue.put((example, 0)) # 0 represents first attempt - async def process_example(example, attempt): + async def scale_down_worker(task_to_cancel=None): + """Remove a single worker when rate limiting is detected""" + async with state_lock: + # Only scale if cooldown period has passed and we're above min_workers + current_time = time.time() + if current_time - state["last_scaling_time"] < state["scaling_cooldown"] or state["active_workers"] <= min_workers: + return False + + # Reset success streak when scaling down + state["success_streak"] = 0 + state["last_scaling_time"] = current_time + + # If a specific task was provided, cancel it + if task_to_cancel and task_to_cancel in state["worker_tasks"]: + print(f"Rate limiting detected! Removing 1 worker, going from {state['active_workers']} to {state['active_workers'] - 1}") + state["worker_tasks"].remove(task_to_cancel) + task_to_cancel.cancel() + state["active_workers"] -= 1 + return True + + # Otherwise, cancel the most recently added worker + elif state["worker_tasks"]: + print(f"Rate limiting detected! Removing 1 worker, going from {state['active_workers']} to {state['active_workers'] - 1}") + task = state["worker_tasks"].pop() + task.cancel() + state["active_workers"] -= 1 + return True + + return False + + async def scale_up_worker(): + """Add a single worker when operations have been consistently successful""" + async with state_lock: + # Only scale if cooldown period has passed and we're below num_workers + current_time = time.time() + if current_time - state["last_scaling_time"] < state["scaling_cooldown"] or state["active_workers"] >= num_workers: + return False + + # Add a worker after a streak of successful operations + if state["success_streak"] >= 5: + print(f"Operations succeeding! Adding 1 worker, going from {state['active_workers']} to {state['active_workers'] + 1}") + + # Create new worker + if state["running"]: + new_task = asyncio.create_task(worker()) + state["worker_tasks"].append(new_task) + state["active_workers"] += 1 + state["success_streak"] = 0 + state["last_scaling_time"] = current_time + return True + + return False + + async def is_rate_limit_error(error): + """Determine if an error is due to rate limiting""" + # Check for common rate limit error patterns + if isinstance(error, modal.exception.Error): + error_msg = str(error).lower() + rate_limit_indicators = ["rate limit", "too many requests", "429", "throttle", "quota exceeded", "capacity", "limit exceeded"] + return any(indicator in error_msg for indicator in rate_limit_indicators) + return False + + async def process_example(example, attempt, current_task): try: result = await run_agent_modal.remote.aio(example) if result is None: print(f"Warning: Null result for {example.instance_id}") - return {"instance_id": example.instance_id, "status": "error", "error_info": {"error_type": "NullResult", "error_message": "Process returned None"}} + return {"status": "error", "instance_id": example.instance_id, "error_info": {"error_type": "NullResult", "error_message": "Process returned None"}} + + # Increment success streak and potentially scale up + async with state_lock: + state["success_streak"] += 1 + + if state["success_streak"] % 5 == 0: # Check after every 5 successes + await scale_up_worker() + return result except Exception as e: @@ -56,51 +140,86 @@ async def process_example(example, attempt): print(f"Error processing {example.instance_id} (attempt {attempt + 1}):") print(f"Type: {error_type}") print(f"Message: {str(e)}") - print("Traceback:") - print("".join(error_info["traceback"])) + + # Check if this is a rate limit error + if await is_rate_limit_error(e): + print(f"Rate limit detected on task for {example.instance_id}") + + # Scale down by removing this specific worker + scaled_down = await scale_down_worker(current_task) + + # If we're removing this worker, we need to requeue the task for another worker + if scaled_down: + # Requeue this example with the same attempt count (not incrementing) + await queue.put((example, attempt)) + return None + + # Otherwise add a small delay before retrying + await asyncio.sleep(2 * (attempt + 1)) # Exponential backoff if attempt < max_retries: await queue.put((example, attempt + 1)) return None - return {"instance_id": example.instance_id, "status": "error", "error_info": error_info} + return {"status": "error", "instance_id": example.instance_id, "error_info": error_info} async def worker(): - while True: + # Store this task reference to allow targeted cancellation + current_task = asyncio.current_task() + + while state["running"]: try: - example, attempt = await queue.get() + # Use a timeout to allow worker to check if it should exit + try: + example, attempt = await asyncio.wait_for(queue.get(), timeout=1.0) + except asyncio.TimeoutError: + continue if example.instance_id in results: queue.task_done() continue + print(f"Processing example {example.instance_id}") + process_result = await process_example(example, attempt, current_task) - result = await process_example(example, attempt) - - if result is not None: - results[example.instance_id] = result + # If we're still processing this task (not requeued due to rate limiting) + if process_result is not None: + results[example.instance_id] = {"instance_id": example.instance_id, **process_result} + print(f"Processed example {example.instance_id}") + queue.task_done() - queue.task_done() + # If None is returned, the task was requeued due to rate limiting + # and this worker is being shut down, so exit the loop + else: + print(f"Task for {example.instance_id} has been requeued") + queue.task_done() + if current_task not in state["worker_tasks"]: + break + except asyncio.CancelledError: + # Handle graceful cancellation + print("Worker task cancelled") + break except Exception as e: print(f"Worker error: {str(e)}") traceback.print_exc() queue.task_done() - # Start workers - workers = [asyncio.create_task(worker()) for _ in range(num_workers)] + # Start initial workers + state["worker_tasks"] = [asyncio.create_task(worker()) for _ in range(num_workers)] # Wait for queue to be fully processed await queue.join() - # Cancel workers - for w in workers: + # Mark as not running and cancel remaining workers + state["running"] = False + for w in state["worker_tasks"]: w.cancel() # Wait for all workers to be cancelled - await asyncio.gather(*workers, return_exceptions=True) + await asyncio.gather(*state["worker_tasks"], return_exceptions=True) # Return results in the same order as input examples - return [results[example.instance_id] for example in examples] + return [results.get(example.instance_id, {"instance_id": example.instance_id, "status": "missing"}) for example in examples] def process_batch_local(examples: list[SweBenchExample], num_workers=5, codebases: dict[str, Codebase] = {}): @@ -171,7 +290,7 @@ async def run_eval( predictions_dir.mkdir(exist_ok=True, parents=True) # Create a timestamp for this run - timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + timestamp = time.strftime("%Y-%m-%d %H:%M %Z", time.localtime(time.time())) # Process all examples in parallel batches if local: