Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 7 additions & 6 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
repos:
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: "v0.11.0"
hooks:
- id: ruff
args: [--fix, --exit-non-zero-on-fix, --config=pyproject.toml]
- id: ruff-format
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.12.7
hooks:
# Run the linter.
- id: ruff-check
# Run the formatter.
- id: ruff-format
6 changes: 4 additions & 2 deletions codeflash/benchmarking/instrument_codeflash_trace.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,12 +32,14 @@ def leave_ClassDef(
def visit_ClassDef(self, node: ClassDef) -> Optional[bool]:
if self.class_name: # Don't go into nested class
return False
self.class_name = node.name.value # noqa: RET503
self.class_name = node.name.value
return None

def visit_FunctionDef(self, node: FunctionDef) -> Optional[bool]:
if self.function_name: # Don't go into nested function
return False
self.function_name = node.name.value # noqa: RET503
self.function_name = node.name.value
return None

def leave_FunctionDef(self, original_node: FunctionDef, updated_node: FunctionDef) -> FunctionDef:
if self.function_name == original_node.name.value:
Expand Down
2 changes: 1 addition & 1 deletion codeflash/cli_cmds/cmd_init.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
from argparse import Namespace

CODEFLASH_LOGO: str = (
f"{LF}" # noqa: ISC003
f"{LF}"
r" _ ___ _ _ " + f"{LF}"
r" | | / __)| | | | " + f"{LF}"
r" ____ ___ _ | | ____ | |__ | | ____ ___ | | _ " + f"{LF}"
Expand Down
2 changes: 1 addition & 1 deletion codeflash/models/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -480,7 +480,7 @@ def unique_invocation_loop_id(self) -> str:
return f"{self.loop_index}:{self.id.id()}"


class TestResults(BaseModel):
class TestResults(BaseModel): # noqa: PLW1641
# don't modify these directly, use the add method
# also we don't support deletion of test results elements - caution is advised
test_results: list[FunctionTestInvocation] = []
Expand Down
11 changes: 5 additions & 6 deletions codeflash/result/critic.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from __future__ import annotations

from typing import TYPE_CHECKING, Optional
from typing import TYPE_CHECKING

from codeflash.cli_cmds.console import logger
from codeflash.code_utils import env_utils
Expand Down Expand Up @@ -29,7 +29,8 @@ def speedup_critic(
candidate_result: OptimizedCandidateResult,
original_code_runtime: int,
best_runtime_until_now: int | None,
disable_gh_action_noise: Optional[bool] = None,
*,
disable_gh_action_noise: bool = False,
) -> bool:
"""Take in a correct optimized Test Result and decide if the optimization should actually be surfaced to the user.

Expand All @@ -39,10 +40,8 @@ def speedup_critic(
The noise floor is doubled when benchmarking on a (noisy) GitHub Action virtual instance, also we want to be more confident there.
"""
noise_floor = 3 * MIN_IMPROVEMENT_THRESHOLD if original_code_runtime < 10000 else MIN_IMPROVEMENT_THRESHOLD
if not disable_gh_action_noise:
in_github_actions_mode = bool(env_utils.get_pr_number())
if in_github_actions_mode:
noise_floor = noise_floor * 2 # Increase the noise floor in GitHub Actions mode
if not disable_gh_action_noise and env_utils.is_ci():
noise_floor = noise_floor * 2 # Increase the noise floor in GitHub Actions mode

perf_gain = performance_gain(
original_runtime_ns=original_code_runtime, optimized_runtime_ns=candidate_result.best_test_runtime
Expand Down
2 changes: 1 addition & 1 deletion codeflash/verification/parse_test_output.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,7 +271,7 @@ def parse_test_xml(
groups = match.groups()
if len(groups[5].split(":")) > 1:
iteration_id = groups[5].split(":")[0]
groups = groups[:5] + (iteration_id,)
groups = (*groups[:5], iteration_id)
end_matches[groups] = match

if not begin_matches or not begin_matches:
Expand Down
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -239,7 +239,8 @@ ignore = [
"D104",
"PERF203",
"LOG015",
"PLC0415"
"PLC0415",
"UP045"
]

[tool.ruff.lint.flake8-type-checking]
Expand Down
6 changes: 3 additions & 3 deletions tests/test_critic.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def test_speedup_critic() -> None:
total_candidate_timing=12,
)

assert speedup_critic(candidate_result, original_code_runtime, best_runtime_until_now, True) # 20% improvement
assert speedup_critic(candidate_result, original_code_runtime, best_runtime_until_now, disable_gh_action_noise=True) # 20% improvement

candidate_result = OptimizedCandidateResult(
max_loop_count=5,
Expand All @@ -52,7 +52,7 @@ def test_speedup_critic() -> None:
optimization_candidate_index=0,
)

assert not speedup_critic(candidate_result, original_code_runtime, best_runtime_until_now, True) # 6% improvement
assert not speedup_critic(candidate_result, original_code_runtime, best_runtime_until_now, disable_gh_action_noise=True) # 6% improvement

original_code_runtime = 100000
best_runtime_until_now = 100000
Expand All @@ -66,7 +66,7 @@ def test_speedup_critic() -> None:
optimization_candidate_index=0,
)

assert speedup_critic(candidate_result, original_code_runtime, best_runtime_until_now, True) # 6% improvement
assert speedup_critic(candidate_result, original_code_runtime, best_runtime_until_now, disable_gh_action_noise=True) # 6% improvement


def test_generated_test_critic() -> None:
Expand Down
Loading