diff --git a/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp b/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp index 265293b83f84c..ee694104dc918 100644 --- a/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp +++ b/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp @@ -856,6 +856,17 @@ convertArithRoundingModeToSPIRV(arith::RoundingMode roundingMode) { llvm_unreachable("Unhandled rounding mode"); } +static bool isSignednessCast(Type srcType, Type dstType) { + if (srcType.isInteger() && dstType.isInteger()) { + return srcType.getIntOrFloatBitWidth() == dstType.getIntOrFloatBitWidth(); + } + if (isa(srcType) && isa(dstType)) { + return isSignednessCast(cast(srcType).getElementType(), + cast(dstType).getElementType()); + } + return false; +} + /// Converts type-casting standard operations to SPIR-V operations. template struct TypeCastingOpPattern final : public OpConversionPattern { @@ -864,42 +875,86 @@ struct TypeCastingOpPattern final : public OpConversionPattern { LogicalResult matchAndRewrite(Op op, typename Op::Adaptor adaptor, ConversionPatternRewriter &rewriter) const override { - Type srcType = llvm::getSingleElement(adaptor.getOperands()).getType(); - Type dstType = this->getTypeConverter()->convertType(op.getType()); - if (!dstType) - return getTypeConversionFailure(rewriter, op); + TypeRange dstTypes; + SmallVector newDstTypes; + SmallVector unrealizedConvCastSrcs; + SmallVector unrealizedConvCastDstTypes; + constexpr bool isUnrealizedConvCast = + std::is_same_v; + if constexpr (isUnrealizedConvCast) + dstTypes = op.getOutputs().getTypes(); + else + dstTypes = op.getType(); + LogicalResult matched = failure(); + for (auto [src, dstType] : llvm::zip(adaptor.getOperands(), dstTypes)) { + Type srcType = src.getType(); + // Use UnrealizedConversionCast as the bridge so that we don't need to + // pull in patterns for other dialects. + if (isUnrealizedConvCast && !isSignednessCast(srcType, dstType)) { + newDstTypes.push_back(dstType); + unrealizedConvCastSrcs.push_back(src); + unrealizedConvCastDstTypes.push_back(dstType); + continue; + } + dstType = this->getTypeConverter()->convertType(dstType); + if (!dstType) + return getTypeConversionFailure(rewriter, op); + + if (isBoolScalarOrVector(srcType) || isBoolScalarOrVector(dstType)) + return failure(); + matched = success(); + newDstTypes.push_back(dstType); + } - if (isBoolScalarOrVector(srcType) || isBoolScalarOrVector(dstType)) + if (failed(matched)) return failure(); - if (dstType == srcType) { - // Due to type conversion, we are seeing the same source and target type. - // Then we can just erase this operation by forwarding its operand. - rewriter.replaceOp(op, adaptor.getOperands().front()); - } else { - // Compute new rounding mode (if any). - std::optional rm = std::nullopt; - if (auto roundingModeOp = - dyn_cast(*op)) { - if (arith::RoundingModeAttr roundingMode = - roundingModeOp.getRoundingModeAttr()) { - if (!(rm = - convertArithRoundingModeToSPIRV(roundingMode.getValue()))) { - return rewriter.notifyMatchFailure( - op->getLoc(), - llvm::formatv("unsupported rounding mode '{0}'", roundingMode)); - } + // Compute new rounding mode (if any). + Location loc = op->getLoc(); + std::optional rm = std::nullopt; + if (auto roundingModeOp = + dyn_cast(*op)) { + if (arith::RoundingModeAttr roundingMode = + roundingModeOp.getRoundingModeAttr()) { + if (!(rm = convertArithRoundingModeToSPIRV(roundingMode.getValue()))) { + return rewriter.notifyMatchFailure( + loc, + llvm::formatv("unsupported rounding mode '{0}'", roundingMode)); } } - // Create replacement op and attach rounding mode attribute (if any). - auto newOp = rewriter.template replaceOpWithNewOp( - op, dstType, adaptor.getOperands()); - if (rm) { - newOp->setAttr( - getDecorationString(spirv::Decoration::FPRoundingMode), - spirv::FPRoundingModeAttr::get(rewriter.getContext(), *rm)); + } + + llvm::DenseMap unrealizedConvCastSrcDstMap; + if (!unrealizedConvCastSrcs.empty()) { + auto newOp = rewriter.create( + loc, unrealizedConvCastDstTypes, unrealizedConvCastSrcs); + for (auto [src, dst] : + llvm::zip(unrealizedConvCastSrcs, newOp.getResults())) + unrealizedConvCastSrcDstMap[src] = dst; + } + + SmallVector newValues; + for (auto [src, dstType] : llvm::zip(adaptor.getOperands(), newDstTypes)) { + Type srcType = src.getType(); + if (dstType == srcType) { + // Due to type conversion, we are seeing the same source and target + // type. Then we can just erase this operation by forwarding its + // operand. + newValues.push_back(src); + } else if (isUnrealizedConvCast && !isSignednessCast(srcType, dstType)) { + newValues.push_back(unrealizedConvCastSrcDstMap[src]); + } else { + // Create replacement op and attach rounding mode attribute (if any). + auto newOp = rewriter.template create(loc, dstType, src); + if (rm) { + newOp->setAttr( + getDecorationString(spirv::Decoration::FPRoundingMode), + spirv::FPRoundingModeAttr::get(rewriter.getContext(), *rm)); + } + newValues.push_back(newOp.getResult()); } } + rewriter.replaceOp(op, newValues); return success(); } }; @@ -1331,6 +1386,7 @@ void mlir::arith::populateArithToSPIRVPatterns( TypeCastingOpPattern, TypeCastingOpPattern, TypeCastingOpPattern, + TypeCastingOpPattern, CmpIOpBooleanPattern, CmpIOpPattern, CmpFOpNanNonePattern, CmpFOpPattern, AddUIExtendedOpPattern, @@ -1385,8 +1441,17 @@ struct ConvertArithToSPIRVPass SPIRVTypeConverter typeConverter(targetAttr, options); // Use UnrealizedConversionCast as the bridge so that we don't need to pull - // in patterns for other dialects. - target->addLegalOp(); + // in patterns for other dialects. If the UnrealizedConversionCast is + // between integers of the same bitwidth, it is either a nop or a + // signedness cast which the corresponding pattern convert to Bitcast. + target->addDynamicallyLegalOp( + [&](UnrealizedConversionCastOp op) { + for (auto [srcType, dstType] : + llvm::zip(op.getOperandTypes(), op.getResultTypes())) + if (isSignednessCast(srcType, dstType)) + return false; + return true; + }); // Fail hard when there are any remaining 'arith' ops. target->addIllegalDialect(); diff --git a/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir b/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir index 6e2352e706acc..b9a4232758a17 100644 --- a/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir +++ b/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir @@ -743,6 +743,35 @@ func.func @bit_cast(%arg0: vector<2xf32>, %arg1: i64) { return } +// CHECK-LABEL: @unrealized_conversion_cast +func.func @unrealized_conversion_cast(%arg0: vector<3xi64>, %arg1: i16, %arg2: f32) { + // CHECK-NEXT: spirv.Bitcast %{{.+}} : vector<3xi64> to vector<3xui64> + %0 = builtin.unrealized_conversion_cast %arg0 : vector<3xi64> to vector<3xui64> + // CHECK-NEXT: spirv.Bitcast %{{.+}} : i16 to ui16 + %1 = builtin.unrealized_conversion_cast %arg1 : i16 to ui16 + + // CHECK-NEXT: spirv.Bitcast %{{.+}} : vector<3xi64> to vector<3xui64> + // CHECK-NEXT: spirv.Bitcast %{{.+}} : i16 to ui16 + %2:2 = builtin.unrealized_conversion_cast %arg0, %arg1 : vector<3xi64>, i16 to vector<3xui64>, ui16 + + // CHECK-NEXT: spirv.Bitcast %{{.+}} : i16 to ui16 + %3:2 = builtin.unrealized_conversion_cast %arg0, %arg1 : vector<3xi64>, i16 to vector<3xi64>, ui16 + // CHECK-NEXT: spirv.Bitcast %{{.+}} : vector<3xi64> to vector<3xui64> + %4:2 = builtin.unrealized_conversion_cast %arg0, %arg1 : vector<3xi64>, i16 to vector<3xui64>, i16 + + // bitcast from float to int should be represented using arith.bitcast + // CHECK-NEXT: builtin.unrealized_conversion_cast %{{.+}} : f32 to i32 + %5 = builtin.unrealized_conversion_cast %arg2 : f32 to i32 + + // test mixed signedness and non-signedness cast + // CHECK-NEXT: builtin.unrealized_conversion_cast %{{.+}} : f32 to f16 + // CHECK-NEXT: spirv.Bitcast %{{.+}} : i32 to ui32 + %6:2 = builtin.unrealized_conversion_cast %5, %arg2 : i32, f32 to ui32, f16 + + // CHECK-NEXT: return + return +} + // CHECK-LABEL: @fpext1 func.func @fpext1(%arg0: f16) -> f64 { // CHECK: spirv.FConvert %{{.*}} : f16 to f64