Skip to content

Commit

Permalink
[mlir][Asm] Add support for resolving operation locations after parsi…
Browse files Browse the repository at this point in the history
…ng has finished

This revision adds support in the parser/printer for "deferrable" aliases, i.e. those that can be resolved after printing has finished. This allows for printing aliases for operation locations after the module instead of before, i.e. this is now supported:

```
"foo.op"() : () -> () loc(#loc)

#loc = loc("some_location")
```

Differential Revision: https://reviews.llvm.org/D91227
  • Loading branch information
River707 committed Nov 13, 2020
1 parent 92b036d commit 48e8129
Show file tree
Hide file tree
Showing 6 changed files with 221 additions and 117 deletions.
186 changes: 121 additions & 65 deletions mlir/lib/IR/AsmPrinter.cpp
Expand Up @@ -225,6 +225,38 @@ static raw_ostream &operator<<(raw_ostream &os, NewLineCounter &newLine) {
//===----------------------------------------------------------------------===//

namespace {
/// This class represents a specific instance of a symbol Alias.
class SymbolAlias {
public:
SymbolAlias(StringRef name, bool isDeferrable)
: name(name), suffixIndex(0), hasSuffixIndex(false),
isDeferrable(isDeferrable) {}
SymbolAlias(StringRef name, uint32_t suffixIndex, bool isDeferrable)
: name(name), suffixIndex(suffixIndex), hasSuffixIndex(true),
isDeferrable(isDeferrable) {}

/// Print this alias to the given stream.
void print(raw_ostream &os) const {
os << name;
if (hasSuffixIndex)
os << suffixIndex;
}

/// Returns true if this alias supports deferred resolution when parsing.
bool canBeDeferred() const { return isDeferrable; }

private:
/// The main name of the alias.
StringRef name;
/// The optional suffix index of the alias, if multiple aliases had the same
/// name.
uint32_t suffixIndex : 30;
/// A flag indicating whether this alias has a suffix or not.
bool hasSuffixIndex : 1;
/// A flag indicating whether this alias may be deferred or not.
bool isDeferrable : 1;
};

/// This class represents a utility that initializes the set of attribute and
/// type aliases, without the need to store the extra information within the
/// main AliasState class or pass it around via function arguments.
Expand All @@ -236,24 +268,26 @@ class AliasInitializer {
: interfaces(interfaces), aliasAllocator(aliasAllocator),
aliasOS(aliasBuffer) {}

void initialize(
Operation *op, const OpPrintingFlags &printerFlags,
llvm::MapVector<Attribute, std::pair<StringRef, Optional<int>>>
&attrToAlias,
llvm::MapVector<Type, std::pair<StringRef, Optional<int>>> &typeToAlias);
void initialize(Operation *op, const OpPrintingFlags &printerFlags,
llvm::MapVector<Attribute, SymbolAlias> &attrToAlias,
llvm::MapVector<Type, SymbolAlias> &typeToAlias);

/// Visit the given attribute to see if it has an alias.
void visit(Attribute attr);
/// Visit the given attribute to see if it has an alias. `canBeDeferred` is
/// set to true if the originator of this attribute can resolve the alias
/// after parsing has completed (e.g. in the case of operation locations).
void visit(Attribute attr, bool canBeDeferred = false);

/// Visit the given type to see if it has an alias.
void visit(Type type);

private:
/// Try to generate an alias for the provided symbol. If an alias is
/// generated, the provided alias mapping and reverse mapping are updated.
/// Returns success if an alias was generated, failure otherwise.
template <typename T>
void generateAlias(T symbol,
llvm::MapVector<StringRef, std::vector<T>> &aliasToSymbol);
LogicalResult
generateAlias(T symbol,
llvm::MapVector<StringRef, std::vector<T>> &aliasToSymbol);

/// The set of asm interfaces within the context.
DialectInterfaceCollection<OpAsmDialectInterface> &interfaces;
Expand All @@ -268,6 +302,9 @@ class AliasInitializer {
/// The set of visited attributes.
DenseSet<Attribute> visitedAttributes;

/// The set of attributes that have aliases *and* can be deferred.
DenseSet<Attribute> deferrableAttributes;

/// The set of visited types.
DenseSet<Type> visitedTypes;

Expand All @@ -291,7 +328,7 @@ class DummyAliasOperationPrinter : private OpAsmPrinter {
void print(Operation *op) {
// Visit the operation location.
if (printerFlags.shouldPrintDebugInfo())
initializer.visit(op->getLoc());
initializer.visit(op->getLoc(), /*canBeDeferred=*/true);

// If requested, always print the generic form.
if (!printerFlags.shouldPrintGenericOpForm()) {
Expand Down Expand Up @@ -464,9 +501,10 @@ static StringRef sanitizeIdentifier(StringRef name, SmallString<16> &buffer,
/// Given a collection of aliases and symbols, initialize a mapping from a
/// symbol to a given alias.
template <typename T>
static void initializeAliases(
llvm::MapVector<StringRef, std::vector<T>> &aliasToSymbol,
llvm::MapVector<T, std::pair<StringRef, Optional<int>>> &symbolToAlias) {
static void
initializeAliases(llvm::MapVector<StringRef, std::vector<T>> &aliasToSymbol,
llvm::MapVector<T, SymbolAlias> &symbolToAlias,
DenseSet<T> *deferrableAliases = nullptr) {
std::vector<std::pair<StringRef, std::vector<T>>> aliases =
aliasToSymbol.takeVector();
llvm::array_pod_sort(aliases.begin(), aliases.end(),
Expand All @@ -477,34 +515,50 @@ static void initializeAliases(
for (auto &it : aliases) {
// If there is only one instance for this alias, use the name directly.
if (it.second.size() == 1) {
symbolToAlias.insert({it.second.front(), {it.first, llvm::None}});
T symbol = it.second.front();
bool isDeferrable = deferrableAliases && deferrableAliases->count(symbol);
symbolToAlias.insert({symbol, SymbolAlias(it.first, isDeferrable)});
continue;
}
// Otherwise, add the index to the name.
for (int i = 0, e = it.second.size(); i < e; ++i)
symbolToAlias.insert({it.second[i], {it.first, i}});
for (int i = 0, e = it.second.size(); i < e; ++i) {
T symbol = it.second[i];
bool isDeferrable = deferrableAliases && deferrableAliases->count(symbol);
symbolToAlias.insert({symbol, SymbolAlias(it.first, i, isDeferrable)});
}
}
}

void AliasInitializer::initialize(
Operation *op, const OpPrintingFlags &printerFlags,
llvm::MapVector<Attribute, std::pair<StringRef, Optional<int>>>
&attrToAlias,
llvm::MapVector<Type, std::pair<StringRef, Optional<int>>> &typeToAlias) {
llvm::MapVector<Attribute, SymbolAlias> &attrToAlias,
llvm::MapVector<Type, SymbolAlias> &typeToAlias) {
// Use a dummy printer when walking the IR so that we can collect the
// attributes/types that will actually be used during printing when
// considering aliases.
DummyAliasOperationPrinter aliasPrinter(printerFlags, *this);
aliasPrinter.print(op);

// Initialize the aliases sorted by name.
initializeAliases(aliasToAttr, attrToAlias);
initializeAliases(aliasToAttr, attrToAlias, &deferrableAttributes);
initializeAliases(aliasToType, typeToAlias);
}

void AliasInitializer::visit(Attribute attr) {
if (!visitedAttributes.insert(attr).second)
void AliasInitializer::visit(Attribute attr, bool canBeDeferred) {
if (!visitedAttributes.insert(attr).second) {
// If this attribute already has an alias and this instance can't be
// deferred, make sure that the alias isn't deferred.
if (!canBeDeferred)
deferrableAttributes.erase(attr);
return;
}

// Try to generate an alias for this attribute.
if (succeeded(generateAlias(attr, aliasToAttr))) {
if (canBeDeferred)
deferrableAttributes.insert(attr);
return;
}

if (auto arrayAttr = attr.dyn_cast<ArrayAttr>()) {
for (Attribute element : arrayAttr.getValue())
Expand All @@ -515,15 +569,16 @@ void AliasInitializer::visit(Attribute attr) {
} else if (auto typeAttr = attr.dyn_cast<TypeAttr>()) {
visit(typeAttr.getValue());
}

// Try to generate an alias for this attribute.
generateAlias(attr, aliasToAttr);
}

void AliasInitializer::visit(Type type) {
if (!visitedTypes.insert(type).second)
return;

// Try to generate an alias for this type.
if (succeeded(generateAlias(type, aliasToType)))
return;

// Visit several subtypes that contain types or atttributes.
if (auto funcType = type.dyn_cast<FunctionType>()) {
// Visit input and result types for functions.
Expand All @@ -539,13 +594,10 @@ void AliasInitializer::visit(Type type) {
for (auto map : memref.getAffineMaps())
visit(AffineMapAttr::get(map));
}

// Try to generate an alias for this type.
generateAlias(type, aliasToType);
}

template <typename T>
void AliasInitializer::generateAlias(
LogicalResult AliasInitializer::generateAlias(
T symbol, llvm::MapVector<StringRef, std::vector<T>> &aliasToSymbol) {
SmallString<16> tempBuffer;
for (const auto &interface : interfaces) {
Expand All @@ -559,8 +611,9 @@ void AliasInitializer::generateAlias(

aliasToSymbol[name].push_back(symbol);
aliasBuffer.clear();
break;
return success();
}
return failure();
}

//===----------------------------------------------------------------------===//
Expand All @@ -580,21 +633,31 @@ class AliasState {
/// Returns success if an alias was printed, failure otherwise.
LogicalResult getAlias(Attribute attr, raw_ostream &os) const;

/// Print all of the referenced attribute aliases.
void printAttributeAliases(raw_ostream &os, NewLineCounter &newLine) const;

/// Get an alias for the given type if it has one and print it in `os`.
/// Returns success if an alias was printed, failure otherwise.
LogicalResult getAlias(Type ty, raw_ostream &os) const;

/// Print all of the referenced type aliases.
void printTypeAliases(raw_ostream &os, NewLineCounter &newLine) const;
/// Print all of the referenced aliases that can not be resolved in a deferred
/// manner.
void printNonDeferredAliases(raw_ostream &os, NewLineCounter &newLine) const {
printAliases(os, newLine, /*isDeferred=*/false);
}

/// Print all of the referenced aliases that support deferred resolution.
void printDeferredAliases(raw_ostream &os, NewLineCounter &newLine) const {
printAliases(os, newLine, /*isDeferred=*/true);
}

private:
/// Mapping between attribute and a pair comprised of a base alias name and a
/// count suffix. If the suffix is set to None, it is not displayed.
llvm::MapVector<Attribute, std::pair<StringRef, Optional<int>>> attrToAlias;
llvm::MapVector<Type, std::pair<StringRef, Optional<int>>> typeToAlias;
/// Print all of the referenced aliases that support the provided resolution
/// behavior.
void printAliases(raw_ostream &os, NewLineCounter &newLine,
bool isDeferred) const;

/// Mapping between attribute and alias.
llvm::MapVector<Attribute, SymbolAlias> attrToAlias;
/// Mapping between type and alias.
llvm::MapVector<Type, SymbolAlias> typeToAlias;

/// An allocator used for alias names.
llvm::BumpPtrAllocator aliasAllocator;
Expand All @@ -608,44 +671,34 @@ void AliasState::initialize(
initializer.initialize(op, printerFlags, attrToAlias, typeToAlias);
}

static void printAlias(raw_ostream &os,
const std::pair<StringRef, Optional<int>> &alias,
char prefix) {
os << prefix << alias.first;
if (alias.second)
os << *alias.second;
}

LogicalResult AliasState::getAlias(Attribute attr, raw_ostream &os) const {
auto it = attrToAlias.find(attr);
if (it == attrToAlias.end())
return failure();

printAlias(os, it->second, '#');
it->second.print(os << '#');
return success();
}

void AliasState::printAttributeAliases(raw_ostream &os,
NewLineCounter &newLine) const {
for (const auto &it : attrToAlias) {
printAlias(os, it.second, '#');
os << " = " << it.first << newLine;
}
}

LogicalResult AliasState::getAlias(Type ty, raw_ostream &os) const {
auto it = typeToAlias.find(ty);
if (it == typeToAlias.end())
return failure();

printAlias(os, it->second, '!');
it->second.print(os << '!');
return success();
}

void AliasState::printTypeAliases(raw_ostream &os,
NewLineCounter &newLine) const {
for (const auto &it : typeToAlias) {
printAlias(os, it.second, '!');
void AliasState::printAliases(raw_ostream &os, NewLineCounter &newLine,
bool isDeferred) const {
auto filterFn = [=](const auto &aliasIt) {
return aliasIt.second.canBeDeferred() == isDeferred;
};
for (const auto &it : llvm::make_filter_range(attrToAlias, filterFn)) {
it.second.print(os << '#');
os << " = " << it.first << newLine;
}
for (const auto &it : llvm::make_filter_range(typeToAlias, filterFn)) {
it.second.print(os << '!');
os << " = " << it.first << newLine;
}
}
Expand Down Expand Up @@ -2237,12 +2290,15 @@ class OperationPrinter : public ModulePrinter, private OpAsmPrinter {
} // end anonymous namespace

void OperationPrinter::print(ModuleOp op) {
// Output the aliases at the top level.
state->getAliasState().printAttributeAliases(os, newLine);
state->getAliasState().printTypeAliases(os, newLine);
// Output the aliases at the top level that can't be deferred.
state->getAliasState().printNonDeferredAliases(os, newLine);

// Print the module.
print(op.getOperation());
os << newLine;

// Output the aliases at the top level that can be deferred.
state->getAliasState().printDeferredAliases(os, newLine);
}

void OperationPrinter::print(Operation *op) {
Expand Down
32 changes: 0 additions & 32 deletions mlir/lib/Parser/LocationParser.cpp
Expand Up @@ -177,35 +177,3 @@ ParseResult Parser::parseLocationInstance(LocationAttr &loc) {

return emitError("expected location instance");
}

ParseResult Parser::parseOptionalTrailingLocation(Location &loc) {
// If there is a 'loc' we parse a trailing location.
if (!consumeIf(Token::kw_loc))
return success();
if (parseToken(Token::l_paren, "expected '(' in location"))
return failure();
Token tok = getToken();

// Check to see if we are parsing a location alias.
LocationAttr directLoc;
if (tok.is(Token::hash_identifier)) {
// TODO: This should be reworked a bit to allow for resolving operation
// locations to aliases after the operation has already been parsed(i.e.
// allow post parse location fixups).
Attribute attr = parseExtendedAttr(Type());
if (!attr)
return failure();
if (!(directLoc = attr.dyn_cast<LocationAttr>()))
return emitError(tok.getLoc()) << "expected location, but found " << attr;

// Otherwise, we parse the location directly.
} else if (parseLocationInstance(directLoc)) {
return failure();
}

if (parseToken(Token::r_paren, "expected ')' in location"))
return failure();

loc = directLoc;
return success();
}

0 comments on commit 48e8129

Please sign in to comment.