Skip to content

Commit

Permalink
[mlir][ods] Added EnumAttr, an AttrDef implementation of enum attributes
Browse files Browse the repository at this point in the history
`EnumAttr` is a pure TableGen implementation of enum attributes using `AttrDef`. This is meant as a drop-in replacement for `StrEnumAttr`, which is soon to be deprecated. `StrEnumAttr` is often used over `IntEnumAttr` because its more readable in MLIR assembly formats. However, storing and manipulating strings is not efficient. Defining `StrEnumAttr` can also be awkward and relies on a lot of special logic in `EnumsGen`, and has some hidden sharp edges.

Also, `EnumAttr` stores the enum directly,  removing the need to convert to/from integers when calling attribute getters on ops.

Reviewed By: mehdi_amini

Differential Revision: https://reviews.llvm.org/D115181
  • Loading branch information
Mogball committed Dec 17, 2021
1 parent c50a4b3 commit 319d8cf
Show file tree
Hide file tree
Showing 7 changed files with 189 additions and 1 deletion.
96 changes: 96 additions & 0 deletions mlir/include/mlir/IR/EnumAttr.td
@@ -0,0 +1,96 @@
//===-- EnumAttr.td - Enum attributes ----------------------*- tablegen -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#ifndef ENUM_ATTR
#define ENUM_ATTR

include "mlir/IR/OpBase.td"

// A C++ enum as an attribute parameter. The parameter implements a parser and
// printer for the enum by dispatching calls to `stringToSymbol` and
// `symbolToString`.
class EnumParameter<EnumAttrInfo enumInfo>
: AttrParameter<enumInfo.cppNamespace # "::" # enumInfo.className,
"an enum of type " # enumInfo.className> {
// Parse a keyword and pass it to `stringToSymbol`. Emit an error if a the
// symbol is not valid.
let parser = [{[&]() -> ::mlir::FailureOr<}] # cppType # [{> {
auto loc = $_parser.getCurrentLocation();
::llvm::StringRef enumKeyword;
if (::mlir::failed($_parser.parseKeyword(&enumKeyword)))
return ::mlir::failure();
auto maybeEnum = }] # enumInfo.cppNamespace # "::" #
enumInfo.stringToSymbolFnName # [{(enumKeyword);
if (maybeEnum)
return *maybeEnum;
return {$_parser.emitError(loc, "expected }] #
cppType # [{ to be one of: }] #
!interleave(!foreach(enum, enumInfo.enumerants, enum.str), ", ") # [{")};
}()}];
// Print the enum by calling `symbolToString`.
let printer = "$_printer << " # enumInfo.symbolToStringFnName # "($_self)";
}

// An attribute backed by a C++ enum. The attribute contains a single
// parameter `value` whose type is the C++ enum class.
//
// Example:
//
// ```
// def MyEnum : I32EnumAttr<"MyEnum", "a simple enum", [
// I32EnumAttrCase<"First", 0, "first">,
// I32EnumAttrCase<"Second", 1, "second>]> {
// let genSpecializedAttr = 0;
// }
//
// def MyEnumAttr : EnumAttr<MyDialect, MyEnum, "enum">;
// ```
//
// By default, the assembly format of the attribute works best with operation
// assembly formats. For example:
//
// ```
// def MyOp : Op<MyDialect, "my_op"> {
// let arguments = (ins MyEnumAttr:$enum);
// let assemblyFormat = "$enum attr-dict";
// }
// ```
//
// The op will appear in the IR as `my_dialect.my_op first`. However, the
// generic format of the attribute will be `#my_dialect<"enum first">`. Override
// the attribute's assembly format as required.
class EnumAttr<Dialect dialect, EnumAttrInfo enumInfo, string name = "",
list <Trait> traits = []>
: AttrDef<dialect, enumInfo.className, traits> {
let summary = enumInfo.summary;

// Inherit the C++ namespace from the enum.
let cppNamespace = enumInfo.cppNamespace;

// Define a constant builder for the attribute to convert from C++ enums.
let constBuilderCall = cppNamespace # "::" # cppClassName #
"::get($_builder.getContext(), $0)";

// Op attribute getters should return the underlying C++ enum type.
let returnType = enumInfo.cppNamespace # "::" # enumInfo.className;

// Convert from attribute to the underlying C++ type in op getters.
let convertFromStorage = "$_self.getValue()";

// The enum attribute has one parameter: the C++ enum value.
let parameters = (ins EnumParameter<enumInfo>:$value);

// If a mnemonic was provided, use it to generate a custom assembly format.
let mnemonic = name;

// The default assembly format for enum attributes. Selected to best work with
// operation assembly formats.
let assemblyFormat = "$value";
}

#endif // ENUM_ATTR
30 changes: 30 additions & 0 deletions mlir/test/IR/enum-attr-invalid.mlir
@@ -0,0 +1,30 @@
// RUN: mlir-opt -verify-diagnostics -split-input-file %s

func @test_invalid_enum_case() -> () {
// expected-error@+2 {{expected test::TestEnum to be one of: first, second, third}}
// expected-error@+1 {{failed to parse TestEnumAttr}}
test.op_with_enum #test<"enum fourth">
}

// -----

func @test_invalid_enum_case() -> () {
// expected-error@+1 {{expected test::TestEnum to be one of: first, second, third}}
test.op_with_enum fourth
// expected-error@+1 {{failed to parse TestEnumAttr}}
}

// -----

func @test_invalid_attr() -> () {
// expected-error@+1 {{op attribute 'value' failed to satisfy constraint: a test enum}}
"test.op_with_enum"() {value = 1 : index} : () -> ()
}

// -----

func @test_parse_invalid_attr() -> () {
// expected-error@+2 {{expected valid keyword}}
// expected-error@+1 {{failed to parse TestEnumAttr parameter 'value'}}
test.op_with_enum 1 : index
}
28 changes: 28 additions & 0 deletions mlir/test/IR/enum-attr-roundtrip.mlir
@@ -0,0 +1,28 @@
// RUN: mlir-opt %s | mlir-opt -test-patterns | FileCheck %s

// CHECK-LABEL: @test_enum_attr_roundtrip
func @test_enum_attr_roundtrip() -> () {
// CHECK: value = #test<"enum first">
"test.op"() {value = #test<"enum first">} : () -> ()
// CHECK: value = #test<"enum second">
"test.op"() {value = #test<"enum second">} : () -> ()
// CHECK: value = #test<"enum third">
"test.op"() {value = #test<"enum third">} : () -> ()
return
}

// CHECK-LABEL: @test_op_with_enum
func @test_op_with_enum() -> () {
// CHECK: test.op_with_enum third
test.op_with_enum third
return
}

// CHECK-LABEL: @test_match_op_with_enum
func @test_match_op_with_enum() -> () {
// CHECK: test.op_with_enum third tag 0 : i32
test.op_with_enum third tag 0 : i32
// CHECK: test.op_with_enum second tag 1 : i32
test.op_with_enum first tag 0 : i32
return
}
1 change: 1 addition & 0 deletions mlir/test/lib/Dialect/Test/TestAttributes.h
Expand Up @@ -23,6 +23,7 @@
#include "mlir/IR/DialectImplementation.h"

#include "TestAttrInterfaces.h.inc"
#include "TestOpEnums.h.inc"

#define GET_ATTRDEF_CLASSES
#include "TestAttrDefs.h.inc"
Expand Down
1 change: 0 additions & 1 deletion mlir/test/lib/Dialect/Test/TestDialect.h
Expand Up @@ -39,7 +39,6 @@ class DLTIDialect;
class RewritePatternSet;
} // namespace mlir

#include "TestOpEnums.h.inc"
#include "TestOpInterfaces.h.inc"
#include "TestOpStructs.h.inc"
#include "TestOpsDialect.h.inc"
Expand Down
33 changes: 33 additions & 0 deletions mlir/test/lib/Dialect/Test/TestOps.td
Expand Up @@ -11,6 +11,7 @@

include "TestDialect.td"
include "mlir/Dialect/DLTI/DLTIBase.td"
include "mlir/IR/EnumAttr.td"
include "mlir/IR/OpBase.td"
include "mlir/IR/OpAsmInterface.td"
include "mlir/IR/RegionKindInterface.td"
Expand Down Expand Up @@ -287,6 +288,38 @@ def StringElementsAttrOp : TEST_Op<"string_elements_attr"> {
);
}

//===----------------------------------------------------------------------===//
// Test Enum Attributes
//===----------------------------------------------------------------------===//

// Define the C++ enum.
def TestEnum
: I32EnumAttr<"TestEnum", "a test enum", [
I32EnumAttrCase<"First", 0, "first">,
I32EnumAttrCase<"Second", 1, "second">,
I32EnumAttrCase<"Third", 2, "third">,
]> {
let genSpecializedAttr = 0;
let cppNamespace = "test";
}

// Define the enum attribute.
def TestEnumAttr : EnumAttr<Test_Dialect, TestEnum, "enum">;

// Define an op that contains the enum attribute.
def OpWithEnum : TEST_Op<"op_with_enum"> {
let arguments = (ins TestEnumAttr:$value, OptionalAttr<AnyAttr>:$tag);
let assemblyFormat = "$value (`tag` $tag^)? attr-dict";
}

// Define a pattern that matches and creates an enum attribute.
def : Pat<(OpWithEnum ConstantAttr<TestEnumAttr,
"::test::TestEnum::First">:$value,
ConstantAttr<I32Attr, "0">:$tag),
(OpWithEnum ConstantAttr<TestEnumAttr,
"::test::TestEnum::Second">,
ConstantAttr<I32Attr, "1">)>;

//===----------------------------------------------------------------------===//
// Test Attribute Constraints
//===----------------------------------------------------------------------===//
Expand Down
1 change: 1 addition & 0 deletions utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
Expand Up @@ -839,6 +839,7 @@ cc_binary(
td_library(
name = "OpBaseTdFiles",
srcs = [
"include/mlir/IR/EnumAttr.td",
"include/mlir/IR/OpAsmInterface.td",
"include/mlir/IR/OpBase.td",
"include/mlir/IR/RegionKindInterface.td",
Expand Down

0 comments on commit 319d8cf

Please sign in to comment.