Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 30 additions & 31 deletions codeflash/code_utils/time_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,36 +53,35 @@ def humanize_runtime(time_in_ns: int) -> str:

def format_time(nanoseconds: int) -> str:
"""Format nanoseconds into a human-readable string with 3 significant digits when needed."""
# Inlined significant digit check: >= 3 digits if value >= 100
# Define conversion factors and units
if not isinstance(nanoseconds, int):
raise TypeError("Input must be an integer.")
if nanoseconds < 0:
raise ValueError("Input must be a positive integer.")
conversions = [(1_000_000_000, "s"), (1_000_000, "ms"), (1_000, "μs"), (1, "ns")]

# Handle nanoseconds case directly (no decimal formatting needed)
if nanoseconds < 1_000:
return f"{nanoseconds}ns"
if nanoseconds < 1_000_000:
microseconds_int = nanoseconds // 1_000
if microseconds_int >= 100:
return f"{microseconds_int}μs"
microseconds = nanoseconds / 1_000
# Format with precision: 3 significant digits
if microseconds >= 100:
return f"{microseconds:.0f}μs"
if microseconds >= 10:
return f"{microseconds:.1f}μs"
return f"{microseconds:.2f}μs"
if nanoseconds < 1_000_000_000:
milliseconds_int = nanoseconds // 1_000_000
if milliseconds_int >= 100:
return f"{milliseconds_int}ms"
milliseconds = nanoseconds / 1_000_000
if milliseconds >= 100:
return f"{milliseconds:.0f}ms"
if milliseconds >= 10:
return f"{milliseconds:.1f}ms"
return f"{milliseconds:.2f}ms"
seconds_int = nanoseconds // 1_000_000_000
if seconds_int >= 100:
return f"{seconds_int}s"
seconds = nanoseconds / 1_000_000_000
if seconds >= 100:
return f"{seconds:.0f}s"
if seconds >= 10:
return f"{seconds:.1f}s"
return f"{seconds:.2f}s"

# Find appropriate unit
for divisor, unit in conversions:
if nanoseconds >= divisor:
value = nanoseconds / divisor
int_value = nanoseconds // divisor

# Use integer formatting for values >= 100
if int_value >= 100:
formatted_value = f"{int_value:.0f}"
# Format with precision for 3 significant digits
elif value >= 100:
formatted_value = f"{value:.0f}"
elif value >= 10:
formatted_value = f"{value:.1f}"
else:
formatted_value = f"{value:.2f}"

return f"{formatted_value}{unit}"

# This should never be reached, but included for completeness
return f"{nanoseconds}ns"
27 changes: 16 additions & 11 deletions codeflash/context/unused_definition_remover.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,20 @@
import ast
from collections import defaultdict
from dataclasses import dataclass, field
from pathlib import Path
from typing import Optional
from typing import TYPE_CHECKING

if TYPE_CHECKING:
from pathlib import Path
from typing import TYPE_CHECKING, Optional

import libcst as cst

from codeflash.cli_cmds.console import logger
from codeflash.code_utils.code_replacer import replace_function_definitions_in_module
from codeflash.models.models import CodeOptimizationContext, FunctionSource

if TYPE_CHECKING:
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
from codeflash.models.models import CodeOptimizationContext, FunctionSource


@dataclass
Expand Down Expand Up @@ -493,11 +499,12 @@ def print_definitions(definitions: dict[str, UsageInfo]) -> None:


def revert_unused_helper_functions(
project_root, unused_helpers: list[FunctionSource], original_helper_code: dict[Path, str]
project_root: Path, unused_helpers: list[FunctionSource], original_helper_code: dict[Path, str]
) -> None:
"""Revert unused helper functions back to their original definitions.

Args:
project_root: project_root
unused_helpers: List of unused helper functions to revert
original_helper_code: Dictionary mapping file paths to their original code

Expand All @@ -516,9 +523,6 @@ def revert_unused_helper_functions(
for file_path, helpers_in_file in unused_helpers_by_file.items():
if file_path in original_helper_code:
try:
# Read current file content
current_code = file_path.read_text(encoding="utf8")

# Get original code for this file
original_code = original_helper_code[file_path]

Expand Down Expand Up @@ -557,7 +561,6 @@ def _analyze_imports_in_optimized_code(
# Precompute a two-level dict: module_name -> func_name -> [helpers]
helpers_by_file_and_func = defaultdict(dict)
helpers_by_file = defaultdict(list) # preserved for "import module"
helpers_append = helpers_by_file_and_func.setdefault
for helper in code_context.helper_functions:
jedi_type = helper.jedi_definition.type
if jedi_type != "class":
Expand Down Expand Up @@ -606,11 +609,12 @@ def _analyze_imports_in_optimized_code(


def detect_unused_helper_functions(
function_to_optimize, code_context: CodeOptimizationContext, optimized_code: str
function_to_optimize: FunctionToOptimize, code_context: CodeOptimizationContext, optimized_code: str
) -> list[FunctionSource]:
"""Detect helper functions that are no longer called by the optimized entrypoint function.

Args:
function_to_optimize: The function to optimize
code_context: The code optimization context containing helper functions
optimized_code: The optimized code to analyze

Expand Down Expand Up @@ -702,8 +706,9 @@ def detect_unused_helper_functions(
logger.debug(f"Helper function {helper_qualified_name} is still called in optimized code")
logger.debug(f" Called via: {possible_call_names.intersection(called_function_names)}")

return unused_helpers
ret_val = unused_helpers

except Exception as e:
logger.debug(f"Error detecting unused helper functions: {e}")
return []
ret_val = []
return ret_val
147 changes: 146 additions & 1 deletion tests/test_humanize_time.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from codeflash.code_utils.time_utils import humanize_runtime
from codeflash.code_utils.time_utils import humanize_runtime, format_time
import pytest


def test_humanize_runtime():
Expand Down Expand Up @@ -28,3 +29,147 @@ def test_humanize_runtime():
assert humanize_runtime(12345678912345) == "3.43 hours"
assert humanize_runtime(98765431298760) == "1.14 days"
assert humanize_runtime(197530862597520) == "2.29 days"


class TestFormatTime:
"""Test cases for the format_time function."""

def test_nanoseconds_range(self):
"""Test formatting for nanoseconds (< 1,000 ns)."""
assert format_time(0) == "0ns"
assert format_time(1) == "1ns"
assert format_time(500) == "500ns"
assert format_time(999) == "999ns"

def test_microseconds_range(self):
"""Test formatting for microseconds (1,000 ns to 999,999 ns)."""
# Integer microseconds >= 100
# assert format_time(100_000) == "100μs"
# assert format_time(500_000) == "500μs"
# assert format_time(999_000) == "999μs"

# Decimal microseconds with varying precision
assert format_time(1_000) == "1.00μs" # 1.0 μs, 2 decimal places
assert format_time(1_500) == "1.50μs" # 1.5 μs, 2 decimal places
assert format_time(9_999) == "10.00μs" # 9.999 μs rounds to 10.00
assert format_time(10_000) == "10.0μs" # 10.0 μs, 1 decimal place
assert format_time(15_500) == "15.5μs" # 15.5 μs, 1 decimal place
assert format_time(99_900) == "99.9μs" # 99.9 μs, 1 decimal place

def test_milliseconds_range(self):
"""Test formatting for milliseconds (1,000,000 ns to 999,999,999 ns)."""
# Integer milliseconds >= 100
assert format_time(100_000_000) == "100ms"
assert format_time(500_000_000) == "500ms"
assert format_time(999_000_000) == "999ms"

# Decimal milliseconds with varying precision
assert format_time(1_000_000) == "1.00ms" # 1.0 ms, 2 decimal places
assert format_time(1_500_000) == "1.50ms" # 1.5 ms, 2 decimal places
assert format_time(9_999_000) == "10.00ms" # 9.999 ms rounds to 10.00
assert format_time(10_000_000) == "10.0ms" # 10.0 ms, 1 decimal place
assert format_time(15_500_000) == "15.5ms" # 15.5 ms, 1 decimal place
assert format_time(99_900_000) == "99.9ms" # 99.9 ms, 1 decimal place

def test_seconds_range(self):
"""Test formatting for seconds (>= 1,000,000,000 ns)."""
# Integer seconds >= 100
assert format_time(100_000_000_000) == "100s"
assert format_time(500_000_000_000) == "500s"
assert format_time(999_000_000_000) == "999s"

# Decimal seconds with varying precision
assert format_time(1_000_000_000) == "1.00s" # 1.0 s, 2 decimal places
assert format_time(1_500_000_000) == "1.50s" # 1.5 s, 2 decimal places
assert format_time(9_999_000_000) == "10.00s" # 9.999 s rounds to 10.00
assert format_time(10_000_000_000) == "10.0s" # 10.0 s, 1 decimal place
assert format_time(15_500_000_000) == "15.5s" # 15.5 s, 1 decimal place
assert format_time(99_900_000_000) == "99.9s" # 99.9 s, 1 decimal place

def test_boundary_values(self):
"""Test exact boundary values between units."""
# Boundaries between nanoseconds and microseconds
assert format_time(999) == "999ns"
assert format_time(1_000) == "1.00μs"

# Boundaries between microseconds and milliseconds
assert format_time(999_999) == "999μs" # This might round to 1000.00μs
assert format_time(1_000_000) == "1.00ms"

# Boundaries between milliseconds and seconds
assert format_time(999_999_999) == "999ms" # This might round to 1000.00ms
assert format_time(1_000_000_000) == "1.00s"

def test_precision_boundaries(self):
"""Test precision changes at significant digit boundaries."""
# Microseconds precision changes
assert format_time(9_950) == "9.95μs" # 2 decimal places
assert format_time(10_000) == "10.0μs" # 1 decimal place
assert format_time(99_900) == "99.9μs" # 1 decimal place
assert format_time(100_000) == "100μs" # No decimal places

# Milliseconds precision changes
assert format_time(9_950_000) == "9.95ms" # 2 decimal places
assert format_time(10_000_000) == "10.0ms" # 1 decimal place
assert format_time(99_900_000) == "99.9ms" # 1 decimal place
assert format_time(100_000_000) == "100ms" # No decimal places

# Seconds precision changes
assert format_time(9_950_000_000) == "9.95s" # 2 decimal places
assert format_time(10_000_000_000) == "10.0s" # 1 decimal place
assert format_time(99_900_000_000) == "99.9s" # 1 decimal place
assert format_time(100_000_000_000) == "100s" # No decimal places

def test_rounding_behavior(self):
"""Test rounding behavior for edge cases."""
# Test rounding in microseconds
assert format_time(1_234) == "1.23μs"
assert format_time(1_235) == "1.24μs" # Should round up
assert format_time(12_345) == "12.3μs"
assert format_time(12_350) == "12.3μs" # Should round up

# Test rounding in milliseconds
assert format_time(1_234_000) == "1.23ms"
assert format_time(1_235_000) == "1.24ms" # Should round up
assert format_time(12_345_000) == "12.3ms"
assert format_time(12_350_000) == "12.3ms" # Should round up

def test_large_values(self):
"""Test very large nanosecond values."""
assert format_time(3_600_000_000_000) == "3600s" # 1 hour
assert format_time(86_400_000_000_000) == "86400s" # 1 day

@pytest.mark.parametrize("nanoseconds,expected", [
(0, "0ns"),
(42, "42ns"),
(1_500, "1.50μs"),
(25_000, "25.0μs"),
(150_000, "150μs"),
(2_500_000, "2.50ms"),
(45_000_000, "45.0ms"),
(200_000_000, "200ms"),
(3_500_000_000, "3.50s"),
(75_000_000_000, "75.0s"),
(300_000_000_000, "300s"),
])
def test_parametrized_examples(self, nanoseconds, expected):
"""Parametrized test with various input/output combinations."""
assert format_time(nanoseconds) == expected

def test_invalid_input_types(self):
"""Test that function handles invalid input types appropriately."""
with pytest.raises(TypeError):
format_time("1000")

with pytest.raises(TypeError):
format_time(1000.5)

with pytest.raises(TypeError):
format_time(None)

def test_negative_values(self):
"""Test behavior with negative values (if applicable)."""
# This test depends on whether your function should handle negative values
# You might want to modify based on expected behavior
with pytest.raises((ValueError, TypeError)) or pytest.warns():
format_time(-1000)
Loading