diff --git a/codeflash/code_utils/checkpoint.py b/codeflash/code_utils/checkpoint.py new file mode 100644 index 000000000..c924665d7 --- /dev/null +++ b/codeflash/code_utils/checkpoint.py @@ -0,0 +1,143 @@ +import argparse +import datetime +import json +import sys +import time +import uuid +from pathlib import Path +from typing import Any, Optional + +import click + + +class CodeflashRunCheckpoint: + def __init__(self, module_root: Path, checkpoint_dir: Path = Path("/tmp")) -> None: + self.module_root = module_root + self.checkpoint_dir = Path(checkpoint_dir) + # Create a unique checkpoint file name + unique_id = str(uuid.uuid4())[:8] + checkpoint_filename = f"codeflash_checkpoint_{unique_id}.jsonl" + self.checkpoint_path = self.checkpoint_dir / checkpoint_filename + + # Initialize the checkpoint file with metadata + self._initialize_checkpoint_file() + + def _initialize_checkpoint_file(self) -> None: + """Create a new checkpoint file with metadata.""" + metadata = { + "type": "metadata", + "module_root": str(self.module_root), + "created_at": time.time(), + "last_updated": time.time(), + } + + with open(self.checkpoint_path, "w") as f: + f.write(json.dumps(metadata) + "\n") + + def add_function_to_checkpoint( + self, + function_fully_qualified_name: str, + status: str = "optimized", + additional_info: Optional[dict[str, Any]] = None, + ) -> None: + """Add a function to the checkpoint after it has been processed. + + Args: + function_fully_qualified_name: The fully qualified name of the function + status: Status of optimization (e.g., "optimized", "failed", "skipped") + additional_info: Any additional information to store about the function + + """ + if additional_info is None: + additional_info = {} + + function_data = { + "type": "function", + "function_name": function_fully_qualified_name, + "status": status, + "timestamp": time.time(), + **additional_info, + } + + with open(self.checkpoint_path, "a") as f: + f.write(json.dumps(function_data) + "\n") + + # Update the metadata last_updated timestamp + self._update_metadata_timestamp() + + def _update_metadata_timestamp(self) -> None: + """Update the last_updated timestamp in the metadata.""" + # Read the first line (metadata) + with self.checkpoint_path.open() as f: + metadata = json.loads(f.readline()) + rest_content = f.read() + + # Update the timestamp + metadata["last_updated"] = time.time() + + # Write all lines to a temporary file + + with self.checkpoint_path.open("w") as f: + f.write(json.dumps(metadata) + "\n") + f.write(rest_content) + + def cleanup(self) -> None: + """Unlink all the checkpoint files for this module_root.""" + to_delete = [] + self.checkpoint_path.unlink(missing_ok=True) + + for file in self.checkpoint_dir.glob("codeflash_checkpoint_*.jsonl"): + with file.open() as f: + # Skip the first line (metadata) + first_line = next(f) + metadata = json.loads(first_line) + if metadata.get("module_root", str(self.module_root)) == str(self.module_root): + to_delete.append(file) + for file in to_delete: + file.unlink(missing_ok=True) + + +def get_all_historical_functions(module_root: Path, checkpoint_dir: Path) -> dict[str, dict[str, str]]: + """Get information about all processed functions, regardless of status. + + Returns: + Dictionary mapping function names to their processing information + + """ + processed_functions = {} + to_delete = [] + + for file in checkpoint_dir.glob("codeflash_checkpoint_*.jsonl"): + with file.open() as f: + # Skip the first line (metadata) + first_line = next(f) + metadata = json.loads(first_line) + if metadata.get("last_updated"): + last_updated = datetime.datetime.fromtimestamp(metadata["last_updated"]) + if datetime.datetime.now() - last_updated >= datetime.timedelta(days=7): + to_delete.append(file) + continue + if metadata.get("module_root") != str(module_root): + continue + + for line in f: + entry = json.loads(line) + if entry.get("type") == "function": + processed_functions[entry["function_name"]] = entry + for file in to_delete: + file.unlink(missing_ok=True) + return processed_functions + + +def ask_should_use_checkpoint_get_functions(args: argparse.Namespace) -> Optional[dict[str, dict[str, str]]]: + previous_checkpoint_functions = None + if args.all and (sys.platform == "linux" or sys.platform == "darwin") and Path("/tmp").is_dir(): + previous_checkpoint_functions = get_all_historical_functions(args.module_root, Path("/tmp")) + if previous_checkpoint_functions and click.confirm( + "Previous Checkpoint detected from an incomplete optimization run, shall I continue the optimization from that point?", + default=True, + ): + pass + else: + previous_checkpoint_functions = None + return previous_checkpoint_functions diff --git a/codeflash/discovery/functions_to_optimize.py b/codeflash/discovery/functions_to_optimize.py index 8adfb4e00..15c5c6e0d 100644 --- a/codeflash/discovery/functions_to_optimize.py +++ b/codeflash/discovery/functions_to_optimize.py @@ -8,7 +8,7 @@ from collections import defaultdict from functools import cache from pathlib import Path -from typing import TYPE_CHECKING, Optional +from typing import TYPE_CHECKING, Any, Optional import git import libcst as cst @@ -145,6 +145,7 @@ def qualified_name(self) -> str: def qualified_name_with_modules_from_root(self, project_root_path: Path) -> str: return f"{module_name_from_file_path(self.file_path, project_root_path)}.{self.qualified_name}" + def get_functions_to_optimize( optimize_all: str | None, replay_test: str | None, @@ -154,10 +155,11 @@ def get_functions_to_optimize( ignore_paths: list[Path], project_root: Path, module_root: Path, + previous_checkpoint_functions: dict[str, dict[str, str]] | None = None, ) -> tuple[dict[Path, list[FunctionToOptimize]], int]: - assert sum([bool(optimize_all), bool(replay_test), bool(file)]) <= 1, ( - "Only one of optimize_all, replay_test, or file should be provided" - ) + assert ( + sum([bool(optimize_all), bool(replay_test), bool(file)]) <= 1 + ), "Only one of optimize_all, replay_test, or file should be provided" functions: dict[str, list[FunctionToOptimize]] with warnings.catch_warnings(): warnings.simplefilter(action="ignore", category=SyntaxWarning) @@ -198,7 +200,7 @@ def get_functions_to_optimize( ph("cli-optimizing-git-diff") functions = get_functions_within_git_diff() filtered_modified_functions, functions_count = filter_functions( - functions, test_cfg.tests_root, ignore_paths, project_root, module_root + functions, test_cfg.tests_root, ignore_paths, project_root, module_root, previous_checkpoint_functions ) logger.info(f"Found {functions_count} function{'s' if functions_count > 1 else ''} to optimize") return filtered_modified_functions, functions_count @@ -414,6 +416,7 @@ def filter_functions( ignore_paths: list[Path], project_root: Path, module_root: Path, + previous_checkpoint_functions: dict[Path, dict[str, Any]] | None = None, disable_logs: bool = False, ) -> tuple[dict[Path, list[FunctionToOptimize]], int]: blocklist_funcs = get_blocklisted_functions() @@ -430,13 +433,16 @@ def filter_functions( ignore_paths_removed_count: int = 0 malformed_paths_count: int = 0 submodule_ignored_paths_count: int = 0 + blocklist_funcs_removed_count: int = 0 + previous_checkpoint_functions_removed_count: int = 0 tests_root_str = str(tests_root) module_root_str = str(module_root) # We desperately need Python 3.10+ only support to make this code readable with structural pattern matching for file_path_path, functions in modified_functions.items(): + _functions = functions file_path = str(file_path_path) if file_path.startswith(tests_root_str + os.sep): - test_functions_removed_count += len(functions) + test_functions_removed_count += len(_functions) continue if file_path in ignore_paths or any( file_path.startswith(str(ignore_path) + os.sep) for ignore_path in ignore_paths @@ -449,10 +455,10 @@ def filter_functions( submodule_ignored_paths_count += 1 continue if path_belongs_to_site_packages(Path(file_path)): - site_packages_removed_count += len(functions) + site_packages_removed_count += len(_functions) continue if not file_path.startswith(module_root_str + os.sep): - non_modules_removed_count += len(functions) + non_modules_removed_count += len(_functions) continue try: ast.parse(f"import {module_name_from_file_path(Path(file_path), project_root)}") @@ -460,16 +466,28 @@ def filter_functions( malformed_paths_count += 1 continue if blocklist_funcs: - functions = [ - function - for function in functions + functions_tmp = [] + for function in _functions: if not ( function.file_path.name in blocklist_funcs and function.qualified_name in blocklist_funcs[function.file_path.name] - ) - ] - filtered_modified_functions[file_path] = functions - functions_count += len(functions) + ): + blocklist_funcs_removed_count += 1 + continue + functions_tmp.append(function) + _functions = functions_tmp + + if previous_checkpoint_functions: + functions_tmp = [] + for function in _functions: + if function.qualified_name_with_modules_from_root(project_root) in previous_checkpoint_functions: + previous_checkpoint_functions_removed_count += 1 + continue + functions_tmp.append(function) + _functions = functions_tmp + + filtered_modified_functions[file_path] = _functions + functions_count += len(_functions) if not disable_logs: log_info = { @@ -479,6 +497,8 @@ def filter_functions( f"{non_modules_removed_count} function{'s' if non_modules_removed_count != 1 else ''} outside module-root": non_modules_removed_count, f"{ignore_paths_removed_count} file{'s' if ignore_paths_removed_count != 1 else ''} from ignored paths": ignore_paths_removed_count, f"{submodule_ignored_paths_count} file{'s' if submodule_ignored_paths_count != 1 else ''} from ignored submodules": submodule_ignored_paths_count, + f"{blocklist_funcs_removed_count} function{'s' if blocklist_funcs_removed_count != 1 else ''} as previously optimized": blocklist_funcs_removed_count, + f"{previous_checkpoint_functions_removed_count} function{'s' if previous_checkpoint_functions_removed_count != 1 else ''} skipped from checkpoint": previous_checkpoint_functions_removed_count, } log_string = "\n".join([k for k, v in log_info.items() if v > 0]) if log_string: diff --git a/codeflash/optimization/function_optimizer.py b/codeflash/optimization/function_optimizer.py index 6ad25bc0e..786c4afb4 100644 --- a/codeflash/optimization/function_optimizer.py +++ b/codeflash/optimization/function_optimizer.py @@ -242,7 +242,9 @@ def optimize_function(self) -> Result[BestOptimization, str]: # request for new optimizations but don't block execution, check for completion later # adding to control and experiment set but with same traceid best_optimization = None - for _u, (candidates, exp_type) in enumerate(zip([optimizations_set.control, optimizations_set.experiment],["EXP0","EXP1"])): + for _u, (candidates, exp_type) in enumerate( + zip([optimizations_set.control, optimizations_set.experiment], ["EXP0", "EXP1"]) + ): if candidates is None: continue @@ -254,7 +256,14 @@ def optimize_function(self) -> Result[BestOptimization, str]: file_path_to_helper_classes=file_path_to_helper_classes, exp_type=exp_type, ) - ph("cli-optimize-function-finished", {"function_trace_id": self.function_trace_id[:-4] + exp_type if self.experiment_id else self.function_trace_id}) + ph( + "cli-optimize-function-finished", + { + "function_trace_id": self.function_trace_id[:-4] + exp_type + if self.experiment_id + else self.function_trace_id + }, + ) generated_tests = remove_functions_from_generated_tests( generated_tests=generated_tests, test_functions_to_remove=test_functions_to_remove @@ -324,7 +333,9 @@ def optimize_function(self) -> Result[BestOptimization, str]: explanation=explanation, existing_tests_source=existing_tests, generated_original_test_source=generated_tests_str, - function_trace_id=self.function_trace_id[:-4] + exp_type if self.experiment_id else self.function_trace_id, + function_trace_id=self.function_trace_id[:-4] + exp_type + if self.experiment_id + else self.function_trace_id, coverage_message=coverage_message, git_remote=self.args.git_remote, ) @@ -379,7 +390,7 @@ def determine_best_candidate( # Start a new thread for AI service request, start loop in main thread # check if aiservice request is complete, when it is complete, append result to the candidates list with concurrent.futures.ThreadPoolExecutor(max_workers=2) as executor: - ai_service_client = self.aiservice_client if exp_type=="EXP0" else self.local_aiservice_client + ai_service_client = self.aiservice_client if exp_type == "EXP0" else self.local_aiservice_client future_line_profile_results = executor.submit( ai_service_client.optimize_python_code_line_profiler, source_code=code_context.read_writable_code, @@ -387,7 +398,11 @@ def determine_best_candidate( trace_id=self.function_trace_id[:-4] + exp_type if self.experiment_id else self.function_trace_id, line_profiler_results=original_code_baseline.line_profile_results["str_out"], num_candidates=10, - experiment_metadata=ExperimentMetadata(id=self.experiment_id, group= "control" if exp_type == "EXP0" else "experiment") if self.experiment_id else None, + experiment_metadata=ExperimentMetadata( + id=self.experiment_id, group="control" if exp_type == "EXP0" else "experiment" + ) + if self.experiment_id + else None, ) try: candidate_index = 0 @@ -462,7 +477,7 @@ def determine_best_candidate( f"loop{'s' if candidate_result.max_loop_count > 1 else ''})" ) tree.add(f"Speedup percentage: {perf_gain * 100:.1f}%") - tree.add(f"Speedup ratio: {perf_gain + 1:.1f}X") + tree.add(f"Speedup ratio: {perf_gain + 1:.3f}X") replay_perf_gain = {} if self.args.benchmark: test_results_by_benchmark = ( @@ -528,7 +543,9 @@ def determine_best_candidate( ) return best_optimization - def log_successful_optimization(self, explanation: Explanation, generated_tests: GeneratedTestsList, exp_type: str) -> None: + def log_successful_optimization( + self, explanation: Explanation, generated_tests: GeneratedTestsList, exp_type: str + ) -> None: explanation_panel = Panel( f"⚡️ Optimization successful! 📄 {self.function_to_optimize.qualified_name} in {explanation.file_path}\n" f"📈 {explanation.perf_improvement_line}\n" @@ -555,7 +572,9 @@ def log_successful_optimization(self, explanation: Explanation, generated_tests: ph( "cli-optimize-success", { - "function_trace_id": self.function_trace_id[:-4] + exp_type if self.experiment_id else self.function_trace_id, + "function_trace_id": self.function_trace_id[:-4] + exp_type + if self.experiment_id + else self.function_trace_id, "speedup_x": explanation.speedup_x, "speedup_pct": explanation.speedup_pct, "best_runtime": explanation.best_runtime_ns, diff --git a/codeflash/optimization/optimizer.py b/codeflash/optimization/optimizer.py index 946e3e822..1e1f98435 100644 --- a/codeflash/optimization/optimizer.py +++ b/codeflash/optimization/optimizer.py @@ -17,6 +17,7 @@ from codeflash.benchmarking.utils import print_benchmark_table, validate_and_format_benchmark_table from codeflash.cli_cmds.console import console, logger, progress_bar from codeflash.code_utils import env_utils +from codeflash.code_utils.checkpoint import CodeflashRunCheckpoint, ask_should_use_checkpoint_get_functions from codeflash.code_utils.code_replacer import normalize_code, normalize_node from codeflash.code_utils.code_utils import get_run_tmp_file from codeflash.code_utils.static_analysis import analyze_imported_modules, get_first_top_level_function_or_method_ast @@ -52,6 +53,8 @@ def __init__(self, args: Namespace) -> None: self.experiment_id = os.getenv("CODEFLASH_EXPERIMENT_ID", None) self.local_aiservice_client = LocalAiServiceClient() if self.experiment_id else None self.replay_tests_dir = None + self.functions_checkpoint: CodeflashRunCheckpoint | None = None + def create_function_optimizer( self, function_to_optimize: FunctionToOptimize, @@ -71,7 +74,7 @@ def create_function_optimizer( args=self.args, function_benchmark_timings=function_benchmark_timings if function_benchmark_timings else None, total_benchmark_timings=total_benchmark_timings if total_benchmark_timings else None, - replay_tests_dir = self.replay_tests_dir + replay_tests_dir=self.replay_tests_dir, ) def run(self) -> None: @@ -83,7 +86,7 @@ def run(self) -> None: function_optimizer = None file_to_funcs_to_optimize: dict[Path, list[FunctionToOptimize]] num_optimizable_functions: int - + previous_checkpoint_functions = ask_should_use_checkpoint_get_functions(self.args) # discover functions (file_to_funcs_to_optimize, num_optimizable_functions) = get_functions_to_optimize( optimize_all=self.args.all, @@ -94,14 +97,12 @@ def run(self) -> None: ignore_paths=self.args.ignore_paths, project_root=self.args.project_root, module_root=self.args.module_root, + previous_checkpoint_functions=previous_checkpoint_functions, ) function_benchmark_timings: dict[str, dict[BenchmarkKey, int]] = {} total_benchmark_timings: dict[BenchmarkKey, int] = {} if self.args.benchmark and num_optimizable_functions > 0: - with progress_bar( - f"Running benchmarks in {self.args.benchmarks_root}", - transient=True, - ): + with progress_bar(f"Running benchmarks in {self.args.benchmarks_root}", transient=True): # Insert decorator file_path_to_source_code = defaultdict(str) for file in file_to_funcs_to_optimize: @@ -113,15 +114,23 @@ def run(self) -> None: if trace_file.exists(): trace_file.unlink() - self.replay_tests_dir = Path(tempfile.mkdtemp(prefix="codeflash_replay_tests_", dir=self.args.benchmarks_root)) - trace_benchmarks_pytest(self.args.benchmarks_root, self.args.tests_root, self.args.project_root, trace_file) # Run all tests that use pytest-benchmark + self.replay_tests_dir = Path( + tempfile.mkdtemp(prefix="codeflash_replay_tests_", dir=self.args.benchmarks_root) + ) + trace_benchmarks_pytest( + self.args.benchmarks_root, self.args.tests_root, self.args.project_root, trace_file + ) # Run all tests that use pytest-benchmark replay_count = generate_replay_test(trace_file, self.replay_tests_dir) if replay_count == 0: - logger.info(f"No valid benchmarks found in {self.args.benchmarks_root} for functions to optimize, continuing optimization") + logger.info( + f"No valid benchmarks found in {self.args.benchmarks_root} for functions to optimize, continuing optimization" + ) else: function_benchmark_timings = CodeFlashBenchmarkPlugin.get_function_benchmark_timings(trace_file) total_benchmark_timings = CodeFlashBenchmarkPlugin.get_benchmark_timings(trace_file) - function_to_results = validate_and_format_benchmark_table(function_benchmark_timings, total_benchmark_timings) + function_to_results = validate_and_format_benchmark_table( + function_benchmark_timings, total_benchmark_timings + ) print_benchmark_table(function_to_results) except Exception as e: logger.info(f"Error while tracing existing benchmarks: {e}") @@ -148,10 +157,13 @@ def run(self) -> None: function_to_tests: dict[str, list[FunctionCalledInTest]] = discover_unit_tests(self.test_cfg) num_discovered_tests: int = sum([len(value) for value in function_to_tests.values()]) console.rule() - logger.info(f"Discovered {num_discovered_tests} existing unit tests in {(time.time() - start_time):.1f}s at {self.test_cfg.tests_root}") + logger.info( + f"Discovered {num_discovered_tests} existing unit tests in {(time.time() - start_time):.1f}s at {self.test_cfg.tests_root}" + ) console.rule() ph("cli-optimize-discovered-tests", {"num_tests": num_discovered_tests}) - + if self.args.all: + self.functions_checkpoint = CodeflashRunCheckpoint(self.args.module_root) for original_module_path in file_to_funcs_to_optimize: logger.info(f"Examining file {original_module_path!s}…") @@ -212,17 +224,33 @@ def run(self) -> None: qualified_name_w_module = function_to_optimize.qualified_name_with_modules_from_root( self.args.project_root ) - if self.args.benchmark and function_benchmark_timings and qualified_name_w_module in function_benchmark_timings and total_benchmark_timings: + if ( + self.args.benchmark + and function_benchmark_timings + and qualified_name_w_module in function_benchmark_timings + and total_benchmark_timings + ): function_optimizer = self.create_function_optimizer( - function_to_optimize, function_to_optimize_ast, function_to_tests, validated_original_code[original_module_path].source_code, function_benchmark_timings[qualified_name_w_module], total_benchmark_timings + function_to_optimize, + function_to_optimize_ast, + function_to_tests, + validated_original_code[original_module_path].source_code, + function_benchmark_timings[qualified_name_w_module], + total_benchmark_timings, ) else: function_optimizer = self.create_function_optimizer( - function_to_optimize, function_to_optimize_ast, function_to_tests, - validated_original_code[original_module_path].source_code + function_to_optimize, + function_to_optimize_ast, + function_to_tests, + validated_original_code[original_module_path].source_code, ) best_optimization = function_optimizer.optimize_function() + if self.functions_checkpoint: + self.functions_checkpoint.add_function_to_checkpoint( + function_to_optimize.qualified_name_with_modules_from_root(self.args.project_root) + ) if is_successful(best_optimization): optimizations_found += 1 else: @@ -230,6 +258,8 @@ def run(self) -> None: console.rule() continue ph("cli-optimize-run-finished", {"optimizations_found": optimizations_found}) + if self.functions_checkpoint: + self.functions_checkpoint.cleanup() if optimizations_found == 0: logger.info("❌ No optimizations found.") elif self.args.all: diff --git a/tests/test_codeflash_checkpoint.py b/tests/test_codeflash_checkpoint.py new file mode 100644 index 000000000..b9770b676 --- /dev/null +++ b/tests/test_codeflash_checkpoint.py @@ -0,0 +1,176 @@ +import json +import tempfile +from pathlib import Path + +import pytest +from codeflash.code_utils.checkpoint import CodeflashRunCheckpoint, get_all_historical_functions + + +class TestCodeflashRunCheckpoint: + @pytest.fixture + def temp_dir(self): + with tempfile.TemporaryDirectory() as temp_dir: + yield Path(temp_dir) + + def test_initialization(self, temp_dir): + module_root = Path("/fake/module/root") + checkpoint = CodeflashRunCheckpoint(module_root, checkpoint_dir=temp_dir) + + # Check if checkpoint file was created + assert checkpoint.checkpoint_path.exists() + + # Check if metadata was written correctly + with open(checkpoint.checkpoint_path) as f: + metadata = json.loads(f.readline()) + assert metadata["type"] == "metadata" + assert metadata["module_root"] == str(module_root) + assert "created_at" in metadata + assert "last_updated" in metadata + + def test_add_function_to_checkpoint(self, temp_dir): + module_root = Path("/fake/module/root") + checkpoint = CodeflashRunCheckpoint(module_root, checkpoint_dir=temp_dir) + + # Add a function to the checkpoint + function_name = "module.submodule.function" + checkpoint.add_function_to_checkpoint(function_name, status="optimized") + + # Read the checkpoint file and verify + with open(checkpoint.checkpoint_path) as f: + lines = f.readlines() + assert len(lines) == 2 # Metadata + function entry + + function_data = json.loads(lines[1]) + assert function_data["type"] == "function" + assert function_data["function_name"] == function_name + assert function_data["status"] == "optimized" + assert "timestamp" in function_data + + def test_add_function_with_additional_info(self, temp_dir): + module_root = Path("/fake/module/root") + checkpoint = CodeflashRunCheckpoint(module_root, checkpoint_dir=temp_dir) + + # Add a function with additional info + function_name = "module.submodule.function" + additional_info = {"execution_time": 1.5, "memory_usage": "10MB"} + checkpoint.add_function_to_checkpoint(function_name, status="optimized", additional_info=additional_info) + + # Read the checkpoint file and verify + with open(checkpoint.checkpoint_path) as f: + lines = f.readlines() + function_data = json.loads(lines[1]) + assert function_data["execution_time"] == 1.5 + assert function_data["memory_usage"] == "10MB" + + def test_update_metadata_timestamp(self, temp_dir): + module_root = Path("/fake/module/root") + checkpoint = CodeflashRunCheckpoint(module_root, checkpoint_dir=temp_dir) + + # Get initial timestamp + with open(checkpoint.checkpoint_path) as f: + initial_metadata = json.loads(f.readline()) + initial_timestamp = initial_metadata["last_updated"] + + # Wait a bit to ensure timestamp changes + import time + + time.sleep(0.01) + + # Update timestamp + checkpoint._update_metadata_timestamp() + + # Check if timestamp was updated + with open(checkpoint.checkpoint_path) as f: + updated_metadata = json.loads(f.readline()) + updated_timestamp = updated_metadata["last_updated"] + + assert updated_timestamp > initial_timestamp + + def test_cleanup(self, temp_dir): + module_root = Path("/fake/module/root") + + # Create multiple checkpoint files + checkpoint1 = CodeflashRunCheckpoint(module_root, checkpoint_dir=temp_dir) + checkpoint2 = CodeflashRunCheckpoint(module_root, checkpoint_dir=temp_dir) + + # Create a checkpoint for a different module + different_module = Path("/different/module") + checkpoint3 = CodeflashRunCheckpoint(different_module, checkpoint_dir=temp_dir) + + # Verify all files exist + assert checkpoint1.checkpoint_path.exists() + assert checkpoint2.checkpoint_path.exists() + assert checkpoint3.checkpoint_path.exists() + + # Clean up files for module_root + checkpoint1.cleanup() + + # Check that only the files for module_root were deleted + assert not checkpoint1.checkpoint_path.exists() + assert not checkpoint2.checkpoint_path.exists() + assert checkpoint3.checkpoint_path.exists() + + +class TestGetAllHistoricalFunctions: + @pytest.fixture + def setup_checkpoint_files(self): + with tempfile.TemporaryDirectory() as temp_dir: + temp_dir_path = Path(temp_dir) + module_root = Path("/fake/module/root") + + # Create a checkpoint file with some functions + checkpoint = CodeflashRunCheckpoint(module_root, checkpoint_dir=temp_dir_path) + checkpoint.add_function_to_checkpoint("module.func1", status="optimized") + checkpoint.add_function_to_checkpoint("module.func2", status="failed") + + # Create an old checkpoint file (more than 7 days old) + old_checkpoint_path = temp_dir_path / "codeflash_checkpoint_old.jsonl" + with open(old_checkpoint_path, "w") as f: + # Create metadata with old timestamp (8 days ago) + import time + + old_time = time.time() - (8 * 24 * 60 * 60) + metadata = { + "type": "metadata", + "module_root": str(module_root), + "created_at": old_time, + "last_updated": old_time, + } + f.write(json.dumps(metadata) + "\n") + + # Add a function entry + function_data = { + "type": "function", + "function_name": "module.old_func", + "status": "optimized", + "timestamp": old_time, + } + f.write(json.dumps(function_data) + "\n") + + # Create a checkpoint for a different module + different_module = Path("/different/module") + diff_checkpoint = CodeflashRunCheckpoint(different_module, checkpoint_dir=temp_dir_path) + diff_checkpoint.add_function_to_checkpoint("different.func", status="optimized") + + yield module_root, temp_dir_path + + def test_get_all_historical_functions(self, setup_checkpoint_files): + module_root, checkpoint_dir = setup_checkpoint_files + + # Get historical functions + functions = get_all_historical_functions(module_root, checkpoint_dir) + + # Verify the functions from the current checkpoint are included + assert "module.func1" in functions + assert "module.func2" in functions + assert functions["module.func1"]["status"] == "optimized" + assert functions["module.func2"]["status"] == "failed" + + # Verify the old function is not included (file should be deleted) + assert "module.old_func" not in functions + + # Verify the function from the different module is not included + assert "different.func" not in functions + + # Verify the old checkpoint file was deleted + assert not (checkpoint_dir / "codeflash_checkpoint_old.jsonl").exists()