Skip to content

Commit

Permalink
Revert "[DDR] Introduce implicit equality check for the source patter…
Browse files Browse the repository at this point in the history
…n operands with the same name."

This reverts commit 7271c1b.

This broke the gcc-5 build:

/usr/include/c++/5/ext/new_allocator.h:120:4: error: no matching function for call to 'std::pair<const std::__cxx11::basic_string<char>, mlir::tblgen::SymbolInfoMap::SymbolInfo>::pair(llvm::StringRef&, mlir::tblgen::SymbolInfoMap::SymbolInfo)'
  { ::new((void *)__p) _Up(std::forward<_Args>(__args)...); }
    ^
In file included from /usr/include/c++/5/utility:70:0,
                 from llvm/include/llvm/Support/type_traits.h:18,
                 from llvm/include/llvm/Support/Casting.h:18,
                 from mlir/include/mlir/Support/LLVM.h:24,
                 from mlir/include/mlir/TableGen/Pattern.h:17,
                 from mlir/lib/TableGen/Pattern.cpp:14:
/usr/include/c++/5/bits/stl_pair.h:206:9: note: candidate: template<class ... _Args1, long unsigned int ..._Indexes1, class ... _Args2, long unsigned int ..._Indexes2> std::pair<_T1, _T2>::pair(std::tuple<_Args1 ...>&, std::tuple<_Args2 ...>&, std::_Index_tuple<_Indexes1 ...>, std::_Index_tuple<_Indexes2 ...>)
         pair(tuple<_Args1...>&, tuple<_Args2...>&,
         ^
  • Loading branch information
joker-eph committed Oct 14, 2020
1 parent 5fe53c4 commit 0b793c4
Show file tree
Hide file tree
Showing 5 changed files with 21 additions and 250 deletions.
29 changes: 3 additions & 26 deletions mlir/include/mlir/TableGen/Pattern.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,6 @@
#include "llvm/ADT/StringMap.h"
#include "llvm/ADT/StringSet.h"

#include <unordered_map>

namespace llvm {
class DagInit;
class Init;
Expand Down Expand Up @@ -230,9 +228,6 @@ class SymbolInfoMap {
// value bound by this symbol.
std::string getVarDecl(StringRef name) const;

// Returns a variable name for the symbol named as `name`.
std::string getVarName(StringRef name) const;

private:
// Allow SymbolInfoMap to access private methods.
friend class SymbolInfoMap;
Expand Down Expand Up @@ -290,12 +285,9 @@ class SymbolInfoMap {
Kind kind; // The kind of the bound entity
// The argument index (for `Attr` and `Operand` only)
Optional<int> argIndex;
// Alternative name for the symbol. It is used in case the name
// is not unique. Applicable for `Operand` only.
Optional<std::string> alternativeName;
};

using BaseT = std::unordered_multimap<std::string, SymbolInfo>;
using BaseT = llvm::StringMap<SymbolInfo>;

// Iterators for accessing all symbols.
using iterator = BaseT::iterator;
Expand All @@ -308,7 +300,7 @@ class SymbolInfoMap {
const_iterator end() const { return symbolInfoMap.end(); }

// Binds the given `symbol` to the `argIndex`-th argument to the given `op`.
// Returns false if `symbol` is already bound and symbols are not operands.
// Returns false if `symbol` is already bound.
bool bindOpArgument(StringRef symbol, const Operator &op, int argIndex);

// Binds the given `symbol` to the results the given `op`. Returns false if
Expand All @@ -325,18 +317,6 @@ class SymbolInfoMap {
// Returns an iterator to the information of the given symbol named as `key`.
const_iterator find(StringRef key) const;

// Returns an iterator to the information of the given symbol named as `key`,
// with index `argIndex` for operator `op`.
const_iterator findBoundSymbol(StringRef key, const Operator &op,
int argIndex) const;

// Returns the bounds of a range that includes all the elements which
// bind to the `key`.
std::pair<iterator, iterator> getRangeOfEqualElements(StringRef key);

// Returns number of times symbol named as `key` was used.
int count(StringRef key) const;

// Returns the number of static values of the given `symbol` corresponds to.
// A static value is an operand/result declared in ODS. Normally a symbol only
// represents one static value, but symbols bound to op results can represent
Expand All @@ -358,9 +338,6 @@ class SymbolInfoMap {
std::string getAllRangeUse(StringRef symbol, const char *fmt = "{0}",
const char *separator = ", ") const;

// Assign alternative unique names to Operands that have equal names.
void assignUniqueAlternativeNames();

// Splits the given `symbol` into a value pack name and an index. Returns the
// value pack name and writes the index to `index` on success. Returns
// `symbol` itself if it does not contain an index.
Expand All @@ -370,7 +347,7 @@ class SymbolInfoMap {
static StringRef getValuePackName(StringRef symbol, int *index = nullptr);

private:
BaseT symbolInfoMap;
llvm::StringMap<SymbolInfo> symbolInfoMap;

// Pattern instantiation location. This is intended to be used as parameter
// to PrintFatalError() to report errors.
Expand Down
108 changes: 11 additions & 97 deletions mlir/lib/TableGen/Pattern.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -208,10 +208,6 @@ int SymbolInfoMap::SymbolInfo::getStaticValueCount() const {
llvm_unreachable("unknown kind");
}

std::string SymbolInfoMap::SymbolInfo::getVarName(StringRef name) const {
return alternativeName.hasValue() ? alternativeName.getValue() : name.str();
}

std::string SymbolInfoMap::SymbolInfo::getVarDecl(StringRef name) const {
LLVM_DEBUG(llvm::dbgs() << "getVarDecl for '" << name << "': ");
switch (kind) {
Expand All @@ -223,9 +219,8 @@ std::string SymbolInfoMap::SymbolInfo::getVarDecl(StringRef name) const {
case Kind::Operand: {
// Use operand range for captured operands (to support potential variadic
// operands).
return std::string(
formatv("::mlir::Operation::operand_range {0}(op0->getOperands());\n",
getVarName(name)));
return std::string(formatv(
"::mlir::Operation::operand_range {0}(op0->getOperands());\n", name));
}
case Kind::Value: {
return std::string(formatv("::llvm::ArrayRef<::mlir::Value> {0};\n", name));
Expand Down Expand Up @@ -364,73 +359,27 @@ bool SymbolInfoMap::bindOpArgument(StringRef symbol, const Operator &op,
? SymbolInfo::getAttr(&op, argIndex)
: SymbolInfo::getOperand(&op, argIndex);

std::string key = symbol.str();
if (auto numberOfEntries = symbolInfoMap.count(key)) {
// Only non unique name for the operand is supported.
if (symInfo.kind != SymbolInfo::Kind::Operand) {
return false;
}

// Cannot add new operand if there is already non operand with the same
// name.
if (symbolInfoMap.find(key)->second.kind != SymbolInfo::Kind::Operand) {
return false;
}
}

symbolInfoMap.emplace(key, symInfo);
return true;
return symbolInfoMap.insert({symbol, symInfo}).second;
}

bool SymbolInfoMap::bindOpResult(StringRef symbol, const Operator &op) {
StringRef name = getValuePackName(symbol);
auto inserted = symbolInfoMap.emplace(name, SymbolInfo::getResult(&op));

return symbolInfoMap.count(inserted->first) == 1;
return symbolInfoMap.insert({name, SymbolInfo::getResult(&op)}).second;
}

bool SymbolInfoMap::bindValue(StringRef symbol) {
auto inserted = symbolInfoMap.emplace(symbol, SymbolInfo::getValue());
return symbolInfoMap.count(inserted->first) == 1;
return symbolInfoMap.insert({symbol, SymbolInfo::getValue()}).second;
}

bool SymbolInfoMap::contains(StringRef symbol) const {
return find(symbol) != symbolInfoMap.end();
}

SymbolInfoMap::const_iterator SymbolInfoMap::find(StringRef key) const {
std::string name = getValuePackName(key).str();

StringRef name = getValuePackName(key);
return symbolInfoMap.find(name);
}

SymbolInfoMap::const_iterator
SymbolInfoMap::findBoundSymbol(StringRef key, const Operator &op,
int argIndex) const {
std::string name = getValuePackName(key).str();
auto range = symbolInfoMap.equal_range(name);

for (auto it = range.first; it != range.second; ++it) {
if (it->second.op == &op && it->second.argIndex == argIndex) {
return it;
}
}

return symbolInfoMap.end();
}

std::pair<SymbolInfoMap::iterator, SymbolInfoMap::iterator>
SymbolInfoMap::getRangeOfEqualElements(StringRef key) {
std::string name = getValuePackName(key).str();

return symbolInfoMap.equal_range(name);
}

int SymbolInfoMap::count(StringRef key) const {
std::string name = getValuePackName(key).str();
return symbolInfoMap.count(name);
}

int SymbolInfoMap::getStaticValueCount(StringRef symbol) const {
StringRef name = getValuePackName(symbol);
if (name != symbol) {
Expand All @@ -439,7 +388,7 @@ int SymbolInfoMap::getStaticValueCount(StringRef symbol) const {
return 1;
}
// Otherwise, find how many it represents by querying the symbol's info.
return find(name)->second.getStaticValueCount();
return find(name)->getValue().getStaticValueCount();
}

std::string SymbolInfoMap::getValueAndRangeUse(StringRef symbol,
Expand All @@ -448,58 +397,27 @@ std::string SymbolInfoMap::getValueAndRangeUse(StringRef symbol,
int index = -1;
StringRef name = getValuePackName(symbol, &index);

auto it = symbolInfoMap.find(name.str());
auto it = symbolInfoMap.find(name);
if (it == symbolInfoMap.end()) {
auto error = formatv("referencing unbound symbol '{0}'", symbol);
PrintFatalError(loc, error);
}

return it->second.getValueAndRangeUse(name, index, fmt, separator);
return it->getValue().getValueAndRangeUse(name, index, fmt, separator);
}

std::string SymbolInfoMap::getAllRangeUse(StringRef symbol, const char *fmt,
const char *separator) const {
int index = -1;
StringRef name = getValuePackName(symbol, &index);

auto it = symbolInfoMap.find(name.str());
auto it = symbolInfoMap.find(name);
if (it == symbolInfoMap.end()) {
auto error = formatv("referencing unbound symbol '{0}'", symbol);
PrintFatalError(loc, error);
}

return it->second.getAllRangeUse(name, index, fmt, separator);
}

void SymbolInfoMap::assignUniqueAlternativeNames() {
llvm::StringSet<> usedNames;

for (auto symbolInfoIt = symbolInfoMap.begin();
symbolInfoIt != symbolInfoMap.end();) {
auto range = symbolInfoMap.equal_range(symbolInfoIt->first);
auto startRange = range.first;
auto endRange = range.second;

auto operandName = symbolInfoIt->first;
int startSearchIndex = 0;
for (++startRange; startRange != endRange; ++startRange) {
// Current operand name is not unique, find a unique one
// and set the alternative name.
for (int i = startSearchIndex;; ++i) {
std::string alternativeName = operandName + std::to_string(i);
if (!usedNames.contains(alternativeName) &&
symbolInfoMap.count(alternativeName) == 0) {
usedNames.insert(alternativeName);
startRange->second.alternativeName = alternativeName;
startSearchIndex = i + 1;

break;
}
}
}

symbolInfoIt = endRange;
}
return it->getValue().getAllRangeUse(name, index, fmt, separator);
}

//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -527,10 +445,6 @@ void Pattern::collectSourcePatternBoundSymbols(SymbolInfoMap &infoMap) {
LLVM_DEBUG(llvm::dbgs() << "start collecting source pattern bound symbols\n");
collectBoundSymbols(getSourcePattern(), infoMap, /*isSrcPattern=*/true);
LLVM_DEBUG(llvm::dbgs() << "done collecting source pattern bound symbols\n");

LLVM_DEBUG(llvm::dbgs() << "start assigning alternative names for symbols\n");
infoMap.assignUniqueAlternativeNames();
LLVM_DEBUG(llvm::dbgs() << "done assigning alternative names for symbols\n");
}

void Pattern::collectResultPatternBoundSymbols(SymbolInfoMap &infoMap) {
Expand Down
26 changes: 0 additions & 26 deletions mlir/test/lib/Dialect/Test/TestOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -619,32 +619,6 @@ def OpM : TEST_Op<"op_m"> {
let results = (outs I32);
}

def OpN : TEST_Op<"op_n"> {
let arguments = (ins I32, I32);
let results = (outs I32);
}

def OpO : TEST_Op<"op_o"> {
let arguments = (ins I32);
let results = (outs I32);
}

def OpP : TEST_Op<"op_p"> {
let arguments = (ins I32, I32, I32, I32, I32, I32);
let results = (outs I32);
}

// Test same operand name enforces equality condition check.
def TestEqualArgsPattern : Pat<(OpN $a, $a), (OpO $a)>;

// Test when equality is enforced at different depth.
def TestNestedOpEqualArgsPattern :
Pat<(OpN $b, (OpP $a, $b, $c, $d, $e, $f)), (replaceWithValue $b)>;

// Test multiple equal arguments check enforced.
def TestMultipleEqualArgsPattern :
Pat<(OpP $a, $b, $a, $a, $b, $c), (OpN $c, $b)>;

// Test for memrefs normalization of an op with normalizable memrefs.
def OpNorm : TEST_Op<"op_norm", [MemRefsNormalizable]> {
let arguments = (ins AnyMemRef:$X, AnyMemRef:$Y);
Expand Down
58 changes: 0 additions & 58 deletions mlir/test/mlir-tblgen/pattern.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -111,64 +111,6 @@ func @verifyManyArgs(%arg: i32) {
return
}

// CHECK-LABEL: verifyEqualArgs
func @verifyEqualArgs(%arg0: i32, %arg1: i32) {
// def TestEqualArgsPattern : Pat<(OpN $a, $a), (OpO $a)>;

// CHECK: "test.op_o"(%arg0) : (i32) -> i32
"test.op_n"(%arg0, %arg0) : (i32, i32) -> (i32)

// CHECK: "test.op_n"(%arg0, %arg1) : (i32, i32) -> i32
"test.op_n"(%arg0, %arg1) : (i32, i32) -> (i32)

return
}

// CHECK-LABEL: verifyNestedOpEqualArgs
func @verifyNestedOpEqualArgs(
%arg0: i32, %arg1: i32, %arg2 : i32, %arg3 : i32, %arg4 : i32, %arg5 : i32) {
// def TestNestedOpEqualArgsPattern :
// Pat<(OpN $b, (OpP $a, $b, $c, $d, $e, $f)), (replaceWithValue $b)>;

// CHECK: %arg1
%0 = "test.op_p"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5)
: (i32, i32, i32, i32, i32, i32) -> (i32)
%1 = "test.op_n"(%arg1, %0) : (i32, i32) -> (i32)

// CHECK: test.op_p
// CHECK: test.op_n
%2 = "test.op_p"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5)
: (i32, i32, i32, i32, i32, i32) -> (i32)
%3 = "test.op_n"(%arg0, %2) : (i32, i32) -> (i32)

return
}

// CHECK-LABEL: verifyMultipleEqualArgs
func @verifyMultipleEqualArgs(
%arg0: i32, %arg1 : i32, %arg2 : i32, %arg3 : i32, %arg4 : i32) {
// def TestMultipleEqualArgsPattern :
// Pat<(OpP $a, $b, $a, $a, $b, $c), (OpN $c, $b)>;

// CHECK: "test.op_n"(%arg2, %arg1) : (i32, i32) -> i32
"test.op_p"(%arg0, %arg1, %arg0, %arg0, %arg1, %arg2) :
(i32, i32, i32, i32 , i32, i32) -> i32

// CHECK: test.op_p
"test.op_p"(%arg0, %arg1, %arg0, %arg0, %arg0, %arg2) :
(i32, i32, i32, i32 , i32, i32) -> i32

// CHECK: test.op_p
"test.op_p"(%arg0, %arg1, %arg1, %arg0, %arg1, %arg2) :
(i32, i32, i32, i32 , i32, i32) -> i32

// CHECK: test.op_p
"test.op_p"(%arg0, %arg1, %arg2, %arg2, %arg3, %arg4) :
(i32, i32, i32, i32 , i32, i32) -> i32

return
}

//===----------------------------------------------------------------------===//
// Test Symbol Binding
//===----------------------------------------------------------------------===//
Expand Down
Loading

0 comments on commit 0b793c4

Please sign in to comment.