Skip to content

Commit

Permalink
[ODS] Add support for FloatElementsAttr
Browse files Browse the repository at this point in the history
This CL adds a new FloatElementsAttr definition to ODS for float
elements attributes of a certain type.

Tests are added to show both verification and how to use it in patterns.

PiperOrigin-RevId: 270455487
  • Loading branch information
antiagainst authored and tensorflower-gardener committed Sep 21, 2019
1 parent 33a3a91 commit 8e49063
Show file tree
Hide file tree
Showing 4 changed files with 99 additions and 0 deletions.
26 changes: 26 additions & 0 deletions mlir/include/mlir/IR/OpBase.td
Expand Up @@ -924,6 +924,32 @@ class IntElementsAttr<int width> : ElementsAttrBase<
def I32ElementsAttr : IntElementsAttr<32>;
def I64ElementsAttr : IntElementsAttr<64>;

// A `width`-bit floating point elements attribute. The attribute should be
// ranked and has a shape as specified in `dims`.
class RankedFloatElementsAttr<int width, list<int> dims> : ElementsAttrBase<
CPred<"$_self.isa<DenseFPElementsAttr>() &&"
"$_self.cast<DenseFPElementsAttr>().getType()."
"getElementType().isF" # width # "() && "
// Check that this is ranked and has the specified shape.
"$_self.cast<DenseFPElementsAttr>().getType().hasRank() && "
"$_self.cast<DenseFPElementsAttr>().getType().getShape() == "
"llvm::ArrayRef<int64_t>({" # StrJoinInt<dims>.result # "})">,
width # "-bit float elements attribute of shape [" #
StrJoinInt<dims>.result # "]"> {

let storageType = [{ DenseFPElementsAttr }];
let returnType = [{ DenseFPElementsAttr }];

let constBuilderCall = "DenseElementsAttr::get("
"$_builder.getTensorType({" # StrJoinInt<dims>.result #
"}, $_builder.getF" # width # "Type()), "
"llvm::makeArrayRef($0)).cast<DenseFPElementsAttr>()";
let convertFromStorage = "$_self";
}

class RankedF32ElementsAttr<list<int> dims> : RankedFloatElementsAttr<32, dims>;
class RankedF64ElementsAttr<list<int> dims> : RankedFloatElementsAttr<64, dims>;

// Base class for array attributes.
class ArrayAttrBase<Pred condition, string description> :
Attr<condition, description> {
Expand Down
38 changes: 38 additions & 0 deletions mlir/test/IR/attribute.mlir
Expand Up @@ -189,3 +189,41 @@ func @disallowed_case7_fail() {
%0 = "test.i64_enum_attr"() {attr = 5: i32} : () -> i32
return
}

// -----

//===----------------------------------------------------------------------===//
// Test FloatElementsAttr
//===----------------------------------------------------------------------===//

func @correct_type_pass() {
"test.float_elements_attr"() {
// CHECK: scalar_f32_attr = dense<5.000000e+00> : tensor<2xf32>
// CHECK: tensor_f64_attr = dense<6.000000e+00> : tensor<4x8xf64>
scalar_f32_attr = dense<5.0> : tensor<2xf32>,
tensor_f64_attr = dense<6.0> : tensor<4x8xf64>
} : () -> ()
return
}

// -----

func @wrong_element_type_pass() {
// expected-error @+1 {{failed to satisfy constraint: 32-bit float elements attribute of shape [2]}}
"test.float_elements_attr"() {
scalar_f32_attr = dense<5.0> : tensor<2xf64>,
tensor_f64_attr = dense<6.0> : tensor<4x8xf64>
} : () -> ()
return
}

// -----

func @correct_type_pass() {
// expected-error @+1 {{failed to satisfy constraint: 64-bit float elements attribute of shape [4, 8]}}
"test.float_elements_attr"() {
scalar_f32_attr = dense<5.0> : tensor<2xf32>,
tensor_f64_attr = dense<6.0> : tensor<4xf64>
} : () -> ()
return
}
17 changes: 17 additions & 0 deletions mlir/test/lib/TestDialect/TestOps.td
Expand Up @@ -162,6 +162,23 @@ def I64EnumAttrOp : TEST_Op<"i64_enum_attr"> {
let results = (outs I32:$val);
}

def FloatElementsAttrOp : TEST_Op<"float_elements_attr"> {
let arguments = (ins
RankedF32ElementsAttr<[2]>:$scalar_f32_attr,
RankedF64ElementsAttr<[4, 8]>:$tensor_f64_attr
);
}

// A pattern that updates dense<[3.0, 4.0]> to dense<[5.0, 6.0]>.
// This tests both matching and generating float elements attributes.
def UpdateFloatElementsAttr : Pat<
(FloatElementsAttrOp
ConstantAttr<RankedF32ElementsAttr<[2]>, "{3.0f, 4.0f}">:$f32attr,
$f64attr),
(FloatElementsAttrOp
ConstantAttr<RankedF32ElementsAttr<[2]>, "{5.0f, 6.0f}">:$f32attr,
$f64attr)>;

//===----------------------------------------------------------------------===//
// Test Regions
//===----------------------------------------------------------------------===//
Expand Down
18 changes: 18 additions & 0 deletions mlir/test/mlir-tblgen/pattern.mlir
Expand Up @@ -163,6 +163,24 @@ func @rewrite_i32elementsattr() -> () {
return
}

// CHECK-LABEL: rewrite_f64elementsattr
func @rewrite_f64elementsattr() -> () {
"test.float_elements_attr"() {
// Should match
// CHECK: scalar_f32_attr = dense<[5.000000e+00, 6.000000e+00]> : tensor<2xf32>
scalar_f32_attr = dense<[3.0, 4.0]> : tensor<2xf32>,
tensor_f64_attr = dense<6.0> : tensor<4x8xf64>
} : () -> ()

"test.float_elements_attr"() {
// Should not match
// CHECK: scalar_f32_attr = dense<7.000000e+00> : tensor<2xf32>
scalar_f32_attr = dense<7.0> : tensor<2xf32>,
tensor_f64_attr = dense<3.0> : tensor<4x8xf64>
} : () -> ()
return
}

//===----------------------------------------------------------------------===//
// Test Multi-result Ops
//===----------------------------------------------------------------------===//
Expand Down

0 comments on commit 8e49063

Please sign in to comment.