Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[mlir][spirv] Lower arith overflow flags to corresponding SPIR-V op decorations #77714

Merged
merged 3 commits into from Jan 11, 2024

Conversation

Hardcode84
Copy link
Contributor

No description provided.

@llvmbot
Copy link
Collaborator

llvmbot commented Jan 11, 2024

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-spirv

Author: Ivan Butygin (Hardcode84)

Changes

Full diff: https://github.com/llvm/llvm-project/pull/77714.diff

2 Files Affected:

  • (modified) mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp (+56-3)
  • (modified) mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir (+40)
diff --git a/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp b/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
index aba6a21deccb0c..3b851604f597af 100644
--- a/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
+++ b/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
@@ -158,8 +158,61 @@ getTypeConversionFailure(ConversionPatternRewriter &rewriter, Operation *op) {
   return getTypeConversionFailure(rewriter, op, op->getResultTypes().front());
 }
 
+// TODO: Move to some common place?
+static std::string getDecorationString(spirv::Decoration decor) {
+  return llvm::convertToSnakeFromCamelCase(stringifyDecoration(decor));
+}
+
 namespace {
 
+/// Converts elementwise unary, binary and ternary arith operations to SPIR-V
+/// operations. Op can potentially support overflow flags.
+template <typename Op, typename SPIRVOp>
+struct ElementwiseArithOpPattern : public OpConversionPattern<Op> {
+  using OpConversionPattern<Op>::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(Op op, typename Op::Adaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    assert(adaptor.getOperands().size() <= 3);
+    Type dstType = this->getTypeConverter()->convertType(op.getType());
+    if (!dstType) {
+      return rewriter.notifyMatchFailure(
+          op->getLoc(),
+          llvm::formatv("failed to convert type {0} for SPIR-V", op.getType()));
+    }
+
+    if (SPIRVOp::template hasTrait<OpTrait::spirv::UnsignedOp>() &&
+        !getElementTypeOrSelf(op.getType()).isIndex() &&
+        dstType != op.getType()) {
+      return op.emitError("bitwidth emulation is not implemented yet on "
+                          "unsigned op pattern version");
+    }
+
+    auto converter = this->getTypeConverter<SPIRVTypeConverter>();
+    auto overflowFlags = arith::IntegerOverflowFlags::none;
+    if (auto overflowIface =
+            dyn_cast<arith::ArithIntegerOverflowFlagsInterface>(*op)) {
+      if (converter->getTargetEnv().allows(
+              spirv::Extension::SPV_KHR_no_integer_wrap_decoration))
+        overflowFlags = overflowIface.getOverflowAttr().getValue();
+    }
+
+    auto newOp = rewriter.template replaceOpWithNewOp<SPIRVOp>(
+        op, dstType, adaptor.getOperands());
+
+    if (bitEnumContainsAny(overflowFlags, arith::IntegerOverflowFlags::nsw))
+      newOp->setAttr(getDecorationString(spirv::Decoration::NoSignedWrap),
+                     rewriter.getUnitAttr());
+
+    if (bitEnumContainsAny(overflowFlags, arith::IntegerOverflowFlags::nuw))
+      newOp->setAttr(getDecorationString(spirv::Decoration::NoUnsignedWrap),
+                     rewriter.getUnitAttr());
+
+    return success();
+  }
+};
+
 //===----------------------------------------------------------------------===//
 // ConstantOp
 //===----------------------------------------------------------------------===//
@@ -1154,9 +1207,9 @@ void mlir::arith::populateArithToSPIRVPatterns(
   patterns.add<
     ConstantCompositeOpPattern,
     ConstantScalarOpPattern,
-    spirv::ElementwiseOpPattern<arith::AddIOp, spirv::IAddOp>,
-    spirv::ElementwiseOpPattern<arith::SubIOp, spirv::ISubOp>,
-    spirv::ElementwiseOpPattern<arith::MulIOp, spirv::IMulOp>,
+    ElementwiseArithOpPattern<arith::AddIOp, spirv::IAddOp>,
+    ElementwiseArithOpPattern<arith::SubIOp, spirv::ISubOp>,
+    ElementwiseArithOpPattern<arith::MulIOp, spirv::IMulOp>,
     spirv::ElementwiseOpPattern<arith::DivUIOp, spirv::UDivOp>,
     spirv::ElementwiseOpPattern<arith::DivSIOp, spirv::SDivOp>,
     spirv::ElementwiseOpPattern<arith::RemUIOp, spirv::UModOp>,
diff --git a/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir b/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir
index 0221e4815a9397..8bf90ed0aec8ee 100644
--- a/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir
+++ b/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir
@@ -1407,3 +1407,43 @@ func.func @float_scalar(%arg0: f16) {
 }
 
 } // end module
+
+// -----
+
+module attributes {
+  spirv.target_env = #spirv.target_env<#spirv.vce<v1.0, [Int8, Int16, Int64, Float16, Float64, Kernel], [SPV_KHR_no_integer_wrap_decoration]>, #spirv.resource_limits<>>
+} {
+
+// CHECK-LABEL: @ops_flags
+func.func @ops_flags(%arg0: i64, %arg1: i64) {
+  // CHECK: %{{.*}} = spirv.IAdd %{{.*}}, %{{.*}} {no_signed_wrap} : i64
+  %0 = arith.addi %arg0, %arg1 overflow<nsw> : i64
+  // CHECK: %{{.*}} = spirv.ISub %{{.*}}, %{{.*}} {no_unsigned_wrap} : i64
+  %1 = arith.subi %arg0, %arg1 overflow<nuw> : i64
+  // CHECK: %{{.*}} = spirv.IMul %{{.*}}, %{{.*}} {no_signed_wrap, no_unsigned_wrap} : i64
+  %2 = arith.muli %arg0, %arg1 overflow<nsw, nuw> : i64
+  return
+}
+
+} // end module
+
+
+// -----
+
+module attributes {
+  spirv.target_env = #spirv.target_env<#spirv.vce<v1.0, [Int8, Int16, Int64, Float16, Float64], []>, #spirv.resource_limits<>>
+} {
+
+// No decorations should be generated is corresponding Extensions/Capabilities are missing
+// CHECK-LABEL: @ops_flags
+func.func @ops_flags(%arg0: i64, %arg1: i64) {
+  // CHECK: %{{.*}} = spirv.IAdd %{{.*}}, %{{.*}} : i64
+  %0 = arith.addi %arg0, %arg1 overflow<nsw> : i64
+  // CHECK: %{{.*}} = spirv.ISub %{{.*}}, %{{.*}} : i64
+  %1 = arith.subi %arg0, %arg1 overflow<nuw> : i64
+  // CHECK: %{{.*}} = spirv.IMul %{{.*}}, %{{.*}} : i64
+  %2 = arith.muli %arg0, %arg1 overflow<nsw, nuw> : i64
+  return
+}
+
+} // end module

@Hardcode84 Hardcode84 changed the title [mlir][spirv] Lower arith overflow flags to corresponding op decorations [mlir][spirv] Lower arith overflow flags to corresponding SPIR-V op decorations Jan 11, 2024
Copy link
Member

@kuhar kuhar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks for plumbing this through!

mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp Outdated Show resolved Hide resolved
mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp Outdated Show resolved Hide resolved
@Hardcode84 Hardcode84 merged commit 4278d9b into llvm:main Jan 11, 2024
4 checks passed
@Hardcode84 Hardcode84 deleted the overgflow-spirv branch January 11, 2024 17:40
Hardcode84 added a commit that referenced this pull request Jan 11, 2024
…PIR-V op decorations (#77714)"

Temporaryly reverting as it broke python bindings

This reverts commit 4278d9b.
jpienaar pushed a commit to jpienaar/llvm-project that referenced this pull request Jan 17, 2024
Add overflow flags support to the following ops:
* `arith.addi`
* `arith.subi`
* `arith.muli`

Example of new syntax:
```
%res = arith.addi %arg1, %arg2 overflow<nsw> : i64
```
Similar to existing LLVM dialect syntax
```
%res = llvm.add %arg1, %arg2 overflow<nsw> : i64
```

Tablegen canonicalization patterns updated to always drop flags, proper
support with tests will be added later.

Updated LLVMIR translation as part of this commit as it currenly written
in a way that it will crash when new attributes added to arith ops
otherwise.

Also lower `arith` overflow flags to corresponding SPIR-V op decorations

Discussion
https://discourse.llvm.org/t/rfc-integer-overflow-flags-support-in-arith-dialect/76025

This effectively rolls forward llvm#77211, llvm#77700 and llvm#77714 while adding a
test to ensure the Python usage is not broken. More follow up needed but
unrelated to the core change here. The changes here are minimal and just
correspond to "textual namespacing" ODS side, no C++ or Python changes
were needed.

---------

Co-authored-by: Ivan Butygin <ivan.butygin@gmail.com>, Yi Wu <yi.wu2@arm.com>
Hardcode84 pushed a commit that referenced this pull request Jan 17, 2024
Add overflow flags support to the following ops:
* `arith.addi`
* `arith.subi`
* `arith.muli`

Example of new syntax:
```
%res = arith.addi %arg1, %arg2 overflow<nsw> : i64
```
Similar to existing LLVM dialect syntax
```
%res = llvm.add %arg1, %arg2 overflow<nsw> : i64
```

Tablegen canonicalization patterns updated to always drop flags, proper
support with tests will be added later.

Updated LLVMIR translation as part of this commit as it currenly written
in a way that it will crash when new attributes added to arith ops
otherwise.

Also lower `arith` overflow flags to corresponding SPIR-V op decorations

Discussion

https://discourse.llvm.org/t/rfc-integer-overflow-flags-support-in-arith-dialect/76025

This effectively rolls forward #77211, #77700 and #77714 while adding a
test to ensure the Python usage is not broken. More follow up needed but
unrelated to the core change here. The changes here are minimal and just
correspond to "textual namespacing" ODS side, no C++ or Python changes
were needed.

---------

---------

Co-authored-by: Ivan Butygin <ivan.butygin@gmail.com>, Yi Wu <yi.wu2@arm.com>
justinfargnoli pushed a commit to justinfargnoli/llvm-project that referenced this pull request Jan 28, 2024
justinfargnoli pushed a commit to justinfargnoli/llvm-project that referenced this pull request Jan 28, 2024
…PIR-V op decorations (llvm#77714)"

Temporaryly reverting as it broke python bindings

This reverts commit 4278d9b.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants