diff --git a/codeflash/optimization/function_optimizer.py b/codeflash/optimization/function_optimizer.py index 45c61b7f7..ba1b79492 100644 --- a/codeflash/optimization/function_optimizer.py +++ b/codeflash/optimization/function_optimizer.py @@ -3,11 +3,12 @@ import ast import concurrent.futures import os +import queue import random import subprocess import time import uuid -from collections import defaultdict, deque +from collections import defaultdict from pathlib import Path from typing import TYPE_CHECKING @@ -104,6 +105,83 @@ from codeflash.verification.verification_utils import TestConfig +class CandidateProcessor: + """Handles candidate processing using a queue-based approach.""" + + def __init__( + self, + initial_candidates: list, + future_line_profile_results: concurrent.futures.Future, + future_all_refinements: list, + ) -> None: + self.candidate_queue = queue.Queue() + self.line_profiler_done = False + self.refinement_done = False + self.candidate_len = len(initial_candidates) + + # Initialize queue with initial candidates + for candidate in initial_candidates: + self.candidate_queue.put(candidate) + + self.future_line_profile_results = future_line_profile_results + self.future_all_refinements = future_all_refinements + + def get_next_candidate(self) -> OptimizedCandidate | None: + """Get the next candidate from the queue, handling async results as needed.""" + try: + return self.candidate_queue.get_nowait() + except queue.Empty: + return self._handle_empty_queue() + + def _handle_empty_queue(self) -> OptimizedCandidate | None: + """Handle empty queue by checking for pending async results.""" + if not self.line_profiler_done: + return self._process_line_profiler_results() + if self.line_profiler_done and not self.refinement_done: + return self._process_refinement_results() + return None # All done + + def _process_line_profiler_results(self) -> OptimizedCandidate | None: + """Process line profiler results and add to queue.""" + logger.debug("all candidates processed, await candidates from line profiler") + concurrent.futures.wait([self.future_line_profile_results]) + line_profile_results = self.future_line_profile_results.result() + + for candidate in line_profile_results: + self.candidate_queue.put(candidate) + + self.candidate_len += len(line_profile_results) + logger.info(f"Added results from line profiler to candidates, total candidates now: {self.candidate_len}") + self.line_profiler_done = True + + return self.get_next_candidate() + + def _process_refinement_results(self) -> OptimizedCandidate | None: + """Process refinement results and add to queue.""" + concurrent.futures.wait(self.future_all_refinements) + refinement_response = [] + + for future_refinement in self.future_all_refinements: + possible_refinement = future_refinement.result() + if len(possible_refinement) > 0: + refinement_response.append(possible_refinement[0]) + + for candidate in refinement_response: + self.candidate_queue.put(candidate) + + self.candidate_len += len(refinement_response) + logger.info( + f"Added {len(refinement_response)} candidates from refinement, total candidates now: {self.candidate_len}" + ) + self.refinement_done = True + + return self.get_next_candidate() + + def is_done(self) -> bool: + """Check if processing is complete.""" + return self.line_profiler_done and self.refinement_done and self.candidate_queue.empty() + + class FunctionOptimizer: def __init__( self, @@ -378,15 +456,13 @@ def determine_best_candidate( f"{self.function_to_optimize.qualified_name}…" ) console.rule() - candidates = deque(candidates) - refinement_done = False - line_profiler_done = False + future_all_refinements: list[concurrent.futures.Future] = [] ast_code_to_id = {} valid_optimizations = [] optimizations_post = {} # we need to overwrite some opt candidates' code strings as they are no longer evaluated, instead their shorter/longer versions might be evaluated - # Start a new thread for AI service request, start loop in main thread - # check if aiservice request is complete, when it is complete, append result to the candidates list + + # Start a new thread for AI service request ai_service_client = self.aiservice_client if exp_type == "EXP0" else self.local_aiservice_client future_line_profile_results = self.executor.submit( ai_service_client.optimize_python_code_line_profiler, @@ -401,48 +477,23 @@ def determine_best_candidate( if self.experiment_id else None, ) + + # Initialize candidate processor + processor = CandidateProcessor(candidates, future_line_profile_results, future_all_refinements) candidate_index = 0 - original_len = len(candidates) - # TODO : We need to rewrite this candidate loop as a class, the container which has candidates receives new candidates at unknown times due to the async nature of lp and refinement calls, - # TODO : in addition, the refinement calls depend on line profiler calls being complete so we need to check that reliably - while True: - try: - if len(candidates) > 0: - candidate = candidates.popleft() - else: - if not line_profiler_done: - logger.debug("all candidates processed, await candidates from line profiler") - concurrent.futures.wait([future_line_profile_results]) - line_profile_results = future_line_profile_results.result() - candidates.extend(line_profile_results) - original_len += len(line_profile_results) - logger.info( - f"Added results from line profiler to candidates, total candidates now: {original_len}" - ) - line_profiler_done = True - continue - if line_profiler_done and not refinement_done: - concurrent.futures.wait(future_all_refinements) - refinement_response = [] - for future_refinement in future_all_refinements: - possible_refinement = future_refinement.result() - if len(possible_refinement) > 0: # if the api returns a valid response - refinement_response.append(possible_refinement[0]) - candidates.extend(refinement_response) - original_len += len(refinement_response) - logger.info( - f"Added {len(refinement_response)} candidates from refinement, total candidates now: {original_len}" - ) - refinement_done = True - continue - if line_profiler_done and refinement_done: - logger.debug("everything done, exiting") - break + # Process candidates using queue-based approach + while not processor.is_done(): + candidate = processor.get_next_candidate() + if candidate is None: + logger.debug("everything done, exiting") + break + + try: candidate_index += 1 get_run_tmp_file(Path(f"test_return_values_{candidate_index}.bin")).unlink(missing_ok=True) get_run_tmp_file(Path(f"test_return_values_{candidate_index}.sqlite")).unlink(missing_ok=True) - logger.info(f"Optimization candidate {candidate_index}/{original_len}:") + logger.info(f"Optimization candidate {candidate_index}/{processor.candidate_len}:") code_print(candidate.source_code.flat) # map ast normalized code to diff len, unnormalized code # map opt id to the shortest unnormalized code @@ -467,7 +518,7 @@ def determine_best_candidate( # check if this code has been evaluated before by checking the ast normalized code string normalized_code = ast.unparse(ast.parse(candidate.source_code.flat.strip())) if normalized_code in ast_code_to_id: - logger.warning( + logger.info( "Current candidate has been encountered before in testing, Skipping optimization candidate." ) past_opt_id = ast_code_to_id[normalized_code]["optimization_id"]