100 changes: 100 additions & 0 deletions mlir/include/mlir/Tools/PDLL/Parser/CodeComplete.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
//===- CodeComplete.h - PDLL Frontend CodeComplete Context ------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#ifndef MLIR_TOOLS_PDLL_PARSER_CODECOMPLETE_H_
#define MLIR_TOOLS_PDLL_PARSER_CODECOMPLETE_H_

#include "mlir/Support/LLVM.h"
#include "llvm/Support/SourceMgr.h"

namespace mlir {
namespace pdll {
namespace ast {
class CallableDecl;
class DeclScope;
class Expr;
class OperationType;
class TupleType;
class Type;
class VariableDecl;
} // namespace ast

/// This class provides an abstract interface into the parser for hooking in
/// code completion events.
class CodeCompleteContext {
public:
virtual ~CodeCompleteContext();

/// Return the location used to provide code completion.
SMLoc getCodeCompleteLoc() const { return codeCompleteLoc; }

//===--------------------------------------------------------------------===//
// Completion Hooks
//===--------------------------------------------------------------------===//

/// Signal code completion for a member access into the given tuple type.
virtual void codeCompleteTupleMemberAccess(ast::TupleType tupleType);

/// Signal code completion for a member access into the given operation type.
virtual void codeCompleteOperationMemberAccess(ast::OperationType opType);

/// Signal code completion for a member access into the given operation type.
virtual void codeCompleteOperationAttributeName(StringRef opName) {}

/// Signal code completion for a constraint name with an optional decl scope.
/// `currentType` is the current type of the variable that will use the
/// constraint, or nullptr if a type is unknown. `allowNonCoreConstraints`
/// indicates if user defined constraints are allowed in the completion
/// results. `allowInlineTypeConstraints` enables inline type constraints for
/// Attr/Value/ValueRange.
virtual void codeCompleteConstraintName(ast::Type currentType,
bool allowNonCoreConstraints,
bool allowInlineTypeConstraints,
const ast::DeclScope *scope);

/// Signal code completion for a dialect name.
virtual void codeCompleteDialectName() {}

/// Signal code completion for an operation name in the given dialect.
virtual void codeCompleteOperationName(StringRef dialectName) {}

/// Signal code completion for Pattern metadata.
virtual void codeCompletePatternMetadata() {}

//===--------------------------------------------------------------------===//
// Signature Hooks
//===--------------------------------------------------------------------===//

/// Signal code completion for the signature of a callable.
virtual void codeCompleteCallSignature(const ast::CallableDecl *callable,
unsigned currentNumArgs) {}

/// Signal code completion for the signature of an operation's operands.
virtual void
codeCompleteOperationOperandsSignature(Optional<StringRef> opName,
unsigned currentNumOperands) {}

/// Signal code completion for the signature of an operation's results.
virtual void
codeCompleteOperationResultsSignature(Optional<StringRef> opName,
unsigned currentNumResults) {}

protected:
/// Create a new code completion context with the given code complete
/// location.
explicit CodeCompleteContext(SMLoc codeCompleteLoc)
: codeCompleteLoc(codeCompleteLoc) {}

private:
/// The location used to code complete.
SMLoc codeCompleteLoc;
};
} // namespace pdll
} // namespace mlir

#endif // MLIR_TOOLS_PDLL_PARSER_CODECOMPLETE_H_
11 changes: 8 additions & 3 deletions mlir/include/mlir/Tools/PDLL/Parser/Parser.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,19 @@ class SourceMgr;

namespace mlir {
namespace pdll {
class CodeCompleteContext;

namespace ast {
class Context;
class Module;
} // namespace ast

/// Parse an AST module from the main file of the given source manager.
FailureOr<ast::Module *> parsePDLAST(ast::Context &ctx,
llvm::SourceMgr &sourceMgr);
/// Parse an AST module from the main file of the given source manager. An
/// optional code completion context may be provided to receive code completion
/// suggestions. If a completion is hit, this method returns a failure.
FailureOr<ast::Module *>
parsePDLAST(ast::Context &ctx, llvm::SourceMgr &sourceMgr,
CodeCompleteContext *codeCompleteContext = nullptr);
} // namespace pdll
} // namespace mlir

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
//===- MlirPdllLspServerMain.h - MLIR PDLL Language Server main -*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// Main entry function for mlir-pdll-lsp-server for when built as standalone
// binary.
//
//===----------------------------------------------------------------------===//

#ifndef MLIR_TOOLS_MLIR_PDLL_LSP_SERVER_MLIRPDLLLSPSERVERMAIN_H
#define MLIR_TOOLS_MLIR_PDLL_LSP_SERVER_MLIRPDLLLSPSERVERMAIN_H

namespace mlir {
struct LogicalResult;

/// Implementation for tools like `mlir-pdll-lsp-server`.
LogicalResult MlirPdllLspServerMain(int argc, char **argv);

} // namespace mlir

#endif // MLIR_TOOLS_MLIR_PDLL_LSP_SERVER_MLIRPDLLLSPSERVERMAIN_H
12 changes: 5 additions & 7 deletions mlir/lib/Conversion/PDLToPDLInterp/PDLToPDLInterp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -431,9 +431,8 @@ void PatternLowering::generate(BoolNode *boolNode, Block *&currentBlock,
}
case Predicates::ConstraintQuestion: {
auto *cstQuestion = cast<ConstraintQuestion>(question);
builder.create<pdl_interp::ApplyConstraintOp>(
loc, cstQuestion->getName(), args, cstQuestion->getParams(), success,
failure);
builder.create<pdl_interp::ApplyConstraintOp>(loc, cstQuestion->getName(),
args, success, failure);
break;
}
default:
Expand Down Expand Up @@ -644,8 +643,7 @@ SymbolRefAttr PatternLowering::generateRewriter(
auto mappedArgs = llvm::map_range(rewriter.externalArgs(), mapRewriteValue);
args.append(mappedArgs.begin(), mappedArgs.end());
builder.create<pdl_interp::ApplyRewriteOp>(
rewriter.getLoc(), /*resultTypes=*/TypeRange(), rewriteName, args,
rewriter.externalConstParamsAttr());
rewriter.getLoc(), /*resultTypes=*/TypeRange(), rewriteName, args);
} else {
// Otherwise this is a dag rewriter defined using PDL operations.
for (Operation &rewriteOp : *rewriter.getBody()) {
Expand Down Expand Up @@ -678,8 +676,8 @@ void PatternLowering::generateRewriter(
arguments.push_back(mapRewriteValue(argument));
auto interpOp = builder.create<pdl_interp::ApplyRewriteOp>(
rewriteOp.getLoc(), rewriteOp.getResultTypes(), rewriteOp.nameAttr(),
arguments, rewriteOp.constParamsAttr());
for (auto it : llvm::zip(rewriteOp.results(), interpOp.getResults()))
arguments);
for (auto it : llvm::zip(rewriteOp.getResults(), interpOp.getResults()))
rewriteValues[std::get<0>(it)] = std::get<1>(it);
}

Expand Down
23 changes: 7 additions & 16 deletions mlir/lib/Conversion/PDLToPDLInterp/Predicate.h
Original file line number Diff line number Diff line change
Expand Up @@ -445,10 +445,9 @@ struct AttributeQuestion

/// Apply a parameterized constraint to multiple position values.
struct ConstraintQuestion
: public PredicateBase<
ConstraintQuestion, Qualifier,
std::tuple<StringRef, ArrayRef<Position *>, Attribute>,
Predicates::ConstraintQuestion> {
: public PredicateBase<ConstraintQuestion, Qualifier,
std::tuple<StringRef, ArrayRef<Position *>>,
Predicates::ConstraintQuestion> {
using Base::Base;

/// Return the name of the constraint.
Expand All @@ -457,17 +456,11 @@ struct ConstraintQuestion
/// Return the arguments of the constraint.
ArrayRef<Position *> getArgs() const { return std::get<1>(key); }

/// Return the constant parameters of the constraint.
ArrayAttr getParams() const {
return std::get<2>(key).dyn_cast_or_null<ArrayAttr>();
}

/// Construct an instance with the given storage allocator.
static ConstraintQuestion *construct(StorageUniquer::StorageAllocator &alloc,
KeyTy key) {
return Base::construct(alloc, KeyTy{alloc.copyInto(std::get<0>(key)),
alloc.copyInto(std::get<1>(key)),
std::get<2>(key)});
alloc.copyInto(std::get<1>(key))});
}
};

Expand Down Expand Up @@ -667,11 +660,9 @@ class PredicateBuilder {
}

/// Create a predicate that applies a generic constraint.
Predicate getConstraint(StringRef name, ArrayRef<Position *> pos,
Attribute params) {
return {
ConstraintQuestion::get(uniquer, std::make_tuple(name, pos, params)),
TrueAnswer::get(uniquer)};
Predicate getConstraint(StringRef name, ArrayRef<Position *> pos) {
return {ConstraintQuestion::get(uniquer, std::make_tuple(name, pos)),
TrueAnswer::get(uniquer)};
}

/// Create a predicate comparing a value with null.
Expand Down
3 changes: 1 addition & 2 deletions mlir/lib/Conversion/PDLToPDLInterp/PredicateTree.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -263,7 +263,6 @@ static void getConstraintPredicates(pdl::ApplyNativeConstraintOp op,
PredicateBuilder &builder,
DenseMap<Value, Position *> &inputs) {
OperandRange arguments = op.args();
ArrayAttr parameters = op.constParamsAttr();

std::vector<Position *> allPositions;
allPositions.reserve(arguments.size());
Expand All @@ -274,7 +273,7 @@ static void getConstraintPredicates(pdl::ApplyNativeConstraintOp op,
Position *pos = *std::max_element(allPositions.begin(), allPositions.end(),
comparePosDepth);
PredicateBuilder::Predicate pred =
builder.getConstraint(op.name(), allPositions, parameters);
builder.getConstraint(op.name(), allPositions);
predList.emplace_back(pos, pred);
}

Expand Down
4 changes: 0 additions & 4 deletions mlir/lib/Dialect/PDL/IR/PDL.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -425,10 +425,6 @@ LogicalResult RewriteOp::verifyRegions() {
return emitOpError() << "expected no external arguments when the "
"rewrite is specified inline";
}
if (externalConstParams()) {
return emitOpError() << "expected no external constant parameters when "
"the rewrite is specified inline";
}

return success();
}
Expand Down
14 changes: 4 additions & 10 deletions mlir/lib/Rewrite/ByteCode.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -757,17 +757,15 @@ void Generator::generate(pdl_interp::ApplyConstraintOp op,
ByteCodeWriter &writer) {
assert(constraintToMemIndex.count(op.getName()) &&
"expected index for constraint function");
writer.append(OpCode::ApplyConstraint, constraintToMemIndex[op.getName()],
op.getConstParamsAttr());
writer.append(OpCode::ApplyConstraint, constraintToMemIndex[op.getName()]);
writer.appendPDLValueList(op.getArgs());
writer.append(op.getSuccessors());
}
void Generator::generate(pdl_interp::ApplyRewriteOp op,
ByteCodeWriter &writer) {
assert(externalRewriterToMemIndex.count(op.getName()) &&
"expected index for rewrite function");
writer.append(OpCode::ApplyRewrite, externalRewriterToMemIndex[op.getName()],
op.getConstParamsAttr());
writer.append(OpCode::ApplyRewrite, externalRewriterToMemIndex[op.getName()]);
writer.appendPDLValueList(op.getArgs());

ResultRange results = op.getResults();
Expand Down Expand Up @@ -1333,37 +1331,33 @@ class ByteCodeRewriteResultList : public PDLResultList {
void ByteCodeExecutor::executeApplyConstraint(PatternRewriter &rewriter) {
LLVM_DEBUG(llvm::dbgs() << "Executing ApplyConstraint:\n");
const PDLConstraintFunction &constraintFn = constraintFunctions[read()];
ArrayAttr constParams = read<ArrayAttr>();
SmallVector<PDLValue, 16> args;
readList<PDLValue>(args);

LLVM_DEBUG({
llvm::dbgs() << " * Arguments: ";
llvm::interleaveComma(args, llvm::dbgs());
llvm::dbgs() << "\n * Parameters: " << constParams << "\n";
});

// Invoke the constraint and jump to the proper destination.
selectJump(succeeded(constraintFn(args, constParams, rewriter)));
selectJump(succeeded(constraintFn(args, rewriter)));
}

void ByteCodeExecutor::executeApplyRewrite(PatternRewriter &rewriter) {
LLVM_DEBUG(llvm::dbgs() << "Executing ApplyRewrite:\n");
const PDLRewriteFunction &rewriteFn = rewriteFunctions[read()];
ArrayAttr constParams = read<ArrayAttr>();
SmallVector<PDLValue, 16> args;
readList<PDLValue>(args);

LLVM_DEBUG({
llvm::dbgs() << " * Arguments: ";
llvm::interleaveComma(args, llvm::dbgs());
llvm::dbgs() << "\n * Parameters: " << constParams << "\n";
});

// Execute the rewrite function.
ByteCodeField numResults = read();
ByteCodeRewriteResultList results(numResults);
rewriteFn(args, constParams, rewriter, results);
rewriteFn(args, rewriter, results);

assert(results.getResults().size() == numResults &&
"native PDL rewrite function returned unexpected number of results");
Expand Down
1 change: 1 addition & 0 deletions mlir/lib/Tools/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
add_subdirectory(lsp-server-support)
add_subdirectory(mlir-lsp-server)
add_subdirectory(mlir-opt)
add_subdirectory(mlir-pdll-lsp-server)
add_subdirectory(mlir-reduce)
add_subdirectory(mlir-translate)
add_subdirectory(PDLL)
130 changes: 130 additions & 0 deletions mlir/lib/Tools/PDLL/AST/Nodes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
#include "mlir/Tools/PDLL/AST/Nodes.h"
#include "mlir/Tools/PDLL/AST/Context.h"
#include "llvm/ADT/SmallPtrSet.h"
#include "llvm/ADT/TypeSwitch.h"

using namespace mlir;
using namespace mlir::pdll::ast;
Expand All @@ -33,6 +34,135 @@ const Name &Name::create(Context &ctx, StringRef name, SMRange location) {
Name(copyStringWithNull(ctx, name), location);
}

//===----------------------------------------------------------------------===//
// Node
//===----------------------------------------------------------------------===//

namespace {
class NodeVisitor {
public:
explicit NodeVisitor(function_ref<void(const Node *)> visitFn)
: visitFn(visitFn) {}

void visit(const Node *node) {
if (!node || !alreadyVisited.insert(node).second)
return;

visitFn(node);
TypeSwitch<const Node *>(node)
.Case<
// Statements.
const CompoundStmt, const EraseStmt, const LetStmt,
const ReplaceStmt, const ReturnStmt, const RewriteStmt,

// Expressions.
const AttributeExpr, const CallExpr, const DeclRefExpr,
const MemberAccessExpr, const OperationExpr, const TupleExpr,
const TypeExpr,

// Core Constraint Decls.
const AttrConstraintDecl, const OpConstraintDecl,
const TypeConstraintDecl, const TypeRangeConstraintDecl,
const ValueConstraintDecl, const ValueRangeConstraintDecl,

// Decls.
const NamedAttributeDecl, const OpNameDecl, const PatternDecl,
const UserConstraintDecl, const UserRewriteDecl, const VariableDecl,

const Module>(
[&](auto derivedNode) { this->visitImpl(derivedNode); })
.Default([](const Node *) { llvm_unreachable("unknown AST node"); });
}

private:
void visitImpl(const CompoundStmt *stmt) {
for (const Node *child : stmt->getChildren())
visit(child);
}
void visitImpl(const EraseStmt *stmt) { visit(stmt->getRootOpExpr()); }
void visitImpl(const LetStmt *stmt) { visit(stmt->getVarDecl()); }
void visitImpl(const ReplaceStmt *stmt) {
visit(stmt->getRootOpExpr());
for (const Node *child : stmt->getReplExprs())
visit(child);
}
void visitImpl(const ReturnStmt *stmt) { visit(stmt->getResultExpr()); }
void visitImpl(const RewriteStmt *stmt) {
visit(stmt->getRootOpExpr());
visit(stmt->getRewriteBody());
}

void visitImpl(const AttributeExpr *expr) {}
void visitImpl(const CallExpr *expr) {
visit(expr->getCallableExpr());
for (const Node *child : expr->getArguments())
visit(child);
}
void visitImpl(const DeclRefExpr *expr) { visit(expr->getDecl()); }
void visitImpl(const MemberAccessExpr *expr) { visit(expr->getParentExpr()); }
void visitImpl(const OperationExpr *expr) {
visit(expr->getNameDecl());
for (const Node *child : expr->getOperands())
visit(child);
for (const Node *child : expr->getResultTypes())
visit(child);
for (const Node *child : expr->getAttributes())
visit(child);
}
void visitImpl(const TupleExpr *expr) {
for (const Node *child : expr->getElements())
visit(child);
}
void visitImpl(const TypeExpr *expr) {}

void visitImpl(const AttrConstraintDecl *decl) { visit(decl->getTypeExpr()); }
void visitImpl(const OpConstraintDecl *decl) { visit(decl->getNameDecl()); }
void visitImpl(const TypeConstraintDecl *decl) {}
void visitImpl(const TypeRangeConstraintDecl *decl) {}
void visitImpl(const ValueConstraintDecl *decl) {
visit(decl->getTypeExpr());
}
void visitImpl(const ValueRangeConstraintDecl *decl) {
visit(decl->getTypeExpr());
}

void visitImpl(const NamedAttributeDecl *decl) { visit(decl->getValue()); }
void visitImpl(const OpNameDecl *decl) {}
void visitImpl(const PatternDecl *decl) { visit(decl->getBody()); }
void visitImpl(const UserConstraintDecl *decl) {
for (const Node *child : decl->getInputs())
visit(child);
for (const Node *child : decl->getResults())
visit(child);
visit(decl->getBody());
}
void visitImpl(const UserRewriteDecl *decl) {
for (const Node *child : decl->getInputs())
visit(child);
for (const Node *child : decl->getResults())
visit(child);
visit(decl->getBody());
}
void visitImpl(const VariableDecl *decl) {
visit(decl->getInitExpr());
for (const ConstraintRef &child : decl->getConstraints())
visit(child.constraint);
}

void visitImpl(const Module *module) {
for (const Node *child : module->getChildren())
visit(child);
}

function_ref<void(const Node *)> visitFn;
SmallPtrSet<const Node *, 16> alreadyVisited;
};
} // namespace

void Node::walk(function_ref<void(const Node *)> walkFn) const {
return NodeVisitor(walkFn).visit(this);
}

//===----------------------------------------------------------------------===//
// DeclScope
//===----------------------------------------------------------------------===//
Expand Down
2 changes: 1 addition & 1 deletion mlir/lib/Tools/PDLL/CodeGen/CPPGen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,7 @@ void CodeGen::generateConstraintOrRewrite(StringRef name, bool isConstraint,
// what we need as a frontend.
os << "static " << (isConstraint ? "::mlir::LogicalResult " : "void ") << name
<< "PDLFn(::llvm::ArrayRef<::mlir::PDLValue> values, "
"::mlir::ArrayAttr constParams, ::mlir::PatternRewriter &rewriter"
"::mlir::PatternRewriter &rewriter"
<< (isConstraint ? "" : ", ::mlir::PDLResultList &results") << ") {\n";

const char *argumentInitStr = R"(
Expand Down
16 changes: 5 additions & 11 deletions mlir/lib/Tools/PDLL/CodeGen/MLIRGen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -200,9 +200,9 @@ void CodeGen::genImpl(const ast::CompoundStmt *stmt) {
static void checkAndNestUnderRewriteOp(OpBuilder &builder, Value rootExpr,
Location loc) {
if (isa<pdl::PatternOp>(builder.getInsertionBlock()->getParentOp())) {
pdl::RewriteOp rewrite = builder.create<pdl::RewriteOp>(
loc, rootExpr, /*name=*/StringAttr(),
/*externalArgs=*/ValueRange(), /*externalConstParams=*/ArrayAttr());
pdl::RewriteOp rewrite =
builder.create<pdl::RewriteOp>(loc, rootExpr, /*name=*/StringAttr(),
/*externalArgs=*/ValueRange());
builder.createBlock(&rewrite.body());
}
}
Expand Down Expand Up @@ -564,14 +564,8 @@ SmallVector<Value> CodeGen::genConstraintOrRewriteCall(const T *decl,
} else {
resultTypes.push_back(genType(declResultType));
}

// FIXME: We currently do not have a modeling for the "constant params"
// support PDL provides. We should either figure out a modeling for this, or
// refactor the support within PDL to be something a bit more reasonable for
// what we need as a frontend.
Operation *pdlOp = builder.create<PDLOpT>(loc, resultTypes,
decl->getName().getName(), inputs,
/*params=*/ArrayAttr());
Operation *pdlOp = builder.create<PDLOpT>(
loc, resultTypes, decl->getName().getName(), inputs);
return pdlOp->getResults();
}

Expand Down
1 change: 1 addition & 0 deletions mlir/lib/Tools/PDLL/Parser/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@

llvm_add_library(MLIRPDLLParser STATIC
CodeComplete.cpp
Lexer.cpp
Parser.cpp

Expand Down
28 changes: 28 additions & 0 deletions mlir/lib/Tools/PDLL/Parser/CodeComplete.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
//===- CodeComplete.cpp ---------------------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include "mlir/Tools/PDLL/Parser/CodeComplete.h"
#include "mlir/Tools/PDLL/AST/Types.h"

using namespace mlir;
using namespace mlir::pdll;

//===----------------------------------------------------------------------===//
// CodeCompleteContext
//===----------------------------------------------------------------------===//

CodeCompleteContext::~CodeCompleteContext() = default;

void CodeCompleteContext::codeCompleteTupleMemberAccess(
ast::TupleType tupleType) {}
void CodeCompleteContext::codeCompleteOperationMemberAccess(
ast::OperationType opType) {}

void CodeCompleteContext::codeCompleteConstraintName(
ast::Type currentType, bool allowNonCoreConstraints,
bool allowInlineTypeConstraints, const ast::DeclScope *scope) {}
17 changes: 15 additions & 2 deletions mlir/lib/Tools/PDLL/Parser/Lexer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
#include "Lexer.h"
#include "mlir/Support/LogicalResult.h"
#include "mlir/Tools/PDLL/AST/Diagnostic.h"
#include "mlir/Tools/PDLL/Parser/CodeComplete.h"
#include "llvm/ADT/StringExtras.h"
#include "llvm/ADT/StringSwitch.h"
#include "llvm/Support/SourceMgr.h"
Expand Down Expand Up @@ -67,12 +68,20 @@ std::string Token::getStringValue() const {
// Lexer
//===----------------------------------------------------------------------===//

Lexer::Lexer(llvm::SourceMgr &mgr, ast::DiagnosticEngine &diagEngine)
: srcMgr(mgr), diagEngine(diagEngine), addedHandlerToDiagEngine(false) {
Lexer::Lexer(llvm::SourceMgr &mgr, ast::DiagnosticEngine &diagEngine,
CodeCompleteContext *codeCompleteContext)
: srcMgr(mgr), diagEngine(diagEngine), addedHandlerToDiagEngine(false),
codeCompletionLocation(nullptr) {
curBufferID = mgr.getMainFileID();
curBuffer = srcMgr.getMemoryBuffer(curBufferID)->getBuffer();
curPtr = curBuffer.begin();

// Set the code completion location if necessary.
if (codeCompleteContext) {
codeCompletionLocation =
codeCompleteContext->getCodeCompleteLoc().getPointer();
}

// If the diag engine has no handler, add a default that emits to the
// SourceMgr.
if (!diagEngine.getHandlerFn()) {
Expand Down Expand Up @@ -147,6 +156,10 @@ Token Lexer::lexToken() {
while (true) {
const char *tokStart = curPtr;

// Check to see if this token is at the code completion location.
if (tokStart == codeCompletionLocation)
return formToken(Token::code_complete, tokStart);

// This always consumes at least one character.
int curChar = getNextChar();
switch (curChar) {
Expand Down
9 changes: 8 additions & 1 deletion mlir/lib/Tools/PDLL/Parser/Lexer.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ namespace mlir {
struct LogicalResult;

namespace pdll {
class CodeCompleteContext;

namespace ast {
class DiagnosticEngine;
} // namespace ast
Expand All @@ -35,6 +37,7 @@ class Token {
// Markers.
eof,
error,
code_complete,

// Keywords.
KW_BEGIN,
Expand Down Expand Up @@ -162,7 +165,8 @@ class Token {

class Lexer {
public:
Lexer(llvm::SourceMgr &mgr, ast::DiagnosticEngine &diagEngine);
Lexer(llvm::SourceMgr &mgr, ast::DiagnosticEngine &diagEngine,
CodeCompleteContext *codeCompleteContext);
~Lexer();

/// Return a reference to the source manager used by the lexer.
Expand Down Expand Up @@ -215,6 +219,9 @@ class Lexer {
/// A flag indicating if we added a default diagnostic handler to the provided
/// diagEngine.
bool addedHandlerToDiagEngine;

/// The optional code completion point within the input file.
const char *codeCompletionLocation;
};
} // namespace pdll
} // namespace mlir
Expand Down
195 changes: 178 additions & 17 deletions mlir/lib/Tools/PDLL/Parser/Parser.cpp

Large diffs are not rendered by default.

211 changes: 211 additions & 0 deletions mlir/lib/Tools/lsp-server-support/Protocol.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -585,3 +585,214 @@ llvm::json::Value mlir::lsp::toJSON(const PublishDiagnosticsParams &params) {
{"version", params.version},
};
}

//===----------------------------------------------------------------------===//
// TextEdit
//===----------------------------------------------------------------------===//

bool mlir::lsp::fromJSON(const llvm::json::Value &value, TextEdit &result,
llvm::json::Path path) {
llvm::json::ObjectMapper o(value, path);
return o && o.map("range", result.range) && o.map("newText", result.newText);
}

llvm::json::Value mlir::lsp::toJSON(const TextEdit &value) {
return llvm::json::Object{
{"range", value.range},
{"newText", value.newText},
};
}

raw_ostream &mlir::lsp::operator<<(raw_ostream &os, const TextEdit &value) {
os << value.range << " => \"";
llvm::printEscapedString(value.newText, os);
return os << '"';
}

//===----------------------------------------------------------------------===//
// CompletionItemKind
//===----------------------------------------------------------------------===//

bool mlir::lsp::fromJSON(const llvm::json::Value &value,
CompletionItemKind &result, llvm::json::Path path) {
if (Optional<int64_t> intValue = value.getAsInteger()) {
if (*intValue < static_cast<int>(CompletionItemKind::Text) ||
*intValue > static_cast<int>(CompletionItemKind::TypeParameter))
return false;
result = static_cast<CompletionItemKind>(*intValue);
return true;
}
return false;
}

CompletionItemKind mlir::lsp::adjustKindToCapability(
CompletionItemKind kind,
CompletionItemKindBitset &supportedCompletionItemKinds) {
size_t kindVal = static_cast<size_t>(kind);
if (kindVal >= kCompletionItemKindMin &&
kindVal <= supportedCompletionItemKinds.size() &&
supportedCompletionItemKinds[kindVal])
return kind;

// Provide some fall backs for common kinds that are close enough.
switch (kind) {
case CompletionItemKind::Folder:
return CompletionItemKind::File;
case CompletionItemKind::EnumMember:
return CompletionItemKind::Enum;
case CompletionItemKind::Struct:
return CompletionItemKind::Class;
default:
return CompletionItemKind::Text;
}
}

bool mlir::lsp::fromJSON(const llvm::json::Value &value,
CompletionItemKindBitset &result,
llvm::json::Path path) {
if (const llvm::json::Array *arrayValue = value.getAsArray()) {
for (size_t i = 0, e = arrayValue->size(); i < e; ++i) {
CompletionItemKind kindOut;
if (fromJSON((*arrayValue)[i], kindOut, path.index(i)))
result.set(size_t(kindOut));
}
return true;
}
return false;
}

//===----------------------------------------------------------------------===//
// CompletionItem
//===----------------------------------------------------------------------===//

llvm::json::Value mlir::lsp::toJSON(const CompletionItem &value) {
assert(!value.label.empty() && "completion item label is required");
llvm::json::Object result{{"label", value.label}};
if (value.kind != CompletionItemKind::Missing)
result["kind"] = static_cast<int>(value.kind);
if (!value.detail.empty())
result["detail"] = value.detail;
if (value.documentation)
result["documentation"] = value.documentation;
if (!value.sortText.empty())
result["sortText"] = value.sortText;
if (!value.filterText.empty())
result["filterText"] = value.filterText;
if (!value.insertText.empty())
result["insertText"] = value.insertText;
if (value.insertTextFormat != InsertTextFormat::Missing)
result["insertTextFormat"] = static_cast<int>(value.insertTextFormat);
if (value.textEdit)
result["textEdit"] = *value.textEdit;
if (!value.additionalTextEdits.empty()) {
result["additionalTextEdits"] =
llvm::json::Array(value.additionalTextEdits);
}
if (value.deprecated)
result["deprecated"] = value.deprecated;
return std::move(result);
}

raw_ostream &mlir::lsp::operator<<(raw_ostream &os,
const CompletionItem &value) {
return os << value.label << " - " << toJSON(value);
}

bool mlir::lsp::operator<(const CompletionItem &lhs,
const CompletionItem &rhs) {
return (lhs.sortText.empty() ? lhs.label : lhs.sortText) <
(rhs.sortText.empty() ? rhs.label : rhs.sortText);
}

//===----------------------------------------------------------------------===//
// CompletionList
//===----------------------------------------------------------------------===//

llvm::json::Value mlir::lsp::toJSON(const CompletionList &value) {
return llvm::json::Object{
{"isIncomplete", value.isIncomplete},
{"items", llvm::json::Array(value.items)},
};
}

//===----------------------------------------------------------------------===//
// CompletionContext
//===----------------------------------------------------------------------===//

bool mlir::lsp::fromJSON(const llvm::json::Value &value,
CompletionContext &result, llvm::json::Path path) {
llvm::json::ObjectMapper o(value, path);
int triggerKind;
if (!o || !o.map("triggerKind", triggerKind) ||
!mapOptOrNull(value, "triggerCharacter", result.triggerCharacter, path))
return false;
result.triggerKind = static_cast<CompletionTriggerKind>(triggerKind);
return true;
}

//===----------------------------------------------------------------------===//
// CompletionParams
//===----------------------------------------------------------------------===//

bool mlir::lsp::fromJSON(const llvm::json::Value &value,
CompletionParams &result, llvm::json::Path path) {
if (!fromJSON(value, static_cast<TextDocumentPositionParams &>(result), path))
return false;
if (const llvm::json::Value *context = value.getAsObject()->get("context"))
return fromJSON(*context, result.context, path.field("context"));
return true;
}

//===----------------------------------------------------------------------===//
// ParameterInformation
//===----------------------------------------------------------------------===//

llvm::json::Value mlir::lsp::toJSON(const ParameterInformation &value) {
assert((value.labelOffsets.hasValue() || !value.labelString.empty()) &&
"parameter information label is required");
llvm::json::Object result;
if (value.labelOffsets)
result["label"] = llvm::json::Array(
{value.labelOffsets->first, value.labelOffsets->second});
else
result["label"] = value.labelString;
if (!value.documentation.empty())
result["documentation"] = value.documentation;
return std::move(result);
}

//===----------------------------------------------------------------------===//
// SignatureInformation
//===----------------------------------------------------------------------===//

llvm::json::Value mlir::lsp::toJSON(const SignatureInformation &value) {
assert(!value.label.empty() && "signature information label is required");
llvm::json::Object result{
{"label", value.label},
{"parameters", llvm::json::Array(value.parameters)},
};
if (!value.documentation.empty())
result["documentation"] = value.documentation;
return std::move(result);
}

raw_ostream &mlir::lsp::operator<<(raw_ostream &os,
const SignatureInformation &value) {
return os << value.label << " - " << toJSON(value);
}

//===----------------------------------------------------------------------===//
// SignatureHelp
//===----------------------------------------------------------------------===//

llvm::json::Value mlir::lsp::toJSON(const SignatureHelp &value) {
assert(value.activeSignature >= 0 &&
"Unexpected negative value for number of active signatures.");
assert(value.activeParameter >= 0 &&
"Unexpected negative value for active parameter index");
return llvm::json::Object{
{"activeSignature", value.activeSignature},
{"activeParameter", value.activeParameter},
{"signatures", llvm::json::Array(value.signatures)},
};
}
292 changes: 291 additions & 1 deletion mlir/lib/Tools/lsp-server-support/Protocol.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
//===----------------------------------------------------------------------===//
//
// This file contains structs based on the LSP specification at
// https://github.com/Microsoft/language-server-protocol/blob/main/protocol.md
// https://microsoft.github.io/language-server-protocol/specification
//
// This is not meant to be a complete implementation, new interfaces are added
// when they're needed.
Expand All @@ -26,6 +26,7 @@
#include "mlir/Support/LLVM.h"
#include "llvm/ADT/Optional.h"
#include "llvm/Support/JSON.h"
#include "llvm/Support/SourceMgr.h"
#include "llvm/Support/raw_ostream.h"
#include <bitset>
#include <memory>
Expand Down Expand Up @@ -245,6 +246,13 @@ struct Position {
Position(int line = 0, int character = 0)
: line(line), character(character) {}

/// Construct a position from the given source location.
Position(llvm::SourceMgr &mgr, SMLoc loc) {
std::pair<unsigned, unsigned> lineAndCol = mgr.getLineAndColumn(loc);
line = lineAndCol.first - 1;
character = lineAndCol.second - 1;
}

/// Line position in a document (zero-based).
int line = 0;

Expand All @@ -266,6 +274,13 @@ struct Position {
return std::tie(lhs.line, lhs.character) <=
std::tie(rhs.line, rhs.character);
}

/// Convert this position into a source location in the main file of the given
/// source manager.
SMLoc getAsSMLoc(llvm::SourceMgr &mgr) const {
return mgr.FindLocForLineAndColumn(mgr.getMainFileID(), line + 1,
character);
}
};

/// Add support for JSON serialization.
Expand All @@ -283,6 +298,10 @@ struct Range {
Range(Position start, Position end) : start(start), end(end) {}
Range(Position loc) : Range(loc, loc) {}

/// Construct a range from the given source range.
Range(llvm::SourceMgr &mgr, SMRange range)
: Range(Position(mgr, range.Start), Position(mgr, range.End)) {}

/// The range's start position.
Position start;

Expand Down Expand Up @@ -316,6 +335,13 @@ raw_ostream &operator<<(raw_ostream &os, const Range &value);
//===----------------------------------------------------------------------===//

struct Location {
Location() = default;
Location(const URIForFile &uri, Range range) : uri(uri), range(range) {}

/// Construct a Location from the given source range.
Location(const URIForFile &uri, llvm::SourceMgr &mgr, SMRange range)
: Location(uri, Range(mgr, range)) {}

/// The text document's URI.
URIForFile uri;
Range range;
Expand Down Expand Up @@ -640,6 +666,270 @@ struct PublishDiagnosticsParams {
/// Add support for JSON serialization.
llvm::json::Value toJSON(const PublishDiagnosticsParams &params);

//===----------------------------------------------------------------------===//
// TextEdit
//===----------------------------------------------------------------------===//

struct TextEdit {
/// The range of the text document to be manipulated. To insert
/// text into a document create a range where start === end.
Range range;

/// The string to be inserted. For delete operations use an
/// empty string.
std::string newText;
};

inline bool operator==(const TextEdit &lhs, const TextEdit &rhs) {
return std::tie(lhs.newText, lhs.range) == std::tie(rhs.newText, rhs.range);
}

bool fromJSON(const llvm::json::Value &value, TextEdit &result,
llvm::json::Path path);
llvm::json::Value toJSON(const TextEdit &value);
raw_ostream &operator<<(raw_ostream &os, const TextEdit &value);

//===----------------------------------------------------------------------===//
// CompletionItemKind
//===----------------------------------------------------------------------===//

/// The kind of a completion entry.
enum class CompletionItemKind {
Missing = 0,
Text = 1,
Method = 2,
Function = 3,
Constructor = 4,
Field = 5,
Variable = 6,
Class = 7,
Interface = 8,
Module = 9,
Property = 10,
Unit = 11,
Value = 12,
Enum = 13,
Keyword = 14,
Snippet = 15,
Color = 16,
File = 17,
Reference = 18,
Folder = 19,
EnumMember = 20,
Constant = 21,
Struct = 22,
Event = 23,
Operator = 24,
TypeParameter = 25,
};
bool fromJSON(const llvm::json::Value &value, CompletionItemKind &result,
llvm::json::Path path);

constexpr auto kCompletionItemKindMin =
static_cast<size_t>(CompletionItemKind::Text);
constexpr auto kCompletionItemKindMax =
static_cast<size_t>(CompletionItemKind::TypeParameter);
using CompletionItemKindBitset = std::bitset<kCompletionItemKindMax + 1>;
bool fromJSON(const llvm::json::Value &value, CompletionItemKindBitset &result,
llvm::json::Path path);

CompletionItemKind
adjustKindToCapability(CompletionItemKind kind,
CompletionItemKindBitset &supportedCompletionItemKinds);

//===----------------------------------------------------------------------===//
// CompletionItem
//===----------------------------------------------------------------------===//

/// Defines whether the insert text in a completion item should be interpreted
/// as plain text or a snippet.
enum class InsertTextFormat {
Missing = 0,
/// The primary text to be inserted is treated as a plain string.
PlainText = 1,
/// The primary text to be inserted is treated as a snippet.
///
/// A snippet can define tab stops and placeholders with `$1`, `$2`
/// and `${3:foo}`. `$0` defines the final tab stop, it defaults to the end
/// of the snippet. Placeholders with equal identifiers are linked, that is
/// typing in one will update others too.
///
/// See also:
/// https//github.com/Microsoft/vscode/blob/master/src/vs/editor/contrib/snippet/common/snippet.md
Snippet = 2,
};

struct CompletionItem {
/// The label of this completion item. By default also the text that is
/// inserted when selecting this completion.
std::string label;

/// The kind of this completion item. Based of the kind an icon is chosen by
/// the editor.
CompletionItemKind kind = CompletionItemKind::Missing;

/// A human-readable string with additional information about this item, like
/// type or symbol information.
std::string detail;

/// A human-readable string that represents a doc-comment.
Optional<MarkupContent> documentation;

/// A string that should be used when comparing this item with other items.
/// When `falsy` the label is used.
std::string sortText;

/// A string that should be used when filtering a set of completion items.
/// When `falsy` the label is used.
std::string filterText;

/// A string that should be inserted to a document when selecting this
/// completion. When `falsy` the label is used.
std::string insertText;

/// The format of the insert text. The format applies to both the `insertText`
/// property and the `newText` property of a provided `textEdit`.
InsertTextFormat insertTextFormat = InsertTextFormat::Missing;

/// An edit which is applied to a document when selecting this completion.
/// When an edit is provided `insertText` is ignored.
///
/// Note: The range of the edit must be a single line range and it must
/// contain the position at which completion has been requested.
Optional<TextEdit> textEdit;

/// An optional array of additional text edits that are applied when selecting
/// this completion. Edits must not overlap with the main edit nor with
/// themselves.
std::vector<TextEdit> additionalTextEdits;

/// Indicates if this item is deprecated.
bool deprecated = false;
};

/// Add support for JSON serialization.
llvm::json::Value toJSON(const CompletionItem &value);
raw_ostream &operator<<(raw_ostream &os, const CompletionItem &value);
bool operator<(const CompletionItem &lhs, const CompletionItem &rhs);

//===----------------------------------------------------------------------===//
// CompletionList
//===----------------------------------------------------------------------===//

/// Represents a collection of completion items to be presented in the editor.
struct CompletionList {
/// The list is not complete. Further typing should result in recomputing the
/// list.
bool isIncomplete = false;

/// The completion items.
std::vector<CompletionItem> items;
};

/// Add support for JSON serialization.
llvm::json::Value toJSON(const CompletionList &value);

//===----------------------------------------------------------------------===//
// CompletionContext
//===----------------------------------------------------------------------===//

enum class CompletionTriggerKind {
/// Completion was triggered by typing an identifier (24x7 code
/// complete), manual invocation (e.g Ctrl+Space) or via API.
Invoked = 1,

/// Completion was triggered by a trigger character specified by
/// the `triggerCharacters` properties of the `CompletionRegistrationOptions`.
TriggerCharacter = 2,

/// Completion was re-triggered as the current completion list is incomplete.
TriggerTriggerForIncompleteCompletions = 3
};

struct CompletionContext {
/// How the completion was triggered.
CompletionTriggerKind triggerKind = CompletionTriggerKind::Invoked;

/// The trigger character (a single character) that has trigger code complete.
/// Is undefined if `triggerKind !== CompletionTriggerKind.TriggerCharacter`
std::string triggerCharacter;
};

/// Add support for JSON serialization.
bool fromJSON(const llvm::json::Value &value, CompletionContext &result,
llvm::json::Path path);

//===----------------------------------------------------------------------===//
// CompletionParams
//===----------------------------------------------------------------------===//

struct CompletionParams : TextDocumentPositionParams {
CompletionContext context;
};

/// Add support for JSON serialization.
bool fromJSON(const llvm::json::Value &value, CompletionParams &result,
llvm::json::Path path);

//===----------------------------------------------------------------------===//
// ParameterInformation
//===----------------------------------------------------------------------===//

/// A single parameter of a particular signature.
struct ParameterInformation {
/// The label of this parameter. Ignored when labelOffsets is set.
std::string labelString;

/// Inclusive start and exclusive end offsets withing the containing signature
/// label.
Optional<std::pair<unsigned, unsigned>> labelOffsets;

/// The documentation of this parameter. Optional.
std::string documentation;
};

/// Add support for JSON serialization.
llvm::json::Value toJSON(const ParameterInformation &value);

//===----------------------------------------------------------------------===//
// SignatureInformation
//===----------------------------------------------------------------------===//

/// Represents the signature of something callable.
struct SignatureInformation {
/// The label of this signature. Mandatory.
std::string label;

/// The documentation of this signature. Optional.
std::string documentation;

/// The parameters of this signature.
std::vector<ParameterInformation> parameters;
};

/// Add support for JSON serialization.
llvm::json::Value toJSON(const SignatureInformation &value);
raw_ostream &operator<<(raw_ostream &os, const SignatureInformation &value);

//===----------------------------------------------------------------------===//
// SignatureHelp
//===----------------------------------------------------------------------===//

/// Represents the signature of a callable.
struct SignatureHelp {
/// The resulting signatures.
std::vector<SignatureInformation> signatures;

/// The active signature.
int activeSignature = 0;

/// The active parameter of the active signature.
int activeParameter = 0;
};

/// Add support for JSON serialization.
llvm::json::Value toJSON(const SignatureHelp &value);

} // namespace lsp
} // namespace mlir

Expand Down
60 changes: 17 additions & 43 deletions mlir/lib/Tools/mlir-lsp-server/MLIRServer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,32 +17,6 @@

using namespace mlir;

/// Returns a language server position for the given source location.
static lsp::Position getPosFromLoc(llvm::SourceMgr &mgr, SMLoc loc) {
std::pair<unsigned, unsigned> lineAndCol = mgr.getLineAndColumn(loc);
lsp::Position pos;
pos.line = lineAndCol.first - 1;
pos.character = lineAndCol.second - 1;
return pos;
}

/// Returns a source location from the given language server position.
static SMLoc getPosFromLoc(llvm::SourceMgr &mgr, lsp::Position pos) {
return mgr.FindLocForLineAndColumn(mgr.getMainFileID(), pos.line + 1,
pos.character);
}

/// Returns a language server range for the given source range.
static lsp::Range getRangeFromLoc(llvm::SourceMgr &mgr, SMRange range) {
return {getPosFromLoc(mgr, range.Start), getPosFromLoc(mgr, range.End)};
}

/// Returns a language server location from the given source range.
static lsp::Location getLocationFromLoc(llvm::SourceMgr &mgr, SMRange range,
const lsp::URIForFile &uri) {
return lsp::Location{uri, getRangeFromLoc(mgr, range)};
}

/// Returns a language server location from the given MLIR file location.
static Optional<lsp::Location> getLocationFromLoc(FileLineColLoc loc) {
llvm::Expected<lsp::URIForFile> sourceURI =
Expand Down Expand Up @@ -348,13 +322,13 @@ MLIRDocument::MLIRDocument(MLIRContext &context, const lsp::URIForFile &uri,
void MLIRDocument::getLocationsOf(const lsp::URIForFile &uri,
const lsp::Position &defPos,
std::vector<lsp::Location> &locations) {
SMLoc posLoc = getPosFromLoc(sourceMgr, defPos);
SMLoc posLoc = defPos.getAsSMLoc(sourceMgr);

// Functor used to check if an SM definition contains the position.
auto containsPosition = [&](const AsmParserState::SMDefinition &def) {
if (!isDefOrUse(def, posLoc))
return false;
locations.push_back(getLocationFromLoc(sourceMgr, def.loc, uri));
locations.emplace_back(uri, sourceMgr, def.loc);
return true;
};

Expand All @@ -367,7 +341,7 @@ void MLIRDocument::getLocationsOf(const lsp::URIForFile &uri,
return collectLocationsFromLoc(op.op->getLoc(), locations, uri);
for (const auto &symUse : op.symbolUses) {
if (contains(symUse, posLoc)) {
locations.push_back(getLocationFromLoc(sourceMgr, op.loc, uri));
locations.emplace_back(uri, sourceMgr, op.loc);
return collectLocationsFromLoc(op.op->getLoc(), locations, uri);
}
}
Expand All @@ -389,12 +363,12 @@ void MLIRDocument::findReferencesOf(const lsp::URIForFile &uri,
// Functor used to append all of the definitions/uses of the given SM
// definition to the reference list.
auto appendSMDef = [&](const AsmParserState::SMDefinition &def) {
references.push_back(getLocationFromLoc(sourceMgr, def.loc, uri));
references.emplace_back(uri, sourceMgr, def.loc);
for (const SMRange &use : def.uses)
references.push_back(getLocationFromLoc(sourceMgr, use, uri));
references.emplace_back(uri, sourceMgr, use);
};

SMLoc posLoc = getPosFromLoc(sourceMgr, pos);
SMLoc posLoc = pos.getAsSMLoc(sourceMgr);

// Check all definitions related to operations.
for (const AsmParserState::OperationDefinition &op : asmState.getOpDefs()) {
Expand All @@ -403,7 +377,7 @@ void MLIRDocument::findReferencesOf(const lsp::URIForFile &uri,
appendSMDef(result.definition);
for (const auto &symUse : op.symbolUses)
if (contains(symUse, posLoc))
references.push_back(getLocationFromLoc(sourceMgr, symUse, uri));
references.emplace_back(uri, sourceMgr, symUse);
return;
}
for (const auto &result : op.resultGroups)
Expand All @@ -413,7 +387,7 @@ void MLIRDocument::findReferencesOf(const lsp::URIForFile &uri,
if (!contains(symUse, posLoc))
continue;
for (const auto &symUse : op.symbolUses)
references.push_back(getLocationFromLoc(sourceMgr, symUse, uri));
references.emplace_back(uri, sourceMgr, symUse);
return;
}
}
Expand All @@ -435,7 +409,7 @@ void MLIRDocument::findReferencesOf(const lsp::URIForFile &uri,

Optional<lsp::Hover> MLIRDocument::findHover(const lsp::URIForFile &uri,
const lsp::Position &hoverPos) {
SMLoc posLoc = getPosFromLoc(sourceMgr, hoverPos);
SMLoc posLoc = hoverPos.getAsSMLoc(sourceMgr);
SMRange hoverRange;

// Check for Hovers on operations and results.
Expand Down Expand Up @@ -482,7 +456,7 @@ Optional<lsp::Hover> MLIRDocument::findHover(const lsp::URIForFile &uri,

Optional<lsp::Hover> MLIRDocument::buildHoverForOperation(
SMRange hoverRange, const AsmParserState::OperationDefinition &op) {
lsp::Hover hover(getRangeFromLoc(sourceMgr, hoverRange));
lsp::Hover hover(lsp::Range(sourceMgr, hoverRange));
llvm::raw_string_ostream os(hover.contents.value);

// Add the operation name to the hover.
Expand Down Expand Up @@ -518,7 +492,7 @@ lsp::Hover MLIRDocument::buildHoverForOperationResult(SMRange hoverRange,
unsigned resultStart,
unsigned resultEnd,
SMLoc posLoc) {
lsp::Hover hover(getRangeFromLoc(sourceMgr, hoverRange));
lsp::Hover hover(lsp::Range(sourceMgr, hoverRange));
llvm::raw_string_ostream os(hover.contents.value);

// Add the parent operation name to the hover.
Expand Down Expand Up @@ -551,7 +525,7 @@ lsp::Hover MLIRDocument::buildHoverForOperationResult(SMRange hoverRange,
lsp::Hover
MLIRDocument::buildHoverForBlock(SMRange hoverRange,
const AsmParserState::BlockDefinition &block) {
lsp::Hover hover(getRangeFromLoc(sourceMgr, hoverRange));
lsp::Hover hover(lsp::Range(sourceMgr, hoverRange));
llvm::raw_string_ostream os(hover.contents.value);

// Print the given block to the hover output stream.
Expand Down Expand Up @@ -583,7 +557,7 @@ MLIRDocument::buildHoverForBlock(SMRange hoverRange,
lsp::Hover MLIRDocument::buildHoverForBlockArgument(
SMRange hoverRange, BlockArgument arg,
const AsmParserState::BlockDefinition &block) {
lsp::Hover hover(getRangeFromLoc(sourceMgr, hoverRange));
lsp::Hover hover(lsp::Range(sourceMgr, hoverRange));
llvm::raw_string_ostream os(hover.contents.value);

// Display the parent operation, block, the argument number, and the type.
Expand Down Expand Up @@ -618,16 +592,16 @@ void MLIRDocument::findDocumentSymbols(
isa<FunctionOpInterface>(op)
? lsp::SymbolKind::Function
: lsp::SymbolKind::Class,
getRangeFromLoc(sourceMgr, def->scopeLoc),
getRangeFromLoc(sourceMgr, def->loc));
lsp::Range(sourceMgr, def->scopeLoc),
lsp::Range(sourceMgr, def->loc));
childSymbols = &symbols.back().children;

} else if (op->hasTrait<OpTrait::SymbolTable>()) {
// Otherwise, if this is a symbol table push an anonymous document symbol.
symbols.emplace_back("<" + op->getName().getStringRef() + ">",
lsp::SymbolKind::Namespace,
getRangeFromLoc(sourceMgr, def->scopeLoc),
getRangeFromLoc(sourceMgr, def->loc));
lsp::Range(sourceMgr, def->scopeLoc),
lsp::Range(sourceMgr, def->loc));
childSymbols = &symbols.back().children;
}
}
Expand Down
12 changes: 12 additions & 0 deletions mlir/lib/Tools/mlir-pdll-lsp-server/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
llvm_add_library(MLIRPdllLspServerLib
LSPServer.cpp
PDLLServer.cpp
MlirPdllLspServerMain.cpp

ADDITIONAL_HEADER_DIRS
${MLIR_MAIN_INCLUDE_DIR}/mlir/Tools/mlir-pdll-lsp-server

LINK_LIBS PUBLIC
MLIRPDLLParser
MLIRLspServerSupportLib
)
286 changes: 286 additions & 0 deletions mlir/lib/Tools/mlir-pdll-lsp-server/LSPServer.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,286 @@
//===- LSPServer.cpp - PDLL Language Server -------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include "LSPServer.h"

#include "../lsp-server-support/Logging.h"
#include "../lsp-server-support/Protocol.h"
#include "../lsp-server-support/Transport.h"
#include "PDLLServer.h"
#include "llvm/ADT/FunctionExtras.h"
#include "llvm/ADT/StringMap.h"

#define DEBUG_TYPE "pdll-lsp-server"

using namespace mlir;
using namespace mlir::lsp;

//===----------------------------------------------------------------------===//
// LSPServer
//===----------------------------------------------------------------------===//

namespace {
struct LSPServer {
LSPServer(PDLLServer &server, JSONTransport &transport)
: server(server), transport(transport) {}

//===--------------------------------------------------------------------===//
// Initialization

void onInitialize(const InitializeParams &params,
Callback<llvm::json::Value> reply);
void onInitialized(const InitializedParams &params);
void onShutdown(const NoParams &params, Callback<std::nullptr_t> reply);

//===--------------------------------------------------------------------===//
// Document Change

void onDocumentDidOpen(const DidOpenTextDocumentParams &params);
void onDocumentDidClose(const DidCloseTextDocumentParams &params);
void onDocumentDidChange(const DidChangeTextDocumentParams &params);

//===--------------------------------------------------------------------===//
// Definitions and References

void onGoToDefinition(const TextDocumentPositionParams &params,
Callback<std::vector<Location>> reply);
void onReference(const ReferenceParams &params,
Callback<std::vector<Location>> reply);

//===--------------------------------------------------------------------===//
// Hover

void onHover(const TextDocumentPositionParams &params,
Callback<Optional<Hover>> reply);

//===--------------------------------------------------------------------===//
// Document Symbols

void onDocumentSymbol(const DocumentSymbolParams &params,
Callback<std::vector<DocumentSymbol>> reply);

//===--------------------------------------------------------------------===//
// Code Completion

void onCompletion(const CompletionParams &params,
Callback<CompletionList> reply);

//===--------------------------------------------------------------------===//
// Signature Help

void onSignatureHelp(const TextDocumentPositionParams &params,
Callback<SignatureHelp> reply);

//===--------------------------------------------------------------------===//
// Fields
//===--------------------------------------------------------------------===//

PDLLServer &server;
JSONTransport &transport;

/// An outgoing notification used to send diagnostics to the client when they
/// are ready to be processed.
OutgoingNotification<PublishDiagnosticsParams> publishDiagnostics;

/// Used to indicate that the 'shutdown' request was received from the
/// Language Server client.
bool shutdownRequestReceived = false;
};
} // namespace

//===----------------------------------------------------------------------===//
// Initialization

void LSPServer::onInitialize(const InitializeParams &params,
Callback<llvm::json::Value> reply) {
// Send a response with the capabilities of this server.
llvm::json::Object serverCaps{
{"textDocumentSync",
llvm::json::Object{
{"openClose", true},
{"change", (int)TextDocumentSyncKind::Full},
{"save", true},
}},
{"completionProvider",
llvm::json::Object{
{"allCommitCharacters",
{" ", "\t", "(", ")", "[", "]", "{", "}", "<",
">", ":", ";", ",", "+", "-", "/", "*", "%",
"^", "&", "#", "?", ".", "=", "\"", "'", "|"}},
{"resolveProvider", false},
{"triggerCharacters", {".", ">", "(", "{", ",", "<", ":", "[", " "}},
}},
{"signatureHelpProvider",
llvm::json::Object{
{"triggerCharacters", {"(", ","}},
}},
{"definitionProvider", true},
{"referencesProvider", true},
{"hoverProvider", true},
{"documentSymbolProvider", true},
};

llvm::json::Object result{
{{"serverInfo", llvm::json::Object{{"name", "mlir-pdll-lsp-server"},
{"version", "0.0.1"}}},
{"capabilities", std::move(serverCaps)}}};
reply(std::move(result));
}
void LSPServer::onInitialized(const InitializedParams &) {}
void LSPServer::onShutdown(const NoParams &, Callback<std::nullptr_t> reply) {
shutdownRequestReceived = true;
reply(nullptr);
}

//===----------------------------------------------------------------------===//
// Document Change

void LSPServer::onDocumentDidOpen(const DidOpenTextDocumentParams &params) {
PublishDiagnosticsParams diagParams(params.textDocument.uri,
params.textDocument.version);
server.addOrUpdateDocument(params.textDocument.uri, params.textDocument.text,
params.textDocument.version,
diagParams.diagnostics);

// Publish any recorded diagnostics.
publishDiagnostics(diagParams);
}
void LSPServer::onDocumentDidClose(const DidCloseTextDocumentParams &params) {
Optional<int64_t> version = server.removeDocument(params.textDocument.uri);
if (!version)
return;

// Empty out the diagnostics shown for this document. This will clear out
// anything currently displayed by the client for this document (e.g. in the
// "Problems" pane of VSCode).
publishDiagnostics(
PublishDiagnosticsParams(params.textDocument.uri, *version));
}
void LSPServer::onDocumentDidChange(const DidChangeTextDocumentParams &params) {
// TODO: We currently only support full document updates, we should refactor
// to avoid this.
if (params.contentChanges.size() != 1)
return;
PublishDiagnosticsParams diagParams(params.textDocument.uri,
params.textDocument.version);
server.addOrUpdateDocument(
params.textDocument.uri, params.contentChanges.front().text,
params.textDocument.version, diagParams.diagnostics);

// Publish any recorded diagnostics.
publishDiagnostics(diagParams);
}

//===----------------------------------------------------------------------===//
// Definitions and References

void LSPServer::onGoToDefinition(const TextDocumentPositionParams &params,
Callback<std::vector<Location>> reply) {
std::vector<Location> locations;
server.getLocationsOf(params.textDocument.uri, params.position, locations);
reply(std::move(locations));
}

void LSPServer::onReference(const ReferenceParams &params,
Callback<std::vector<Location>> reply) {
std::vector<Location> locations;
server.findReferencesOf(params.textDocument.uri, params.position, locations);
reply(std::move(locations));
}

//===----------------------------------------------------------------------===//
// Hover

void LSPServer::onHover(const TextDocumentPositionParams &params,
Callback<Optional<Hover>> reply) {
reply(server.findHover(params.textDocument.uri, params.position));
}

//===----------------------------------------------------------------------===//
// Document Symbols

void LSPServer::onDocumentSymbol(const DocumentSymbolParams &params,
Callback<std::vector<DocumentSymbol>> reply) {
std::vector<DocumentSymbol> symbols;
server.findDocumentSymbols(params.textDocument.uri, symbols);
reply(std::move(symbols));
}

//===----------------------------------------------------------------------===//
// Code Completion

void LSPServer::onCompletion(const CompletionParams &params,
Callback<CompletionList> reply) {
reply(server.getCodeCompletion(params.textDocument.uri, params.position));
}

//===----------------------------------------------------------------------===//
// Signature Help

void LSPServer::onSignatureHelp(const TextDocumentPositionParams &params,
Callback<SignatureHelp> reply) {
reply(server.getSignatureHelp(params.textDocument.uri, params.position));
}

//===----------------------------------------------------------------------===//
// Entry Point
//===----------------------------------------------------------------------===//

LogicalResult mlir::lsp::runPdllLSPServer(PDLLServer &server,
JSONTransport &transport) {
LSPServer lspServer(server, transport);
MessageHandler messageHandler(transport);

// Initialization
messageHandler.method("initialize", &lspServer, &LSPServer::onInitialize);
messageHandler.notification("initialized", &lspServer,
&LSPServer::onInitialized);
messageHandler.method("shutdown", &lspServer, &LSPServer::onShutdown);

// Document Changes
messageHandler.notification("textDocument/didOpen", &lspServer,
&LSPServer::onDocumentDidOpen);
messageHandler.notification("textDocument/didClose", &lspServer,
&LSPServer::onDocumentDidClose);
messageHandler.notification("textDocument/didChange", &lspServer,
&LSPServer::onDocumentDidChange);

// Definitions and References
messageHandler.method("textDocument/definition", &lspServer,
&LSPServer::onGoToDefinition);
messageHandler.method("textDocument/references", &lspServer,
&LSPServer::onReference);

// Hover
messageHandler.method("textDocument/hover", &lspServer, &LSPServer::onHover);

// Document Symbols
messageHandler.method("textDocument/documentSymbol", &lspServer,
&LSPServer::onDocumentSymbol);

// Code Completion
messageHandler.method("textDocument/completion", &lspServer,
&LSPServer::onCompletion);

// Signature Help
messageHandler.method("textDocument/signatureHelp", &lspServer,
&LSPServer::onSignatureHelp);

// Diagnostics
lspServer.publishDiagnostics =
messageHandler.outgoingNotification<PublishDiagnosticsParams>(
"textDocument/publishDiagnostics");

// Run the main loop of the transport.
if (llvm::Error error = transport.run(messageHandler)) {
Logger::error("Transport error: {0}", error);
llvm::consumeError(std::move(error));
return failure();
}
return success(lspServer.shutdownRequestReceived);
}
28 changes: 28 additions & 0 deletions mlir/lib/Tools/mlir-pdll-lsp-server/LSPServer.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
//===- LSPServer.h - PDLL LSP Server ----------------------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#ifndef LIB_MLIR_TOOLS_MLIRPDLLLSPSERVER_LSPSERVER_H
#define LIB_MLIR_TOOLS_MLIRPDLLLSPSERVER_LSPSERVER_H

#include <memory>

namespace mlir {
struct LogicalResult;

namespace lsp {
class JSONTransport;
class PDLLServer;

/// Run the main loop of the LSP server using the given PDLL server and
/// transport.
LogicalResult runPdllLSPServer(PDLLServer &server, JSONTransport &transport);

} // namespace lsp
} // namespace mlir

#endif // LIB_MLIR_TOOLS_MLIRPDLLLSPSERVER_LSPSERVER_H
72 changes: 72 additions & 0 deletions mlir/lib/Tools/mlir-pdll-lsp-server/MlirPdllLspServerMain.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
//===- MlirPdllLspServerMain.cpp - MLIR PDLL Language Server main ---------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include "mlir/Tools/mlir-pdll-lsp-server/MlirPdllLspServerMain.h"
#include "../lsp-server-support/Logging.h"
#include "../lsp-server-support/Transport.h"
#include "LSPServer.h"
#include "PDLLServer.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/Program.h"

using namespace mlir;
using namespace mlir::lsp;

LogicalResult mlir::MlirPdllLspServerMain(int argc, char **argv) {
llvm::cl::opt<JSONStreamStyle> inputStyle{
"input-style",
llvm::cl::desc("Input JSON stream encoding"),
llvm::cl::values(clEnumValN(JSONStreamStyle::Standard, "standard",
"usual LSP protocol"),
clEnumValN(JSONStreamStyle::Delimited, "delimited",
"messages delimited by `// -----` lines, "
"with // comment support")),
llvm::cl::init(JSONStreamStyle::Standard),
llvm::cl::Hidden,
};
llvm::cl::opt<bool> litTest{
"lit-test",
llvm::cl::desc(
"Abbreviation for -input-style=delimited -pretty -log=verbose. "
"Intended to simplify lit tests"),
llvm::cl::init(false),
};
llvm::cl::opt<Logger::Level> logLevel{
"log",
llvm::cl::desc("Verbosity of log messages written to stderr"),
llvm::cl::values(
clEnumValN(Logger::Level::Error, "error", "Error messages only"),
clEnumValN(Logger::Level::Info, "info",
"High level execution tracing"),
clEnumValN(Logger::Level::Debug, "verbose", "Low level details")),
llvm::cl::init(Logger::Level::Info),
};
llvm::cl::opt<bool> prettyPrint{
"pretty",
llvm::cl::desc("Pretty-print JSON output"),
llvm::cl::init(false),
};
llvm::cl::ParseCommandLineOptions(argc, argv, "MLIR LSP Language Server");

if (litTest) {
inputStyle = JSONStreamStyle::Delimited;
logLevel = Logger::Level::Debug;
prettyPrint = true;
}

// Configure the logger.
Logger::setLogLevel(logLevel);

// Configure the transport used for communication.
llvm::sys::ChangeStdinToBinary();
JSONTransport transport(stdin, llvm::outs(), inputStyle, prettyPrint);

// Configure the servers and start the main language server.
PDLLServer server;
return runPdllLSPServer(server, transport);
}
1,295 changes: 1,295 additions & 0 deletions mlir/lib/Tools/mlir-pdll-lsp-server/PDLLServer.cpp

Large diffs are not rendered by default.

79 changes: 79 additions & 0 deletions mlir/lib/Tools/mlir-pdll-lsp-server/PDLLServer.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
//===- PDLLServer.h - PDL General Language Server ---------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#ifndef LIB_MLIR_TOOLS_MLIRPDLLSPSERVER_SERVER_H_
#define LIB_MLIR_TOOLS_MLIRPDLLSPSERVER_SERVER_H_

#include "mlir/Support/LLVM.h"
#include <memory>

namespace mlir {
namespace lsp {
struct Diagnostic;
struct CompletionList;
struct DocumentSymbol;
struct Hover;
struct Location;
struct Position;
struct SignatureHelp;
class URIForFile;

/// This class implements all of the PDLL related functionality necessary for a
/// language server. This class allows for keeping the PDLL specific logic
/// separate from the logic that involves LSP server/client communication.
class PDLLServer {
public:
PDLLServer();
~PDLLServer();

/// Add or update the document, with the provided `version`, at the given URI.
/// Any diagnostics emitted for this document should be added to
/// `diagnostics`.
void addOrUpdateDocument(const URIForFile &uri, StringRef contents,
int64_t version,
std::vector<Diagnostic> &diagnostics);

/// Remove the document with the given uri. Returns the version of the removed
/// document, or None if the uri did not have a corresponding document within
/// the server.
Optional<int64_t> removeDocument(const URIForFile &uri);

/// Return the locations of the object pointed at by the given position.
void getLocationsOf(const URIForFile &uri, const Position &defPos,
std::vector<Location> &locations);

/// Find all references of the object pointed at by the given position.
void findReferencesOf(const URIForFile &uri, const Position &pos,
std::vector<Location> &references);

/// Find a hover description for the given hover position, or None if one
/// couldn't be found.
Optional<Hover> findHover(const URIForFile &uri, const Position &hoverPos);

/// Find all of the document symbols within the given file.
void findDocumentSymbols(const URIForFile &uri,
std::vector<DocumentSymbol> &symbols);

/// Get the code completion list for the position within the given file.
CompletionList getCodeCompletion(const URIForFile &uri,
const Position &completePos);

/// Get the signature help for the position within the given file.
SignatureHelp getSignatureHelp(const URIForFile &uri,
const Position &helpPos);

private:
struct Impl;

std::unique_ptr<Impl> impl;
};

} // namespace lsp
} // namespace mlir

#endif // LIB_MLIR_TOOLS_MLIRPDLLSPSERVER_SERVER_H_
12 changes: 3 additions & 9 deletions mlir/python/mlir/dialects/_pdl_ops_ext.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,14 +59,12 @@ class ApplyNativeConstraintOp:
def __init__(self,
name: Union[str, StringAttr],
args: Sequence[Union[OpView, Operation, Value]] = [],
params: Optional[Union[ArrayAttr, Sequence[Attribute]]] = None,
*,
loc=None,
ip=None):
name = _get_str_attr(name)
args = _get_values(args)
params = params if params is None else _get_array_attr(params)
super().__init__(name, args, params, loc=loc, ip=ip)
super().__init__(name, args, loc=loc, ip=ip)


class ApplyNativeRewriteOp:
Expand All @@ -76,14 +74,12 @@ def __init__(self,
results: Sequence[Type],
name: Union[str, StringAttr],
args: Sequence[Union[OpView, Operation, Value]] = [],
params: Optional[Union[ArrayAttr, Sequence[Attribute]]] = None,
*,
loc=None,
ip=None):
name = _get_str_attr(name)
args = _get_values(args)
params = params if params is None else _get_array_attr(params)
super().__init__(results, name, args, params, loc=loc, ip=ip)
super().__init__(results, name, args, loc=loc, ip=ip)


class AttributeOp:
Expand Down Expand Up @@ -236,15 +232,13 @@ def __init__(self,
root: Optional[Union[OpView, Operation, Value]] = None,
name: Optional[Union[StringAttr, str]] = None,
args: Sequence[Union[OpView, Operation, Value]] = [],
params: Optional[Union[ArrayAttr, Sequence[Attribute]]] = None,
*,
loc=None,
ip=None):
root = root if root is None else _get_value(root)
name = name if name is None else _get_str_attr(name)
args = _get_values(args)
params = params if params is None else _get_array_attr(params)
super().__init__(root, name, args, params, loc=loc, ip=ip)
super().__init__(root, name, args, loc=loc, ip=ip)

def add_body(self):
"""Add body (block) to the rewrite."""
Expand Down
1 change: 1 addition & 0 deletions mlir/test/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ set(MLIR_TEST_DEPENDS
mlir-capi-pdl-test
mlir-linalg-ods-yaml-gen
mlir-lsp-server
mlir-pdll-lsp-server
mlir-opt
mlir-pdll
mlir-reduce
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,15 +64,15 @@ module @constraints {
// CHECK-DAG: %[[INPUT:.*]] = pdl_interp.get_operand 0 of %[[ROOT]]
// CHECK-DAG: %[[INPUT1:.*]] = pdl_interp.get_operand 1 of %[[ROOT]]
// CHECK-DAG: %[[RESULT:.*]] = pdl_interp.get_result 0 of %[[ROOT]]
// CHECK: pdl_interp.apply_constraint "multi_constraint" [true](%[[INPUT]], %[[INPUT1]], %[[RESULT]]
// CHECK: pdl_interp.apply_constraint "multi_constraint"(%[[INPUT]], %[[INPUT1]], %[[RESULT]]

pdl.pattern : benefit(1) {
%input0 = operand
%input1 = operand
%root = operation(%input0, %input1 : !pdl.value, !pdl.value)
%result0 = result 0 of %root

pdl.apply_native_constraint "multi_constraint"[true](%input0, %input1, %result0 : !pdl.value, !pdl.value, !pdl.value)
pdl.apply_native_constraint "multi_constraint"(%input0, %input1, %result0 : !pdl.value, !pdl.value, !pdl.value)
rewrite %root with "rewriter"
}
}
Expand Down Expand Up @@ -393,11 +393,11 @@ module @predicate_ordering {
// CHECK: %[[RESULT:.*]] = pdl_interp.get_result 0 of %[[ROOT]]
// CHECK-NEXT: pdl_interp.is_not_null %[[RESULT]]
// CHECK: %[[RESULT_TYPE:.*]] = pdl_interp.get_value_type of %[[RESULT]]
// CHECK: pdl_interp.apply_constraint "typeConstraint" [](%[[RESULT_TYPE]]
// CHECK: pdl_interp.apply_constraint "typeConstraint"(%[[RESULT_TYPE]]

pdl.pattern : benefit(1) {
%resultType = type
pdl.apply_native_constraint "typeConstraint"[](%resultType : !pdl.type)
pdl.apply_native_constraint "typeConstraint"(%resultType : !pdl.type)
%root = operation -> (%resultType : !pdl.type)
rewrite %root with "rewriter"
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,11 @@
module @external {
// CHECK: module @rewriters
// CHECK: func @pdl_generated_rewriter(%[[ROOT:.*]]: !pdl.operation, %[[INPUT:.*]]: !pdl.value)
// CHECK: pdl_interp.apply_rewrite "rewriter" [true](%[[ROOT]], %[[INPUT]] : !pdl.operation, !pdl.value)
// CHECK: pdl_interp.apply_rewrite "rewriter"(%[[ROOT]], %[[INPUT]] : !pdl.operation, !pdl.value)
pdl.pattern : benefit(1) {
%input = operand
%root = operation "foo.op"(%input : !pdl.value)
rewrite %root with "rewriter"[true](%input : !pdl.value)
rewrite %root with "rewriter"(%input : !pdl.value)
}
}

Expand Down Expand Up @@ -191,13 +191,13 @@ module @replace_with_no_results {
module @apply_native_rewrite {
// CHECK: module @rewriters
// CHECK: func @pdl_generated_rewriter(%[[ROOT:.*]]: !pdl.operation)
// CHECK: %[[TYPE:.*]] = pdl_interp.apply_rewrite "functor" [true](%[[ROOT]] : !pdl.operation) : !pdl.type
// CHECK: %[[TYPE:.*]] = pdl_interp.apply_rewrite "functor"(%[[ROOT]] : !pdl.operation) : !pdl.type
// CHECK: pdl_interp.create_operation "foo.op" -> (%[[TYPE]] : !pdl.type)
pdl.pattern : benefit(1) {
%type = type
%root = operation "foo.op" -> (%type : !pdl.type)
rewrite %root {
%newType = apply_native_rewrite "functor"[true](%root : !pdl.operation) : !pdl.type
%newType = apply_native_rewrite "functor"(%root : !pdl.operation) : !pdl.type
%newOp = operation "foo.op" -> (%newType : !pdl.type)
}
}
Expand Down
17 changes: 2 additions & 15 deletions mlir/test/Dialect/PDL/invalid.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ pdl.pattern : benefit(1) {
%op = operation "foo.op"

// expected-error@below {{expected at least one argument}}
"pdl.apply_native_constraint"() {name = "foo", params = []} : () -> ()
"pdl.apply_native_constraint"() {name = "foo"} : () -> ()
rewrite %op with "rewriter"
}

Expand All @@ -22,7 +22,7 @@ pdl.pattern : benefit(1) {
%op = operation "foo.op"
rewrite %op {
// expected-error@below {{expected at least one argument}}
"pdl.apply_native_rewrite"() {name = "foo", params = []} : () -> ()
"pdl.apply_native_rewrite"() {name = "foo"} : () -> ()
}
}

Expand Down Expand Up @@ -264,19 +264,6 @@ pdl.pattern : benefit(1) {

// -----

pdl.pattern : benefit(1) {
%op = operation "foo.op"

// expected-error@below {{expected no external constant parameters when the rewrite is specified inline}}
"pdl.rewrite"(%op) ({
^bb1:
}) {
operand_segment_sizes = dense<[1,0]> : vector<2xi32>,
externalConstParams = []} : (!pdl.operation) -> ()
}

// -----

pdl.pattern : benefit(1) {
%op = operation "foo.op"

Expand Down
19 changes: 2 additions & 17 deletions mlir/test/Dialect/PDL/ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -27,21 +27,6 @@ pdl.pattern @rewrite_with_args : benefit(1) {

// -----

pdl.pattern @rewrite_with_params : benefit(1) {
%root = operation
rewrite %root with "rewriter"["I am param"]
}

// -----

pdl.pattern @rewrite_with_args_and_params : benefit(1) {
%input = operand
%root = operation(%input : !pdl.value)
rewrite %root with "rewriter"["I am param"](%input : !pdl.value)
}

// -----

pdl.pattern @rewrite_multi_root_optimal : benefit(2) {
%input1 = operand
%input2 = operand
Expand All @@ -52,7 +37,7 @@ pdl.pattern @rewrite_multi_root_optimal : benefit(2) {
%op2 = operation(%input2 : !pdl.value) -> (%type : !pdl.type)
%val2 = result 0 of %op2
%root2 = operation(%val1, %val2 : !pdl.value, !pdl.value)
rewrite with "rewriter"["I am param"](%root1, %root2 : !pdl.operation, !pdl.operation)
rewrite with "rewriter"(%root1, %root2 : !pdl.operation, !pdl.operation)
}

// -----
Expand All @@ -67,7 +52,7 @@ pdl.pattern @rewrite_multi_root_forced : benefit(2) {
%op2 = operation(%input2 : !pdl.value) -> (%type : !pdl.type)
%val2 = result 0 of %op2
%root2 = operation(%val1, %val2 : !pdl.value, !pdl.value)
rewrite %root1 with "rewriter"["I am param"](%root2 : !pdl.operation)
rewrite %root1 with "rewriter"(%root2 : !pdl.operation)
}

// -----
Expand Down
4 changes: 2 additions & 2 deletions mlir/test/Rewrite/pdl-bytecode.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ module @patterns {
module @rewriters {
pdl_interp.func @success(%root : !pdl.operation) {
%operand = pdl_interp.get_operand 0 of %root
pdl_interp.apply_rewrite "rewriter"[42](%root, %operand : !pdl.operation, !pdl.value)
pdl_interp.apply_rewrite "rewriter"(%root, %operand : !pdl.operation, !pdl.value)
pdl_interp.finalize
}
}
Expand All @@ -99,7 +99,7 @@ module @patterns {
// CHECK-LABEL: test.apply_rewrite_1
// CHECK: %[[INPUT:.*]] = "test.op_input"
// CHECK-NOT: "test.op"
// CHECK: "test.success"(%[[INPUT]]) {constantParams = [42]}
// CHECK: "test.success"(%[[INPUT]])
module @ir attributes { test.apply_rewrite_1 } {
%input = "test.op_input"() : () -> i32
"test.op"(%input) : (i32) -> ()
Expand Down
18 changes: 6 additions & 12 deletions mlir/test/lib/Rewrite/TestPDLByteCode.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,19 +15,16 @@ using namespace mlir;

/// Custom constraint invoked from PDL.
static LogicalResult customSingleEntityConstraint(PDLValue value,
ArrayAttr constantParams,
PatternRewriter &rewriter) {
Operation *rootOp = value.cast<Operation *>();
return success(rootOp->getName().getStringRef() == "test.op");
}
static LogicalResult customMultiEntityConstraint(ArrayRef<PDLValue> values,
ArrayAttr constantParams,
PatternRewriter &rewriter) {
return customSingleEntityConstraint(values[1], constantParams, rewriter);
return customSingleEntityConstraint(values[1], rewriter);
}
static LogicalResult
customMultiEntityVariadicConstraint(ArrayRef<PDLValue> values,
ArrayAttr constantParams,
PatternRewriter &rewriter) {
if (llvm::any_of(values, [](const PDLValue &value) { return !value; }))
return failure();
Expand All @@ -39,32 +36,29 @@ customMultiEntityVariadicConstraint(ArrayRef<PDLValue> values,
}

// Custom creator invoked from PDL.
static void customCreate(ArrayRef<PDLValue> args, ArrayAttr constantParams,
PatternRewriter &rewriter, PDLResultList &results) {
static void customCreate(ArrayRef<PDLValue> args, PatternRewriter &rewriter,
PDLResultList &results) {
results.push_back(rewriter.createOperation(
OperationState(args[0].cast<Operation *>()->getLoc(), "test.success")));
}
static void customVariadicResultCreate(ArrayRef<PDLValue> args,
ArrayAttr constantParams,
PatternRewriter &rewriter,
PDLResultList &results) {
Operation *root = args[0].cast<Operation *>();
results.push_back(root->getOperands());
results.push_back(root->getOperands().getTypes());
}
static void customCreateType(ArrayRef<PDLValue> args, ArrayAttr constantParams,
PatternRewriter &rewriter,
static void customCreateType(ArrayRef<PDLValue> args, PatternRewriter &rewriter,
PDLResultList &results) {
results.push_back(rewriter.getF32Type());
}

/// Custom rewriter invoked from PDL.
static void customRewriter(ArrayRef<PDLValue> args, ArrayAttr constantParams,
PatternRewriter &rewriter, PDLResultList &results) {
static void customRewriter(ArrayRef<PDLValue> args, PatternRewriter &rewriter,
PDLResultList &results) {
Operation *root = args[0].cast<Operation *>();
OperationState successOpState(root->getLoc(), "test.success");
successOpState.addOperands(args[1].cast<Value>());
successOpState.addAttribute("constantParams", constantParams);
rewriter.createOperation(successOpState);
rewriter.eraseOp(root);
}
Expand Down
205 changes: 205 additions & 0 deletions mlir/test/mlir-pdll-lsp-server/completion.test
Original file line number Diff line number Diff line change
@@ -0,0 +1,205 @@
// RUN: mlir-pdll-lsp-server -lit-test < %s | FileCheck -strict-whitespace %s
{"jsonrpc":"2.0","id":0,"method":"initialize","params":{"processId":123,"rootPath":"pdll","capabilities":{},"trace":"off"}}
// -----
{"jsonrpc":"2.0","method":"textDocument/didOpen","params":{"textDocument":{
"uri":"test:///foo.pdll",
"languageId":"pdll",
"version":1,
"text":"Constraint ValueCst(value: Value);\nConstraint Cst();\nPattern FooPattern with benefit(1) {\nlet tuple = (value1 = _: Op, _: Op<test.op>);\nerase tuple.value1;\n}"
}}}
// -----
{"jsonrpc":"2.0","id":1,"method":"textDocument/completion","params":{
"textDocument":{"uri":"test:///foo.pdll"},
"position":{"line":4,"character":12}
}}
// CHECK: "id": 1
// CHECK-NEXT: "jsonrpc": "2.0",
// CHECK-NEXT: "result": {
// CHECK-NEXT: "isIncomplete": false,
// CHECK-NEXT: "items": [
// CHECK-NEXT: {
// CHECK-NEXT: "detail": "0: Op",
// CHECK-NEXT: "filterText": "0",
// CHECK-NEXT: "insertText": "0",
// CHECK-NEXT: "insertTextFormat": 1,
// CHECK-NEXT: "kind": 5,
// CHECK-NEXT: "label": "0 (field #0)",
// CHECK-NEXT: "sortText": "0"
// CHECK-NEXT: },
// CHECK-NEXT: {
// CHECK-NEXT: "detail": "0: Op",
// CHECK-NEXT: "filterText": "value1 (field #0)",
// CHECK-NEXT: "insertText": "value1",
// CHECK-NEXT: "insertTextFormat": 1,
// CHECK-NEXT: "kind": 5,
// CHECK-NEXT: "label": "value1 (field #0)",
// CHECK-NEXT: "sortText": "0"
// CHECK-NEXT: },
// CHECK-NEXT: {
// CHECK-NEXT: "detail": "1: Op<test.op>",
// CHECK-NEXT: "filterText": "1",
// CHECK-NEXT: "insertText": "1",
// CHECK-NEXT: "insertTextFormat": 1,
// CHECK-NEXT: "kind": 5,
// CHECK-NEXT: "label": "1 (field #1)",
// CHECK-NEXT: "sortText": "1"
// CHECK-NEXT: }
// CHECK-NEXT: ]
// CHECK-NEXT: }
// -----
{"jsonrpc":"2.0","id":1,"method":"textDocument/completion","params":{
"textDocument":{"uri":"test:///foo.pdll"},
"position":{"line":2,"character":23}
}}
// CHECK: "id": 1
// CHECK-NEXT: "jsonrpc": "2.0",
// CHECK-NEXT: "result": {
// CHECK-NEXT: "isIncomplete": false,
// CHECK-NEXT: "items": [
// CHECK-NEXT: {
// CHECK-NEXT: "detail": "pattern metadata",
// CHECK-NEXT: "documentation": {
// CHECK-NEXT: "kind": "markdown",
// CHECK-NEXT: "value": "The `benefit` of matching the pattern."
// CHECK-NEXT: },
// CHECK-NEXT: "insertText": "benefit($1)",
// CHECK-NEXT: "insertTextFormat": 2,
// CHECK-NEXT: "kind": 7,
// CHECK-NEXT: "label": "benefit"
// CHECK-NEXT: },
// CHECK-NEXT: {
// CHECK-NEXT: "detail": "pattern metadata",
// CHECK-NEXT: "documentation": {
// CHECK-NEXT: "kind": "markdown",
// CHECK-NEXT: "value": "The pattern properly handles recursive application."
// CHECK-NEXT: },
// CHECK-NEXT: "insertTextFormat": 1,
// CHECK-NEXT: "kind": 7,
// CHECK-NEXT: "label": "recursion"
// CHECK-NEXT: }
// CHECK-NEXT: ]
// CHECK-NEXT: }
// -----
{"jsonrpc":"2.0","id":1,"method":"textDocument/completion","params":{
"textDocument":{"uri":"test:///foo.pdll"},
"position":{"line":3,"character":24}
}}
// CHECK: "id": 1
// CHECK-NEXT: "jsonrpc": "2.0",
// CHECK-NEXT: "result": {
// CHECK-NEXT: "isIncomplete": false,
// CHECK-NEXT: "items": [
// CHECK-NEXT: {
// CHECK-NEXT: "detail": "Attr constraint",
// CHECK-NEXT: "documentation": {
// CHECK-NEXT: "kind": "markdown",
// CHECK-NEXT: "value": "A single entity core constraint of type `mlir::Attribute`"
// CHECK-NEXT: },
// CHECK-NEXT: "insertTextFormat": 1,
// CHECK-NEXT: "kind": 7,
// CHECK-NEXT: "label": "Attr",
// CHECK-NEXT: "sortText": "0"
// CHECK-NEXT: },
// CHECK-NEXT: {
// CHECK-NEXT: "detail": "Op constraint",
// CHECK-NEXT: "documentation": {
// CHECK-NEXT: "kind": "markdown",
// CHECK-NEXT: "value": "A single entity core constraint of type `mlir::Operation *`"
// CHECK-NEXT: },
// CHECK-NEXT: "insertTextFormat": 1,
// CHECK-NEXT: "kind": 7,
// CHECK-NEXT: "label": "Op",
// CHECK-NEXT: "sortText": "0"
// CHECK-NEXT: },
// CHECK-NEXT: {
// CHECK-NEXT: "detail": "Value constraint",
// CHECK-NEXT: "documentation": {
// CHECK-NEXT: "kind": "markdown",
// CHECK-NEXT: "value": "A single entity core constraint of type `mlir::Value`"
// CHECK-NEXT: },
// CHECK-NEXT: "insertTextFormat": 1,
// CHECK-NEXT: "kind": 7,
// CHECK-NEXT: "label": "Value",
// CHECK-NEXT: "sortText": "0"
// CHECK-NEXT: },
// CHECK-NEXT: {
// CHECK-NEXT: "detail": "ValueRange constraint",
// CHECK-NEXT: "documentation": {
// CHECK-NEXT: "kind": "markdown",
// CHECK-NEXT: "value": "A single entity core constraint of type `mlir::ValueRange`"
// CHECK-NEXT: },
// CHECK-NEXT: "insertTextFormat": 1,
// CHECK-NEXT: "kind": 7,
// CHECK-NEXT: "label": "ValueRange",
// CHECK-NEXT: "sortText": "0"
// CHECK-NEXT: },
// CHECK-NEXT: {
// CHECK-NEXT: "detail": "Type constraint",
// CHECK-NEXT: "documentation": {
// CHECK-NEXT: "kind": "markdown",
// CHECK-NEXT: "value": "A single entity core constraint of type `mlir::Type`"
// CHECK-NEXT: },
// CHECK-NEXT: "insertTextFormat": 1,
// CHECK-NEXT: "kind": 7,
// CHECK-NEXT: "label": "Type",
// CHECK-NEXT: "sortText": "0"
// CHECK-NEXT: },
// CHECK-NEXT: {
// CHECK-NEXT: "detail": "TypeRange constraint",
// CHECK-NEXT: "documentation": {
// CHECK-NEXT: "kind": "markdown",
// CHECK-NEXT: "value": "A single entity core constraint of type `mlir::TypeRange`"
// CHECK-NEXT: },
// CHECK-NEXT: "insertTextFormat": 1,
// CHECK-NEXT: "kind": 7,
// CHECK-NEXT: "label": "TypeRange",
// CHECK-NEXT: "sortText": "0"
// CHECK-NEXT: },
// CHECK-NEXT: {
// CHECK-NEXT: "detail": "Attr<type> constraint",
// CHECK-NEXT: "documentation": {
// CHECK-NEXT: "kind": "markdown",
// CHECK-NEXT: "value": "A single entity core constraint of type `mlir::Attribute`"
// CHECK-NEXT: },
// CHECK-NEXT: "insertText": "Attr<$1>",
// CHECK-NEXT: "insertTextFormat": 2,
// CHECK-NEXT: "kind": 7,
// CHECK-NEXT: "label": "Attr<type>",
// CHECK-NEXT: "sortText": "0"
// CHECK-NEXT: },
// CHECK-NEXT: {
// CHECK-NEXT: "detail": "Value<type> constraint",
// CHECK-NEXT: "documentation": {
// CHECK-NEXT: "kind": "markdown",
// CHECK-NEXT: "value": "A single entity core constraint of type `mlir::Value`"
// CHECK-NEXT: },
// CHECK-NEXT: "insertText": "Value<$1>",
// CHECK-NEXT: "insertTextFormat": 2,
// CHECK-NEXT: "kind": 7,
// CHECK-NEXT: "label": "Value<type>",
// CHECK-NEXT: "sortText": "0"
// CHECK-NEXT: },
// CHECK-NEXT: {
// CHECK-NEXT: "detail": "ValueRange<type> constraint",
// CHECK-NEXT: "documentation": {
// CHECK-NEXT: "kind": "markdown",
// CHECK-NEXT: "value": "A single entity core constraint of type `mlir::ValueRange`"
// CHECK-NEXT: },
// CHECK-NEXT: "insertText": "ValueRange<$1>",
// CHECK-NEXT: "insertTextFormat": 2,
// CHECK-NEXT: "kind": 7,
// CHECK-NEXT: "label": "ValueRange<type>",
// CHECK-NEXT: "sortText": "0"
// CHECK-NEXT: },
// CHECK-NEXT: {
// CHECK-NEXT: "detail": "(value: Value) -> Tuple<>",
// CHECK-NEXT: "kind": 8,
// CHECK-NEXT: "label": "ValueCst",
// CHECK-NEXT: "sortText": "2_ValueCst"
// CHECK-NEXT: }
// CHECK-NEXT: ]
// CHECK-NEXT: }
// -----
{"jsonrpc":"2.0","id":3,"method":"shutdown"}
// -----
{"jsonrpc":"2.0","method":"exit"}
37 changes: 37 additions & 0 deletions mlir/test/mlir-pdll-lsp-server/definition-split-file.test
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
// RUN: mlir-pdll-lsp-server -lit-test < %s | FileCheck %s
// This test checks support for split files by attempting to find the definition
// of a symbol in a split file. The interesting part of this test is that the
// file chunk before the one we are looking for the definition in has an error.
{"jsonrpc":"2.0","id":0,"method":"initialize","params":{"processId":123,"rootPath":"pdll","capabilities":{},"trace":"off"}}
// -----
{"jsonrpc":"2.0","method":"textDocument/didOpen","params":{"textDocument":{
"uri":"test:///foo.pdll",
"languageId":"pdll",
"version":1,
"text":"Pattern Foo {\n// -----\nPattern {\n erase root: Op<toy.test>;\n }"
}}}
// -----
{"jsonrpc":"2.0","id":1,"method":"textDocument/definition","params":{
"textDocument":{"uri":"test:///foo.pdll"},
"position":{"line":3,"character":12}
}}
// CHECK: "id": 1
// CHECK-NEXT: "jsonrpc": "2.0",
// CHECK-NEXT: "result": [
// CHECK-NEXT: {
// CHECK-NEXT: "range": {
// CHECK-NEXT: "end": {
// CHECK-NEXT: "character": 12,
// CHECK-NEXT: "line": 3
// CHECK-NEXT: },
// CHECK-NEXT: "start": {
// CHECK-NEXT: "character": 8,
// CHECK-NEXT: "line": 3
// CHECK-NEXT: }
// CHECK-NEXT: },
// CHECK-NEXT: "uri": "{{.*}}/foo.pdll"
// CHECK-NEXT: }
// -----
{"jsonrpc":"2.0","id":3,"method":"shutdown"}
// -----
{"jsonrpc":"2.0","method":"exit"}
55 changes: 55 additions & 0 deletions mlir/test/mlir-pdll-lsp-server/definition.test
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
// RUN: mlir-pdll-lsp-server -lit-test < %s | FileCheck %s
{"jsonrpc":"2.0","id":0,"method":"initialize","params":{"processId":123,"rootPath":"pdll","capabilities":{},"trace":"off"}}
// -----
{"jsonrpc":"2.0","method":"textDocument/didOpen","params":{"textDocument":{
"uri":"test:///foo.pdll",
"languageId":"pdll",
"version":1,
"text":"Pattern FooPattern {\nlet root: Op<toy.test>;\nerase root;\n}"
}}}
// -----
{"jsonrpc":"2.0","id":1,"method":"textDocument/definition","params":{
"textDocument":{"uri":"test:///foo.pdll"},
"position":{"line":0,"character":12}
}}
// CHECK: "id": 1
// CHECK-NEXT: "jsonrpc": "2.0",
// CHECK-NEXT: "result": [
// CHECK-NEXT: {
// CHECK-NEXT: "range": {
// CHECK-NEXT: "end": {
// CHECK-NEXT: "character": 18,
// CHECK-NEXT: "line": 0
// CHECK-NEXT: },
// CHECK-NEXT: "start": {
// CHECK-NEXT: "character": 8,
// CHECK-NEXT: "line": 0
// CHECK-NEXT: }
// CHECK-NEXT: },
// CHECK-NEXT: "uri": "{{.*}}/foo.pdll"
// CHECK-NEXT: }
// -----
{"jsonrpc":"2.0","id":2,"method":"textDocument/definition","params":{
"textDocument":{"uri":"test:///foo.pdll"},
"position":{"line":2,"character":8}
}}
// CHECK: "id": 2
// CHECK-NEXT: "jsonrpc": "2.0",
// CHECK-NEXT: "result": [
// CHECK-NEXT: {
// CHECK-NEXT: "range": {
// CHECK-NEXT: "end": {
// CHECK-NEXT: "character": 8,
// CHECK-NEXT: "line": 1
// CHECK-NEXT: },
// CHECK-NEXT: "start": {
// CHECK-NEXT: "character": 4,
// CHECK-NEXT: "line": 1
// CHECK-NEXT: }
// CHECK-NEXT: },
// CHECK-NEXT: "uri": "{{.*}}/foo.pdll"
// CHECK-NEXT: }
// -----
{"jsonrpc":"2.0","id":3,"method":"shutdown"}
// -----
{"jsonrpc":"2.0","method":"exit"}
93 changes: 93 additions & 0 deletions mlir/test/mlir-pdll-lsp-server/document-symbols.test
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
// RUN: mlir-pdll-lsp-server -lit-test < %s | FileCheck -strict-whitespace %s
{"jsonrpc":"2.0","id":0,"method":"initialize","params":{"processId":123,"rootUri":"test:///workspace","capabilities":{"textDocument":{"documentSymbol":{"hierarchicalDocumentSymbolSupport":true}}},"trace":"off"}}
// -----
{"jsonrpc":"2.0","method":"textDocument/didOpen","params":{"textDocument":{
"uri":"test:///foo.pdll",
"languageId":"pdll",
"version":1,
"text":"Pattern Foo {\nerase op<foo.op>;\n}\nConstraint Cst() -> Op{\nreturn op<toy.test>;\n}\n\nRewrite SomeRewrite() -> Op {\nreturn op: Op;\n}"
}}}
// -----
{"jsonrpc":"2.0","id":1,"method":"textDocument/documentSymbol","params":{
"textDocument":{"uri":"test:///foo.pdll"}
}}
// CHECK: "id": 1
// CHECK-NEXT: "jsonrpc": "2.0",
// CHECK-NEXT: "result": [
// CHECK-NEXT: {
// CHECK-NEXT: "kind": 5,
// CHECK-NEXT: "name": "Foo",
// CHECK-NEXT: "range": {
// CHECK-NEXT: "end": {
// CHECK-NEXT: "character": 1,
// CHECK-NEXT: "line": 2
// CHECK-NEXT: },
// CHECK-NEXT: "start": {
// CHECK-NEXT: "character": 8,
// CHECK-NEXT: "line": 0
// CHECK-NEXT: }
// CHECK-NEXT: },
// CHECK-NEXT: "selectionRange": {
// CHECK-NEXT: "end": {
// CHECK-NEXT: "character": 11,
// CHECK-NEXT: "line": 0
// CHECK-NEXT: },
// CHECK-NEXT: "start": {
// CHECK-NEXT: "character": 8,
// CHECK-NEXT: "line": 0
// CHECK-NEXT: }
// CHECK-NEXT: }
// CHECK-NEXT: },
// CHECK-NEXT: {
// CHECK-NEXT: "kind": 12,
// CHECK-NEXT: "name": "Cst",
// CHECK-NEXT: "range": {
// CHECK-NEXT: "end": {
// CHECK-NEXT: "character": 14,
// CHECK-NEXT: "line": 3
// CHECK-NEXT: },
// CHECK-NEXT: "start": {
// CHECK-NEXT: "character": 11,
// CHECK-NEXT: "line": 3
// CHECK-NEXT: }
// CHECK-NEXT: },
// CHECK-NEXT: "selectionRange": {
// CHECK-NEXT: "end": {
// CHECK-NEXT: "character": 14,
// CHECK-NEXT: "line": 3
// CHECK-NEXT: },
// CHECK-NEXT: "start": {
// CHECK-NEXT: "character": 11,
// CHECK-NEXT: "line": 3
// CHECK-NEXT: }
// CHECK-NEXT: }
// CHECK-NEXT: },
// CHECK-NEXT: {
// CHECK-NEXT: "kind": 12,
// CHECK-NEXT: "name": "SomeRewrite",
// CHECK-NEXT: "range": {
// CHECK-NEXT: "end": {
// CHECK-NEXT: "character": 19,
// CHECK-NEXT: "line": 7
// CHECK-NEXT: },
// CHECK-NEXT: "start": {
// CHECK-NEXT: "character": 8,
// CHECK-NEXT: "line": 7
// CHECK-NEXT: }
// CHECK-NEXT: },
// CHECK-NEXT: "selectionRange": {
// CHECK-NEXT: "end": {
// CHECK-NEXT: "character": 19,
// CHECK-NEXT: "line": 7
// CHECK-NEXT: },
// CHECK-NEXT: "start": {
// CHECK-NEXT: "character": 8,
// CHECK-NEXT: "line": 7
// CHECK-NEXT: }
// CHECK-NEXT: }
// CHECK-NEXT: }
// CHECK-NEXT: ]
// -----
{"jsonrpc":"2.0","id":3,"method":"shutdown"}
// -----
{"jsonrpc":"2.0","method":"exit"}
7 changes: 7 additions & 0 deletions mlir/test/mlir-pdll-lsp-server/exit-eof.test
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
// RUN: not mlir-pdll-lsp-server < %s 2> %t.err
// RUN: FileCheck %s < %t.err
//
// No LSP messages here, just let mlir-pdll-lsp-server see the end-of-file
// CHECK: Transport error:
// (Typically "Transport error: Input/output error" but platform-dependent).

6 changes: 6 additions & 0 deletions mlir/test/mlir-pdll-lsp-server/exit-with-shutdown.test
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
// RUN: mlir-pdll-lsp-server -lit-test < %s
{"jsonrpc":"2.0","id":0,"method":"initialize","params":{"processId":123,"rootPath":"pdll","capabilities":{},"trace":"off"}}
// -----
{"jsonrpc":"2.0","id":3,"method":"shutdown"}
// -----
{"jsonrpc":"2.0","method":"exit"}
4 changes: 4 additions & 0 deletions mlir/test/mlir-pdll-lsp-server/exit-without-shutdown.test
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
// RUN: not mlir-pdll-lsp-server -lit-test < %s
{"jsonrpc":"2.0","id":0,"method":"initialize","params":{"processId":123,"rootPath":"pdll","capabilities":{},"trace":"off"}}
// -----
{"jsonrpc":"2.0","method":"exit"}
Loading