Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
139 changes: 95 additions & 44 deletions codeflash/optimization/function_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,12 @@
import ast
import concurrent.futures
import os
import queue
import random
import subprocess
import time
import uuid
from collections import defaultdict, deque
from collections import defaultdict
from pathlib import Path
from typing import TYPE_CHECKING

Expand Down Expand Up @@ -104,6 +105,83 @@
from codeflash.verification.verification_utils import TestConfig


class CandidateProcessor:
"""Handles candidate processing using a queue-based approach."""

def __init__(
self,
initial_candidates: list,
future_line_profile_results: concurrent.futures.Future,
future_all_refinements: list,
) -> None:
self.candidate_queue = queue.Queue()
self.line_profiler_done = False
self.refinement_done = False
self.candidate_len = len(initial_candidates)

# Initialize queue with initial candidates
for candidate in initial_candidates:
self.candidate_queue.put(candidate)

self.future_line_profile_results = future_line_profile_results
self.future_all_refinements = future_all_refinements

def get_next_candidate(self) -> OptimizedCandidate | None:
"""Get the next candidate from the queue, handling async results as needed."""
try:
return self.candidate_queue.get_nowait()
except queue.Empty:
return self._handle_empty_queue()

def _handle_empty_queue(self) -> OptimizedCandidate | None:
"""Handle empty queue by checking for pending async results."""
if not self.line_profiler_done:
return self._process_line_profiler_results()
if self.line_profiler_done and not self.refinement_done:
return self._process_refinement_results()
return None # All done

def _process_line_profiler_results(self) -> OptimizedCandidate | None:
"""Process line profiler results and add to queue."""
logger.debug("all candidates processed, await candidates from line profiler")
concurrent.futures.wait([self.future_line_profile_results])
line_profile_results = self.future_line_profile_results.result()

for candidate in line_profile_results:
self.candidate_queue.put(candidate)

self.candidate_len += len(line_profile_results)
logger.info(f"Added results from line profiler to candidates, total candidates now: {self.candidate_len}")
self.line_profiler_done = True

return self.get_next_candidate()

def _process_refinement_results(self) -> OptimizedCandidate | None:
"""Process refinement results and add to queue."""
concurrent.futures.wait(self.future_all_refinements)
refinement_response = []

for future_refinement in self.future_all_refinements:
possible_refinement = future_refinement.result()
if len(possible_refinement) > 0:
refinement_response.append(possible_refinement[0])

for candidate in refinement_response:
self.candidate_queue.put(candidate)

self.candidate_len += len(refinement_response)
logger.info(
f"Added {len(refinement_response)} candidates from refinement, total candidates now: {self.candidate_len}"
)
self.refinement_done = True

return self.get_next_candidate()

def is_done(self) -> bool:
"""Check if processing is complete."""
return self.line_profiler_done and self.refinement_done and self.candidate_queue.empty()


class FunctionOptimizer:
def __init__(
self,
Expand Down Expand Up @@ -378,15 +456,13 @@ def determine_best_candidate(
f"{self.function_to_optimize.qualified_name}…"
)
console.rule()
candidates = deque(candidates)
refinement_done = False
line_profiler_done = False

future_all_refinements: list[concurrent.futures.Future] = []
ast_code_to_id = {}
valid_optimizations = []
optimizations_post = {} # we need to overwrite some opt candidates' code strings as they are no longer evaluated, instead their shorter/longer versions might be evaluated
# Start a new thread for AI service request, start loop in main thread
# check if aiservice request is complete, when it is complete, append result to the candidates list

# Start a new thread for AI service request
ai_service_client = self.aiservice_client if exp_type == "EXP0" else self.local_aiservice_client
future_line_profile_results = self.executor.submit(
ai_service_client.optimize_python_code_line_profiler,
Expand All @@ -401,48 +477,23 @@ def determine_best_candidate(
if self.experiment_id
else None,
)

# Initialize candidate processor
processor = CandidateProcessor(candidates, future_line_profile_results, future_all_refinements)
candidate_index = 0
original_len = len(candidates)
# TODO : We need to rewrite this candidate loop as a class, the container which has candidates receives new candidates at unknown times due to the async nature of lp and refinement calls,
# TODO : in addition, the refinement calls depend on line profiler calls being complete so we need to check that reliably
while True:
try:
if len(candidates) > 0:
candidate = candidates.popleft()
else:
if not line_profiler_done:
logger.debug("all candidates processed, await candidates from line profiler")
concurrent.futures.wait([future_line_profile_results])
line_profile_results = future_line_profile_results.result()
candidates.extend(line_profile_results)
original_len += len(line_profile_results)
logger.info(
f"Added results from line profiler to candidates, total candidates now: {original_len}"
)
line_profiler_done = True
continue
if line_profiler_done and not refinement_done:
concurrent.futures.wait(future_all_refinements)
refinement_response = []
for future_refinement in future_all_refinements:
possible_refinement = future_refinement.result()
if len(possible_refinement) > 0: # if the api returns a valid response
refinement_response.append(possible_refinement[0])
candidates.extend(refinement_response)
original_len += len(refinement_response)
logger.info(
f"Added {len(refinement_response)} candidates from refinement, total candidates now: {original_len}"
)
refinement_done = True
continue
if line_profiler_done and refinement_done:
logger.debug("everything done, exiting")
break

# Process candidates using queue-based approach
while not processor.is_done():
candidate = processor.get_next_candidate()
if candidate is None:
logger.debug("everything done, exiting")
break

try:
candidate_index += 1
get_run_tmp_file(Path(f"test_return_values_{candidate_index}.bin")).unlink(missing_ok=True)
get_run_tmp_file(Path(f"test_return_values_{candidate_index}.sqlite")).unlink(missing_ok=True)
logger.info(f"Optimization candidate {candidate_index}/{original_len}:")
logger.info(f"Optimization candidate {candidate_index}/{processor.candidate_len}:")
code_print(candidate.source_code.flat)
# map ast normalized code to diff len, unnormalized code
# map opt id to the shortest unnormalized code
Expand All @@ -467,7 +518,7 @@ def determine_best_candidate(
# check if this code has been evaluated before by checking the ast normalized code string
normalized_code = ast.unparse(ast.parse(candidate.source_code.flat.strip()))
if normalized_code in ast_code_to_id:
logger.warning(
logger.info(
"Current candidate has been encountered before in testing, Skipping optimization candidate."
)
past_opt_id = ast_code_to_id[normalized_code]["optimization_id"]
Expand Down
Loading