-
Notifications
You must be signed in to change notification settings - Fork 15.3k
[MLIR][NVVM] Add missing rounding modes in fp16x2 conversions #169005
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
This change adds the `RN` and `RZ` rounding modes to the `convert.f32x2.to.f16x2` and `convert.f32x2.to.bf16x2` Ops. Tests are added `convert_fp16x2.mlir` and `nvvmir-invalid.mlir`. Tests with these Ops in `convert_stochastic_rounding.mlir` and `invalid-convert-stochastic-rounding.mlir` have been removed or modified. PTX spec reference: https://docs.nvidia.com/cuda/parallel-thread-execution/#data-movement-and-conversion-instructions-cvt
|
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-mlir-llvm Author: Srinivasa Ravi (Wolfram70) ChangesThis change adds the Tests are added PTX spec reference: https://docs.nvidia.com/cuda/parallel-thread-execution/#data-movement-and-conversion-instructions-cvt Patch is 25.56 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/169005.diff 6 Files Affected:
diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
index 6e3a92b5bde42..7a2cfb1fee5eb 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
@@ -1912,45 +1912,51 @@ def NVVM_ConvertF4x2ToF16x2Op :
// Base class for conversions from F32x2 to FPx2 formats
// (F16x2, BF16x2)
-// TODO: In separate PR, add .rn and .rz rounding variants for this conversion
-// as currently only support .rs rounding mode
class NVVM_ConvertF32x2ToFPx2OpBase<string dstFormat, string mnemonic, Type dstType> :
- NVVM_Op<mnemonic, [Pure, NVVMRequiresSMa<[100, 103]>]>,
+ NVVM_Op<mnemonic, [Pure]>,
Results<(outs dstType:$dst)>,
- Arguments<(ins F32:$src_hi, F32:$src_lo, I32:$rbits,
- DefaultValuedAttr<FPRoundingModeAttr, "FPRoundingMode::RS">:$rnd,
+ Arguments<(ins F32:$src_hi, F32:$src_lo, Optional<I32>:$rbits,
+ DefaultValuedAttr<FPRoundingModeAttr, "FPRoundingMode::NONE">:$rnd,
DefaultValuedAttr<SaturationModeAttr, "SaturationMode::NONE">:$sat,
DefaultValuedAttr<BoolAttr, "false">:$relu)> {
- let summary = "Convert two F32 values to packed " # dstFormat # " with stochastic rounding (.rs)";
+ let summary = "Convert two F32 values to packed " # !tolower(dstFormat) # ".";
let description = [{
- Converts two F32 values to packed }] # dstFormat # [{ format using stochastic
- rounding (.rs) mode with randomness provided by the `rbits` parameter. The
- `relu` attribute clamps negative results to 0. The `sat` attribute determines
- saturation behavior. The `src_hi` and `src_lo` parameters correspond to operands
- `a` and `b` in the PTX ISA, respectively.
+ Converts two F32 values to packed }] # !tolower(dstFormat) # [{ format with
+ the specified rounding mode. The `src_hi` and `src_lo` parameters
+ correspond to operands `a` and `b` in the PTX ISA, respectively.
+
+ The `rbits` parameter is required for stochastic rounding.
+
+ The `relu` attribute clamps negative results to 0.
+
+ The `sat` attribute determines saturation behavior.
[For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cvt)
}];
- let assemblyFormat = "$src_hi `,` $src_lo `,` $rbits attr-dict `:` type($dst)";
+ let assemblyFormat = "$src_hi `,` $src_lo (`,` $rbits^)? attr-dict `:` type($dst)";
let hasVerifier = 1;
let extraClassDeclaration = [{
- llvm::Intrinsic::ID getIntrinsicID();
+ static NVVM::IDArgPair
+ getIntrinsicIDAndArgs(
+ NVVM::ConvertF32x2To}] # dstFormat # [{Op &op,
+ LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder);
}];
string llvmBuilder = [{
- auto intId = op.getIntrinsicID();
- $dst = createIntrinsicCall(builder, intId, {$src_hi, $src_lo, $rbits});
+ auto [intId, args] = mlir::NVVM::ConvertF32x2To}] # dstFormat #
+ [{Op::getIntrinsicIDAndArgs(op, moduleTranslation, builder);
+ $dst = createIntrinsicCall(builder, intId, args);
}];
- }
+}
// F32x2 -> F16x2 with stochastic rounding
-def NVVM_ConvertF32x2ToF16x2Op : NVVM_ConvertF32x2ToFPx2OpBase<"f16x2", "convert.f32x2.to.f16x2", VectorOfLengthAndType<[2], [F16]>>;
+def NVVM_ConvertF32x2ToF16x2Op : NVVM_ConvertF32x2ToFPx2OpBase<"F16x2", "convert.f32x2.to.f16x2", VectorOfLengthAndType<[2], [F16]>>;
// F32x2 -> BF16x2 with stochastic rounding
-def NVVM_ConvertF32x2ToBF16x2Op : NVVM_ConvertF32x2ToFPx2OpBase<"bf16x2", "convert.f32x2.to.bf16x2", VectorOfLengthAndType<[2], [BF16]>>;
+def NVVM_ConvertF32x2ToBF16x2Op : NVVM_ConvertF32x2ToFPx2OpBase<"BF16x2", "convert.f32x2.to.bf16x2", VectorOfLengthAndType<[2], [BF16]>>;
// Base class for stochastic rounding conversions from F32x4 to FPx4 formats
// (E4M3x4, E5M2x4, E2M3x4, E3M2x4, E2M1x4)
diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
index 7ac427dbe3941..4654ed49a0ca1 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
@@ -391,16 +391,40 @@ LogicalResult ConvertF4x2ToF16x2Op::verify() {
//===----------------------------------------------------------------------===//
LogicalResult ConvertF32x2ToF16x2Op::verify() {
- if (getRnd() != FPRoundingMode::RS)
- return emitOpError("Only RS rounding mode is supported for "
+ switch (getRnd()) {
+ case FPRoundingMode::RN:
+ case FPRoundingMode::RZ:
+ if (getRbits())
+ return emitOpError("rbits not supported for RN and RZ rounding modes.");
+ break;
+ case FPRoundingMode::RS:
+ if (!getRbits())
+ return emitOpError("rbits is required for RS rounding mode.");
+ break;
+ default:
+ return emitOpError("Only RN, RZ, and RS rounding modes are supported for "
"conversions from f32x2 to f16x2.");
+ }
+
return success();
}
LogicalResult ConvertF32x2ToBF16x2Op::verify() {
- if (getRnd() != FPRoundingMode::RS)
- return emitOpError("Only RS rounding mode is supported for "
+ switch (getRnd()) {
+ case FPRoundingMode::RN:
+ case FPRoundingMode::RZ:
+ if (getRbits())
+ return emitOpError("rbits not supported for RN and RZ rounding modes.");
+ break;
+ case FPRoundingMode::RS:
+ if (!getRbits())
+ return emitOpError("rbits is required for RS rounding mode.");
+ break;
+ default:
+ return emitOpError("Only RN, RZ, and RS rounding modes are supported for "
"conversions from f32x2 to bf16x2.");
+ }
+
return success();
}
@@ -2727,30 +2751,98 @@ Tcgen05CommitOp::getIntrinsicIDAndArgs(Operation &op,
return TCGEN05_CP_2CTA(shape_mc, , is_2cta); \
}()
-llvm::Intrinsic::ID ConvertF32x2ToF16x2Op::getIntrinsicID() {
- bool hasRelu = getRelu();
- bool hasSatFinite = (getSat() == NVVM::SaturationMode::SATFINITE);
+NVVM::IDArgPair
+ConvertF32x2ToF16x2Op::getIntrinsicIDAndArgs(NVVM::ConvertF32x2ToF16x2Op &op,
+ LLVM::ModuleTranslation &mt,
+ llvm::IRBuilderBase &builder) {
+ static constexpr llvm::Intrinsic::ID rndRNIds[] = {
+ llvm::Intrinsic::nvvm_ff2f16x2_rn,
+ llvm::Intrinsic::nvvm_ff2f16x2_rn_relu,
+ llvm::Intrinsic::nvvm_ff2f16x2_rn_satfinite,
+ llvm::Intrinsic::nvvm_ff2f16x2_rn_relu_satfinite,
+ };
+ static constexpr llvm::Intrinsic::ID rndRZIds[] = {
+ llvm::Intrinsic::nvvm_ff2f16x2_rz,
+ llvm::Intrinsic::nvvm_ff2f16x2_rz_relu,
+ llvm::Intrinsic::nvvm_ff2f16x2_rz_satfinite,
+ llvm::Intrinsic::nvvm_ff2f16x2_rz_relu_satfinite,
+ };
+ static constexpr llvm::Intrinsic::ID rndRSIds[] = {
+ llvm::Intrinsic::nvvm_ff2f16x2_rs,
+ llvm::Intrinsic::nvvm_ff2f16x2_rs_relu,
+ llvm::Intrinsic::nvvm_ff2f16x2_rs_satfinite,
+ llvm::Intrinsic::nvvm_ff2f16x2_rs_relu_satfinite,
+ };
- if (hasRelu && hasSatFinite)
- return llvm::Intrinsic::nvvm_ff2f16x2_rs_relu_satfinite;
- if (hasRelu)
- return llvm::Intrinsic::nvvm_ff2f16x2_rs_relu;
- if (hasSatFinite)
- return llvm::Intrinsic::nvvm_ff2f16x2_rs_satfinite;
- return llvm::Intrinsic::nvvm_ff2f16x2_rs;
+ bool hasRelu = op.getRelu();
+ bool hasSatFinite = (op.getSat() == NVVM::SaturationMode::SATFINITE);
+ // idx: bit-0 - relu
+ // bit-1 - satfinite
+ unsigned idx = (hasSatFinite << 1) | hasRelu;
+
+ llvm::SmallVector<llvm::Value *> args;
+ args.push_back(mt.lookupValue(op.getSrcHi()));
+ args.push_back(mt.lookupValue(op.getSrcLo()));
+ if (op.getRbits())
+ args.push_back(mt.lookupValue(op.getRbits()));
+
+ switch (op.getRnd()) {
+ case FPRoundingMode::RN:
+ return {rndRNIds[idx], std::move(args)};
+ case FPRoundingMode::RZ:
+ return {rndRZIds[idx], std::move(args)};
+ case FPRoundingMode::RS:
+ return {rndRSIds[idx], std::move(args)};
+ default:
+ llvm_unreachable("Invalid rounding mode for ConvertF32x2ToF16x2Op");
+ }
}
-llvm::Intrinsic::ID ConvertF32x2ToBF16x2Op::getIntrinsicID() {
- bool hasRelu = getRelu();
- bool hasSatFinite = (getSat() == NVVM::SaturationMode::SATFINITE);
-
- if (hasRelu && hasSatFinite)
- return llvm::Intrinsic::nvvm_ff2bf16x2_rs_relu_satfinite;
- if (hasRelu)
- return llvm::Intrinsic::nvvm_ff2bf16x2_rs_relu;
- if (hasSatFinite)
- return llvm::Intrinsic::nvvm_ff2bf16x2_rs_satfinite;
- return llvm::Intrinsic::nvvm_ff2bf16x2_rs;
+NVVM::IDArgPair
+ConvertF32x2ToBF16x2Op::getIntrinsicIDAndArgs(NVVM::ConvertF32x2ToBF16x2Op &op,
+ LLVM::ModuleTranslation &mt,
+ llvm::IRBuilderBase &builder) {
+ static constexpr llvm::Intrinsic::ID rndRNIds[] = {
+ llvm::Intrinsic::nvvm_ff2bf16x2_rn,
+ llvm::Intrinsic::nvvm_ff2bf16x2_rn_relu,
+ llvm::Intrinsic::nvvm_ff2bf16x2_rn_satfinite,
+ llvm::Intrinsic::nvvm_ff2bf16x2_rn_relu_satfinite,
+ };
+ static constexpr llvm::Intrinsic::ID rndRZIds[] = {
+ llvm::Intrinsic::nvvm_ff2bf16x2_rz,
+ llvm::Intrinsic::nvvm_ff2bf16x2_rz_relu,
+ llvm::Intrinsic::nvvm_ff2bf16x2_rz_satfinite,
+ llvm::Intrinsic::nvvm_ff2bf16x2_rz_relu_satfinite,
+ };
+ static constexpr llvm::Intrinsic::ID rndRSIds[] = {
+ llvm::Intrinsic::nvvm_ff2bf16x2_rs,
+ llvm::Intrinsic::nvvm_ff2bf16x2_rs_relu,
+ llvm::Intrinsic::nvvm_ff2bf16x2_rs_satfinite,
+ llvm::Intrinsic::nvvm_ff2bf16x2_rs_relu_satfinite,
+ };
+
+ bool hasRelu = op.getRelu();
+ bool hasSatFinite = (op.getSat() == NVVM::SaturationMode::SATFINITE);
+ // idx: bit-0 - relu
+ // bit-1 - satfinite
+ unsigned idx = (hasSatFinite << 1) | hasRelu;
+
+ llvm::SmallVector<llvm::Value *> args;
+ args.push_back(mt.lookupValue(op.getSrcHi()));
+ args.push_back(mt.lookupValue(op.getSrcLo()));
+ if (op.getRbits())
+ args.push_back(mt.lookupValue(op.getRbits()));
+
+ switch (op.getRnd()) {
+ case FPRoundingMode::RN:
+ return {rndRNIds[idx], std::move(args)};
+ case FPRoundingMode::RZ:
+ return {rndRZIds[idx], std::move(args)};
+ case FPRoundingMode::RS:
+ return {rndRSIds[idx], std::move(args)};
+ default:
+ llvm_unreachable("Invalid rounding mode for ConvertF32x2ToBF16x2Op");
+ }
}
llvm::Intrinsic::ID ConvertF32x4ToF8x4Op::getIntrinsicID() {
diff --git a/mlir/test/Dialect/LLVMIR/nvvm/invalid-convert-stochastic-rounding.mlir b/mlir/test/Dialect/LLVMIR/nvvm/invalid-convert-stochastic-rounding.mlir
index 35f5e1b3c8ba2..506b81e1e7048 100644
--- a/mlir/test/Dialect/LLVMIR/nvvm/invalid-convert-stochastic-rounding.mlir
+++ b/mlir/test/Dialect/LLVMIR/nvvm/invalid-convert-stochastic-rounding.mlir
@@ -2,35 +2,15 @@
// Test invalid target architecture (sm_100 instead of sm_100a)
gpu.module @invalid_arch_sm_100 [#nvvm.target<chip = "sm_100">] {
- func.func @convert_rs() {
- %f1 = llvm.mlir.constant(1.0 : f32) : f32
- %f2 = llvm.mlir.constant(2.0 : f32) : f32
- %rbits = llvm.mlir.constant(0x12345678 : i32) : i32
- // expected-error@+1 {{'nvvm.convert.f32x2.to.f16x2' op is not supported on sm_100}}
- %res = nvvm.convert.f32x2.to.f16x2 %f1, %f2, %rbits : vector<2xf16>
+ func.func @convert_rs(%src : vector<4xf32>, %rbits : i32) {
+ // expected-error@+1 {{'nvvm.convert.f32x4.to.f8x4' op is not supported on sm_100}}
+ %res = nvvm.convert.f32x4.to.f8x4 %src, %rbits : vector<4xf32> -> vector<4xi8> (f8E4M3FN)
return
}
}
// -----
-// Test that operations require stochastic rounding mode
-llvm.func @invalid_rnd_mode_f16x2(%srcA : f32, %srcB : f32, %rbits : i32) -> vector<2xf16> {
- // expected-error@+1 {{Only RS rounding mode is supported for conversions from f32x2 to f16x2.}}
- %res = nvvm.convert.f32x2.to.f16x2 %srcA, %srcB, %rbits {rnd = #nvvm.fp_rnd_mode<rn>} : vector<2xf16>
- llvm.return %res : vector<2xf16>
-}
-
-// -----
-
-llvm.func @invalid_rnd_mode_bf16x2(%srcA : f32, %srcB : f32, %rbits : i32) -> vector<2xbf16> {
- // expected-error@+1 {{Only RS rounding mode is supported for conversions from f32x2 to bf16x2.}}
- %res = nvvm.convert.f32x2.to.bf16x2 %srcA, %srcB, %rbits {rnd = #nvvm.fp_rnd_mode<rz>} : vector<2xbf16>
- llvm.return %res : vector<2xbf16>
-}
-
-// -----
-
// Test invalid destination types for f8x4 (should only accept f8E4M3FN, f8E5M2)
llvm.func @invalid_dst_type_f8x4_e3m4(%src : vector<4xf32>, %rbits : i32) -> vector<4xi8> {
// expected-error@+1 {{Only 'f8E4M3FN' and 'f8E5M2' types are supported for conversions from f32x4 to f8x4.}}
diff --git a/mlir/test/Target/LLVMIR/nvvm/convert_fp16x2.mlir b/mlir/test/Target/LLVMIR/nvvm/convert_fp16x2.mlir
new file mode 100644
index 0000000000000..a4bece83f832a
--- /dev/null
+++ b/mlir/test/Target/LLVMIR/nvvm/convert_fp16x2.mlir
@@ -0,0 +1,87 @@
+// RUN: mlir-translate -mlir-to-llvmir %s | FileCheck %s
+
+// CHECK-LABEL: @convert_f32x2_to_f16x2_rn
+llvm.func @convert_f32x2_to_f16x2_rn(%srcA : f32, %srcB : f32) {
+ // CHECK: %{{.*}} = call <2 x half> @llvm.nvvm.ff2f16x2.rn(float %{{.*}}, float %{{.*}})
+ %res1 = nvvm.convert.f32x2.to.f16x2 %srcA, %srcB {rnd = #nvvm.fp_rnd_mode<rn>} : vector<2xf16>
+ // CHECK: %{{.*}} = call <2 x half> @llvm.nvvm.ff2f16x2.rn.satfinite(float %{{.*}}, float %{{.*}})
+ %res2 = nvvm.convert.f32x2.to.f16x2 %srcA, %srcB {rnd = #nvvm.fp_rnd_mode<rn>, sat = #nvvm.sat_mode<satfinite>} : vector<2xf16>
+ // CHECK: %{{.*}} = call <2 x half> @llvm.nvvm.ff2f16x2.rn.relu(float %{{.*}}, float %{{.*}})
+ %res3 = nvvm.convert.f32x2.to.f16x2 %srcA, %srcB {rnd = #nvvm.fp_rnd_mode<rn>, relu = true} : vector<2xf16>
+ // CHECK: %{{.*}} = call <2 x half> @llvm.nvvm.ff2f16x2.rn.relu.satfinite(float %{{.*}}, float %{{.*}})
+ %res4 = nvvm.convert.f32x2.to.f16x2 %srcA, %srcB {rnd = #nvvm.fp_rnd_mode<rn>, relu = true, sat = #nvvm.sat_mode<satfinite>} : vector<2xf16>
+
+ llvm.return
+}
+
+// CHECK-LABEL: @convert_f32x2_to_f16x2_rz
+llvm.func @convert_f32x2_to_f16x2_rz(%srcA : f32, %srcB : f32) {
+ // CHECK: %{{.*}} = call <2 x half> @llvm.nvvm.ff2f16x2.rz(float %{{.*}}, float %{{.*}})
+ %res1 = nvvm.convert.f32x2.to.f16x2 %srcA, %srcB {rnd = #nvvm.fp_rnd_mode<rz>} : vector<2xf16>
+ // CHECK: %{{.*}} = call <2 x half> @llvm.nvvm.ff2f16x2.rz.satfinite(float %{{.*}}, float %{{.*}})
+ %res2 = nvvm.convert.f32x2.to.f16x2 %srcA, %srcB {rnd = #nvvm.fp_rnd_mode<rz>, sat = #nvvm.sat_mode<satfinite>} : vector<2xf16>
+ // CHECK: %{{.*}} = call <2 x half> @llvm.nvvm.ff2f16x2.rz.relu(float %{{.*}}, float %{{.*}})
+ %res3 = nvvm.convert.f32x2.to.f16x2 %srcA, %srcB {rnd = #nvvm.fp_rnd_mode<rz>, relu = true} : vector<2xf16>
+ // CHECK: %{{.*}} = call <2 x half> @llvm.nvvm.ff2f16x2.rz.relu.satfinite(float %{{.*}}, float %{{.*}})
+ %res4 = nvvm.convert.f32x2.to.f16x2 %srcA, %srcB {rnd = #nvvm.fp_rnd_mode<rz>, relu = true, sat = #nvvm.sat_mode<satfinite>} : vector<2xf16>
+
+ llvm.return
+}
+
+// CHECK-LABEL: @convert_f32x2_to_f16x2_rs_stochastic
+llvm.func @convert_f32x2_to_f16x2_rs_stochastic(%srcA : f32, %srcB : f32, %rbits : i32) {
+ // CHECK: %{{.*}} = call <2 x half> @llvm.nvvm.ff2f16x2.rs(float %{{.*}}, float %{{.*}}, i32 %{{.*}})
+ %res1 = nvvm.convert.f32x2.to.f16x2 %srcA, %srcB, %rbits {rnd = #nvvm.fp_rnd_mode<rs>} : vector<2xf16>
+ // CHECK: %{{.*}} = call <2 x half> @llvm.nvvm.ff2f16x2.rs.relu(float %{{.*}}, float %{{.*}}, i32 %{{.*}})
+ %res2 = nvvm.convert.f32x2.to.f16x2 %srcA, %srcB, %rbits {relu = true, rnd = #nvvm.fp_rnd_mode<rs>} : vector<2xf16>
+ // CHECK: %{{.*}} = call <2 x half> @llvm.nvvm.ff2f16x2.rs.satfinite(float %{{.*}}, float %{{.*}}, i32 %{{.*}})
+ %res3 = nvvm.convert.f32x2.to.f16x2 %srcA, %srcB, %rbits {rnd = #nvvm.fp_rnd_mode<rs>, sat = #nvvm.sat_mode<satfinite>} : vector<2xf16>
+ // CHECK: %{{.*}} = call <2 x half> @llvm.nvvm.ff2f16x2.rs.relu.satfinite(float %{{.*}}, float %{{.*}}, i32 %{{.*}})
+ %res4 = nvvm.convert.f32x2.to.f16x2 %srcA, %srcB, %rbits {relu = true, rnd = #nvvm.fp_rnd_mode<rs>, sat = #nvvm.sat_mode<satfinite>} : vector<2xf16>
+
+ llvm.return
+}
+
+// -----
+
+// CHECK-LABEL: @convert_f32x2_to_bf16x2_rn
+llvm.func @convert_f32x2_to_bf16x2_rn(%srcA : f32, %srcB : f32) {
+ // CHECK: %{{.*}} = call <2 x bfloat> @llvm.nvvm.ff2bf16x2.rn(float %{{.*}}, float %{{.*}})
+ %res1 = nvvm.convert.f32x2.to.bf16x2 %srcA, %srcB {rnd = #nvvm.fp_rnd_mode<rn>} : vector<2xbf16>
+ // CHECK: %{{.*}} = call <2 x bfloat> @llvm.nvvm.ff2bf16x2.rn.satfinite(float %{{.*}}, float %{{.*}})
+ %res2 = nvvm.convert.f32x2.to.bf16x2 %srcA, %srcB {rnd = #nvvm.fp_rnd_mode<rn>, sat = #nvvm.sat_mode<satfinite>} : vector<2xbf16>
+ // CHECK: %{{.*}} = call <2 x bfloat> @llvm.nvvm.ff2bf16x2.rn.relu(float %{{.*}}, float %{{.*}})
+ %res3 = nvvm.convert.f32x2.to.bf16x2 %srcA, %srcB {rnd = #nvvm.fp_rnd_mode<rn>, relu = true} : vector<2xbf16>
+ // CHECK: %{{.*}} = call <2 x bfloat> @llvm.nvvm.ff2bf16x2.rn.relu.satfinite(float %{{.*}}, float %{{.*}})
+ %res4 = nvvm.convert.f32x2.to.bf16x2 %srcA, %srcB {rnd = #nvvm.fp_rnd_mode<rn>, relu = true, sat = #nvvm.sat_mode<satfinite>} : vector<2xbf16>
+
+ llvm.return
+}
+
+// CHECK-LABEL: @convert_f32x2_to_bf16x2_rz
+llvm.func @convert_f32x2_to_bf16x2_rz(%srcA : f32, %srcB : f32) {
+ // CHECK: %{{.*}} = call <2 x bfloat> @llvm.nvvm.ff2bf16x2.rz(float %{{.*}}, float %{{.*}})
+ %res1 = nvvm.convert.f32x2.to.bf16x2 %srcA, %srcB {rnd = #nvvm.fp_rnd_mode<rz>} : vector<2xbf16>
+ // CHECK: %{{.*}} = call <2 x bfloat> @llvm.nvvm.ff2bf16x2.rz.satfinite(float %{{.*}}, float %{{.*}})
+ %res2 = nvvm.convert.f32x2.to.bf16x2 %srcA, %srcB {rnd = #nvvm.fp_rnd_mode<rz>, sat = #nvvm.sat_mode<satfinite>} : vector<2xbf16>
+ // CHECK: %{{.*}} = call <2 x bfloat> @llvm.nvvm.ff2bf16x2.rz.relu(float %{{.*}}, float %{{.*}})
+ %res3 = nvvm.convert.f32x2.to.bf16x2 %srcA, %srcB {rnd = #nvvm.fp_rnd_mode<rz>, relu = true} : vector<2xbf16>
+ // CHECK: %{{.*}} = call <2 x bfloat> @llvm.nvvm.ff2bf16x2.rz.relu.satfinite(float %{{.*}}, float %{{.*}})
+ %res4 = nvvm.convert.f32x2.to.bf16x2 %srcA, %srcB {rnd = #nvvm.fp_rnd_mode<rz>, relu = true, sat = #nvvm.sat_mode<satfinite>} : vector<2xbf16>
+
+ llvm.return
+}
+
+// CHECK-LABEL: @convert_f32x2_to_bf16x2_rs_stochastic
+llvm.func @convert_f32x2_to_bf16x2_rs_stochastic(%srcA : f32, %srcB : f32, %rbits : i32) {
+ // CHECK: %{{.*}} = call <2 x bfloat> @llvm.nvvm.ff2bf16x2.rs(float %{{.*}}, float %{{.*}}, i32 %{{.*}})
+ %res1 = nvvm.convert.f32x2.to.bf16x2 %srcA, %srcB, %rbits {rnd = #nvvm.fp_rnd_mode<rs>} : vector<2xbf16>
+ // CHECK: %{{.*}} = call <2 x bfloat> @llvm.nvvm.ff2bf16x2.rs.relu(float %{{.*}}, float %{{.*}}, i32 %{{.*}})
+ %res2 = nvvm.convert.f32x2.to.bf16x2 %srcA, %srcB, %rbits {relu = true, rnd = #nvvm.fp_rnd_mode<rs>} : vector<2xbf16>
+ // CHECK: %{{.*}} = call <2 x bfloat> @llvm.nvvm.ff2bf16x2.rs.satfinite(float %{{.*}}, float %{{.*}}, i32 %{{.*}})
+ %res3 = nvvm.convert.f32x2.to.bf16x2 %srcA, %srcB, %rbits {rnd = #nvvm.fp_rnd_mode<rs>, sat = #nvvm.sat_mode<satfinite>} : vector<2xbf16>
+ // CHECK: %{{.*}} = call <2 x bfloat> @llvm.nvvm.ff2bf16x2.rs.relu.satfinite(float %{{.*}}, float %{{.*}}, i32 %{{.*}})
+ %res4 = nvvm.convert.f32x2.to.bf16x2 %srcA, %srcB, %rbits {relu = true, rnd = #nvvm.fp_rnd_mode<rs>, sat = #nvvm.sat_mode<satfinite>} : vector<2xbf16>
+
+ llvm.return
+}
diff --git a/mlir/test/Target/LLVMIR/nvvm/convert_stochastic_rounding.mlir b/mlir/test/Target/LLVMIR/nvvm/convert_stochastic_rounding.mlir
index b5bb22350dcd7..03abcddd96cb0 100644
--- a/mlir/test/Target/LLVMIR/nvvm/convert_stochastic_rounding.mlir
+++ b/mlir/test/Target/LLVMIR/nvvm/convert_stochastic_rounding.mlir
@@ -10,7 +10,7 @@ gpu.module @valid_f16x2_rs_sm_100a [#nvvm.target<chip = "sm_100a">] {
%f1 = llvm.mlir.constant(1.0 : f32) : f32
%f2 = llvm.mlir.constant(2.0 : f32) : f32
%rbits = llvm.mlir.constant(0x12345678 : i32) : i32
- %res = nvvm.convert.f32x2.to.f16x2 %f1, %f2, %rbits : vector<2xf16>
+ %res = nvvm.convert.f32x2.to.f16x2 %f1, %f2, %rbits {rnd = #nvvm.fp_rnd_mode<rs>} : vector<2xf16>
return
}
}
@@ -21,77 +21,13 @@ gpu.module @valid_bf16x2_rs_sm_103a [#nvvm.target<chip = "sm_103a">] {
%f1 = llvm.mlir.constant(1.0 : f32) : f32
%f2 = llvm.mlir.constant(2.0 : f32) : f32
%rbits = llvm.mlir.constant(0 : i32) : i32
- %res = nvvm.convert.f32x2.to.bf16x2 %f1, %f2, %rbits : vector<2xbf16>
+ %res = nvvm.convert.f32x2.to.bf16x2 %f1, %f2, %rbits {rnd = #nvvm.fp_rnd_mode<rs>} : vector<2xbf16>
return
}
}
// ...
[truncated]
|
🐧 Linux x64 Test Results
|
| NVVM::IDArgPair | ||
| ConvertF32x2ToBF16x2Op::getIntrinsicIDAndArgs(NVVM::ConvertF32x2ToBF16x2Op &op, | ||
| LLVM::ModuleTranslation &mt, | ||
| llvm::IRBuilderBase &builder) { | ||
| static constexpr llvm::Intrinsic::ID rndRNIds[] = { | ||
| llvm::Intrinsic::nvvm_ff2bf16x2_rn, | ||
| llvm::Intrinsic::nvvm_ff2bf16x2_rn_relu, | ||
| llvm::Intrinsic::nvvm_ff2bf16x2_rn_satfinite, | ||
| llvm::Intrinsic::nvvm_ff2bf16x2_rn_relu_satfinite, | ||
| }; | ||
| static constexpr llvm::Intrinsic::ID rndRZIds[] = { | ||
| llvm::Intrinsic::nvvm_ff2bf16x2_rz, | ||
| llvm::Intrinsic::nvvm_ff2bf16x2_rz_relu, | ||
| llvm::Intrinsic::nvvm_ff2bf16x2_rz_satfinite, | ||
| llvm::Intrinsic::nvvm_ff2bf16x2_rz_relu_satfinite, | ||
| }; | ||
| static constexpr llvm::Intrinsic::ID rndRSIds[] = { | ||
| llvm::Intrinsic::nvvm_ff2bf16x2_rs, | ||
| llvm::Intrinsic::nvvm_ff2bf16x2_rs_relu, | ||
| llvm::Intrinsic::nvvm_ff2bf16x2_rs_satfinite, | ||
| llvm::Intrinsic::nvvm_ff2bf16x2_rs_relu_satfinite, | ||
| }; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
just thinking out loud - can we combine these two tables for b16 and f16:
static constexpr llvm::Intrinsic::ID ff2x16IntrinsicSet
[2 /*Fp16Kind*/]
[3 /*RoundingMode*/]
[4 /*PostOp*/] = {
// ===== F16 =====
{
// RN
{
llvm::Intrinsic::nvvm_ff2f16x2_rn,
llvm::Intrinsic::nvvm_ff2f16x2_rn_relu,
llvm::Intrinsic::nvvm_ff2f16x2_rn_satfinite,
llvm::Intrinsic::nvvm_ff2f16x2_rn_relu_satfinite,
},
....
}
and also write a selector function
inline llvm::Intrinsic::ID getIntrinsic(Fp16Kind kind,RoundingMode rnd, Post post) {
return ff2x16IntrinsicSet[kind][rnd][post]
}
Then you can select the intrinsic nicely:
llvm::Intrinsic::ID it = getIntrinsic(op.getType(), op.getRnd(), op.getMode());
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't have strong opinion. I think we are still trying to find what is the best way to create large tables and select and intrinsic from them nicely.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I did consider having a single table initially but went with separate ones because the valid rounding modes rn, rz, and rs are not consecutive in the FPRoundingMode enum so we'd have to do some sort of mapping (or have empty entries in the table) which didn't look very nice.
|
LGTM, nice job! Good to merge on my end after other comments addressed |
durga4github
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Latest revision LGTM, thanks for addressing the comments..
|
LLVM Buildbot has detected a new failure on builder Full details are available at: https://lab.llvm.org/buildbot/#/builders/116/builds/21489 Here is the relevant piece of the build log for the reference |
|
LLVM Buildbot has detected a new failure on builder Full details are available at: https://lab.llvm.org/buildbot/#/builders/157/builds/42571 Here is the relevant piece of the build log for the reference |
…69005) This change adds the `RN` and `RZ` rounding modes to the `convert.f32x2.to.f16x2` and `convert.f32x2.to.bf16x2` Ops. Tests are added in `convert_fp16x2.mlir` and `invalid_convert_fp16x2.mlir`. Tests with these Ops in `convert_stochastic_rounding.mlir` and `invalid-convert-stochastic-rounding.mlir` have been removed or modified. PTX spec reference: https://docs.nvidia.com/cuda/parallel-thread-execution/#data-movement-and-conversion-instructions-cvt
This change adds the
RNandRZrounding modes to theconvert.f32x2.to.f16x2andconvert.f32x2.to.bf16x2Ops.Tests are added in
convert_fp16x2.mlirandinvalid_convert_fp16x2.mlir.Tests with these Ops in
convert_stochastic_rounding.mlirandinvalid-convert-stochastic-rounding.mlirhave been removed or modified.PTX spec reference: https://docs.nvidia.com/cuda/parallel-thread-execution/#data-movement-and-conversion-instructions-cvt