Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 11 additions & 5 deletions codeflash/code_utils/code_replacer.py
Original file line number Diff line number Diff line change
Expand Up @@ -354,21 +354,23 @@ def transform_asserts(self, code: str) -> str:
return "\n".join(result_lines)

def _transform_assert_line(self, line: str) -> Optional[str]:
indent = line[: len(line) - len(line.lstrip())]
indent_len = len(line) - len(line.lstrip())
indent = line[:indent_len]

assert_match = re.match(r"\s*assert\s+(.*?)(?:\s*==\s*.*)?$", line)
assert_match = self.assert_pattern.match(line)
if assert_match:
expression = assert_match.group(1).strip()
if expression.startswith("not "):
return f"{indent}{expression}"

expression = re.sub(r"[,;]\s*$", "", expression)
# Removing trailing commas or semicolons without using regex
if expression and expression[-1] in ",;":
expression = expression[:-1]
return f"{indent}{expression}"

unittest_match = re.match(r"(\s*)self\.assert([A-Za-z]+)\((.*)\)$", line)
unittest_match = self.unittest_pattern.match(line)
if unittest_match:
indent, assert_method, args = unittest_match.groups()

if args:
arg_parts = self._split_top_level_args(args)
if arg_parts and arg_parts[0]:
Expand Down Expand Up @@ -399,6 +401,10 @@ def _split_top_level_args(self, args_str: str) -> list[str]:

return result

def __init__(self):
self.assert_pattern = re.compile(r"\s*assert\s+(.*?)(?:\s*==\s*.*)?$")
self.unittest_pattern = re.compile(r"(\s*)self\.assert([A-Za-z]+)\((.*)\)$")


def clean_concolic_tests(test_suite_code: str) -> str:
try:
Expand Down
Loading