diff --git a/mlir/docs/Dialects/Linalg/OpDSL.md b/mlir/docs/Dialects/Linalg/OpDSL.md index b892bbe427a18..37604fc17dd9b 100644 --- a/mlir/docs/Dialects/Linalg/OpDSL.md +++ b/mlir/docs/Dialects/Linalg/OpDSL.md @@ -311,16 +311,17 @@ An example for a rank polymorphic operation is `fill`: ```python @linalg_structured_op -def fill(value=ScalarDef(T1), - O=TensorDef(U, output=True)): - O[None] = TypeFn.cast_signed(U, value) +def fill(value=ScalarDef(T), + O=TensorDef(T, output=True)): + O[None] = value ``` -The operation sets the elements of the output tensor `O` to `value`. All -operands are either scalars or rank zero tensors that are accessed using the -index `None`. The operation thus performs a scalar computation that trivially -extends to a multi-dimensional pointwise computation. As a result, we may use -`fill` with arbitrary ranked output tensors: +The operation sets the elements of the output tensor `O` to `value`. The value +type must match the element type of the output tensor. All operands are either +scalars or rank zero tensors that are accessed using the index `None`. The +operation thus performs a scalar computation that trivially extends to a +multi-dimensional pointwise computation. As a result, we may use `fill` with +arbitrary ranked output tensors: ```python tensor_2d = tensor.EmptyOp([4, 8], f32) diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml index 9aae1b850c3a0..521afc991063f 100644 --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml @@ -6054,9 +6054,9 @@ metadata: !LinalgOpMetadata doc: |- Fills the output tensor with the given value. - Works for arbitrary ranked output tensors since the operation performs scalar - accesses only and is thus rank polymorphic. Numeric casting is performed on - the value operand, promoting it to the same data type as the output. + Works for arbitrary ranked output tensors since the operation performs + scalar accesses only and is thus rank polymorphic. The value operand + type must match the element type of the output. implements: - LinalgFillOpInterface defines: @@ -6066,11 +6066,11 @@ structured_op: !LinalgStructuredOpConfig - !LinalgOperandDefConfig name: value kind: scalar - type_var: T1 + type_var: T - !LinalgOperandDefConfig name: O kind: output_tensor - type_var: U + type_var: T shape_map: affine_map<() -> ()> indexing_maps: !LinalgIndexingMapsConfig static_indexing_maps: @@ -6081,13 +6081,7 @@ structured_op: !LinalgStructuredOpConfig - !ScalarAssign arg: O value: !ScalarExpression - scalar_fn: - kind: type - fn_name: cast_signed - type_var: U - operands: - - !ScalarExpression - scalar_arg: value + scalar_arg: value --- !LinalgOpConfig metadata: !LinalgOpMetadata name: fill_rng_2d diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp index dcc1ef9e997ea..b4b1347493529 100644 --- a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp @@ -1057,12 +1057,15 @@ LogicalResult mlir::linalg::detail::verifyConvolutionInterface(Operation *op) { // FillOpInterface implementation //===----------------------------------------------------------------------===// +namespace { enum class MatchFillResult { Success = 0, NotLinalgOp, WrongNumOperands, - NotScalarInput + NotScalarInput, + TypeMismatch }; +} // namespace static MatchFillResult isFillInterfaceImpl(Operation *op) { auto linalgOp = dyn_cast(op); @@ -1075,17 +1078,33 @@ static MatchFillResult isFillInterfaceImpl(Operation *op) { if (!linalgOp.isScalar(value)) return MatchFillResult::NotScalarInput; + // Check that the scalar input type matches the output element type. + OpOperand *output = linalgOp.getDpsInitOperand(0); + Type scalarType = value->get().getType(); + Type outputElementType = getElementTypeOrSelf(output->get().getType()); + if (scalarType != outputElementType) + return MatchFillResult::TypeMismatch; + return MatchFillResult::Success; } LogicalResult mlir::linalg::detail::verifyFillInterface(Operation *op) { - auto res = isFillInterfaceImpl(op); + MatchFillResult res = isFillInterfaceImpl(op); if (res == MatchFillResult::NotLinalgOp) return op->emitError("expected a LinalgOp"); if (res == MatchFillResult::WrongNumOperands) return op->emitError("expected op with 1 input and 1 output"); if (res == MatchFillResult::NotScalarInput) return op->emitError("expected op with scalar input"); + if (res == MatchFillResult::TypeMismatch) { + auto linalgOp = cast(op); + Type scalarType = linalgOp.getDpsInputOperand(0)->get().getType(); + Type outputElementType = + getElementTypeOrSelf(linalgOp.getDpsInitOperand(0)->get().getType()); + return op->emitOpError("expected fill value type (") + << scalarType << ") to match output element type (" + << outputElementType << ")"; + } return success(); } diff --git a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py index fd4a5a848f1e3..9c24f94fcf612 100644 --- a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py +++ b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py @@ -1729,16 +1729,16 @@ def pooling_ndhwc_min( @linalg_structured_op -def fill(value=ScalarDef(T1), O=TensorDef(U, output=True)): +def fill(value=ScalarDef(T), O=TensorDef(T, output=True)): """Fills the output tensor with the given value. Works for arbitrary ranked output tensors since the operation performs scalar - accesses only and is thus rank polymorphic. Numeric casting is performed on - the value operand, promoting it to the same data type as the output. + accesses only and is thus rank polymorphic. The value type must match the + element type of the output tensor or memref. """ implements(FillOpInterface) defines(Canonicalizer) - O[None] = TypeFn.cast_signed(U, value) + O[None] = value @linalg_structured_op diff --git a/mlir/test/Dialect/Affine/value-bounds-reification.mlir b/mlir/test/Dialect/Affine/value-bounds-reification.mlir index 817614be50533..2e801028057a1 100644 --- a/mlir/test/Dialect/Affine/value-bounds-reification.mlir +++ b/mlir/test/Dialect/Affine/value-bounds-reification.mlir @@ -36,13 +36,13 @@ func.func @reify_through_chain(%sz0: index, %sz2: index) -> (index, index, index // CHECK: "test.some_use"(%[[c5]]) // CHECK: %[[c5:.*]] = arith.constant 5 : index // CHECK: "test.some_use"(%[[c5]]) -func.func @reify_slice_bound(%t: tensor, %idx: index, %ub: index, %f: f32) { +func.func @reify_slice_bound(%t: tensor, %idx: index, %ub: index, %f: i32) { %c0 = arith.constant 0 : index %c4 = arith.constant 4 : index scf.for %iv = %c0 to %ub step %c4 { %sz = affine.min affine_map<(d0)[s0] -> (-d0 + s0, 4)>(%iv)[%ub] %slice = tensor.extract_slice %t[%idx, %iv] [1, %sz] [1, 1] : tensor to tensor<1x?xi32> - %filled = linalg.fill ins(%f : f32) outs(%slice : tensor<1x?xi32>) -> tensor<1x?xi32> + %filled = linalg.fill ins(%f : i32) outs(%slice : tensor<1x?xi32>) -> tensor<1x?xi32> %bound = "test.reify_bound"(%filled) {dim = 1, type = "UB"} : (tensor<1x?xi32>) -> (index) "test.some_use"(%bound) : (index) -> () diff --git a/mlir/test/Dialect/Linalg/fusion-elementwise-ops.mlir b/mlir/test/Dialect/Linalg/fusion-elementwise-ops.mlir index bc55c12c02f29..6f1a422324e08 100644 --- a/mlir/test/Dialect/Linalg/fusion-elementwise-ops.mlir +++ b/mlir/test/Dialect/Linalg/fusion-elementwise-ops.mlir @@ -921,30 +921,6 @@ func.func @fold_fill_generic_basic(%arg0: tensor) -> (tensor) { // ----- -// CHECK-LABEL: func @fold_fill_generic_different_dtype -// CHECK-SAME: (%[[ARG0:.*]]: tensor) -> tensor { -// CHECK-NOT: linalg.fill -// CHECK: %[[GENERIC_OP:.*]] = linalg.generic -// CHECK-SAME: ins(%[[ARG0]] : tensor) -// CHECK-SAME: outs({{.*}} : tensor) { -#map0 = affine_map<(d0) -> (d0)> -func.func @fold_fill_generic_different_dtype(%arg0: tensor) -> (tensor) { - %c0 = arith.constant 0 : index - %cst = arith.constant 7.0 : f32 - %0 = tensor.dim %arg0, %c0 : tensor - %1 = tensor.empty(%0) : tensor - %2 = linalg.fill ins(%cst : f32) outs(%1 : tensor) -> tensor - %3 = tensor.empty(%0) : tensor - %4 = linalg.generic {indexing_maps = [#map0, #map0, #map0], iterator_types=["parallel"]} ins(%arg0, %2 : tensor, tensor) outs (%3:tensor) { - ^bb0(%arg1: f16, %arg2: f16, %arg3: f16): - %5 = arith.addf %arg1, %arg2 : f16 - linalg.yield %5 : f16 - } -> tensor - return %4 : tensor -} - -// ----- - // CHECK-LABEL: func @fold_fill_generic_mixedaccess // CHECK-NOT: linalg.fill // CHECK: %[[GENERIC_OP:.*]] = linalg.generic @@ -1079,4 +1055,4 @@ module { // CHECK-NOT: linalg.generic // CHECK: tensor.expand_shape // CHECK: linalg.generic {{.*}}, iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel", "parallel", "reduction"]} -// CHECK-SAME: ins(%[[ARG0]], %[[FUSED]]#1 : tensor<1x1x2x1xf32>, tensor<4x1x1x1xf32>) \ No newline at end of file +// CHECK-SAME: ins(%[[ARG0]], %[[FUSED]]#1 : tensor<1x1x2x1xf32>, tensor<4x1x1x1xf32>) diff --git a/mlir/test/Dialect/Linalg/generalize-named-polymorphic-ops.mlir b/mlir/test/Dialect/Linalg/generalize-named-polymorphic-ops.mlir index 290c6c7c36f76..4526dc90fad2e 100644 --- a/mlir/test/Dialect/Linalg/generalize-named-polymorphic-ops.mlir +++ b/mlir/test/Dialect/Linalg/generalize-named-polymorphic-ops.mlir @@ -380,8 +380,8 @@ func.func @generalize_pooling_nwc_sum_i32(%input : tensor<1x16x1xi32>, %shape: t // ----- -func.func @generalize_fill_0d(%value: f64, %O: tensor) -> tensor { - %0 = linalg.fill ins(%value: f64) outs(%O : tensor) -> tensor +func.func @generalize_fill_0d(%value: f32, %O: tensor) -> tensor { + %0 = linalg.fill ins(%value: f32) outs(%O : tensor) -> tensor return %0: tensor } @@ -394,8 +394,8 @@ func.func @generalize_fill_0d(%value: f64, %O: tensor) -> tensor { // ----- -func.func @generalize_fill_2d(%value: f64, %O: memref<16x32xf32>) { - linalg.fill ins(%value: f64) outs(%O : memref<16x32xf32>) +func.func @generalize_fill_2d(%value: f32, %O: memref<16x32xf32>) { + linalg.fill ins(%value: f32) outs(%O : memref<16x32xf32>) return } diff --git a/mlir/test/Dialect/Linalg/invalid.mlir b/mlir/test/Dialect/Linalg/invalid.mlir index fabc8e610612d..1f554e6c45da7 100644 --- a/mlir/test/Dialect/Linalg/invalid.mlir +++ b/mlir/test/Dialect/Linalg/invalid.mlir @@ -352,6 +352,24 @@ func.func @illegal_fill_tensor_with_memref_return // ----- +func.func @illegal_fill_element_type_truncation(%arg0 : tensor<2xf32>, %arg1 : f64) -> tensor<2xf32> +{ + // expected-error @+1 {{'linalg.fill' op expected fill value type ('f64') to match output element type ('f32')}} + %0 = linalg.fill ins(%arg1 : f64) outs(%arg0 : tensor<2xf32>) -> tensor<2xf32> + return %0 : tensor<2xf32> +} + +// ----- + +func.func @illegal_fill_element_type_extension(%arg0 : tensor<2xi32>, %arg1 : i16) -> tensor<2xi32> +{ + // expected-error @+1 {{'linalg.fill' op expected fill value type ('i16') to match output element type ('i32')}} + %0 = linalg.fill ins(%arg1 : i16) outs(%arg0 : tensor<2xi32>) -> tensor<2xi32> + return %0 : tensor<2xi32> +} + +// ----- + func.func @illegal_fill_value_type(%arg0 : tensor<2x2xf32>, %arg1 : tensor<2xf32>) -> tensor<2x2xf32> { // expected-error @+1 {{expected op with scalar input}} diff --git a/mlir/test/Integration/Dialect/Linalg/CPU/test-matmul-masked-vec.mlir b/mlir/test/Integration/Dialect/Linalg/CPU/test-matmul-masked-vec.mlir index 8fa32d7aeb586..bbda8d4e99d04 100644 --- a/mlir/test/Integration/Dialect/Linalg/CPU/test-matmul-masked-vec.mlir +++ b/mlir/test/Integration/Dialect/Linalg/CPU/test-matmul-masked-vec.mlir @@ -27,8 +27,8 @@ func.func @main() { %A_dyn = tensor.cast %A : tensor<8x2xf32> to tensor %B_dyn = tensor.cast %B : tensor<2x4xf32> to tensor - %c0_i32 = arith.constant 0 : i32 - %C_init = linalg.fill ins(%c0_i32 : i32) outs(%C_dyn : tensor) -> tensor + %c0_f32 = arith.constant 0.0 : f32 + %C_init = linalg.fill ins(%c0_f32 : f32) outs(%C_dyn : tensor) -> tensor %res = linalg.matmul ins(%A_dyn, %B_dyn: tensor, tensor) outs(%C_init: tensor) -> tensor diff --git a/mlir/test/Integration/Dialect/Transform/match_matmul.mlir b/mlir/test/Integration/Dialect/Transform/match_matmul.mlir index a374d9a611258..e3fee917cdeaa 100644 --- a/mlir/test/Integration/Dialect/Transform/match_matmul.mlir +++ b/mlir/test/Integration/Dialect/Transform/match_matmul.mlir @@ -63,11 +63,11 @@ func.func @matmul_simple(%lhs: tensor<10x20xf16>, %rhs: tensor<20x15xf32>) -> te } func.func @matmul_with_extra_ops_in_func(%lhs: tensor<10x20xf32>, %rhs: tensor<20x15xf32>) -> tensor<10x15xf32> { - %cst = arith.constant 0.0 : f64 + %cst = arith.constant 0.0 : f32 %empty = tensor.empty() : tensor<10x15xf32> // expected-remark @below {{fill}} - %fill = linalg.fill ins(%cst : f64) outs(%empty : tensor<10x15xf32>) -> tensor<10x15xf32> + %fill = linalg.fill ins(%cst : f32) outs(%empty : tensor<10x15xf32>) -> tensor<10x15xf32> %real_lhs = linalg.mul ins(%lhs, %lhs : tensor<10x20xf32>, tensor<10x20xf32>) outs(%lhs : tensor<10x20xf32>) -> tensor<10x20xf32> diff --git a/mlir/test/python/integration/dialects/linalg/opsrun.py b/mlir/test/python/integration/dialects/linalg/opsrun.py index 8f202318146ee..8eff573f98ad3 100644 --- a/mlir/test/python/integration/dialects/linalg/opsrun.py +++ b/mlir/test/python/integration/dialects/linalg/opsrun.py @@ -25,13 +25,13 @@ def log(*args): %O1 = memref.alloc() : memref<16xi32> %O2 = memref.alloc() : memref<4x16xi32> - %val0 = arith.constant 1.0 : f32 - %val1 = arith.constant 2.0 : f32 - %val2 = arith.constant 3.0 : f32 + %val0 = arith.constant 1 : i32 + %val1 = arith.constant 2 : i32 + %val2 = arith.constant 3 : i32 - call @fill_0d_on_buffers(%val0, %O0) : (f32, memref) -> () - call @fill_1d_on_buffers(%val1, %O1) : (f32, memref<16xi32>) -> () - call @fill_2d_on_buffers(%val2, %O2) : (f32, memref<4x16xi32>) -> () + call @fill_0d_on_buffers(%val0, %O0) : (i32, memref) -> () + call @fill_1d_on_buffers(%val1, %O1) : (i32, memref<16xi32>) -> () + call @fill_2d_on_buffers(%val2, %O2) : (i32, memref<4x16xi32>) -> () %c0 = arith.constant 0 : index %res0 = memref.load %O0[] : memref @@ -149,19 +149,18 @@ def transform(module, boilerplate): def test_fill_builtin(): with Context() as ctx, Location.unknown(): module = Module.create() - f32 = F32Type.get() i32 = IntegerType.get_signless(32) with InsertionPoint(module.body): - @func.FuncOp.from_py_func(f32, MemRefType.get([], i32)) + @func.FuncOp.from_py_func(i32, MemRefType.get([], i32)) def fill_0d_on_buffers(value, out): linalg.fill(value, outs=[out]) - @func.FuncOp.from_py_func(f32, MemRefType.get([16], i32)) + @func.FuncOp.from_py_func(i32, MemRefType.get([16], i32)) def fill_1d_on_buffers(value, out): linalg.fill(value, outs=[out]) - @func.FuncOp.from_py_func(f32, MemRefType.get([4, 16], i32)) + @func.FuncOp.from_py_func(i32, MemRefType.get([4, 16], i32)) def fill_2d_on_buffers(value, out): linalg.fill(value, outs=[out]) @@ -184,19 +183,18 @@ def fill_2d_on_buffers(value, out): def test_fill_generic(): with Context() as ctx, Location.unknown(): module = Module.create() - f32 = F32Type.get() i32 = IntegerType.get_signless(32) with InsertionPoint(module.body): - @func.FuncOp.from_py_func(f32, MemRefType.get([], i32)) + @func.FuncOp.from_py_func(i32, MemRefType.get([], i32)) def fill_0d_on_buffers(value, out): linalg.fill(value, outs=[out], emit_generic=True) - @func.FuncOp.from_py_func(f32, MemRefType.get([16], i32)) + @func.FuncOp.from_py_func(i32, MemRefType.get([16], i32)) def fill_1d_on_buffers(value, out): linalg.fill(value, outs=[out], emit_generic=True) - @func.FuncOp.from_py_func(f32, MemRefType.get([4, 16], i32)) + @func.FuncOp.from_py_func(i32, MemRefType.get([4, 16], i32)) def fill_2d_on_buffers(value, out): linalg.fill(value, outs=[out], emit_generic=True)