Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[mlir][emitc] Add EmitC lowering for arith.trunci, arith.extsi, arith.extui #91491

Merged
merged 4 commits into from
May 22, 2024

Conversation

cferry-AMD
Copy link
Contributor

These operations can be lowered to EmitC provided the sign-extension and truncation behavior is respected.

Per C++ Reference: when casting to a narrower integer, truncation is guaranteed if unsigned casts are performed, or C++20 is used regardless of the sign. This implementation sticks to unsigned for trunci, so C++20 is not necessary.

This implementation is a bit more generic than needed by these three operations to accomodate index_cast and index_castui at a later point (specific emitc.size_t and emitc.ssize_t types are being discussed).

@llvmbot llvmbot added the mlir label May 8, 2024
@llvmbot
Copy link
Collaborator

llvmbot commented May 8, 2024

@llvm/pr-subscribers-mlir-emitc

@llvm/pr-subscribers-mlir

Author: Corentin Ferry (cferry-AMD)

Changes

These operations can be lowered to EmitC provided the sign-extension and truncation behavior is respected.

Per C++ Reference: when casting to a narrower integer, truncation is guaranteed if unsigned casts are performed, or C++20 is used regardless of the sign. This implementation sticks to unsigned for trunci, so C++20 is not necessary.

This implementation is a bit more generic than needed by these three operations to accomodate index_cast and index_castui at a later point (specific emitc.size_t and emitc.ssize_t types are being discussed).


Full diff: https://github.com/llvm/llvm-project/pull/91491.diff

3 Files Affected:

  • (modified) mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp (+76)
  • (modified) mlir/test/Conversion/ArithToEmitC/arith-to-emitc-unsupported.mlir (+19)
  • (modified) mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir (+39)
diff --git a/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp b/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp
index 1447b182ccfdb..6216e6ea89b9b 100644
--- a/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp
+++ b/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp
@@ -112,6 +112,78 @@ class CmpIOpConversion : public OpConversionPattern<arith::CmpIOp> {
   }
 };
 
+template <typename ArithOp, bool needsUnsigned>
+class CastConversion : public OpConversionPattern<ArithOp> {
+public:
+  using OpConversionPattern<ArithOp>::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(ArithOp op, typename ArithOp::Adaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+
+    Type opReturnType = this->getTypeConverter()->convertType(op.getType());
+    if (!isa_and_nonnull<IntegerType>(opReturnType)) {
+      return rewriter.notifyMatchFailure(op, "expected integer result type");
+    }
+
+    if (adaptor.getOperands().size() != 1) {
+      return rewriter.notifyMatchFailure(
+          op, "CastConversion only supports unary ops");
+    }
+
+    Type operandType = adaptor.getIn().getType();
+    if (!isa_and_nonnull<IntegerType>(operandType)) {
+      return rewriter.notifyMatchFailure(op, "expected integer operand type");
+    }
+
+    bool isTruncation = operandType.getIntOrFloatBitWidth() >
+                        opReturnType.getIntOrFloatBitWidth();
+    bool doUnsigned = needsUnsigned || isTruncation;
+
+    Type castType = opReturnType;
+    // For int conversions: if the op is a ui variant and the type wanted as
+    // return type isn't unsigned, we need to issue an unsigned type to do
+    // the conversion.
+    if (castType.isUnsignedInteger() != doUnsigned) {
+      castType = rewriter.getIntegerType(opReturnType.getIntOrFloatBitWidth(),
+                                         /*isSigned=*/!doUnsigned);
+    }
+
+    Value actualOp = adaptor.getIn();
+    // Fix the signedness of the operand if necessary
+    if (operandType.isUnsignedInteger() != doUnsigned) {
+      Type correctSignednessType =
+          rewriter.getIntegerType(operandType.getIntOrFloatBitWidth(),
+                                  /*isSigned=*/!doUnsigned);
+      actualOp = rewriter.template create<emitc::CastOp>(
+          op.getLoc(), correctSignednessType, actualOp);
+    }
+
+    auto result = rewriter.template create<emitc::CastOp>(op.getLoc(), castType,
+                                                          actualOp);
+
+    // Fix the signedness of what this operation returns (for integers,
+    // the arith ops want signless results)
+    if (castType != opReturnType) {
+      result = rewriter.template create<emitc::CastOp>(op.getLoc(),
+                                                       opReturnType, result);
+    }
+
+    rewriter.replaceOp(op, result);
+    return success();
+  }
+};
+
+template <typename ArithOp>
+class UnsignedCastConversion : public CastConversion<ArithOp, true> {
+  using CastConversion<ArithOp, true>::CastConversion;
+};
+
+template <typename ArithOp>
+class SignedCastConversion : public CastConversion<ArithOp, false> {
+  using CastConversion<ArithOp, false>::CastConversion;
+};
+
 template <typename ArithOp, typename EmitCOp>
 class ArithOpConversion final : public OpConversionPattern<ArithOp> {
 public:
@@ -313,6 +385,10 @@ void mlir::populateArithToEmitCPatterns(TypeConverter &typeConverter,
     IntegerOpConversion<arith::SubIOp, emitc::SubOp>,
     CmpIOpConversion,
     SelectOpConversion,
+    // Truncation is guaranteed for unsigned types.
+    UnsignedCastConversion<arith::TruncIOp>,
+    SignedCastConversion<arith::ExtSIOp>,
+    UnsignedCastConversion<arith::ExtUIOp>,
     ItoFCastOpConversion<arith::SIToFPOp>,
     ItoFCastOpConversion<arith::UIToFPOp>,
     FtoICastOpConversion<arith::FPToSIOp>,
diff --git a/mlir/test/Conversion/ArithToEmitC/arith-to-emitc-unsupported.mlir b/mlir/test/Conversion/ArithToEmitC/arith-to-emitc-unsupported.mlir
index 66dfa8fa3e157..551c3ba7a77ef 100644
--- a/mlir/test/Conversion/ArithToEmitC/arith-to-emitc-unsupported.mlir
+++ b/mlir/test/Conversion/ArithToEmitC/arith-to-emitc-unsupported.mlir
@@ -63,3 +63,22 @@ func.func @arith_cast_fptoui_i1(%arg0: f32) -> i1 {
   return %t: i1
 }
 
+// -----
+
+func.func @index_cast(%arg0: i32) -> i32 {
+  // expected-error @+1 {{failed to legalize operation 'arith.index_cast'}}
+  %idx = arith.index_cast %arg0 : i32 to index
+  %int = arith.index_cast %idx : index to i32
+
+  return %int : i32
+}
+
+// -----
+
+func.func @index_castui(%arg0: i32) -> i32 {
+  // expected-error @+1 {{failed to legalize operation 'arith.index_castui'}}
+  %idx = arith.index_castui %arg0 : i32 to index
+  %int = arith.index_castui %idx : index to i32
+
+  return %int : i32
+}
diff --git a/mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir b/mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir
index 79fecd61494d0..80665bacd2a5c 100644
--- a/mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir
+++ b/mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir
@@ -177,3 +177,42 @@ func.func @arith_int_to_float_cast_ops(%arg0: i8, %arg1: i64) {
 
   return
 }
+
+// -----
+
+func.func @trunci(%arg0: i32) -> i8 {
+  // CHECK-LABEL: trunci
+  // CHECK-SAME: (%[[Arg0:[^ ]*]]: i32)
+  // CHECK: %[[CastUI:.*]] = emitc.cast %[[Arg0]] : i32 to ui32
+  // CHECK: %[[Trunc:.*]] = emitc.cast %[[CastUI]] : ui32 to ui8
+  // CHECK: emitc.cast %[[Trunc]] : ui8 to i8
+  %truncd = arith.trunci %arg0 : i32 to i8
+
+  return %truncd : i8
+}
+
+// -----
+
+func.func @extsi(%arg0: i32) {
+  // CHECK-LABEL: extsi
+  // CHECK-SAME: ([[Arg0:[^ ]*]]: i32)
+  // CHECK: emitc.cast [[Arg0]] : i32 to i64
+
+  %extd = arith.extsi %arg0 : i32 to i64
+
+  return
+}
+
+// -----
+
+func.func @extui(%arg0: i32) {
+  // CHECK-LABEL: extui
+  // CHECK-SAME: (%[[Arg0:[^ ]*]]: i32)
+  // CHECK: %[[Conv0:.*]] = emitc.cast %[[Arg0]] : i32 to ui32
+  // CHECK: %[[Conv1:.*]] = emitc.cast %[[Conv0]] : ui32 to ui64
+  // CHECK: emitc.cast %[[Conv1]] : ui64 to i64
+
+  %extd = arith.extui %arg0 : i32 to i64
+
+  return
+}

@cferry-AMD
Copy link
Contributor Author

Pool of reviewers: @simon-camp @marbre @TinaAMD @mgehre-amd

@TinaAMD TinaAMD self-requested a review May 10, 2024 09:03
Copy link
Contributor

@TinaAMD TinaAMD left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Neat trick with the unsigned interpretation for truncation!

mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp Outdated Show resolved Hide resolved
mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp Outdated Show resolved Hide resolved
mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp Outdated Show resolved Hide resolved
mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp Outdated Show resolved Hide resolved
mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp Outdated Show resolved Hide resolved
mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp Outdated Show resolved Hide resolved
mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir Outdated Show resolved Hide resolved
mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir Outdated Show resolved Hide resolved
Copy link
Contributor

@simon-camp simon-camp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think truncation to i1 needs to be handled specially, as the arith dialect discards the high bits, but a conversion to bool is similar to x != 0. For the same reason signed extension from i1 should be rejected by this pattern. Unsigned extension from i1 works correctly I think.

https://godbolt.org/z/sEcdnz4s4

@cferry-AMD
Copy link
Contributor Author

Yes, the i1 case is special indeed -- thanks for the remark! I added its handling.

Copy link

github-actions bot commented May 22, 2024

✅ With the latest revision this PR passed the C/C++ code formatter.

Copy link
Contributor

@simon-camp simon-camp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Other than the formatting and missing newline this looks good to me.

Out of curiosity, what would be a good way of implementing the sign extension for i1? Doing unsigned extension, oring with 0b1111...00... and casting back to signed?

mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir Outdated Show resolved Hide resolved
@cferry-AMD
Copy link
Contributor Author

cferry-AMD commented May 22, 2024

I would do just the unsigned extension: I don't think it makes sense to interpret the only bit of an i1 as a sign bit (then what's its value?)... then the other choice would be to see a 1 as +1, which needs no sign extension to be preserved.

Now assuming we still sign-extend: as an alternative to the or you suggest, we could go unsigned to i8, do x | 0xFE, and then perform regular sign-extension with a cast. That would let us work e.g. with index types, which bitwidth is unspecified.

@cferry-AMD cferry-AMD merged commit 7630379 into llvm:main May 22, 2024
3 of 4 checks passed
@cferry-AMD cferry-AMD deleted the corentin.upstream_arith_trunc branch May 22, 2024 14:33
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants