Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 38 additions & 10 deletions mlir/utils/generate-test-checks.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,20 +31,26 @@
import os # Used to advertise this file's name ("autogenerated_note").
import re
import sys
from collections import Counter

ADVERT_BEGIN = "// NOTE: Assertions have been autogenerated by "
ADVERT_END = """
// The script is designed to make adding checks to
// a test case fast, it is *not* designed to be authoritative
// about what constitutes a good test! The CHECK should be
// minimized and named to reflect the test intent.
// This script is intended to make adding checks to a test case quick and easy.
// It is *not* authoritative about what constitutes a good test. After using the
// script, be sure to review and refine the generated checks. For example,
// CHECK lines should be minimized and named to reflect the test’s intent.
// For comprehensive guidelines, see:
// * https://mlir.llvm.org/getting_started/TestingGuide/
"""


# Regex command to match an SSA identifier.
SSA_RE_STR = "[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*"
SSA_RE = re.compile(SSA_RE_STR)

# Regex matching `dialect.op_name` (e.g. `vector.transfer_read`).
SSA_OP_NAME_RE = re.compile(r"\b(?:\s=\s[a-z_]+)[.]([a-z_]+)\b")

# Regex matching the left-hand side of an assignment
SSA_RESULTS_STR = r'\s*(%' + SSA_RE_STR + r')(\s*,\s*(%' + SSA_RE_STR + r'))*\s*='
SSA_RESULTS_RE = re.compile(SSA_RESULTS_STR)
Expand All @@ -63,7 +69,12 @@
class VariableNamer:
def __init__(self, variable_names):
self.scopes = []
# Counter for generic FileCHeck names, e.g. VAL_#N
self.name_counter = 0
# Counters for FileCheck names derived from Op names, e.g.
# TRANSFER_READ_#N (based on `vector.transfer_read`). Note, there's a
# dedicated counter for every Op type present in the input.
self.op_name_counter = Counter()

# Number of variable names to still generate in parent scope
self.generate_in_parent_scope_left = 0
Expand All @@ -77,17 +88,29 @@ def generate_in_parent_scope(self, n):
self.generate_in_parent_scope_left = n

# Generate a substitution name for the given ssa value name.
def generate_name(self, source_variable_name, use_ssa_name):
def generate_name(self, source_variable_name, use_ssa_name, op_name=""):

# Compute variable name
variable_name = self.variable_names.pop(0) if len(self.variable_names) > 0 else ''
if variable_name == '':
variable_name = (
self.variable_names.pop(0) if len(self.variable_names) > 0 else ""
)
if variable_name == "":
# If `use_ssa_name` is set, use the MLIR SSA value name to generate
# a FileCHeck substation string. As FileCheck requires these
# strings to start with a character, skip MLIR variables starting
# with a digit (e.g. `%0`).
#
# The next fallback option is to use the op name, if the
# corresponding match succeeds.
#
# If neither worked, use a generic name: `VAL_#N`.
if use_ssa_name and source_variable_name[0].isalpha():
variable_name = source_variable_name.upper()
elif op_name != "":
variable_name = (
op_name.upper() + "_" + str(self.op_name_counter[op_name])
)
self.op_name_counter[op_name] += 1
else:
variable_name = "VAL_" + str(self.name_counter)
self.name_counter += 1
Expand Down Expand Up @@ -123,6 +146,7 @@ def num_scopes(self):
def clear_names(self):
self.name_counter = 0
self.used_variable_names = set()
self.op_name_counter.clear()

class AttributeNamer:

Expand Down Expand Up @@ -170,8 +194,12 @@ def process_line(line_chunks, variable_namer, use_ssa_name=False, strict_name_re

# Process the rest that contained an SSA value name.
for chunk in line_chunks:
m = SSA_RE.match(chunk)
ssa_name = m.group(0) if m is not None else ''
ssa = SSA_RE.match(chunk)
op_name_with_dialect = SSA_OP_NAME_RE.search(chunk)
ssa_name = ssa.group(0) if ssa is not None else ""
op_name = (
op_name_with_dialect.group(1) if op_name_with_dialect is not None else ""
)

# Check if an existing variable exists for this name.
variable = None
Expand All @@ -185,7 +213,7 @@ def process_line(line_chunks, variable_namer, use_ssa_name=False, strict_name_re
output_line += "%[[" + variable + "]]"
else:
# Otherwise, generate a new variable.
variable = variable_namer.generate_name(ssa_name, use_ssa_name)
variable = variable_namer.generate_name(ssa_name, use_ssa_name, op_name)
if strict_name_re:
# Use stricter regexp for the variable name, if requested.
# Greedy matching may cause issues with the generic '.*'
Expand Down