diff --git a/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp b/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp index 0bc001b5d576a..d6b1e9552fbc5 100644 --- a/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp +++ b/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp @@ -601,6 +601,64 @@ struct UIToFPI1Pattern final : public OpConversionPattern { } }; +/// Converts arith.uitofp/arith.sitofp to spirv.ConvertUToF/spirv.ConvertSToF. +/// When the source integer type was widened during type conversion (e.g., i8 +/// emulated as i32), the upper bits of the widened value may contain garbage. +/// This pattern cleans the upper bits before the conversion: +/// - For unsigned (IsSigned=false): mask with BitwiseAnd. +/// - For signed (IsSigned=true): sign-extend via ShiftLeftLogical + +/// ShiftRightArithmetic. +template +struct IntToFPPattern final : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(ArithOp op, typename ArithOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Type srcType = adaptor.getOperands().front().getType(); + if (isBoolScalarOrVector(srcType)) + return failure(); + + Type dstType = this->getTypeConverter()->convertType(op.getType()); + if (!dstType) + return getTypeConversionFailure(rewriter, op); + + // Check if the source integer type was widened during type conversion. + unsigned originalBitwidth = + getElementTypeOrSelf(op.getIn().getType()).getIntOrFloatBitWidth(); + unsigned convertedBitwidth = + getElementTypeOrSelf(srcType).getIntOrFloatBitWidth(); + + if (originalBitwidth >= convertedBitwidth) { + rewriter.replaceOpWithNewOp(op, dstType, adaptor.getOperands()); + return success(); + } + + // The source was widened. Clean the upper bits before converting. + Location loc = op.getLoc(); + Value cleaned; + if constexpr (IsSigned) { + // Sign-extend by shifting left then arithmetic right. + unsigned shiftAmount = convertedBitwidth - originalBitwidth; + Value shiftSize = + getScalarOrVectorConstInt(srcType, shiftAmount, rewriter, loc); + Value shifted = spirv::ShiftLeftLogicalOp::create( + rewriter, loc, srcType, adaptor.getIn(), shiftSize); + cleaned = spirv::ShiftRightArithmeticOp::create(rewriter, loc, srcType, + shifted, shiftSize); + } else { + // Zero-extend by masking off the upper bits. + Value mask = getScalarOrVectorConstInt( + srcType, llvm::maskTrailingOnes(originalBitwidth), rewriter, + loc); + cleaned = spirv::BitwiseAndOp::create(rewriter, loc, srcType, + adaptor.getIn(), mask); + } + rewriter.replaceOpWithNewOp(op, dstType, cleaned); + return success(); + } +}; + //===----------------------------------------------------------------------===// // IndexCastOp //===----------------------------------------------------------------------===// @@ -1376,8 +1434,9 @@ void mlir::arith::populateArithToSPIRVPatterns( TypeCastingOpPattern, TruncIPattern, TruncII1Pattern, TypeCastingOpPattern, - TypeCastingOpPattern, UIToFPI1Pattern, - TypeCastingOpPattern, + IntToFPPattern, + UIToFPI1Pattern, + IntToFPPattern, TypeCastingOpPattern, TypeCastingOpPattern, TypeCastingOpPattern, diff --git a/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir b/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir index 3cb5294598994..9c726b8643a46 100644 --- a/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir +++ b/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir @@ -1491,6 +1491,66 @@ func.func @float_scalar(%arg0: f16) { return } +// When i8 is emulated as i32 (no Int8 capability), uitofp from i8 needs a +// bitmask to clear upper bits that may contain garbage from sign-extension +// during packed byte extraction. +// CHECK-LABEL: @uitofp_i8_emulated_f32 +func.func @uitofp_i8_emulated_f32(%arg0: i8) -> f32 { + // CHECK: %[[MASK:.+]] = spirv.Constant 255 : i32 + // CHECK: %[[MASKED:.+]] = spirv.BitwiseAnd %{{.*}}, %[[MASK]] : i32 + // CHECK: spirv.ConvertUToF %[[MASKED]] : i32 to f32 + %0 = arith.uitofp %arg0 : i8 to f32 + return %0 : f32 +} + +// CHECK-LABEL: @uitofp_i16_emulated_f32 +func.func @uitofp_i16_emulated_f32(%arg0: i16) -> f32 { + // CHECK: %[[MASK:.+]] = spirv.Constant 65535 : i32 + // CHECK: %[[MASKED:.+]] = spirv.BitwiseAnd %{{.*}}, %[[MASK]] : i32 + // CHECK: spirv.ConvertUToF %[[MASKED]] : i32 to f32 + %0 = arith.uitofp %arg0 : i16 to f32 + return %0 : f32 +} + +// CHECK-LABEL: @uitofp_vec_i8_emulated_f32 +func.func @uitofp_vec_i8_emulated_f32(%arg0: vector<4xi8>) -> vector<4xf32> { + // CHECK: %[[MASK:.+]] = spirv.Constant dense<255> : vector<4xi32> + // CHECK: %[[MASKED:.+]] = spirv.BitwiseAnd %{{.*}}, %[[MASK]] : vector<4xi32> + // CHECK: spirv.ConvertUToF %[[MASKED]] : vector<4xi32> to vector<4xf32> + %0 = arith.uitofp %arg0 : vector<4xi8> to vector<4xf32> + return %0 : vector<4xf32> +} + +// CHECK-LABEL: @sitofp_i8_emulated_f32 +func.func @sitofp_i8_emulated_f32(%arg0: i8) -> f32 { + // CHECK: %[[SHIFT:.+]] = spirv.Constant 24 : i32 + // CHECK: %[[SHL:.+]] = spirv.ShiftLeftLogical %{{.*}}, %[[SHIFT]] : i32, i32 + // CHECK: %[[SHR:.+]] = spirv.ShiftRightArithmetic %[[SHL]], %[[SHIFT]] : i32, i32 + // CHECK: spirv.ConvertSToF %[[SHR]] : i32 to f32 + %0 = arith.sitofp %arg0 : i8 to f32 + return %0 : f32 +} + +// CHECK-LABEL: @sitofp_i16_emulated_f32 +func.func @sitofp_i16_emulated_f32(%arg0: i16) -> f32 { + // CHECK: %[[SHIFT:.+]] = spirv.Constant 16 : i32 + // CHECK: %[[SHL:.+]] = spirv.ShiftLeftLogical %{{.*}}, %[[SHIFT]] : i32, i32 + // CHECK: %[[SHR:.+]] = spirv.ShiftRightArithmetic %[[SHL]], %[[SHIFT]] : i32, i32 + // CHECK: spirv.ConvertSToF %[[SHR]] : i32 to f32 + %0 = arith.sitofp %arg0 : i16 to f32 + return %0 : f32 +} + +// CHECK-LABEL: @sitofp_vec_i8_emulated_f32 +func.func @sitofp_vec_i8_emulated_f32(%arg0: vector<4xi8>) -> vector<4xf32> { + // CHECK: %[[SHIFT:.+]] = spirv.Constant dense<24> : vector<4xi32> + // CHECK: %[[SHL:.+]] = spirv.ShiftLeftLogical %{{.*}}, %[[SHIFT]] : vector<4xi32>, vector<4xi32> + // CHECK: %[[SHR:.+]] = spirv.ShiftRightArithmetic %[[SHL]], %[[SHIFT]] : vector<4xi32>, vector<4xi32> + // CHECK: spirv.ConvertSToF %[[SHR]] : vector<4xi32> to vector<4xf32> + %0 = arith.sitofp %arg0 : vector<4xi8> to vector<4xf32> + return %0 : vector<4xf32> +} + } // end module // -----