378 changes: 378 additions & 0 deletions mlir/include/mlir/Dialect/LLVMIR/LLVMEnums.td

Large diffs are not rendered by default.

3 changes: 3 additions & 0 deletions mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,9 @@ def LLVM_Dialect : Dialect {
/// can save us lots of verification time if there are many occurrences
/// of some deeply-nested aggregate types in the program.
ThreadLocalCache<DenseSet<Type>> compatibleTypes;

/// Register the attributes of this dialect.
void registerAttributes();
}];
}

Expand Down
335 changes: 0 additions & 335 deletions mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td

Large diffs are not rendered by default.

21 changes: 15 additions & 6 deletions mlir/include/mlir/IR/BuiltinLocationAttributes.td
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,11 @@

include "mlir/IR/AttrTypeBase.td"
include "mlir/IR/BuiltinDialect.td"
include "mlir/IR/SubElementInterfaces.td"

// Base class for Builtin dialect location attributes.
class Builtin_LocationAttr<string name>
: AttrDef<Builtin_Dialect, name, [], "::mlir::LocationAttr"> {
class Builtin_LocationAttr<string name, list<Trait> traits = []>
: AttrDef<Builtin_Dialect, name, traits, "::mlir::LocationAttr"> {
let cppClassName = name;
let mnemonic = ?;
}
Expand All @@ -27,7 +28,9 @@ class Builtin_LocationAttr<string name>
// CallSiteLoc
//===----------------------------------------------------------------------===//

def CallSiteLoc : Builtin_LocationAttr<"CallSiteLoc"> {
def CallSiteLoc : Builtin_LocationAttr<"CallSiteLoc", [
DeclareAttrInterfaceMethods<SubElementAttrInterface>
]> {
let summary = "A callsite source location";
let description = [{
Syntax:
Expand Down Expand Up @@ -104,7 +107,9 @@ def FileLineColLoc : Builtin_LocationAttr<"FileLineColLoc"> {
// FusedLoc
//===----------------------------------------------------------------------===//

def FusedLoc : Builtin_LocationAttr<"FusedLoc"> {
def FusedLoc : Builtin_LocationAttr<"FusedLoc", [
DeclareAttrInterfaceMethods<SubElementAttrInterface>
]> {
let summary = "A tuple of other source locations";
let description = [{
Syntax:
Expand Down Expand Up @@ -143,7 +148,9 @@ def FusedLoc : Builtin_LocationAttr<"FusedLoc"> {
// NameLoc
//===----------------------------------------------------------------------===//

def NameLoc : Builtin_LocationAttr<"NameLoc"> {
def NameLoc : Builtin_LocationAttr<"NameLoc", [
DeclareAttrInterfaceMethods<SubElementAttrInterface>
]> {
let summary = "A named source location";
let description = [{
Syntax:
Expand Down Expand Up @@ -180,7 +187,9 @@ def NameLoc : Builtin_LocationAttr<"NameLoc"> {
// OpaqueLoc
//===----------------------------------------------------------------------===//

def OpaqueLoc : Builtin_LocationAttr<"OpaqueLoc"> {
def OpaqueLoc : Builtin_LocationAttr<"OpaqueLoc", [
DeclareAttrInterfaceMethods<SubElementAttrInterface>
]> {
let summary = "An opaque source location";
let description = [{
An instance of this location essentially contains a pointer to some data
Expand Down
1 change: 1 addition & 0 deletions mlir/include/mlir/IR/Location.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
#define MLIR_IR_LOCATION_H

#include "mlir/IR/Attributes.h"
#include "mlir/IR/SubElementInterfaces.h"
#include "llvm/Support/PointerLikeTypeTraits.h"

namespace mlir {
Expand Down
1 change: 1 addition & 0 deletions mlir/lib/Dialect/LLVMIR/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ add_subdirectory(Transforms)

add_mlir_dialect_library(MLIRLLVMDialect
IR/FunctionCallUtils.cpp
IR/LLVMAttrs.cpp
IR/LLVMDialect.cpp
IR/LLVMIntrinsicOps.cpp
IR/LLVMTypes.cpp
Expand Down
206 changes: 206 additions & 0 deletions mlir/lib/Dialect/LLVMIR/IR/LLVMAttrs.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,206 @@
//===- LLVMAttrs.cpp - LLVM Attributes registration -----------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file defines the attribute details for the LLVM IR dialect in MLIR.
//
//===----------------------------------------------------------------------===//

#include "mlir/Dialect/LLVMIR/LLVMAttrs.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/DialectImplementation.h"
#include "llvm/ADT/StringExtras.h"
#include "llvm/ADT/TypeSwitch.h"

using namespace mlir;
using namespace mlir::LLVM;

#include "mlir/Dialect/LLVMIR/LLVMOpsEnums.cpp.inc"
#define GET_ATTRDEF_CLASSES
#include "mlir/Dialect/LLVMIR/LLVMOpsAttrDefs.cpp.inc"

//===----------------------------------------------------------------------===//
// LLVMDialect registration
//===----------------------------------------------------------------------===//

void LLVMDialect::registerAttributes() {
addAttributes<
#define GET_ATTRDEF_LIST
#include "mlir/Dialect/LLVMIR/LLVMOpsAttrDefs.cpp.inc"
>();
}

//===----------------------------------------------------------------------===//
// LoopOptionsAttrBuilder
//===----------------------------------------------------------------------===//

LoopOptionsAttrBuilder::LoopOptionsAttrBuilder(LoopOptionsAttr attr)
: options(attr.getOptions().begin(), attr.getOptions().end()) {}

template <typename T>
LoopOptionsAttrBuilder &LoopOptionsAttrBuilder::setOption(LoopOptionCase tag,
Optional<T> value) {
auto option = llvm::find_if(
options, [tag](auto option) { return option.first == tag; });
if (option != options.end()) {
if (value)
option->second = *value;
else
options.erase(option);
} else {
options.push_back(LoopOptionsAttr::OptionValuePair(tag, *value));
}
return *this;
}

LoopOptionsAttrBuilder &
LoopOptionsAttrBuilder::setDisableLICM(Optional<bool> value) {
return setOption(LoopOptionCase::disable_licm, value);
}

/// Set the `interleave_count` option to the provided value. If no value
/// is provided the option is deleted.
LoopOptionsAttrBuilder &
LoopOptionsAttrBuilder::setInterleaveCount(Optional<uint64_t> count) {
return setOption(LoopOptionCase::interleave_count, count);
}

/// Set the `disable_unroll` option to the provided value. If no value
/// is provided the option is deleted.
LoopOptionsAttrBuilder &
LoopOptionsAttrBuilder::setDisableUnroll(Optional<bool> value) {
return setOption(LoopOptionCase::disable_unroll, value);
}

/// Set the `disable_pipeline` option to the provided value. If no value
/// is provided the option is deleted.
LoopOptionsAttrBuilder &
LoopOptionsAttrBuilder::setDisablePipeline(Optional<bool> value) {
return setOption(LoopOptionCase::disable_pipeline, value);
}

/// Set the `pipeline_initiation_interval` option to the provided value.
/// If no value is provided the option is deleted.
LoopOptionsAttrBuilder &LoopOptionsAttrBuilder::setPipelineInitiationInterval(
Optional<uint64_t> count) {
return setOption(LoopOptionCase::pipeline_initiation_interval, count);
}

//===----------------------------------------------------------------------===//
// LoopOptionsAttr
//===----------------------------------------------------------------------===//

template <typename T>
static Optional<T>
getOption(ArrayRef<std::pair<LoopOptionCase, int64_t>> options,
LoopOptionCase option) {
auto it =
lower_bound(options, option, [](auto optionPair, LoopOptionCase option) {
return optionPair.first < option;
});
if (it == options.end())
return {};
return static_cast<T>(it->second);
}

Optional<bool> LoopOptionsAttr::disableUnroll() {
return getOption<bool>(getOptions(), LoopOptionCase::disable_unroll);
}

Optional<bool> LoopOptionsAttr::disableLICM() {
return getOption<bool>(getOptions(), LoopOptionCase::disable_licm);
}

Optional<int64_t> LoopOptionsAttr::interleaveCount() {
return getOption<int64_t>(getOptions(), LoopOptionCase::interleave_count);
}

/// Build the LoopOptions Attribute from a sorted array of individual options.
LoopOptionsAttr LoopOptionsAttr::get(
MLIRContext *context,
ArrayRef<std::pair<LoopOptionCase, int64_t>> sortedOptions) {
assert(llvm::is_sorted(sortedOptions, llvm::less_first()) &&
"LoopOptionsAttr ctor expects a sorted options array");
return Base::get(context, sortedOptions);
}

/// Build the LoopOptions Attribute from a sorted array of individual options.
LoopOptionsAttr LoopOptionsAttr::get(MLIRContext *context,
LoopOptionsAttrBuilder &optionBuilders) {
llvm::sort(optionBuilders.options, llvm::less_first());
return Base::get(context, optionBuilders.options);
}

void LoopOptionsAttr::print(AsmPrinter &printer) const {
printer << "<";
llvm::interleaveComma(getOptions(), printer, [&](auto option) {
printer << stringifyEnum(option.first) << " = ";
switch (option.first) {
case LoopOptionCase::disable_licm:
case LoopOptionCase::disable_unroll:
case LoopOptionCase::disable_pipeline:
printer << (option.second ? "true" : "false");
break;
case LoopOptionCase::interleave_count:
case LoopOptionCase::pipeline_initiation_interval:
printer << option.second;
break;
}
});
printer << ">";
}

Attribute LoopOptionsAttr::parse(AsmParser &parser, Type type) {
if (failed(parser.parseLess()))
return {};

SmallVector<std::pair<LoopOptionCase, int64_t>> options;
llvm::SmallDenseSet<LoopOptionCase> seenOptions;
auto parseLoopOptions = [&]() -> ParseResult {
StringRef optionName;
if (parser.parseKeyword(&optionName))
return failure();

auto option = symbolizeLoopOptionCase(optionName);
if (!option)
return parser.emitError(parser.getNameLoc(), "unknown loop option: ")
<< optionName;
if (!seenOptions.insert(*option).second)
return parser.emitError(parser.getNameLoc(), "loop option present twice");
if (failed(parser.parseEqual()))
return failure();

int64_t value;
switch (*option) {
case LoopOptionCase::disable_licm:
case LoopOptionCase::disable_unroll:
case LoopOptionCase::disable_pipeline:
if (succeeded(parser.parseOptionalKeyword("true")))
value = 1;
else if (succeeded(parser.parseOptionalKeyword("false")))
value = 0;
else {
return parser.emitError(parser.getNameLoc(),
"expected boolean value 'true' or 'false'");
}
break;
case LoopOptionCase::interleave_count:
case LoopOptionCase::pipeline_initiation_interval:
if (failed(parser.parseInteger(value)))
return parser.emitError(parser.getNameLoc(), "expected integer value");
break;
}
options.push_back(std::make_pair(*option, value));
return success();
};
if (parser.parseCommaSeparatedList(parseLoopOptions) || parser.parseGreater())
return {};

llvm::sort(options, llvm::less_first());
return get(parser.getContext(), options);
}
215 changes: 1 addition & 214 deletions mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -45,10 +45,7 @@ static constexpr const char kVolatileAttrName[] = "volatile_";
static constexpr const char kNonTemporalAttrName[] = "nontemporal";
static constexpr const char kElemTypeAttrName[] = "elem_type";

#include "mlir/Dialect/LLVMIR/LLVMOpsEnums.cpp.inc"
#include "mlir/Dialect/LLVMIR/LLVMOpsInterfaces.cpp.inc"
#define GET_ATTRDEF_CLASSES
#include "mlir/Dialect/LLVMIR/LLVMOpsAttrDefs.cpp.inc"

static auto processFMFAttr(ArrayRef<NamedAttribute> attrs) {
SmallVector<NamedAttribute, 8> filteredAttrs(
Expand Down Expand Up @@ -2564,7 +2561,7 @@ OpFoldResult LLVM::GEPOp::fold(ArrayRef<Attribute> operands) {
//===----------------------------------------------------------------------===//

void LLVMDialect::initialize() {
addAttributes<FastmathFlagsAttr, LinkageAttr, CConvAttr, LoopOptionsAttr>();
registerAttributes();

// clang-format off
addTypes<LLVMVoidType,
Expand Down Expand Up @@ -2796,213 +2793,3 @@ bool mlir::LLVM::satisfiesLLVMModule(Operation *op) {
return op->hasTrait<OpTrait::SymbolTable>() &&
op->hasTrait<OpTrait::IsIsolatedFromAbove>();
}

void LinkageAttr::print(AsmPrinter &printer) const {
printer << "<";
if (static_cast<uint64_t>(getLinkage()) <= getMaxEnumValForLinkage())
printer << stringifyEnum(getLinkage());
else
printer << static_cast<uint64_t>(getLinkage());
printer << ">";
}

Attribute LinkageAttr::parse(AsmParser &parser, Type type) {
StringRef elemName;
if (parser.parseLess() || parser.parseKeyword(&elemName) ||
parser.parseGreater())
return {};
auto elem = linkage::symbolizeLinkage(elemName);
if (!elem) {
parser.emitError(parser.getNameLoc(), "Unknown linkage: ") << elemName;
return {};
}
Linkage linkage = *elem;
return LinkageAttr::get(parser.getContext(), linkage);
}

void CConvAttr::print(AsmPrinter &printer) const {
printer << "<";
if (static_cast<uint64_t>(getCallingConv()) <= cconv::getMaxEnumValForCConv())
printer << stringifyEnum(getCallingConv());
else
printer << "INVALID_cc_" << static_cast<uint64_t>(getCallingConv());
printer << ">";
}

Attribute CConvAttr::parse(AsmParser &parser, Type type) {
StringRef convName;

if (parser.parseLess() || parser.parseKeyword(&convName) ||
parser.parseGreater())
return {};
auto cconv = cconv::symbolizeCConv(convName);
if (!cconv) {
parser.emitError(parser.getNameLoc(), "unknown calling convention: ")
<< convName;
return {};
}
CConv cconvVal = *cconv;
return CConvAttr::get(parser.getContext(), cconvVal);
}

LoopOptionsAttrBuilder::LoopOptionsAttrBuilder(LoopOptionsAttr attr)
: options(attr.getOptions().begin(), attr.getOptions().end()) {}

template <typename T>
LoopOptionsAttrBuilder &LoopOptionsAttrBuilder::setOption(LoopOptionCase tag,
Optional<T> value) {
auto option = llvm::find_if(
options, [tag](auto option) { return option.first == tag; });
if (option != options.end()) {
if (value)
option->second = *value;
else
options.erase(option);
} else {
options.push_back(LoopOptionsAttr::OptionValuePair(tag, *value));
}
return *this;
}

LoopOptionsAttrBuilder &
LoopOptionsAttrBuilder::setDisableLICM(Optional<bool> value) {
return setOption(LoopOptionCase::disable_licm, value);
}

/// Set the `interleave_count` option to the provided value. If no value
/// is provided the option is deleted.
LoopOptionsAttrBuilder &
LoopOptionsAttrBuilder::setInterleaveCount(Optional<uint64_t> count) {
return setOption(LoopOptionCase::interleave_count, count);
}

/// Set the `disable_unroll` option to the provided value. If no value
/// is provided the option is deleted.
LoopOptionsAttrBuilder &
LoopOptionsAttrBuilder::setDisableUnroll(Optional<bool> value) {
return setOption(LoopOptionCase::disable_unroll, value);
}

/// Set the `disable_pipeline` option to the provided value. If no value
/// is provided the option is deleted.
LoopOptionsAttrBuilder &
LoopOptionsAttrBuilder::setDisablePipeline(Optional<bool> value) {
return setOption(LoopOptionCase::disable_pipeline, value);
}

/// Set the `pipeline_initiation_interval` option to the provided value.
/// If no value is provided the option is deleted.
LoopOptionsAttrBuilder &LoopOptionsAttrBuilder::setPipelineInitiationInterval(
Optional<uint64_t> count) {
return setOption(LoopOptionCase::pipeline_initiation_interval, count);
}

template <typename T>
static Optional<T>
getOption(ArrayRef<std::pair<LoopOptionCase, int64_t>> options,
LoopOptionCase option) {
auto it =
lower_bound(options, option, [](auto optionPair, LoopOptionCase option) {
return optionPair.first < option;
});
if (it == options.end())
return {};
return static_cast<T>(it->second);
}

Optional<bool> LoopOptionsAttr::disableUnroll() {
return getOption<bool>(getOptions(), LoopOptionCase::disable_unroll);
}

Optional<bool> LoopOptionsAttr::disableLICM() {
return getOption<bool>(getOptions(), LoopOptionCase::disable_licm);
}

Optional<int64_t> LoopOptionsAttr::interleaveCount() {
return getOption<int64_t>(getOptions(), LoopOptionCase::interleave_count);
}

/// Build the LoopOptions Attribute from a sorted array of individual options.
LoopOptionsAttr LoopOptionsAttr::get(
MLIRContext *context,
ArrayRef<std::pair<LoopOptionCase, int64_t>> sortedOptions) {
assert(llvm::is_sorted(sortedOptions, llvm::less_first()) &&
"LoopOptionsAttr ctor expects a sorted options array");
return Base::get(context, sortedOptions);
}

/// Build the LoopOptions Attribute from a sorted array of individual options.
LoopOptionsAttr LoopOptionsAttr::get(MLIRContext *context,
LoopOptionsAttrBuilder &optionBuilders) {
llvm::sort(optionBuilders.options, llvm::less_first());
return Base::get(context, optionBuilders.options);
}

void LoopOptionsAttr::print(AsmPrinter &printer) const {
printer << "<";
llvm::interleaveComma(getOptions(), printer, [&](auto option) {
printer << stringifyEnum(option.first) << " = ";
switch (option.first) {
case LoopOptionCase::disable_licm:
case LoopOptionCase::disable_unroll:
case LoopOptionCase::disable_pipeline:
printer << (option.second ? "true" : "false");
break;
case LoopOptionCase::interleave_count:
case LoopOptionCase::pipeline_initiation_interval:
printer << option.second;
break;
}
});
printer << ">";
}

Attribute LoopOptionsAttr::parse(AsmParser &parser, Type type) {
if (failed(parser.parseLess()))
return {};

SmallVector<std::pair<LoopOptionCase, int64_t>> options;
llvm::SmallDenseSet<LoopOptionCase> seenOptions;
auto parseLoopOptions = [&]() -> ParseResult {
StringRef optionName;
if (parser.parseKeyword(&optionName))
return failure();

auto option = symbolizeLoopOptionCase(optionName);
if (!option)
return parser.emitError(parser.getNameLoc(), "unknown loop option: ")
<< optionName;
if (!seenOptions.insert(*option).second)
return parser.emitError(parser.getNameLoc(), "loop option present twice");
if (failed(parser.parseEqual()))
return failure();

int64_t value;
switch (*option) {
case LoopOptionCase::disable_licm:
case LoopOptionCase::disable_unroll:
case LoopOptionCase::disable_pipeline:
if (succeeded(parser.parseOptionalKeyword("true")))
value = 1;
else if (succeeded(parser.parseOptionalKeyword("false")))
value = 0;
else {
return parser.emitError(parser.getNameLoc(),
"expected boolean value 'true' or 'false'");
}
break;
case LoopOptionCase::interleave_count:
case LoopOptionCase::pipeline_initiation_interval:
if (failed(parser.parseInteger(value)))
return parser.emitError(parser.getNameLoc(), "expected integer value");
break;
}
options.push_back(std::make_pair(*option, value));
return success();
};
if (parser.parseCommaSeparatedList(parseLoopOptions) || parser.parseGreater())
return {};

llvm::sort(options, llvm::less_first());
return get(parser.getContext(), options);
}
66 changes: 66 additions & 0 deletions mlir/lib/IR/Location.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,20 @@ CallSiteLoc CallSiteLoc::get(Location name, ArrayRef<Location> frames) {
return CallSiteLoc::get(name, caller);
}

void CallSiteLoc::walkImmediateSubElements(
function_ref<void(Attribute)> walkAttrsFn,
function_ref<void(Type)> walkTypesFn) const {
walkAttrsFn(getCallee());
walkAttrsFn(getCaller());
}

Attribute
CallSiteLoc::replaceImmediateSubElements(ArrayRef<Attribute> replAttrs,
ArrayRef<Type> replTypes) const {
return get(replAttrs[0].cast<LocationAttr>(),
replAttrs[1].cast<LocationAttr>());
}

//===----------------------------------------------------------------------===//
// FusedLoc
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -121,3 +135,55 @@ Location FusedLoc::get(ArrayRef<Location> locs, Attribute metadata,

return Base::get(context, locs, metadata);
}

void FusedLoc::walkImmediateSubElements(
function_ref<void(Attribute)> walkAttrsFn,
function_ref<void(Type)> walkTypesFn) const {
for (Attribute attr : getLocations())
walkAttrsFn(attr);
walkAttrsFn(getMetadata());
}

Attribute
FusedLoc::replaceImmediateSubElements(ArrayRef<Attribute> replAttrs,
ArrayRef<Type> replTypes) const {
SmallVector<Location> newLocs;
newLocs.reserve(replAttrs.size() - 1);
for (Attribute attr : replAttrs.drop_back())
newLocs.push_back(attr.cast<LocationAttr>());
return get(getContext(), newLocs, replAttrs.back());
}

//===----------------------------------------------------------------------===//
// NameLoc
//===----------------------------------------------------------------------===//

void NameLoc::walkImmediateSubElements(
function_ref<void(Attribute)> walkAttrsFn,
function_ref<void(Type)> walkTypesFn) const {
walkAttrsFn(getName());
walkAttrsFn(getChildLoc());
}

Attribute NameLoc::replaceImmediateSubElements(ArrayRef<Attribute> replAttrs,
ArrayRef<Type> replTypes) const {
return get(replAttrs[0].cast<StringAttr>(),
replAttrs[1].cast<LocationAttr>());
}

//===----------------------------------------------------------------------===//
// OpaqueLoc
//===----------------------------------------------------------------------===//

void OpaqueLoc::walkImmediateSubElements(
function_ref<void(Attribute)> walkAttrsFn,
function_ref<void(Type)> walkTypesFn) const {
walkAttrsFn(getFallbackLocation());
}

Attribute
OpaqueLoc::replaceImmediateSubElements(ArrayRef<Attribute> replAttrs,
ArrayRef<Type> replTypes) const {
return get(getUnderlyingLocation(), getUnderlyingTypeID(),
replAttrs[0].cast<LocationAttr>());
}
3 changes: 2 additions & 1 deletion mlir/test/Dialect/LLVMIR/func.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -273,8 +273,9 @@ module {
// -----

module {
// expected-error@+2 {{unknown calling convention: cc_12}}
"llvm.func"() ({
// expected-error @below {{invalid Calling Conventions specification: cc_12}}
// expected-error @below {{failed to parse CConvAttr parameter 'CallingConv' which is to be a `CConv`}}
}) {sym_name = "generic_unknown_calling_convention", CConv = #llvm.cconv<cc_12>, function_type = !llvm.func<i64 (i64, i64)>} : () -> ()
}

Expand Down
49 changes: 49 additions & 0 deletions mlir/test/mlir-tblgen/enums-gen.td
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,24 @@ def MyBitEnum: I32BitEnumAttr<"MyBitEnum", "An example bit enum",
// DECL: std::string stringifyMyBitEnum(MyBitEnum);
// DECL: ::llvm::Optional<MyBitEnum> symbolizeMyBitEnum(::llvm::StringRef);

// DECL: struct FieldParser<::MyBitEnum, ::MyBitEnum> {
// DECL: template <typename ParserT>
// DECL: static FailureOr<::MyBitEnum> parse(ParserT &parser) {
// DECL: // Parse the keyword/string containing the enum.
// DECL: std::string enumKeyword;
// DECL: auto loc = parser.getCurrentLocation();
// DECL: if (failed(parser.parseOptionalKeywordOrString(&enumKeyword)))
// DECL: return parser.emitError(loc, "expected keyword for An example bit enum");
// DECL: // Symbolize the keyword.
// DECL: if (::llvm::Optional<::MyBitEnum> attr = ::symbolizeEnum<::MyBitEnum>(enumKeyword))
// DECL: return *attr;
// DECL: return parser.emitError(loc, "invalid An example bit enum specification: ") << enumKeyword;
// DECL: }

// DECL: inline ::llvm::raw_ostream &operator<<(::llvm::raw_ostream &p, ::MyBitEnum value) {
// DECL: auto valueStr = stringifyEnum(value);
// DECL: return p << valueStr;

// DEF-LABEL: std::string stringifyMyBitEnum
// DEF: auto val = static_cast<uint32_t>
// DEF: if (val == 0) return "None";
Expand All @@ -40,3 +58,34 @@ def MyBitEnum: I32BitEnumAttr<"MyBitEnum", "An example bit enum",
// DEF: if (str == "None") return MyBitEnum::None;
// DEF: .Case("tagged", 1)
// DEF: .Case("Bit1", 2)

// Test enum printer generation for non non-keyword enums.

def NonKeywordBit: I32BitEnumAttrCaseBit<"Bit0", 0, "tag-ged">;
def MyMixedNonKeywordBitEnum: I32BitEnumAttr<"MyMixedNonKeywordBitEnum", "An example bit enum", [
NonKeywordBit,
Bit1
]> {
let genSpecializedAttr = 0;
}

def MyNonKeywordBitEnum: I32BitEnumAttr<"MyNonKeywordBitEnum", "An example bit enum", [
NonKeywordBit
]> {
let genSpecializedAttr = 0;
}

// DECL: inline ::llvm::raw_ostream &operator<<(::llvm::raw_ostream &p, ::MyMixedNonKeywordBitEnum value) {
// DECL: auto valueStr = stringifyEnum(value);
// DECL: switch (value) {
// DECL: case ::MyMixedNonKeywordBitEnum::Bit1:
// DECL: break;
// DECL: default:
// DECL: return p << '"' << valueStr << '"';
// DECL: }
// DECL: return p << valueStr;
// DECL: }

// DECL: inline ::llvm::raw_ostream &operator<<(::llvm::raw_ostream &p, ::MyNonKeywordBitEnum value) {
// DECL: auto valueStr = stringifyEnum(value);
// DECL: return p << '"' << valueStr << '"';
97 changes: 93 additions & 4 deletions mlir/tools/mlir-tblgen/EnumsGen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,11 @@
//
//===----------------------------------------------------------------------===//

#include "FormatGen.h"
#include "mlir/TableGen/Attribute.h"
#include "mlir/TableGen/Format.h"
#include "mlir/TableGen/GenInfo.h"
#include "llvm/ADT/BitVector.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/StringExtras.h"
#include "llvm/Support/FormatVariadic.h"
Expand Down Expand Up @@ -65,10 +67,92 @@ static void emitEnumClass(const Record &enumDef, StringRef enumName,
os << "};\n\n";
}

static void emitDenseMapInfo(StringRef enumName, std::string underlyingType,
static void emitParserPrinter(const EnumAttr &enumAttr, StringRef qualName,
StringRef cppNamespace, raw_ostream &os) {
if (enumAttr.getUnderlyingType().empty() ||
enumAttr.getConstBuilderTemplate().empty())
return;
auto cases = enumAttr.getAllCases();

// Check which cases shouldn't be printed using a keyword.
llvm::BitVector nonKeywordCases(cases.size());
for (auto [index, caseVal] : llvm::enumerate(cases))
if (!mlir::tblgen::canFormatStringAsKeyword(caseVal.getStr()))
nonKeywordCases.set(index);

// If this is a bit enum attribute, don't allow cases that may overlap with
// other cases. For simplicity sake, only allow cases with a single bit value.
if (enumAttr.isBitEnum()) {
for (auto [index, caseVal] : llvm::enumerate(cases)) {
int64_t value = caseVal.getValue();
if (value < 0 || (value != 0 && !llvm::isPowerOf2_64(value)))
nonKeywordCases.set(index);
}
}

// Generate the parser and the start of the printer for the enum.
const char *parsedAndPrinterStart = R"(
namespace mlir {
template <typename T, typename>
struct FieldParser;
template<>
struct FieldParser<{0}, {0}> {{
template <typename ParserT>
static FailureOr<{0}> parse(ParserT &parser) {{
// Parse the keyword/string containing the enum.
std::string enumKeyword;
auto loc = parser.getCurrentLocation();
if (failed(parser.parseOptionalKeywordOrString(&enumKeyword)))
return parser.emitError(loc, "expected keyword for {2}");
// Symbolize the keyword.
if (::llvm::Optional<{0}> attr = {1}::symbolizeEnum<{0}>(enumKeyword))
return *attr;
return parser.emitError(loc, "invalid {2} specification: ") << enumKeyword;
}
};
} // namespace mlir
namespace llvm {
inline ::llvm::raw_ostream &operator<<(::llvm::raw_ostream &p, {0} value) {{
auto valueStr = stringifyEnum(value);
)";
os << formatv(parsedAndPrinterStart, qualName, cppNamespace,
enumAttr.getSummary());

// If all cases require a string, always wrap.
if (nonKeywordCases.all()) {
os << " return p << '\"' << valueStr << '\"';\n"
"}\n"
"} // namespace llvm\n";
return;
}

// If there are any cases that can't be used with a keyword, switch on the
// case value to determine when to print in the string form.
if (nonKeywordCases.any()) {
os << " switch (value) {\n";
for (auto &it : llvm::enumerate(cases)) {
if (nonKeywordCases.test(it.index()))
continue;
StringRef symbol = it.value().getSymbol();
os << llvm::formatv(" case {0}::{1}:\n", qualName,
llvm::isDigit(symbol.front()) ? ("_" + symbol)
: symbol);
}
os << " break;\n"
" default:\n"
" return p << '\"' << valueStr << '\"';\n"
" }\n";
}
os << " return p << valueStr;\n"
"}\n"
"} // namespace llvm\n";
}

static void emitDenseMapInfo(StringRef qualName, std::string underlyingType,
StringRef cppNamespace, raw_ostream &os) {
std::string qualName =
std::string(formatv("{0}::{1}", cppNamespace, enumName));
if (underlyingType.empty())
underlyingType =
std::string(formatv("std::underlying_type_t<{0}>", qualName));
Expand Down Expand Up @@ -529,8 +613,13 @@ class {1} : public ::mlir::{2} {
for (auto ns : llvm::reverse(namespaces))
os << "} // namespace " << ns << "\n";

// Generate a generic parser and printer for the enum.
std::string qualName =
std::string(formatv("{0}::{1}", cppNamespace, enumName));
emitParserPrinter(enumAttr, qualName, cppNamespace, os);

// Emit DenseMapInfo for this enum class
emitDenseMapInfo(enumName, underlyingType, cppNamespace, os);
emitDenseMapInfo(qualName, underlyingType, cppNamespace, os);
}

static bool emitEnumDecls(const RecordKeeper &recordKeeper, raw_ostream &os) {
Expand Down
5 changes: 5 additions & 0 deletions mlir/tools/mlir-tblgen/FormatGen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -444,6 +444,11 @@ bool mlir::tblgen::shouldEmitSpaceBefore(StringRef value,

bool mlir::tblgen::canFormatStringAsKeyword(
StringRef value, function_ref<void(Twine)> emitError) {
if (value.empty()) {
if (emitError)
emitError("keywords cannot be empty");
return false;
}
if (!isalpha(value.front()) && value.front() != '_') {
if (emitError)
emitError("valid keyword starts with a letter or '_'");
Expand Down