-
Notifications
You must be signed in to change notification settings - Fork 15.2k
[mlir][linalg] Reject unsigned pooling on non-integer element types #166070
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
|
Thank you for submitting a Pull Request (PR) to the LLVM Project! This PR will be automatically labeled and the relevant teams will be notified. If you wish to, you can add reviewers by using the "Reviewers" section on this page. If this is not working for you, it is probably because you do not have write permissions for the repository. In which case you can instead tag reviewers by name in a comment by using If you have received no comments on your PR for a week, you can request a review by "ping"ing the PR by adding a comment “Ping”. The common courtesy "ping" rate is once a week. Please remember that you are asking for valuable time from other developers. If you have further questions, they may be answered by the LLVM GitHub User Guide. You can also ask questions in a comment on this PR, on the LLVM Discord or on the forums. |
|
@llvm/pr-subscribers-mlir-linalg Author: Men-cotton (Men-cotton) Changes#164800 Ensures unsigned pooling ops in Linalg stay in the integer domain: the lowering now rejects floating/bool inputs with a clear diagnostic, new regression tests lock in both the error path and a valid integer example, and transform decompositions are updated to reflect the integer typing. CC: @banach-space Full diff: https://github.com/llvm/llvm-project/pull/166070.diff 4 Files Affected:
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 3dc45edf4a23f..8eb03dc182ae9 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -579,13 +579,23 @@ class RegionBuilderHelper {
return arith::MinSIOp::create(builder, arg0.getLoc(), arg0, arg1);
case BinaryFn::max_unsigned:
assert(!allComplex);
- if (allFloatingPoint)
- return arith::MaximumFOp::create(builder, arg0.getLoc(), arg0, arg1);
+ if (!allInteger || allBool) {
+ if (emitError) {
+ emitError() << "unsupported operation: unsigned max not on uint";
+ return nullptr;
+ }
+ llvm_unreachable("unsupported operation: unsigned max not on uint");
+ }
return arith::MaxUIOp::create(builder, arg0.getLoc(), arg0, arg1);
case BinaryFn::min_unsigned:
assert(!allComplex);
- if (allFloatingPoint)
- return arith::MinimumFOp::create(builder, arg0.getLoc(), arg0, arg1);
+ if (!allInteger || allBool) {
+ if (emitError) {
+ emitError() << "unsupported operation: unsigned min not on uint";
+ return nullptr;
+ }
+ llvm_unreachable("unsupported operation: unsigned min not on uint");
+ }
return arith::MinUIOp::create(builder, arg0.getLoc(), arg0, arg1);
case BinaryFn::powf:
assert(allFloatingPoint);
diff --git a/mlir/test/Dialect/Linalg/named-ops-fail.mlir b/mlir/test/Dialect/Linalg/named-ops-fail.mlir
index 552a0abaa797c..4ecf685b4c695 100644
--- a/mlir/test/Dialect/Linalg/named-ops-fail.mlir
+++ b/mlir/test/Dialect/Linalg/named-ops-fail.mlir
@@ -80,6 +80,20 @@ func.func @divu_broadcast(%arg0: memref<8x16xi32>, %arg1: memref<4x8x16xi32>, %a
// -----
+func.func @pooling_nhwc_max_unsigned_float(
+ %input: tensor<?x?x?x?xf32>,
+ %filter: tensor<?x?xf32>,
+ %init_val: tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32> {
+ // CHECK: unsupported operation: unsigned max not on uint
+ linalg.pooling_nhwc_max_unsigned {dilations = dense<1> : tensor<2xi64>,
+ strides = dense<1> : tensor<2xi64>}
+ ins (%input, %filter: tensor<?x?x?x?xf32>, tensor<?x?xf32>)
+ outs (%init_val: tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32>
+ return %init_val : tensor<?x?x?x?xf32>
+}
+
+// -----
+
func.func @exp_type_cast(%arg: memref<4x8x16xf16>, %out: memref<4x8x16xf32>) {
// CHECK: operand 1 ('f16') doesn't match the element type of the enclosing linalg.generic op ('f32')
linalg.exp ins(%arg : memref<4x8x16xf16>) outs(%out: memref<4x8x16xf32>)
@@ -349,4 +363,3 @@ func.func @select_wrong_condition_type(%arg0: memref<4x8x16xf32>, %arg1: memref<
linalg.select ins(%arg0, %arg1, %arg2 : memref<4x8x16xf32>, memref<4x8x16xf32>, memref<4x8x16xf32>) outs(%arg3: memref<4x8x16xf32>)
return
}
-
diff --git a/mlir/test/Dialect/Linalg/named-ops.mlir b/mlir/test/Dialect/Linalg/named-ops.mlir
index a93e9799ceb3f..c2a8f24624d8e 100644
--- a/mlir/test/Dialect/Linalg/named-ops.mlir
+++ b/mlir/test/Dialect/Linalg/named-ops.mlir
@@ -705,6 +705,23 @@ func.func @pooling_nhwc_max_tensor(%input: tensor<1x4x4x1xf32>) -> tensor<1x2x2x
return %res : tensor<1x2x2x1xf32>
}
+// -----
+
+// CHECK-LABEL: func @pooling_nhwc_max_unsigned_tensor
+// CHECK: %{{.+}} = linalg.pooling_nhwc_max_unsigned
+// CHECK-SAME: ins(%{{.+}}, %{{.+}} : tensor<1x4x4x1xi32>, tensor<3x3xi32>)
+// CHECK-SAME: outs(%{{.+}} : tensor<1x2x2x1xi32>) -> tensor<1x2x2x1xi32>
+func.func @pooling_nhwc_max_unsigned_tensor(%input: tensor<1x4x4x1xi32>) -> tensor<1x2x2x1xi32> {
+ %fake = tensor.empty() : tensor<3x3xi32>
+ %init = tensor.empty() : tensor<1x2x2x1xi32>
+ %cst = arith.constant 0 : i32
+ %fill = linalg.fill ins(%cst : i32) outs(%init : tensor<1x2x2x1xi32>) -> tensor<1x2x2x1xi32>
+ %res = linalg.pooling_nhwc_max_unsigned {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>}
+ ins(%input, %fake: tensor<1x4x4x1xi32>, tensor<3x3xi32>)
+ outs(%fill: tensor<1x2x2x1xi32>) -> tensor<1x2x2x1xi32>
+ return %res : tensor<1x2x2x1xi32>
+}
+
// -----
// CHECK-LABEL: func @pooling_nwc_max_tensor
// CHECK: %{{.+}} = linalg.pooling_nwc_max
@@ -1017,6 +1034,23 @@ func.func @pooling_nhwc_min_tensor(%input: tensor<1x4x4x1xf32>) -> tensor<1x2x2x
// -----
+// CHECK-LABEL: func @pooling_nhwc_min_unsigned_tensor
+// CHECK: %{{.+}} = linalg.pooling_nhwc_min_unsigned
+// CHECK-SAME: ins(%{{.+}}, %{{.+}} : tensor<1x4x4x1xi32>, tensor<3x3xi32>)
+// CHECK-SAME: outs(%{{.+}} : tensor<1x2x2x1xi32>) -> tensor<1x2x2x1xi32>
+func.func @pooling_nhwc_min_unsigned_tensor(%input: tensor<1x4x4x1xi32>) -> tensor<1x2x2x1xi32> {
+ %fake = tensor.empty() : tensor<3x3xi32>
+ %init = tensor.empty() : tensor<1x2x2x1xi32>
+ %cst = arith.constant 0 : i32
+ %fill = linalg.fill ins(%cst : i32) outs(%init : tensor<1x2x2x1xi32>) -> tensor<1x2x2x1xi32>
+ %res = linalg.pooling_nhwc_min_unsigned {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>}
+ ins(%input, %fake: tensor<1x4x4x1xi32>, tensor<3x3xi32>)
+ outs(%fill: tensor<1x2x2x1xi32>) -> tensor<1x2x2x1xi32>
+ return %res : tensor<1x2x2x1xi32>
+}
+
+// -----
+
// CHECK-LABEL: func @pooling_nwc_min_tensor
// CHECK: %{{.+}} = linalg.pooling_nwc_min
// CHECK-SAME: dilations = dense<1> : tensor<1xi64>
diff --git a/mlir/test/Dialect/Linalg/transform-op-decompose.mlir b/mlir/test/Dialect/Linalg/transform-op-decompose.mlir
index 72acf43361f50..60a4c555fa19a 100644
--- a/mlir/test/Dialect/Linalg/transform-op-decompose.mlir
+++ b/mlir/test/Dialect/Linalg/transform-op-decompose.mlir
@@ -131,10 +131,10 @@ func.func @pooling_nhwc_max(%input: tensor<?x1x?x?xf32>, %filter: tensor<1x?xf32
}
// CHECK-LABEL: @pooling_nhwc_max_unsigned
-// CHECK-SAME: %[[ARG0:.+]]: tensor<?x1x?x?xf32>,
-// CHECK-SAME: %[[ARG1:.+]]: tensor<1x?xf32>
-// CHECK-SAME: %[[ARG2:.+]]: tensor<?x1x?x?xf32>
-func.func @pooling_nhwc_max_unsigned(%input: tensor<?x1x?x?xf32>, %filter: tensor<1x?xf32>, %init: tensor<?x1x?x?xf32>) -> tensor<?x1x?x?xf32> {
+// CHECK-SAME: %[[ARG0:.+]]: tensor<?x1x?x?xi32>,
+// CHECK-SAME: %[[ARG1:.+]]: tensor<1x?xi32>
+// CHECK-SAME: %[[ARG2:.+]]: tensor<?x1x?x?xi32>
+func.func @pooling_nhwc_max_unsigned(%input: tensor<?x1x?x?xi32>, %filter: tensor<1x?xi32>, %init: tensor<?x1x?x?xi32>) -> tensor<?x1x?x?xi32> {
// CHECK: %[[SLICE0:.+]] = tensor.extract_slice %[[ARG0]]
// CHECK: %[[SLICE1:.+]] = tensor.extract_slice %[[ARG1]]
// CHECK: %[[SLICE2:.+]] = tensor.extract_slice %[[ARG2]]
@@ -142,10 +142,10 @@ func.func @pooling_nhwc_max_unsigned(%input: tensor<?x1x?x?xf32>, %filter: tenso
// CHECK: %[[RES:.+]] = tensor.insert_slice %[[SLICERES]] into %[[ARG2]]
%0 = linalg.pooling_nhwc_max_unsigned {dilations = dense<1> : tensor<2xi64>,
strides = dense<1> : tensor<2xi64>}
- ins (%input, %filter: tensor<?x1x?x?xf32>, tensor<1x?xf32>)
- outs (%init: tensor<?x1x?x?xf32>) -> tensor<?x1x?x?xf32>
+ ins (%input, %filter: tensor<?x1x?x?xi32>, tensor<1x?xi32>)
+ outs (%init: tensor<?x1x?x?xi32>) -> tensor<?x1x?x?xi32>
// CHECK: return %[[RES]]
- return %0 : tensor<?x1x?x?xf32>
+ return %0 : tensor<?x1x?x?xi32>
}
// CHECK-LABEL: @pooling_nhwc_min
@@ -167,10 +167,10 @@ func.func @pooling_nhwc_min(%input: tensor<?x1x?x?xf32>, %filter: tensor<1x?xf32
}
// CHECK-LABEL: @pooling_nhwc_min_unsigned
-// CHECK-SAME: %[[ARG0:.+]]: tensor<?x1x?x?xf32>,
-// CHECK-SAME: %[[ARG1:.+]]: tensor<1x?xf32>
-// CHECK-SAME: %[[ARG2:.+]]: tensor<?x1x?x?xf32>
-func.func @pooling_nhwc_min_unsigned(%input: tensor<?x1x?x?xf32>, %filter: tensor<1x?xf32>, %init: tensor<?x1x?x?xf32>) -> tensor<?x1x?x?xf32> {
+// CHECK-SAME: %[[ARG0:.+]]: tensor<?x1x?x?xi32>,
+// CHECK-SAME: %[[ARG1:.+]]: tensor<1x?xi32>
+// CHECK-SAME: %[[ARG2:.+]]: tensor<?x1x?x?xi32>
+func.func @pooling_nhwc_min_unsigned(%input: tensor<?x1x?x?xi32>, %filter: tensor<1x?xi32>, %init: tensor<?x1x?x?xi32>) -> tensor<?x1x?x?xi32> {
// CHECK: %[[SLICE0:.+]] = tensor.extract_slice %[[ARG0]]
// CHECK: %[[SLICE1:.+]] = tensor.extract_slice %[[ARG1]]
// CHECK: %[[SLICE2:.+]] = tensor.extract_slice %[[ARG2]]
@@ -178,10 +178,10 @@ func.func @pooling_nhwc_min_unsigned(%input: tensor<?x1x?x?xf32>, %filter: tenso
// CHECK: %[[RES:.+]] = tensor.insert_slice %[[SLICERES]] into %[[ARG2]]
%0 = linalg.pooling_nhwc_min_unsigned {dilations = dense<1> : tensor<2xi64>,
strides = dense<1> : tensor<2xi64>}
- ins (%input, %filter: tensor<?x1x?x?xf32>, tensor<1x?xf32>)
- outs (%init: tensor<?x1x?x?xf32>) -> tensor<?x1x?x?xf32>
+ ins (%input, %filter: tensor<?x1x?x?xi32>, tensor<1x?xi32>)
+ outs (%init: tensor<?x1x?x?xi32>) -> tensor<?x1x?x?xi32>
// CHECK: return %[[RES]]
- return %0 : tensor<?x1x?x?xf32>
+ return %0 : tensor<?x1x?x?xi32>
}
// CHECK-LABEL: @pooling_nchw_max
|
|
@llvm/pr-subscribers-mlir Author: Men-cotton (Men-cotton) Changes#164800 Ensures unsigned pooling ops in Linalg stay in the integer domain: the lowering now rejects floating/bool inputs with a clear diagnostic, new regression tests lock in both the error path and a valid integer example, and transform decompositions are updated to reflect the integer typing. CC: @banach-space Full diff: https://github.com/llvm/llvm-project/pull/166070.diff 4 Files Affected:
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 3dc45edf4a23f..8eb03dc182ae9 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -579,13 +579,23 @@ class RegionBuilderHelper {
return arith::MinSIOp::create(builder, arg0.getLoc(), arg0, arg1);
case BinaryFn::max_unsigned:
assert(!allComplex);
- if (allFloatingPoint)
- return arith::MaximumFOp::create(builder, arg0.getLoc(), arg0, arg1);
+ if (!allInteger || allBool) {
+ if (emitError) {
+ emitError() << "unsupported operation: unsigned max not on uint";
+ return nullptr;
+ }
+ llvm_unreachable("unsupported operation: unsigned max not on uint");
+ }
return arith::MaxUIOp::create(builder, arg0.getLoc(), arg0, arg1);
case BinaryFn::min_unsigned:
assert(!allComplex);
- if (allFloatingPoint)
- return arith::MinimumFOp::create(builder, arg0.getLoc(), arg0, arg1);
+ if (!allInteger || allBool) {
+ if (emitError) {
+ emitError() << "unsupported operation: unsigned min not on uint";
+ return nullptr;
+ }
+ llvm_unreachable("unsupported operation: unsigned min not on uint");
+ }
return arith::MinUIOp::create(builder, arg0.getLoc(), arg0, arg1);
case BinaryFn::powf:
assert(allFloatingPoint);
diff --git a/mlir/test/Dialect/Linalg/named-ops-fail.mlir b/mlir/test/Dialect/Linalg/named-ops-fail.mlir
index 552a0abaa797c..4ecf685b4c695 100644
--- a/mlir/test/Dialect/Linalg/named-ops-fail.mlir
+++ b/mlir/test/Dialect/Linalg/named-ops-fail.mlir
@@ -80,6 +80,20 @@ func.func @divu_broadcast(%arg0: memref<8x16xi32>, %arg1: memref<4x8x16xi32>, %a
// -----
+func.func @pooling_nhwc_max_unsigned_float(
+ %input: tensor<?x?x?x?xf32>,
+ %filter: tensor<?x?xf32>,
+ %init_val: tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32> {
+ // CHECK: unsupported operation: unsigned max not on uint
+ linalg.pooling_nhwc_max_unsigned {dilations = dense<1> : tensor<2xi64>,
+ strides = dense<1> : tensor<2xi64>}
+ ins (%input, %filter: tensor<?x?x?x?xf32>, tensor<?x?xf32>)
+ outs (%init_val: tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32>
+ return %init_val : tensor<?x?x?x?xf32>
+}
+
+// -----
+
func.func @exp_type_cast(%arg: memref<4x8x16xf16>, %out: memref<4x8x16xf32>) {
// CHECK: operand 1 ('f16') doesn't match the element type of the enclosing linalg.generic op ('f32')
linalg.exp ins(%arg : memref<4x8x16xf16>) outs(%out: memref<4x8x16xf32>)
@@ -349,4 +363,3 @@ func.func @select_wrong_condition_type(%arg0: memref<4x8x16xf32>, %arg1: memref<
linalg.select ins(%arg0, %arg1, %arg2 : memref<4x8x16xf32>, memref<4x8x16xf32>, memref<4x8x16xf32>) outs(%arg3: memref<4x8x16xf32>)
return
}
-
diff --git a/mlir/test/Dialect/Linalg/named-ops.mlir b/mlir/test/Dialect/Linalg/named-ops.mlir
index a93e9799ceb3f..c2a8f24624d8e 100644
--- a/mlir/test/Dialect/Linalg/named-ops.mlir
+++ b/mlir/test/Dialect/Linalg/named-ops.mlir
@@ -705,6 +705,23 @@ func.func @pooling_nhwc_max_tensor(%input: tensor<1x4x4x1xf32>) -> tensor<1x2x2x
return %res : tensor<1x2x2x1xf32>
}
+// -----
+
+// CHECK-LABEL: func @pooling_nhwc_max_unsigned_tensor
+// CHECK: %{{.+}} = linalg.pooling_nhwc_max_unsigned
+// CHECK-SAME: ins(%{{.+}}, %{{.+}} : tensor<1x4x4x1xi32>, tensor<3x3xi32>)
+// CHECK-SAME: outs(%{{.+}} : tensor<1x2x2x1xi32>) -> tensor<1x2x2x1xi32>
+func.func @pooling_nhwc_max_unsigned_tensor(%input: tensor<1x4x4x1xi32>) -> tensor<1x2x2x1xi32> {
+ %fake = tensor.empty() : tensor<3x3xi32>
+ %init = tensor.empty() : tensor<1x2x2x1xi32>
+ %cst = arith.constant 0 : i32
+ %fill = linalg.fill ins(%cst : i32) outs(%init : tensor<1x2x2x1xi32>) -> tensor<1x2x2x1xi32>
+ %res = linalg.pooling_nhwc_max_unsigned {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>}
+ ins(%input, %fake: tensor<1x4x4x1xi32>, tensor<3x3xi32>)
+ outs(%fill: tensor<1x2x2x1xi32>) -> tensor<1x2x2x1xi32>
+ return %res : tensor<1x2x2x1xi32>
+}
+
// -----
// CHECK-LABEL: func @pooling_nwc_max_tensor
// CHECK: %{{.+}} = linalg.pooling_nwc_max
@@ -1017,6 +1034,23 @@ func.func @pooling_nhwc_min_tensor(%input: tensor<1x4x4x1xf32>) -> tensor<1x2x2x
// -----
+// CHECK-LABEL: func @pooling_nhwc_min_unsigned_tensor
+// CHECK: %{{.+}} = linalg.pooling_nhwc_min_unsigned
+// CHECK-SAME: ins(%{{.+}}, %{{.+}} : tensor<1x4x4x1xi32>, tensor<3x3xi32>)
+// CHECK-SAME: outs(%{{.+}} : tensor<1x2x2x1xi32>) -> tensor<1x2x2x1xi32>
+func.func @pooling_nhwc_min_unsigned_tensor(%input: tensor<1x4x4x1xi32>) -> tensor<1x2x2x1xi32> {
+ %fake = tensor.empty() : tensor<3x3xi32>
+ %init = tensor.empty() : tensor<1x2x2x1xi32>
+ %cst = arith.constant 0 : i32
+ %fill = linalg.fill ins(%cst : i32) outs(%init : tensor<1x2x2x1xi32>) -> tensor<1x2x2x1xi32>
+ %res = linalg.pooling_nhwc_min_unsigned {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>}
+ ins(%input, %fake: tensor<1x4x4x1xi32>, tensor<3x3xi32>)
+ outs(%fill: tensor<1x2x2x1xi32>) -> tensor<1x2x2x1xi32>
+ return %res : tensor<1x2x2x1xi32>
+}
+
+// -----
+
// CHECK-LABEL: func @pooling_nwc_min_tensor
// CHECK: %{{.+}} = linalg.pooling_nwc_min
// CHECK-SAME: dilations = dense<1> : tensor<1xi64>
diff --git a/mlir/test/Dialect/Linalg/transform-op-decompose.mlir b/mlir/test/Dialect/Linalg/transform-op-decompose.mlir
index 72acf43361f50..60a4c555fa19a 100644
--- a/mlir/test/Dialect/Linalg/transform-op-decompose.mlir
+++ b/mlir/test/Dialect/Linalg/transform-op-decompose.mlir
@@ -131,10 +131,10 @@ func.func @pooling_nhwc_max(%input: tensor<?x1x?x?xf32>, %filter: tensor<1x?xf32
}
// CHECK-LABEL: @pooling_nhwc_max_unsigned
-// CHECK-SAME: %[[ARG0:.+]]: tensor<?x1x?x?xf32>,
-// CHECK-SAME: %[[ARG1:.+]]: tensor<1x?xf32>
-// CHECK-SAME: %[[ARG2:.+]]: tensor<?x1x?x?xf32>
-func.func @pooling_nhwc_max_unsigned(%input: tensor<?x1x?x?xf32>, %filter: tensor<1x?xf32>, %init: tensor<?x1x?x?xf32>) -> tensor<?x1x?x?xf32> {
+// CHECK-SAME: %[[ARG0:.+]]: tensor<?x1x?x?xi32>,
+// CHECK-SAME: %[[ARG1:.+]]: tensor<1x?xi32>
+// CHECK-SAME: %[[ARG2:.+]]: tensor<?x1x?x?xi32>
+func.func @pooling_nhwc_max_unsigned(%input: tensor<?x1x?x?xi32>, %filter: tensor<1x?xi32>, %init: tensor<?x1x?x?xi32>) -> tensor<?x1x?x?xi32> {
// CHECK: %[[SLICE0:.+]] = tensor.extract_slice %[[ARG0]]
// CHECK: %[[SLICE1:.+]] = tensor.extract_slice %[[ARG1]]
// CHECK: %[[SLICE2:.+]] = tensor.extract_slice %[[ARG2]]
@@ -142,10 +142,10 @@ func.func @pooling_nhwc_max_unsigned(%input: tensor<?x1x?x?xf32>, %filter: tenso
// CHECK: %[[RES:.+]] = tensor.insert_slice %[[SLICERES]] into %[[ARG2]]
%0 = linalg.pooling_nhwc_max_unsigned {dilations = dense<1> : tensor<2xi64>,
strides = dense<1> : tensor<2xi64>}
- ins (%input, %filter: tensor<?x1x?x?xf32>, tensor<1x?xf32>)
- outs (%init: tensor<?x1x?x?xf32>) -> tensor<?x1x?x?xf32>
+ ins (%input, %filter: tensor<?x1x?x?xi32>, tensor<1x?xi32>)
+ outs (%init: tensor<?x1x?x?xi32>) -> tensor<?x1x?x?xi32>
// CHECK: return %[[RES]]
- return %0 : tensor<?x1x?x?xf32>
+ return %0 : tensor<?x1x?x?xi32>
}
// CHECK-LABEL: @pooling_nhwc_min
@@ -167,10 +167,10 @@ func.func @pooling_nhwc_min(%input: tensor<?x1x?x?xf32>, %filter: tensor<1x?xf32
}
// CHECK-LABEL: @pooling_nhwc_min_unsigned
-// CHECK-SAME: %[[ARG0:.+]]: tensor<?x1x?x?xf32>,
-// CHECK-SAME: %[[ARG1:.+]]: tensor<1x?xf32>
-// CHECK-SAME: %[[ARG2:.+]]: tensor<?x1x?x?xf32>
-func.func @pooling_nhwc_min_unsigned(%input: tensor<?x1x?x?xf32>, %filter: tensor<1x?xf32>, %init: tensor<?x1x?x?xf32>) -> tensor<?x1x?x?xf32> {
+// CHECK-SAME: %[[ARG0:.+]]: tensor<?x1x?x?xi32>,
+// CHECK-SAME: %[[ARG1:.+]]: tensor<1x?xi32>
+// CHECK-SAME: %[[ARG2:.+]]: tensor<?x1x?x?xi32>
+func.func @pooling_nhwc_min_unsigned(%input: tensor<?x1x?x?xi32>, %filter: tensor<1x?xi32>, %init: tensor<?x1x?x?xi32>) -> tensor<?x1x?x?xi32> {
// CHECK: %[[SLICE0:.+]] = tensor.extract_slice %[[ARG0]]
// CHECK: %[[SLICE1:.+]] = tensor.extract_slice %[[ARG1]]
// CHECK: %[[SLICE2:.+]] = tensor.extract_slice %[[ARG2]]
@@ -178,10 +178,10 @@ func.func @pooling_nhwc_min_unsigned(%input: tensor<?x1x?x?xf32>, %filter: tenso
// CHECK: %[[RES:.+]] = tensor.insert_slice %[[SLICERES]] into %[[ARG2]]
%0 = linalg.pooling_nhwc_min_unsigned {dilations = dense<1> : tensor<2xi64>,
strides = dense<1> : tensor<2xi64>}
- ins (%input, %filter: tensor<?x1x?x?xf32>, tensor<1x?xf32>)
- outs (%init: tensor<?x1x?x?xf32>) -> tensor<?x1x?x?xf32>
+ ins (%input, %filter: tensor<?x1x?x?xi32>, tensor<1x?xi32>)
+ outs (%init: tensor<?x1x?x?xi32>) -> tensor<?x1x?x?xi32>
// CHECK: return %[[RES]]
- return %0 : tensor<?x1x?x?xf32>
+ return %0 : tensor<?x1x?x?xi32>
}
// CHECK-LABEL: @pooling_nchw_max
|
| if (!allInteger || allBool) { | ||
| if (emitError) { | ||
| emitError() << "unsupported operation: unsigned max not on uint"; | ||
| return nullptr; | ||
| } | ||
| llvm_unreachable("unsupported operation: unsigned max not on uint"); | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi @Men-cotton - thanks for taking this up.
Although this addresses the immediate issue of linalg.*_(max|min)_unsigned_* - a better solution might be to indeed fix :-
llvm-project/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py
Lines 534 to 539 in f5885de
| def _binary_max_unsigned(self, lhs: Value, rhs: Value) -> Value: | |
| if _is_floating_point_type(lhs.type): | |
| return arith.MaximumFOp(lhs, rhs).result | |
| if _is_integer_type(lhs.type) or _is_index_type(lhs.type): | |
| return arith.MaxUIOp(lhs, rhs).result | |
| raise NotImplementedError("Unsupported 'max_unsigned' operands: {lhs}, {rhs}") |
Because that'd ensure that in future some other op implementing a (max|min)_unsigned also doesn't warrant a similar fix in their verifier as the one above.
I also see something similar already done in this file for a non-Convolution op : linalg.div_unsigned :-
llvm-project/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
Lines 561 to 568 in 03ef5fc
| case BinaryFn::div_unsigned: | |
| if (!allInteger || allBool) { | |
| if (emitError) { | |
| emitError() << "unsupported operation: unsigned div not on uint"; | |
| return nullptr; | |
| } | |
| llvm_unreachable("unsupported operation: unsigned div not on uint"); | |
| } |
Again, these are just my opinion - I'll let @banach-space take a closer look. :)
|
@banach-space Ping |
#164800
Ensures unsigned pooling ops in Linalg stay in the integer domain: the lowering now rejects floating/bool inputs with a clear diagnostic, new regression tests lock in both the error path and a valid integer example, and transform decompositions are updated to reflect the integer typing.
CC: @banach-space