Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -254,3 +254,6 @@ fabric.properties

# Mac
.DS_Store

# VSC Workspace settings
codeflash.code-workspace
80 changes: 0 additions & 80 deletions codeflash.code-workspace

This file was deleted.

2 changes: 1 addition & 1 deletion codeflash/cli_cmds/cli_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def inquirer_wrapper_path(*args: str, **kwargs: str) -> dict[str, str] | None:
new_kwargs["message"] = last_message
new_args.append(args[0])

return cast(dict[str, str], inquirer.prompt([inquirer.Path(*new_args, **new_kwargs)]))
return cast("dict[str, str]", inquirer.prompt([inquirer.Path(*new_args, **new_kwargs)]))


def split_string_to_fit_width(string: str, width: int) -> list[str]:
Expand Down
24 changes: 12 additions & 12 deletions codeflash/cli_cmds/cmd_init.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,6 @@ def init_codeflash() -> None:
did_add_new_key = prompt_api_key()

if should_modify_pyproject_toml():

setup_info: SetupInfo = collect_setup_info()

configure_pyproject_toml(setup_info)
Expand All @@ -81,7 +80,6 @@ def init_codeflash() -> None:
if "setup_info" in locals():
module_string = f" you selected ({setup_info.module_root})"


click.echo(
f"{LF}"
f"⚡️ Codeflash is now set up! You can now run:{LF}"
Expand Down Expand Up @@ -123,29 +121,31 @@ def ask_run_end_to_end_test(args: Namespace) -> None:
bubble_sort_path, bubble_sort_test_path = create_bubble_sort_file_and_test(args)
run_end_to_end_test(args, bubble_sort_path, bubble_sort_test_path)


def should_modify_pyproject_toml() -> bool:
"""
Check if the current directory contains a valid pyproject.toml file with codeflash config
"""Check if the current directory contains a valid pyproject.toml file with codeflash config
If it does, ask the user if they want to re-configure it.
"""
from rich.prompt import Confirm

pyproject_toml_path = Path.cwd() / "pyproject.toml"
if not pyproject_toml_path.exists():
return True
try:
config, config_file_path = parse_config_file(pyproject_toml_path)
except Exception as e:
except Exception:
return True

if "module_root" not in config or config["module_root"] is None or not Path(config["module_root"]).is_dir():
return True
if "tests_root" not in config or config["tests_root"] is None or not Path(config["tests_root"]).is_dir():
return True

create_toml = Confirm.ask(
f"✅ A valid Codeflash config already exists in this project. Do you want to re-configure it?", default=False, show_default=True
return Confirm.ask(
"✅ A valid Codeflash config already exists in this project. Do you want to re-configure it?",
default=False,
show_default=True,
)
return create_toml


def collect_setup_info() -> SetupInfo:
Expand Down Expand Up @@ -224,7 +224,7 @@ def collect_setup_info() -> SetupInfo:
else:
apologize_and_exit()
else:
tests_root = Path(curdir) / Path(cast(str, tests_root_answer))
tests_root = Path(curdir) / Path(cast("str", tests_root_answer))
tests_root = tests_root.relative_to(curdir)
ph("cli-tests-root-provided")

Expand Down Expand Up @@ -278,9 +278,9 @@ def collect_setup_info() -> SetupInfo:
return SetupInfo(
module_root=str(module_root),
tests_root=str(tests_root),
test_framework=cast(str, test_framework),
test_framework=cast("str", test_framework),
ignore_paths=ignore_paths,
formatter=cast(str, formatter),
formatter=cast("str", formatter),
git_remote=str(git_remote),
)

Expand Down Expand Up @@ -390,7 +390,7 @@ def check_for_toml_or_setup_file() -> str | None:
click.echo("⏩️ Skipping pyproject.toml creation.")
apologize_and_exit()
click.echo()
return cast(str, project_name)
return cast("str", project_name)


def install_github_actions(override_formatter_check: bool = False) -> None:
Expand Down
6 changes: 4 additions & 2 deletions codeflash/code_utils/code_extractor.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from __future__ import annotations

import ast
from pathlib import Path
from typing import TYPE_CHECKING

import libcst as cst
Expand All @@ -11,12 +10,15 @@
from libcst.helpers import calculate_module_and_package

from codeflash.cli_cmds.console import logger
from codeflash.models.models import FunctionParent, FunctionSource
from codeflash.models.models import FunctionParent

if TYPE_CHECKING:
from pathlib import Path

from libcst.helpers import ModuleNameAndPackage

from codeflash.discovery.functions_to_optimize import FunctionToOptimize
from codeflash.models.models import FunctionSource


class FutureAliasedImportTransformer(cst.CSTTransformer):
Expand Down
2 changes: 1 addition & 1 deletion codeflash/code_utils/config_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ def parse_config_file(
"In pyproject.toml, Codeflash only supports the 'test-framework' as pytest and unittest."
)
if len(config["formatter-cmds"]) > 0:
#see if this is happening during GitHub actions setup
# see if this is happening during GitHub actions setup
if not override_formatter_check:
assert config["formatter-cmds"][0] != "your-formatter $file", (
"The formatter command is not set correctly in pyproject.toml. Please set the "
Expand Down
68 changes: 39 additions & 29 deletions codeflash/context/code_context_extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,18 +3,15 @@
import os
from collections import defaultdict
from itertools import chain
from pathlib import Path
from typing import TYPE_CHECKING, Optional

import jedi
import libcst as cst
import tiktoken
from jedi.api.classes import Name
from libcst import CSTNode

from codeflash.cli_cmds.console import logger
from codeflash.code_utils.code_extractor import add_needed_imports_from_module, find_preexisting_objects
from codeflash.code_utils.code_utils import get_qualified_name, path_belongs_to_site_packages
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
from codeflash.models.models import (
CodeContextType,
CodeOptimizationContext,
Expand All @@ -24,6 +21,15 @@
)
from codeflash.optimization.function_context import belongs_to_function_qualified

if TYPE_CHECKING:
from pathlib import Path

from jedi.api.classes import Name
from libcst import CSTNode

from codeflash.discovery.functions_to_optimize import FunctionToOptimize
from typing import Callable


def get_code_optimization_context(
function_to_optimize: FunctionToOptimize,
Expand Down Expand Up @@ -75,7 +81,8 @@ def get_code_optimization_context(
tokenizer = tiktoken.encoding_for_model("gpt-4o")
final_read_writable_tokens = len(tokenizer.encode(final_read_writable_code))
if final_read_writable_tokens > optim_token_limit:
raise ValueError("Read-writable code has exceeded token limit, cannot proceed")
msg = "Read-writable code has exceeded token limit, cannot proceed"
raise ValueError(msg)

# Setup preexisting objects for code replacer
preexisting_objects = set(
Expand Down Expand Up @@ -122,7 +129,8 @@ def get_code_optimization_context(
testgen_context_code = testgen_code_markdown.code
testgen_context_code_tokens = len(tokenizer.encode(testgen_context_code))
if testgen_context_code_tokens > testgen_token_limit:
raise ValueError("Testgen code context has exceeded token limit, cannot proceed")
msg = "Testgen code context has exceeded token limit, cannot proceed"
raise ValueError(msg)

return CodeOptimizationContext(
testgen_context_code=testgen_context_code,
Expand All @@ -143,7 +151,7 @@ def extract_code_string_context_from_files(
"""Extract code context from files containing target functions and their helpers.
This function processes two sets of files:
1. Files containing the function to optimize (fto) and their first-degree helpers
2. Files containing only helpers of helpers (with no overlap with the first set)
2. Files containing only helpers of helpers (with no overlap with the first set).

For each file, it extracts relevant code based on the specified context type, adds necessary
imports, and combines them.
Expand Down Expand Up @@ -358,19 +366,17 @@ def get_function_to_optimize_as_function_source(
and name.name == function_to_optimize.function_name
and get_qualified_name(name.module_name, name.full_name) == function_to_optimize.qualified_name
):
function_source = FunctionSource(
return FunctionSource(
file_path=function_to_optimize.file_path,
qualified_name=function_to_optimize.qualified_name,
fully_qualified_name=name.full_name,
only_function_name=name.name,
source_code=name.get_line_code(),
jedi_definition=name,
)
return function_source

raise ValueError(
f"Could not find function {function_to_optimize.function_name} in {function_to_optimize.file_path}"
)
msg = f"Could not find function {function_to_optimize.function_name} in {function_to_optimize.file_path}"
raise ValueError(msg)


def get_function_sources_from_jedi(
Expand Down Expand Up @@ -436,7 +442,7 @@ def get_section_names(node: cst.CSTNode) -> list[str]:


def remove_docstring_from_body(indented_block: cst.IndentedBlock) -> cst.CSTNode:
"""Removes the docstring from an indented block if it exists"""
"""Removes the docstring from an indented block if it exists."""
if not isinstance(indented_block.body[0], cst.SimpleStatementLine):
return indented_block
first_stmt = indented_block.body[0].body[0]
Expand Down Expand Up @@ -464,16 +470,13 @@ def parse_code_and_prune_cst(
filtered_node, found_target = prune_cst_for_testgen_code(
module, target_functions, helpers_of_helper_functions, remove_docstrings=remove_docstrings
)
else:
raise ValueError(f"Unknown code_context_type: {code_context_type}")

if not found_target:
raise ValueError("No target functions found in the provided code")
msg = "No target functions found in the provided code"
raise ValueError(msg)
if filtered_node and isinstance(filtered_node, cst.Module):
return str(filtered_node.code)
return ""


def prune_cst_for_read_writable_code(
node: cst.CSTNode, target_functions: set[str], prefix: str = ""
) -> tuple[cst.CSTNode | None, bool]:
Expand All @@ -500,7 +503,8 @@ def prune_cst_for_read_writable_code(
return None, False
# Assuming always an IndentedBlock
if not isinstance(node.body, cst.IndentedBlock):
raise ValueError("ClassDef body is not an IndentedBlock")
msg = "ClassDef body is not an IndentedBlock"
raise ValueError(msg)
class_prefix = f"{prefix}.{node.name.value}" if prefix else node.name.value
new_body = []
found_target = False
Expand Down Expand Up @@ -593,7 +597,8 @@ def prune_cst_for_read_only_code(
return None, False
# Assuming always an IndentedBlock
if not isinstance(node.body, cst.IndentedBlock):
raise ValueError("ClassDef body is not an IndentedBlock")
msg = "ClassDef body is not an IndentedBlock"
raise ValueError(msg)

class_prefix = f"{prefix}.{node.name.value}" if prefix else node.name.value

Expand All @@ -612,10 +617,12 @@ def prune_cst_for_read_only_code(
return None, False

if remove_docstrings:
return node.with_changes(
body=remove_docstring_from_body(node.body.with_changes(body=new_class_body))
) if new_class_body else None, True
return node.with_changes(body=node.body.with_changes(body=new_class_body)) if new_class_body else None, True
return (
node.with_changes(body=remove_docstring_from_body(node.body.with_changes(body=new_class_body)))
if new_class_body
else None
), True
return (node.with_changes(body=node.body.with_changes(body=new_class_body)) if new_class_body else None), True

# For other nodes, keep the node and recursively filter children
section_names = get_section_names(node)
Expand Down Expand Up @@ -698,7 +705,8 @@ def prune_cst_for_testgen_code(
return None, False
# Assuming always an IndentedBlock
if not isinstance(node.body, cst.IndentedBlock):
raise ValueError("ClassDef body is not an IndentedBlock")
msg = "ClassDef body is not an IndentedBlock"
raise ValueError(msg)

class_prefix = f"{prefix}.{node.name.value}" if prefix else node.name.value

Expand All @@ -717,10 +725,12 @@ def prune_cst_for_testgen_code(
return None, False

if remove_docstrings:
return node.with_changes(
body=remove_docstring_from_body(node.body.with_changes(body=new_class_body))
) if new_class_body else None, True
return node.with_changes(body=node.body.with_changes(body=new_class_body)) if new_class_body else None, True
return (
node.with_changes(body=remove_docstring_from_body(node.body.with_changes(body=new_class_body)))
if new_class_body
else None
), True
return (node.with_changes(body=node.body.with_changes(body=new_class_body)) if new_class_body else None), True

# For other nodes, keep the node and recursively filter children
section_names = get_section_names(node)
Expand Down
Loading
Loading