diff --git a/mlir/utils/generate-test-checks.py b/mlir/utils/generate-test-checks.py index 3712a6b9c963d..22774468bb403 100755 --- a/mlir/utils/generate-test-checks.py +++ b/mlir/utils/generate-test-checks.py @@ -230,14 +230,11 @@ def process_line(line_chunks, variable_namer, use_ssa_name=False, strict_name_re # Process the source file lines. The source file doesn't have to be .mlir. -def process_source_lines(source_lines, note, args): +def process_source_lines(source_lines, args): source_split_re = re.compile(args.source_delim_regex) source_segments = [[]] for line in source_lines: - # Remove previous note. - if line in note: - continue # Remove previous CHECK lines. if line.find(args.check_prefix) != -1: continue @@ -359,9 +356,10 @@ def main(): source_segments = None if args.source: - source_segments = process_source_lines( - [l.rstrip() for l in open(args.source, "r")], autogenerated_note, args - ) + with open(args.source, "r") as f: + raw_source = f.read().replace(autogenerated_note, "") + raw_source_lines = [l.rstrip() for l in raw_source.splitlines()] + source_segments = process_source_lines(raw_source_lines, args) if args.inplace: assert args.output is None