Skip to content

Conversation

@Men-cotton
Copy link
Contributor

#164800

Ensures unsigned pooling ops in Linalg stay in the integer domain: the lowering now rejects floating/bool inputs with a clear diagnostic, new regression tests lock in both the error path and a valid integer example, and transform decompositions are updated to reflect the integer typing.

CC: @banach-space

@github-actions
Copy link

github-actions bot commented Nov 2, 2025

Thank you for submitting a Pull Request (PR) to the LLVM Project!

This PR will be automatically labeled and the relevant teams will be notified.

If you wish to, you can add reviewers by using the "Reviewers" section on this page.

If this is not working for you, it is probably because you do not have write permissions for the repository. In which case you can instead tag reviewers by name in a comment by using @ followed by their GitHub username.

If you have received no comments on your PR for a week, you can request a review by "ping"ing the PR by adding a comment “Ping”. The common courtesy "ping" rate is once a week. Please remember that you are asking for valuable time from other developers.

If you have further questions, they may be answered by the LLVM GitHub User Guide.

You can also ask questions in a comment on this PR, on the LLVM Discord or on the forums.

@llvmbot
Copy link
Member

llvmbot commented Nov 2, 2025

@llvm/pr-subscribers-mlir-linalg

Author: Men-cotton (Men-cotton)

Changes

#164800

Ensures unsigned pooling ops in Linalg stay in the integer domain: the lowering now rejects floating/bool inputs with a clear diagnostic, new regression tests lock in both the error path and a valid integer example, and transform decompositions are updated to reflect the integer typing.

CC: @banach-space


Full diff: https://github.com/llvm/llvm-project/pull/166070.diff

4 Files Affected:

  • (modified) mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp (+14-4)
  • (modified) mlir/test/Dialect/Linalg/named-ops-fail.mlir (+14-1)
  • (modified) mlir/test/Dialect/Linalg/named-ops.mlir (+34)
  • (modified) mlir/test/Dialect/Linalg/transform-op-decompose.mlir (+14-14)
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 3dc45edf4a23f..8eb03dc182ae9 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -579,13 +579,23 @@ class RegionBuilderHelper {
       return arith::MinSIOp::create(builder, arg0.getLoc(), arg0, arg1);
     case BinaryFn::max_unsigned:
       assert(!allComplex);
-      if (allFloatingPoint)
-        return arith::MaximumFOp::create(builder, arg0.getLoc(), arg0, arg1);
+      if (!allInteger || allBool) {
+        if (emitError) {
+          emitError() << "unsupported operation: unsigned max not on uint";
+          return nullptr;
+        }
+        llvm_unreachable("unsupported operation: unsigned max not on uint");
+      }
       return arith::MaxUIOp::create(builder, arg0.getLoc(), arg0, arg1);
     case BinaryFn::min_unsigned:
       assert(!allComplex);
-      if (allFloatingPoint)
-        return arith::MinimumFOp::create(builder, arg0.getLoc(), arg0, arg1);
+      if (!allInteger || allBool) {
+        if (emitError) {
+          emitError() << "unsupported operation: unsigned min not on uint";
+          return nullptr;
+        }
+        llvm_unreachable("unsupported operation: unsigned min not on uint");
+      }
       return arith::MinUIOp::create(builder, arg0.getLoc(), arg0, arg1);
     case BinaryFn::powf:
       assert(allFloatingPoint);
diff --git a/mlir/test/Dialect/Linalg/named-ops-fail.mlir b/mlir/test/Dialect/Linalg/named-ops-fail.mlir
index 552a0abaa797c..4ecf685b4c695 100644
--- a/mlir/test/Dialect/Linalg/named-ops-fail.mlir
+++ b/mlir/test/Dialect/Linalg/named-ops-fail.mlir
@@ -80,6 +80,20 @@ func.func @divu_broadcast(%arg0: memref<8x16xi32>, %arg1: memref<4x8x16xi32>, %a
 
 // -----
 
+func.func @pooling_nhwc_max_unsigned_float(
+    %input: tensor<?x?x?x?xf32>,
+    %filter: tensor<?x?xf32>,
+    %init_val: tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32> {
+  // CHECK: unsupported operation: unsigned max not on uint
+  linalg.pooling_nhwc_max_unsigned {dilations = dense<1> : tensor<2xi64>,
+                                    strides = dense<1> : tensor<2xi64>}
+      ins (%input, %filter: tensor<?x?x?x?xf32>, tensor<?x?xf32>)
+     outs (%init_val: tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32>
+  return %init_val : tensor<?x?x?x?xf32>
+}
+
+// -----
+
 func.func @exp_type_cast(%arg: memref<4x8x16xf16>, %out: memref<4x8x16xf32>) {
   // CHECK: operand 1 ('f16') doesn't match the element type of the enclosing linalg.generic op ('f32')
   linalg.exp ins(%arg : memref<4x8x16xf16>) outs(%out: memref<4x8x16xf32>)
@@ -349,4 +363,3 @@ func.func @select_wrong_condition_type(%arg0: memref<4x8x16xf32>, %arg1: memref<
   linalg.select ins(%arg0, %arg1, %arg2 : memref<4x8x16xf32>, memref<4x8x16xf32>, memref<4x8x16xf32>) outs(%arg3: memref<4x8x16xf32>)
   return
 }
-
diff --git a/mlir/test/Dialect/Linalg/named-ops.mlir b/mlir/test/Dialect/Linalg/named-ops.mlir
index a93e9799ceb3f..c2a8f24624d8e 100644
--- a/mlir/test/Dialect/Linalg/named-ops.mlir
+++ b/mlir/test/Dialect/Linalg/named-ops.mlir
@@ -705,6 +705,23 @@ func.func @pooling_nhwc_max_tensor(%input: tensor<1x4x4x1xf32>) -> tensor<1x2x2x
   return %res : tensor<1x2x2x1xf32>
 }
 
+// -----
+
+// CHECK-LABEL: func @pooling_nhwc_max_unsigned_tensor
+// CHECK:         %{{.+}} = linalg.pooling_nhwc_max_unsigned
+// CHECK-SAME:      ins(%{{.+}}, %{{.+}} : tensor<1x4x4x1xi32>, tensor<3x3xi32>)
+// CHECK-SAME:      outs(%{{.+}} : tensor<1x2x2x1xi32>) -> tensor<1x2x2x1xi32>
+func.func @pooling_nhwc_max_unsigned_tensor(%input: tensor<1x4x4x1xi32>) -> tensor<1x2x2x1xi32> {
+  %fake = tensor.empty() : tensor<3x3xi32>
+  %init = tensor.empty() : tensor<1x2x2x1xi32>
+  %cst = arith.constant 0 : i32
+  %fill = linalg.fill ins(%cst : i32) outs(%init : tensor<1x2x2x1xi32>) -> tensor<1x2x2x1xi32>
+  %res = linalg.pooling_nhwc_max_unsigned {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>}
+    ins(%input, %fake: tensor<1x4x4x1xi32>, tensor<3x3xi32>)
+    outs(%fill: tensor<1x2x2x1xi32>) -> tensor<1x2x2x1xi32>
+  return %res : tensor<1x2x2x1xi32>
+}
+
 // -----
 // CHECK-LABEL: func @pooling_nwc_max_tensor
 // CHECK:         %{{.+}} = linalg.pooling_nwc_max
@@ -1017,6 +1034,23 @@ func.func @pooling_nhwc_min_tensor(%input: tensor<1x4x4x1xf32>) -> tensor<1x2x2x
 
 // -----
 
+// CHECK-LABEL: func @pooling_nhwc_min_unsigned_tensor
+// CHECK:         %{{.+}} = linalg.pooling_nhwc_min_unsigned
+// CHECK-SAME:      ins(%{{.+}}, %{{.+}} : tensor<1x4x4x1xi32>, tensor<3x3xi32>)
+// CHECK-SAME:      outs(%{{.+}} : tensor<1x2x2x1xi32>) -> tensor<1x2x2x1xi32>
+func.func @pooling_nhwc_min_unsigned_tensor(%input: tensor<1x4x4x1xi32>) -> tensor<1x2x2x1xi32> {
+  %fake = tensor.empty() : tensor<3x3xi32>
+  %init = tensor.empty() : tensor<1x2x2x1xi32>
+  %cst = arith.constant 0 : i32
+  %fill = linalg.fill ins(%cst : i32) outs(%init : tensor<1x2x2x1xi32>) -> tensor<1x2x2x1xi32>
+  %res = linalg.pooling_nhwc_min_unsigned {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>}
+    ins(%input, %fake: tensor<1x4x4x1xi32>, tensor<3x3xi32>)
+    outs(%fill: tensor<1x2x2x1xi32>) -> tensor<1x2x2x1xi32>
+  return %res : tensor<1x2x2x1xi32>
+}
+
+// -----
+
 // CHECK-LABEL: func @pooling_nwc_min_tensor
 // CHECK:         %{{.+}} = linalg.pooling_nwc_min
 // CHECK-SAME:      dilations = dense<1> : tensor<1xi64>
diff --git a/mlir/test/Dialect/Linalg/transform-op-decompose.mlir b/mlir/test/Dialect/Linalg/transform-op-decompose.mlir
index 72acf43361f50..60a4c555fa19a 100644
--- a/mlir/test/Dialect/Linalg/transform-op-decompose.mlir
+++ b/mlir/test/Dialect/Linalg/transform-op-decompose.mlir
@@ -131,10 +131,10 @@ func.func @pooling_nhwc_max(%input: tensor<?x1x?x?xf32>, %filter: tensor<1x?xf32
 }
 
 // CHECK-LABEL: @pooling_nhwc_max_unsigned
-// CHECK-SAME: %[[ARG0:.+]]: tensor<?x1x?x?xf32>,
-// CHECK-SAME: %[[ARG1:.+]]: tensor<1x?xf32>
-// CHECK-SAME: %[[ARG2:.+]]: tensor<?x1x?x?xf32>
-func.func @pooling_nhwc_max_unsigned(%input: tensor<?x1x?x?xf32>, %filter: tensor<1x?xf32>, %init: tensor<?x1x?x?xf32>) -> tensor<?x1x?x?xf32> {
+// CHECK-SAME: %[[ARG0:.+]]: tensor<?x1x?x?xi32>,
+// CHECK-SAME: %[[ARG1:.+]]: tensor<1x?xi32>
+// CHECK-SAME: %[[ARG2:.+]]: tensor<?x1x?x?xi32>
+func.func @pooling_nhwc_max_unsigned(%input: tensor<?x1x?x?xi32>, %filter: tensor<1x?xi32>, %init: tensor<?x1x?x?xi32>) -> tensor<?x1x?x?xi32> {
   // CHECK: %[[SLICE0:.+]] = tensor.extract_slice %[[ARG0]]
   // CHECK: %[[SLICE1:.+]] = tensor.extract_slice %[[ARG1]]
   // CHECK: %[[SLICE2:.+]] = tensor.extract_slice %[[ARG2]]
@@ -142,10 +142,10 @@ func.func @pooling_nhwc_max_unsigned(%input: tensor<?x1x?x?xf32>, %filter: tenso
   // CHECK: %[[RES:.+]] = tensor.insert_slice %[[SLICERES]] into %[[ARG2]]
   %0 = linalg.pooling_nhwc_max_unsigned {dilations = dense<1> : tensor<2xi64>,
                                 strides = dense<1> : tensor<2xi64>}
-     ins (%input, %filter: tensor<?x1x?x?xf32>, tensor<1x?xf32>)
-    outs (%init: tensor<?x1x?x?xf32>) -> tensor<?x1x?x?xf32>
+     ins (%input, %filter: tensor<?x1x?x?xi32>, tensor<1x?xi32>)
+    outs (%init: tensor<?x1x?x?xi32>) -> tensor<?x1x?x?xi32>
   // CHECK: return %[[RES]]
-  return %0 : tensor<?x1x?x?xf32>
+  return %0 : tensor<?x1x?x?xi32>
 }
 
 // CHECK-LABEL: @pooling_nhwc_min
@@ -167,10 +167,10 @@ func.func @pooling_nhwc_min(%input: tensor<?x1x?x?xf32>, %filter: tensor<1x?xf32
 }
 
 // CHECK-LABEL: @pooling_nhwc_min_unsigned
-// CHECK-SAME: %[[ARG0:.+]]: tensor<?x1x?x?xf32>,
-// CHECK-SAME: %[[ARG1:.+]]: tensor<1x?xf32>
-// CHECK-SAME: %[[ARG2:.+]]: tensor<?x1x?x?xf32>
-func.func @pooling_nhwc_min_unsigned(%input: tensor<?x1x?x?xf32>, %filter: tensor<1x?xf32>, %init: tensor<?x1x?x?xf32>) -> tensor<?x1x?x?xf32> {
+// CHECK-SAME: %[[ARG0:.+]]: tensor<?x1x?x?xi32>,
+// CHECK-SAME: %[[ARG1:.+]]: tensor<1x?xi32>
+// CHECK-SAME: %[[ARG2:.+]]: tensor<?x1x?x?xi32>
+func.func @pooling_nhwc_min_unsigned(%input: tensor<?x1x?x?xi32>, %filter: tensor<1x?xi32>, %init: tensor<?x1x?x?xi32>) -> tensor<?x1x?x?xi32> {
   // CHECK: %[[SLICE0:.+]] = tensor.extract_slice %[[ARG0]]
   // CHECK: %[[SLICE1:.+]] = tensor.extract_slice %[[ARG1]]
   // CHECK: %[[SLICE2:.+]] = tensor.extract_slice %[[ARG2]]
@@ -178,10 +178,10 @@ func.func @pooling_nhwc_min_unsigned(%input: tensor<?x1x?x?xf32>, %filter: tenso
   // CHECK: %[[RES:.+]] = tensor.insert_slice %[[SLICERES]] into %[[ARG2]]
   %0 = linalg.pooling_nhwc_min_unsigned {dilations = dense<1> : tensor<2xi64>,
                                 strides = dense<1> : tensor<2xi64>}
-     ins (%input, %filter: tensor<?x1x?x?xf32>, tensor<1x?xf32>)
-    outs (%init: tensor<?x1x?x?xf32>) -> tensor<?x1x?x?xf32>
+     ins (%input, %filter: tensor<?x1x?x?xi32>, tensor<1x?xi32>)
+    outs (%init: tensor<?x1x?x?xi32>) -> tensor<?x1x?x?xi32>
   // CHECK: return %[[RES]]
-  return %0 : tensor<?x1x?x?xf32>
+  return %0 : tensor<?x1x?x?xi32>
 }
 
 // CHECK-LABEL: @pooling_nchw_max

@llvmbot
Copy link
Member

llvmbot commented Nov 2, 2025

@llvm/pr-subscribers-mlir

Author: Men-cotton (Men-cotton)

Changes

#164800

Ensures unsigned pooling ops in Linalg stay in the integer domain: the lowering now rejects floating/bool inputs with a clear diagnostic, new regression tests lock in both the error path and a valid integer example, and transform decompositions are updated to reflect the integer typing.

CC: @banach-space


Full diff: https://github.com/llvm/llvm-project/pull/166070.diff

4 Files Affected:

  • (modified) mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp (+14-4)
  • (modified) mlir/test/Dialect/Linalg/named-ops-fail.mlir (+14-1)
  • (modified) mlir/test/Dialect/Linalg/named-ops.mlir (+34)
  • (modified) mlir/test/Dialect/Linalg/transform-op-decompose.mlir (+14-14)
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 3dc45edf4a23f..8eb03dc182ae9 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -579,13 +579,23 @@ class RegionBuilderHelper {
       return arith::MinSIOp::create(builder, arg0.getLoc(), arg0, arg1);
     case BinaryFn::max_unsigned:
       assert(!allComplex);
-      if (allFloatingPoint)
-        return arith::MaximumFOp::create(builder, arg0.getLoc(), arg0, arg1);
+      if (!allInteger || allBool) {
+        if (emitError) {
+          emitError() << "unsupported operation: unsigned max not on uint";
+          return nullptr;
+        }
+        llvm_unreachable("unsupported operation: unsigned max not on uint");
+      }
       return arith::MaxUIOp::create(builder, arg0.getLoc(), arg0, arg1);
     case BinaryFn::min_unsigned:
       assert(!allComplex);
-      if (allFloatingPoint)
-        return arith::MinimumFOp::create(builder, arg0.getLoc(), arg0, arg1);
+      if (!allInteger || allBool) {
+        if (emitError) {
+          emitError() << "unsupported operation: unsigned min not on uint";
+          return nullptr;
+        }
+        llvm_unreachable("unsupported operation: unsigned min not on uint");
+      }
       return arith::MinUIOp::create(builder, arg0.getLoc(), arg0, arg1);
     case BinaryFn::powf:
       assert(allFloatingPoint);
diff --git a/mlir/test/Dialect/Linalg/named-ops-fail.mlir b/mlir/test/Dialect/Linalg/named-ops-fail.mlir
index 552a0abaa797c..4ecf685b4c695 100644
--- a/mlir/test/Dialect/Linalg/named-ops-fail.mlir
+++ b/mlir/test/Dialect/Linalg/named-ops-fail.mlir
@@ -80,6 +80,20 @@ func.func @divu_broadcast(%arg0: memref<8x16xi32>, %arg1: memref<4x8x16xi32>, %a
 
 // -----
 
+func.func @pooling_nhwc_max_unsigned_float(
+    %input: tensor<?x?x?x?xf32>,
+    %filter: tensor<?x?xf32>,
+    %init_val: tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32> {
+  // CHECK: unsupported operation: unsigned max not on uint
+  linalg.pooling_nhwc_max_unsigned {dilations = dense<1> : tensor<2xi64>,
+                                    strides = dense<1> : tensor<2xi64>}
+      ins (%input, %filter: tensor<?x?x?x?xf32>, tensor<?x?xf32>)
+     outs (%init_val: tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32>
+  return %init_val : tensor<?x?x?x?xf32>
+}
+
+// -----
+
 func.func @exp_type_cast(%arg: memref<4x8x16xf16>, %out: memref<4x8x16xf32>) {
   // CHECK: operand 1 ('f16') doesn't match the element type of the enclosing linalg.generic op ('f32')
   linalg.exp ins(%arg : memref<4x8x16xf16>) outs(%out: memref<4x8x16xf32>)
@@ -349,4 +363,3 @@ func.func @select_wrong_condition_type(%arg0: memref<4x8x16xf32>, %arg1: memref<
   linalg.select ins(%arg0, %arg1, %arg2 : memref<4x8x16xf32>, memref<4x8x16xf32>, memref<4x8x16xf32>) outs(%arg3: memref<4x8x16xf32>)
   return
 }
-
diff --git a/mlir/test/Dialect/Linalg/named-ops.mlir b/mlir/test/Dialect/Linalg/named-ops.mlir
index a93e9799ceb3f..c2a8f24624d8e 100644
--- a/mlir/test/Dialect/Linalg/named-ops.mlir
+++ b/mlir/test/Dialect/Linalg/named-ops.mlir
@@ -705,6 +705,23 @@ func.func @pooling_nhwc_max_tensor(%input: tensor<1x4x4x1xf32>) -> tensor<1x2x2x
   return %res : tensor<1x2x2x1xf32>
 }
 
+// -----
+
+// CHECK-LABEL: func @pooling_nhwc_max_unsigned_tensor
+// CHECK:         %{{.+}} = linalg.pooling_nhwc_max_unsigned
+// CHECK-SAME:      ins(%{{.+}}, %{{.+}} : tensor<1x4x4x1xi32>, tensor<3x3xi32>)
+// CHECK-SAME:      outs(%{{.+}} : tensor<1x2x2x1xi32>) -> tensor<1x2x2x1xi32>
+func.func @pooling_nhwc_max_unsigned_tensor(%input: tensor<1x4x4x1xi32>) -> tensor<1x2x2x1xi32> {
+  %fake = tensor.empty() : tensor<3x3xi32>
+  %init = tensor.empty() : tensor<1x2x2x1xi32>
+  %cst = arith.constant 0 : i32
+  %fill = linalg.fill ins(%cst : i32) outs(%init : tensor<1x2x2x1xi32>) -> tensor<1x2x2x1xi32>
+  %res = linalg.pooling_nhwc_max_unsigned {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>}
+    ins(%input, %fake: tensor<1x4x4x1xi32>, tensor<3x3xi32>)
+    outs(%fill: tensor<1x2x2x1xi32>) -> tensor<1x2x2x1xi32>
+  return %res : tensor<1x2x2x1xi32>
+}
+
 // -----
 // CHECK-LABEL: func @pooling_nwc_max_tensor
 // CHECK:         %{{.+}} = linalg.pooling_nwc_max
@@ -1017,6 +1034,23 @@ func.func @pooling_nhwc_min_tensor(%input: tensor<1x4x4x1xf32>) -> tensor<1x2x2x
 
 // -----
 
+// CHECK-LABEL: func @pooling_nhwc_min_unsigned_tensor
+// CHECK:         %{{.+}} = linalg.pooling_nhwc_min_unsigned
+// CHECK-SAME:      ins(%{{.+}}, %{{.+}} : tensor<1x4x4x1xi32>, tensor<3x3xi32>)
+// CHECK-SAME:      outs(%{{.+}} : tensor<1x2x2x1xi32>) -> tensor<1x2x2x1xi32>
+func.func @pooling_nhwc_min_unsigned_tensor(%input: tensor<1x4x4x1xi32>) -> tensor<1x2x2x1xi32> {
+  %fake = tensor.empty() : tensor<3x3xi32>
+  %init = tensor.empty() : tensor<1x2x2x1xi32>
+  %cst = arith.constant 0 : i32
+  %fill = linalg.fill ins(%cst : i32) outs(%init : tensor<1x2x2x1xi32>) -> tensor<1x2x2x1xi32>
+  %res = linalg.pooling_nhwc_min_unsigned {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>}
+    ins(%input, %fake: tensor<1x4x4x1xi32>, tensor<3x3xi32>)
+    outs(%fill: tensor<1x2x2x1xi32>) -> tensor<1x2x2x1xi32>
+  return %res : tensor<1x2x2x1xi32>
+}
+
+// -----
+
 // CHECK-LABEL: func @pooling_nwc_min_tensor
 // CHECK:         %{{.+}} = linalg.pooling_nwc_min
 // CHECK-SAME:      dilations = dense<1> : tensor<1xi64>
diff --git a/mlir/test/Dialect/Linalg/transform-op-decompose.mlir b/mlir/test/Dialect/Linalg/transform-op-decompose.mlir
index 72acf43361f50..60a4c555fa19a 100644
--- a/mlir/test/Dialect/Linalg/transform-op-decompose.mlir
+++ b/mlir/test/Dialect/Linalg/transform-op-decompose.mlir
@@ -131,10 +131,10 @@ func.func @pooling_nhwc_max(%input: tensor<?x1x?x?xf32>, %filter: tensor<1x?xf32
 }
 
 // CHECK-LABEL: @pooling_nhwc_max_unsigned
-// CHECK-SAME: %[[ARG0:.+]]: tensor<?x1x?x?xf32>,
-// CHECK-SAME: %[[ARG1:.+]]: tensor<1x?xf32>
-// CHECK-SAME: %[[ARG2:.+]]: tensor<?x1x?x?xf32>
-func.func @pooling_nhwc_max_unsigned(%input: tensor<?x1x?x?xf32>, %filter: tensor<1x?xf32>, %init: tensor<?x1x?x?xf32>) -> tensor<?x1x?x?xf32> {
+// CHECK-SAME: %[[ARG0:.+]]: tensor<?x1x?x?xi32>,
+// CHECK-SAME: %[[ARG1:.+]]: tensor<1x?xi32>
+// CHECK-SAME: %[[ARG2:.+]]: tensor<?x1x?x?xi32>
+func.func @pooling_nhwc_max_unsigned(%input: tensor<?x1x?x?xi32>, %filter: tensor<1x?xi32>, %init: tensor<?x1x?x?xi32>) -> tensor<?x1x?x?xi32> {
   // CHECK: %[[SLICE0:.+]] = tensor.extract_slice %[[ARG0]]
   // CHECK: %[[SLICE1:.+]] = tensor.extract_slice %[[ARG1]]
   // CHECK: %[[SLICE2:.+]] = tensor.extract_slice %[[ARG2]]
@@ -142,10 +142,10 @@ func.func @pooling_nhwc_max_unsigned(%input: tensor<?x1x?x?xf32>, %filter: tenso
   // CHECK: %[[RES:.+]] = tensor.insert_slice %[[SLICERES]] into %[[ARG2]]
   %0 = linalg.pooling_nhwc_max_unsigned {dilations = dense<1> : tensor<2xi64>,
                                 strides = dense<1> : tensor<2xi64>}
-     ins (%input, %filter: tensor<?x1x?x?xf32>, tensor<1x?xf32>)
-    outs (%init: tensor<?x1x?x?xf32>) -> tensor<?x1x?x?xf32>
+     ins (%input, %filter: tensor<?x1x?x?xi32>, tensor<1x?xi32>)
+    outs (%init: tensor<?x1x?x?xi32>) -> tensor<?x1x?x?xi32>
   // CHECK: return %[[RES]]
-  return %0 : tensor<?x1x?x?xf32>
+  return %0 : tensor<?x1x?x?xi32>
 }
 
 // CHECK-LABEL: @pooling_nhwc_min
@@ -167,10 +167,10 @@ func.func @pooling_nhwc_min(%input: tensor<?x1x?x?xf32>, %filter: tensor<1x?xf32
 }
 
 // CHECK-LABEL: @pooling_nhwc_min_unsigned
-// CHECK-SAME: %[[ARG0:.+]]: tensor<?x1x?x?xf32>,
-// CHECK-SAME: %[[ARG1:.+]]: tensor<1x?xf32>
-// CHECK-SAME: %[[ARG2:.+]]: tensor<?x1x?x?xf32>
-func.func @pooling_nhwc_min_unsigned(%input: tensor<?x1x?x?xf32>, %filter: tensor<1x?xf32>, %init: tensor<?x1x?x?xf32>) -> tensor<?x1x?x?xf32> {
+// CHECK-SAME: %[[ARG0:.+]]: tensor<?x1x?x?xi32>,
+// CHECK-SAME: %[[ARG1:.+]]: tensor<1x?xi32>
+// CHECK-SAME: %[[ARG2:.+]]: tensor<?x1x?x?xi32>
+func.func @pooling_nhwc_min_unsigned(%input: tensor<?x1x?x?xi32>, %filter: tensor<1x?xi32>, %init: tensor<?x1x?x?xi32>) -> tensor<?x1x?x?xi32> {
   // CHECK: %[[SLICE0:.+]] = tensor.extract_slice %[[ARG0]]
   // CHECK: %[[SLICE1:.+]] = tensor.extract_slice %[[ARG1]]
   // CHECK: %[[SLICE2:.+]] = tensor.extract_slice %[[ARG2]]
@@ -178,10 +178,10 @@ func.func @pooling_nhwc_min_unsigned(%input: tensor<?x1x?x?xf32>, %filter: tenso
   // CHECK: %[[RES:.+]] = tensor.insert_slice %[[SLICERES]] into %[[ARG2]]
   %0 = linalg.pooling_nhwc_min_unsigned {dilations = dense<1> : tensor<2xi64>,
                                 strides = dense<1> : tensor<2xi64>}
-     ins (%input, %filter: tensor<?x1x?x?xf32>, tensor<1x?xf32>)
-    outs (%init: tensor<?x1x?x?xf32>) -> tensor<?x1x?x?xf32>
+     ins (%input, %filter: tensor<?x1x?x?xi32>, tensor<1x?xi32>)
+    outs (%init: tensor<?x1x?x?xi32>) -> tensor<?x1x?x?xi32>
   // CHECK: return %[[RES]]
-  return %0 : tensor<?x1x?x?xf32>
+  return %0 : tensor<?x1x?x?xi32>
 }
 
 // CHECK-LABEL: @pooling_nchw_max

Comment on lines +582 to +588
if (!allInteger || allBool) {
if (emitError) {
emitError() << "unsupported operation: unsigned max not on uint";
return nullptr;
}
llvm_unreachable("unsupported operation: unsigned max not on uint");
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @Men-cotton - thanks for taking this up.
Although this addresses the immediate issue of linalg.*_(max|min)_unsigned_* - a better solution might be to indeed fix :-

def _binary_max_unsigned(self, lhs: Value, rhs: Value) -> Value:
if _is_floating_point_type(lhs.type):
return arith.MaximumFOp(lhs, rhs).result
if _is_integer_type(lhs.type) or _is_index_type(lhs.type):
return arith.MaxUIOp(lhs, rhs).result
raise NotImplementedError("Unsupported 'max_unsigned' operands: {lhs}, {rhs}")

Because that'd ensure that in future some other op implementing a (max|min)_unsigned also doesn't warrant a similar fix in their verifier as the one above.

I also see something similar already done in this file for a non-Convolution op : linalg.div_unsigned :-

case BinaryFn::div_unsigned:
if (!allInteger || allBool) {
if (emitError) {
emitError() << "unsupported operation: unsigned div not on uint";
return nullptr;
}
llvm_unreachable("unsupported operation: unsigned div not on uint");
}

Again, these are just my opinion - I'll let @banach-space take a closer look. :)

@Men-cotton
Copy link
Contributor Author

@banach-space Ping

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants