Skip to content

Commit ff81a2c

Browse files
committed
[mlir-lsp-server] Add support for textDocument/documentSymbols
This allows for building an outline of the symbols and symbol tables within the IR. This allows for easy navigations to functions/modules and other symbol/symbol table operations within the IR. Differential Revision: https://reviews.llvm.org/D103729
1 parent 1b894cc commit ff81a2c

File tree

11 files changed

+454
-53
lines changed

11 files changed

+454
-53
lines changed

mlir/include/mlir/Parser/AsmParserState.h

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -53,14 +53,19 @@ class AsmParserState {
5353
SMDefinition definition;
5454
};
5555

56-
OperationDefinition(Operation *op, llvm::SMRange loc) : op(op), loc(loc) {}
56+
OperationDefinition(Operation *op, llvm::SMRange loc, llvm::SMLoc endLoc)
57+
: op(op), loc(loc), scopeLoc(loc.Start, endLoc) {}
5758

5859
/// The operation representing this definition.
5960
Operation *op;
6061

6162
/// The source location for the operation, i.e. the location of its name.
6263
llvm::SMRange loc;
6364

65+
/// The full source range of the operation definition, i.e. a range
66+
/// encompassing the start and end of the full operation definition.
67+
llvm::SMRange scopeLoc;
68+
6469
/// Source definitions for any result groups of this operation.
6570
SmallVector<std::pair<unsigned, SMDefinition>> resultGroups;
6671

@@ -110,6 +115,10 @@ class AsmParserState {
110115
/// state.
111116
iterator_range<OperationDefIterator> getOpDefs() const;
112117

118+
/// Return the definition for the given operation, or nullptr if the given
119+
/// operation does not have a definition.
120+
const OperationDefinition *getOpDef(Operation *op) const;
121+
113122
/// Returns (heuristically) the range of an identifier given a SMLoc
114123
/// corresponding to the start of an identifier location.
115124
static llvm::SMRange convertIdLocToRange(llvm::SMLoc loc);
@@ -130,7 +139,7 @@ class AsmParserState {
130139

131140
/// Finalize the most recently started operation definition.
132141
void finalizeOperationDefinition(
133-
Operation *op, llvm::SMRange nameLoc,
142+
Operation *op, llvm::SMRange nameLoc, llvm::SMLoc endLoc,
134143
ArrayRef<std::pair<unsigned, llvm::SMLoc>> resultGroups = llvm::None);
135144

136145
/// Start a definition for a region nested under the current operation.

mlir/lib/Parser/AsmParserState.cpp

Lines changed: 39 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -36,8 +36,8 @@ struct AsmParserState::Impl {
3636
std::unique_ptr<SymbolUseMap> symbolTable;
3737
};
3838

39-
/// Resolve any symbol table uses under the given partial operation.
40-
void resolveSymbolUses(Operation *op, PartialOpDef &opDef);
39+
/// Resolve any symbol table uses in the IR.
40+
void resolveSymbolUses();
4141

4242
/// A mapping from operations in the input source file to their parser state.
4343
SmallVector<std::unique_ptr<OperationDefinition>> operations;
@@ -51,6 +51,10 @@ struct AsmParserState::Impl {
5151
/// This map should be empty if the parser finishes successfully.
5252
DenseMap<Value, SmallVector<llvm::SMLoc>> placeholderValueUses;
5353

54+
/// The symbol table operations within the IR.
55+
SmallVector<std::pair<Operation *, std::unique_ptr<SymbolUseMap>>>
56+
symbolTableOperations;
57+
5458
/// A stack of partial operation definitions that have been started but not
5559
/// yet finalized.
5660
SmallVector<PartialOpDef> partialOperations;
@@ -63,22 +67,21 @@ struct AsmParserState::Impl {
6367
SymbolTableCollection symbolTable;
6468
};
6569

66-
void AsmParserState::Impl::resolveSymbolUses(Operation *op,
67-
PartialOpDef &opDef) {
68-
assert(opDef.isSymbolTable() && "expected op to be a symbol table");
69-
70+
void AsmParserState::Impl::resolveSymbolUses() {
7071
SmallVector<Operation *> symbolOps;
71-
for (auto &it : *opDef.symbolTable) {
72-
symbolOps.clear();
73-
if (failed(symbolTable.lookupSymbolIn(op, it.first.cast<SymbolRefAttr>(),
74-
symbolOps)))
75-
continue;
76-
77-
for (ArrayRef<llvm::SMRange> useRange : it.second) {
78-
for (const auto &symIt : llvm::zip(symbolOps, useRange)) {
79-
auto opIt = operationToIdx.find(std::get<0>(symIt));
80-
if (opIt != operationToIdx.end())
81-
operations[opIt->second]->symbolUses.push_back(std::get<1>(symIt));
72+
for (auto &opAndUseMapIt : symbolTableOperations) {
73+
for (auto &it : *opAndUseMapIt.second) {
74+
symbolOps.clear();
75+
if (failed(symbolTable.lookupSymbolIn(
76+
opAndUseMapIt.first, it.first.cast<SymbolRefAttr>(), symbolOps)))
77+
continue;
78+
79+
for (ArrayRef<llvm::SMRange> useRange : it.second) {
80+
for (const auto &symIt : llvm::zip(symbolOps, useRange)) {
81+
auto opIt = operationToIdx.find(std::get<0>(symIt));
82+
if (opIt != operationToIdx.end())
83+
operations[opIt->second]->symbolUses.push_back(std::get<1>(symIt));
84+
}
8285
}
8386
}
8487
}
@@ -112,8 +115,13 @@ auto AsmParserState::getOpDefs() const -> iterator_range<OperationDefIterator> {
112115
return llvm::make_pointee_range(llvm::makeArrayRef(impl->operations));
113116
}
114117

115-
/// Returns (heuristically) the range of an identifier given a SMLoc
116-
/// corresponding to the start of an identifier location.
118+
auto AsmParserState::getOpDef(Operation *op) const
119+
-> const OperationDefinition * {
120+
auto it = impl->operationToIdx.find(op);
121+
return it == impl->operationToIdx.end() ? nullptr
122+
: &*impl->operations[it->second];
123+
}
124+
117125
llvm::SMRange AsmParserState::convertIdLocToRange(llvm::SMLoc loc) {
118126
if (!loc.isValid())
119127
return llvm::SMRange();
@@ -124,7 +132,7 @@ llvm::SMRange AsmParserState::convertIdLocToRange(llvm::SMLoc loc) {
124132
};
125133

126134
const char *curPtr = loc.getPointer();
127-
while (isIdentifierChar(*(++curPtr)))
135+
while (*curPtr && isIdentifierChar(*(++curPtr)))
128136
continue;
129137
return llvm::SMRange(loc, llvm::SMLoc::getFromPointer(curPtr));
130138
}
@@ -147,33 +155,38 @@ void AsmParserState::finalize(Operation *topLevelOp) {
147155
Impl::PartialOpDef partialOpDef = impl->partialOperations.pop_back_val();
148156

149157
// If this operation is a symbol table, resolve any symbol uses.
150-
if (partialOpDef.isSymbolTable())
151-
impl->resolveSymbolUses(topLevelOp, partialOpDef);
158+
if (partialOpDef.isSymbolTable()) {
159+
impl->symbolTableOperations.emplace_back(
160+
topLevelOp, std::move(partialOpDef.symbolTable));
161+
}
162+
impl->resolveSymbolUses();
152163
}
153164

154165
void AsmParserState::startOperationDefinition(const OperationName &opName) {
155166
impl->partialOperations.emplace_back(opName);
156167
}
157168

158169
void AsmParserState::finalizeOperationDefinition(
159-
Operation *op, llvm::SMRange nameLoc,
170+
Operation *op, llvm::SMRange nameLoc, llvm::SMLoc endLoc,
160171
ArrayRef<std::pair<unsigned, llvm::SMLoc>> resultGroups) {
161172
assert(!impl->partialOperations.empty() &&
162173
"expected valid partial operation definition");
163174
Impl::PartialOpDef partialOpDef = impl->partialOperations.pop_back_val();
164175

165176
// Build the full operation definition.
166177
std::unique_ptr<OperationDefinition> def =
167-
std::make_unique<OperationDefinition>(op, nameLoc);
178+
std::make_unique<OperationDefinition>(op, nameLoc, endLoc);
168179
for (auto &resultGroup : resultGroups)
169180
def->resultGroups.emplace_back(resultGroup.first,
170181
convertIdLocToRange(resultGroup.second));
171182
impl->operationToIdx.try_emplace(op, impl->operations.size());
172183
impl->operations.emplace_back(std::move(def));
173184

174185
// If this operation is a symbol table, resolve any symbol uses.
175-
if (partialOpDef.isSymbolTable())
176-
impl->resolveSymbolUses(op, partialOpDef);
186+
if (partialOpDef.isSymbolTable()) {
187+
impl->symbolTableOperations.emplace_back(
188+
op, std::move(partialOpDef.symbolTable));
189+
}
177190
}
178191

179192
void AsmParserState::startRegionDefinition() {

mlir/lib/Parser/Parser.cpp

Lines changed: 28 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -166,12 +166,12 @@ namespace {
166166
/// operations.
167167
class OperationParser : public Parser {
168168
public:
169-
OperationParser(ParserState &state, Operation *topLevelOp);
169+
OperationParser(ParserState &state, ModuleOp topLevelOp);
170170
~OperationParser();
171171

172172
/// After parsing is finished, this function must be called to see if there
173173
/// are any remaining issues.
174-
ParseResult finalize(Operation *topLevelOp);
174+
ParseResult finalize();
175175

176176
//===--------------------------------------------------------------------===//
177177
// SSA Value Handling
@@ -399,9 +399,8 @@ class OperationParser : public Parser {
399399
};
400400
} // end anonymous namespace
401401

402-
OperationParser::OperationParser(ParserState &state, Operation *topLevelOp)
403-
: Parser(state), opBuilder(topLevelOp->getRegion(0)),
404-
topLevelOp(topLevelOp) {
402+
OperationParser::OperationParser(ParserState &state, ModuleOp topLevelOp)
403+
: Parser(state), opBuilder(topLevelOp.getRegion()), topLevelOp(topLevelOp) {
405404
// The top level operation starts a new name scope.
406405
pushSSANameScope(/*isIsolated=*/true);
407406

@@ -429,7 +428,7 @@ OperationParser::~OperationParser() {
429428

430429
/// After parsing is finished, this function must be called to see if there are
431430
/// any remaining issues.
432-
ParseResult OperationParser::finalize(Operation *topLevelOp) {
431+
ParseResult OperationParser::finalize() {
433432
// Check for any forward references that are left. If we find any, error
434433
// out.
435434
if (!forwardRefPlaceholders.empty()) {
@@ -466,12 +465,18 @@ ParseResult OperationParser::finalize(Operation *topLevelOp) {
466465
opOrArgument.get<BlockArgument>().setLoc(locAttr);
467466
}
468467

468+
// Pop the top level name scope.
469+
if (failed(popSSANameScope()))
470+
return failure();
471+
472+
// Verify that the parsed operations are valid.
473+
if (failed(verify(topLevelOp)))
474+
return failure();
475+
469476
// If we are populating the parser state, finalize the top-level operation.
470477
if (state.asmState)
471478
state.asmState->finalize(topLevelOp);
472-
473-
// Pop the top level name scope.
474-
return popSSANameScope();
479+
return success();
475480
}
476481

477482
//===----------------------------------------------------------------------===//
@@ -821,8 +826,9 @@ ParseResult OperationParser::parseOperation() {
821826
asmResultGroups.emplace_back(resultIt, std::get<2>(record));
822827
resultIt += std::get<1>(record);
823828
}
824-
state.asmState->finalizeOperationDefinition(op, nameTok.getLocRange(),
825-
asmResultGroups);
829+
state.asmState->finalizeOperationDefinition(
830+
op, nameTok.getLocRange(), /*endLoc=*/getToken().getLoc(),
831+
asmResultGroups);
826832
}
827833

828834
// Add definitions for each of the result groups.
@@ -837,7 +843,8 @@ ParseResult OperationParser::parseOperation() {
837843

838844
// Add this operation to the assembly state if it was provided to populate.
839845
} else if (state.asmState) {
840-
state.asmState->finalizeOperationDefinition(op, nameTok.getLocRange());
846+
state.asmState->finalizeOperationDefinition(op, nameTok.getLocRange(),
847+
/*endLoc=*/getToken().getLoc());
841848
}
842849

843850
return success();
@@ -1009,7 +1016,8 @@ Operation *OperationParser::parseGenericOperation(Block *insertBlock,
10091016
// If we are populating the parser asm state, finalize this operation
10101017
// definition.
10111018
if (state.asmState)
1012-
state.asmState->finalizeOperationDefinition(op, nameToken.getLocRange());
1019+
state.asmState->finalizeOperationDefinition(op, nameToken.getLocRange(),
1020+
/*endLoc=*/getToken().getLoc());
10131021
return op;
10141022
}
10151023

@@ -2019,6 +2027,10 @@ ParseResult OperationParser::parseRegionBody(
20192027

20202028
// Add arguments to the entry block.
20212029
if (!entryArguments.empty()) {
2030+
// If we had named arguments, then don't allow a block name.
2031+
if (getToken().is(Token::caret_identifier))
2032+
return emitError("invalid block name in region with named arguments");
2033+
20222034
for (auto &placeholderArgPair : entryArguments) {
20232035
auto &argInfo = placeholderArgPair.first;
20242036

@@ -2040,10 +2052,6 @@ ParseResult OperationParser::parseRegionBody(
20402052
if (addDefinition(argInfo, arg))
20412053
return failure();
20422054
}
2043-
2044-
// If we had named arguments, then don't allow a block name.
2045-
if (getToken().is(Token::caret_identifier))
2046-
return emitError("invalid block name in region with named arguments");
20472055
}
20482056

20492057
if (parseBlock(block))
@@ -2310,7 +2318,7 @@ ParseResult TopLevelOperationParser::parseTypeAliasDef() {
23102318
ParseResult TopLevelOperationParser::parse(Block *topLevelBlock,
23112319
Location parserLoc) {
23122320
// Create a top-level operation to contain the parsed state.
2313-
OwningOpRef<Operation *> topLevelOp(ModuleOp::create(parserLoc));
2321+
OwningOpRef<ModuleOp> topLevelOp(ModuleOp::create(parserLoc));
23142322
OperationParser opParser(state, topLevelOp.get());
23152323
while (true) {
23162324
switch (getToken().getKind()) {
@@ -2322,16 +2330,12 @@ ParseResult TopLevelOperationParser::parse(Block *topLevelBlock,
23222330

23232331
// If we got to the end of the file, then we're done.
23242332
case Token::eof: {
2325-
if (opParser.finalize(topLevelOp.get()))
2326-
return failure();
2327-
2328-
// Verify that the parsed operations are valid.
2329-
if (failed(verify(topLevelOp.get())))
2333+
if (opParser.finalize())
23302334
return failure();
23312335

23322336
// Splice the blocks of the parsed operation over to the provided
23332337
// top-level block.
2334-
auto &parsedOps = (*topLevelOp)->getRegion(0).front().getOperations();
2338+
auto &parsedOps = topLevelOp->getBody()->getOperations();
23352339
auto &destOps = topLevelBlock->getOperations();
23362340
destOps.splice(destOps.empty() ? destOps.end() : std::prev(destOps.end()),
23372341
parsedOps, parsedOps.begin(), parsedOps.end());

mlir/lib/Tools/mlir-lsp-server/LSPServer.cpp

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,16 @@ struct LSPServer::Impl {
5656
void onHover(const TextDocumentPositionParams &params,
5757
Callback<Optional<Hover>> reply);
5858

59+
//===--------------------------------------------------------------------===//
60+
// Document Symbols
61+
62+
void onDocumentSymbol(const DocumentSymbolParams &params,
63+
Callback<std::vector<DocumentSymbol>> reply);
64+
65+
//===--------------------------------------------------------------------===//
66+
// Fields
67+
//===--------------------------------------------------------------------===//
68+
5969
MLIRServer &server;
6070
JSONTransport &transport;
6171

@@ -73,6 +83,7 @@ struct LSPServer::Impl {
7383

7484
void LSPServer::Impl::onInitialize(const InitializeParams &params,
7585
Callback<llvm::json::Value> reply) {
86+
// Send a response with the capabilities of this server.
7687
llvm::json::Object serverCaps{
7788
{"textDocumentSync",
7889
llvm::json::Object{
@@ -83,6 +94,11 @@ void LSPServer::Impl::onInitialize(const InitializeParams &params,
8394
{"definitionProvider", true},
8495
{"referencesProvider", true},
8596
{"hoverProvider", true},
97+
98+
// For now we only support documenting symbols when the client supports
99+
// hierarchical symbols.
100+
{"documentSymbolProvider",
101+
params.capabilities.hierarchicalDocumentSymbol},
86102
};
87103

88104
llvm::json::Object result{
@@ -165,6 +181,17 @@ void LSPServer::Impl::onHover(const TextDocumentPositionParams &params,
165181
reply(server.findHover(params.textDocument.uri, params.position));
166182
}
167183

184+
//===----------------------------------------------------------------------===//
185+
// Document Symbols
186+
187+
void LSPServer::Impl::onDocumentSymbol(
188+
const DocumentSymbolParams &params,
189+
Callback<std::vector<DocumentSymbol>> reply) {
190+
std::vector<DocumentSymbol> symbols;
191+
server.findDocumentSymbols(params.textDocument.uri, symbols);
192+
reply(std::move(symbols));
193+
}
194+
168195
//===----------------------------------------------------------------------===//
169196
// LSPServer
170197
//===----------------------------------------------------------------------===//
@@ -198,6 +225,10 @@ LogicalResult LSPServer::run() {
198225
// Hover
199226
messageHandler.method("textDocument/hover", impl.get(), &Impl::onHover);
200227

228+
// Document Symbols
229+
messageHandler.method("textDocument/documentSymbol", impl.get(),
230+
&Impl::onDocumentSymbol);
231+
201232
// Diagnostics
202233
impl->publishDiagnostics =
203234
messageHandler.outgoingNotification<PublishDiagnosticsParams>(

0 commit comments

Comments
 (0)