Skip to content

Commit

Permalink
[mlir][core] Add IndexElementsAttr helpers.
Browse files Browse the repository at this point in the history
Summary:
In a follow-up, I'll update the Shape dialect to use this instead of
I64ElementsAttr.

Differential Revision: https://reviews.llvm.org/D80601
  • Loading branch information
silvasean committed May 27, 2020
1 parent 98ef93e commit 9546d8b
Show file tree
Hide file tree
Showing 6 changed files with 36 additions and 0 deletions.
1 change: 1 addition & 0 deletions mlir/include/mlir/IR/Builders.h
Expand Up @@ -128,6 +128,7 @@ class Builder {
/// as attributes.
DenseIntElementsAttr getI32TensorAttr(ArrayRef<int32_t> values);
DenseIntElementsAttr getI64TensorAttr(ArrayRef<int64_t> values);
DenseIntElementsAttr getIndexTensorAttr(ArrayRef<int64_t> values);

ArrayAttr getAffineMapArrayAttr(ArrayRef<AffineMap> values);
ArrayAttr getBoolArrayAttr(ArrayRef<bool> values);
Expand Down
7 changes: 7 additions & 0 deletions mlir/include/mlir/IR/OpBase.td
Expand Up @@ -1218,6 +1218,13 @@ class IntElementsAttrBase<Pred condition, string description> :
let convertFromStorage = "$_self";
}

def IndexElementsAttr
: IntElementsAttrBase<CPred<[{$_self.cast<DenseIntElementsAttr>()
.getType()
.getElementType()
.isIndex()}]>,
"index elements attribute">;

class AnyIntElementsAttr<int width> : IntElementsAttrBase<
CPred<"$_self.cast<DenseIntElementsAttr>().getType()."
"getElementType().isInteger(" # width # ")">,
Expand Down
2 changes: 2 additions & 0 deletions mlir/lib/IR/Attributes.cpp
Expand Up @@ -624,6 +624,8 @@ Attribute DenseElementsAttr::AttributeElementIterator::operator*() const {
owner.getContext());
return IntegerAttr::get(eltTy, *IntElementIterator(owner, index));
}
if (eltTy.isa<IndexType>())
return IntegerAttr::get(eltTy, *IntElementIterator(owner, index));
if (auto floatEltTy = eltTy.dyn_cast<FloatType>()) {
IntElementIterator intIt(owner, index);
FloatElementIterator floatIt(floatEltTy.getFloatSemantics(), intIt);
Expand Down
7 changes: 7 additions & 0 deletions mlir/lib/IR/Builders.cpp
Expand Up @@ -130,6 +130,13 @@ DenseIntElementsAttr Builder::getI64TensorAttr(ArrayRef<int64_t> values) {
values);
}

DenseIntElementsAttr Builder::getIndexTensorAttr(ArrayRef<int64_t> values) {
return DenseIntElementsAttr::get(
RankedTensorType::get(static_cast<int64_t>(values.size()),
getIndexType()),
values);
}

IntegerAttr Builder::getI32IntegerAttr(int32_t value) {
return IntegerAttr::get(getIntegerType(32), APInt(32, value));
}
Expand Down
4 changes: 4 additions & 0 deletions mlir/test/lib/Dialect/Test/TestOps.td
Expand Up @@ -454,6 +454,10 @@ def I32ElementsAttrOp : TEST_Op<"i32ElementsAttr"> {
let arguments = (ins I32ElementsAttr:$attr);
}

def IndexElementsAttrOp : TEST_Op<"indexElementsAttr"> {
let arguments = (ins IndexElementsAttr:$attr);
}

def OpWithInferTypeInterfaceOp : TEST_Op<"op_with_infer_type_if", [
DeclareOpInterfaceMethods<InferTypeOpInterface>]> {
let arguments = (ins AnyTensor, AnyTensor);
Expand Down
15 changes: 15 additions & 0 deletions mlir/test/mlir-tblgen/types.mlir
Expand Up @@ -489,3 +489,18 @@ func @elements_attr_i32(%arg0: tensor<1x2xi32>) {
"test.i32ElementsAttr"() {attr = dense<[1, 2]>:tensor<2xi32>} : () -> ()
return
}

// -----

func @elements_attr_index() {
"test.indexElementsAttr"() {attr = dense<[1, 2]>:tensor<2xindex>} : () -> ()
return
}

// -----

func @elements_attr_not_index() {
// expected-error@+1 {{index elements attribute}}
"test.indexElementsAttr"() {attr = dense<[1, 2]>:tensor<2xi32>} : () -> ()
return
}

0 comments on commit 9546d8b

Please sign in to comment.