16 changes: 15 additions & 1 deletion mlir/lib/TableGen/Class.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,13 @@ void ParentClass::writeTo(raw_indented_ostream &os) const {
//===----------------------------------------------------------------------===//

void UsingDeclaration::writeDeclTo(raw_indented_ostream &os) const {
if (!templateParams.empty()) {
os << "template <";
llvm::interleaveComma(templateParams, os, [&](StringRef paramName) {
os << "typename " << paramName;
});
os << ">\n";
}
os << "using " << name;
if (!value.empty())
os << " = " << value;
Expand Down Expand Up @@ -275,6 +282,13 @@ ParentClass &Class::addParent(ParentClass parent) {
}

void Class::writeDeclTo(raw_indented_ostream &os) const {
if (!templateParams.empty()) {
os << "template <";
llvm::interleaveComma(templateParams, os,
[&](StringRef param) { os << "typename " << param; });
os << ">\n";
}

// Declare the class.
os << (isStruct ? "struct" : "class") << ' ' << className << ' ';

Expand Down Expand Up @@ -341,7 +355,7 @@ Visibility Class::getLastVisibilityDecl() const {
});
return it == reverseDecls.end()
? (isStruct ? Visibility::Public : Visibility::Private)
: cast<VisibilityDeclaration>(*it).getVisibility();
: cast<VisibilityDeclaration>(**it).getVisibility();
}

Method *insertAndPruneMethods(std::vector<std::unique_ptr<Method>> &methods,
Expand Down
10 changes: 10 additions & 0 deletions mlir/lib/TableGen/Dialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,16 @@ bool Dialect::isExtensible() const {
return def->getValueAsBit("isExtensible");
}

Dialect::FolderAPI Dialect::getFolderAPI() const {
int64_t value = def->getValueAsInt("useFoldAPI");
if (value < static_cast<int64_t>(FolderAPI::RawAttributes) ||
value > static_cast<int64_t>(FolderAPI::FolderAdaptor))
llvm::PrintFatalError(def->getLoc(),
"Invalid value for dialect field `useFoldAPI`");

return static_cast<FolderAPI>(value);
}

bool Dialect::operator==(const Dialect &other) const {
return def == other.def;
}
Expand Down
6 changes: 6 additions & 0 deletions mlir/lib/TableGen/Operator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,10 @@ std::string Operator::getAdaptorName() const {
return std::string(llvm::formatv("{0}Adaptor", getCppClassName()));
}

std::string Operator::getGenericAdaptorName() const {
return std::string(llvm::formatv("{0}GenericAdaptor", getCppClassName()));
}

/// Assert the invariants of accessors generated for the given name.
static void assertAccessorInvariants(const Operator &op, StringRef name) {
std::string accessorName =
Expand Down Expand Up @@ -741,3 +745,5 @@ std::string Operator::getSetterName(StringRef name) const {
std::string Operator::getRemoverName(StringRef name) const {
return "remove" + convertToCamelFromSnakeCase(name, /*capitalizeFirst=*/true);
}

bool Operator::hasFolder() const { return def.getValueAsBit("hasFolder"); }
16 changes: 16 additions & 0 deletions mlir/test/IR/test-fold-adaptor.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
// RUN: mlir-opt %s -canonicalize -split-input-file | FileCheck %s

func.func @test() -> i32 {
%c5 = "test.constant"() {value = 5 : i32} : () -> i32
%c1 = "test.constant"() {value = 1 : i32} : () -> i32
%c2 = "test.constant"() {value = 2 : i32} : () -> i32
%c3 = "test.constant"() {value = 3 : i32} : () -> i32
%res = test.fold_with_fold_adaptor %c5, [ %c1, %c2], { (%c3), (%c3) } {
%c0 = "test.constant"() {value = 0 : i32} : () -> i32
}
return %res : i32
}

// CHECK-LABEL: func.func @test
// CHECK-NEXT: %[[C:.*]] = "test.constant"() {value = 33 : i32}
// CHECK-NEXT: return %[[C]]
21 changes: 21 additions & 0 deletions mlir/test/lib/Dialect/Test/TestDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@
#include "llvm/ADT/StringExtras.h"
#include "llvm/ADT/StringSwitch.h"

#include <numeric>

// Include this before the using namespace lines below to
// test that we don't have namespace dependencies.
#include "TestOpsDialect.cpp.inc"
Expand Down Expand Up @@ -1126,6 +1128,25 @@ OpFoldResult TestPassthroughFold::fold(ArrayRef<Attribute> operands) {
return getOperand();
}

OpFoldResult TestOpFoldWithFoldAdaptor::fold(FoldAdaptor adaptor) {
int64_t sum = 0;
if (auto value = dyn_cast_or_null<IntegerAttr>(adaptor.getOp()))
sum += value.getValue().getSExtValue();

for (Attribute attr : adaptor.getVariadic())
if (auto value = dyn_cast_or_null<IntegerAttr>(attr))
sum += 2 * value.getValue().getSExtValue();

for (ArrayRef<Attribute> attrs : adaptor.getVarOfVar())
for (Attribute attr : attrs)
if (auto value = dyn_cast_or_null<IntegerAttr>(attr))
sum += 3 * value.getValue().getSExtValue();

sum += 4 * std::distance(adaptor.getBody().begin(), adaptor.getBody().end());

return IntegerAttr::get(getType(), sum);
}

LogicalResult OpWithInferTypeInterfaceOp::inferReturnTypes(
MLIRContext *, std::optional<Location> location, ValueRange operands,
DictionaryAttr attributes, RegionRange regions,
Expand Down
25 changes: 25 additions & 0 deletions mlir/test/lib/Dialect/Test/TestOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -1297,6 +1297,31 @@ def TestOpInPlaceFoldSuccess : TEST_Op<"op_in_place_fold_success"> {
}];
}

def TestOpFoldWithFoldAdaptor
: TEST_Op<"fold_with_fold_adaptor",
[AttrSizedOperandSegments, NoTerminator]> {
let arguments = (ins
I32:$op,
DenseI32ArrayAttr:$attr,
Variadic<I32>:$variadic,
VariadicOfVariadic<I32, "attr">:$var_of_var
);

let results = (outs I32:$res);

let regions = (region AnyRegion:$body);

let assemblyFormat = [{
$op `,` `[` $variadic `]` `,` `{` $var_of_var `}` $body attr-dict-with-keyword
}];

let hasFolder = 0;

let extraClassDeclaration = [{
::mlir::OpFoldResult fold(FoldAdaptor adaptor);
}];
}

// An op that always fold itself.
def TestPassthroughFold : TEST_Op<"passthrough_fold"> {
let arguments = (ins AnyType:$op);
Expand Down
15 changes: 15 additions & 0 deletions mlir/test/mlir-tblgen/has-fold-invalid-values.td
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
// RUN: not mlir-tblgen -gen-op-decls -I %S/../../include %s 2>&1 | FileCheck %s

include "mlir/IR/OpBase.td"

def Test_Dialect : Dialect {
let name = "test";
let cppNamespace = "NS";
let useFoldAPI = 3;
}

def InvalidValue_Op : Op<Test_Dialect, "invalid_op"> {
let hasFolder = 1;
}

// CHECK: Invalid value for dialect field `useFoldAPI`
85 changes: 68 additions & 17 deletions mlir/test/mlir-tblgen/op-decl-and-defs.td
Original file line number Diff line number Diff line change
Expand Up @@ -52,20 +52,34 @@ def NS_AOp : NS_Op<"a_op", [IsolatedFromAbove, IsolatedFromAbove]> {

// CHECK-LABEL: NS::AOp declarations

// CHECK: class AOpAdaptor {
// CHECK: namespace detail {
// CHECK: class AOpGenericAdaptorBase {
// CHECK: public:
// CHECK: AOpAdaptor(::mlir::ValueRange values
// CHECK: ::mlir::ValueRange getODSOperands(unsigned index);
// CHECK: ::mlir::Value getA();
// CHECK: ::mlir::ValueRange getB();
// CHECK: ::mlir::IntegerAttr getAttr1Attr();
// CHECK: uint32_t getAttr1();
// CHECK: ::mlir::FloatAttr getSomeAttr2Attr();
// CHECK: ::std::optional< ::llvm::APFloat > getSomeAttr2();
// CHECK: ::mlir::Region &getSomeRegion();
// CHECK: ::mlir::RegionRange getSomeRegions();
// CHECK: };
// CHECK: }

// CHECK: template <typename RangeT>
// CHECK: class AOpGenericAdaptor : public detail::AOpGenericAdaptorBase {
// CHECK: public:
// CHECK: AOpGenericAdaptor(RangeT values,
// CHECK-SAME: odsOperands(values)
// CHECK: RangeT getODSOperands(unsigned index) {
// CHECK: ValueT getA() {
// CHECK: RangeT getB() {
// CHECK: private:
// CHECK: ::mlir::ValueRange odsOperands;
// CHECK: RangeT odsOperands;
// CHECK: };

// CHECK: class AOpAdaptor : public AOpGenericAdaptor<::mlir::ValueRange> {
// CHECK: public:
// CHECK: AOpAdaptor(AOp
// CHECK: ::mlir::LogicalResult verify(
// CHECK: };

// CHECK: class AOp : public ::mlir::Op<AOp, ::mlir::OpTrait::AtLeastNRegions<1>::Impl, ::mlir::OpTrait::AtLeastNResults<1>::Impl, ::mlir::OpTrait::ZeroSuccessors, ::mlir::OpTrait::AtLeastNOperands<1>::Impl, ::mlir::OpTrait::OpInvariants, ::mlir::OpTrait::IsIsolatedFromAbove
Expand Down Expand Up @@ -108,10 +122,10 @@ def NS_AOp : NS_Op<"a_op", [IsolatedFromAbove, IsolatedFromAbove]> {

// DEFS-LABEL: NS::AOp definitions

// DEFS: AOpAdaptor::AOpAdaptor(::mlir::ValueRange values, ::mlir::DictionaryAttr attrs, ::mlir::RegionRange regions) : odsOperands(values), odsAttrs(attrs), odsRegions(regions)
// DEFS: ::mlir::RegionRange AOpAdaptor::getSomeRegions()
// DEFS: AOpGenericAdaptorBase::AOpGenericAdaptorBase(::mlir::DictionaryAttr attrs, ::mlir::RegionRange regions) : odsAttrs(attrs), odsRegions(regions)
// DEFS: ::mlir::RegionRange AOpGenericAdaptorBase::getSomeRegions()
// DEFS-NEXT: return odsRegions.drop_front(1);
// DEFS: ::mlir::RegionRange AOpAdaptor::getRegions()
// DEFS: ::mlir::RegionRange AOpGenericAdaptorBase::getRegions()

// Check AttrSizedOperandSegments
// ---
Expand All @@ -127,15 +141,17 @@ def NS_AttrSizedOperandOp : NS_Op<"attr_sized_operands",
);
}

// CHECK-LABEL: AttrSizedOperandOpAdaptor(
// CHECK-SAME: ::mlir::ValueRange values
// CHECK-SAME: ::mlir::DictionaryAttr attrs
// CHECK: ::mlir::ValueRange getA();
// CHECK: ::mlir::ValueRange getB();
// CHECK: ::mlir::Value getC();
// CHECK: ::mlir::ValueRange getD();
// CHECK-LABEL: class AttrSizedOperandOpGenericAdaptorBase {
// CHECK: ::mlir::DenseIntElementsAttr getOperandSegmentSizes();

// CHECK-LABEL: AttrSizedOperandOpGenericAdaptor(
// CHECK-SAME: RangeT values
// CHECK-SAME: ::mlir::DictionaryAttr attrs
// CHECK: RangeT getA() {
// CHECK: RangeT getB() {
// CHECK: ValueT getC() {
// CHECK: RangeT getD() {

// Check op trait for different number of operands
// ---

Expand Down Expand Up @@ -166,7 +182,7 @@ def NS_EOp : NS_Op<"op_with_optionals", []> {
}

// CHECK-LABEL: NS::EOp declarations
// CHECK: ::mlir::Value getA();
// CHECK: ::mlir::TypedValue<::mlir::IntegerType> getA();
// CHECK: ::mlir::MutableOperandRange getAMutable();
// CHECK: ::mlir::TypedValue<::mlir::FloatType> getB();
// CHECK: static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, /*optional*/::mlir::Type b, /*optional*/::mlir::Value a)
Expand Down Expand Up @@ -301,6 +317,29 @@ def NS_LOp : NS_Op<"op_with_same_operands_and_result_types_unwrapped_attr", [Sam
// CHECK: static void build(::mlir::OpBuilder &, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::ValueRange operands, ::llvm::ArrayRef<::mlir::NamedAttribute> attributes = {});
// CHECK: static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::ValueRange operands, ::llvm::ArrayRef<::mlir::NamedAttribute> attributes = {});

def TestWithNewFold_Dialect : Dialect {
let name = "test";
let cppNamespace = "::mlir::testWithFold";
let useFoldAPI = kEmitFoldAdaptorFolder;
}

def NS_MOp : Op<TestWithNewFold_Dialect, "op_with_single_result_and_fold_adaptor_fold", []> {
let results = (outs AnyType:$res);

let hasFolder = 1;
}

// CHECK-LABEL: class MOp :
// CHECK: ::mlir::OpFoldResult fold(FoldAdaptor adaptor);

def NS_NOp : Op<TestWithNewFold_Dialect, "op_with_multiple_results_and_fold_adaptor_fold", []> {
let results = (outs AnyType:$res1, AnyType:$res2);

let hasFolder = 1;
}

// CHECK-LABEL: class NOp :
// CHECK: ::mlir::LogicalResult fold(FoldAdaptor adaptor, ::llvm::SmallVectorImpl<::mlir::OpFoldResult> &results);

// Test that type defs have the proper namespaces when used as a constraint.
// ---
Expand Down Expand Up @@ -335,6 +374,18 @@ def NS_SkipDefaultBuildersOp : NS_Op<"skip_default_builders", []> {
// Check leading underscore in op name
// ---

def NS_VarOfVarOperandOp : NS_Op<"var_of_var_operand", []> {
let arguments = (ins
VariadicOfVariadic<F32, "var_size">:$var_of_var_attr,
DenseI32ArrayAttr:$var_size
);
}

// CHECK-LABEL: class VarOfVarOperandOpGenericAdaptor
// CHECK: public:
// CHECK: ::llvm::SmallVector<RangeT> getVarOfVarAttr() {


def NS__AOp : NS_Op<"_op_with_leading_underscore", []>;

// CHECK-LABEL: NS::_AOp declarations
Expand Down
13 changes: 2 additions & 11 deletions mlir/test/mlir-tblgen/op-operand.td
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@ def OpA : NS_Op<"one_normal_operand_op", []> {

// CHECK-LABEL: OpA definitions

// CHECK: OpAAdaptor::OpAAdaptor
// CHECK-SAME: odsOperands(values), odsAttrs(attrs)
// CHECK: OpAGenericAdaptorBase::OpAGenericAdaptorBase
// CHECK-SAME: odsAttrs(attrs)

// CHECK: void OpA::build
// CHECK: ::mlir::Value input
Expand All @@ -39,15 +39,6 @@ def OpD : NS_Op<"mix_variadic_and_normal_inputs_op", [SameVariadicOperandSize]>
let arguments = (ins Variadic<AnyTensor>:$input1, AnyTensor:$input2, Variadic<AnyTensor>:$input3);
}

// CHECK-LABEL: ::mlir::ValueRange OpDAdaptor::getInput1
// CHECK-NEXT: return getODSOperands(0);

// CHECK-LABEL: ::mlir::Value OpDAdaptor::getInput2
// CHECK-NEXT: return *getODSOperands(1).begin();

// CHECK-LABEL: ::mlir::ValueRange OpDAdaptor::getInput3
// CHECK-NEXT: return getODSOperands(2);

// CHECK-LABEL: ::mlir::Operation::operand_range OpD::getInput1
// CHECK-NEXT: return getODSOperands(0);

Expand Down
5 changes: 5 additions & 0 deletions mlir/tools/mlir-tblgen/OpClass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,11 @@ OpClass::OpClass(StringRef name, StringRef extraClassDeclaration,
declare<UsingDeclaration>("Op::print");
/// Type alias for the adaptor class.
declare<UsingDeclaration>("Adaptor", className + "Adaptor");
declare<UsingDeclaration>("GenericAdaptor",
className + "GenericAdaptor<RangeT>")
->addTemplateParam("RangeT");
declare<UsingDeclaration>(
"FoldAdaptor", "GenericAdaptor<::llvm::ArrayRef<::mlir::Attribute>>");
}

void OpClass::finalize() {
Expand Down
267 changes: 180 additions & 87 deletions mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp

Large diffs are not rendered by default.

63 changes: 63 additions & 0 deletions mlir/unittests/IR/AdaptorTest.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
//===- AdaptorTest.cpp - Adaptor unit tests -------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include "../../test/lib/Dialect/Test/TestDialect.h"
#include "gmock/gmock.h"
#include "gtest/gtest.h"

using namespace llvm;
using namespace mlir;
using namespace test;

using testing::ElementsAre;

TEST(Adaptor, GenericAdaptorsOperandAccess) {
MLIRContext context;
context.loadDialect<test::TestDialect>();
Builder builder(&context);

// Has normal and Variadic arguments.
MixedNormalVariadicOperandOp::FoldAdaptor a({});
{
SmallVector<int> v = {0, 1, 2, 3, 4};
MixedNormalVariadicOperandOp::GenericAdaptor<ArrayRef<int>> b(v);
EXPECT_THAT(b.getInput1(), ElementsAre(0, 1));
EXPECT_EQ(b.getInput2(), 2);
EXPECT_THAT(b.getInput3(), ElementsAre(3, 4));
}

// Has optional arguments.
OIListSimple::FoldAdaptor c({}, nullptr);
{
// Optional arguments return the default constructed value if not present.
// Using optional instead of plain int here to differentiate absence of
// value from the value 0.
SmallVector<std::optional<int>> v = {0, 4};
OIListSimple::GenericAdaptor<ArrayRef<std::optional<int>>> d(
v, builder.getDictionaryAttr({builder.getNamedAttr(
"operand_segment_sizes",
builder.getDenseI32ArrayAttr({1, 0, 1}))}));
EXPECT_EQ(d.getArg0(), 0);
EXPECT_EQ(d.getArg1(), std::nullopt);
EXPECT_EQ(d.getArg2(), 4);
}

// Has VariadicOfVariadic arguments.
FormatVariadicOfVariadicOperand::FoldAdaptor e({});
{
SmallVector<int> v = {0, 1, 2, 3, 4};
FormatVariadicOfVariadicOperand::GenericAdaptor<ArrayRef<int>> f(
v, builder.getDictionaryAttr({builder.getNamedAttr(
"operand_segments", builder.getDenseI32ArrayAttr({3, 2, 0}))}));
SmallVector<ArrayRef<int>> operand = f.getOperand();
ASSERT_EQ(operand.size(), (std::size_t)3);
EXPECT_THAT(operand[0], ElementsAre(0, 1, 2));
EXPECT_THAT(operand[1], ElementsAre(3, 4));
EXPECT_THAT(operand[2], ElementsAre());
}
}
1 change: 1 addition & 0 deletions mlir/unittests/IR/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
add_mlir_unittest(MLIRIRTests
AdaptorTest.cpp
AttributeTest.cpp
BlockAndValueMapping.cpp
DialectTest.cpp
Expand Down