diff --git a/codeflash/code_utils/code_replacer.py b/codeflash/code_utils/code_replacer.py index 269dd4706..28e252c47 100644 --- a/codeflash/code_utils/code_replacer.py +++ b/codeflash/code_utils/code_replacer.py @@ -354,21 +354,23 @@ def transform_asserts(self, code: str) -> str: return "\n".join(result_lines) def _transform_assert_line(self, line: str) -> Optional[str]: - indent = line[: len(line) - len(line.lstrip())] + indent_len = len(line) - len(line.lstrip()) + indent = line[:indent_len] - assert_match = re.match(r"\s*assert\s+(.*?)(?:\s*==\s*.*)?$", line) + assert_match = self.assert_pattern.match(line) if assert_match: expression = assert_match.group(1).strip() if expression.startswith("not "): return f"{indent}{expression}" - expression = re.sub(r"[,;]\s*$", "", expression) + # Removing trailing commas or semicolons without using regex + if expression and expression[-1] in ",;": + expression = expression[:-1] return f"{indent}{expression}" - unittest_match = re.match(r"(\s*)self\.assert([A-Za-z]+)\((.*)\)$", line) + unittest_match = self.unittest_pattern.match(line) if unittest_match: indent, assert_method, args = unittest_match.groups() - if args: arg_parts = self._split_top_level_args(args) if arg_parts and arg_parts[0]: @@ -399,6 +401,10 @@ def _split_top_level_args(self, args_str: str) -> list[str]: return result + def __init__(self): + self.assert_pattern = re.compile(r"\s*assert\s+(.*?)(?:\s*==\s*.*)?$") + self.unittest_pattern = re.compile(r"(\s*)self\.assert([A-Za-z]+)\((.*)\)$") + def clean_concolic_tests(test_suite_code: str) -> str: try: