-
Notifications
You must be signed in to change notification settings - Fork 15.3k
[MLIR][NVVM] Add missing rounding modes in fp16x2 conversions #169005
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
c07effb
447acaf
04d38d0
4f8cb48
9b65959
fb0d052
1139e5a
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -390,18 +390,42 @@ LogicalResult ConvertF4x2ToF16x2Op::verify() { | |
| // Stochastic Rounding Conversion Ops | ||
| //===----------------------------------------------------------------------===// | ||
|
|
||
| LogicalResult ConvertF32x2ToF16x2Op::verify() { | ||
| if (getRnd() != FPRoundingMode::RS) | ||
| return emitOpError("Only RS rounding mode is supported for " | ||
| "conversions from f32x2 to f16x2."); | ||
| static LogicalResult verifyConvertF32x2ToFP16x2Op(Twine dstType, | ||
| FPRoundingMode rnd, | ||
| bool hasRandomBits, | ||
| Operation *op) { | ||
| static constexpr FPRoundingMode validRndModes[] = { | ||
| FPRoundingMode::RN, FPRoundingMode::RZ, FPRoundingMode::RS}; | ||
|
|
||
| if (!llvm::is_contained(validRndModes, rnd)) { | ||
| return op->emitOpError( | ||
| "Only RN, RZ, and RS rounding modes are supported for " | ||
| "conversions from f32x2 to ") | ||
| << dstType << "."; | ||
| } | ||
|
|
||
| if (rnd == FPRoundingMode::RS) { | ||
| if (!hasRandomBits) { | ||
| return op->emitOpError("random_bits is required for RS rounding mode."); | ||
| } | ||
| } else { | ||
| if (hasRandomBits) { | ||
| return op->emitOpError( | ||
| "random_bits not supported for RN and RZ rounding modes."); | ||
| } | ||
| } | ||
|
|
||
| return success(); | ||
| } | ||
|
|
||
| LogicalResult ConvertF32x2ToF16x2Op::verify() { | ||
| return verifyConvertF32x2ToFP16x2Op("f16x2", getRnd(), | ||
| getRandomBits() ? true : false, *this); | ||
| } | ||
|
|
||
| LogicalResult ConvertF32x2ToBF16x2Op::verify() { | ||
| if (getRnd() != FPRoundingMode::RS) | ||
| return emitOpError("Only RS rounding mode is supported for " | ||
| "conversions from f32x2 to bf16x2."); | ||
| return success(); | ||
| return verifyConvertF32x2ToFP16x2Op("bf16x2", getRnd(), | ||
| getRandomBits() ? true : false, *this); | ||
| } | ||
|
|
||
| LogicalResult ConvertF32x4ToF8x4Op::verify() { | ||
|
|
@@ -2727,30 +2751,100 @@ Tcgen05CommitOp::getIntrinsicIDAndArgs(Operation &op, | |
| return TCGEN05_CP_2CTA(shape_mc, , is_2cta); \ | ||
| }() | ||
|
|
||
| llvm::Intrinsic::ID ConvertF32x2ToF16x2Op::getIntrinsicID() { | ||
| bool hasRelu = getRelu(); | ||
| bool hasSatFinite = (getSat() == NVVM::SaturationMode::SATFINITE); | ||
| NVVM::IDArgPair | ||
| ConvertF32x2ToF16x2Op::getIntrinsicIDAndArgs(NVVM::ConvertF32x2ToF16x2Op &op, | ||
| LLVM::ModuleTranslation &mt, | ||
| llvm::IRBuilderBase &builder) { | ||
| static constexpr llvm::Intrinsic::ID rndRNIds[] = { | ||
| llvm::Intrinsic::nvvm_ff2f16x2_rn, | ||
| llvm::Intrinsic::nvvm_ff2f16x2_rn_relu, | ||
| llvm::Intrinsic::nvvm_ff2f16x2_rn_satfinite, | ||
| llvm::Intrinsic::nvvm_ff2f16x2_rn_relu_satfinite, | ||
| }; | ||
| static constexpr llvm::Intrinsic::ID rndRZIds[] = { | ||
| llvm::Intrinsic::nvvm_ff2f16x2_rz, | ||
| llvm::Intrinsic::nvvm_ff2f16x2_rz_relu, | ||
| llvm::Intrinsic::nvvm_ff2f16x2_rz_satfinite, | ||
| llvm::Intrinsic::nvvm_ff2f16x2_rz_relu_satfinite, | ||
| }; | ||
| static constexpr llvm::Intrinsic::ID rndRSIds[] = { | ||
| llvm::Intrinsic::nvvm_ff2f16x2_rs, | ||
| llvm::Intrinsic::nvvm_ff2f16x2_rs_relu, | ||
| llvm::Intrinsic::nvvm_ff2f16x2_rs_satfinite, | ||
| llvm::Intrinsic::nvvm_ff2f16x2_rs_relu_satfinite, | ||
| }; | ||
|
|
||
| if (hasRelu && hasSatFinite) | ||
| return llvm::Intrinsic::nvvm_ff2f16x2_rs_relu_satfinite; | ||
| if (hasRelu) | ||
| return llvm::Intrinsic::nvvm_ff2f16x2_rs_relu; | ||
| if (hasSatFinite) | ||
| return llvm::Intrinsic::nvvm_ff2f16x2_rs_satfinite; | ||
| return llvm::Intrinsic::nvvm_ff2f16x2_rs; | ||
| unsigned hasRelu = op.getRelu() ? 1 : 0; | ||
| unsigned hasSatFinite = | ||
| (op.getSat() == NVVM::SaturationMode::SATFINITE) ? 1 : 0; | ||
| // idx: bit-0 - relu | ||
| // bit-1 - satfinite | ||
| unsigned idx = (hasSatFinite << 1) | hasRelu; | ||
|
|
||
| llvm::SmallVector<llvm::Value *> args; | ||
| args.push_back(mt.lookupValue(op.getSrcHi())); | ||
| args.push_back(mt.lookupValue(op.getSrcLo())); | ||
| if (op.getRandomBits()) | ||
| args.push_back(mt.lookupValue(op.getRandomBits())); | ||
|
|
||
| switch (op.getRnd()) { | ||
| case FPRoundingMode::RN: | ||
| return {rndRNIds[idx], std::move(args)}; | ||
| case FPRoundingMode::RZ: | ||
| return {rndRZIds[idx], std::move(args)}; | ||
| case FPRoundingMode::RS: | ||
| return {rndRSIds[idx], std::move(args)}; | ||
| default: | ||
| llvm_unreachable("Invalid rounding mode for ConvertF32x2ToF16x2Op"); | ||
| } | ||
| } | ||
|
|
||
| llvm::Intrinsic::ID ConvertF32x2ToBF16x2Op::getIntrinsicID() { | ||
| bool hasRelu = getRelu(); | ||
| bool hasSatFinite = (getSat() == NVVM::SaturationMode::SATFINITE); | ||
|
|
||
| if (hasRelu && hasSatFinite) | ||
| return llvm::Intrinsic::nvvm_ff2bf16x2_rs_relu_satfinite; | ||
| if (hasRelu) | ||
| return llvm::Intrinsic::nvvm_ff2bf16x2_rs_relu; | ||
| if (hasSatFinite) | ||
| return llvm::Intrinsic::nvvm_ff2bf16x2_rs_satfinite; | ||
| return llvm::Intrinsic::nvvm_ff2bf16x2_rs; | ||
| NVVM::IDArgPair | ||
| ConvertF32x2ToBF16x2Op::getIntrinsicIDAndArgs(NVVM::ConvertF32x2ToBF16x2Op &op, | ||
| LLVM::ModuleTranslation &mt, | ||
| llvm::IRBuilderBase &builder) { | ||
| static constexpr llvm::Intrinsic::ID rndRNIds[] = { | ||
| llvm::Intrinsic::nvvm_ff2bf16x2_rn, | ||
| llvm::Intrinsic::nvvm_ff2bf16x2_rn_relu, | ||
| llvm::Intrinsic::nvvm_ff2bf16x2_rn_satfinite, | ||
| llvm::Intrinsic::nvvm_ff2bf16x2_rn_relu_satfinite, | ||
| }; | ||
| static constexpr llvm::Intrinsic::ID rndRZIds[] = { | ||
| llvm::Intrinsic::nvvm_ff2bf16x2_rz, | ||
| llvm::Intrinsic::nvvm_ff2bf16x2_rz_relu, | ||
| llvm::Intrinsic::nvvm_ff2bf16x2_rz_satfinite, | ||
| llvm::Intrinsic::nvvm_ff2bf16x2_rz_relu_satfinite, | ||
| }; | ||
| static constexpr llvm::Intrinsic::ID rndRSIds[] = { | ||
| llvm::Intrinsic::nvvm_ff2bf16x2_rs, | ||
| llvm::Intrinsic::nvvm_ff2bf16x2_rs_relu, | ||
| llvm::Intrinsic::nvvm_ff2bf16x2_rs_satfinite, | ||
| llvm::Intrinsic::nvvm_ff2bf16x2_rs_relu_satfinite, | ||
| }; | ||
|
Comment on lines
+2802
to
+2823
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. just thinking out loud - can we combine these two tables for b16 and f16: and also write a selector function Then you can select the intrinsic nicely:
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't have strong opinion. I think we are still trying to find what is the best way to create large tables and select and intrinsic from them nicely.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I did consider having a single table initially but went with separate ones because the valid rounding modes |
||
|
|
||
| unsigned hasRelu = op.getRelu() ? 1 : 0; | ||
| unsigned hasSatFinite = | ||
| (op.getSat() == NVVM::SaturationMode::SATFINITE) ? 1 : 0; | ||
| // idx: bit-0 - relu | ||
| // bit-1 - satfinite | ||
| unsigned idx = (hasSatFinite << 1) | hasRelu; | ||
|
|
||
| llvm::SmallVector<llvm::Value *> args; | ||
| args.push_back(mt.lookupValue(op.getSrcHi())); | ||
| args.push_back(mt.lookupValue(op.getSrcLo())); | ||
| if (op.getRandomBits()) | ||
| args.push_back(mt.lookupValue(op.getRandomBits())); | ||
|
|
||
| switch (op.getRnd()) { | ||
| case FPRoundingMode::RN: | ||
| return {rndRNIds[idx], std::move(args)}; | ||
| case FPRoundingMode::RZ: | ||
| return {rndRZIds[idx], std::move(args)}; | ||
| case FPRoundingMode::RS: | ||
| return {rndRSIds[idx], std::move(args)}; | ||
| default: | ||
| llvm_unreachable("Invalid rounding mode for ConvertF32x2ToBF16x2Op"); | ||
| } | ||
| } | ||
|
|
||
| llvm::Intrinsic::ID ConvertF32x4ToF8x4Op::getIntrinsicID() { | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,87 @@ | ||
| // RUN: mlir-translate -mlir-to-llvmir %s | FileCheck %s | ||
|
|
||
| // CHECK-LABEL: @convert_f32x2_to_f16x2_rn | ||
| llvm.func @convert_f32x2_to_f16x2_rn(%srcA : f32, %srcB : f32) { | ||
| // CHECK: %{{.*}} = call <2 x half> @llvm.nvvm.ff2f16x2.rn(float %{{.*}}, float %{{.*}}) | ||
| %res1 = nvvm.convert.f32x2.to.f16x2 %srcA, %srcB {rnd = #nvvm.fp_rnd_mode<rn>} : vector<2xf16> | ||
| // CHECK: %{{.*}} = call <2 x half> @llvm.nvvm.ff2f16x2.rn.satfinite(float %{{.*}}, float %{{.*}}) | ||
| %res2 = nvvm.convert.f32x2.to.f16x2 %srcA, %srcB {rnd = #nvvm.fp_rnd_mode<rn>, sat = #nvvm.sat_mode<satfinite>} : vector<2xf16> | ||
| // CHECK: %{{.*}} = call <2 x half> @llvm.nvvm.ff2f16x2.rn.relu(float %{{.*}}, float %{{.*}}) | ||
| %res3 = nvvm.convert.f32x2.to.f16x2 %srcA, %srcB {rnd = #nvvm.fp_rnd_mode<rn>, relu = true} : vector<2xf16> | ||
| // CHECK: %{{.*}} = call <2 x half> @llvm.nvvm.ff2f16x2.rn.relu.satfinite(float %{{.*}}, float %{{.*}}) | ||
| %res4 = nvvm.convert.f32x2.to.f16x2 %srcA, %srcB {rnd = #nvvm.fp_rnd_mode<rn>, relu = true, sat = #nvvm.sat_mode<satfinite>} : vector<2xf16> | ||
|
|
||
| llvm.return | ||
| } | ||
|
|
||
| // CHECK-LABEL: @convert_f32x2_to_f16x2_rz | ||
| llvm.func @convert_f32x2_to_f16x2_rz(%srcA : f32, %srcB : f32) { | ||
| // CHECK: %{{.*}} = call <2 x half> @llvm.nvvm.ff2f16x2.rz(float %{{.*}}, float %{{.*}}) | ||
| %res1 = nvvm.convert.f32x2.to.f16x2 %srcA, %srcB {rnd = #nvvm.fp_rnd_mode<rz>} : vector<2xf16> | ||
| // CHECK: %{{.*}} = call <2 x half> @llvm.nvvm.ff2f16x2.rz.satfinite(float %{{.*}}, float %{{.*}}) | ||
| %res2 = nvvm.convert.f32x2.to.f16x2 %srcA, %srcB {rnd = #nvvm.fp_rnd_mode<rz>, sat = #nvvm.sat_mode<satfinite>} : vector<2xf16> | ||
| // CHECK: %{{.*}} = call <2 x half> @llvm.nvvm.ff2f16x2.rz.relu(float %{{.*}}, float %{{.*}}) | ||
| %res3 = nvvm.convert.f32x2.to.f16x2 %srcA, %srcB {rnd = #nvvm.fp_rnd_mode<rz>, relu = true} : vector<2xf16> | ||
| // CHECK: %{{.*}} = call <2 x half> @llvm.nvvm.ff2f16x2.rz.relu.satfinite(float %{{.*}}, float %{{.*}}) | ||
| %res4 = nvvm.convert.f32x2.to.f16x2 %srcA, %srcB {rnd = #nvvm.fp_rnd_mode<rz>, relu = true, sat = #nvvm.sat_mode<satfinite>} : vector<2xf16> | ||
|
|
||
| llvm.return | ||
| } | ||
|
|
||
| // CHECK-LABEL: @convert_f32x2_to_f16x2_rs_stochastic | ||
| llvm.func @convert_f32x2_to_f16x2_rs_stochastic(%srcA : f32, %srcB : f32, %rbits : i32) { | ||
| // CHECK: %{{.*}} = call <2 x half> @llvm.nvvm.ff2f16x2.rs(float %{{.*}}, float %{{.*}}, i32 %{{.*}}) | ||
| %res1 = nvvm.convert.f32x2.to.f16x2 %srcA, %srcB, %rbits {rnd = #nvvm.fp_rnd_mode<rs>} : vector<2xf16> | ||
| // CHECK: %{{.*}} = call <2 x half> @llvm.nvvm.ff2f16x2.rs.relu(float %{{.*}}, float %{{.*}}, i32 %{{.*}}) | ||
| %res2 = nvvm.convert.f32x2.to.f16x2 %srcA, %srcB, %rbits {relu = true, rnd = #nvvm.fp_rnd_mode<rs>} : vector<2xf16> | ||
| // CHECK: %{{.*}} = call <2 x half> @llvm.nvvm.ff2f16x2.rs.satfinite(float %{{.*}}, float %{{.*}}, i32 %{{.*}}) | ||
| %res3 = nvvm.convert.f32x2.to.f16x2 %srcA, %srcB, %rbits {rnd = #nvvm.fp_rnd_mode<rs>, sat = #nvvm.sat_mode<satfinite>} : vector<2xf16> | ||
| // CHECK: %{{.*}} = call <2 x half> @llvm.nvvm.ff2f16x2.rs.relu.satfinite(float %{{.*}}, float %{{.*}}, i32 %{{.*}}) | ||
| %res4 = nvvm.convert.f32x2.to.f16x2 %srcA, %srcB, %rbits {relu = true, rnd = #nvvm.fp_rnd_mode<rs>, sat = #nvvm.sat_mode<satfinite>} : vector<2xf16> | ||
|
|
||
| llvm.return | ||
| } | ||
|
|
||
| // ----- | ||
|
|
||
| // CHECK-LABEL: @convert_f32x2_to_bf16x2_rn | ||
| llvm.func @convert_f32x2_to_bf16x2_rn(%srcA : f32, %srcB : f32) { | ||
| // CHECK: %{{.*}} = call <2 x bfloat> @llvm.nvvm.ff2bf16x2.rn(float %{{.*}}, float %{{.*}}) | ||
| %res1 = nvvm.convert.f32x2.to.bf16x2 %srcA, %srcB {rnd = #nvvm.fp_rnd_mode<rn>} : vector<2xbf16> | ||
| // CHECK: %{{.*}} = call <2 x bfloat> @llvm.nvvm.ff2bf16x2.rn.satfinite(float %{{.*}}, float %{{.*}}) | ||
| %res2 = nvvm.convert.f32x2.to.bf16x2 %srcA, %srcB {rnd = #nvvm.fp_rnd_mode<rn>, sat = #nvvm.sat_mode<satfinite>} : vector<2xbf16> | ||
| // CHECK: %{{.*}} = call <2 x bfloat> @llvm.nvvm.ff2bf16x2.rn.relu(float %{{.*}}, float %{{.*}}) | ||
| %res3 = nvvm.convert.f32x2.to.bf16x2 %srcA, %srcB {rnd = #nvvm.fp_rnd_mode<rn>, relu = true} : vector<2xbf16> | ||
| // CHECK: %{{.*}} = call <2 x bfloat> @llvm.nvvm.ff2bf16x2.rn.relu.satfinite(float %{{.*}}, float %{{.*}}) | ||
| %res4 = nvvm.convert.f32x2.to.bf16x2 %srcA, %srcB {rnd = #nvvm.fp_rnd_mode<rn>, relu = true, sat = #nvvm.sat_mode<satfinite>} : vector<2xbf16> | ||
|
|
||
| llvm.return | ||
| } | ||
|
|
||
| // CHECK-LABEL: @convert_f32x2_to_bf16x2_rz | ||
| llvm.func @convert_f32x2_to_bf16x2_rz(%srcA : f32, %srcB : f32) { | ||
| // CHECK: %{{.*}} = call <2 x bfloat> @llvm.nvvm.ff2bf16x2.rz(float %{{.*}}, float %{{.*}}) | ||
| %res1 = nvvm.convert.f32x2.to.bf16x2 %srcA, %srcB {rnd = #nvvm.fp_rnd_mode<rz>} : vector<2xbf16> | ||
| // CHECK: %{{.*}} = call <2 x bfloat> @llvm.nvvm.ff2bf16x2.rz.satfinite(float %{{.*}}, float %{{.*}}) | ||
| %res2 = nvvm.convert.f32x2.to.bf16x2 %srcA, %srcB {rnd = #nvvm.fp_rnd_mode<rz>, sat = #nvvm.sat_mode<satfinite>} : vector<2xbf16> | ||
| // CHECK: %{{.*}} = call <2 x bfloat> @llvm.nvvm.ff2bf16x2.rz.relu(float %{{.*}}, float %{{.*}}) | ||
| %res3 = nvvm.convert.f32x2.to.bf16x2 %srcA, %srcB {rnd = #nvvm.fp_rnd_mode<rz>, relu = true} : vector<2xbf16> | ||
| // CHECK: %{{.*}} = call <2 x bfloat> @llvm.nvvm.ff2bf16x2.rz.relu.satfinite(float %{{.*}}, float %{{.*}}) | ||
| %res4 = nvvm.convert.f32x2.to.bf16x2 %srcA, %srcB {rnd = #nvvm.fp_rnd_mode<rz>, relu = true, sat = #nvvm.sat_mode<satfinite>} : vector<2xbf16> | ||
|
|
||
| llvm.return | ||
| } | ||
|
|
||
| // CHECK-LABEL: @convert_f32x2_to_bf16x2_rs_stochastic | ||
| llvm.func @convert_f32x2_to_bf16x2_rs_stochastic(%srcA : f32, %srcB : f32, %rbits : i32) { | ||
| // CHECK: %{{.*}} = call <2 x bfloat> @llvm.nvvm.ff2bf16x2.rs(float %{{.*}}, float %{{.*}}, i32 %{{.*}}) | ||
| %res1 = nvvm.convert.f32x2.to.bf16x2 %srcA, %srcB, %rbits {rnd = #nvvm.fp_rnd_mode<rs>} : vector<2xbf16> | ||
| // CHECK: %{{.*}} = call <2 x bfloat> @llvm.nvvm.ff2bf16x2.rs.relu(float %{{.*}}, float %{{.*}}, i32 %{{.*}}) | ||
| %res2 = nvvm.convert.f32x2.to.bf16x2 %srcA, %srcB, %rbits {relu = true, rnd = #nvvm.fp_rnd_mode<rs>} : vector<2xbf16> | ||
| // CHECK: %{{.*}} = call <2 x bfloat> @llvm.nvvm.ff2bf16x2.rs.satfinite(float %{{.*}}, float %{{.*}}, i32 %{{.*}}) | ||
| %res3 = nvvm.convert.f32x2.to.bf16x2 %srcA, %srcB, %rbits {rnd = #nvvm.fp_rnd_mode<rs>, sat = #nvvm.sat_mode<satfinite>} : vector<2xbf16> | ||
| // CHECK: %{{.*}} = call <2 x bfloat> @llvm.nvvm.ff2bf16x2.rs.relu.satfinite(float %{{.*}}, float %{{.*}}, i32 %{{.*}}) | ||
| %res4 = nvvm.convert.f32x2.to.bf16x2 %srcA, %srcB, %rbits {relu = true, rnd = #nvvm.fp_rnd_mode<rs>, sat = #nvvm.sat_mode<satfinite>} : vector<2xbf16> | ||
|
|
||
| llvm.return | ||
| } |
Uh oh!
There was an error while loading. Please reload this page.