Skip to content

Commit

Permalink
[mlir-lsp-server] Add support for tracking the use/def chains of symbols
Browse files Browse the repository at this point in the history
This revision adds assembly state tracking for uses of symbols, allowing for go-to-definition and references support for SymbolRefAttrs.

Differential Revision: https://reviews.llvm.org/D103585
  • Loading branch information
River707 committed Jun 3, 2021
1 parent e42def6 commit d6af89b
Show file tree
Hide file tree
Showing 9 changed files with 326 additions and 61 deletions.
36 changes: 33 additions & 3 deletions mlir/include/mlir/Parser/AsmParserState.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ class Block;
class BlockArgument;
class FileLineColLoc;
class Operation;
class OperationName;
class SymbolRefAttr;
class Value;

/// This class represents state from a parsed MLIR textual format string. It is
Expand Down Expand Up @@ -61,6 +63,10 @@ class AsmParserState {

/// Source definitions for any result groups of this operation.
SmallVector<std::pair<unsigned, SMDefinition>> resultGroups;

/// If this operation is a symbol operation, this vector contains symbol
/// uses of this operation.
SmallVector<llvm::SMRange> symbolUses;
};

/// This class represents the information for a block definition within the
Expand Down Expand Up @@ -112,17 +118,41 @@ class AsmParserState {
// Populate State
//===--------------------------------------------------------------------===//

/// Add a definition of the given operation.
void addDefinition(
Operation *op, llvm::SMRange location,
/// Initialize the state in preparation for populating more parser state under
/// the given top-level operation.
void initialize(Operation *topLevelOp);

/// Finalize any in-progress parser state under the given top-level operation.
void finalize(Operation *topLevelOp);

/// Start a definition for an operation with the given name.
void startOperationDefinition(const OperationName &opName);

/// Finalize the most recently started operation definition.
void finalizeOperationDefinition(
Operation *op, llvm::SMRange nameLoc,
ArrayRef<std::pair<unsigned, llvm::SMLoc>> resultGroups = llvm::None);

/// Start a definition for a region nested under the current operation.
void startRegionDefinition();

/// Finalize the most recently started region definition.
void finalizeRegionDefinition();

/// Add a definition of the given entity.
void addDefinition(Block *block, llvm::SMLoc location);
void addDefinition(BlockArgument blockArg, llvm::SMLoc location);

/// Add a source uses of the given value.
void addUses(Value value, ArrayRef<llvm::SMLoc> locations);
void addUses(Block *block, ArrayRef<llvm::SMLoc> locations);

/// Add source uses for all the references nested under `refAttr`. The
/// provided `locations` should match 1-1 with the number of references in
/// `refAttr`, i.e.:
/// nestedReferences.size() + /*leafReference=*/1 == refLocations.size()
void addUses(SymbolRefAttr refAttr, ArrayRef<llvm::SMRange> refLocations);

/// Refine the `oldValue` to the `newValue`. This is used to indicate that
/// `oldValue` was a placeholder, and the uses of it should really refer to
/// `newValue`.
Expand Down
125 changes: 121 additions & 4 deletions mlir/lib/Parser/AsmParserState.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

#include "mlir/Parser/AsmParserState.h"
#include "mlir/IR/Operation.h"
#include "mlir/IR/SymbolTable.h"

using namespace mlir;

Expand All @@ -16,6 +17,27 @@ using namespace mlir;
//===----------------------------------------------------------------------===//

struct AsmParserState::Impl {
/// A map from a SymbolRefAttr to a range of uses.
using SymbolUseMap = DenseMap<Attribute, SmallVector<llvm::SMRange>>;

struct PartialOpDef {
explicit PartialOpDef(const OperationName &opName) {
const auto *abstractOp = opName.getAbstractOperation();
if (abstractOp && abstractOp->hasTrait<OpTrait::SymbolTable>())
symbolTable = std::make_unique<SymbolUseMap>();
}

/// Return if this operation is a symbol table.
bool isSymbolTable() const { return symbolTable.get(); }

/// If this operation is a symbol table, the following contains symbol uses
/// within this operation.
std::unique_ptr<SymbolUseMap> symbolTable;
};

/// Resolve any symbol table uses under the given partial operation.
void resolveSymbolUses(Operation *op, PartialOpDef &opDef);

/// A mapping from operations in the input source file to their parser state.
SmallVector<std::unique_ptr<OperationDefinition>> operations;
DenseMap<Operation *, unsigned> operationToIdx;
Expand All @@ -27,8 +49,38 @@ struct AsmParserState::Impl {
/// A set of value definitions that are placeholders for forward references.
/// This map should be empty if the parser finishes successfully.
DenseMap<Value, SmallVector<llvm::SMLoc>> placeholderValueUses;

/// A stack of partial operation definitions that have been started but not
/// yet finalized.
SmallVector<PartialOpDef> partialOperations;

/// A stack of symbol use scopes. This is used when collecting symbol table
/// uses during parsing.
SmallVector<SymbolUseMap *> symbolUseScopes;

/// A symbol table containing all of the symbol table operations in the IR.
SymbolTableCollection symbolTable;
};

void AsmParserState::Impl::resolveSymbolUses(Operation *op,
PartialOpDef &opDef) {
assert(opDef.isSymbolTable() && "expected op to be a symbol table");

SmallVector<Operation *> symbolOps;
for (auto &it : *opDef.symbolTable) {
symbolOps.clear();
if (failed(symbolTable.lookupSymbolIn(op, it.first.cast<SymbolRefAttr>(),
symbolOps)))
continue;

for (const auto &symIt : llvm::zip(symbolOps, it.second)) {
auto opIt = operationToIdx.find(std::get<0>(symIt));
if (opIt != operationToIdx.end())
operations[opIt->second]->symbolUses.push_back(std::get<1>(symIt));
}
}
}

//===----------------------------------------------------------------------===//
// AsmParserState
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -77,17 +129,70 @@ llvm::SMRange AsmParserState::convertIdLocToRange(llvm::SMLoc loc) {
//===----------------------------------------------------------------------===//
// Populate State

void AsmParserState::addDefinition(
Operation *op, llvm::SMRange location,
void AsmParserState::initialize(Operation *topLevelOp) {
startOperationDefinition(topLevelOp->getName());

// If the top-level operation is a symbol table, push a new symbol scope.
Impl::PartialOpDef &partialOpDef = impl->partialOperations.back();
if (partialOpDef.isSymbolTable())
impl->symbolUseScopes.push_back(partialOpDef.symbolTable.get());
}

void AsmParserState::finalize(Operation *topLevelOp) {
assert(!impl->partialOperations.empty() &&
"expected valid partial operation definition");
Impl::PartialOpDef partialOpDef = impl->partialOperations.pop_back_val();

// If this operation is a symbol table, resolve any symbol uses.
if (partialOpDef.isSymbolTable())
impl->resolveSymbolUses(topLevelOp, partialOpDef);
}

void AsmParserState::startOperationDefinition(const OperationName &opName) {
impl->partialOperations.emplace_back(opName);
}

void AsmParserState::finalizeOperationDefinition(
Operation *op, llvm::SMRange nameLoc,
ArrayRef<std::pair<unsigned, llvm::SMLoc>> resultGroups) {
assert(!impl->partialOperations.empty() &&
"expected valid partial operation definition");
Impl::PartialOpDef partialOpDef = impl->partialOperations.pop_back_val();

// Build the full operation definition.
std::unique_ptr<OperationDefinition> def =
std::make_unique<OperationDefinition>(op, location);
std::make_unique<OperationDefinition>(op, nameLoc);
for (auto &resultGroup : resultGroups)
def->resultGroups.emplace_back(resultGroup.first,
convertIdLocToRange(resultGroup.second));

impl->operationToIdx.try_emplace(op, impl->operations.size());
impl->operations.emplace_back(std::move(def));

// If this operation is a symbol table, resolve any symbol uses.
if (partialOpDef.isSymbolTable())
impl->resolveSymbolUses(op, partialOpDef);
}

void AsmParserState::startRegionDefinition() {
assert(!impl->partialOperations.empty() &&
"expected valid partial operation definition");

// If the parent operation of this region is a symbol table, we also push a
// new symbol scope.
Impl::PartialOpDef &partialOpDef = impl->partialOperations.back();
if (partialOpDef.isSymbolTable())
impl->symbolUseScopes.push_back(partialOpDef.symbolTable.get());
}

void AsmParserState::finalizeRegionDefinition() {
assert(!impl->partialOperations.empty() &&
"expected valid partial operation definition");

// If the parent operation of this region is a symbol table, pop the symbol
// scope for this region.
Impl::PartialOpDef &partialOpDef = impl->partialOperations.back();
if (partialOpDef.isSymbolTable())
impl->symbolUseScopes.pop_back();
}

void AsmParserState::addDefinition(Block *block, llvm::SMLoc location) {
Expand Down Expand Up @@ -169,6 +274,18 @@ void AsmParserState::addUses(Block *block, ArrayRef<llvm::SMLoc> locations) {
def.definition.uses.push_back(convertIdLocToRange(loc));
}

void AsmParserState::addUses(SymbolRefAttr refAttr,
ArrayRef<llvm::SMRange> locations) {
// Ignore this symbol if no scopes are active.
if (impl->symbolUseScopes.empty())
return;

assert((refAttr.getNestedReferences().size() + 1) == locations.size() &&
"expected the same number of references as provided locations");
(*impl->symbolUseScopes.back())[refAttr].append(locations.begin(),
locations.end());
}

void AsmParserState::refineDefinition(Value oldValue, Value newValue) {
auto it = impl->placeholderValueUses.find(oldValue);
assert(it != impl->placeholderValueUses.end() &&
Expand Down
19 changes: 18 additions & 1 deletion mlir/lib/Parser/AttributeParser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Dialect.h"
#include "mlir/IR/IntegerSet.h"
#include "mlir/Parser/AsmParserState.h"
#include "llvm/ADT/StringExtras.h"
#include "llvm/Support/Endian.h"

Expand Down Expand Up @@ -153,6 +154,13 @@ Attribute Parser::parseAttribute(Type type) {

// Parse a symbol reference attribute.
case Token::at_identifier: {
// When populating the parser state, this is a list of locations for all of
// the nested references.
SmallVector<llvm::SMRange> referenceLocations;
if (state.asmState)
referenceLocations.push_back(getToken().getLocRange());

// Parse the top-level reference.
std::string nameStr = getToken().getSymbolReference();
consumeToken(Token::at_identifier);

Expand All @@ -174,12 +182,21 @@ Attribute Parser::parseAttribute(Type type) {
return Attribute();
}

// If we are populating the assembly state, add the location for this
// reference.
if (state.asmState)
referenceLocations.push_back(getToken().getLocRange());

std::string nameStr = getToken().getSymbolReference();
consumeToken(Token::at_identifier);
nestedRefs.push_back(SymbolRefAttr::get(getContext(), nameStr));
}
SymbolRefAttr symbolRefAttr = builder.getSymbolRefAttr(nameStr, nestedRefs);

return builder.getSymbolRefAttr(nameStr, nestedRefs);
// If we are populating the assembly state, record this symbol reference.
if (state.asmState)
state.asmState->addUses(symbolRefAttr, referenceLocations);
return symbolRefAttr;
}

// Parse a 'unit' attribute.
Expand Down
Loading

0 comments on commit d6af89b

Please sign in to comment.