Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[mlir][spirv] Add folding for [I|Logical][Not]Equal #74194

Merged
merged 3 commits into from
Dec 20, 2023

Conversation

inbelic
Copy link
Contributor

@inbelic inbelic commented Dec 2, 2023

No description provided.

@inbelic inbelic force-pushed the inbelic/spirv-folding-eq-ops branch from 1ef8988 to 73ff53c Compare December 3, 2023 18:00
@inbelic inbelic changed the title Inbelic/spirv folding eq ops [mlir][spirv] Add folding for [I|Logical][Not]Equal Dec 3, 2023
@inbelic inbelic marked this pull request as ready for review December 3, 2023 18:45
@llvmbot
Copy link
Collaborator

llvmbot commented Dec 3, 2023

@llvm/pr-subscribers-mlir-spirv

@llvm/pr-subscribers-mlir

Author: Finn Plummer (inbelic)

Changes

Full diff: https://github.com/llvm/llvm-project/pull/74194.diff

4 Files Affected:

  • (modified) mlir/include/mlir/Dialect/SPIRV/IR/SPIRVLogicalOps.td (+8-1)
  • (modified) mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp (+95-2)
  • (modified) mlir/test/Conversion/SPIRVToLLVM/logical-ops-to-llvm.mlir (+8-8)
  • (modified) mlir/test/Dialect/SPIRV/Transforms/canonicalize.mlir (+165)
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVLogicalOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVLogicalOps.td
index cf38c15d20dc3..0053cd5fc9448 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVLogicalOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVLogicalOps.td
@@ -473,6 +473,8 @@ def SPIRV_IEqualOp : SPIRV_LogicalBinaryOp<"IEqual",
 
     ```
   }];
+
+  let hasFolder = 1;
 }
 
 // -----
@@ -506,6 +508,8 @@ def SPIRV_INotEqualOp : SPIRV_LogicalBinaryOp<"INotEqual",
 
     ```
   }];
+
+  let hasFolder = 1;
 }
 
 // -----
@@ -644,6 +648,8 @@ def SPIRV_LogicalEqualOp : SPIRV_LogicalBinaryOp<"LogicalEqual",
     %2 = spirv.LogicalEqual %0, %1 : vector<4xi1>
     ```
   }];
+
+  let hasFolder = 1;
 }
 
 // -----
@@ -713,7 +719,8 @@ def SPIRV_LogicalNotEqualOp : SPIRV_LogicalBinaryOp<"LogicalNotEqual",
     %2 = spirv.LogicalNotEqual %0, %1 : vector<4xi1>
     ```
   }];
-  let hasFolder = true;
+
+  let hasFolder = 1;
 }
 
 // -----
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp
index 9acd982dc95af..16efe8797f4a3 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp
@@ -309,6 +309,32 @@ OpFoldResult spirv::LogicalAndOp::fold(FoldAdaptor adaptor) {
   return Attribute();
 }
 
+//===----------------------------------------------------------------------===//
+// spirv.LogicalEqualOp
+//===----------------------------------------------------------------------===//
+
+OpFoldResult
+spirv::LogicalEqualOp::fold(spirv::LogicalEqualOp::FoldAdaptor adaptor) {
+  // x == x -> true
+  if (getOperand1() == getOperand2()) {
+    auto type = getType();
+    if (isa<IntegerType>(type)) {
+      return BoolAttr::get(getContext(), true);
+    }
+    if (isa<VectorType>(type)) {
+      auto vtType = cast<ShapedType>(type);
+      auto element = BoolAttr::get(getContext(), true);
+      return DenseElementsAttr::get(vtType, element);
+    }
+  }
+
+  return constFoldBinaryOp<IntegerAttr>(adaptor.getOperands(),
+                                        [](const APInt &a, const APInt &b) {
+                                          APInt zero = APInt::getZero(1);
+                                          return a == b ? (zero + 1) : zero;
+                                        });
+}
+
 //===----------------------------------------------------------------------===//
 // spirv.LogicalNotEqualOp
 //===----------------------------------------------------------------------===//
@@ -316,12 +342,29 @@ OpFoldResult spirv::LogicalAndOp::fold(FoldAdaptor adaptor) {
 OpFoldResult spirv::LogicalNotEqualOp::fold(FoldAdaptor adaptor) {
   if (std::optional<bool> rhs =
           getScalarOrSplatBoolAttr(adaptor.getOperand2())) {
-    // x && false = x
+    // x != false -> x
     if (!rhs.value())
       return getOperand1();
   }
 
-  return Attribute();
+  // x == x -> false
+  if (getOperand1() == getOperand2()) {
+    auto type = getType();
+    if (isa<IntegerType>(type)) {
+      return BoolAttr::get(getContext(), false);
+    }
+    if (isa<VectorType>(type)) {
+      auto vtType = cast<ShapedType>(type);
+      auto element = BoolAttr::get(getContext(), false);
+      return DenseElementsAttr::get(vtType, element);
+    }
+  }
+
+  return constFoldBinaryOp<IntegerAttr>(adaptor.getOperands(),
+                                        [](const APInt &a, const APInt &b) {
+                                          APInt zero = APInt::getZero(1);
+                                          return a == b ? zero : (zero + 1);
+                                        });
 }
 
 //===----------------------------------------------------------------------===//
@@ -356,6 +399,56 @@ OpFoldResult spirv::LogicalOrOp::fold(FoldAdaptor adaptor) {
   return Attribute();
 }
 
+//===----------------------------------------------------------------------===//
+// spirv.IEqualOp
+//===----------------------------------------------------------------------===//
+
+OpFoldResult spirv::IEqualOp::fold(spirv::IEqualOp::FoldAdaptor adaptor) {
+  // x == x -> true
+  if (getOperand1() == getOperand2()) {
+    auto type = getType();
+    if (isa<IntegerType>(type)) {
+      return BoolAttr::get(getContext(), true);
+    }
+    if (isa<VectorType>(type)) {
+      auto vtType = cast<ShapedType>(type);
+      auto element = BoolAttr::get(getContext(), true);
+      return DenseElementsAttr::get(vtType, element);
+    }
+  }
+
+  return constFoldBinaryOp<IntegerAttr>(adaptor.getOperands(), getType(),
+                                        [](const APInt &a, const APInt &b) {
+                                          APInt zero = APInt::getZero(1);
+                                          return a == b ? (zero + 1) : zero;
+                                        });
+}
+
+//===----------------------------------------------------------------------===//
+// spirv.INotEqualOp
+//===----------------------------------------------------------------------===//
+
+OpFoldResult spirv::INotEqualOp::fold(spirv::INotEqualOp::FoldAdaptor adaptor) {
+  // x == x -> false
+  if (getOperand1() == getOperand2()) {
+    auto type = getType();
+    if (isa<IntegerType>(type)) {
+      return BoolAttr::get(getContext(), false);
+    }
+    if (isa<VectorType>(type)) {
+      auto vtType = cast<ShapedType>(type);
+      auto element = BoolAttr::get(getContext(), false);
+      return DenseElementsAttr::get(vtType, element);
+    }
+  }
+
+  return constFoldBinaryOp<IntegerAttr>(adaptor.getOperands(), getType(),
+                                        [](const APInt &a, const APInt &b) {
+                                          APInt zero = APInt::getZero(1);
+                                          return a == b ? zero : (zero + 1);
+                                        });
+}
+
 //===----------------------------------------------------------------------===//
 // spirv.mlir.selection
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Conversion/SPIRVToLLVM/logical-ops-to-llvm.mlir b/mlir/test/Conversion/SPIRVToLLVM/logical-ops-to-llvm.mlir
index 6d93480d3ed14..aab2dce980ca7 100644
--- a/mlir/test/Conversion/SPIRVToLLVM/logical-ops-to-llvm.mlir
+++ b/mlir/test/Conversion/SPIRVToLLVM/logical-ops-to-llvm.mlir
@@ -7,14 +7,14 @@
 // CHECK-LABEL: @logical_equal_scalar
 spirv.func @logical_equal_scalar(%arg0: i1, %arg1: i1) "None" {
   // CHECK: llvm.icmp "eq" %{{.*}}, %{{.*}} : i1
-  %0 = spirv.LogicalEqual %arg0, %arg0 : i1
+  %0 = spirv.LogicalEqual %arg0, %arg1 : i1
   spirv.Return
 }
 
 // CHECK-LABEL: @logical_equal_vector
 spirv.func @logical_equal_vector(%arg0: vector<4xi1>, %arg1: vector<4xi1>) "None" {
   // CHECK: llvm.icmp "eq" %{{.*}}, %{{.*}} : vector<4xi1>
-  %0 = spirv.LogicalEqual %arg0, %arg0 : vector<4xi1>
+  %0 = spirv.LogicalEqual %arg0, %arg1 : vector<4xi1>
   spirv.Return
 }
 
@@ -25,14 +25,14 @@ spirv.func @logical_equal_vector(%arg0: vector<4xi1>, %arg1: vector<4xi1>) "None
 // CHECK-LABEL: @logical_not_equal_scalar
 spirv.func @logical_not_equal_scalar(%arg0: i1, %arg1: i1) "None" {
   // CHECK: llvm.icmp "ne" %{{.*}}, %{{.*}} : i1
-  %0 = spirv.LogicalNotEqual %arg0, %arg0 : i1
+  %0 = spirv.LogicalNotEqual %arg0, %arg1 : i1
   spirv.Return
 }
 
 // CHECK-LABEL: @logical_not_equal_vector
 spirv.func @logical_not_equal_vector(%arg0: vector<4xi1>, %arg1: vector<4xi1>) "None" {
   // CHECK: llvm.icmp "ne" %{{.*}}, %{{.*}} : vector<4xi1>
-  %0 = spirv.LogicalNotEqual %arg0, %arg0 : vector<4xi1>
+  %0 = spirv.LogicalNotEqual %arg0, %arg1 : vector<4xi1>
   spirv.Return
 }
 
@@ -63,14 +63,14 @@ spirv.func @logical_not_vector(%arg0: vector<4xi1>) "None" {
 // CHECK-LABEL: @logical_and_scalar
 spirv.func @logical_and_scalar(%arg0: i1, %arg1: i1) "None" {
   // CHECK: llvm.and %{{.*}}, %{{.*}} : i1
-  %0 = spirv.LogicalAnd %arg0, %arg0 : i1
+  %0 = spirv.LogicalAnd %arg0, %arg1 : i1
   spirv.Return
 }
 
 // CHECK-LABEL: @logical_and_vector
 spirv.func @logical_and_vector(%arg0: vector<4xi1>, %arg1: vector<4xi1>) "None" {
   // CHECK: llvm.and %{{.*}}, %{{.*}} : vector<4xi1>
-  %0 = spirv.LogicalAnd %arg0, %arg0 : vector<4xi1>
+  %0 = spirv.LogicalAnd %arg0, %arg1 : vector<4xi1>
   spirv.Return
 }
 
@@ -81,13 +81,13 @@ spirv.func @logical_and_vector(%arg0: vector<4xi1>, %arg1: vector<4xi1>) "None"
 // CHECK-LABEL: @logical_or_scalar
 spirv.func @logical_or_scalar(%arg0: i1, %arg1: i1) "None" {
   // CHECK: llvm.or %{{.*}}, %{{.*}} : i1
-  %0 = spirv.LogicalOr %arg0, %arg0 : i1
+  %0 = spirv.LogicalOr %arg0, %arg1 : i1
   spirv.Return
 }
 
 // CHECK-LABEL: @logical_or_vector
 spirv.func @logical_or_vector(%arg0: vector<4xi1>, %arg1: vector<4xi1>) "None" {
   // CHECK: llvm.or %{{.*}}, %{{.*}} : vector<4xi1>
-  %0 = spirv.LogicalOr %arg0, %arg0 : vector<4xi1>
+  %0 = spirv.LogicalOr %arg0, %arg1 : vector<4xi1>
   spirv.Return
 }
diff --git a/mlir/test/Dialect/SPIRV/Transforms/canonicalize.mlir b/mlir/test/Dialect/SPIRV/Transforms/canonicalize.mlir
index 0200805a44439..7a8e262db266a 100644
--- a/mlir/test/Dialect/SPIRV/Transforms/canonicalize.mlir
+++ b/mlir/test/Dialect/SPIRV/Transforms/canonicalize.mlir
@@ -569,6 +569,48 @@ func.func @convert_logical_not_to_not_equal(%arg0: vector<3xi64>, %arg1: vector<
   spirv.ReturnValue %3 : vector<3xi1>
 }
 
+// -----
+
+//===----------------------------------------------------------------------===//
+// spirv.LogicalEqual
+//===----------------------------------------------------------------------===//
+
+// CHECK-LABEL: @logical_equal_same
+func.func @logical_equal_same(%arg0 : i1, %arg1 : vector<3xi1>) -> (i1, vector<3xi1>) {
+  // CHECK-DAG: %[[CTRUE:.*]] = spirv.Constant true
+  // CHECK-DAG: %[[CVTRUE:.*]] = spirv.Constant dense<true>
+
+  %0 = spirv.LogicalEqual %arg0, %arg0 : i1
+  %1 = spirv.LogicalEqual %arg1, %arg1 : vector<3xi1>
+  // CHECK: return %[[CTRUE]], %[[CVTRUE]]
+  return %0, %1 : i1, vector<3xi1>
+}
+
+// CHECK-LABEL: @const_fold_scalar_logical_equal
+func.func @const_fold_scalar_logical_equal() -> (i1, i1) {
+  %true = spirv.Constant true
+  %false = spirv.Constant false
+
+  // CHECK-DAG: %[[CTRUE:.*]] = spirv.Constant true
+  // CHECK-DAG: %[[CFALSE:.*]] = spirv.Constant false
+  %0 = spirv.LogicalEqual %true, %false : i1
+  %1 = spirv.LogicalEqual %false, %false : i1
+
+  // CHECK: return %[[CFALSE]], %[[CTRUE]]
+  return %0, %1 : i1, i1
+}
+
+// CHECK-LABEL: @const_fold_vector_logical_equal
+func.func @const_fold_vector_logical_equal() -> vector<3xi1> {
+  %cv0 = spirv.Constant dense<[true, false, true]> : vector<3xi1>
+  %cv1 = spirv.Constant dense<[true, false, false]> : vector<3xi1>
+
+  // CHECK: %[[RET:.*]] = spirv.Constant dense<[true, true, false]>
+  %0 = spirv.LogicalEqual %cv0, %cv1 : vector<3xi1>
+
+  // CHECK: return %[[RET]]
+  return %0 : vector<3xi1>
+}
 
 // -----
 
@@ -585,6 +627,43 @@ func.func @convert_logical_not_equal_false(%arg: vector<4xi1>) -> vector<4xi1> {
   spirv.ReturnValue %0 : vector<4xi1>
 }
 
+// CHECK-LABEL: @logical_not_equal_same
+func.func @logical_not_equal_same(%arg0 : i1, %arg1 : vector<3xi1>) -> (i1, vector<3xi1>) {
+  // CHECK-DAG: %[[CFALSE:.*]] = spirv.Constant false
+  // CHECK-DAG: %[[CVFALSE:.*]] = spirv.Constant dense<false>
+  %0 = spirv.LogicalNotEqual %arg0, %arg0 : i1
+  %1 = spirv.LogicalNotEqual %arg1, %arg1 : vector<3xi1>
+
+  // CHECK: return %[[CFALSE]], %[[CVFALSE]]
+  return %0, %1 : i1, vector<3xi1>
+}
+
+// CHECK-LABEL: @const_fold_scalar_logical_not_equal
+func.func @const_fold_scalar_logical_not_equal() -> (i1, i1) {
+  %true = spirv.Constant true
+  %false = spirv.Constant false
+
+  // CHECK-DAG: %[[CTRUE:.*]] = spirv.Constant true
+  // CHECK-DAG: %[[CFALSE:.*]] = spirv.Constant false
+  %0 = spirv.LogicalNotEqual %true, %false : i1
+  %1 = spirv.LogicalNotEqual %false, %false : i1
+
+  // CHECK: return %[[CTRUE]], %[[CFALSE]]
+  return %0, %1 : i1, i1
+}
+
+// CHECK-LABEL: @const_fold_vector_logical_not_equal
+func.func @const_fold_vector_logical_not_equal() -> vector<3xi1> {
+  %cv0 = spirv.Constant dense<[true, false, true]> : vector<3xi1>
+  %cv1 = spirv.Constant dense<[true, false, false]> : vector<3xi1>
+
+  // CHECK: %[[RET:.*]] = spirv.Constant dense<[false, false, true]>
+  %0 = spirv.LogicalNotEqual %cv0, %cv1 : vector<3xi1>
+
+  // CHECK: return %[[RET]]
+  return %0 : vector<3xi1>
+}
+
 // -----
 
 func.func @convert_logical_not_to_equal(%arg0: vector<3xi64>, %arg1: vector<3xi64>) -> vector<3xi1> {
@@ -660,6 +739,92 @@ func.func @convert_logical_or_true_false_vector(%arg: vector<3xi1>) -> (vector<3
 
 // -----
 
+//===----------------------------------------------------------------------===//
+// spirv.IEqual
+//===----------------------------------------------------------------------===//
+
+// CHECK-LABEL: @iequal_same
+func.func @iequal_same(%arg0 : i32, %arg1 : vector<3xi32>) -> (i1, vector<3xi1>) {
+  // CHECK-DAG: %[[CTRUE:.*]] = spirv.Constant true
+  // CHECK-DAG: %[[CVTRUE:.*]] = spirv.Constant dense<true>
+  %0 = spirv.IEqual %arg0, %arg0 : i32
+  %1 = spirv.IEqual %arg1, %arg1 : vector<3xi32>
+
+  // CHECK: return %[[CTRUE]], %[[CVTRUE]]
+  return %0, %1 : i1, vector<3xi1>
+}
+
+// CHECK-LABEL: @const_fold_scalar_iequal
+func.func @const_fold_scalar_iequal() -> (i1, i1) {
+  %c5 = spirv.Constant 5 : i32
+  %c6 = spirv.Constant 6 : i32
+
+  // CHECK-DAG: %[[CTRUE:.*]] = spirv.Constant true
+  // CHECK-DAG: %[[CFALSE:.*]] = spirv.Constant false
+  %0 = spirv.IEqual %c5, %c6 : i32
+  %1 = spirv.IEqual %c5, %c5 : i32
+
+  // CHECK: return %[[CFALSE]], %[[CTRUE]]
+  return %0, %1 : i1, i1
+}
+
+// CHECK-LABEL: @const_fold_vector_iequal
+func.func @const_fold_vector_iequal() -> vector<3xi1> {
+  %cv0 = spirv.Constant dense<[-1, -4, 2]> : vector<3xi32>
+  %cv1 = spirv.Constant dense<[-1, -3, 2]> : vector<3xi32>
+
+  // CHECK: %[[RET:.*]] = spirv.Constant dense<[true, false, true]>
+  %0 = spirv.IEqual %cv0, %cv1 : vector<3xi32>
+
+  // CHECK: return %[[RET]]
+  return %0 : vector<3xi1>
+}
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// spirv.INotEqual
+//===----------------------------------------------------------------------===//
+
+// CHECK-LABEL: @inotequal_same
+func.func @inotequal_same(%arg0 : i32, %arg1 : vector<3xi32>) -> (i1, vector<3xi1>) {
+  // CHECK-DAG: %[[CFALSE:.*]] = spirv.Constant false
+  // CHECK-DAG: %[[CVFALSE:.*]] = spirv.Constant dense<false>
+  %0 = spirv.INotEqual %arg0, %arg0 : i32
+  %1 = spirv.INotEqual %arg1, %arg1 : vector<3xi32>
+
+  // CHECK: return %[[CFALSE]], %[[CVFALSE]]
+  return %0, %1 : i1, vector<3xi1>
+}
+
+// CHECK-LABEL: @const_fold_scalar_inotequal
+func.func @const_fold_scalar_inotequal() -> (i1, i1) {
+  %c5 = spirv.Constant 5 : i32
+  %c6 = spirv.Constant 6 : i32
+
+  // CHECK-DAG: %[[CTRUE:.*]] = spirv.Constant true
+  // CHECK-DAG: %[[CFALSE:.*]] = spirv.Constant false
+  %0 = spirv.INotEqual %c5, %c6 : i32
+  %1 = spirv.INotEqual %c5, %c5 : i32
+
+  // CHECK: return %[[CTRUE]], %[[CFALSE]]
+  return %0, %1 : i1, i1
+}
+
+// CHECK-LABEL: @const_fold_vector_inotequal
+func.func @const_fold_vector_inotequal() -> vector<3xi1> {
+  %cv0 = spirv.Constant dense<[-1, -4, 2]> : vector<3xi32>
+  %cv1 = spirv.Constant dense<[-1, -3, 2]> : vector<3xi32>
+
+  // CHECK: %[[RET:.*]] = spirv.Constant dense<[false, true, false]>
+  %0 = spirv.INotEqual %cv0, %cv1 : vector<3xi32>
+
+  // CHECK: return %[[RET]]
+  return %0 : vector<3xi1>
+}
+
+// -----
+
 //===----------------------------------------------------------------------===//
 // spirv.mlir.selection
 //===----------------------------------------------------------------------===//

Copy link
Member

@kuhar kuhar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, just some coding style suggestions here

mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp Outdated Show resolved Hide resolved
mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp Outdated Show resolved Hide resolved
mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp Outdated Show resolved Hide resolved
mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp Outdated Show resolved Hide resolved
mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp Outdated Show resolved Hide resolved
mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp Outdated Show resolved Hide resolved
@inbelic
Copy link
Contributor Author

inbelic commented Dec 10, 2023

Nice, thanks. Let me know if it looks good and I can squash and try my new commit privileges :)

Copy link
Member

@kuhar kuhar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, just one cosmetic suggestion for single-line if statements. Feel free to change and merge without asking for re-approval.

Comment on lines 674 to 676
if (isa<IntegerType>(getType())) {
return trueAttr;
}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: we don't need braces around single-line if statements. also elsewhere

Add missing constant propogation folder for [I|Logical][N]Eq

Implement additional folding when lhs == rhs for all ops.

As well as, fix test cases in logical-ops-to-llvm that failed due to
introduced folding.

This helps for readability of lowered code into SPIR-V.

Part of work for llvm#70704
- fix coding style
@inbelic inbelic merged commit 4c83c27 into llvm:main Dec 20, 2023
4 checks passed
@inbelic inbelic deleted the inbelic/spirv-folding-eq-ops branch March 21, 2024 15:49
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants