From 9546d8b108dce03e03e0448cebbca5fa0fe4be21 Mon Sep 17 00:00:00 2001 From: Sean Silva Date: Tue, 26 May 2020 16:44:20 -0700 Subject: [PATCH] [mlir][core] Add IndexElementsAttr helpers. Summary: In a follow-up, I'll update the Shape dialect to use this instead of I64ElementsAttr. Differential Revision: https://reviews.llvm.org/D80601 --- mlir/include/mlir/IR/Builders.h | 1 + mlir/include/mlir/IR/OpBase.td | 7 +++++++ mlir/lib/IR/Attributes.cpp | 2 ++ mlir/lib/IR/Builders.cpp | 7 +++++++ mlir/test/lib/Dialect/Test/TestOps.td | 4 ++++ mlir/test/mlir-tblgen/types.mlir | 15 +++++++++++++++ 6 files changed, 36 insertions(+) diff --git a/mlir/include/mlir/IR/Builders.h b/mlir/include/mlir/IR/Builders.h index 4ade6bb1e4390..424eb980cd33a 100644 --- a/mlir/include/mlir/IR/Builders.h +++ b/mlir/include/mlir/IR/Builders.h @@ -128,6 +128,7 @@ class Builder { /// as attributes. DenseIntElementsAttr getI32TensorAttr(ArrayRef values); DenseIntElementsAttr getI64TensorAttr(ArrayRef values); + DenseIntElementsAttr getIndexTensorAttr(ArrayRef values); ArrayAttr getAffineMapArrayAttr(ArrayRef values); ArrayAttr getBoolArrayAttr(ArrayRef values); diff --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td index 6a7542c7127c0..5ffb1727ee353 100644 --- a/mlir/include/mlir/IR/OpBase.td +++ b/mlir/include/mlir/IR/OpBase.td @@ -1218,6 +1218,13 @@ class IntElementsAttrBase : let convertFromStorage = "$_self"; } +def IndexElementsAttr + : IntElementsAttrBase() + .getType() + .getElementType() + .isIndex()}]>, + "index elements attribute">; + class AnyIntElementsAttr : IntElementsAttrBase< CPred<"$_self.cast().getType()." "getElementType().isInteger(" # width # ")">, diff --git a/mlir/lib/IR/Attributes.cpp b/mlir/lib/IR/Attributes.cpp index 540c3c6258e29..12fd08787fa75 100644 --- a/mlir/lib/IR/Attributes.cpp +++ b/mlir/lib/IR/Attributes.cpp @@ -624,6 +624,8 @@ Attribute DenseElementsAttr::AttributeElementIterator::operator*() const { owner.getContext()); return IntegerAttr::get(eltTy, *IntElementIterator(owner, index)); } + if (eltTy.isa()) + return IntegerAttr::get(eltTy, *IntElementIterator(owner, index)); if (auto floatEltTy = eltTy.dyn_cast()) { IntElementIterator intIt(owner, index); FloatElementIterator floatIt(floatEltTy.getFloatSemantics(), intIt); diff --git a/mlir/lib/IR/Builders.cpp b/mlir/lib/IR/Builders.cpp index a72e03c739e3b..064889724f092 100644 --- a/mlir/lib/IR/Builders.cpp +++ b/mlir/lib/IR/Builders.cpp @@ -130,6 +130,13 @@ DenseIntElementsAttr Builder::getI64TensorAttr(ArrayRef values) { values); } +DenseIntElementsAttr Builder::getIndexTensorAttr(ArrayRef values) { + return DenseIntElementsAttr::get( + RankedTensorType::get(static_cast(values.size()), + getIndexType()), + values); +} + IntegerAttr Builder::getI32IntegerAttr(int32_t value) { return IntegerAttr::get(getIntegerType(32), APInt(32, value)); } diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td index 997d8eb44ae59..8e5b380dff452 100644 --- a/mlir/test/lib/Dialect/Test/TestOps.td +++ b/mlir/test/lib/Dialect/Test/TestOps.td @@ -454,6 +454,10 @@ def I32ElementsAttrOp : TEST_Op<"i32ElementsAttr"> { let arguments = (ins I32ElementsAttr:$attr); } +def IndexElementsAttrOp : TEST_Op<"indexElementsAttr"> { + let arguments = (ins IndexElementsAttr:$attr); +} + def OpWithInferTypeInterfaceOp : TEST_Op<"op_with_infer_type_if", [ DeclareOpInterfaceMethods]> { let arguments = (ins AnyTensor, AnyTensor); diff --git a/mlir/test/mlir-tblgen/types.mlir b/mlir/test/mlir-tblgen/types.mlir index 6a0a80ca5e5fc..5e4dac33012b9 100644 --- a/mlir/test/mlir-tblgen/types.mlir +++ b/mlir/test/mlir-tblgen/types.mlir @@ -489,3 +489,18 @@ func @elements_attr_i32(%arg0: tensor<1x2xi32>) { "test.i32ElementsAttr"() {attr = dense<[1, 2]>:tensor<2xi32>} : () -> () return } + +// ----- + +func @elements_attr_index() { + "test.indexElementsAttr"() {attr = dense<[1, 2]>:tensor<2xindex>} : () -> () + return +} + +// ----- + +func @elements_attr_not_index() { + // expected-error@+1 {{index elements attribute}} + "test.indexElementsAttr"() {attr = dense<[1, 2]>:tensor<2xi32>} : () -> () + return +}