diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td index f8e1ab38e80d4..57a0c67e82c47 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td @@ -6380,6 +6380,49 @@ def NVVM_SubFOp : NVVM_FloatBinaryOp<"subf"> { }]; } +def NVVM_FmaOp : NVVM_Op<"fma", [Pure, SameOperandsAndResultType]> { + let summary = [{ + Performs floating point fused multiply-add operation with support for mixed + precision operands + }]; + let description = [{ + The `nvvm.fma` operation performs floating point fused multiply-add of + three operands of the same type. + + The rounding mode is specified by the `rnd` attribute, saturation mode by + the `sat` attribute, flush-to-zero by the `ftz` attribute, and ReLU by the + `relu` attribute. + + Out-of-bounds (OOB) behavior is controlled by the `oob` attribute. `oob` + clamps the result to 0 if either of the operands is `OOB NaN` (see [Tensors](https://docs.nvidia.com/cuda/parallel-thread-execution/#tensors)). + + For more information, see PTX ISA: + - [floating point fused multiply-add](https://docs.nvidia.com/cuda/parallel-thread-execution/#floating-point-instructions-fma) + - [half-precision floating point fused multiply-add](https://docs.nvidia.com/cuda/parallel-thread-execution/#half-precision-floating-point-instructions-fma) + }]; + let arguments = (ins + SIMTFloatType:$a, + SIMTFloatType:$b, + SIMTFloatType:$c, + FPArithRoundingMode:$rnd, + DefaultValuedAttr:$sat, + DefaultValuedAttr:$ftz, + DefaultValuedAttr:$relu, + DefaultValuedAttr:$oob + ); + let results = (outs SIMTFloatType:$res); + let assemblyFormat = "$a `,` $b `,` $c attr-dict `:` type($a)"; + let hasVerifier = 1; + + let extraClassDeclaration = [{ + static void lowerFmaToLLVMIR( + Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder); + }]; + let llvmBuilder = [{ + NVVM::FmaOp::lowerFmaToLLVMIR(*op, moduleTranslation, builder); + }]; +} + //===----------------------------------------------------------------------===// // NVVM tensormap.replace Op //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp index 6ccd59cec65bc..7d49aa3878ebe 100644 --- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp @@ -3104,6 +3104,53 @@ LogicalResult NVVM::AddFOp::verify() { return verifyAddSubFOp(*this); } LogicalResult NVVM::SubFOp::verify() { return verifyAddSubFOp(*this); } +LogicalResult NVVM::FmaOp::verify() { + auto opType = getRes().getType(); + mlir::NVVM::FPRoundingMode rndMode = getRnd(); + mlir::NVVM::SaturationMode satMode = getSat(); + bool isFTZ = getFtz(); + bool isRelu = getRelu(); + bool hasOOB = getOob(); + + auto getBaseFType = [](Type type) -> Type { + if (isa(type)) + return cast(type).getElementType(); + return type; + }; + + auto opBaseType = getBaseFType(opType); + + if (rndMode == NVVM::FPRoundingMode::NONE) + return emitOpError("rounding mode must be specified"); + + if (isRelu && satMode == NVVM::SaturationMode::SAT) + return emitOpError("relu and saturation are not supported together"); + + if (hasOOB && (satMode == NVVM::SaturationMode::SAT || isFTZ)) + return emitOpError("oob is not supported with saturation or FTZ"); + + if (!(opBaseType.isF16() || opBaseType.isBF16()) && (isRelu || hasOOB)) + return emitOpError("relu and oob are only supported for f16 and bf16"); + + if (opBaseType.isF64() && (satMode != NVVM::SaturationMode::NONE || isFTZ)) + return emitOpError("FTZ and saturation are not supported for f64 type"); + + if (opBaseType.isF16() && rndMode != NVVM::FPRoundingMode::RN) + return emitOpError( + "only RN rounding mode is supported for f16 and vector<2xf16>"); + + if (opBaseType.isBF16()) { + if (rndMode != NVVM::FPRoundingMode::RN) + return emitOpError( + "only RN rounding mode is supported for bf16 and vector<2xbf16>"); + if (satMode != NVVM::SaturationMode::NONE || isFTZ) + return emitOpError( + "FTZ and saturation are not supported for bf16 and vector<2xbf16>"); + } + + return success(); +} + /// Packs the given `field` into the `result`. /// The `result` is 64-bits and each `field` can be 32-bits or narrower. static llvm::Value * diff --git a/mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp index 092643f408ce6..5e5f6700c9fd7 100644 --- a/mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp @@ -446,6 +446,33 @@ getFenceProxySyncRestrictID(NVVM::MemOrderKind order) { nvvm_fence_proxy_async_generic_release_sync_restrict_space_cta_scope_cluster; } +// Calls an LLVM intrinsic on the given operands. For f32/f64 vector types, +// the intrinsic is called per-element and the results are packed back into a +// vector. If retType is non-null, it is forwarded as the return-type +// overload to `createIntrinsicCall`. +static llvm::Value * +createScalarizedIntrinsicCall(llvm::IRBuilderBase &builder, + llvm::Intrinsic::ID IID, llvm::Type *opTypeLLVM, + ArrayRef operands, + llvm::Type *retType) { + if (opTypeLLVM->isVectorTy() && (opTypeLLVM->getScalarType()->isFloatTy() || + opTypeLLVM->getScalarType()->isDoubleTy())) { + llvm::Value *result = llvm::PoisonValue::get( + llvm::FixedVectorType::get(opTypeLLVM->getScalarType(), 2)); + for (int64_t i = 0; i < 2; ++i) { + llvm::SmallVector scalarArgs; + for (llvm::Value *op : operands) + scalarArgs.push_back( + builder.CreateExtractElement(op, builder.getInt32(i))); + llvm::Value *res = createIntrinsicCall(builder, IID, retType, scalarArgs); + result = builder.CreateInsertElement(result, res, builder.getInt32(i)); + } + return result; + } + + return createIntrinsicCall(builder, IID, retType, operands); +} + void NVVM::AddFOp::lowerAddFToLLVMIR(llvm::Value *argLHS, llvm::Value *argRHS, Value res, NVVM::FPRoundingMode rndMode, NVVM::SaturationMode satMode, bool isFTZ, @@ -493,31 +520,9 @@ void NVVM::AddFOp::lowerAddFToLLVMIR(llvm::Value *argLHS, llvm::Value *argRHS, llvm::Intrinsic::nvvm_add_rp_d, llvm::Intrinsic::nvvm_add_rz_d}; auto addIntrinsic = [&](llvm::Intrinsic::ID IID) -> llvm::Value * { - auto createAddIntrinsicCall = [&](llvm::Intrinsic::ID IID, llvm::Value *LHS, - llvm::Value *RHS) -> llvm::CallInst * { - llvm::SmallVector callArgs; - callArgs.push_back(LHS); - callArgs.push_back(RHS); - return createIntrinsicCall(builder, IID, callArgs); - }; - - if (isVectorOp && (opTypeLLVM->getScalarType()->isFloatTy() || - opTypeLLVM->getScalarType()->isDoubleTy())) { - llvm::Value *result = llvm::PoisonValue::get( - llvm::FixedVectorType::get(opTypeLLVM->getScalarType(), 2)); - for (int64_t i = 0; i < 2; ++i) { - llvm::Value *lhsElemi = - builder.CreateExtractElement(argLHS, builder.getInt32(i)); - llvm::Value *rhsElemi = - builder.CreateExtractElement(argRHS, builder.getInt32(i)); - llvm::Value *sum = createAddIntrinsicCall(IID, lhsElemi, rhsElemi); - result = builder.CreateInsertElement(result, sum, builder.getInt32(i)); - }; - return result; - } - - return createAddIntrinsicCall(IID, argLHS, argRHS); - }; // addIntrinsic end + return createScalarizedIntrinsicCall(builder, IID, opTypeLLVM, + {argLHS, argRHS}, opTypeLLVM); + }; // f16 + f16 -> f16 / vector<2xf16> + vector<2xf16> -> vector<2xf16> // FIXME: Allow lowering to add.rn.ftz.f16x2 and add.rn.ftz.f16 here when the @@ -557,6 +562,122 @@ void NVVM::AddFOp::lowerAddFToLLVMIR(llvm::Value *argLHS, llvm::Value *argRHS, } } +void NVVM::FmaOp::lowerFmaToLLVMIR(Operation &op, LLVM::ModuleTranslation &mt, + llvm::IRBuilderBase &builder) { + auto thisOp = cast(op); + mlir::NVVM::FPRoundingMode rndMode = thisOp.getRnd(); + unsigned rndIndex = static_cast(rndMode) - 1; // 1-4 mapped to 0-3 + mlir::NVVM::SaturationMode satMode = thisOp.getSat(); + bool isFTZ = thisOp.getFtz(); + bool isRelu = thisOp.getRelu(); + bool isSat = satMode == NVVM::SaturationMode::SAT; + bool isOOB = thisOp.getOob(); + + mlir::Type opType = thisOp.getRes().getType(); + llvm::Type *opTypeLLVM = mt.convertType(opType); + bool isVectorFma = opTypeLLVM->isVectorTy(); + + llvm::Value *argA = mt.lookupValue(thisOp.getA()); + llvm::Value *argB = mt.lookupValue(thisOp.getB()); + llvm::Value *argC = mt.lookupValue(thisOp.getC()); + + static constexpr llvm::Intrinsic::ID f16IDs[] = { + llvm::Intrinsic::nvvm_fma_rn_f16, + llvm::Intrinsic::nvvm_fma_rn_f16x2, + llvm::Intrinsic::nvvm_fma_rn_ftz_f16, + llvm::Intrinsic::nvvm_fma_rn_ftz_f16x2, + llvm::Intrinsic::nvvm_fma_rn_sat_f16, + llvm::Intrinsic::nvvm_fma_rn_sat_f16x2, + llvm::Intrinsic::nvvm_fma_rn_ftz_sat_f16, + llvm::Intrinsic::nvvm_fma_rn_ftz_sat_f16x2, + llvm::Intrinsic::nvvm_fma_rn_relu_f16, + llvm::Intrinsic::nvvm_fma_rn_relu_f16x2, + llvm::Intrinsic::nvvm_fma_rn_ftz_relu_f16, + llvm::Intrinsic::nvvm_fma_rn_ftz_relu_f16x2}; + + static constexpr llvm::Intrinsic::ID bf16IDs[] = { + llvm::Intrinsic::nvvm_fma_rn_bf16, llvm::Intrinsic::nvvm_fma_rn_bf16x2, + llvm::Intrinsic::nvvm_fma_rn_relu_bf16, + llvm::Intrinsic::nvvm_fma_rn_relu_bf16x2}; + + static constexpr llvm::Intrinsic::ID f32IDs[] = { + llvm::Intrinsic::nvvm_fma_rn_f, + llvm::Intrinsic::nvvm_fma_rm_f, + llvm::Intrinsic::nvvm_fma_rp_f, + llvm::Intrinsic::nvvm_fma_rz_f, + llvm::Intrinsic::nvvm_fma_rn_sat_f, + llvm::Intrinsic::nvvm_fma_rm_sat_f, + llvm::Intrinsic::nvvm_fma_rp_sat_f, + llvm::Intrinsic::nvvm_fma_rz_sat_f, + llvm::Intrinsic::nvvm_fma_rn_ftz_f, + llvm::Intrinsic::nvvm_fma_rm_ftz_f, + llvm::Intrinsic::nvvm_fma_rp_ftz_f, + llvm::Intrinsic::nvvm_fma_rz_ftz_f, + llvm::Intrinsic::nvvm_fma_rn_ftz_sat_f, + llvm::Intrinsic::nvvm_fma_rm_ftz_sat_f, + llvm::Intrinsic::nvvm_fma_rp_ftz_sat_f, + llvm::Intrinsic::nvvm_fma_rz_ftz_sat_f, + }; + + static constexpr llvm::Intrinsic::ID f64IDs[] = { + llvm::Intrinsic::nvvm_fma_rn_d, llvm::Intrinsic::nvvm_fma_rm_d, + llvm::Intrinsic::nvvm_fma_rp_d, llvm::Intrinsic::nvvm_fma_rz_d}; + + auto fmaIntrinsic = [&](llvm::Intrinsic::ID IID, + llvm::Type *retType) -> llvm::Value * { + return createScalarizedIntrinsicCall( + builder, IID, opTypeLLVM, {argA, argB, argC}, /*retType=*/retType); + }; + + // f16 + f16 -> f16 / vector<2xf16> + vector<2xf16> -> vector<2xf16> + if (opTypeLLVM->getScalarType()->isHalfTy()) { + llvm::Value *result; + if (isOOB) { + result = fmaIntrinsic(isRelu ? llvm::Intrinsic::nvvm_fma_rn_oob_relu + : llvm::Intrinsic::nvvm_fma_rn_oob, + opTypeLLVM); + } else { + unsigned index = + (isRelu << 3) | (isSat << 2) | (isFTZ << 1) | + isVectorFma; // Op verifier ensures that this index is valid + result = fmaIntrinsic(f16IDs[index], opTypeLLVM); + } + mt.mapValue(thisOp.getRes(), result); + return; + } + + // bf16 + bf16 -> bf16 / vector<2xbf16> + vector<2xbf16> -> vector<2xbf16> + if (opTypeLLVM->getScalarType()->isBFloatTy()) { + llvm::Value *result; + if (isOOB) { + result = fmaIntrinsic(isRelu ? llvm::Intrinsic::nvvm_fma_rn_oob_relu + : llvm::Intrinsic::nvvm_fma_rn_oob, + opTypeLLVM); + } else { + unsigned index = (isRelu << 1) | isVectorFma; + result = fmaIntrinsic(bf16IDs[index], opTypeLLVM); + } + mt.mapValue(thisOp.getRes(), result); + return; + } + + // f64 + f64 -> f64 / vector<2xf64> + vector<2xf64> -> vector<2xf64> + if (opTypeLLVM->getScalarType()->isDoubleTy()) { + mt.mapValue(thisOp.getRes(), + fmaIntrinsic(f64IDs[rndIndex], opTypeLLVM->getScalarType())); + return; + } + + // f32 + f32 -> f32 / vector<2xf32> + vector<2xf32> -> vector<2xf32> + const unsigned numRndModes = 4; // RN, RM, RP, RZ + if (opTypeLLVM->getScalarType()->isFloatTy()) { + unsigned index = ((isFTZ << 1) | isSat) * numRndModes + rndIndex; + mt.mapValue(thisOp.getRes(), + fmaIntrinsic(f32IDs[index], opTypeLLVM->getScalarType())); + return; + } +} + namespace { /// Implementation of the dialect interface that converts operations belonging /// to the NVVM dialect to LLVM IR. diff --git a/mlir/test/Target/LLVMIR/nvvm/fma/fma.mlir b/mlir/test/Target/LLVMIR/nvvm/fma/fma.mlir new file mode 100644 index 0000000000000..236175daff21e --- /dev/null +++ b/mlir/test/Target/LLVMIR/nvvm/fma/fma.mlir @@ -0,0 +1,114 @@ +// RUN: mlir-translate -mlir-to-llvmir %s | FileCheck %s + +llvm.func @fma_f16(%a: f16, %b: f16, %c: f16) -> f16 { + // CHECK-LABEL: define half @fma_f16(half %0, half %1, half %2) { + // CHECK-NEXT: %4 = call half @llvm.nvvm.fma.rn.f16(half %0, half %1, half %2) + // CHECK-NEXT: %5 = call half @llvm.nvvm.fma.rn.ftz.f16(half %0, half %1, half %4) + // CHECK-NEXT: %6 = call half @llvm.nvvm.fma.rn.sat.f16(half %0, half %1, half %5) + // CHECK-NEXT: %7 = call half @llvm.nvvm.fma.rn.ftz.sat.f16(half %0, half %1, half %6) + // CHECK-NEXT: %8 = call half @llvm.nvvm.fma.rn.relu.f16(half %0, half %1, half %7) + // CHECK-NEXT: %9 = call half @llvm.nvvm.fma.rn.ftz.relu.f16(half %0, half %1, half %8) + // CHECK-NEXT: %10 = call half @llvm.nvvm.fma.rn.oob.f16(half %0, half %1, half %9) + // CHECK-NEXT: %11 = call half @llvm.nvvm.fma.rn.oob.relu.f16(half %0, half %1, half %10) + // CHECK-NEXT: ret half %11 + // CHECK-NEXT: } + %f0 = nvvm.fma %a, %b, %c {rnd = #nvvm.fp_rnd_mode} : f16 + %f1 = nvvm.fma %a, %b, %f0 {rnd = #nvvm.fp_rnd_mode, ftz = true} : f16 + %f2 = nvvm.fma %a, %b, %f1 {rnd = #nvvm.fp_rnd_mode, sat = #nvvm.sat_mode} : f16 + %f3 = nvvm.fma %a, %b, %f2 {rnd = #nvvm.fp_rnd_mode, sat = #nvvm.sat_mode, ftz = true} : f16 + %f4 = nvvm.fma %a, %b, %f3 {rnd = #nvvm.fp_rnd_mode, relu = true} : f16 + %f5 = nvvm.fma %a, %b, %f4 {rnd = #nvvm.fp_rnd_mode, relu = true, ftz = true} : f16 + %f6 = nvvm.fma %a, %b, %f5 {rnd = #nvvm.fp_rnd_mode, oob = true} : f16 + %f7 = nvvm.fma %a, %b, %f6 {rnd = #nvvm.fp_rnd_mode, oob = true, relu = true} : f16 + llvm.return %f7 : f16 +} + +llvm.func @fma_bf16(%a: bf16, %b: bf16, %c: bf16) -> bf16 { + // CHECK-LABEL: define bfloat @fma_bf16(bfloat %0, bfloat %1, bfloat %2) { + // CHECK-NEXT: %4 = call bfloat @llvm.nvvm.fma.rn.bf16(bfloat %0, bfloat %1, bfloat %2) + // CHECK-NEXT: %5 = call bfloat @llvm.nvvm.fma.rn.relu.bf16(bfloat %0, bfloat %1, bfloat %4) + // CHECK-NEXT: %6 = call bfloat @llvm.nvvm.fma.rn.oob.bf16(bfloat %0, bfloat %1, bfloat %5) + // CHECK-NEXT: %7 = call bfloat @llvm.nvvm.fma.rn.oob.relu.bf16(bfloat %0, bfloat %1, bfloat %6) + // CHECK-NEXT: ret bfloat %7 + // CHECK-NEXT: } + %f0 = nvvm.fma %a, %b, %c {rnd = #nvvm.fp_rnd_mode} : bf16 + %f1 = nvvm.fma %a, %b, %f0 {rnd = #nvvm.fp_rnd_mode, relu = true} : bf16 + %f2 = nvvm.fma %a, %b, %f1 {rnd = #nvvm.fp_rnd_mode, oob = true} : bf16 + %f3 = nvvm.fma %a, %b, %f2 {rnd = #nvvm.fp_rnd_mode, oob = true, relu = true} : bf16 + llvm.return %f3 : bf16 +} + +llvm.func @fma_f32_rn(%a: f32, %b: f32, %c: f32) -> f32 { + // CHECK-LABEL: define float @fma_f32_rn(float %0, float %1, float %2) { + // CHECK-NEXT: %4 = call float @llvm.nvvm.fma.rn.f(float %0, float %1, float %2) + // CHECK-NEXT: %5 = call float @llvm.nvvm.fma.rn.ftz.f(float %0, float %1, float %4) + // CHECK-NEXT: %6 = call float @llvm.nvvm.fma.rn.sat.f(float %0, float %1, float %5) + // CHECK-NEXT: %7 = call float @llvm.nvvm.fma.rn.ftz.sat.f(float %0, float %1, float %6) + // CHECK-NEXT: ret float %7 + // CHECK-NEXT: } + %f0 = nvvm.fma %a, %b, %c {rnd = #nvvm.fp_rnd_mode} : f32 + %f1 = nvvm.fma %a, %b, %f0 {rnd = #nvvm.fp_rnd_mode, ftz = true} : f32 + %f2 = nvvm.fma %a, %b, %f1 {rnd = #nvvm.fp_rnd_mode, sat = #nvvm.sat_mode} : f32 + %f3 = nvvm.fma %a, %b, %f2 {rnd = #nvvm.fp_rnd_mode, sat = #nvvm.sat_mode, ftz = true} : f32 + llvm.return %f3 : f32 +} + +llvm.func @fma_f32_rm(%a: f32, %b: f32, %c: f32) -> f32 { + // CHECK-LABEL: define float @fma_f32_rm(float %0, float %1, float %2) { + // CHECK-NEXT: %4 = call float @llvm.nvvm.fma.rm.f(float %0, float %1, float %2) + // CHECK-NEXT: %5 = call float @llvm.nvvm.fma.rm.ftz.f(float %0, float %1, float %4) + // CHECK-NEXT: %6 = call float @llvm.nvvm.fma.rm.sat.f(float %0, float %1, float %5) + // CHECK-NEXT: %7 = call float @llvm.nvvm.fma.rm.ftz.sat.f(float %0, float %1, float %6) + // CHECK-NEXT: ret float %7 + // CHECK-NEXT: } + %f0 = nvvm.fma %a, %b, %c {rnd = #nvvm.fp_rnd_mode} : f32 + %f1 = nvvm.fma %a, %b, %f0 {rnd = #nvvm.fp_rnd_mode, ftz = true} : f32 + %f2 = nvvm.fma %a, %b, %f1 {rnd = #nvvm.fp_rnd_mode, sat = #nvvm.sat_mode} : f32 + %f3 = nvvm.fma %a, %b, %f2 {rnd = #nvvm.fp_rnd_mode, sat = #nvvm.sat_mode, ftz = true} : f32 + llvm.return %f3 : f32 +} + +llvm.func @fma_f32_rp(%a: f32, %b: f32, %c: f32) -> f32 { + // CHECK-LABEL: define float @fma_f32_rp(float %0, float %1, float %2) { + // CHECK-NEXT: %4 = call float @llvm.nvvm.fma.rp.f(float %0, float %1, float %2) + // CHECK-NEXT: %5 = call float @llvm.nvvm.fma.rp.ftz.f(float %0, float %1, float %4) + // CHECK-NEXT: %6 = call float @llvm.nvvm.fma.rp.sat.f(float %0, float %1, float %5) + // CHECK-NEXT: %7 = call float @llvm.nvvm.fma.rp.ftz.sat.f(float %0, float %1, float %6) + // CHECK-NEXT: ret float %7 + // CHECK-NEXT: } + %f0 = nvvm.fma %a, %b, %c {rnd = #nvvm.fp_rnd_mode} : f32 + %f1 = nvvm.fma %a, %b, %f0 {rnd = #nvvm.fp_rnd_mode, ftz = true} : f32 + %f2 = nvvm.fma %a, %b, %f1 {rnd = #nvvm.fp_rnd_mode, sat = #nvvm.sat_mode} : f32 + %f3 = nvvm.fma %a, %b, %f2 {rnd = #nvvm.fp_rnd_mode, sat = #nvvm.sat_mode, ftz = true} : f32 + llvm.return %f3 : f32 +} + +llvm.func @fma_f32_rz(%a: f32, %b: f32, %c: f32) -> f32 { + // CHECK-LABEL: define float @fma_f32_rz(float %0, float %1, float %2) { + // CHECK-NEXT: %4 = call float @llvm.nvvm.fma.rz.f(float %0, float %1, float %2) + // CHECK-NEXT: %5 = call float @llvm.nvvm.fma.rz.ftz.f(float %0, float %1, float %4) + // CHECK-NEXT: %6 = call float @llvm.nvvm.fma.rz.sat.f(float %0, float %1, float %5) + // CHECK-NEXT: %7 = call float @llvm.nvvm.fma.rz.ftz.sat.f(float %0, float %1, float %6) + // CHECK-NEXT: ret float %7 + // CHECK-NEXT: } + %f0 = nvvm.fma %a, %b, %c {rnd = #nvvm.fp_rnd_mode} : f32 + %f1 = nvvm.fma %a, %b, %f0 {rnd = #nvvm.fp_rnd_mode, ftz = true} : f32 + %f2 = nvvm.fma %a, %b, %f1 {rnd = #nvvm.fp_rnd_mode, sat = #nvvm.sat_mode} : f32 + %f3 = nvvm.fma %a, %b, %f2 {rnd = #nvvm.fp_rnd_mode, sat = #nvvm.sat_mode, ftz = true} : f32 + llvm.return %f3 : f32 +} + +llvm.func @fma_f64(%a: f64, %b: f64, %c: f64) -> f64 { + // CHECK-LABEL: define double @fma_f64(double %0, double %1, double %2) { + // CHECK-NEXT: %4 = call double @llvm.nvvm.fma.rn.d(double %0, double %1, double %2) + // CHECK-NEXT: %5 = call double @llvm.nvvm.fma.rm.d(double %0, double %1, double %4) + // CHECK-NEXT: %6 = call double @llvm.nvvm.fma.rp.d(double %0, double %1, double %5) + // CHECK-NEXT: %7 = call double @llvm.nvvm.fma.rz.d(double %0, double %1, double %6) + // CHECK-NEXT: ret double %7 + // CHECK-NEXT: } + %f0 = nvvm.fma %a, %b, %c {rnd = #nvvm.fp_rnd_mode} : f64 + %f1 = nvvm.fma %a, %b, %f0 {rnd = #nvvm.fp_rnd_mode} : f64 + %f2 = nvvm.fma %a, %b, %f1 {rnd = #nvvm.fp_rnd_mode} : f64 + %f3 = nvvm.fma %a, %b, %f2 {rnd = #nvvm.fp_rnd_mode} : f64 + llvm.return %f3 : f64 +} diff --git a/mlir/test/Target/LLVMIR/nvvm/fma/fma_invalid.mlir b/mlir/test/Target/LLVMIR/nvvm/fma/fma_invalid.mlir new file mode 100644 index 0000000000000..ea92b707b65de --- /dev/null +++ b/mlir/test/Target/LLVMIR/nvvm/fma/fma_invalid.mlir @@ -0,0 +1,89 @@ +// RUN: mlir-translate --mlir-to-llvmir --split-input-file --verify-diagnostics %s + +// ----- + +llvm.func @fma_invalid_rnd_mode(%a : f16, %b : f16, %c : f16) -> f16 { + // expected-error@+1 {{rounding mode must be specified}} + %f1 = nvvm.fma %a, %b, %c {rnd = #nvvm.fp_rnd_mode} : f16 + llvm.return %f1 : f16 +} + +// ----- + +llvm.func @fma_invalid_sat_mode(%a : f16, %b : f16, %c : f16) -> f16 { + // expected-error@+1 {{attribute 'sat' failed to satisfy constraint: Describes the saturation mode whose value is one of {none, sat}}} + %f1 = nvvm.fma %a, %b, %c {sat = #nvvm.sat_mode, rnd = #nvvm.fp_rnd_mode} : f16 + llvm.return %f1 : f16 +} + +// ----- + +llvm.func @fma_invalid_relu_sat(%a : f16, %b : f16, %c : f16) -> f16 { + // expected-error@+1 {{relu and saturation are not supported together}} + %f1 = nvvm.fma %a, %b, %c {relu = true, sat = #nvvm.sat_mode, rnd = #nvvm.fp_rnd_mode} : f16 + llvm.return %f1 : f16 +} + +// ----- + +llvm.func @fma_invalid_oob_sat(%a : f16, %b : f16, %c : f16) -> f16 { + // expected-error@+1 {{oob is not supported with saturation or FTZ}} + %f1 = nvvm.fma %a, %b, %c {oob = true, sat = #nvvm.sat_mode, rnd = #nvvm.fp_rnd_mode} : f16 + llvm.return %f1 : f16 +} + +// ----- + +llvm.func @fma_invalid_oob_f64(%a : f64, %b : f64, %c : f64) -> f64 { + // expected-error@+1 {{relu and oob are only supported for f16 and bf16}} + %f1 = nvvm.fma %a, %b, %c {oob = true, rnd = #nvvm.fp_rnd_mode} : f64 + llvm.return %f1 : f64 +} + +// ----- + +llvm.func @fma_invalid_relu_oob(%a : f32, %b : f32, %c : f32) -> f32 { + // expected-error@+1 {{relu and oob are only supported for f16 and bf16}} + %f1 = nvvm.fma %a, %b, %c {relu = true, rnd = #nvvm.fp_rnd_mode} : f32 + llvm.return %f1 : f32 +} + +// ----- + +llvm.func @fma_invalid_ftz_sat_f64(%a : f64, %b : f64, %c : f64) -> f64 { + // expected-error@+1 {{FTZ and saturation are not supported for f64 type}} + %f1 = nvvm.fma %a, %b, %c {ftz = true, sat = #nvvm.sat_mode, rnd = #nvvm.fp_rnd_mode} : f64 + llvm.return %f1 : f64 +} + +// ----- + +llvm.func @fma_invalid_v2f16_rnd_mode(%a : vector<2xf16>, %b : vector<2xf16>, %c : vector<2xf16>) -> vector<2xf16> { + // expected-error@+1 {{only RN rounding mode is supported for f16 and vector<2xf16>}} + %f1 = nvvm.fma %a, %b, %c {rnd = #nvvm.fp_rnd_mode} : vector<2xf16> + llvm.return %f1 : vector<2xf16> +} + +// ----- + +llvm.func @fma_invalid_v2bf16_rnd_mode(%a : vector<2xbf16>, %b : vector<2xbf16>, %c : vector<2xbf16>) -> vector<2xbf16> { + // expected-error@+1 {{only RN rounding mode is supported for bf16 and vector<2xbf16>}} + %f1 = nvvm.fma %a, %b, %c {rnd = #nvvm.fp_rnd_mode} : vector<2xbf16> + llvm.return %f1 : vector<2xbf16> +} + +// ----- + +llvm.func @fma_invalid_ftz_v2bf16(%a : vector<2xbf16>, %b : vector<2xbf16>, %c : vector<2xbf16>) -> vector<2xbf16> { + // expected-error@+1 {{FTZ and saturation are not supported for bf16 and vector<2xbf16>}} + %f1 = nvvm.fma %a, %b, %c {ftz = true, rnd = #nvvm.fp_rnd_mode} : vector<2xbf16> + llvm.return %f1 : vector<2xbf16> +} + +// ----- + +llvm.func @fma_invalid_sat_v2bf16(%a : vector<2xbf16>, %b : vector<2xbf16>, %c : vector<2xbf16>) -> vector<2xbf16> { + // expected-error@+1 {{FTZ and saturation are not supported for bf16 and vector<2xbf16>}} + %f1 = nvvm.fma %a, %b, %c {sat = #nvvm.sat_mode, rnd = #nvvm.fp_rnd_mode} : vector<2xbf16> + llvm.return %f1 : vector<2xbf16> +} diff --git a/mlir/test/Target/LLVMIR/nvvm/fma/fma_vector.mlir b/mlir/test/Target/LLVMIR/nvvm/fma/fma_vector.mlir new file mode 100644 index 0000000000000..020bdcfc27705 --- /dev/null +++ b/mlir/test/Target/LLVMIR/nvvm/fma/fma_vector.mlir @@ -0,0 +1,294 @@ +// RUN: mlir-translate -mlir-to-llvmir %s | FileCheck %s + +llvm.func @fma_f16(%a: vector<2xf16>, %b: vector<2xf16>, %c: vector<2xf16>) -> vector<2xf16> { + // CHECK-LABEL: define <2 x half> @fma_f16(<2 x half> %0, <2 x half> %1, <2 x half> %2) { + // CHECK-NEXT: %4 = call <2 x half> @llvm.nvvm.fma.rn.f16x2(<2 x half> %0, <2 x half> %1, <2 x half> %2) + // CHECK-NEXT: %5 = call <2 x half> @llvm.nvvm.fma.rn.ftz.f16x2(<2 x half> %0, <2 x half> %1, <2 x half> %4) + // CHECK-NEXT: %6 = call <2 x half> @llvm.nvvm.fma.rn.sat.f16x2(<2 x half> %0, <2 x half> %1, <2 x half> %5) + // CHECK-NEXT: %7 = call <2 x half> @llvm.nvvm.fma.rn.ftz.sat.f16x2(<2 x half> %0, <2 x half> %1, <2 x half> %6) + // CHECK-NEXT: %8 = call <2 x half> @llvm.nvvm.fma.rn.relu.f16x2(<2 x half> %0, <2 x half> %1, <2 x half> %7) + // CHECK-NEXT: %9 = call <2 x half> @llvm.nvvm.fma.rn.ftz.relu.f16x2(<2 x half> %0, <2 x half> %1, <2 x half> %8) + // CHECK-NEXT: %10 = call <2 x half> @llvm.nvvm.fma.rn.oob.v2f16(<2 x half> %0, <2 x half> %1, <2 x half> %9) + // CHECK-NEXT: %11 = call <2 x half> @llvm.nvvm.fma.rn.oob.relu.v2f16(<2 x half> %0, <2 x half> %1, <2 x half> %10) + // CHECK-NEXT: ret <2 x half> %11 + // CHECK-NEXT: } + %f0 = nvvm.fma %a, %b, %c {rnd = #nvvm.fp_rnd_mode} : vector<2xf16> + %f1 = nvvm.fma %a, %b, %f0 {rnd = #nvvm.fp_rnd_mode, ftz = true} : vector<2xf16> + %f2 = nvvm.fma %a, %b, %f1 {rnd = #nvvm.fp_rnd_mode, sat = #nvvm.sat_mode} : vector<2xf16> + %f3 = nvvm.fma %a, %b, %f2 {rnd = #nvvm.fp_rnd_mode, sat = #nvvm.sat_mode, ftz = true} : vector<2xf16> + %f4 = nvvm.fma %a, %b, %f3 {rnd = #nvvm.fp_rnd_mode, relu = true} : vector<2xf16> + %f5 = nvvm.fma %a, %b, %f4 {rnd = #nvvm.fp_rnd_mode, relu = true, ftz = true} : vector<2xf16> + %f6 = nvvm.fma %a, %b, %f5 {rnd = #nvvm.fp_rnd_mode, oob = true} : vector<2xf16> + %f7 = nvvm.fma %a, %b, %f6 {rnd = #nvvm.fp_rnd_mode, oob = true, relu = true} : vector<2xf16> + llvm.return %f7 : vector<2xf16> +} + +llvm.func @fma_bf16(%a: vector<2xbf16>, %b: vector<2xbf16>, %c: vector<2xbf16>) -> vector<2xbf16> { + // CHECK-LABEL: define <2 x bfloat> @fma_bf16(<2 x bfloat> %0, <2 x bfloat> %1, <2 x bfloat> %2) { + // CHECK-NEXT: %4 = call <2 x bfloat> @llvm.nvvm.fma.rn.bf16x2(<2 x bfloat> %0, <2 x bfloat> %1, <2 x bfloat> %2) + // CHECK-NEXT: %5 = call <2 x bfloat> @llvm.nvvm.fma.rn.relu.bf16x2(<2 x bfloat> %0, <2 x bfloat> %1, <2 x bfloat> %4) + // CHECK-NEXT: %6 = call <2 x bfloat> @llvm.nvvm.fma.rn.oob.v2bf16(<2 x bfloat> %0, <2 x bfloat> %1, <2 x bfloat> %5) + // CHECK-NEXT: %7 = call <2 x bfloat> @llvm.nvvm.fma.rn.oob.relu.v2bf16(<2 x bfloat> %0, <2 x bfloat> %1, <2 x bfloat> %6) + // CHECK-NEXT: ret <2 x bfloat> %7 + // CHECK-NEXT: } + %f0 = nvvm.fma %a, %b, %c {rnd = #nvvm.fp_rnd_mode} : vector<2xbf16> + %f1 = nvvm.fma %a, %b, %f0 {rnd = #nvvm.fp_rnd_mode, relu = true} : vector<2xbf16> + %f2 = nvvm.fma %a, %b, %f1 {rnd = #nvvm.fp_rnd_mode, oob = true} : vector<2xbf16> + %f3 = nvvm.fma %a, %b, %f2 {rnd = #nvvm.fp_rnd_mode, oob = true, relu = true} : vector<2xbf16> + llvm.return %f3 : vector<2xbf16> +} + +llvm.func @fma_f32_rn(%a: vector<2xf32>, %b: vector<2xf32>, %c: vector<2xf32>) -> vector<2xf32> { + // CHECK-LABEL: define <2 x float> @fma_f32_rn(<2 x float> %0, <2 x float> %1, <2 x float> %2) { + // CHECK-NEXT: %4 = extractelement <2 x float> %0, i32 0 + // CHECK-NEXT: %5 = extractelement <2 x float> %1, i32 0 + // CHECK-NEXT: %6 = extractelement <2 x float> %2, i32 0 + // CHECK-NEXT: %7 = call float @llvm.nvvm.fma.rn.f(float %4, float %5, float %6) + // CHECK-NEXT: %8 = insertelement <2 x float> poison, float %7, i32 0 + // CHECK-NEXT: %9 = extractelement <2 x float> %0, i32 1 + // CHECK-NEXT: %10 = extractelement <2 x float> %1, i32 1 + // CHECK-NEXT: %11 = extractelement <2 x float> %2, i32 1 + // CHECK-NEXT: %12 = call float @llvm.nvvm.fma.rn.f(float %9, float %10, float %11) + // CHECK-NEXT: %13 = insertelement <2 x float> %8, float %12, i32 1 + // CHECK-NEXT: %14 = extractelement <2 x float> %0, i32 0 + // CHECK-NEXT: %15 = extractelement <2 x float> %1, i32 0 + // CHECK-NEXT: %16 = extractelement <2 x float> %13, i32 0 + // CHECK-NEXT: %17 = call float @llvm.nvvm.fma.rn.ftz.f(float %14, float %15, float %16) + // CHECK-NEXT: %18 = insertelement <2 x float> poison, float %17, i32 0 + // CHECK-NEXT: %19 = extractelement <2 x float> %0, i32 1 + // CHECK-NEXT: %20 = extractelement <2 x float> %1, i32 1 + // CHECK-NEXT: %21 = extractelement <2 x float> %13, i32 1 + // CHECK-NEXT: %22 = call float @llvm.nvvm.fma.rn.ftz.f(float %19, float %20, float %21) + // CHECK-NEXT: %23 = insertelement <2 x float> %18, float %22, i32 1 + // CHECK-NEXT: %24 = extractelement <2 x float> %0, i32 0 + // CHECK-NEXT: %25 = extractelement <2 x float> %1, i32 0 + // CHECK-NEXT: %26 = extractelement <2 x float> %23, i32 0 + // CHECK-NEXT: %27 = call float @llvm.nvvm.fma.rn.sat.f(float %24, float %25, float %26) + // CHECK-NEXT: %28 = insertelement <2 x float> poison, float %27, i32 0 + // CHECK-NEXT: %29 = extractelement <2 x float> %0, i32 1 + // CHECK-NEXT: %30 = extractelement <2 x float> %1, i32 1 + // CHECK-NEXT: %31 = extractelement <2 x float> %23, i32 1 + // CHECK-NEXT: %32 = call float @llvm.nvvm.fma.rn.sat.f(float %29, float %30, float %31) + // CHECK-NEXT: %33 = insertelement <2 x float> %28, float %32, i32 1 + // CHECK-NEXT: %34 = extractelement <2 x float> %0, i32 0 + // CHECK-NEXT: %35 = extractelement <2 x float> %1, i32 0 + // CHECK-NEXT: %36 = extractelement <2 x float> %33, i32 0 + // CHECK-NEXT: %37 = call float @llvm.nvvm.fma.rn.ftz.sat.f(float %34, float %35, float %36) + // CHECK-NEXT: %38 = insertelement <2 x float> poison, float %37, i32 0 + // CHECK-NEXT: %39 = extractelement <2 x float> %0, i32 1 + // CHECK-NEXT: %40 = extractelement <2 x float> %1, i32 1 + // CHECK-NEXT: %41 = extractelement <2 x float> %33, i32 1 + // CHECK-NEXT: %42 = call float @llvm.nvvm.fma.rn.ftz.sat.f(float %39, float %40, float %41) + // CHECK-NEXT: %43 = insertelement <2 x float> %38, float %42, i32 1 + // CHECK-NEXT: ret <2 x float> %43 + // CHECK-NEXT: } + %f0 = nvvm.fma %a, %b, %c {rnd = #nvvm.fp_rnd_mode} : vector<2xf32> + %f1 = nvvm.fma %a, %b, %f0 {rnd = #nvvm.fp_rnd_mode, ftz = true} : vector<2xf32> + %f2 = nvvm.fma %a, %b, %f1 {rnd = #nvvm.fp_rnd_mode, sat = #nvvm.sat_mode} : vector<2xf32> + %f3 = nvvm.fma %a, %b, %f2 {rnd = #nvvm.fp_rnd_mode, sat = #nvvm.sat_mode, ftz = true} : vector<2xf32> + llvm.return %f3 : vector<2xf32> +} + +llvm.func @fma_f32_rm(%a: vector<2xf32>, %b: vector<2xf32>, %c: vector<2xf32>) -> vector<2xf32> { + // CHECK-LABEL: define <2 x float> @fma_f32_rm(<2 x float> %0, <2 x float> %1, <2 x float> %2) { + // CHECK-NEXT: %4 = extractelement <2 x float> %0, i32 0 + // CHECK-NEXT: %5 = extractelement <2 x float> %1, i32 0 + // CHECK-NEXT: %6 = extractelement <2 x float> %2, i32 0 + // CHECK-NEXT: %7 = call float @llvm.nvvm.fma.rm.f(float %4, float %5, float %6) + // CHECK-NEXT: %8 = insertelement <2 x float> poison, float %7, i32 0 + // CHECK-NEXT: %9 = extractelement <2 x float> %0, i32 1 + // CHECK-NEXT: %10 = extractelement <2 x float> %1, i32 1 + // CHECK-NEXT: %11 = extractelement <2 x float> %2, i32 1 + // CHECK-NEXT: %12 = call float @llvm.nvvm.fma.rm.f(float %9, float %10, float %11) + // CHECK-NEXT: %13 = insertelement <2 x float> %8, float %12, i32 1 + // CHECK-NEXT: %14 = extractelement <2 x float> %0, i32 0 + // CHECK-NEXT: %15 = extractelement <2 x float> %1, i32 0 + // CHECK-NEXT: %16 = extractelement <2 x float> %13, i32 0 + // CHECK-NEXT: %17 = call float @llvm.nvvm.fma.rm.ftz.f(float %14, float %15, float %16) + // CHECK-NEXT: %18 = insertelement <2 x float> poison, float %17, i32 0 + // CHECK-NEXT: %19 = extractelement <2 x float> %0, i32 1 + // CHECK-NEXT: %20 = extractelement <2 x float> %1, i32 1 + // CHECK-NEXT: %21 = extractelement <2 x float> %13, i32 1 + // CHECK-NEXT: %22 = call float @llvm.nvvm.fma.rm.ftz.f(float %19, float %20, float %21) + // CHECK-NEXT: %23 = insertelement <2 x float> %18, float %22, i32 1 + // CHECK-NEXT: %24 = extractelement <2 x float> %0, i32 0 + // CHECK-NEXT: %25 = extractelement <2 x float> %1, i32 0 + // CHECK-NEXT: %26 = extractelement <2 x float> %23, i32 0 + // CHECK-NEXT: %27 = call float @llvm.nvvm.fma.rm.sat.f(float %24, float %25, float %26) + // CHECK-NEXT: %28 = insertelement <2 x float> poison, float %27, i32 0 + // CHECK-NEXT: %29 = extractelement <2 x float> %0, i32 1 + // CHECK-NEXT: %30 = extractelement <2 x float> %1, i32 1 + // CHECK-NEXT: %31 = extractelement <2 x float> %23, i32 1 + // CHECK-NEXT: %32 = call float @llvm.nvvm.fma.rm.sat.f(float %29, float %30, float %31) + // CHECK-NEXT: %33 = insertelement <2 x float> %28, float %32, i32 1 + // CHECK-NEXT: %34 = extractelement <2 x float> %0, i32 0 + // CHECK-NEXT: %35 = extractelement <2 x float> %1, i32 0 + // CHECK-NEXT: %36 = extractelement <2 x float> %33, i32 0 + // CHECK-NEXT: %37 = call float @llvm.nvvm.fma.rm.ftz.sat.f(float %34, float %35, float %36) + // CHECK-NEXT: %38 = insertelement <2 x float> poison, float %37, i32 0 + // CHECK-NEXT: %39 = extractelement <2 x float> %0, i32 1 + // CHECK-NEXT: %40 = extractelement <2 x float> %1, i32 1 + // CHECK-NEXT: %41 = extractelement <2 x float> %33, i32 1 + // CHECK-NEXT: %42 = call float @llvm.nvvm.fma.rm.ftz.sat.f(float %39, float %40, float %41) + // CHECK-NEXT: %43 = insertelement <2 x float> %38, float %42, i32 1 + // CHECK-NEXT: ret <2 x float> %43 + // CHECK-NEXT: } + %f0 = nvvm.fma %a, %b, %c {rnd = #nvvm.fp_rnd_mode} : vector<2xf32> + %f1 = nvvm.fma %a, %b, %f0 {rnd = #nvvm.fp_rnd_mode, ftz = true} : vector<2xf32> + %f2 = nvvm.fma %a, %b, %f1 {rnd = #nvvm.fp_rnd_mode, sat = #nvvm.sat_mode} : vector<2xf32> + %f3 = nvvm.fma %a, %b, %f2 {rnd = #nvvm.fp_rnd_mode, sat = #nvvm.sat_mode, ftz = true} : vector<2xf32> + llvm.return %f3 : vector<2xf32> +} + +llvm.func @fma_f32_rp(%a: vector<2xf32>, %b: vector<2xf32>, %c: vector<2xf32>) -> vector<2xf32> { + // CHECK-LABEL: define <2 x float> @fma_f32_rp(<2 x float> %0, <2 x float> %1, <2 x float> %2) { + // CHECK-NEXT: %4 = extractelement <2 x float> %0, i32 0 + // CHECK-NEXT: %5 = extractelement <2 x float> %1, i32 0 + // CHECK-NEXT: %6 = extractelement <2 x float> %2, i32 0 + // CHECK-NEXT: %7 = call float @llvm.nvvm.fma.rp.f(float %4, float %5, float %6) + // CHECK-NEXT: %8 = insertelement <2 x float> poison, float %7, i32 0 + // CHECK-NEXT: %9 = extractelement <2 x float> %0, i32 1 + // CHECK-NEXT: %10 = extractelement <2 x float> %1, i32 1 + // CHECK-NEXT: %11 = extractelement <2 x float> %2, i32 1 + // CHECK-NEXT: %12 = call float @llvm.nvvm.fma.rp.f(float %9, float %10, float %11) + // CHECK-NEXT: %13 = insertelement <2 x float> %8, float %12, i32 1 + // CHECK-NEXT: %14 = extractelement <2 x float> %0, i32 0 + // CHECK-NEXT: %15 = extractelement <2 x float> %1, i32 0 + // CHECK-NEXT: %16 = extractelement <2 x float> %13, i32 0 + // CHECK-NEXT: %17 = call float @llvm.nvvm.fma.rp.ftz.f(float %14, float %15, float %16) + // CHECK-NEXT: %18 = insertelement <2 x float> poison, float %17, i32 0 + // CHECK-NEXT: %19 = extractelement <2 x float> %0, i32 1 + // CHECK-NEXT: %20 = extractelement <2 x float> %1, i32 1 + // CHECK-NEXT: %21 = extractelement <2 x float> %13, i32 1 + // CHECK-NEXT: %22 = call float @llvm.nvvm.fma.rp.ftz.f(float %19, float %20, float %21) + // CHECK-NEXT: %23 = insertelement <2 x float> %18, float %22, i32 1 + // CHECK-NEXT: %24 = extractelement <2 x float> %0, i32 0 + // CHECK-NEXT: %25 = extractelement <2 x float> %1, i32 0 + // CHECK-NEXT: %26 = extractelement <2 x float> %23, i32 0 + // CHECK-NEXT: %27 = call float @llvm.nvvm.fma.rp.sat.f(float %24, float %25, float %26) + // CHECK-NEXT: %28 = insertelement <2 x float> poison, float %27, i32 0 + // CHECK-NEXT: %29 = extractelement <2 x float> %0, i32 1 + // CHECK-NEXT: %30 = extractelement <2 x float> %1, i32 1 + // CHECK-NEXT: %31 = extractelement <2 x float> %23, i32 1 + // CHECK-NEXT: %32 = call float @llvm.nvvm.fma.rp.sat.f(float %29, float %30, float %31) + // CHECK-NEXT: %33 = insertelement <2 x float> %28, float %32, i32 1 + // CHECK-NEXT: %34 = extractelement <2 x float> %0, i32 0 + // CHECK-NEXT: %35 = extractelement <2 x float> %1, i32 0 + // CHECK-NEXT: %36 = extractelement <2 x float> %33, i32 0 + // CHECK-NEXT: %37 = call float @llvm.nvvm.fma.rp.ftz.sat.f(float %34, float %35, float %36) + // CHECK-NEXT: %38 = insertelement <2 x float> poison, float %37, i32 0 + // CHECK-NEXT: %39 = extractelement <2 x float> %0, i32 1 + // CHECK-NEXT: %40 = extractelement <2 x float> %1, i32 1 + // CHECK-NEXT: %41 = extractelement <2 x float> %33, i32 1 + // CHECK-NEXT: %42 = call float @llvm.nvvm.fma.rp.ftz.sat.f(float %39, float %40, float %41) + // CHECK-NEXT: %43 = insertelement <2 x float> %38, float %42, i32 1 + // CHECK-NEXT: ret <2 x float> %43 + // CHECK-NEXT: } + %f0 = nvvm.fma %a, %b, %c {rnd = #nvvm.fp_rnd_mode} : vector<2xf32> + %f1 = nvvm.fma %a, %b, %f0 {rnd = #nvvm.fp_rnd_mode, ftz = true} : vector<2xf32> + %f2 = nvvm.fma %a, %b, %f1 {rnd = #nvvm.fp_rnd_mode, sat = #nvvm.sat_mode} : vector<2xf32> + %f3 = nvvm.fma %a, %b, %f2 {rnd = #nvvm.fp_rnd_mode, sat = #nvvm.sat_mode, ftz = true} : vector<2xf32> + llvm.return %f3 : vector<2xf32> +} + +llvm.func @fma_f32_rz(%a: vector<2xf32>, %b: vector<2xf32>, %c: vector<2xf32>) -> vector<2xf32> { + // CHECK-LABEL: define <2 x float> @fma_f32_rz(<2 x float> %0, <2 x float> %1, <2 x float> %2) { + // CHECK-NEXT: %4 = extractelement <2 x float> %0, i32 0 + // CHECK-NEXT: %5 = extractelement <2 x float> %1, i32 0 + // CHECK-NEXT: %6 = extractelement <2 x float> %2, i32 0 + // CHECK-NEXT: %7 = call float @llvm.nvvm.fma.rz.f(float %4, float %5, float %6) + // CHECK-NEXT: %8 = insertelement <2 x float> poison, float %7, i32 0 + // CHECK-NEXT: %9 = extractelement <2 x float> %0, i32 1 + // CHECK-NEXT: %10 = extractelement <2 x float> %1, i32 1 + // CHECK-NEXT: %11 = extractelement <2 x float> %2, i32 1 + // CHECK-NEXT: %12 = call float @llvm.nvvm.fma.rz.f(float %9, float %10, float %11) + // CHECK-NEXT: %13 = insertelement <2 x float> %8, float %12, i32 1 + // CHECK-NEXT: %14 = extractelement <2 x float> %0, i32 0 + // CHECK-NEXT: %15 = extractelement <2 x float> %1, i32 0 + // CHECK-NEXT: %16 = extractelement <2 x float> %13, i32 0 + // CHECK-NEXT: %17 = call float @llvm.nvvm.fma.rz.ftz.f(float %14, float %15, float %16) + // CHECK-NEXT: %18 = insertelement <2 x float> poison, float %17, i32 0 + // CHECK-NEXT: %19 = extractelement <2 x float> %0, i32 1 + // CHECK-NEXT: %20 = extractelement <2 x float> %1, i32 1 + // CHECK-NEXT: %21 = extractelement <2 x float> %13, i32 1 + // CHECK-NEXT: %22 = call float @llvm.nvvm.fma.rz.ftz.f(float %19, float %20, float %21) + // CHECK-NEXT: %23 = insertelement <2 x float> %18, float %22, i32 1 + // CHECK-NEXT: %24 = extractelement <2 x float> %0, i32 0 + // CHECK-NEXT: %25 = extractelement <2 x float> %1, i32 0 + // CHECK-NEXT: %26 = extractelement <2 x float> %23, i32 0 + // CHECK-NEXT: %27 = call float @llvm.nvvm.fma.rz.sat.f(float %24, float %25, float %26) + // CHECK-NEXT: %28 = insertelement <2 x float> poison, float %27, i32 0 + // CHECK-NEXT: %29 = extractelement <2 x float> %0, i32 1 + // CHECK-NEXT: %30 = extractelement <2 x float> %1, i32 1 + // CHECK-NEXT: %31 = extractelement <2 x float> %23, i32 1 + // CHECK-NEXT: %32 = call float @llvm.nvvm.fma.rz.sat.f(float %29, float %30, float %31) + // CHECK-NEXT: %33 = insertelement <2 x float> %28, float %32, i32 1 + // CHECK-NEXT: %34 = extractelement <2 x float> %0, i32 0 + // CHECK-NEXT: %35 = extractelement <2 x float> %1, i32 0 + // CHECK-NEXT: %36 = extractelement <2 x float> %33, i32 0 + // CHECK-NEXT: %37 = call float @llvm.nvvm.fma.rz.ftz.sat.f(float %34, float %35, float %36) + // CHECK-NEXT: %38 = insertelement <2 x float> poison, float %37, i32 0 + // CHECK-NEXT: %39 = extractelement <2 x float> %0, i32 1 + // CHECK-NEXT: %40 = extractelement <2 x float> %1, i32 1 + // CHECK-NEXT: %41 = extractelement <2 x float> %33, i32 1 + // CHECK-NEXT: %42 = call float @llvm.nvvm.fma.rz.ftz.sat.f(float %39, float %40, float %41) + // CHECK-NEXT: %43 = insertelement <2 x float> %38, float %42, i32 1 + // CHECK-NEXT: ret <2 x float> %43 + // CHECK-NEXT: } + %f0 = nvvm.fma %a, %b, %c {rnd = #nvvm.fp_rnd_mode} : vector<2xf32> + %f1 = nvvm.fma %a, %b, %f0 {rnd = #nvvm.fp_rnd_mode, ftz = true} : vector<2xf32> + %f2 = nvvm.fma %a, %b, %f1 {rnd = #nvvm.fp_rnd_mode, sat = #nvvm.sat_mode} : vector<2xf32> + %f3 = nvvm.fma %a, %b, %f2 {rnd = #nvvm.fp_rnd_mode, sat = #nvvm.sat_mode, ftz = true} : vector<2xf32> + llvm.return %f3 : vector<2xf32> +} + +llvm.func @fma_f64(%a: vector<2xf64>, %b: vector<2xf64>, %c: vector<2xf64>) -> vector<2xf64> { + // CHECK-LABEL: define <2 x double> @fma_f64(<2 x double> %0, <2 x double> %1, <2 x double> %2) { + // CHECK-NEXT: %4 = extractelement <2 x double> %0, i32 0 + // CHECK-NEXT: %5 = extractelement <2 x double> %1, i32 0 + // CHECK-NEXT: %6 = extractelement <2 x double> %2, i32 0 + // CHECK-NEXT: %7 = call double @llvm.nvvm.fma.rn.d(double %4, double %5, double %6) + // CHECK-NEXT: %8 = insertelement <2 x double> poison, double %7, i32 0 + // CHECK-NEXT: %9 = extractelement <2 x double> %0, i32 1 + // CHECK-NEXT: %10 = extractelement <2 x double> %1, i32 1 + // CHECK-NEXT: %11 = extractelement <2 x double> %2, i32 1 + // CHECK-NEXT: %12 = call double @llvm.nvvm.fma.rn.d(double %9, double %10, double %11) + // CHECK-NEXT: %13 = insertelement <2 x double> %8, double %12, i32 1 + // CHECK-NEXT: %14 = extractelement <2 x double> %0, i32 0 + // CHECK-NEXT: %15 = extractelement <2 x double> %1, i32 0 + // CHECK-NEXT: %16 = extractelement <2 x double> %13, i32 0 + // CHECK-NEXT: %17 = call double @llvm.nvvm.fma.rm.d(double %14, double %15, double %16) + // CHECK-NEXT: %18 = insertelement <2 x double> poison, double %17, i32 0 + // CHECK-NEXT: %19 = extractelement <2 x double> %0, i32 1 + // CHECK-NEXT: %20 = extractelement <2 x double> %1, i32 1 + // CHECK-NEXT: %21 = extractelement <2 x double> %13, i32 1 + // CHECK-NEXT: %22 = call double @llvm.nvvm.fma.rm.d(double %19, double %20, double %21) + // CHECK-NEXT: %23 = insertelement <2 x double> %18, double %22, i32 1 + // CHECK-NEXT: %24 = extractelement <2 x double> %0, i32 0 + // CHECK-NEXT: %25 = extractelement <2 x double> %1, i32 0 + // CHECK-NEXT: %26 = extractelement <2 x double> %23, i32 0 + // CHECK-NEXT: %27 = call double @llvm.nvvm.fma.rp.d(double %24, double %25, double %26) + // CHECK-NEXT: %28 = insertelement <2 x double> poison, double %27, i32 0 + // CHECK-NEXT: %29 = extractelement <2 x double> %0, i32 1 + // CHECK-NEXT: %30 = extractelement <2 x double> %1, i32 1 + // CHECK-NEXT: %31 = extractelement <2 x double> %23, i32 1 + // CHECK-NEXT: %32 = call double @llvm.nvvm.fma.rp.d(double %29, double %30, double %31) + // CHECK-NEXT: %33 = insertelement <2 x double> %28, double %32, i32 1 + // CHECK-NEXT: %34 = extractelement <2 x double> %0, i32 0 + // CHECK-NEXT: %35 = extractelement <2 x double> %1, i32 0 + // CHECK-NEXT: %36 = extractelement <2 x double> %33, i32 0 + // CHECK-NEXT: %37 = call double @llvm.nvvm.fma.rz.d(double %34, double %35, double %36) + // CHECK-NEXT: %38 = insertelement <2 x double> poison, double %37, i32 0 + // CHECK-NEXT: %39 = extractelement <2 x double> %0, i32 1 + // CHECK-NEXT: %40 = extractelement <2 x double> %1, i32 1 + // CHECK-NEXT: %41 = extractelement <2 x double> %33, i32 1 + // CHECK-NEXT: %42 = call double @llvm.nvvm.fma.rz.d(double %39, double %40, double %41) + // CHECK-NEXT: %43 = insertelement <2 x double> %38, double %42, i32 1 + // CHECK-NEXT: ret <2 x double> %43 + // CHECK-NEXT: } + %f0 = nvvm.fma %a, %b, %c {rnd = #nvvm.fp_rnd_mode} : vector<2xf64> + %f1 = nvvm.fma %a, %b, %f0 {rnd = #nvvm.fp_rnd_mode} : vector<2xf64> + %f2 = nvvm.fma %a, %b, %f1 {rnd = #nvvm.fp_rnd_mode} : vector<2xf64> + %f3 = nvvm.fma %a, %b, %f2 {rnd = #nvvm.fp_rnd_mode} : vector<2xf64> + llvm.return %f3 : vector<2xf64> +}