Skip to content

Commit

Permalink
[DDR] Introduce implicit equality check for the source pattern operan…
Browse files Browse the repository at this point in the history
…ds with the same name.

This CL allows user to specify the same name for the operands in the source pattern which implicitly enforces equality on operands with the same name.
E.g., Pat<(OpA $a, $b, $a) ... > would create a matching rule for checking equality for the first and the last operands. Equality of the operands is enforced at any depth, e.g., OpA ($a, $b, OpB($a, $c, OpC ($a))).

Example usage: Pat<(Reshape $arg0, (Shape $arg0)), (replaceWithValue $arg0)>

Note, this feature only covers operands but not attributes.
Current use cases are based on the operand equality and explicitly add the constraint into the pattern. Attribute equality will be worked out on the different CL.

Differential Revision: https://reviews.llvm.org/D89254
  • Loading branch information
rdzhabarov authored and jpienaar committed Oct 13, 2020
1 parent ab870f3 commit 7271c1b
Show file tree
Hide file tree
Showing 5 changed files with 250 additions and 21 deletions.
29 changes: 26 additions & 3 deletions mlir/include/mlir/TableGen/Pattern.h
Expand Up @@ -21,6 +21,8 @@
#include "llvm/ADT/StringMap.h"
#include "llvm/ADT/StringSet.h"

#include <unordered_map>

namespace llvm {
class DagInit;
class Init;
Expand Down Expand Up @@ -228,6 +230,9 @@ class SymbolInfoMap {
// value bound by this symbol.
std::string getVarDecl(StringRef name) const;

// Returns a variable name for the symbol named as `name`.
std::string getVarName(StringRef name) const;

private:
// Allow SymbolInfoMap to access private methods.
friend class SymbolInfoMap;
Expand Down Expand Up @@ -285,9 +290,12 @@ class SymbolInfoMap {
Kind kind; // The kind of the bound entity
// The argument index (for `Attr` and `Operand` only)
Optional<int> argIndex;
// Alternative name for the symbol. It is used in case the name
// is not unique. Applicable for `Operand` only.
Optional<std::string> alternativeName;
};

using BaseT = llvm::StringMap<SymbolInfo>;
using BaseT = std::unordered_multimap<std::string, SymbolInfo>;

// Iterators for accessing all symbols.
using iterator = BaseT::iterator;
Expand All @@ -300,7 +308,7 @@ class SymbolInfoMap {
const_iterator end() const { return symbolInfoMap.end(); }

// Binds the given `symbol` to the `argIndex`-th argument to the given `op`.
// Returns false if `symbol` is already bound.
// Returns false if `symbol` is already bound and symbols are not operands.
bool bindOpArgument(StringRef symbol, const Operator &op, int argIndex);

// Binds the given `symbol` to the results the given `op`. Returns false if
Expand All @@ -317,6 +325,18 @@ class SymbolInfoMap {
// Returns an iterator to the information of the given symbol named as `key`.
const_iterator find(StringRef key) const;

// Returns an iterator to the information of the given symbol named as `key`,
// with index `argIndex` for operator `op`.
const_iterator findBoundSymbol(StringRef key, const Operator &op,
int argIndex) const;

// Returns the bounds of a range that includes all the elements which
// bind to the `key`.
std::pair<iterator, iterator> getRangeOfEqualElements(StringRef key);

// Returns number of times symbol named as `key` was used.
int count(StringRef key) const;

// Returns the number of static values of the given `symbol` corresponds to.
// A static value is an operand/result declared in ODS. Normally a symbol only
// represents one static value, but symbols bound to op results can represent
Expand All @@ -338,6 +358,9 @@ class SymbolInfoMap {
std::string getAllRangeUse(StringRef symbol, const char *fmt = "{0}",
const char *separator = ", ") const;

// Assign alternative unique names to Operands that have equal names.
void assignUniqueAlternativeNames();

// Splits the given `symbol` into a value pack name and an index. Returns the
// value pack name and writes the index to `index` on success. Returns
// `symbol` itself if it does not contain an index.
Expand All @@ -347,7 +370,7 @@ class SymbolInfoMap {
static StringRef getValuePackName(StringRef symbol, int *index = nullptr);

private:
llvm::StringMap<SymbolInfo> symbolInfoMap;
BaseT symbolInfoMap;

// Pattern instantiation location. This is intended to be used as parameter
// to PrintFatalError() to report errors.
Expand Down
108 changes: 97 additions & 11 deletions mlir/lib/TableGen/Pattern.cpp
Expand Up @@ -208,6 +208,10 @@ int SymbolInfoMap::SymbolInfo::getStaticValueCount() const {
llvm_unreachable("unknown kind");
}

std::string SymbolInfoMap::SymbolInfo::getVarName(StringRef name) const {
return alternativeName.hasValue() ? alternativeName.getValue() : name.str();
}

std::string SymbolInfoMap::SymbolInfo::getVarDecl(StringRef name) const {
LLVM_DEBUG(llvm::dbgs() << "getVarDecl for '" << name << "': ");
switch (kind) {
Expand All @@ -219,8 +223,9 @@ std::string SymbolInfoMap::SymbolInfo::getVarDecl(StringRef name) const {
case Kind::Operand: {
// Use operand range for captured operands (to support potential variadic
// operands).
return std::string(formatv(
"::mlir::Operation::operand_range {0}(op0->getOperands());\n", name));
return std::string(
formatv("::mlir::Operation::operand_range {0}(op0->getOperands());\n",
getVarName(name)));
}
case Kind::Value: {
return std::string(formatv("::llvm::ArrayRef<::mlir::Value> {0};\n", name));
Expand Down Expand Up @@ -359,27 +364,73 @@ bool SymbolInfoMap::bindOpArgument(StringRef symbol, const Operator &op,
? SymbolInfo::getAttr(&op, argIndex)
: SymbolInfo::getOperand(&op, argIndex);

return symbolInfoMap.insert({symbol, symInfo}).second;
std::string key = symbol.str();
if (auto numberOfEntries = symbolInfoMap.count(key)) {
// Only non unique name for the operand is supported.
if (symInfo.kind != SymbolInfo::Kind::Operand) {
return false;
}

// Cannot add new operand if there is already non operand with the same
// name.
if (symbolInfoMap.find(key)->second.kind != SymbolInfo::Kind::Operand) {
return false;
}
}

symbolInfoMap.emplace(key, symInfo);
return true;
}

bool SymbolInfoMap::bindOpResult(StringRef symbol, const Operator &op) {
StringRef name = getValuePackName(symbol);
return symbolInfoMap.insert({name, SymbolInfo::getResult(&op)}).second;
auto inserted = symbolInfoMap.emplace(name, SymbolInfo::getResult(&op));

return symbolInfoMap.count(inserted->first) == 1;
}

bool SymbolInfoMap::bindValue(StringRef symbol) {
return symbolInfoMap.insert({symbol, SymbolInfo::getValue()}).second;
auto inserted = symbolInfoMap.emplace(symbol, SymbolInfo::getValue());
return symbolInfoMap.count(inserted->first) == 1;
}

bool SymbolInfoMap::contains(StringRef symbol) const {
return find(symbol) != symbolInfoMap.end();
}

SymbolInfoMap::const_iterator SymbolInfoMap::find(StringRef key) const {
StringRef name = getValuePackName(key);
std::string name = getValuePackName(key).str();

return symbolInfoMap.find(name);
}

SymbolInfoMap::const_iterator
SymbolInfoMap::findBoundSymbol(StringRef key, const Operator &op,
int argIndex) const {
std::string name = getValuePackName(key).str();
auto range = symbolInfoMap.equal_range(name);

for (auto it = range.first; it != range.second; ++it) {
if (it->second.op == &op && it->second.argIndex == argIndex) {
return it;
}
}

return symbolInfoMap.end();
}

std::pair<SymbolInfoMap::iterator, SymbolInfoMap::iterator>
SymbolInfoMap::getRangeOfEqualElements(StringRef key) {
std::string name = getValuePackName(key).str();

return symbolInfoMap.equal_range(name);
}

int SymbolInfoMap::count(StringRef key) const {
std::string name = getValuePackName(key).str();
return symbolInfoMap.count(name);
}

int SymbolInfoMap::getStaticValueCount(StringRef symbol) const {
StringRef name = getValuePackName(symbol);
if (name != symbol) {
Expand All @@ -388,7 +439,7 @@ int SymbolInfoMap::getStaticValueCount(StringRef symbol) const {
return 1;
}
// Otherwise, find how many it represents by querying the symbol's info.
return find(name)->getValue().getStaticValueCount();
return find(name)->second.getStaticValueCount();
}

std::string SymbolInfoMap::getValueAndRangeUse(StringRef symbol,
Expand All @@ -397,27 +448,58 @@ std::string SymbolInfoMap::getValueAndRangeUse(StringRef symbol,
int index = -1;
StringRef name = getValuePackName(symbol, &index);

auto it = symbolInfoMap.find(name);
auto it = symbolInfoMap.find(name.str());
if (it == symbolInfoMap.end()) {
auto error = formatv("referencing unbound symbol '{0}'", symbol);
PrintFatalError(loc, error);
}

return it->getValue().getValueAndRangeUse(name, index, fmt, separator);
return it->second.getValueAndRangeUse(name, index, fmt, separator);
}

std::string SymbolInfoMap::getAllRangeUse(StringRef symbol, const char *fmt,
const char *separator) const {
int index = -1;
StringRef name = getValuePackName(symbol, &index);

auto it = symbolInfoMap.find(name);
auto it = symbolInfoMap.find(name.str());
if (it == symbolInfoMap.end()) {
auto error = formatv("referencing unbound symbol '{0}'", symbol);
PrintFatalError(loc, error);
}

return it->getValue().getAllRangeUse(name, index, fmt, separator);
return it->second.getAllRangeUse(name, index, fmt, separator);
}

void SymbolInfoMap::assignUniqueAlternativeNames() {
llvm::StringSet<> usedNames;

for (auto symbolInfoIt = symbolInfoMap.begin();
symbolInfoIt != symbolInfoMap.end();) {
auto range = symbolInfoMap.equal_range(symbolInfoIt->first);
auto startRange = range.first;
auto endRange = range.second;

auto operandName = symbolInfoIt->first;
int startSearchIndex = 0;
for (++startRange; startRange != endRange; ++startRange) {
// Current operand name is not unique, find a unique one
// and set the alternative name.
for (int i = startSearchIndex;; ++i) {
std::string alternativeName = operandName + std::to_string(i);
if (!usedNames.contains(alternativeName) &&
symbolInfoMap.count(alternativeName) == 0) {
usedNames.insert(alternativeName);
startRange->second.alternativeName = alternativeName;
startSearchIndex = i + 1;

break;
}
}
}

symbolInfoIt = endRange;
}
}

//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -445,6 +527,10 @@ void Pattern::collectSourcePatternBoundSymbols(SymbolInfoMap &infoMap) {
LLVM_DEBUG(llvm::dbgs() << "start collecting source pattern bound symbols\n");
collectBoundSymbols(getSourcePattern(), infoMap, /*isSrcPattern=*/true);
LLVM_DEBUG(llvm::dbgs() << "done collecting source pattern bound symbols\n");

LLVM_DEBUG(llvm::dbgs() << "start assigning alternative names for symbols\n");
infoMap.assignUniqueAlternativeNames();
LLVM_DEBUG(llvm::dbgs() << "done assigning alternative names for symbols\n");
}

void Pattern::collectResultPatternBoundSymbols(SymbolInfoMap &infoMap) {
Expand Down
26 changes: 26 additions & 0 deletions mlir/test/lib/Dialect/Test/TestOps.td
Expand Up @@ -619,6 +619,32 @@ def OpM : TEST_Op<"op_m"> {
let results = (outs I32);
}

def OpN : TEST_Op<"op_n"> {
let arguments = (ins I32, I32);
let results = (outs I32);
}

def OpO : TEST_Op<"op_o"> {
let arguments = (ins I32);
let results = (outs I32);
}

def OpP : TEST_Op<"op_p"> {
let arguments = (ins I32, I32, I32, I32, I32, I32);
let results = (outs I32);
}

// Test same operand name enforces equality condition check.
def TestEqualArgsPattern : Pat<(OpN $a, $a), (OpO $a)>;

// Test when equality is enforced at different depth.
def TestNestedOpEqualArgsPattern :
Pat<(OpN $b, (OpP $a, $b, $c, $d, $e, $f)), (replaceWithValue $b)>;

// Test multiple equal arguments check enforced.
def TestMultipleEqualArgsPattern :
Pat<(OpP $a, $b, $a, $a, $b, $c), (OpN $c, $b)>;

// Test for memrefs normalization of an op with normalizable memrefs.
def OpNorm : TEST_Op<"op_norm", [MemRefsNormalizable]> {
let arguments = (ins AnyMemRef:$X, AnyMemRef:$Y);
Expand Down
58 changes: 58 additions & 0 deletions mlir/test/mlir-tblgen/pattern.mlir
Expand Up @@ -111,6 +111,64 @@ func @verifyManyArgs(%arg: i32) {
return
}

// CHECK-LABEL: verifyEqualArgs
func @verifyEqualArgs(%arg0: i32, %arg1: i32) {
// def TestEqualArgsPattern : Pat<(OpN $a, $a), (OpO $a)>;

// CHECK: "test.op_o"(%arg0) : (i32) -> i32
"test.op_n"(%arg0, %arg0) : (i32, i32) -> (i32)

// CHECK: "test.op_n"(%arg0, %arg1) : (i32, i32) -> i32
"test.op_n"(%arg0, %arg1) : (i32, i32) -> (i32)

return
}

// CHECK-LABEL: verifyNestedOpEqualArgs
func @verifyNestedOpEqualArgs(
%arg0: i32, %arg1: i32, %arg2 : i32, %arg3 : i32, %arg4 : i32, %arg5 : i32) {
// def TestNestedOpEqualArgsPattern :
// Pat<(OpN $b, (OpP $a, $b, $c, $d, $e, $f)), (replaceWithValue $b)>;

// CHECK: %arg1
%0 = "test.op_p"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5)
: (i32, i32, i32, i32, i32, i32) -> (i32)
%1 = "test.op_n"(%arg1, %0) : (i32, i32) -> (i32)

// CHECK: test.op_p
// CHECK: test.op_n
%2 = "test.op_p"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5)
: (i32, i32, i32, i32, i32, i32) -> (i32)
%3 = "test.op_n"(%arg0, %2) : (i32, i32) -> (i32)

return
}

// CHECK-LABEL: verifyMultipleEqualArgs
func @verifyMultipleEqualArgs(
%arg0: i32, %arg1 : i32, %arg2 : i32, %arg3 : i32, %arg4 : i32) {
// def TestMultipleEqualArgsPattern :
// Pat<(OpP $a, $b, $a, $a, $b, $c), (OpN $c, $b)>;

// CHECK: "test.op_n"(%arg2, %arg1) : (i32, i32) -> i32
"test.op_p"(%arg0, %arg1, %arg0, %arg0, %arg1, %arg2) :
(i32, i32, i32, i32 , i32, i32) -> i32

// CHECK: test.op_p
"test.op_p"(%arg0, %arg1, %arg0, %arg0, %arg0, %arg2) :
(i32, i32, i32, i32 , i32, i32) -> i32

// CHECK: test.op_p
"test.op_p"(%arg0, %arg1, %arg1, %arg0, %arg1, %arg2) :
(i32, i32, i32, i32 , i32, i32) -> i32

// CHECK: test.op_p
"test.op_p"(%arg0, %arg1, %arg2, %arg2, %arg3, %arg4) :
(i32, i32, i32, i32 , i32, i32) -> i32

return
}

//===----------------------------------------------------------------------===//
// Test Symbol Binding
//===----------------------------------------------------------------------===//
Expand Down

0 comments on commit 7271c1b

Please sign in to comment.