diff --git a/codeflash/code_utils/code_extractor.py b/codeflash/code_utils/code_extractor.py index 0f69bed7a..f35bcfd5a 100644 --- a/codeflash/code_utils/code_extractor.py +++ b/codeflash/code_utils/code_extractor.py @@ -2,6 +2,7 @@ from __future__ import annotations import ast +from itertools import chain from typing import TYPE_CHECKING, Optional import libcst as cst @@ -119,6 +120,32 @@ def leave_Assign(self, original_node: cst.Assign, updated_node: cst.Assign) -> c return updated_node + def _find_insertion_index(self, updated_node: cst.Module) -> int: + """Find the position of the last import statement in the top-level of the module.""" + insert_index = 0 + for i, stmt in enumerate(updated_node.body): + is_top_level_import = isinstance(stmt, cst.SimpleStatementLine) and any( + isinstance(child, (cst.Import, cst.ImportFrom)) for child in stmt.body + ) + + is_conditional_import = isinstance(stmt, cst.If) and all( + isinstance(inner, cst.SimpleStatementLine) + and all(isinstance(child, (cst.Import, cst.ImportFrom)) for child in inner.body) + for inner in stmt.body.body + ) + + if is_top_level_import or is_conditional_import: + insert_index = i + 1 + + # Stop scanning once we reach a class or function definition. + # Imports are supposed to be at the top of the file, but they can technically appear anywhere, even at the bottom of the file. + # Without this check, a stray import later in the file + # would incorrectly shift our insertion index below actual code definitions. + if isinstance(stmt, (cst.ClassDef, cst.FunctionDef)): + break + + return insert_index + def leave_Module(self, original_node: cst.Module, updated_node: cst.Module) -> cst.Module: # Add any new assignments that weren't in the original file new_statements = list(updated_node.body) @@ -131,18 +158,26 @@ def leave_Module(self, original_node: cst.Module, updated_node: cst.Module) -> c ] if assignments_to_append: - # Add a blank line before appending new assignments if needed - if new_statements and not isinstance(new_statements[-1], cst.EmptyLine): - new_statements.append(cst.SimpleStatementLine([cst.Pass()], leading_lines=[cst.EmptyLine()])) - new_statements.pop() # Remove the Pass statement but keep the empty line - - # Add the new assignments - new_statements.extend( - [ - cst.SimpleStatementLine([assignment], leading_lines=[cst.EmptyLine()]) - for assignment in assignments_to_append - ] - ) + # after last top-level imports + insert_index = self._find_insertion_index(updated_node) + + assignment_lines = [ + cst.SimpleStatementLine([assignment], leading_lines=[cst.EmptyLine()]) + for assignment in assignments_to_append + ] + + new_statements = list(chain(new_statements[:insert_index], assignment_lines, new_statements[insert_index:])) + + # Add a blank line after the last assignment if needed + after_index = insert_index + len(assignment_lines) + if after_index < len(new_statements): + next_stmt = new_statements[after_index] + # If there's no empty line, add one + has_empty = any(isinstance(line, cst.EmptyLine) for line in next_stmt.leading_lines) + if not has_empty: + new_statements[after_index] = next_stmt.with_changes( + leading_lines=[cst.EmptyLine(), *next_stmt.leading_lines] + ) return updated_node.with_changes(body=new_statements) diff --git a/tests/test_code_replacement.py b/tests/test_code_replacement.py index 26e6e915b..405896087 100644 --- a/tests/test_code_replacement.py +++ b/tests/test_code_replacement.py @@ -2104,6 +2104,8 @@ def new_function2(value): """ expected_code = """import numpy as np +a = 6 + print("Hello world") if 2<3: a=4 @@ -2126,8 +2128,6 @@ def __call__(self, value): return "I am still old" def new_function2(value): return cst.ensure_type(value, str) - -a = 6 """ code_path = (Path(__file__).parent.resolve() / "../code_to_optimize/global_var_original.py").resolve() code_path.write_text(original_code, encoding="utf-8") @@ -3228,3 +3228,228 @@ def _map_tool_definition(f: ToolDefinition) -> ChatCompletionInputTool: assert not re.search(r"^import aiohttp as aiohttp_\b", new_code, re.MULTILINE) # conditional alias import: import as assert not re.search(r"^from math import pi as PI, sin as sine\b", new_code, re.MULTILINE) # conditional multiple aliases imports assert "from huggingface_hub import AsyncInferenceClient, ChatCompletionInputTool" not in new_code # conditional from import + +def test_top_level_global_assignments() -> None: + root_dir = Path(__file__).parent.parent.resolve() + main_file = Path(root_dir / "code_to_optimize/temp_main.py").resolve() + + original_code = '''""" +Module for generating GeneratedWorkflowParameters schema from workflow run input_text actions. +""" + +from typing import Any, Dict, List, Tuple + +import structlog +from pydantic import BaseModel + +from skyvern.forge import app +from skyvern.forge.sdk.prompting import PromptEngine +from skyvern.webeye.actions.actions import ActionType + +LOG = structlog.get_logger(__name__) + +# Initialize prompt engine +prompt_engine = PromptEngine("skyvern") + + +def hydrate_input_text_actions_with_field_names( + actions_by_task: Dict[str, List[Dict[str, Any]]], field_mappings: Dict[str, str] +) -> Dict[str, List[Dict[str, Any]]]: + """ + Add field_name to input_text actions based on generated mappings. + + Args: + actions_by_task: Dictionary mapping task IDs to lists of action dictionaries + field_mappings: Dictionary mapping "task_id:action_id" to field names + + Returns: + Updated actions_by_task with field_name added to input_text actions + """ + updated_actions_by_task = {} + + for task_id, actions in actions_by_task.items(): + updated_actions = [] + + for action in actions: + action_copy = action.copy() + + if action.get("action_type") == ActionType.INPUT_TEXT: + action_id = action.get("action_id", "") + mapping_key = f"{task_id}:{action_id}" + + if mapping_key in field_mappings: + action_copy["field_name"] = field_mappings[mapping_key] + else: + # Fallback field name if mapping not found + intention = action.get("intention", "") + if intention: + # Simple field name generation from intention + field_name = intention.lower().replace(" ", "_").replace("?", "").replace("'", "") + field_name = "".join(c for c in field_name if c.isalnum() or c == "_") + action_copy["field_name"] = field_name or "unknown_field" + else: + action_copy["field_name"] = "unknown_field" + + updated_actions.append(action_copy) + + updated_actions_by_task[task_id] = updated_actions + + return updated_actions_by_task +''' + main_file.write_text(original_code, encoding="utf-8") + optim_code = f'''```python:{main_file.relative_to(root_dir)} +from skyvern.webeye.actions.actions import ActionType +from typing import Any, Dict, List +import re + +# Precompiled regex for efficiently generating simple field_name from intention +_INTENTION_CLEANUP_RE = re.compile(r"[^a-zA-Z0-9_]+") + +def hydrate_input_text_actions_with_field_names( + actions_by_task: Dict[str, List[Dict[str, Any]]], field_mappings: Dict[str, str] +) -> Dict[str, List[Dict[str, Any]]]: + """ + Add field_name to input_text actions based on generated mappings. + + Args: + actions_by_task: Dictionary mapping task IDs to lists of action dictionaries + field_mappings: Dictionary mapping "task_id:action_id" to field names + + Returns: + Updated actions_by_task with field_name added to input_text actions + """ + updated_actions_by_task = {{}} + + input_text_type = ActionType.INPUT_TEXT # local variable for faster access + intention_cleanup = _INTENTION_CLEANUP_RE + + for task_id, actions in actions_by_task.items(): + updated_actions = [] + + for action in actions: + action_copy = action.copy() + + if action.get("action_type") == input_text_type: + action_id = action.get("action_id", "") + mapping_key = f"{{task_id}}:{{action_id}}" + + if mapping_key in field_mappings: + action_copy["field_name"] = field_mappings[mapping_key] + else: + # Fallback field name if mapping not found + intention = action.get("intention", "") + if intention: + # Simple field name generation from intention + field_name = intention.lower().replace(" ", "_").replace("?", "").replace("'", "") + # Use compiled regex instead of "".join(c for ...) + field_name = intention_cleanup.sub("", field_name) + action_copy["field_name"] = field_name or "unknown_field" + else: + action_copy["field_name"] = "unknown_field" + + updated_actions.append(action_copy) + + updated_actions_by_task[task_id] = updated_actions + + return updated_actions_by_task +``` +''' + expected = '''""" +Module for generating GeneratedWorkflowParameters schema from workflow run input_text actions. +""" + +from typing import Any, Dict, List, Tuple + +import structlog +from pydantic import BaseModel + +from skyvern.forge import app +from skyvern.forge.sdk.prompting import PromptEngine +from skyvern.webeye.actions.actions import ActionType +import re + +_INTENTION_CLEANUP_RE = re.compile(r"[^a-zA-Z0-9_]+") + +LOG = structlog.get_logger(__name__) + +# Initialize prompt engine +prompt_engine = PromptEngine("skyvern") + + +def hydrate_input_text_actions_with_field_names( + actions_by_task: Dict[str, List[Dict[str, Any]]], field_mappings: Dict[str, str] +) -> Dict[str, List[Dict[str, Any]]]: + """ + Add field_name to input_text actions based on generated mappings. + + Args: + actions_by_task: Dictionary mapping task IDs to lists of action dictionaries + field_mappings: Dictionary mapping "task_id:action_id" to field names + + Returns: + Updated actions_by_task with field_name added to input_text actions + """ + updated_actions_by_task = {} + + input_text_type = ActionType.INPUT_TEXT # local variable for faster access + intention_cleanup = _INTENTION_CLEANUP_RE + + for task_id, actions in actions_by_task.items(): + updated_actions = [] + + for action in actions: + action_copy = action.copy() + + if action.get("action_type") == input_text_type: + action_id = action.get("action_id", "") + mapping_key = f"{task_id}:{action_id}" + + if mapping_key in field_mappings: + action_copy["field_name"] = field_mappings[mapping_key] + else: + # Fallback field name if mapping not found + intention = action.get("intention", "") + if intention: + # Simple field name generation from intention + field_name = intention.lower().replace(" ", "_").replace("?", "").replace("'", "") + # Use compiled regex instead of "".join(c for ...) + field_name = intention_cleanup.sub("", field_name) + action_copy["field_name"] = field_name or "unknown_field" + else: + action_copy["field_name"] = "unknown_field" + + updated_actions.append(action_copy) + + updated_actions_by_task[task_id] = updated_actions + + return updated_actions_by_task +''' + + func = FunctionToOptimize(function_name="hydrate_input_text_actions_with_field_names", parents=[], file_path=main_file) + test_config = TestConfig( + tests_root=root_dir / "tests/pytest", + tests_project_rootdir=root_dir, + project_root_path=root_dir, + test_framework="pytest", + pytest_cmd="pytest", + ) + func_optimizer = FunctionOptimizer(function_to_optimize=func, test_cfg=test_config) + code_context: CodeOptimizationContext = func_optimizer.get_code_optimization_context().unwrap() + + original_helper_code: dict[Path, str] = {} + helper_function_paths = {hf.file_path for hf in code_context.helper_functions} + for helper_function_path in helper_function_paths: + with helper_function_path.open(encoding="utf8") as f: + helper_code = f.read() + original_helper_code[helper_function_path] = helper_code + + func_optimizer.args = Args() + func_optimizer.replace_function_and_helpers_with_optimized_code( + code_context=code_context, optimized_code=CodeStringsMarkdown.parse_markdown_code(optim_code), original_helper_code=original_helper_code + ) + + + new_code = main_file.read_text(encoding="utf-8") + main_file.unlink(missing_ok=True) + + assert new_code == expected diff --git a/tests/test_multi_file_code_replacement.py b/tests/test_multi_file_code_replacement.py index 05a9c01c0..e33e98d24 100644 --- a/tests/test_multi_file_code_replacement.py +++ b/tests/test_multi_file_code_replacement.py @@ -18,6 +18,8 @@ def test_multi_file_replcement01() -> None: from pydantic_ai_slim.pydantic_ai.messages import BinaryContent, UserContent +_TOKEN_SPLIT_RE = re.compile(r'[\\s",.:]+') + def _estimate_string_tokens(content: str | Sequence[UserContent]) -> int: if not content: return 0 @@ -34,9 +36,6 @@ def _estimate_string_tokens(content: str | Sequence[UserContent]) -> int: # TODO(Marcelo): We need to study how we can estimate the tokens for AudioUrl or ImageUrl. return tokens - - -_TOKEN_SPLIT_RE = re.compile(r'[\\s",.:]+') """, encoding="utf-8") main_file = (root_dir / "code_to_optimize/temp_main.py").resolve() @@ -131,6 +130,10 @@ def _get_string_usage(text: str) -> Usage: from pydantic_ai_slim.pydantic_ai.messages import BinaryContent, UserContent +_translate_table = {ord(c): ord(' ') for c in ' \\t\\n\\r\\x0b\\x0c",.:'} + +_TOKEN_SPLIT_RE = re.compile(r'[\\s",.:]+') + def _estimate_string_tokens(content: str | Sequence[UserContent]) -> int: if not content: return 0 @@ -155,11 +158,6 @@ def _estimate_string_tokens(content: str | Sequence[UserContent]) -> int: tokens += len(part.data) return tokens - - -_TOKEN_SPLIT_RE = re.compile(r'[\\s",.:]+') - -_translate_table = {ord(c): ord(' ') for c in ' \\t\\n\\r\\x0b\\x0c",.:'} """ assert new_code.rstrip() == original_main.rstrip() # No Change