Skip to content

Commit

Permalink
[mlir][emitc] Add EmitC lowering for arith.trunci, arith.extsi, arith…
Browse files Browse the repository at this point in the history
….extui

This commit adds conversion to EmitC for arith dialect casts between integer types (trunc, extsi, extui), excluding indexes for now.
  • Loading branch information
cferry-AMD authored May 22, 2024
1 parent 267de85 commit 7630379
Show file tree
Hide file tree
Showing 3 changed files with 162 additions and 0 deletions.
92 changes: 92 additions & 0 deletions mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/EmitC/IR/EmitC.h"
#include "mlir/Tools/PDLL/AST/Types.h"
#include "mlir/Transforms/DialectConversion.h"

using namespace mlir;
Expand Down Expand Up @@ -112,6 +113,93 @@ class CmpIOpConversion : public OpConversionPattern<arith::CmpIOp> {
}
};

template <typename ArithOp, bool castToUnsigned>
class CastConversion : public OpConversionPattern<ArithOp> {
public:
using OpConversionPattern<ArithOp>::OpConversionPattern;

LogicalResult
matchAndRewrite(ArithOp op, typename ArithOp::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {

Type opReturnType = this->getTypeConverter()->convertType(op.getType());
if (!isa_and_nonnull<IntegerType>(opReturnType))
return rewriter.notifyMatchFailure(op, "expected integer result type");

if (adaptor.getOperands().size() != 1) {
return rewriter.notifyMatchFailure(
op, "CastConversion only supports unary ops");
}

Type operandType = adaptor.getIn().getType();
if (!isa_and_nonnull<IntegerType>(operandType))
return rewriter.notifyMatchFailure(op, "expected integer operand type");

// Signed (sign-extending) casts from i1 are not supported.
if (operandType.isInteger(1) && !castToUnsigned)
return rewriter.notifyMatchFailure(op,
"operation not supported on i1 type");

// to-i1 conversions: arith semantics want truncation, whereas (bool)(v) is
// equivalent to (v != 0). Implementing as (bool)(v & 0x01) gives
// truncation.
if (opReturnType.isInteger(1)) {
auto constOne = rewriter.create<emitc::ConstantOp>(
op.getLoc(), operandType, rewriter.getIntegerAttr(operandType, 1));
auto oneAndOperand = rewriter.create<emitc::BitwiseAndOp>(
op.getLoc(), operandType, adaptor.getIn(), constOne);
rewriter.replaceOpWithNewOp<emitc::CastOp>(op, opReturnType,
oneAndOperand);
return success();
}

bool isTruncation = operandType.getIntOrFloatBitWidth() >
opReturnType.getIntOrFloatBitWidth();
bool doUnsigned = castToUnsigned || isTruncation;

Type castType = opReturnType;
// If the op is a ui variant and the type wanted as
// return type isn't unsigned, we need to issue an unsigned type to do
// the conversion.
if (castType.isUnsignedInteger() != doUnsigned) {
castType = rewriter.getIntegerType(opReturnType.getIntOrFloatBitWidth(),
/*isSigned=*/!doUnsigned);
}

Value actualOp = adaptor.getIn();
// Adapt the signedness of the operand if necessary
if (operandType.isUnsignedInteger() != doUnsigned) {
Type correctSignednessType =
rewriter.getIntegerType(operandType.getIntOrFloatBitWidth(),
/*isSigned=*/!doUnsigned);
actualOp = rewriter.template create<emitc::CastOp>(
op.getLoc(), correctSignednessType, actualOp);
}

auto result = rewriter.template create<emitc::CastOp>(op.getLoc(), castType,
actualOp);

// Cast to the expected output type
if (castType != opReturnType) {
result = rewriter.template create<emitc::CastOp>(op.getLoc(),
opReturnType, result);
}

rewriter.replaceOp(op, result);
return success();
}
};

template <typename ArithOp>
class UnsignedCastConversion : public CastConversion<ArithOp, true> {
using CastConversion<ArithOp, true>::CastConversion;
};

template <typename ArithOp>
class SignedCastConversion : public CastConversion<ArithOp, false> {
using CastConversion<ArithOp, false>::CastConversion;
};

template <typename ArithOp, typename EmitCOp>
class ArithOpConversion final : public OpConversionPattern<ArithOp> {
public:
Expand Down Expand Up @@ -313,6 +401,10 @@ void mlir::populateArithToEmitCPatterns(TypeConverter &typeConverter,
IntegerOpConversion<arith::SubIOp, emitc::SubOp>,
CmpIOpConversion,
SelectOpConversion,
// Truncation is guaranteed for unsigned types.
UnsignedCastConversion<arith::TruncIOp>,
SignedCastConversion<arith::ExtSIOp>,
UnsignedCastConversion<arith::ExtUIOp>,
ItoFCastOpConversion<arith::SIToFPOp>,
ItoFCastOpConversion<arith::UIToFPOp>,
FtoICastOpConversion<arith::FPToSIOp>,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,3 +63,10 @@ func.func @arith_cast_fptoui_i1(%arg0: f32) -> i1 {
return %t: i1
}

// -----

func.func @arith_extsi_i1_to_i32(%arg0: i1) {
// expected-error @+1 {{failed to legalize operation 'arith.extsi'}}
%idx = arith.extsi %arg0 : i1 to i32
return
}
63 changes: 63 additions & 0 deletions mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -177,3 +177,66 @@ func.func @arith_int_to_float_cast_ops(%arg0: i8, %arg1: i64) {

return
}

// -----

func.func @arith_trunci(%arg0: i32) -> i8 {
// CHECK-LABEL: arith_trunci
// CHECK-SAME: (%[[Arg0:[^ ]*]]: i32)
// CHECK: %[[CastUI:.*]] = emitc.cast %[[Arg0]] : i32 to ui32
// CHECK: %[[Trunc:.*]] = emitc.cast %[[CastUI]] : ui32 to ui8
// CHECK: emitc.cast %[[Trunc]] : ui8 to i8
%truncd = arith.trunci %arg0 : i32 to i8

return %truncd : i8
}

// -----

func.func @arith_trunci_to_i1(%arg0: i32) -> i1 {
// CHECK-LABEL: arith_trunci_to_i1
// CHECK-SAME: (%[[Arg0:[^ ]*]]: i32)
// CHECK: %[[Const:.*]] = "emitc.constant"
// CHECK-SAME: value = 1
// CHECK: %[[And:.*]] = emitc.bitwise_and %[[Arg0]], %[[Const]] : (i32, i32) -> i32
// CHECK: emitc.cast %[[And]] : i32 to i1
%truncd = arith.trunci %arg0 : i32 to i1

return %truncd : i1
}

// -----

func.func @arith_extsi(%arg0: i32) {
// CHECK-LABEL: arith_extsi
// CHECK-SAME: ([[Arg0:[^ ]*]]: i32)
// CHECK: emitc.cast [[Arg0]] : i32 to i64
%extd = arith.extsi %arg0 : i32 to i64

return
}

// -----

func.func @arith_extui(%arg0: i32) {
// CHECK-LABEL: arith_extui
// CHECK-SAME: (%[[Arg0:[^ ]*]]: i32)
// CHECK: %[[Conv0:.*]] = emitc.cast %[[Arg0]] : i32 to ui32
// CHECK: %[[Conv1:.*]] = emitc.cast %[[Conv0]] : ui32 to ui64
// CHECK: emitc.cast %[[Conv1]] : ui64 to i64
%extd = arith.extui %arg0 : i32 to i64

return
}

// -----

func.func @arith_extui_i1_to_i32(%arg0: i1) {
// CHECK-LABEL: arith_extui_i1_to_i32
// CHECK-SAME: (%[[Arg0:[^ ]*]]: i1)
// CHECK: %[[Conv0:.*]] = emitc.cast %[[Arg0]] : i1 to ui1
// CHECK: %[[Conv1:.*]] = emitc.cast %[[Conv0]] : ui1 to ui32
// CHECK: emitc.cast %[[Conv1]] : ui32 to i32
%idx = arith.extui %arg0 : i1 to i32
return
}

0 comments on commit 7630379

Please sign in to comment.