Skip to content

Commit

Permalink
Static verifier for type/attribute in DRR
Browse files Browse the repository at this point in the history
Generate static function for matching the type/attribute to reduce the
memory footprint.

Reviewed By: Mogball

Differential Revision: https://reviews.llvm.org/D110199
  • Loading branch information
ChiaHungDuan committed Nov 8, 2021
1 parent ca47447 commit f3798ad
Show file tree
Hide file tree
Showing 6 changed files with 208 additions and 74 deletions.
40 changes: 25 additions & 15 deletions mlir/include/mlir/TableGen/CodeGenHelpers.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

#include "mlir/Support/IndentedOstream.h"
#include "mlir/TableGen/Dialect.h"
#include "mlir/TableGen/Format.h"
#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/StringExtras.h"
#include "llvm/ADT/StringRef.h"
Expand Down Expand Up @@ -91,8 +92,7 @@ class NamespaceEmitter {
///
class StaticVerifierFunctionEmitter {
public:
StaticVerifierFunctionEmitter(const llvm::RecordKeeper &records,
raw_ostream &os);
StaticVerifierFunctionEmitter(const llvm::RecordKeeper &records);

/// Emit the static verifier functions for `llvm::Record`s. The
/// `signatureFormat` describes the required arguments and it must have a
Expand All @@ -112,30 +112,40 @@ class StaticVerifierFunctionEmitter {
///
/// `typeArgName` is used to identify the argument that needs to check its
/// type. The constraint template will replace `$_self` with it.
void emitFunctionsFor(StringRef signatureFormat, StringRef errorHandlerFormat,
StringRef typeArgName, ArrayRef<llvm::Record *> opDefs,
bool emitDecl);

/// This is the helper to generate the constraint functions from op
/// definitions.
void emitConstraintMethodsInNamespace(StringRef signatureFormat,
StringRef errorHandlerFormat,
StringRef cppNamespace,
ArrayRef<const void *> constraints,
raw_ostream &rawOs, bool emitDecl);

/// Emit the static functions for the giving type constraints.
void emitConstraintMethods(StringRef signatureFormat,
StringRef errorHandlerFormat,
ArrayRef<const void *> constraints,
raw_ostream &rawOs, bool emitDecl);

/// Get the name of the local function used for the given type constraint.
/// These functions are used for operand and result constraints and have the
/// form:
/// LogicalResult(Operation *op, Type type, StringRef valueKind,
/// unsigned valueGroupStartIndex);
StringRef getTypeConstraintFn(const Constraint &constraint) const;
StringRef getConstraintFn(const Constraint &constraint) const;

/// The setter to set `self` in format context.
StaticVerifierFunctionEmitter &setSelf(StringRef str);

/// The setter to set `builder` in format context.
StaticVerifierFunctionEmitter &setBuilder(StringRef str);

private:
/// Returns a unique name to use when generating local methods.
static std::string getUniqueName(const llvm::RecordKeeper &records);

/// Emit local methods for the type constraints used within the provided op
/// definitions.
void emitTypeConstraintMethods(StringRef signatureFormat,
StringRef errorHandlerFormat,
StringRef typeArgName,
ArrayRef<llvm::Record *> opDefs,
bool emitDecl);

raw_indented_ostream os;
/// The format context used for building the verifier function.
FmtContext fctx;

/// A unique label for the file currently being generated. This is used to
/// ensure that the local functions have a unique name.
Expand Down
21 changes: 21 additions & 0 deletions mlir/include/mlir/TableGen/Pattern.h
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,9 @@ class DagLeaf {
void print(raw_ostream &os) const;

private:
friend llvm::DenseMapInfo<DagLeaf>;
const void *getAsOpaquePointer() const { return def; }

// Returns true if the TableGen Init `def` in this DagLeaf is a DefInit and
// also a subclass of the given `superclass`.
bool isSubClassOf(StringRef superclass) const;
Expand Down Expand Up @@ -523,6 +526,24 @@ struct DenseMapInfo<mlir::tblgen::DagNode> {
return lhs.node == rhs.node;
}
};

template <>
struct DenseMapInfo<mlir::tblgen::DagLeaf> {
static mlir::tblgen::DagLeaf getEmptyKey() {
return mlir::tblgen::DagLeaf(
llvm::DenseMapInfo<llvm::Init *>::getEmptyKey());
}
static mlir::tblgen::DagLeaf getTombstoneKey() {
return mlir::tblgen::DagLeaf(
llvm::DenseMapInfo<llvm::Init *>::getTombstoneKey());
}
static unsigned getHashValue(mlir::tblgen::DagLeaf leaf) {
return llvm::hash_value(leaf.getAsOpaquePointer());
}
static bool isEqual(mlir::tblgen::DagLeaf lhs, mlir::tblgen::DagLeaf rhs) {
return lhs.def == rhs.def;
}
};
} // end namespace llvm

#endif // MLIR_TABLEGEN_PATTERN_H_
10 changes: 7 additions & 3 deletions mlir/test/mlir-tblgen/rewriter-static-matcher.td
Original file line number Diff line number Diff line change
Expand Up @@ -37,12 +37,16 @@ def COp : NS_Op<"c_op", []> {
// Test static matcher for duplicate DagNode
// ---

// CHECK: static ::mlir::LogicalResult static_dag_matcher_0
// CHECK-DAG: static ::mlir::LogicalResult [[$TYPE_CONSTRAINT:__mlir_ods_local_type_constraint.*]]({{.*::mlir::Type typeOrAttr}}
// CHECK-DAG: static ::mlir::LogicalResult [[$ATTR_CONSTRAINT:__mlir_ods_local_type_constraint.*]]({{.*::mlir::Attribute}}
// CHECK-DAG: static ::mlir::LogicalResult [[$DAG_MATCHER:static_dag_matcher.*]](
// CHECK: if(failed([[$TYPE_CONSTRAINT]]
// CHECK: if(failed([[$ATTR_CONSTRAINT]]

// CHECK: if(failed(static_dag_matcher_0(rewriter, op1, tblgen_ops
// CHECK: if(failed([[$DAG_MATCHER]](rewriter, op1, tblgen_ops
def : Pat<(AOp (BOp I32Attr:$attr, I32:$int)),
(AOp $int)>;

// CHECK: if(failed(static_dag_matcher_0(rewriter, op1, tblgen_ops
// CHECK: if(failed([[$DAG_MATCHER]](rewriter, op1, tblgen_ops
def : Pat<(COp $_, (BOp I32Attr:$attr, I32:$int)),
(COp $attr, $int)>;
59 changes: 29 additions & 30 deletions mlir/tools/mlir-tblgen/CodeGenHelpers.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
//===----------------------------------------------------------------------===//

#include "mlir/TableGen/CodeGenHelpers.h"
#include "mlir/TableGen/Format.h"
#include "mlir/TableGen/Operator.h"
#include "llvm/ADT/SetVector.h"
#include "llvm/Support/FormatVariadic.h"
Expand All @@ -24,21 +23,34 @@ using namespace mlir;
using namespace mlir::tblgen;

StaticVerifierFunctionEmitter::StaticVerifierFunctionEmitter(
const llvm::RecordKeeper &records, raw_ostream &os)
: os(os), uniqueOutputLabel(getUniqueName(records)) {}
const llvm::RecordKeeper &records)
: uniqueOutputLabel(getUniqueName(records)) {}

void StaticVerifierFunctionEmitter::emitFunctionsFor(
StaticVerifierFunctionEmitter &
StaticVerifierFunctionEmitter::setSelf(StringRef str) {
fctx.withSelf(str);
return *this;
}

StaticVerifierFunctionEmitter &
StaticVerifierFunctionEmitter::setBuilder(StringRef str) {
fctx.withBuilder(str);
return *this;
}

void StaticVerifierFunctionEmitter::emitConstraintMethodsInNamespace(
StringRef signatureFormat, StringRef errorHandlerFormat,
StringRef typeArgName, ArrayRef<llvm::Record *> opDefs, bool emitDecl) {
StringRef cppNamespace, ArrayRef<const void *> constraints, raw_ostream &os,
bool emitDecl) {
llvm::Optional<NamespaceEmitter> namespaceEmitter;
if (!emitDecl)
namespaceEmitter.emplace(os, Operator(*opDefs[0]).getCppNamespace());
namespaceEmitter.emplace(os, cppNamespace);

emitTypeConstraintMethods(signatureFormat, errorHandlerFormat, typeArgName,
opDefs, emitDecl);
emitConstraintMethods(signatureFormat, errorHandlerFormat, constraints, os,
emitDecl);
}

StringRef StaticVerifierFunctionEmitter::getTypeConstraintFn(
StringRef StaticVerifierFunctionEmitter::getConstraintFn(
const Constraint &constraint) const {
auto it = localTypeConstraints.find(constraint.getAsOpaquePointer());
assert(it != localTypeConstraints.end() && "expected valid constraint fn");
Expand All @@ -65,28 +77,16 @@ std::string StaticVerifierFunctionEmitter::getUniqueName(
return uniqueName;
}

void StaticVerifierFunctionEmitter::emitTypeConstraintMethods(
void StaticVerifierFunctionEmitter::emitConstraintMethods(
StringRef signatureFormat, StringRef errorHandlerFormat,
StringRef typeArgName, ArrayRef<llvm::Record *> opDefs, bool emitDecl) {
// Collect a set of all of the used type constraints within the operation
// definitions.
llvm::SetVector<const void *> typeConstraints;
for (Record *def : opDefs) {
Operator op(*def);
for (NamedTypeConstraint &operand : op.getOperands())
if (operand.hasPredicate())
typeConstraints.insert(operand.constraint.getAsOpaquePointer());
for (NamedTypeConstraint &result : op.getResults())
if (result.hasPredicate())
typeConstraints.insert(result.constraint.getAsOpaquePointer());
}
ArrayRef<const void *> constraints, raw_ostream &rawOs, bool emitDecl) {
raw_indented_ostream os(rawOs);

// Record the mapping from predicate to constraint. If two constraints has the
// same predicate and constraint summary, they can share the same verification
// function.
llvm::DenseMap<Pred, const void *> predToConstraint;
FmtContext fctx;
for (auto it : llvm::enumerate(typeConstraints)) {
for (auto it : llvm::enumerate(constraints)) {
std::string name;
Constraint constraint = Constraint::getFromOpaquePointer(it.value());
Pred pred = constraint.getPredicate();
Expand All @@ -101,7 +101,7 @@ void StaticVerifierFunctionEmitter::emitTypeConstraintMethods(
// summary, otherwise we may report the wrong message while verification
// fails.
if (constraint.getSummary() == built.getSummary()) {
name = getTypeConstraintFn(built).str();
name = getConstraintFn(built).str();
break;
}
++iter;
Expand All @@ -126,12 +126,11 @@ void StaticVerifierFunctionEmitter::emitTypeConstraintMethods(
continue;

os << formatv(signatureFormat.data(), name) << " {\n";
os.indent() << "if (!("
<< tgfmt(constraint.getConditionTemplate(),
&fctx.withSelf(typeArgName))
os.indent() << "if (!(" << tgfmt(constraint.getConditionTemplate(), &fctx)
<< ")) {\n";
os.indent() << "return "
<< formatv(errorHandlerFormat.data(), constraint.getSummary())
<< formatv(errorHandlerFormat.data(),
escapeString(constraint.getSummary()))
<< ";\n";
os.unindent() << "}\nreturn ::mlir::success();\n";
os.unindent() << "}\n\n";
Expand Down
26 changes: 21 additions & 5 deletions mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2233,7 +2233,7 @@ void OpEmitter::genOperandResultVerifier(OpMethodBody &body,
continue;
// Emit a loop to check all the dynamic values in the pack.
StringRef constraintFn =
staticVerifierEmitter.getTypeConstraintFn(value.constraint);
staticVerifierEmitter.getConstraintFn(value.constraint);
body << " for (::mlir::Value v : valueGroup" << staticValue.index()
<< ") {\n"
<< " if (::mlir::failed(" << constraintFn
Expand Down Expand Up @@ -2639,11 +2639,27 @@ static void emitOpClasses(const RecordKeeper &recordKeeper,
return;

// Generate all of the locally instantiated methods first.
StaticVerifierFunctionEmitter staticVerifierEmitter(recordKeeper, os);
StaticVerifierFunctionEmitter staticVerifierEmitter(recordKeeper);
os << formatv(opCommentHeader, "Local Utility Method", "Definitions");
staticVerifierEmitter.emitFunctionsFor(
typeVerifierSignature, typeVerifierErrorHandler, /*typeArgName=*/"type",
defs, emitDecl);
staticVerifierEmitter.setSelf("type");

// Collect a set of all of the used type constraints within the operation
// definitions.
llvm::SetVector<const void *> typeConstraints;
for (Record *def : defs) {
Operator op(*def);
for (NamedTypeConstraint &operand : op.getOperands())
if (operand.hasPredicate())
typeConstraints.insert(operand.constraint.getAsOpaquePointer());
for (NamedTypeConstraint &result : op.getResults())
if (result.hasPredicate())
typeConstraints.insert(result.constraint.getAsOpaquePointer());
}

staticVerifierEmitter.emitConstraintMethodsInNamespace(
typeVerifierSignature, typeVerifierErrorHandler,
Operator(*defs[0]).getCppNamespace(), typeConstraints.getArrayRef(), os,
emitDecl);

for (auto *def : defs) {
Operator op(*def);
Expand Down
Loading

0 comments on commit f3798ad

Please sign in to comment.