diff --git a/mlir/include/mlir/Parser/CodeComplete.h b/mlir/include/mlir/Parser/CodeComplete.h new file mode 100644 index 0000000000000..24d0181456f67 --- /dev/null +++ b/mlir/include/mlir/Parser/CodeComplete.h @@ -0,0 +1,58 @@ +//===- CodeComplete.h - MLIR Asm CodeComplete Context -----------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_PARSER_CODECOMPLETE_H +#define MLIR_PARSER_CODECOMPLETE_H + +#include "mlir/Support/LLVM.h" +#include "llvm/Support/SourceMgr.h" + +namespace mlir { +/// This class provides an abstract interface into the parser for hooking in +/// code completion events. This class is only really useful for providing +/// language tooling for MLIR, general clients should not need to use this +/// class. +class AsmParserCodeCompleteContext { +public: + virtual ~AsmParserCodeCompleteContext(); + + /// Return the source location used to provide code completion. + SMLoc getCodeCompleteLoc() const { return codeCompleteLoc; } + + //===--------------------------------------------------------------------===// + // Completion Hooks + //===--------------------------------------------------------------------===// + + /// Signal code completion for a dialect name. + virtual void completeDialectName() = 0; + + /// Signal code completion for an operation name within the given dialect. + virtual void completeOperationName(StringRef dialectName) = 0; + + /// Append the given SSA value as a code completion result for SSA value + /// completions. + virtual void appendSSAValueCompletion(StringRef name, + std::string typeData) = 0; + + /// Append the given block as a code completion result for block name + /// completions. + virtual void appendBlockCompletion(StringRef name) = 0; + +protected: + /// Create a new code completion context with the given code complete + /// location. + explicit AsmParserCodeCompleteContext(SMLoc codeCompleteLoc) + : codeCompleteLoc(codeCompleteLoc) {} + +private: + /// The location used to code complete. + SMLoc codeCompleteLoc; +}; +} // namespace mlir + +#endif // MLIR_PARSER_CODECOMPLETE_H diff --git a/mlir/include/mlir/Parser/Parser.h b/mlir/include/mlir/Parser/Parser.h index 69f02c45d6fbb..5c9543f187d6a 100644 --- a/mlir/include/mlir/Parser/Parser.h +++ b/mlir/include/mlir/Parser/Parser.h @@ -26,6 +26,7 @@ class StringRef; namespace mlir { class AsmParserState; +class AsmParserCodeCompleteContext; namespace detail { @@ -83,11 +84,13 @@ inline OwningOpRef constructContainerOpForParserIfNecessary( /// source file that is being parsed. If `asmState` is non-null, it is populated /// with detailed information about the parsed IR (including exact locations for /// SSA uses and definitions). `asmState` should only be provided if this -/// detailed information is desired. -LogicalResult parseSourceFile(const llvm::SourceMgr &sourceMgr, Block *block, - const ParserConfig &config, - LocationAttr *sourceFileLoc = nullptr, - AsmParserState *asmState = nullptr); +/// detailed information is desired. If `codeCompleteContext` is non-null, it is +/// used to signal tracking of a code completion event (generally only ever +/// useful for LSP or other high level language tooling). +LogicalResult parseSourceFile( + const llvm::SourceMgr &sourceMgr, Block *block, const ParserConfig &config, + LocationAttr *sourceFileLoc = nullptr, AsmParserState *asmState = nullptr, + AsmParserCodeCompleteContext *codeCompleteContext = nullptr); /// This parses the file specified by the indicated filename and appends parsed /// operations to the given block. If the block is non-empty, the operations are diff --git a/mlir/lib/Parser/AffineParser.cpp b/mlir/lib/Parser/AffineParser.cpp index 7fa054bf3913a..345b3a505b069 100644 --- a/mlir/lib/Parser/AffineParser.cpp +++ b/mlir/lib/Parser/AffineParser.cpp @@ -743,7 +743,8 @@ IntegerSet mlir::parseIntegerSet(StringRef inputStr, MLIRContext *context, sourceMgr.AddNewSourceBuffer(std::move(memBuffer), SMLoc()); SymbolState symbolState; ParserConfig config(context); - ParserState state(sourceMgr, config, symbolState, /*asmState=*/nullptr); + ParserState state(sourceMgr, config, symbolState, /*asmState=*/nullptr, + /*codeCompleteContext=*/nullptr); Parser parser(state); raw_ostream &os = printDiagnosticInfo ? llvm::errs() : llvm::nulls(); diff --git a/mlir/lib/Parser/DialectSymbolParser.cpp b/mlir/lib/Parser/DialectSymbolParser.cpp index 623ea8ef12ea1..4478d1936e661 100644 --- a/mlir/lib/Parser/DialectSymbolParser.cpp +++ b/mlir/lib/Parser/DialectSymbolParser.cpp @@ -277,7 +277,8 @@ static T parseSymbol(StringRef inputStr, MLIRContext *context, size_t &numRead, sourceMgr.AddNewSourceBuffer(std::move(memBuffer), SMLoc()); SymbolState aliasState; ParserConfig config(context); - ParserState state(sourceMgr, config, aliasState, /*asmState=*/nullptr); + ParserState state(sourceMgr, config, aliasState, /*asmState=*/nullptr, + /*codeCompleteContext=*/nullptr); Parser parser(state); SourceMgrDiagnosticHandler handler( diff --git a/mlir/lib/Parser/Lexer.cpp b/mlir/lib/Parser/Lexer.cpp index f3e4ce6b2768e..cbebaf24159d4 100644 --- a/mlir/lib/Parser/Lexer.cpp +++ b/mlir/lib/Parser/Lexer.cpp @@ -14,6 +14,7 @@ #include "mlir/IR/Diagnostics.h" #include "mlir/IR/Location.h" #include "mlir/IR/MLIRContext.h" +#include "mlir/Parser/CodeComplete.h" #include "llvm/ADT/StringExtras.h" #include "llvm/ADT/StringSwitch.h" #include "llvm/Support/SourceMgr.h" @@ -26,11 +27,16 @@ static bool isPunct(char c) { return c == '$' || c == '.' || c == '_' || c == '-'; } -Lexer::Lexer(const llvm::SourceMgr &sourceMgr, MLIRContext *context) - : sourceMgr(sourceMgr), context(context) { +Lexer::Lexer(const llvm::SourceMgr &sourceMgr, MLIRContext *context, + AsmParserCodeCompleteContext *codeCompleteContext) + : sourceMgr(sourceMgr), context(context), codeCompleteLoc(nullptr) { auto bufferID = sourceMgr.getMainFileID(); curBuffer = sourceMgr.getMemoryBuffer(bufferID)->getBuffer(); curPtr = curBuffer.begin(); + + // Set the code completion location if it was provided. + if (codeCompleteContext) + codeCompleteLoc = codeCompleteContext->getCodeCompleteLoc().getPointer(); } /// Encode the specified source location information into an attribute for @@ -61,6 +67,12 @@ Token Lexer::emitError(const char *loc, const Twine &message) { Token Lexer::lexToken() { while (true) { const char *tokStart = curPtr; + + // Check to see if the current token is at the code completion location. + if (tokStart == codeCompleteLoc) + return formToken(Token::code_complete, tokStart); + + // Lex the next token. switch (*curPtr++) { default: // Handle bare identifiers. @@ -357,17 +369,25 @@ Token Lexer::lexPrefixedIdentifier(const char *tokStart) { // Parse suffix-id. if (isdigit(*curPtr)) { // If suffix-id starts with a digit, the rest must be digits. - while (isdigit(*curPtr)) { + while (isdigit(*curPtr)) ++curPtr; - } } else if (isalpha(*curPtr) || isPunct(*curPtr)) { do { ++curPtr; } while (isalpha(*curPtr) || isdigit(*curPtr) || isPunct(*curPtr)); + } else if (curPtr == codeCompleteLoc) { + return formToken(Token::code_complete, tokStart); } else { return emitError(curPtr - 1, errorKind); } + // Check for a code completion within the identifier. + if (codeCompleteLoc && codeCompleteLoc >= tokStart && + codeCompleteLoc <= curPtr) { + return Token(Token::code_complete, + StringRef(tokStart, codeCompleteLoc - tokStart)); + } + return formToken(kind, tokStart); } @@ -380,6 +400,13 @@ Token Lexer::lexString(const char *tokStart) { assert(curPtr[-1] == '"'); while (true) { + // Check to see if there is a code completion location within the string. In + // these cases we generate a completion location and place the currently + // lexed string within the token. This allows for the parser to use the + // partially lexed string when computing the completion results. + if (curPtr == codeCompleteLoc) + return formToken(Token::code_complete, tokStart); + switch (*curPtr++) { case '"': return formToken(Token::string, tokStart); diff --git a/mlir/lib/Parser/Lexer.h b/mlir/lib/Parser/Lexer.h index 58b5e2321c4d3..e09ae168e3bba 100644 --- a/mlir/lib/Parser/Lexer.h +++ b/mlir/lib/Parser/Lexer.h @@ -22,7 +22,8 @@ class Location; /// This class breaks up the current file into a token stream. class Lexer { public: - explicit Lexer(const llvm::SourceMgr &sourceMgr, MLIRContext *context); + explicit Lexer(const llvm::SourceMgr &sourceMgr, MLIRContext *context, + AsmParserCodeCompleteContext *codeCompleteContext); const llvm::SourceMgr &getSourceMgr() { return sourceMgr; } @@ -64,6 +65,10 @@ class Lexer { StringRef curBuffer; const char *curPtr; + /// An optional code completion point within the input file, used to indicate + /// the position of a code completion token. + const char *codeCompleteLoc; + Lexer(const Lexer &) = delete; void operator=(const Lexer &) = delete; }; diff --git a/mlir/lib/Parser/Parser.cpp b/mlir/lib/Parser/Parser.cpp index 5bc1e43c60a0e..51a65a99dfa7a 100644 --- a/mlir/lib/Parser/Parser.cpp +++ b/mlir/lib/Parser/Parser.cpp @@ -18,6 +18,7 @@ #include "mlir/IR/Dialect.h" #include "mlir/IR/Verifier.h" #include "mlir/Parser/AsmParserState.h" +#include "mlir/Parser/CodeComplete.h" #include "mlir/Parser/Parser.h" #include "llvm/ADT/DenseMap.h" #include "llvm/ADT/ScopeExit.h" @@ -32,6 +33,12 @@ using namespace mlir::detail; using llvm::MemoryBuffer; using llvm::SourceMgr; +//===----------------------------------------------------------------------===// +// CodeComplete +//===----------------------------------------------------------------------===// + +AsmParserCodeCompleteContext::~AsmParserCodeCompleteContext() = default; + //===----------------------------------------------------------------------===// // Parser //===----------------------------------------------------------------------===// @@ -334,6 +341,60 @@ Parser::parseResourceHandle(const OpAsmDialectInterface *dialect, return entry.second; } +//===----------------------------------------------------------------------===// +// Code Completion + +ParseResult Parser::codeCompleteDialectName() { + state.codeCompleteContext->completeDialectName(); + return failure(); +} + +ParseResult Parser::codeCompleteOperationName(StringRef dialectName) { + // Perform some simple validation on the dialect name. This doesn't need to be + // extensive, it's more of an optimization (to avoid checking completion + // results when we know they will fail). + if (dialectName.empty() || dialectName.contains('.')) + return failure(); + state.codeCompleteContext->completeOperationName(dialectName); + return failure(); +} + +ParseResult Parser::codeCompleteDialectOrElidedOpName(SMLoc loc) { + // Check to see if there is anything else on the current line. This check + // isn't strictly necessary, but it does avoid unnecessarily triggering + // completions for operations and dialects in situations where we don't want + // them (e.g. at the end of an operation). + auto shouldIgnoreOpCompletion = [&]() { + const char *bufBegin = state.lex.getBufferBegin(); + const char *it = loc.getPointer() - 1; + for (; it > bufBegin && *it != '\n'; --it) + if (!llvm::is_contained(StringRef(" \t\r"), *it)) + return true; + return false; + }; + if (shouldIgnoreOpCompletion()) + return failure(); + + // The completion here is either for a dialect name, or an operation name + // whose dialect prefix was elided. For this we simply invoke both of the + // individual completion methods. + (void)codeCompleteDialectName(); + return codeCompleteOperationName(state.defaultDialectStack.back()); +} + +ParseResult Parser::codeCompleteStringDialectOrOperationName(StringRef name) { + // If the name is empty, this is the start of the string and contains the + // dialect. + if (name.empty()) + return codeCompleteDialectName(); + + // Otherwise, we treat this as completing an operation name. The current name + // is used as the dialect namespace. + if (name.consume_back(".")) + return codeCompleteOperationName(name); + return failure(); +} + //===----------------------------------------------------------------------===// // OperationParser //===----------------------------------------------------------------------===// @@ -497,6 +558,17 @@ class OperationParser : public Parser { /// us to diagnose references to blocks that are not defined precisely. Block *getBlockNamed(StringRef name, SMLoc loc); + //===--------------------------------------------------------------------===// + // Code Completion + //===--------------------------------------------------------------------===// + + /// The set of various code completion methods. Every completion method + /// returns `failure` to stop the parsing process after providing completion + /// results. + + ParseResult codeCompleteSSAUse(); + ParseResult codeCompleteBlock(); + private: /// This class represents a definition of a Block. struct BlockDefinition { @@ -790,7 +862,7 @@ ParseResult OperationParser::addDefinition(UnresolvedOperand useInfo, /// ParseResult OperationParser::parseOptionalSSAUseList( SmallVectorImpl &results) { - if (getToken().isNot(Token::percent_identifier)) + if (!getToken().isOrIsCodeCompletionFor(Token::percent_identifier)) return success(); return parseCommaSeparatedList([&]() -> ParseResult { UnresolvedOperand result; @@ -807,6 +879,9 @@ ParseResult OperationParser::parseOptionalSSAUseList( /// ParseResult OperationParser::parseSSAUse(UnresolvedOperand &result, bool allowResultNumber) { + if (getToken().isCodeCompletion()) + return codeCompleteSSAUse(); + result.name = getTokenSpelling(); result.number = 0; result.location = getToken().getLoc(); @@ -1017,6 +1092,10 @@ ParseResult OperationParser::parseOperation() { op = parseCustomOperation(resultIDs); else if (nameTok.is(Token::string)) op = parseGenericOperation(); + else if (nameTok.isCodeCompletionFor(Token::string)) + return codeCompleteStringDialectOrOperationName(nameTok.getStringValue()); + else if (nameTok.isCodeCompletion()) + return codeCompleteDialectOrElidedOpName(loc); else return emitWrongTokenError("expected operation name in quotes"); @@ -1071,6 +1150,9 @@ ParseResult OperationParser::parseOperation() { /// successor ::= block-id /// ParseResult OperationParser::parseSuccessor(Block *&dest) { + if (getToken().isCodeCompletion()) + return codeCompleteBlock(); + // Verify branch is identifier and get the matching block. if (!getToken().is(Token::caret_identifier)) return emitWrongTokenError("expected block name"); @@ -1391,7 +1473,7 @@ class CustomOpAsmParser : public AsmParserImpl { OptionalParseResult parseOptionalOperand(UnresolvedOperand &result, bool allowResultNumber = true) override { - if (parser.getToken().is(Token::percent_identifier)) + if (parser.getToken().isOrIsCodeCompletionFor(Token::percent_identifier)) return parseOperand(result, allowResultNumber); return llvm::None; } @@ -1406,14 +1488,15 @@ class CustomOpAsmParser : public AsmParserImpl { if (delimiter == Delimiter::None) { // parseCommaSeparatedList doesn't handle the missing case for "none", // so we handle it custom here. - if (parser.getToken().isNot(Token::percent_identifier)) { + Token tok = parser.getToken(); + if (!tok.isOrIsCodeCompletionFor(Token::percent_identifier)) { // If we didn't require any operands or required exactly zero (weird) // then this is success. if (requiredOperandCount == -1 || requiredOperandCount == 0) return success(); // Otherwise, try to produce a nice error message. - if (parser.getToken().isAny(Token::l_paren, Token::l_square)) + if (tok.isAny(Token::l_paren, Token::l_square)) return parser.emitError("unexpected delimiter"); return parser.emitWrongTokenError("expected operand"); } @@ -1597,7 +1680,7 @@ class CustomOpAsmParser : public AsmParserImpl { /// Parse an optional operation successor and its operand list. OptionalParseResult parseOptionalSuccessor(Block *&dest) override { - if (parser.getToken().isNot(Token::caret_identifier)) + if (!parser.getToken().isOrIsCodeCompletionFor(Token::caret_identifier)) return llvm::None; return parseSuccessor(dest); } @@ -1681,7 +1764,8 @@ class CustomOpAsmParser : public AsmParserImpl { } // namespace FailureOr OperationParser::parseCustomOperationName() { - std::string opName = getTokenSpelling().str(); + Token nameTok = getToken(); + StringRef opName = nameTok.getSpelling(); if (opName.empty()) return (emitError("empty operation name is invalid"), failure()); consumeToken(); @@ -1694,11 +1778,17 @@ FailureOr OperationParser::parseCustomOperationName() { // If the operation doesn't have a dialect prefix try using the default // dialect. - auto opNameSplit = StringRef(opName).split('.'); + auto opNameSplit = opName.split('.'); StringRef dialectName = opNameSplit.first; + std::string opNameStorage; if (opNameSplit.second.empty()) { + // If the name didn't have a prefix, check for a code completion request. + if (getToken().isCodeCompletion() && opName.back() == '.') + return codeCompleteOperationName(dialectName); + dialectName = getState().defaultDialectStack.back(); - opName = (dialectName + "." + opName).str(); + opNameStorage = (dialectName + "." + opName).str(); + opName = opNameStorage; } // Try to load the dialect before returning the operation name to make sure @@ -2091,6 +2181,58 @@ ParseResult OperationParser::parseOptionalBlockArgList(Block *owner) { }); } +//===----------------------------------------------------------------------===// +// Code Completion +//===----------------------------------------------------------------------===// + +ParseResult OperationParser::codeCompleteSSAUse() { + std::string detailData; + llvm::raw_string_ostream detailOS(detailData); + for (IsolatedSSANameScope &scope : isolatedNameScopes) { + for (auto &it : scope.values) { + if (it.second.empty()) + continue; + Value frontValue = it.second.front().value; + + // If the value isn't a forward reference, we also add the name of the op + // to the detail. + if (auto result = frontValue.dyn_cast()) { + if (!forwardRefPlaceholders.count(result)) + detailOS << result.getOwner()->getName() << ": "; + } else { + detailOS << "arg #" << frontValue.cast().getArgNumber() + << ": "; + } + + // Emit the type of the values to aid with completion selection. + detailOS << frontValue.getType(); + + // FIXME: We should define a policy for packed values, e.g. with a limit + // on the detail size, but it isn't clear what would be useful right now. + // For now we just only emit the first type. + if (it.second.size() > 1) + detailOS << ", ..."; + + state.codeCompleteContext->appendSSAValueCompletion( + it.getKey(), std::move(detailOS.str())); + } + } + + return failure(); +} + +ParseResult OperationParser::codeCompleteBlock() { + // Don't provide completions if the token isn't empty, e.g. this avoids + // weirdness when we encounter a `.` within the identifier. + StringRef spelling = getTokenSpelling(); + if (!(spelling.empty() || spelling == "^")) + return failure(); + + for (const auto &it : blocksByName.back()) + state.codeCompleteContext->appendBlockCompletion(it.getFirst()); + return failure(); +} + //===----------------------------------------------------------------------===// // Top-level entity parsing. //===----------------------------------------------------------------------===// @@ -2418,10 +2560,11 @@ ParseResult TopLevelOperationParser::parse(Block *topLevelBlock, //===----------------------------------------------------------------------===// -LogicalResult mlir::parseSourceFile(const llvm::SourceMgr &sourceMgr, - Block *block, const ParserConfig &config, - LocationAttr *sourceFileLoc, - AsmParserState *asmState) { +LogicalResult +mlir::parseSourceFile(const llvm::SourceMgr &sourceMgr, Block *block, + const ParserConfig &config, LocationAttr *sourceFileLoc, + AsmParserState *asmState, + AsmParserCodeCompleteContext *codeCompleteContext) { const auto *sourceBuf = sourceMgr.getMemoryBuffer(sourceMgr.getMainFileID()); Location parserLoc = @@ -2431,7 +2574,8 @@ LogicalResult mlir::parseSourceFile(const llvm::SourceMgr &sourceMgr, *sourceFileLoc = parserLoc; SymbolState aliasState; - ParserState state(sourceMgr, config, aliasState, asmState); + ParserState state(sourceMgr, config, aliasState, asmState, + codeCompleteContext); return TopLevelOperationParser(state).parse(block, parserLoc); } diff --git a/mlir/lib/Parser/Parser.h b/mlir/lib/Parser/Parser.h index e762fc6d3543b..f1c87ca4916b5 100644 --- a/mlir/lib/Parser/Parser.h +++ b/mlir/lib/Parser/Parser.h @@ -306,6 +306,20 @@ class Parser { parseAffineExprOfSSAIds(AffineExpr &expr, function_ref parseElement); + //===--------------------------------------------------------------------===// + // Code Completion + //===--------------------------------------------------------------------===// + + /// The set of various code completion methods. Every completion method + /// returns `failure` to signal that parsing should abort after any desired + /// completions have been enqueued. Note that `failure` is does not mean + /// completion failed, it's just a signal to the parser to stop. + + ParseResult codeCompleteDialectName(); + ParseResult codeCompleteOperationName(StringRef dialectName); + ParseResult codeCompleteDialectOrElidedOpName(SMLoc loc); + ParseResult codeCompleteStringDialectOrOperationName(StringRef name); + protected: /// The Parser is subclassed and reinstantiated. Do not add additional /// non-trivial state here, add it to the ParserState class. diff --git a/mlir/lib/Parser/ParserState.h b/mlir/lib/Parser/ParserState.h index 4d3ff2774f8e4..ec64a43c7bb8b 100644 --- a/mlir/lib/Parser/ParserState.h +++ b/mlir/lib/Parser/ParserState.h @@ -43,9 +43,12 @@ struct SymbolState { /// such as the current lexer position etc. struct ParserState { ParserState(const llvm::SourceMgr &sourceMgr, const ParserConfig &config, - SymbolState &symbols, AsmParserState *asmState) - : config(config), lex(sourceMgr, config.getContext()), - curToken(lex.lexToken()), symbols(symbols), asmState(asmState) {} + SymbolState &symbols, AsmParserState *asmState, + AsmParserCodeCompleteContext *codeCompleteContext) + : config(config), + lex(sourceMgr, config.getContext(), codeCompleteContext), + curToken(lex.lexToken()), symbols(symbols), asmState(asmState), + codeCompleteContext(codeCompleteContext) {} ParserState(const ParserState &) = delete; void operator=(const ParserState &) = delete; @@ -65,6 +68,9 @@ struct ParserState { /// populated during parsing. AsmParserState *asmState; + /// An optional code completion context. + AsmParserCodeCompleteContext *codeCompleteContext; + // Contains the stack of default dialect to use when parsing regions. // A new dialect get pushed to the stack before parsing regions nested // under an operation implementing `OpAsmOpInterface`, and diff --git a/mlir/lib/Parser/Token.cpp b/mlir/lib/Parser/Token.cpp index 23410b37088d4..7a8d2975d2f47 100644 --- a/mlir/lib/Parser/Token.cpp +++ b/mlir/lib/Parser/Token.cpp @@ -78,12 +78,15 @@ Optional Token::getIntTypeSignedness() const { /// removing the quote characters and unescaping the contents of the string. The /// lexer has already verified that this token is valid. std::string Token::getStringValue() const { - assert(getKind() == string || + assert(getKind() == string || getKind() == code_complete || (getKind() == at_identifier && getSpelling()[1] == '"')); // Start by dropping the quotes. - StringRef bytes = getSpelling().drop_front().drop_back(); - if (getKind() == at_identifier) - bytes = bytes.drop_front(); + StringRef bytes = getSpelling().drop_front(); + if (getKind() != Token::code_complete) { + bytes = bytes.drop_back(); + if (getKind() == at_identifier) + bytes = bytes.drop_front(); + } std::string result; result.reserve(bytes.size()); @@ -190,3 +193,22 @@ bool Token::isKeyword() const { #include "TokenKinds.def" } } + +bool Token::isCodeCompletionFor(Kind kind) const { + if (!isCodeCompletion() || spelling.empty()) + return false; + switch (kind) { + case Kind::string: + return spelling[0] == '"'; + case Kind::hash_identifier: + return spelling[0] == '#'; + case Kind::percent_identifier: + return spelling[0] == '%'; + case Kind::caret_identifier: + return spelling[0] == '^'; + case Kind::exclamation_identifier: + return spelling[0] == '!'; + default: + return false; + } +} diff --git a/mlir/lib/Parser/Token.h b/mlir/lib/Parser/Token.h index be0924cb9d67a..0e48805bb06f0 100644 --- a/mlir/lib/Parser/Token.h +++ b/mlir/lib/Parser/Token.h @@ -58,6 +58,19 @@ class Token { /// Return true if this is one of the keyword token kinds (e.g. kw_if). bool isKeyword() const; + /// Returns true if the current token represents a code completion. + bool isCodeCompletion() const { return is(code_complete); } + + /// Returns true if the current token represents a code completion for the + /// "normal" token type. + bool isCodeCompletionFor(Kind kind) const; + + /// Returns true if the current token is the given type, or represents a code + /// completion for that type. + bool isOrIsCodeCompletionFor(Kind kind) const { + return is(kind) || isCodeCompletionFor(kind); + } + // Helpers to decode specific sorts of tokens. /// For an integer token, return its value as an unsigned. If it doesn't fit, diff --git a/mlir/lib/Parser/TokenKinds.def b/mlir/lib/Parser/TokenKinds.def index f7146674234f6..207af3871f8d3 100644 --- a/mlir/lib/Parser/TokenKinds.def +++ b/mlir/lib/Parser/TokenKinds.def @@ -36,6 +36,7 @@ // Markers TOK_MARKER(eof) TOK_MARKER(error) +TOK_MARKER(code_complete) // Identifiers. TOK_IDENTIFIER(bare_identifier) // foo diff --git a/mlir/lib/Tools/lsp-server-support/Protocol.h b/mlir/lib/Tools/lsp-server-support/Protocol.h index fdbbae2edfa63..03fe3ca6e41e7 100644 --- a/mlir/lib/Tools/lsp-server-support/Protocol.h +++ b/mlir/lib/Tools/lsp-server-support/Protocol.h @@ -780,6 +780,11 @@ enum class InsertTextFormat { }; struct CompletionItem { + CompletionItem() = default; + CompletionItem(StringRef label, CompletionItemKind kind) + : label(label.str()), kind(kind), + insertTextFormat(InsertTextFormat::PlainText) {} + /// The label of this completion item. By default also the text that is /// inserted when selecting this completion. std::string label; diff --git a/mlir/lib/Tools/mlir-lsp-server/LSPServer.cpp b/mlir/lib/Tools/mlir-lsp-server/LSPServer.cpp index e5c24443a9efb..a3b86af4249d2 100644 --- a/mlir/lib/Tools/mlir-lsp-server/LSPServer.cpp +++ b/mlir/lib/Tools/mlir-lsp-server/LSPServer.cpp @@ -62,6 +62,12 @@ struct LSPServer::Impl { void onDocumentSymbol(const DocumentSymbolParams ¶ms, Callback> reply); + //===--------------------------------------------------------------------===// + // Code Completion + + void onCompletion(const CompletionParams ¶ms, + Callback reply); + //===--------------------------------------------------------------------===// // Fields //===--------------------------------------------------------------------===// @@ -91,6 +97,15 @@ void LSPServer::Impl::onInitialize(const InitializeParams ¶ms, {"change", (int)TextDocumentSyncKind::Full}, {"save", true}, }}, + {"completionProvider", + llvm::json::Object{ + {"allCommitCharacters", + {"\t", "(", ")", "[", "]", "<", ">", ";", ",", "+", "-", "/", "*", + "&", "?", ".", "=", "|"}}, + {"resolveProvider", false}, + {"triggerCharacters", + {".", "%", "^", "!", "#", "(", ",", "<", ":", "[", " ", "\"", "/"}}, + }}, {"definitionProvider", true}, {"referencesProvider", true}, {"hoverProvider", true}, @@ -192,6 +207,14 @@ void LSPServer::Impl::onDocumentSymbol( reply(std::move(symbols)); } +//===----------------------------------------------------------------------===// +// Code Completion + +void LSPServer::Impl::onCompletion(const CompletionParams ¶ms, + Callback reply) { + reply(server.getCodeCompletion(params.textDocument.uri, params.position)); +} + //===----------------------------------------------------------------------===// // LSPServer //===----------------------------------------------------------------------===// @@ -229,6 +252,10 @@ LogicalResult LSPServer::run() { messageHandler.method("textDocument/documentSymbol", impl.get(), &Impl::onDocumentSymbol); + // Code Completion + messageHandler.method("textDocument/completion", impl.get(), + &Impl::onCompletion); + // Diagnostics impl->publishDiagnostics = messageHandler.outgoingNotification( diff --git a/mlir/lib/Tools/mlir-lsp-server/MLIRServer.cpp b/mlir/lib/Tools/mlir-lsp-server/MLIRServer.cpp index f5671692e6a82..c2097747d8452 100644 --- a/mlir/lib/Tools/mlir-lsp-server/MLIRServer.cpp +++ b/mlir/lib/Tools/mlir-lsp-server/MLIRServer.cpp @@ -13,6 +13,7 @@ #include "mlir/IR/FunctionInterfaces.h" #include "mlir/IR/Operation.h" #include "mlir/Parser/AsmParserState.h" +#include "mlir/Parser/CodeComplete.h" #include "mlir/Parser/Parser.h" #include "llvm/Support/SourceMgr.h" @@ -276,6 +277,14 @@ struct MLIRDocument { void findDocumentSymbols(Operation *op, std::vector &symbols); + //===--------------------------------------------------------------------===// + // Code Completion + //===--------------------------------------------------------------------===// + + lsp::CompletionList getCodeCompletion(const lsp::URIForFile &uri, + const lsp::Position &completePos, + const DialectRegistry ®istry); + //===--------------------------------------------------------------------===// // Fields //===--------------------------------------------------------------------===// @@ -615,6 +624,102 @@ void MLIRDocument::findDocumentSymbols( findDocumentSymbols(&childOp, *childSymbols); } +//===----------------------------------------------------------------------===// +// MLIRDocument: Code Completion +//===----------------------------------------------------------------------===// + +namespace { +class LSPCodeCompleteContext : public AsmParserCodeCompleteContext { +public: + LSPCodeCompleteContext(SMLoc completeLoc, lsp::CompletionList &completionList, + MLIRContext *ctx) + : AsmParserCodeCompleteContext(completeLoc), + completionList(completionList), ctx(ctx) {} + + /// Signal code completion for a dialect name. + void completeDialectName() final { + for (StringRef dialect : ctx->getAvailableDialects()) { + lsp::CompletionItem item(dialect, lsp::CompletionItemKind::Module); + item.sortText = "2"; + item.detail = "dialect"; + completionList.items.emplace_back(item); + } + } + + /// Signal code completion for an operation name within the given dialect. + void completeOperationName(StringRef dialectName) final { + Dialect *dialect = ctx->getOrLoadDialect(dialectName); + if (!dialect) + return; + + for (const auto &op : ctx->getRegisteredOperations()) { + if (&op.getDialect() != dialect) + continue; + + lsp::CompletionItem item( + op.getStringRef().drop_front(dialectName.size() + 1), + lsp::CompletionItemKind::Field); + item.sortText = "1"; + item.detail = "operation"; + completionList.items.emplace_back(item); + } + } + + /// Append the given SSA value as a code completion result for SSA value + /// completions. + void appendSSAValueCompletion(StringRef name, std::string typeData) final { + // Check if we need to insert the `%` or not. + bool stripPrefix = getCodeCompleteLoc().getPointer()[-1] == '%'; + + lsp::CompletionItem item(name, lsp::CompletionItemKind::Variable); + if (stripPrefix) + item.insertText = name.drop_front(1).str(); + item.detail = std::move(typeData); + completionList.items.emplace_back(item); + } + + /// Append the given block as a code completion result for block name + /// completions. + void appendBlockCompletion(StringRef name) final { + // Check if we need to insert the `^` or not. + bool stripPrefix = getCodeCompleteLoc().getPointer()[-1] == '^'; + + lsp::CompletionItem item(name, lsp::CompletionItemKind::Field); + if (stripPrefix) + item.insertText = name.drop_front(1).str(); + completionList.items.emplace_back(item); + } + +private: + lsp::CompletionList &completionList; + MLIRContext *ctx; +}; +} // namespace + +lsp::CompletionList +MLIRDocument::getCodeCompletion(const lsp::URIForFile &uri, + const lsp::Position &completePos, + const DialectRegistry ®istry) { + SMLoc posLoc = completePos.getAsSMLoc(sourceMgr); + if (!posLoc.isValid()) + return lsp::CompletionList(); + + // To perform code completion, we run another parse of the module with the + // code completion context provided. + MLIRContext tmpContext(registry, MLIRContext::Threading::DISABLED); + tmpContext.allowUnregisteredDialects(); + lsp::CompletionList completionList; + LSPCodeCompleteContext lspCompleteContext(posLoc, completionList, + &tmpContext); + + Block tmpIR; + AsmParserState tmpState; + (void)parseSourceFile(sourceMgr, &tmpIR, &tmpContext, + /*sourceFileLoc=*/nullptr, &tmpState, + &lspCompleteContext); + return completionList; +} + //===----------------------------------------------------------------------===// // MLIRTextFileChunk //===----------------------------------------------------------------------===// @@ -670,6 +775,8 @@ class MLIRTextFile { Optional findHover(const lsp::URIForFile &uri, lsp::Position hoverPos); void findDocumentSymbols(std::vector &symbols); + lsp::CompletionList getCodeCompletion(const lsp::URIForFile &uri, + lsp::Position completePos); private: /// Find the MLIR document that contains the given position, and update the @@ -813,6 +920,22 @@ void MLIRTextFile::findDocumentSymbols( } } +lsp::CompletionList MLIRTextFile::getCodeCompletion(const lsp::URIForFile &uri, + lsp::Position completePos) { + MLIRTextFileChunk &chunk = getChunkFor(completePos); + lsp::CompletionList completionList = chunk.document.getCodeCompletion( + uri, completePos, context.getDialectRegistry()); + + // Adjust any completion locations. + for (lsp::CompletionItem &item : completionList.items) { + if (item.textEdit) + chunk.adjustLocForChunkOffset(item.textEdit->range); + for (lsp::TextEdit &edit : item.additionalTextEdits) + chunk.adjustLocForChunkOffset(edit.range); + } + return completionList; +} + MLIRTextFileChunk &MLIRTextFile::getChunkFor(lsp::Position &pos) { if (chunks.size() == 1) return *chunks.front(); @@ -898,3 +1021,12 @@ void lsp::MLIRServer::findDocumentSymbols( if (fileIt != impl->files.end()) fileIt->second->findDocumentSymbols(symbols); } + +lsp::CompletionList +lsp::MLIRServer::getCodeCompletion(const URIForFile &uri, + const Position &completePos) { + auto fileIt = impl->files.find(uri.file()); + if (fileIt != impl->files.end()) + return fileIt->second->getCodeCompletion(uri, completePos); + return CompletionList(); +} diff --git a/mlir/lib/Tools/mlir-lsp-server/MLIRServer.h b/mlir/lib/Tools/mlir-lsp-server/MLIRServer.h index d78d8c3dc6ec8..85ccf0a68f1b0 100644 --- a/mlir/lib/Tools/mlir-lsp-server/MLIRServer.h +++ b/mlir/lib/Tools/mlir-lsp-server/MLIRServer.h @@ -16,6 +16,7 @@ namespace mlir { class DialectRegistry; namespace lsp { +struct CompletionList; struct Diagnostic; struct DocumentSymbol; struct Hover; @@ -60,6 +61,10 @@ class MLIRServer { void findDocumentSymbols(const URIForFile &uri, std::vector &symbols); + /// Get the code completion list for the position within the given file. + CompletionList getCodeCompletion(const URIForFile &uri, + const Position &completePos); + private: struct Impl; diff --git a/mlir/test/mlir-lsp-server/completion.test b/mlir/test/mlir-lsp-server/completion.test new file mode 100644 index 0000000000000..b13dbe147f057 --- /dev/null +++ b/mlir/test/mlir-lsp-server/completion.test @@ -0,0 +1,105 @@ +// RUN: mlir-lsp-server -lit-test < %s | FileCheck -strict-whitespace %s +{"jsonrpc":"2.0","id":0,"method":"initialize","params":{"processId":123,"rootPath":"mlir","capabilities":{},"trace":"off"}} +// ----- +{"jsonrpc":"2.0","method":"textDocument/didOpen","params":{"textDocument":{ + "uri":"test:///foo.mlir", + "languageId":"mlir", + "version":1, + "text":"func.func private @foo(%arg: i32) -> i32 {\n%cast = \"builtin.unrealized_conversion_cast\"() : () -> (i32)\nreturn %" +}}} +// ----- +{"jsonrpc":"2.0","id":1,"method":"textDocument/completion","params":{ + "textDocument":{"uri":"test:///foo.mlir"}, + "position":{"line":0,"character":0} +}} +// CHECK: "id": 1 +// CHECK-NEXT: "jsonrpc": "2.0", +// CHECK-NEXT: "result": { +// CHECK-NEXT: "isIncomplete": false, +// CHECK-NEXT: "items": [ +// CHECK: { +// CHECK: "detail": "dialect", +// CHECK: "insertTextFormat": 1, +// CHECK: "kind": 9, +// CHECK: "label": "builtin", +// CHECK: "sortText": "2" +// CHECK: }, +// CHECK: { +// CHECK: "detail": "operation", +// CHECK: "insertTextFormat": 1, +// CHECK: "kind": 5, +// CHECK: "label": "module", +// CHECK: "sortText": "1" +// CHECK: } +// CHECK: ] +// CHECK-NEXT: } +// ----- +{"jsonrpc":"2.0","id":1,"method":"textDocument/completion","params":{ + "textDocument":{"uri":"test:///foo.mlir"}, + "position":{"line":1,"character":9} +}} +// CHECK: "id": 1 +// CHECK-NEXT: "jsonrpc": "2.0", +// CHECK-NEXT: "result": { +// CHECK-NEXT: "isIncomplete": false, +// CHECK-NEXT: "items": [ +// CHECK: { +// CHECK: "detail": "dialect", +// CHECK: "insertTextFormat": 1, +// CHECK: "kind": 9, +// CHECK: "label": "builtin", +// CHECK: "sortText": "2" +// CHECK: }, +// CHECK-NOT: "detail": "operation", +// CHECK: ] +// CHECK-NEXT: } +// ----- +{"jsonrpc":"2.0","id":1,"method":"textDocument/completion","params":{ + "textDocument":{"uri":"test:///foo.mlir"}, + "position":{"line":1,"character":17} +}} +// CHECK: "id": 1 +// CHECK-NEXT: "jsonrpc": "2.0", +// CHECK-NEXT: "result": { +// CHECK-NEXT: "isIncomplete": false, +// CHECK-NEXT: "items": [ +// CHECK-NOT: "detail": "dialect", +// CHECK: { +// CHECK: "detail": "operation", +// CHECK: "insertTextFormat": 1, +// CHECK: "kind": 5, +// CHECK: "label": "module", +// CHECK: "sortText": "1" +// CHECK: } +// CHECK: ] +// CHECK-NEXT: } +// ----- +{"jsonrpc":"2.0","id":1,"method":"textDocument/completion","params":{ + "textDocument":{"uri":"test:///foo.mlir"}, + "position":{"line":2,"character":8} +}} +// CHECK: "id": 1 +// CHECK-NEXT: "jsonrpc": "2.0", +// CHECK-NEXT: "result": { +// CHECK-NEXT: "isIncomplete": false, +// CHECK-NEXT: "items": [ +// CHECK-NEXT: { +// CHECK-NEXT: "detail": "builtin.unrealized_conversion_cast: i32", +// CHECK-NEXT: "insertText": "cast", +// CHECK-NEXT: "insertTextFormat": 1, +// CHECK-NEXT: "kind": 6, +// CHECK-NEXT: "label": "%cast" +// CHECK-NEXT: }, +// CHECK-NEXT: { +// CHECK-NEXT: "detail": "arg #0: i32", +// CHECK-NEXT: "insertText": "arg", +// CHECK-NEXT: "insertTextFormat": 1, +// CHECK-NEXT: "kind": 6, +// CHECK-NEXT: "label": "%arg" +// CHECK-NEXT: } +// CHECK: ] +// CHECK-NEXT: } +// ----- +{"jsonrpc":"2.0","id":3,"method":"shutdown"} +// ----- +{"jsonrpc":"2.0","method":"exit"} diff --git a/mlir/test/mlir-lsp-server/initialize-params.test b/mlir/test/mlir-lsp-server/initialize-params.test index db41a61a8a1a8..2fe528f8ed2f2 100644 --- a/mlir/test/mlir-lsp-server/initialize-params.test +++ b/mlir/test/mlir-lsp-server/initialize-params.test @@ -5,6 +5,13 @@ // CHECK-NEXT: "jsonrpc": "2.0", // CHECK-NEXT: "result": { // CHECK-NEXT: "capabilities": { +// CHECK-NEXT: "completionProvider": { +// CHECK-NEXT: "allCommitCharacters": [ +// CHECK: ], +// CHECK-NEXT: "resolveProvider": false, +// CHECK-NEXT: "triggerCharacters": [ +// CHECK: ] +// CHECK-NEXT: }, // CHECK-NEXT: "definitionProvider": true, // CHECK-NEXT: "documentSymbolProvider": false, // CHECK-NEXT: "hoverProvider": true,