Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions .github/workflows/pre-commit.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
name: Lint
on:
pull_request:
push:
branches:
- main

concurrency:
group: ${{ github.workflow }}-${{ github.ref }}
cancel-in-progress: true

jobs:
lint:
name: Run pre-commit hooks
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- uses: actions/setup-python@v5
- uses: pre-commit/action@v3.0.1
7 changes: 7 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
repos:
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: "v0.11.0"
hooks:
- id: ruff
args: [--fix, --exit-non-zero-on-fix, --config=pyproject.toml]
- id: ruff-format
17 changes: 8 additions & 9 deletions codeflash/api/aiservice.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def make_ai_service_request(
# response.raise_for_status() # Will raise an HTTPError if the HTTP request returned an unsuccessful status code
return response

def optimize_python_code(
def optimize_python_code( # noqa: D417
self,
source_code: str,
dependency_code: str,
Expand Down Expand Up @@ -139,7 +139,7 @@ def optimize_python_code(
console.rule()
return []

def optimize_python_code_line_profiler(
def optimize_python_code_line_profiler( # noqa: D417
self,
source_code: str,
dependency_code: str,
Expand Down Expand Up @@ -208,7 +208,7 @@ def optimize_python_code_line_profiler(
console.rule()
return []

def log_results(
def log_results( # noqa: D417
self,
function_trace_id: str,
speedup_ratio: dict[str, float | None] | None,
Expand Down Expand Up @@ -240,7 +240,7 @@ def log_results(
except requests.exceptions.RequestException as e:
logger.exception(f"Error logging features: {e}")

def generate_regression_tests(
def generate_regression_tests( # noqa: D417
self,
source_code_being_tested: str,
function_to_optimize: FunctionToOptimize,
Expand Down Expand Up @@ -270,10 +270,9 @@ def generate_regression_tests(
- Dict[str, str] | None: The generated regression tests and instrumented tests, or None if an error occurred.

"""
assert test_framework in [
"pytest",
"unittest",
], f"Invalid test framework, got {test_framework} but expected 'pytest' or 'unittest'"
assert test_framework in ["pytest", "unittest"], (
f"Invalid test framework, got {test_framework} but expected 'pytest' or 'unittest'"
)
payload = {
"source_code_being_tested": source_code_being_tested,
"function_to_optimize": function_to_optimize,
Expand Down Expand Up @@ -308,7 +307,7 @@ def generate_regression_tests(
error = response.json()["error"]
logger.error(f"Error generating tests: {response.status_code} - {error}")
ph("cli-testgen-error-response", {"response_status_code": response.status_code, "error": error})
return None
return None # noqa: TRY300
except Exception:
logger.error(f"Error generating tests: {response.status_code} - {response.text}")
ph("cli-testgen-error-response", {"response_status_code": response.status_code, "error": response.text})
Expand Down
59 changes: 46 additions & 13 deletions codeflash/benchmarking/codeflash_trace.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import sqlite3
import threading
import time
from typing import Callable
from typing import Any, Callable

from codeflash.picklepatch.pickle_patcher import PicklePatcher

Expand Down Expand Up @@ -69,7 +69,7 @@ def write_function_timings(self) -> None:
"(function_name, class_name, module_name, file_path, benchmark_function_name, "
"benchmark_module_path, benchmark_line_number, function_time_ns, overhead_time_ns, args, kwargs) "
"VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)",
self.function_calls_data
self.function_calls_data,
)
self._connection.commit()
self.function_calls_data = []
Expand Down Expand Up @@ -100,9 +100,10 @@ def __call__(self, func: Callable) -> Callable:
The wrapped function

"""
func_id = (func.__module__,func.__name__)
func_id = (func.__module__, func.__name__)

@functools.wraps(func)
def wrapper(*args, **kwargs):
def wrapper(*args, **kwargs) -> Any: # noqa: ANN002, ANN003, ANN401
# Initialize thread-local active functions set if it doesn't exist
if not hasattr(self._thread_local, "active_functions"):
self._thread_local.active_functions = set()
Expand Down Expand Up @@ -139,9 +140,19 @@ def wrapper(*args, **kwargs):
self._thread_local.active_functions.remove(func_id)
overhead_time = time.thread_time_ns() - end_time
self.function_calls_data.append(
(func.__name__, class_name, func.__module__, func.__code__.co_filename,
benchmark_function_name, benchmark_module_path, benchmark_line_number, execution_time,
overhead_time, None, None)
(
func.__name__,
class_name,
func.__module__,
func.__code__.co_filename,
benchmark_function_name,
benchmark_module_path,
benchmark_line_number,
execution_time,
overhead_time,
None,
None,
)
)
return result

Expand All @@ -155,9 +166,19 @@ def wrapper(*args, **kwargs):
self._thread_local.active_functions.remove(func_id)
overhead_time = time.thread_time_ns() - end_time
self.function_calls_data.append(
(func.__name__, class_name, func.__module__, func.__code__.co_filename,
benchmark_function_name, benchmark_module_path, benchmark_line_number, execution_time,
overhead_time, None, None)
(
func.__name__,
class_name,
func.__module__,
func.__code__.co_filename,
benchmark_function_name,
benchmark_module_path,
benchmark_line_number,
execution_time,
overhead_time,
None,
None,
)
)
return result
# Flush to database every 100 calls
Expand All @@ -168,12 +189,24 @@ def wrapper(*args, **kwargs):
self._thread_local.active_functions.remove(func_id)
overhead_time = time.thread_time_ns() - end_time
self.function_calls_data.append(
(func.__name__, class_name, func.__module__, func.__code__.co_filename,
benchmark_function_name, benchmark_module_path, benchmark_line_number, execution_time,
overhead_time, pickled_args, pickled_kwargs)
(
func.__name__,
class_name,
func.__module__,
func.__code__.co_filename,
benchmark_function_name,
benchmark_module_path,
benchmark_line_number,
execution_time,
overhead_time,
pickled_args,
pickled_kwargs,
)
)
return result

return wrapper


# Create a singleton instance
codeflash_trace = CodeflashTrace()
72 changes: 32 additions & 40 deletions codeflash/benchmarking/instrument_codeflash_trace.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,16 @@
from pathlib import Path
from __future__ import annotations

from typing import TYPE_CHECKING, Optional, Union

import isort
import libcst as cst

from codeflash.discovery.functions_to_optimize import FunctionToOptimize
if TYPE_CHECKING:
from pathlib import Path

from libcst import BaseStatement, ClassDef, FlattenSentinel, FunctionDef, RemovalSentinel

from codeflash.discovery.functions_to_optimize import FunctionToOptimize


class AddDecoratorTransformer(cst.CSTTransformer):
Expand All @@ -13,57 +20,48 @@ def __init__(self, target_functions: set[tuple[str, str]]) -> None:
self.added_codeflash_trace = False
self.class_name = ""
self.function_name = ""
self.decorator = cst.Decorator(
decorator=cst.Name(value="codeflash_trace")
)
self.decorator = cst.Decorator(decorator=cst.Name(value="codeflash_trace"))

def leave_ClassDef(self, original_node, updated_node):
def leave_ClassDef(
self, original_node: ClassDef, updated_node: ClassDef
) -> Union[BaseStatement, FlattenSentinel[BaseStatement], RemovalSentinel]:
if self.class_name == original_node.name.value:
self.class_name = "" # Even if nested classes are not visited, this function is still called on them
self.class_name = "" # Even if nested classes are not visited, this function is still called on them
return updated_node

def visit_ClassDef(self, node):
if self.class_name: # Don't go into nested class
def visit_ClassDef(self, node: ClassDef) -> Optional[bool]:
if self.class_name: # Don't go into nested class
return False
self.class_name = node.name.value
self.class_name = node.name.value # noqa: RET503

def visit_FunctionDef(self, node):
if self.function_name: # Don't go into nested function
def visit_FunctionDef(self, node: FunctionDef) -> Optional[bool]:
if self.function_name: # Don't go into nested function
return False
self.function_name = node.name.value
self.function_name = node.name.value # noqa: RET503

def leave_FunctionDef(self, original_node, updated_node):
def leave_FunctionDef(self, original_node: FunctionDef, updated_node: FunctionDef) -> FunctionDef:
if self.function_name == original_node.name.value:
self.function_name = ""
if (self.class_name, original_node.name.value) in self.target_functions:
# Add the new decorator after any existing decorators, so it gets executed first
updated_decorators = list(updated_node.decorators) + [self.decorator]
updated_decorators = [*list(updated_node.decorators), self.decorator]
self.added_codeflash_trace = True
return updated_node.with_changes(
decorators=updated_decorators
)
return updated_node.with_changes(decorators=updated_decorators)

return updated_node

def leave_Module(self, original_node: cst.Module, updated_node: cst.Module) -> cst.Module:
def leave_Module(self, original_node: cst.Module, updated_node: cst.Module) -> cst.Module: # noqa: ARG002
# Create import statement for codeflash_trace
if not self.added_codeflash_trace:
return updated_node
import_stmt = cst.SimpleStatementLine(
body=[
cst.ImportFrom(
module=cst.Attribute(
value=cst.Attribute(
value=cst.Name(value="codeflash"),
attr=cst.Name(value="benchmarking")
),
attr=cst.Name(value="codeflash_trace")
value=cst.Attribute(value=cst.Name(value="codeflash"), attr=cst.Name(value="benchmarking")),
attr=cst.Name(value="codeflash_trace"),
),
names=[
cst.ImportAlias(
name=cst.Name(value="codeflash_trace")
)
]
names=[cst.ImportAlias(name=cst.Name(value="codeflash_trace"))],
)
]
)
Expand All @@ -73,12 +71,13 @@ def leave_Module(self, original_node: cst.Module, updated_node: cst.Module) -> c

return updated_node.with_changes(body=new_body)


def add_codeflash_decorator_to_code(code: str, functions_to_optimize: list[FunctionToOptimize]) -> str:
"""Add codeflash_trace to a function.

Args:
code: The source code as a string
function_to_optimize: The FunctionToOptimize instance containing function details
functions_to_optimize: List of FunctionToOptimize instances containing function details

Returns:
The modified source code as a string
Expand All @@ -91,25 +90,18 @@ def add_codeflash_decorator_to_code(code: str, functions_to_optimize: list[Funct
class_name = function_to_optimize.parents[0].name
target_functions.add((class_name, function_to_optimize.function_name))

transformer = AddDecoratorTransformer(
target_functions = target_functions,
)
transformer = AddDecoratorTransformer(target_functions=target_functions)

module = cst.parse_module(code)
modified_module = module.visit(transformer)
return modified_module.code


def instrument_codeflash_trace_decorator(
file_to_funcs_to_optimize: dict[Path, list[FunctionToOptimize]]
) -> None:
def instrument_codeflash_trace_decorator(file_to_funcs_to_optimize: dict[Path, list[FunctionToOptimize]]) -> None:
"""Instrument codeflash_trace decorator to functions to optimize."""
for file_path, functions_to_optimize in file_to_funcs_to_optimize.items():
original_code = file_path.read_text(encoding="utf-8")
new_code = add_codeflash_decorator_to_code(
original_code,
functions_to_optimize
)
new_code = add_codeflash_decorator_to_code(original_code, functions_to_optimize)
# Modify the code
modified_code = isort.code(code=new_code, float_to_top=True)

Expand Down
Loading
Loading