189 changes: 189 additions & 0 deletions mlir/include/mlir/Tools/PDLL/ODS/Operation.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,189 @@
//===- Operation.h - MLIR PDLL ODS Operation --------------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#ifndef MLIR_TOOLS_PDLL_ODS_OPERATION_H_
#define MLIR_TOOLS_PDLL_ODS_OPERATION_H_

#include <string>

#include "mlir/Support/LLVM.h"
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/StringRef.h"
#include "llvm/Support/SMLoc.h"

namespace mlir {
namespace pdll {
namespace ods {
class AttributeConstraint;
class TypeConstraint;

//===----------------------------------------------------------------------===//
// VariableLengthKind
//===----------------------------------------------------------------------===//

enum VariableLengthKind { Single, Optional, Variadic };

//===----------------------------------------------------------------------===//
// Attribute
//===----------------------------------------------------------------------===//

/// This class provides an ODS representation of a specific operation attribute.
/// This includes the name, optionality, and more.
class Attribute {
public:
/// Return the name of this operand.
StringRef getName() const { return name; }

/// Return true if this attribute is optional.
bool isOptional() const { return optional; }

/// Return the constraint of this attribute.
const AttributeConstraint &getConstraint() const { return constraint; }

private:
Attribute(StringRef name, bool optional,
const AttributeConstraint &constraint)
: name(name.str()), optional(optional), constraint(constraint) {}

/// The ODS name of the attribute.
std::string name;

/// A flag indicating if the attribute is optional.
bool optional;

/// The ODS constraint of this attribute.
const AttributeConstraint &constraint;

/// Allow access to the private constructor.
friend class Operation;
};

//===----------------------------------------------------------------------===//
// OperandOrResult
//===----------------------------------------------------------------------===//

/// This class provides an ODS representation of a specific operation operand or
/// result. This includes the name, variable length flags, and more.
class OperandOrResult {
public:
/// Return the name of this value.
StringRef getName() const { return name; }

/// Returns true if this value is variadic (Note this is false if the value is
/// Optional).
bool isVariadic() const {
return variableLengthKind == VariableLengthKind::Variadic;
}

/// Returns the variable length kind of this value.
VariableLengthKind getVariableLengthKind() const {
return variableLengthKind;
}

/// Return the constraint of this value.
const TypeConstraint &getConstraint() const { return constraint; }

private:
OperandOrResult(StringRef name, VariableLengthKind variableLengthKind,
const TypeConstraint &constraint)
: name(name.str()), variableLengthKind(variableLengthKind),
constraint(constraint) {}

/// The ODS name of this value.
std::string name;

/// The variable length kind of this value.
VariableLengthKind variableLengthKind;

/// The ODS constraint of this value.
const TypeConstraint &constraint;

/// Allow access to the private constructor.
friend class Operation;
};

//===----------------------------------------------------------------------===//
// Operation
//===----------------------------------------------------------------------===//

/// This class provides an ODS representation of a specific operation. This
/// includes all of the information necessary for use by the PDL frontend for
/// generating code for a pattern rewrite.
class Operation {
public:
/// Return the source location of this operation.
SMRange getLoc() const { return location; }

/// Append an attribute to this operation.
void appendAttribute(StringRef name, bool optional,
const AttributeConstraint &constraint) {
attributes.emplace_back(Attribute(name, optional, constraint));
}

/// Append an operand to this operation.
void appendOperand(StringRef name, VariableLengthKind variableLengthKind,
const TypeConstraint &constraint) {
operands.emplace_back(
OperandOrResult(name, variableLengthKind, constraint));
}

/// Append a result to this operation.
void appendResult(StringRef name, VariableLengthKind variableLengthKind,
const TypeConstraint &constraint) {
results.emplace_back(OperandOrResult(name, variableLengthKind, constraint));
}

/// Returns the name of the operation.
StringRef getName() const { return name; }

/// Returns the summary of the operation.
StringRef getSummary() const { return summary; }

/// Returns the description of the operation.
StringRef getDescription() const { return description; }

/// Returns the attributes of this operation.
ArrayRef<Attribute> getAttributes() const { return attributes; }

/// Returns the operands of this operation.
ArrayRef<OperandOrResult> getOperands() const { return operands; }

/// Returns the results of this operation.
ArrayRef<OperandOrResult> getResults() const { return results; }

private:
Operation(StringRef name, StringRef summary, StringRef desc, SMLoc loc);

/// The name of the operation.
std::string name;

/// The documentation of the operation.
std::string summary;
std::string description;

/// The source location of this operation.
SMRange location;

/// The operands of the operation.
SmallVector<OperandOrResult> operands;

/// The results of the operation.
SmallVector<OperandOrResult> results;

/// The attributes of the operation.
SmallVector<Attribute> attributes;

/// Allow access to the private constructor.
friend class Dialect;
};
} // namespace ods
} // namespace pdll
} // namespace mlir

#endif // MLIR_TOOLS_PDLL_ODS_OPERATION_H_
23 changes: 23 additions & 0 deletions mlir/lib/TableGen/Constraint.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,29 @@ StringRef Constraint::getSummary() const {
return def->getName();
}

StringRef Constraint::getDefName() const {
// Functor used to check a base def in the case where the current def is
// anonymous.
auto checkBaseDefFn = [&](StringRef baseName) {
if (const auto *init = dyn_cast<llvm::DefInit>(def->getValueInit(baseName)))
return Constraint(init->getDef(), kind).getDefName();
return def->getName();
};

switch (kind) {
case CK_Attr:
if (def->isAnonymous())
return checkBaseDefFn("baseAttr");
return def->getName();
case CK_Type:
if (def->isAnonymous())
return checkBaseDefFn("baseType");
return def->getName();
default:
return def->getName();
}
}

AppliedConstraint::AppliedConstraint(Constraint &&constraint,
llvm::StringRef self,
std::vector<std::string> &&entities)
Expand Down
1 change: 1 addition & 0 deletions mlir/lib/Tools/PDLL/AST/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,5 +6,6 @@ add_mlir_library(MLIRPDLLAST
Types.cpp

LINK_LIBS PUBLIC
MLIRPDLLODS
MLIRSupport
)
2 changes: 1 addition & 1 deletion mlir/lib/Tools/PDLL/AST/Context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
using namespace mlir;
using namespace mlir::pdll::ast;

Context::Context() {
Context::Context(ods::Context &odsContext) : odsContext(odsContext) {
typeUniquer.registerSingletonStorageType<detail::AttributeTypeStorage>();
typeUniquer.registerSingletonStorageType<detail::ConstraintTypeStorage>();
typeUniquer.registerSingletonStorageType<detail::RewriteTypeStorage>();
Expand Down
1 change: 1 addition & 0 deletions mlir/lib/Tools/PDLL/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
add_subdirectory(AST)
add_subdirectory(CodeGen)
add_subdirectory(ODS)
add_subdirectory(Parser)
31 changes: 29 additions & 2 deletions mlir/lib/Tools/PDLL/CodeGen/MLIRGen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
#include "mlir/Tools/PDLL/AST/Context.h"
#include "mlir/Tools/PDLL/AST/Nodes.h"
#include "mlir/Tools/PDLL/AST/Types.h"
#include "mlir/Tools/PDLL/ODS/Context.h"
#include "mlir/Tools/PDLL/ODS/Operation.h"
#include "llvm/ADT/ScopedHashTable.h"
#include "llvm/ADT/StringExtras.h"
#include "llvm/ADT/TypeSwitch.h"
Expand All @@ -33,7 +35,8 @@ class CodeGen {
public:
CodeGen(MLIRContext *mlirContext, const ast::Context &context,
const llvm::SourceMgr &sourceMgr)
: builder(mlirContext), sourceMgr(sourceMgr) {
: builder(mlirContext), odsContext(context.getODSContext()),
sourceMgr(sourceMgr) {
// Make sure that the PDL dialect is loaded.
mlirContext->loadDialect<pdl::PDLDialect>();
}
Expand Down Expand Up @@ -117,6 +120,9 @@ class CodeGen {
llvm::ScopedHashTable<const ast::VariableDecl *, SmallVector<Value>>;
VariableMapTy variables;

/// A reference to the ODS context.
const ods::Context &odsContext;

/// The source manager of the PDLL ast.
const llvm::SourceMgr &sourceMgr;
};
Expand Down Expand Up @@ -435,7 +441,28 @@ Value CodeGen::genExprImpl(const ast::MemberAccessExpr *expr) {
builder.getI32IntegerAttr(0));
return builder.create<pdl::ResultsOp>(loc, mlirType, parentExprs[0]);
}
llvm_unreachable("unhandled operation member access expression");

assert(opType.getName() && "expected valid operation name");
const ods::Operation *odsOp = odsContext.lookupOperation(*opType.getName());
assert(odsOp && "expected valid ODS operation information");

// Find the result with the member name or by index.
ArrayRef<ods::OperandOrResult> results = odsOp->getResults();
unsigned resultIndex = results.size();
if (llvm::isDigit(name[0])) {
name.getAsInteger(/*Radix=*/10, resultIndex);
} else {
auto findFn = [&](const ods::OperandOrResult &result) {
return result.getName() == name;
};
resultIndex = llvm::find_if(results, findFn) - results.begin();
}
assert(resultIndex < results.size() && "invalid result index");

// Generate the result access.
IntegerAttr index = builder.getI32IntegerAttr(resultIndex);
return builder.create<pdl::ResultsOp>(loc, genType(expr->getType()),
parentExprs[0], index);
}

// Handle tuple based member access.
Expand Down
8 changes: 8 additions & 0 deletions mlir/lib/Tools/PDLL/ODS/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
add_mlir_library(MLIRPDLLODS
Context.cpp
Dialect.cpp
Operation.cpp

LINK_LIBS PUBLIC
MLIRSupport
)
174 changes: 174 additions & 0 deletions mlir/lib/Tools/PDLL/ODS/Context.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,174 @@
//===- Context.cpp --------------------------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include "mlir/Tools/PDLL/ODS/Context.h"
#include "mlir/Tools/PDLL/ODS/Constraint.h"
#include "mlir/Tools/PDLL/ODS/Dialect.h"
#include "mlir/Tools/PDLL/ODS/Operation.h"
#include "llvm/Support/ScopedPrinter.h"
#include "llvm/Support/raw_ostream.h"

using namespace mlir;
using namespace mlir::pdll::ods;

//===----------------------------------------------------------------------===//
// Context
//===----------------------------------------------------------------------===//

Context::Context() = default;
Context::~Context() = default;

const AttributeConstraint &
Context::insertAttributeConstraint(StringRef name, StringRef summary,
StringRef cppClass) {
std::unique_ptr<AttributeConstraint> &constraint = attributeConstraints[name];
if (!constraint) {
constraint.reset(new AttributeConstraint(name, summary, cppClass));
} else {
assert(constraint->getCppClass() == cppClass &&
constraint->getSummary() == summary &&
"constraint with the same name was already registered with a "
"different class");
}
return *constraint;
}

const TypeConstraint &Context::insertTypeConstraint(StringRef name,
StringRef summary,
StringRef cppClass) {
std::unique_ptr<TypeConstraint> &constraint = typeConstraints[name];
if (!constraint)
constraint.reset(new TypeConstraint(name, summary, cppClass));
return *constraint;
}

Dialect &Context::insertDialect(StringRef name) {
std::unique_ptr<Dialect> &dialect = dialects[name];
if (!dialect)
dialect.reset(new Dialect(name));
return *dialect;
}

const Dialect *Context::lookupDialect(StringRef name) const {
auto it = dialects.find(name);
return it == dialects.end() ? nullptr : &*it->second;
}

std::pair<Operation *, bool> Context::insertOperation(StringRef name,
StringRef summary,
StringRef desc,
SMLoc loc) {
std::pair<StringRef, StringRef> dialectAndName = name.split('.');
return insertDialect(dialectAndName.first)
.insertOperation(name, summary, desc, loc);
}

const Operation *Context::lookupOperation(StringRef name) const {
std::pair<StringRef, StringRef> dialectAndName = name.split('.');
if (const Dialect *dialect = lookupDialect(dialectAndName.first))
return dialect->lookupOperation(name);
return nullptr;
}

template <typename T>
SmallVector<T *> sortMapByName(const llvm::StringMap<std::unique_ptr<T>> &map) {
SmallVector<T *> storage;
for (auto &entry : map)
storage.push_back(entry.second.get());
llvm::sort(storage, [](const auto &lhs, const auto &rhs) {
return lhs->getName() < rhs->getName();
});
return storage;
}

void Context::print(raw_ostream &os) const {
auto printVariableLengthCst = [&](StringRef cst, VariableLengthKind kind) {
switch (kind) {
case VariableLengthKind::Optional:
os << "Optional<" << cst << ">";
break;
case VariableLengthKind::Single:
os << cst;
break;
case VariableLengthKind::Variadic:
os << "Variadic<" << cst << ">";
break;
}
};

llvm::ScopedPrinter printer(os);
llvm::DictScope odsScope(printer, "ODSContext");
for (const Dialect *dialect : sortMapByName(dialects)) {
printer.startLine() << "Dialect `" << dialect->getName() << "` {\n";
printer.indent();

for (const Operation *op : sortMapByName(dialect->getOperations())) {
printer.startLine() << "Operation `" << op->getName() << "` {\n";
printer.indent();

// Attributes.
ArrayRef<Attribute> attributes = op->getAttributes();
if (!attributes.empty()) {
printer.startLine() << "Attributes { ";
llvm::interleaveComma(attributes, os, [&](const Attribute &attr) {
os << attr.getName() << " : ";

auto kind = attr.isOptional() ? VariableLengthKind::Optional
: VariableLengthKind::Single;
printVariableLengthCst(attr.getConstraint().getName(), kind);
});
os << " }\n";
}

// Operands.
ArrayRef<OperandOrResult> operands = op->getOperands();
if (!operands.empty()) {
printer.startLine() << "Operands { ";
llvm::interleaveComma(
operands, os, [&](const OperandOrResult &operand) {
os << operand.getName() << " : ";
printVariableLengthCst(operand.getConstraint().getName(),
operand.getVariableLengthKind());
});
os << " }\n";
}

// Results.
ArrayRef<OperandOrResult> results = op->getResults();
if (!results.empty()) {
printer.startLine() << "Results { ";
llvm::interleaveComma(results, os, [&](const OperandOrResult &result) {
os << result.getName() << " : ";
printVariableLengthCst(result.getConstraint().getName(),
result.getVariableLengthKind());
});
os << " }\n";
}

printer.objectEnd();
}
printer.objectEnd();
}
for (const AttributeConstraint *cst : sortMapByName(attributeConstraints)) {
printer.startLine() << "AttributeConstraint `" << cst->getName() << "` {\n";
printer.indent();

printer.startLine() << "Summary: " << cst->getSummary() << "\n";
printer.startLine() << "CppClass: " << cst->getCppClass() << "\n";
printer.objectEnd();
}
for (const TypeConstraint *cst : sortMapByName(typeConstraints)) {
printer.startLine() << "TypeConstraint `" << cst->getName() << "` {\n";
printer.indent();

printer.startLine() << "Summary: " << cst->getSummary() << "\n";
printer.startLine() << "CppClass: " << cst->getCppClass() << "\n";
printer.objectEnd();
}
printer.objectEnd();
}
39 changes: 39 additions & 0 deletions mlir/lib/Tools/PDLL/ODS/Dialect.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
//===- Dialect.cpp --------------------------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include "mlir/Tools/PDLL/ODS/Dialect.h"
#include "mlir/Tools/PDLL/ODS/Constraint.h"
#include "mlir/Tools/PDLL/ODS/Operation.h"
#include "llvm/Support/raw_ostream.h"

using namespace mlir;
using namespace mlir::pdll::ods;

//===----------------------------------------------------------------------===//
// Dialect
//===----------------------------------------------------------------------===//

Dialect::Dialect(StringRef name) : name(name.str()) {}
Dialect::~Dialect() = default;

std::pair<Operation *, bool> Dialect::insertOperation(StringRef name,
StringRef summary,
StringRef desc,
llvm::SMLoc loc) {
std::unique_ptr<Operation> &operation = operations[name];
if (operation)
return std::make_pair(&*operation, /*wasInserted*/ false);

operation.reset(new Operation(name, summary, desc, loc));
return std::make_pair(&*operation, /*wasInserted*/ true);
}

Operation *Dialect::lookupOperation(StringRef name) const {
auto it = operations.find(name);
return it != operations.end() ? it->second.get() : nullptr;
}
26 changes: 26 additions & 0 deletions mlir/lib/Tools/PDLL/ODS/Operation.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
//===- Operation.cpp ------------------------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include "mlir/Tools/PDLL/ODS/Operation.h"
#include "mlir/Support/IndentedOstream.h"
#include "llvm/Support/raw_ostream.h"

using namespace mlir;
using namespace mlir::pdll::ods;

//===----------------------------------------------------------------------===//
// Operation
//===----------------------------------------------------------------------===//

Operation::Operation(StringRef name, StringRef summary, StringRef desc,
llvm::SMLoc loc)
: name(name.str()), summary(summary.str()),
location(loc, llvm::SMLoc::getFromPointer(loc.getPointer() + 1)) {
llvm::raw_string_ostream descOS(description);
raw_indented_ostream(descOS).printReindented(desc.rtrim(" \t"));
}
6 changes: 6 additions & 0 deletions mlir/lib/Tools/PDLL/Parser/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,8 +1,14 @@
set(LLVM_LINK_COMPONENTS
Support
TableGen
)

add_mlir_library(MLIRPDLLParser
Lexer.cpp
Parser.cpp

LINK_LIBS PUBLIC
MLIRPDLLAST
MLIRSupport
MLIRTableGen
)
344 changes: 323 additions & 21 deletions mlir/lib/Tools/PDLL/Parser/Parser.cpp

Large diffs are not rendered by default.

20 changes: 19 additions & 1 deletion mlir/test/mlir-pdll/CodeGen/MLIR/expr.pdll
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// RUN: mlir-pdll %s -I %S -split-input-file -x mlir | FileCheck %s
// RUN: mlir-pdll %s -I %S -I %S/../../../../include -split-input-file -x mlir | FileCheck %s

//===----------------------------------------------------------------------===//
// AttributeExpr
Expand Down Expand Up @@ -55,6 +55,24 @@ Pattern OpAllResultMemberAccess {

// -----

// Handle implicit "named" operation results access.

#include "include/ops.td"

// CHECK: pdl.pattern @OpResultMemberAccess
// CHECK: %[[OP0:.*]] = operation
// CHECK: %[[RES:.*]] = results 0 of %[[OP0]] -> !pdl.value
// CHECK: %[[RES1:.*]] = results 0 of %[[OP0]] -> !pdl.value
// CHECK: %[[RES2:.*]] = results 1 of %[[OP0]] -> !pdl.range<value>
// CHECK: %[[RES3:.*]] = results 1 of %[[OP0]] -> !pdl.range<value>
// CHECK: operation(%[[RES]], %[[RES1]], %[[RES2]], %[[RES3]] : !pdl.value, !pdl.value, !pdl.range<value>, !pdl.range<value>)
Pattern OpResultMemberAccess {
let op: Op<test.with_results>;
erase op<>(op.0, op.result, op.1, op.var_result);
}

// -----

// CHECK: pdl.pattern @TupleMemberAccessNumber
// CHECK: %[[FIRST:.*]] = operation "test.first"
// CHECK: %[[SECOND:.*]] = operation "test.second"
Expand Down
9 changes: 9 additions & 0 deletions mlir/test/mlir-pdll/CodeGen/MLIR/include/ops.td
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
include "mlir/IR/OpBase.td"

def Test_Dialect : Dialect {
let name = "test";
}

def OpWithResults : Op<Test_Dialect, "with_results"> {
let results = (outs I64:$result, Variadic<I64>:$var_result);
}
2 changes: 1 addition & 1 deletion mlir/test/mlir-pdll/Parser/directive-failure.pdll
Original file line number Diff line number Diff line change
Expand Up @@ -19,5 +19,5 @@

// -----

// CHECK: expected include filename to end with `.pdll`
// CHECK: expected include filename to end with `.pdll` or `.td`
#include "unknown_file.foo"
22 changes: 21 additions & 1 deletion mlir/test/mlir-pdll/Parser/expr-failure.pdll
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// RUN: not mlir-pdll %s -split-input-file 2>&1 | FileCheck %s
// RUN: not mlir-pdll %s -I %S -I %S/../../../include -split-input-file 2>&1 | FileCheck %s

//===----------------------------------------------------------------------===//
// Reference Expr
Expand Down Expand Up @@ -276,6 +276,26 @@ Pattern {

// -----

#include "include/ops.td"

Pattern {
// CHECK: invalid number of operand groups for `test.all_empty`; expected 0, but got 2
// CHECK: see the definition of `test.all_empty` here
let foo = op<test.all_empty>(operand1: Value, operand2: Value);
}

// -----

#include "include/ops.td"

Pattern {
// CHECK: invalid number of result groups for `test.all_empty`; expected 0, but got 2
// CHECK: see the definition of `test.all_empty` here
let foo = op<test.all_empty> -> (result1: Type, result2: Type);
}

// -----

//===----------------------------------------------------------------------===//
// `type` Expr
//===----------------------------------------------------------------------===//
Expand Down
21 changes: 20 additions & 1 deletion mlir/test/mlir-pdll/Parser/expr.pdll
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// RUN: mlir-pdll %s -I %S -split-input-file | FileCheck %s
// RUN: mlir-pdll %s -I %S -I %S/../../../include -split-input-file | FileCheck %s

//===----------------------------------------------------------------------===//
// AttrExpr
Expand Down Expand Up @@ -71,6 +71,25 @@ Pattern {

// -----

#include "include/ops.td"

// CHECK: Module
// CHECK: `-VariableDecl {{.*}} Name<firstEltIndex> Type<Value>
// CHECK: `-MemberAccessExpr {{.*}} Member<0> Type<Value>
// CHECK: `-DeclRefExpr {{.*}} Type<Op<test.all_single>>
// CHECK: `-VariableDecl {{.*}} Name<firstEltName> Type<Value>
// CHECK: `-MemberAccessExpr {{.*}} Member<result> Type<Value>
// CHECK: `-DeclRefExpr {{.*}} Type<Op<test.all_single>>
Pattern {
let op: Op<test.all_single>;
let firstEltIndex = op.0;
let firstEltName = op.result;

erase op;
}

// -----

//===----------------------------------------------------------------------===//
// OperationExpr
//===----------------------------------------------------------------------===//
Expand Down
5 changes: 5 additions & 0 deletions mlir/test/mlir-pdll/Parser/include/interfaces.td
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
include "mlir/IR/OpBase.td"

def TestAttrInterface : AttrInterface<"TestAttrInterface">;
def TestOpInterface : OpInterface<"TestOpInterface">;
def TestTypeInterface : TypeInterface<"TestTypeInterface">;
26 changes: 26 additions & 0 deletions mlir/test/mlir-pdll/Parser/include/ops.td
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
include "include/interfaces.td"

def Test_Dialect : Dialect {
let name = "test";
}

def OpAllEmpty : Op<Test_Dialect, "all_empty">;

def OpAllSingle : Op<Test_Dialect, "all_single"> {
let arguments = (ins I64:$operand, I64Attr:$attr);
let results = (outs I64:$result);
}

def OpAllOptional : Op<Test_Dialect, "all_optional"> {
let arguments = (ins Optional<I64>:$operand, OptionalAttr<I64Attr>:$attr);
let results = (outs Optional<I64>:$result);
}

def OpAllVariadic : Op<Test_Dialect, "all_variadic"> {
let arguments = (ins Variadic<I64>:$operands);
let results = (outs Variadic<I64>:$results);
}

def OpMultipleSingleResult : Op<Test_Dialect, "multiple_single_result"> {
let results = (outs I64:$result, I64:$result2);
}
52 changes: 52 additions & 0 deletions mlir/test/mlir-pdll/Parser/include_td.pdll
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
// RUN: mlir-pdll %s -I %S -I %S/../../../include -dump-ods 2>&1 | FileCheck %s

#include "include/ops.td"

// CHECK: Operation `test.all_empty` {
// CHECK-NEXT: }

// CHECK: Operation `test.all_optional` {
// CHECK-NEXT: Attributes { attr : Optional<I64Attr> }
// CHECK-NEXT: Operands { operand : Optional<I64> }
// CHECK-NEXT: Results { result : Optional<I64> }
// CHECK-NEXT: }

// CHECK: Operation `test.all_single` {
// CHECK-NEXT: Attributes { attr : I64Attr }
// CHECK-NEXT: Operands { operand : I64 }
// CHECK-NEXT: Results { result : I64 }
// CHECK-NEXT: }

// CHECK: Operation `test.all_variadic` {
// CHECK-NEXT: Operands { operands : Variadic<I64> }
// CHECK-NEXT: Results { results : Variadic<I64> }
// CHECK-NEXT: }

// CHECK: AttributeConstraint `I64Attr` {
// CHECK-NEXT: Summary: 64-bit signless integer attribute
// CHECK-NEXT: CppClass: ::mlir::IntegerAttr
// CHECK-NEXT: }

// CHECK: TypeConstraint `I64` {
// CHECK-NEXT: Summary: 64-bit signless integer
// CHECK-NEXT: CppClass: ::mlir::IntegerType
// CHECK-NEXT: }

// CHECK: UserConstraintDecl {{.*}} Name<TestAttrInterface> ResultType<Tuple<>> Code<llvm::isa<::TestAttrInterface>(self)>
// CHECK: `Inputs`
// CHECK: `-VariableDecl {{.*}} Name<self> Type<Attr>
// CHECK: `Constraints`
// CHECK: `-AttrConstraintDecl

// CHECK: UserConstraintDecl {{.*}} Name<TestOpInterface> ResultType<Tuple<>> Code<llvm::isa<::TestOpInterface>(self)>
// CHECK: `Inputs`
// CHECK: `-VariableDecl {{.*}} Name<self> Type<Op>
// CHECK: `Constraints`
// CHECK: `-OpConstraintDecl
// CHECK: `-OpNameDecl

// CHECK: UserConstraintDecl {{.*}} Name<TestTypeInterface> ResultType<Tuple<>> Code<llvm::isa<::TestTypeInterface>(self)>
// CHECK: `Inputs`
// CHECK: `-VariableDecl {{.*}} Name<self> Type<Type>
// CHECK: `Constraints`
// CHECK: `-TypeConstraintDecl {{.*}}
24 changes: 23 additions & 1 deletion mlir/test/mlir-pdll/Parser/stmt-failure.pdll
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// RUN: not mlir-pdll %s -split-input-file 2>&1 | FileCheck %s
// RUN: not mlir-pdll %s -I %S -I %S/../../../include -split-input-file 2>&1 | FileCheck %s

// CHECK: expected top-level declaration, such as a `Pattern`
10
Expand Down Expand Up @@ -250,6 +250,28 @@ Pattern {

// -----

#include "include/ops.td"

Pattern {
// CHECK: unable to convert expression of type `Op<test.all_empty>` to the expected type of `Value`
// CHECK: see the definition of `test.all_empty`, which was defined with zero results
let value: Value = op<test.all_empty>;
erase _: Op;
}

// -----

#include "include/ops.td"

Pattern {
// CHECK: unable to convert expression of type `Op<test.multiple_single_result>` to the expected type of `Value`
// CHECK: see the definition of `test.multiple_single_result`, which was defined with at least 2 results
let value: Value = op<test.multiple_single_result>;
erase _: Op;
}

// -----

//===----------------------------------------------------------------------===//
// `replace`
//===----------------------------------------------------------------------===//
Expand Down
24 changes: 21 additions & 3 deletions mlir/tools/mlir-pdll/mlir-pdll.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
#include "mlir/Tools/PDLL/AST/Nodes.h"
#include "mlir/Tools/PDLL/CodeGen/CPPGen.h"
#include "mlir/Tools/PDLL/CodeGen/MLIRGen.h"
#include "mlir/Tools/PDLL/ODS/Context.h"
#include "mlir/Tools/PDLL/Parser/Parser.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/InitLLVM.h"
Expand All @@ -35,16 +36,23 @@ enum class OutputType {

static LogicalResult
processBuffer(raw_ostream &os, std::unique_ptr<llvm::MemoryBuffer> chunkBuffer,
OutputType outputType, std::vector<std::string> &includeDirs) {
OutputType outputType, std::vector<std::string> &includeDirs,
bool dumpODS) {
llvm::SourceMgr sourceMgr;
sourceMgr.setIncludeDirs(includeDirs);
sourceMgr.AddNewSourceBuffer(std::move(chunkBuffer), SMLoc());

ast::Context astContext;
ods::Context odsContext;
ast::Context astContext(odsContext);
FailureOr<ast::Module *> module = parsePDLAST(astContext, sourceMgr);
if (failed(module))
return failure();

// Print out the ODS information if requested.
if (dumpODS)
odsContext.print(llvm::errs());

// Generate the output.
if (outputType == OutputType::AST) {
(*module)->print(os);
return success();
Expand All @@ -66,6 +74,10 @@ processBuffer(raw_ostream &os, std::unique_ptr<llvm::MemoryBuffer> chunkBuffer,
}

int main(int argc, char **argv) {
// FIXME: This is necessary because we link in TableGen, which defines its
// options as static variables.. some of which overlap with our options.
llvm::cl::ResetCommandLineParser();

llvm::cl::opt<std::string> inputFilename(
llvm::cl::Positional, llvm::cl::desc("<input file>"), llvm::cl::init("-"),
llvm::cl::value_desc("filename"));
Expand All @@ -78,6 +90,11 @@ int main(int argc, char **argv) {
"I", llvm::cl::desc("Directory of include files"),
llvm::cl::value_desc("directory"), llvm::cl::Prefix);

llvm::cl::opt<bool> dumpODS(
"dump-ods",
llvm::cl::desc(
"Print out the parsed ODS information from the input file"),
llvm::cl::init(false));
llvm::cl::opt<bool> splitInputFile(
"split-input-file",
llvm::cl::desc("Split the input file into pieces and process each "
Expand Down Expand Up @@ -118,7 +135,8 @@ int main(int argc, char **argv) {
// up into small pieces and checks each independently.
auto processFn = [&](std::unique_ptr<llvm::MemoryBuffer> chunkBuffer,
raw_ostream &os) {
return processBuffer(os, std::move(chunkBuffer), outputType, includeDirs);
return processBuffer(os, std::move(chunkBuffer), outputType, includeDirs,
dumpODS);
};
if (splitInputFile) {
if (failed(splitAndProcessBuffer(std::move(inputFile), processFn,
Expand Down