|
| 1 | +#!/usr/bin/env python |
| 2 | +"""A script to generate FileCheck statements for mlir unit tests. |
| 3 | +
|
| 4 | +This script is a utility to add FileCheck patterns to an mlir file. |
| 5 | +
|
| 6 | +NOTE: The input .mlir is expected to be the output from the parser, not a |
| 7 | +stripped down variant. |
| 8 | +
|
| 9 | +Example usage: |
| 10 | +$ generate-test-checks.py foo.mlir |
| 11 | +$ mlir-opt foo.mlir -transformation | generate-test-checks.py |
| 12 | +
|
| 13 | +The script will heuristically insert CHECK/CHECK-LABEL commands for each line |
| 14 | +within the file. By default this script will also try to insert string |
| 15 | +substitution blocks for all SSA value names. The script is designed to make |
| 16 | +adding checks to a test case fast, it is *not* designed to be authoritative |
| 17 | +about what constitutes a good test! |
| 18 | +""" |
| 19 | + |
| 20 | +# Copyright 2019 The MLIR Authors. |
| 21 | +# |
| 22 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 23 | +# you may not use this file except in compliance with the License. |
| 24 | +# You may obtain a copy of the License at |
| 25 | +# |
| 26 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 27 | +# |
| 28 | +# Unless required by applicable law or agreed to in writing, software |
| 29 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 30 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 31 | +# See the License for the specific language governing permissions and |
| 32 | +# limitations under the License. |
| 33 | + |
| 34 | +import argparse |
| 35 | +import os # Used to advertise this file's name ("autogenerated_note"). |
| 36 | +import re |
| 37 | +import sys |
| 38 | +import string |
| 39 | + |
| 40 | +ADVERT = '// NOTE: Assertions have been autogenerated by ' |
| 41 | + |
| 42 | +# Regex command to match an SSA identifier. |
| 43 | +SSA_RE_STR = '[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*' |
| 44 | +SSA_RE = re.compile(SSA_RE_STR) |
| 45 | + |
| 46 | + |
| 47 | +# Class used to generate and manage string substitution blocks for SSA value |
| 48 | +# names. |
| 49 | +class SSAVariableNamer: |
| 50 | + |
| 51 | + def __init__(self): |
| 52 | + self.scopes = [] |
| 53 | + self.name_counter = 0 |
| 54 | + |
| 55 | + # Generate a subsitution name for the given ssa value name. |
| 56 | + def generate_name(self, ssa_name): |
| 57 | + variable = 'VAL_' + str(self.name_counter) |
| 58 | + self.name_counter += 1 |
| 59 | + self.scopes[-1][ssa_name] = variable |
| 60 | + return variable |
| 61 | + |
| 62 | + # Push a new variable name scope. |
| 63 | + def push_name_scope(self): |
| 64 | + self.scopes.append({}) |
| 65 | + |
| 66 | + # Pop the last variable name scope. |
| 67 | + def pop_name_scope(self): |
| 68 | + self.scopes.pop() |
| 69 | + |
| 70 | + |
| 71 | +# Process a line of input that has been split at each SSA identifier '%'. |
| 72 | +def process_line(line_chunks, variable_namer): |
| 73 | + output_line = '' |
| 74 | + |
| 75 | + # Process the rest that contained an SSA value name. |
| 76 | + for chunk in line_chunks: |
| 77 | + m = SSA_RE.match(chunk) |
| 78 | + ssa_name = m.group(0) |
| 79 | + |
| 80 | + # Check if an existing variable exists for this name. |
| 81 | + variable = None |
| 82 | + for scope in variable_namer.scopes: |
| 83 | + variable = scope.get(ssa_name) |
| 84 | + if variable is not None: |
| 85 | + break |
| 86 | + |
| 87 | + # If one exists, then output the existing name. |
| 88 | + if variable is not None: |
| 89 | + output_line += '[[' + variable + ']]' |
| 90 | + else: |
| 91 | + # Otherwise, generate a new variable. |
| 92 | + variable = variable_namer.generate_name(ssa_name) |
| 93 | + output_line += '[[' + variable + ':%.*]]' |
| 94 | + |
| 95 | + # Append the non named group. |
| 96 | + output_line += chunk[len(ssa_name):] |
| 97 | + |
| 98 | + return output_line + '\n' |
| 99 | + |
| 100 | + |
| 101 | +def main(): |
| 102 | + from argparse import RawTextHelpFormatter |
| 103 | + parser = argparse.ArgumentParser( |
| 104 | + description=__doc__, formatter_class=RawTextHelpFormatter) |
| 105 | + parser.add_argument( |
| 106 | + '--check-prefix', default='CHECK', help='Prefix to use from check file.') |
| 107 | + parser.add_argument( |
| 108 | + '-o', |
| 109 | + '--output', |
| 110 | + nargs='?', |
| 111 | + type=argparse.FileType('w'), |
| 112 | + default=sys.stdout) |
| 113 | + parser.add_argument( |
| 114 | + 'input', |
| 115 | + nargs='?', |
| 116 | + type=argparse.FileType('r'), |
| 117 | + default=sys.stdin) |
| 118 | + args = parser.parse_args() |
| 119 | + |
| 120 | + # Open the given input file. |
| 121 | + input_lines = [l.rstrip() for l in args.input] |
| 122 | + args.input.close() |
| 123 | + |
| 124 | + output_lines = [] |
| 125 | + |
| 126 | + # Generate a note used for the generated check file. |
| 127 | + script_name = os.path.basename(__file__) |
| 128 | + autogenerated_note = (ADVERT + 'utils/' + script_name) |
| 129 | + output_lines.append(autogenerated_note + '\n') |
| 130 | + |
| 131 | + # A map containing data used for naming SSA value names. |
| 132 | + variable_namer = SSAVariableNamer() |
| 133 | + for input_line in input_lines: |
| 134 | + if not input_line: |
| 135 | + continue |
| 136 | + lstripped_input_line = input_line.lstrip() |
| 137 | + |
| 138 | + # Lines with blocks begin with a ^. These lines have a trailing comment |
| 139 | + # that needs to be stripped. |
| 140 | + is_block = lstripped_input_line[0] == '^' |
| 141 | + if is_block: |
| 142 | + input_line = input_line.rsplit('//', 1)[0].rstrip() |
| 143 | + |
| 144 | + # Top-level operations are heuristically the operations at nesting level 1. |
| 145 | + is_toplevel_op = (not is_block and input_line.startswith(' ') and |
| 146 | + input_line[2] != ' ' and input_line[2] != '}') |
| 147 | + |
| 148 | + # If the line starts with a '}', pop the last name scope. |
| 149 | + if lstripped_input_line[0] == '}': |
| 150 | + variable_namer.pop_name_scope() |
| 151 | + |
| 152 | + # If the line ends with a '{', push a new name scope. |
| 153 | + if input_line[-1] == '{': |
| 154 | + variable_namer.push_name_scope() |
| 155 | + |
| 156 | + # Split the line at the each SSA value name. |
| 157 | + ssa_split = input_line.split('%') |
| 158 | + |
| 159 | + # If this is a top-level operation use 'CHECK-LABEL', otherwise 'CHECK:'. |
| 160 | + if not is_toplevel_op or not ssa_split[0]: |
| 161 | + output_line = '// ' + args.check_prefix + ': ' |
| 162 | + # Pad to align with the 'LABEL' statements. |
| 163 | + output_line += (' ' * len('-LABEL')) |
| 164 | + |
| 165 | + # Output the first line chunk that does not contain an SSA name. |
| 166 | + output_line += ssa_split[0] |
| 167 | + |
| 168 | + # Process the rest of the input line. |
| 169 | + output_line += process_line(ssa_split[1:], variable_namer) |
| 170 | + |
| 171 | + else: |
| 172 | + # Append a newline to the output to separate the logical blocks. |
| 173 | + output_lines.append('\n') |
| 174 | + output_line = '// ' + args.check_prefix + '-LABEL: ' |
| 175 | + |
| 176 | + # Output the first line chunk that does not contain an SSA name for the |
| 177 | + # label. |
| 178 | + output_line += ssa_split[0] + '\n' |
| 179 | + |
| 180 | + # Process the rest of the input line on a separate check line. |
| 181 | + if len(ssa_split) > 1: |
| 182 | + output_line += '// ' + args.check_prefix + '-SAME: ' |
| 183 | + |
| 184 | + # Pad to align with the original position in the line. |
| 185 | + output_line += ' ' * len(ssa_split[0]) |
| 186 | + |
| 187 | + # Process the rest of the line. |
| 188 | + output_line += process_line(ssa_split[1:], variable_namer) |
| 189 | + |
| 190 | + # Append the output line. |
| 191 | + output_lines.append(output_line) |
| 192 | + |
| 193 | + # Write the output. |
| 194 | + for output_line in output_lines: |
| 195 | + args.output.write(output_line) |
| 196 | + args.output.write('\n') |
| 197 | + args.output.close() |
| 198 | + |
| 199 | + |
| 200 | +if __name__ == '__main__': |
| 201 | + main() |
0 commit comments