Skip to content

Commit

Permalink
De-duplicate EnumAttr overrides by defining defaults
Browse files Browse the repository at this point in the history
EnumAttr should provide meaningful defaults so concrete instances
do not need to duplicate the fields.

PiperOrigin-RevId: 282398431
  • Loading branch information
antiagainst authored and tensorflower-gardener committed Nov 25, 2019
1 parent bd485af commit 9b6e6ce
Show file tree
Hide file tree
Showing 8 changed files with 19 additions and 51 deletions.
12 changes: 2 additions & 10 deletions mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
Expand Up @@ -174,11 +174,7 @@ def ICmpPredicate : I64EnumAttr<
[ICmpPredicateEQ, ICmpPredicateNE, ICmpPredicateSLT, ICmpPredicateSLE,
ICmpPredicateSGT, ICmpPredicateSGE, ICmpPredicateULT, ICmpPredicateULE,
ICmpPredicateUGT, ICmpPredicateUGE]> {
let cppNamespace = "mlir::LLVM";

let returnType = "ICmpPredicate";
let convertFromStorage =
"static_cast<" # returnType # ">($_self.getValue().getZExtValue())";
let cppNamespace = "::mlir::LLVM";
}

// Other integer operations.
Expand Down Expand Up @@ -225,11 +221,7 @@ def FCmpPredicate : I64EnumAttr<
FCmpPredicateUEQ, FCmpPredicateUGT, FCmpPredicateUGE, FCmpPredicateULT,
FCmpPredicateULE, FCmpPredicateUNE, FCmpPredicateUNO, FCmpPredicateTRUE
]> {
let cppNamespace = "mlir::LLVM";

let returnType = "FCmpPredicate";
let convertFromStorage =
"static_cast<" # returnType # ">($_self.getValue().getZExtValue())";
let cppNamespace = "::mlir::LLVM";
}

// Other integer operations.
Expand Down
34 changes: 0 additions & 34 deletions mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td
Expand Up @@ -326,8 +326,6 @@ def SPV_AddressingModelAttr :
SPV_AM_Logical, SPV_AM_Physical32, SPV_AM_Physical64,
SPV_AM_PhysicalStorageBuffer64
]> {
let returnType = "::mlir::spirv::AddressingModel";
let convertFromStorage = "static_cast<::mlir::spirv::AddressingModel>($_self.getInt())";
let cppNamespace = "::mlir::spirv";
}

Expand Down Expand Up @@ -462,8 +460,6 @@ def SPV_BuiltInAttr :
SPV_BI_HitTNV, SPV_BI_HitKindNV, SPV_BI_IncomingRayFlagsNV,
SPV_BI_WarpsPerSMNV, SPV_BI_SMCountNV, SPV_BI_WarpIDNV, SPV_BI_SMIDNV
]> {
let returnType = "::mlir::spirv::BuiltIn";
let convertFromStorage = "static_cast<::mlir::spirv::BuiltIn>($_self.getInt())";
let cppNamespace = "::mlir::spirv";
}

Expand Down Expand Up @@ -672,8 +668,6 @@ def SPV_CapabilityAttr :
SPV_C_SubgroupAvcMotionEstimationIntraINTEL,
SPV_C_SubgroupAvcMotionEstimationChromaINTEL
]> {
let returnType = "::mlir::spirv::Capability";
let convertFromStorage = "static_cast<::mlir::spirv::Capability>($_self.getInt())";
let cppNamespace = "::mlir::spirv";
}

Expand Down Expand Up @@ -763,8 +757,6 @@ def SPV_DecorationAttr :
SPV_D_AliasedPointer, SPV_D_CounterBuffer, SPV_D_UserSemantic,
SPV_D_UserTypeGOOGLE
]> {
let returnType = "::mlir::spirv::Decoration";
let convertFromStorage = "static_cast<::mlir::spirv::Decoration>($_self.getInt())";
let cppNamespace = "::mlir::spirv";
}

Expand All @@ -781,8 +773,6 @@ def SPV_DimAttr :
SPV_D_1D, SPV_D_2D, SPV_D_3D, SPV_D_Cube, SPV_D_Rect, SPV_D_Buffer,
SPV_D_SubpassData
]> {
let returnType = "::mlir::spirv::Dim";
let convertFromStorage = "static_cast<::mlir::spirv::Dim>($_self.getInt())";
let cppNamespace = "::mlir::spirv";
}

Expand Down Expand Up @@ -866,8 +856,6 @@ def SPV_ExecutionModeAttr :
SPV_EM_SampleInterlockOrderedEXT, SPV_EM_SampleInterlockUnorderedEXT,
SPV_EM_ShadingRateInterlockOrderedEXT, SPV_EM_ShadingRateInterlockUnorderedEXT
]> {
let returnType = "::mlir::spirv::ExecutionMode";
let convertFromStorage = "static_cast<::mlir::spirv::ExecutionMode>($_self.getInt())";
let cppNamespace = "::mlir::spirv";
}

Expand All @@ -894,8 +882,6 @@ def SPV_ExecutionModelAttr :
SPV_EM_TaskNV, SPV_EM_MeshNV, SPV_EM_RayGenerationNV, SPV_EM_IntersectionNV,
SPV_EM_AnyHitNV, SPV_EM_ClosestHitNV, SPV_EM_MissNV, SPV_EM_CallableNV
]> {
let returnType = "::mlir::spirv::ExecutionModel";
let convertFromStorage = "static_cast<::mlir::spirv::ExecutionModel>($_self.getInt())";
let cppNamespace = "::mlir::spirv";
}

Expand All @@ -909,8 +895,6 @@ def SPV_FunctionControlAttr :
BitEnumAttr<"FunctionControl", "valid SPIR-V FunctionControl", [
SPV_FC_None, SPV_FC_Inline, SPV_FC_DontInline, SPV_FC_Pure, SPV_FC_Const
]> {
let returnType = "::mlir::spirv::FunctionControl";
let convertFromStorage = "static_cast<::mlir::spirv::FunctionControl>($_self.getInt())";
let cppNamespace = "::mlir::spirv";
}

Expand Down Expand Up @@ -967,8 +951,6 @@ def SPV_ImageFormatAttr :
SPV_IF_Rgb10a2ui, SPV_IF_Rg32ui, SPV_IF_Rg16ui, SPV_IF_Rg8ui, SPV_IF_R16ui,
SPV_IF_R8ui
]> {
let returnType = "::mlir::spirv::ImageFormat";
let convertFromStorage = "static_cast<::mlir::spirv::ImageFormat>($_self.getInt())";
let cppNamespace = "::mlir::spirv";
}

Expand All @@ -979,8 +961,6 @@ def SPV_LinkageTypeAttr :
I32EnumAttr<"LinkageType", "valid SPIR-V LinkageType", [
SPV_LT_Export, SPV_LT_Import
]> {
let returnType = "::mlir::spirv::LinkageType";
let convertFromStorage = "static_cast<::mlir::spirv::LinkageType>($_self.getInt())";
let cppNamespace = "::mlir::spirv";
}

Expand All @@ -1001,8 +981,6 @@ def SPV_LoopControlAttr :
SPV_LC_DependencyLength, SPV_LC_MinIterations, SPV_LC_MaxIterations,
SPV_LC_IterationMultiple, SPV_LC_PeelCount, SPV_LC_PartialCount
]> {
let returnType = "::mlir::spirv::LoopControl";
let convertFromStorage = "static_cast<::mlir::spirv::LoopControl>($_self.getInt())";
let cppNamespace = "::mlir::spirv";
}

Expand All @@ -1020,8 +998,6 @@ def SPV_MemoryAccessAttr :
SPV_MA_MakePointerAvailable, SPV_MA_MakePointerVisible,
SPV_MA_NonPrivatePointer
]> {
let returnType = "::mlir::spirv::MemoryAccess";
let convertFromStorage = "static_cast<::mlir::spirv::MemoryAccess>($_self.getInt())";
let cppNamespace = "::mlir::spirv";
}

Expand All @@ -1034,8 +1010,6 @@ def SPV_MemoryModelAttr :
I32EnumAttr<"MemoryModel", "valid SPIR-V MemoryModel", [
SPV_MM_Simple, SPV_MM_GLSL450, SPV_MM_OpenCL, SPV_MM_Vulkan
]> {
let returnType = "::mlir::spirv::MemoryModel";
let convertFromStorage = "static_cast<::mlir::spirv::MemoryModel>($_self.getInt())";
let cppNamespace = "::mlir::spirv";
}

Expand Down Expand Up @@ -1063,8 +1037,6 @@ def SPV_MemorySemanticsAttr :
SPV_MS_AtomicCounterMemory, SPV_MS_ImageMemory, SPV_MS_OutputMemory,
SPV_MS_MakeAvailable, SPV_MS_MakeVisible, SPV_MS_Volatile
]> {
let returnType = "::mlir::spirv::MemorySemantics";
let convertFromStorage = "static_cast<::mlir::spirv::MemorySemantics>($_self.getInt())";
let cppNamespace = "::mlir::spirv";
}

Expand All @@ -1080,8 +1052,6 @@ def SPV_ScopeAttr :
SPV_S_CrossDevice, SPV_S_Device, SPV_S_Workgroup, SPV_S_Subgroup,
SPV_S_Invocation, SPV_S_QueueFamily
]> {
let returnType = "::mlir::spirv::Scope";
let convertFromStorage = "static_cast<::mlir::spirv::Scope>($_self.getInt())";
let cppNamespace = "::mlir::spirv";
}

Expand All @@ -1093,8 +1063,6 @@ def SPV_SelectionControlAttr :
BitEnumAttr<"SelectionControl", "valid SPIR-V SelectionControl", [
SPV_SC_None, SPV_SC_Flatten, SPV_SC_DontFlatten
]> {
let returnType = "::mlir::spirv::SelectionControl";
let convertFromStorage = "static_cast<::mlir::spirv::SelectionControl>($_self.getInt())";
let cppNamespace = "::mlir::spirv";
}

Expand Down Expand Up @@ -1128,8 +1096,6 @@ def SPV_StorageClassAttr :
SPV_SC_RayPayloadNV, SPV_SC_HitAttributeNV, SPV_SC_IncomingRayPayloadNV,
SPV_SC_ShaderRecordBufferNV, SPV_SC_PhysicalStorageBuffer
]> {
let returnType = "::mlir::spirv::StorageClass";
let convertFromStorage = "static_cast<::mlir::spirv::StorageClass>($_self.getInt())";
let cppNamespace = "::mlir::spirv";
}

Expand Down
9 changes: 9 additions & 0 deletions mlir/include/mlir/IR/OpBase.td
Expand Up @@ -938,12 +938,18 @@ class IntEnumAttr<I intType, string name, string description,
class I32EnumAttr<string name, string description,
list<I32EnumAttrCase> cases> :
IntEnumAttr<I32, name, description, cases> {
let returnType = cppNamespace # "::" # name;
let underlyingType = "uint32_t";
let convertFromStorage = "static_cast<" # returnType # ">($_self.getInt())";
let constBuilderCall = "$_builder.getI32IntegerAttr(static_cast<int32_t>($0))";
}
class I64EnumAttr<string name, string description,
list<I64EnumAttrCase> cases> :
IntEnumAttr<I64, name, description, cases> {
let returnType = cppNamespace # "::" # name;
let underlyingType = "uint64_t";
let convertFromStorage = "static_cast<" # returnType # ">($_self.getInt())";
let constBuilderCall = "$_builder.getI64IntegerAttr(static_cast<int64_t>($0))";
}

// A bit enum stored with 32-bit IntegerAttr.
Expand All @@ -963,7 +969,10 @@ class BitEnumAttr<string name, string description,
")))">
]>;

let returnType = cppNamespace # "::" # name;
let underlyingType = "uint32_t";
let convertFromStorage = "static_cast<" # returnType # ">($_self.getInt())";
let constBuilderCall = "$_builder.getI32IntegerAttr(static_cast<int32_t>($0))";

// We need to return a string because we may concatenate symbols for multiple
// bits together.
Expand Down
2 changes: 2 additions & 0 deletions mlir/test/lib/TestDialect/CMakeLists.txt
Expand Up @@ -6,6 +6,8 @@ set(LLVM_OPTIONAL_SOURCES
set(LLVM_TARGET_DEFINITIONS TestOps.td)
mlir_tablegen(TestOps.h.inc -gen-op-decls)
mlir_tablegen(TestOps.cpp.inc -gen-op-defs)
mlir_tablegen(TestOpEnums.h.inc -gen-enum-decls)
mlir_tablegen(TestOpEnums.cpp.inc -gen-enum-defs)
mlir_tablegen(TestPatterns.inc -gen-rewriters)
add_public_tablegen_target(MLIRTestOpsIncGen)

Expand Down
3 changes: 3 additions & 0 deletions mlir/test/lib/TestDialect/TestDialect.cpp
Expand Up @@ -22,6 +22,7 @@
#include "mlir/IR/TypeUtilities.h"
#include "mlir/Transforms/FoldUtils.h"
#include "mlir/Transforms/InliningUtils.h"
#include "llvm/ADT/StringSwitch.h"

using namespace mlir;

Expand Down Expand Up @@ -304,5 +305,7 @@ SmallVector<Type, 2> mlir::OpWithInferTypeInterfaceOp::inferReturnTypes(
// Static initialization for Test dialect registration.
static mlir::DialectRegistration<mlir::TestDialect> testDialect;

#include "TestOpEnums.cpp.inc"

#define GET_OP_CLASSES
#include "TestOps.cpp.inc"
2 changes: 2 additions & 0 deletions mlir/test/lib/TestDialect/TestDialect.h
Expand Up @@ -32,6 +32,8 @@
#include "mlir/IR/StandardTypes.h"
#include "mlir/IR/SymbolTable.h"

#include "TestOpEnums.h.inc"

namespace mlir {

class TestDialect : public Dialect {
Expand Down
2 changes: 1 addition & 1 deletion mlir/test/lib/TestDialect/TestOps.td
Expand Up @@ -694,7 +694,7 @@ def MultiResultOpKind5: I64EnumAttrCase<"kind5", 5>;
def MultiResultOpKind6: I64EnumAttrCase<"kind6", 6>;

def MultiResultOpEnum: I64EnumAttr<
"Multi-result op kinds", "", [
"MultiResultOpEnum", "Multi-result op kinds", [
MultiResultOpKind1, MultiResultOpKind2, MultiResultOpKind3,
MultiResultOpKind4, MultiResultOpKind5, MultiResultOpKind6
]>;
Expand Down
6 changes: 0 additions & 6 deletions mlir/utils/spirv/gen_spirv_dialect.py
Expand Up @@ -200,9 +200,6 @@ def get_case_symbol(kind_name, case_name):
enum_attr = 'def SPV_{name}Attr :\n '\
'{category}EnumAttr<"{name}", "valid SPIR-V {name}", [\n{cases}\n'\
' ]> {{\n'\
' let returnType = "::mlir::spirv::{name}";\n'\
' let convertFromStorage = '\
'"static_cast<::mlir::spirv::{name}>($_self.getInt())";\n'\
' let cppNamespace = "::mlir::spirv";\n}}'.format(
name=kind_name, category=kind_category, cases=case_names)
return kind_name, case_defs + '\n\n' + enum_attr
Expand Down Expand Up @@ -240,9 +237,6 @@ def gen_opcode(instructions):
' I32EnumAttr<"{name}", "valid SPIR-V instructions", [\n'\
'{lst}\n'\
' ]> {{\n'\
' let returnType = "::mlir::spirv::{name}";\n'\
' let convertFromStorage = '\
'"static_cast<::mlir::spirv::{name}>($_self.getInt())";\n'\
' let cppNamespace = "::mlir::spirv";\n}}'.format(
name='Opcode', lst=opcode_list)
return opcode_str + '\n\n' + enum_attr
Expand Down

0 comments on commit 9b6e6ce

Please sign in to comment.