Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 32 additions & 20 deletions mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -1912,45 +1912,57 @@ def NVVM_ConvertF4x2ToF16x2Op :

// Base class for conversions from F32x2 to FPx2 formats
// (F16x2, BF16x2)
// TODO: In separate PR, add .rn and .rz rounding variants for this conversion
// as currently only support .rs rounding mode
class NVVM_ConvertF32x2ToFPx2OpBase<string dstFormat, string mnemonic, Type dstType> :
NVVM_Op<mnemonic, [Pure, NVVMRequiresSMa<[100, 103]>]>,
NVVM_Op<mnemonic, [Pure]>,
Results<(outs dstType:$dst)>,
Arguments<(ins F32:$src_hi, F32:$src_lo, I32:$rbits,
DefaultValuedAttr<FPRoundingModeAttr, "FPRoundingMode::RS">:$rnd,
Arguments<(ins F32:$src_hi, F32:$src_lo,
Optional<I32>:$random_bits,
DefaultValuedAttr<FPRoundingModeAttr, "FPRoundingMode::NONE">:$rnd,
DefaultValuedAttr<SaturationModeAttr, "SaturationMode::NONE">:$sat,
DefaultValuedAttr<BoolAttr, "false">:$relu)> {
let summary = "Convert two F32 values to packed " # dstFormat # " with stochastic rounding (.rs)";
let summary = "Convert two F32 values to packed " # !tolower(dstFormat) # ".";
let description = [{
Converts two F32 values to packed }] # dstFormat # [{ format using stochastic
rounding (.rs) mode with randomness provided by the `rbits` parameter. The
`relu` attribute clamps negative results to 0. The `sat` attribute determines
saturation behavior. The `src_hi` and `src_lo` parameters correspond to operands
`a` and `b` in the PTX ISA, respectively.
Converts two F32 values to packed }] # !tolower(dstFormat) # [{ format with
the specified rounding mode. The `src_hi` and `src_lo` parameters
correspond to operands `a` and `b` in the PTX ISA, respectively.

The `random_bits` parameter is required for stochastic rounding and
provides the [random bits](}] #
!if(!eq(dstFormat, "F16x2"),
"https://docs.nvidia.com/cuda/parallel-thread-execution/#cvt-rs-rbits-layout-f16",
"https://docs.nvidia.com/cuda/parallel-thread-execution/#cvt-rs-rbits-layout-bf16") #
[{) to be used for the conversion.

The `relu` attribute clamps negative results to 0.

The `sat` attribute determines saturation behavior.

[For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cvt)
}];

let assemblyFormat = "$src_hi `,` $src_lo `,` $rbits attr-dict `:` type($dst)";
let assemblyFormat = "$src_hi `,` $src_lo (`,` $random_bits^)? attr-dict `:` type($dst)";

let hasVerifier = 1;

let extraClassDeclaration = [{
llvm::Intrinsic::ID getIntrinsicID();
static NVVM::IDArgPair
getIntrinsicIDAndArgs(
NVVM::ConvertF32x2To}] # dstFormat # [{Op &op,
LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder);
}];

string llvmBuilder = [{
auto intId = op.getIntrinsicID();
$dst = createIntrinsicCall(builder, intId, {$src_hi, $src_lo, $rbits});
auto [intId, args] = mlir::NVVM::ConvertF32x2To}] # dstFormat #
[{Op::getIntrinsicIDAndArgs(op, moduleTranslation, builder);
$dst = createIntrinsicCall(builder, intId, args);
}];
}
}

// F32x2 -> F16x2 with stochastic rounding
def NVVM_ConvertF32x2ToF16x2Op : NVVM_ConvertF32x2ToFPx2OpBase<"f16x2", "convert.f32x2.to.f16x2", VectorOfLengthAndType<[2], [F16]>>;
// F32x2 -> F16x2
def NVVM_ConvertF32x2ToF16x2Op : NVVM_ConvertF32x2ToFPx2OpBase<"F16x2", "convert.f32x2.to.f16x2", VectorOfLengthAndType<[2], [F16]>>;

// F32x2 -> BF16x2 with stochastic rounding
def NVVM_ConvertF32x2ToBF16x2Op : NVVM_ConvertF32x2ToFPx2OpBase<"bf16x2", "convert.f32x2.to.bf16x2", VectorOfLengthAndType<[2], [BF16]>>;
// F32x2 -> BF16x2
def NVVM_ConvertF32x2ToBF16x2Op : NVVM_ConvertF32x2ToFPx2OpBase<"BF16x2", "convert.f32x2.to.bf16x2", VectorOfLengthAndType<[2], [BF16]>>;

// Base class for stochastic rounding conversions from F32x4 to FPx4 formats
// (E4M3x4, E5M2x4, E2M3x4, E3M2x4, E2M1x4)
Expand Down
152 changes: 123 additions & 29 deletions mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -390,18 +390,42 @@ LogicalResult ConvertF4x2ToF16x2Op::verify() {
// Stochastic Rounding Conversion Ops
//===----------------------------------------------------------------------===//

LogicalResult ConvertF32x2ToF16x2Op::verify() {
if (getRnd() != FPRoundingMode::RS)
return emitOpError("Only RS rounding mode is supported for "
"conversions from f32x2 to f16x2.");
static LogicalResult verifyConvertF32x2ToFP16x2Op(Twine dstType,
FPRoundingMode rnd,
bool hasRandomBits,
Operation *op) {
static constexpr FPRoundingMode validRndModes[] = {
FPRoundingMode::RN, FPRoundingMode::RZ, FPRoundingMode::RS};

if (!llvm::is_contained(validRndModes, rnd)) {
return op->emitOpError(
"Only RN, RZ, and RS rounding modes are supported for "
"conversions from f32x2 to ")
<< dstType << ".";
}

if (rnd == FPRoundingMode::RS) {
if (!hasRandomBits) {
return op->emitOpError("random_bits is required for RS rounding mode.");
}
} else {
if (hasRandomBits) {
return op->emitOpError(
"random_bits not supported for RN and RZ rounding modes.");
}
}

return success();
}

LogicalResult ConvertF32x2ToF16x2Op::verify() {
return verifyConvertF32x2ToFP16x2Op("f16x2", getRnd(),
getRandomBits() ? true : false, *this);
}

LogicalResult ConvertF32x2ToBF16x2Op::verify() {
if (getRnd() != FPRoundingMode::RS)
return emitOpError("Only RS rounding mode is supported for "
"conversions from f32x2 to bf16x2.");
return success();
return verifyConvertF32x2ToFP16x2Op("bf16x2", getRnd(),
getRandomBits() ? true : false, *this);
}

LogicalResult ConvertF32x4ToF8x4Op::verify() {
Expand Down Expand Up @@ -2727,30 +2751,100 @@ Tcgen05CommitOp::getIntrinsicIDAndArgs(Operation &op,
return TCGEN05_CP_2CTA(shape_mc, , is_2cta); \
}()

llvm::Intrinsic::ID ConvertF32x2ToF16x2Op::getIntrinsicID() {
bool hasRelu = getRelu();
bool hasSatFinite = (getSat() == NVVM::SaturationMode::SATFINITE);
NVVM::IDArgPair
ConvertF32x2ToF16x2Op::getIntrinsicIDAndArgs(NVVM::ConvertF32x2ToF16x2Op &op,
LLVM::ModuleTranslation &mt,
llvm::IRBuilderBase &builder) {
static constexpr llvm::Intrinsic::ID rndRNIds[] = {
llvm::Intrinsic::nvvm_ff2f16x2_rn,
llvm::Intrinsic::nvvm_ff2f16x2_rn_relu,
llvm::Intrinsic::nvvm_ff2f16x2_rn_satfinite,
llvm::Intrinsic::nvvm_ff2f16x2_rn_relu_satfinite,
};
static constexpr llvm::Intrinsic::ID rndRZIds[] = {
llvm::Intrinsic::nvvm_ff2f16x2_rz,
llvm::Intrinsic::nvvm_ff2f16x2_rz_relu,
llvm::Intrinsic::nvvm_ff2f16x2_rz_satfinite,
llvm::Intrinsic::nvvm_ff2f16x2_rz_relu_satfinite,
};
static constexpr llvm::Intrinsic::ID rndRSIds[] = {
llvm::Intrinsic::nvvm_ff2f16x2_rs,
llvm::Intrinsic::nvvm_ff2f16x2_rs_relu,
llvm::Intrinsic::nvvm_ff2f16x2_rs_satfinite,
llvm::Intrinsic::nvvm_ff2f16x2_rs_relu_satfinite,
};

if (hasRelu && hasSatFinite)
return llvm::Intrinsic::nvvm_ff2f16x2_rs_relu_satfinite;
if (hasRelu)
return llvm::Intrinsic::nvvm_ff2f16x2_rs_relu;
if (hasSatFinite)
return llvm::Intrinsic::nvvm_ff2f16x2_rs_satfinite;
return llvm::Intrinsic::nvvm_ff2f16x2_rs;
unsigned hasRelu = op.getRelu() ? 1 : 0;
unsigned hasSatFinite =
(op.getSat() == NVVM::SaturationMode::SATFINITE) ? 1 : 0;
// idx: bit-0 - relu
// bit-1 - satfinite
unsigned idx = (hasSatFinite << 1) | hasRelu;

llvm::SmallVector<llvm::Value *> args;
args.push_back(mt.lookupValue(op.getSrcHi()));
args.push_back(mt.lookupValue(op.getSrcLo()));
if (op.getRandomBits())
args.push_back(mt.lookupValue(op.getRandomBits()));

switch (op.getRnd()) {
case FPRoundingMode::RN:
return {rndRNIds[idx], std::move(args)};
case FPRoundingMode::RZ:
return {rndRZIds[idx], std::move(args)};
case FPRoundingMode::RS:
return {rndRSIds[idx], std::move(args)};
default:
llvm_unreachable("Invalid rounding mode for ConvertF32x2ToF16x2Op");
}
}

llvm::Intrinsic::ID ConvertF32x2ToBF16x2Op::getIntrinsicID() {
bool hasRelu = getRelu();
bool hasSatFinite = (getSat() == NVVM::SaturationMode::SATFINITE);

if (hasRelu && hasSatFinite)
return llvm::Intrinsic::nvvm_ff2bf16x2_rs_relu_satfinite;
if (hasRelu)
return llvm::Intrinsic::nvvm_ff2bf16x2_rs_relu;
if (hasSatFinite)
return llvm::Intrinsic::nvvm_ff2bf16x2_rs_satfinite;
return llvm::Intrinsic::nvvm_ff2bf16x2_rs;
NVVM::IDArgPair
ConvertF32x2ToBF16x2Op::getIntrinsicIDAndArgs(NVVM::ConvertF32x2ToBF16x2Op &op,
LLVM::ModuleTranslation &mt,
llvm::IRBuilderBase &builder) {
static constexpr llvm::Intrinsic::ID rndRNIds[] = {
llvm::Intrinsic::nvvm_ff2bf16x2_rn,
llvm::Intrinsic::nvvm_ff2bf16x2_rn_relu,
llvm::Intrinsic::nvvm_ff2bf16x2_rn_satfinite,
llvm::Intrinsic::nvvm_ff2bf16x2_rn_relu_satfinite,
};
static constexpr llvm::Intrinsic::ID rndRZIds[] = {
llvm::Intrinsic::nvvm_ff2bf16x2_rz,
llvm::Intrinsic::nvvm_ff2bf16x2_rz_relu,
llvm::Intrinsic::nvvm_ff2bf16x2_rz_satfinite,
llvm::Intrinsic::nvvm_ff2bf16x2_rz_relu_satfinite,
};
static constexpr llvm::Intrinsic::ID rndRSIds[] = {
llvm::Intrinsic::nvvm_ff2bf16x2_rs,
llvm::Intrinsic::nvvm_ff2bf16x2_rs_relu,
llvm::Intrinsic::nvvm_ff2bf16x2_rs_satfinite,
llvm::Intrinsic::nvvm_ff2bf16x2_rs_relu_satfinite,
};
Comment on lines +2802 to +2823
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just thinking out loud - can we combine these two tables for b16 and f16:

static constexpr llvm::Intrinsic::ID ff2x16IntrinsicSet
    [2 /*Fp16Kind*/]
    [3 /*RoundingMode*/]
    [4 /*PostOp*/] = {
  // ===== F16 =====
  {
    // RN
    {
      llvm::Intrinsic::nvvm_ff2f16x2_rn,
      llvm::Intrinsic::nvvm_ff2f16x2_rn_relu,
      llvm::Intrinsic::nvvm_ff2f16x2_rn_satfinite,
      llvm::Intrinsic::nvvm_ff2f16x2_rn_relu_satfinite,
    },
     ....
}

and also write a selector function

inline llvm::Intrinsic::ID getIntrinsic(Fp16Kind kind,RoundingMode rnd, Post post) {
  return ff2x16IntrinsicSet[kind][rnd][post]
} 

Then you can select the intrinsic nicely:

llvm::Intrinsic::ID it = getIntrinsic(op.getType(), op.getRnd(), op.getMode());

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't have strong opinion. I think we are still trying to find what is the best way to create large tables and select and intrinsic from them nicely.

Copy link
Contributor Author

@Wolfram70 Wolfram70 Nov 21, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I did consider having a single table initially but went with separate ones because the valid rounding modes rn, rz, and rs are not consecutive in the FPRoundingMode enum so we'd have to do some sort of mapping (or have empty entries in the table) which didn't look very nice.


unsigned hasRelu = op.getRelu() ? 1 : 0;
unsigned hasSatFinite =
(op.getSat() == NVVM::SaturationMode::SATFINITE) ? 1 : 0;
// idx: bit-0 - relu
// bit-1 - satfinite
unsigned idx = (hasSatFinite << 1) | hasRelu;

llvm::SmallVector<llvm::Value *> args;
args.push_back(mt.lookupValue(op.getSrcHi()));
args.push_back(mt.lookupValue(op.getSrcLo()));
if (op.getRandomBits())
args.push_back(mt.lookupValue(op.getRandomBits()));

switch (op.getRnd()) {
case FPRoundingMode::RN:
return {rndRNIds[idx], std::move(args)};
case FPRoundingMode::RZ:
return {rndRZIds[idx], std::move(args)};
case FPRoundingMode::RS:
return {rndRSIds[idx], std::move(args)};
default:
llvm_unreachable("Invalid rounding mode for ConvertF32x2ToBF16x2Op");
}
}

llvm::Intrinsic::ID ConvertF32x4ToF8x4Op::getIntrinsicID() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,35 +2,15 @@

// Test invalid target architecture (sm_100 instead of sm_100a)
gpu.module @invalid_arch_sm_100 [#nvvm.target<chip = "sm_100">] {
func.func @convert_rs() {
%f1 = llvm.mlir.constant(1.0 : f32) : f32
%f2 = llvm.mlir.constant(2.0 : f32) : f32
%rbits = llvm.mlir.constant(0x12345678 : i32) : i32
// expected-error@+1 {{'nvvm.convert.f32x2.to.f16x2' op is not supported on sm_100}}
%res = nvvm.convert.f32x2.to.f16x2 %f1, %f2, %rbits : vector<2xf16>
func.func @convert_rs(%src : vector<4xf32>, %rbits : i32) {
// expected-error@+1 {{'nvvm.convert.f32x4.to.f8x4' op is not supported on sm_100}}
%res = nvvm.convert.f32x4.to.f8x4 %src, %rbits : vector<4xf32> -> vector<4xi8> (f8E4M3FN)
return
}
}

// -----

// Test that operations require stochastic rounding mode
llvm.func @invalid_rnd_mode_f16x2(%srcA : f32, %srcB : f32, %rbits : i32) -> vector<2xf16> {
// expected-error@+1 {{Only RS rounding mode is supported for conversions from f32x2 to f16x2.}}
%res = nvvm.convert.f32x2.to.f16x2 %srcA, %srcB, %rbits {rnd = #nvvm.fp_rnd_mode<rn>} : vector<2xf16>
llvm.return %res : vector<2xf16>
}

// -----

llvm.func @invalid_rnd_mode_bf16x2(%srcA : f32, %srcB : f32, %rbits : i32) -> vector<2xbf16> {
// expected-error@+1 {{Only RS rounding mode is supported for conversions from f32x2 to bf16x2.}}
%res = nvvm.convert.f32x2.to.bf16x2 %srcA, %srcB, %rbits {rnd = #nvvm.fp_rnd_mode<rz>} : vector<2xbf16>
llvm.return %res : vector<2xbf16>
}

// -----

// Test invalid destination types for f8x4 (should only accept f8E4M3FN, f8E5M2)
llvm.func @invalid_dst_type_f8x4_e3m4(%src : vector<4xf32>, %rbits : i32) -> vector<4xi8> {
// expected-error@+1 {{Only 'f8E4M3FN' and 'f8E5M2' types are supported for conversions from f32x4 to f8x4.}}
Expand Down
87 changes: 87 additions & 0 deletions mlir/test/Target/LLVMIR/nvvm/convert_fp16x2.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
// RUN: mlir-translate -mlir-to-llvmir %s | FileCheck %s

// CHECK-LABEL: @convert_f32x2_to_f16x2_rn
llvm.func @convert_f32x2_to_f16x2_rn(%srcA : f32, %srcB : f32) {
// CHECK: %{{.*}} = call <2 x half> @llvm.nvvm.ff2f16x2.rn(float %{{.*}}, float %{{.*}})
%res1 = nvvm.convert.f32x2.to.f16x2 %srcA, %srcB {rnd = #nvvm.fp_rnd_mode<rn>} : vector<2xf16>
// CHECK: %{{.*}} = call <2 x half> @llvm.nvvm.ff2f16x2.rn.satfinite(float %{{.*}}, float %{{.*}})
%res2 = nvvm.convert.f32x2.to.f16x2 %srcA, %srcB {rnd = #nvvm.fp_rnd_mode<rn>, sat = #nvvm.sat_mode<satfinite>} : vector<2xf16>
// CHECK: %{{.*}} = call <2 x half> @llvm.nvvm.ff2f16x2.rn.relu(float %{{.*}}, float %{{.*}})
%res3 = nvvm.convert.f32x2.to.f16x2 %srcA, %srcB {rnd = #nvvm.fp_rnd_mode<rn>, relu = true} : vector<2xf16>
// CHECK: %{{.*}} = call <2 x half> @llvm.nvvm.ff2f16x2.rn.relu.satfinite(float %{{.*}}, float %{{.*}})
%res4 = nvvm.convert.f32x2.to.f16x2 %srcA, %srcB {rnd = #nvvm.fp_rnd_mode<rn>, relu = true, sat = #nvvm.sat_mode<satfinite>} : vector<2xf16>

llvm.return
}

// CHECK-LABEL: @convert_f32x2_to_f16x2_rz
llvm.func @convert_f32x2_to_f16x2_rz(%srcA : f32, %srcB : f32) {
// CHECK: %{{.*}} = call <2 x half> @llvm.nvvm.ff2f16x2.rz(float %{{.*}}, float %{{.*}})
%res1 = nvvm.convert.f32x2.to.f16x2 %srcA, %srcB {rnd = #nvvm.fp_rnd_mode<rz>} : vector<2xf16>
// CHECK: %{{.*}} = call <2 x half> @llvm.nvvm.ff2f16x2.rz.satfinite(float %{{.*}}, float %{{.*}})
%res2 = nvvm.convert.f32x2.to.f16x2 %srcA, %srcB {rnd = #nvvm.fp_rnd_mode<rz>, sat = #nvvm.sat_mode<satfinite>} : vector<2xf16>
// CHECK: %{{.*}} = call <2 x half> @llvm.nvvm.ff2f16x2.rz.relu(float %{{.*}}, float %{{.*}})
%res3 = nvvm.convert.f32x2.to.f16x2 %srcA, %srcB {rnd = #nvvm.fp_rnd_mode<rz>, relu = true} : vector<2xf16>
// CHECK: %{{.*}} = call <2 x half> @llvm.nvvm.ff2f16x2.rz.relu.satfinite(float %{{.*}}, float %{{.*}})
%res4 = nvvm.convert.f32x2.to.f16x2 %srcA, %srcB {rnd = #nvvm.fp_rnd_mode<rz>, relu = true, sat = #nvvm.sat_mode<satfinite>} : vector<2xf16>

llvm.return
}

// CHECK-LABEL: @convert_f32x2_to_f16x2_rs_stochastic
llvm.func @convert_f32x2_to_f16x2_rs_stochastic(%srcA : f32, %srcB : f32, %rbits : i32) {
// CHECK: %{{.*}} = call <2 x half> @llvm.nvvm.ff2f16x2.rs(float %{{.*}}, float %{{.*}}, i32 %{{.*}})
%res1 = nvvm.convert.f32x2.to.f16x2 %srcA, %srcB, %rbits {rnd = #nvvm.fp_rnd_mode<rs>} : vector<2xf16>
// CHECK: %{{.*}} = call <2 x half> @llvm.nvvm.ff2f16x2.rs.relu(float %{{.*}}, float %{{.*}}, i32 %{{.*}})
%res2 = nvvm.convert.f32x2.to.f16x2 %srcA, %srcB, %rbits {relu = true, rnd = #nvvm.fp_rnd_mode<rs>} : vector<2xf16>
// CHECK: %{{.*}} = call <2 x half> @llvm.nvvm.ff2f16x2.rs.satfinite(float %{{.*}}, float %{{.*}}, i32 %{{.*}})
%res3 = nvvm.convert.f32x2.to.f16x2 %srcA, %srcB, %rbits {rnd = #nvvm.fp_rnd_mode<rs>, sat = #nvvm.sat_mode<satfinite>} : vector<2xf16>
// CHECK: %{{.*}} = call <2 x half> @llvm.nvvm.ff2f16x2.rs.relu.satfinite(float %{{.*}}, float %{{.*}}, i32 %{{.*}})
%res4 = nvvm.convert.f32x2.to.f16x2 %srcA, %srcB, %rbits {relu = true, rnd = #nvvm.fp_rnd_mode<rs>, sat = #nvvm.sat_mode<satfinite>} : vector<2xf16>

llvm.return
}

// -----

// CHECK-LABEL: @convert_f32x2_to_bf16x2_rn
llvm.func @convert_f32x2_to_bf16x2_rn(%srcA : f32, %srcB : f32) {
// CHECK: %{{.*}} = call <2 x bfloat> @llvm.nvvm.ff2bf16x2.rn(float %{{.*}}, float %{{.*}})
%res1 = nvvm.convert.f32x2.to.bf16x2 %srcA, %srcB {rnd = #nvvm.fp_rnd_mode<rn>} : vector<2xbf16>
// CHECK: %{{.*}} = call <2 x bfloat> @llvm.nvvm.ff2bf16x2.rn.satfinite(float %{{.*}}, float %{{.*}})
%res2 = nvvm.convert.f32x2.to.bf16x2 %srcA, %srcB {rnd = #nvvm.fp_rnd_mode<rn>, sat = #nvvm.sat_mode<satfinite>} : vector<2xbf16>
// CHECK: %{{.*}} = call <2 x bfloat> @llvm.nvvm.ff2bf16x2.rn.relu(float %{{.*}}, float %{{.*}})
%res3 = nvvm.convert.f32x2.to.bf16x2 %srcA, %srcB {rnd = #nvvm.fp_rnd_mode<rn>, relu = true} : vector<2xbf16>
// CHECK: %{{.*}} = call <2 x bfloat> @llvm.nvvm.ff2bf16x2.rn.relu.satfinite(float %{{.*}}, float %{{.*}})
%res4 = nvvm.convert.f32x2.to.bf16x2 %srcA, %srcB {rnd = #nvvm.fp_rnd_mode<rn>, relu = true, sat = #nvvm.sat_mode<satfinite>} : vector<2xbf16>

llvm.return
}

// CHECK-LABEL: @convert_f32x2_to_bf16x2_rz
llvm.func @convert_f32x2_to_bf16x2_rz(%srcA : f32, %srcB : f32) {
// CHECK: %{{.*}} = call <2 x bfloat> @llvm.nvvm.ff2bf16x2.rz(float %{{.*}}, float %{{.*}})
%res1 = nvvm.convert.f32x2.to.bf16x2 %srcA, %srcB {rnd = #nvvm.fp_rnd_mode<rz>} : vector<2xbf16>
// CHECK: %{{.*}} = call <2 x bfloat> @llvm.nvvm.ff2bf16x2.rz.satfinite(float %{{.*}}, float %{{.*}})
%res2 = nvvm.convert.f32x2.to.bf16x2 %srcA, %srcB {rnd = #nvvm.fp_rnd_mode<rz>, sat = #nvvm.sat_mode<satfinite>} : vector<2xbf16>
// CHECK: %{{.*}} = call <2 x bfloat> @llvm.nvvm.ff2bf16x2.rz.relu(float %{{.*}}, float %{{.*}})
%res3 = nvvm.convert.f32x2.to.bf16x2 %srcA, %srcB {rnd = #nvvm.fp_rnd_mode<rz>, relu = true} : vector<2xbf16>
// CHECK: %{{.*}} = call <2 x bfloat> @llvm.nvvm.ff2bf16x2.rz.relu.satfinite(float %{{.*}}, float %{{.*}})
%res4 = nvvm.convert.f32x2.to.bf16x2 %srcA, %srcB {rnd = #nvvm.fp_rnd_mode<rz>, relu = true, sat = #nvvm.sat_mode<satfinite>} : vector<2xbf16>

llvm.return
}

// CHECK-LABEL: @convert_f32x2_to_bf16x2_rs_stochastic
llvm.func @convert_f32x2_to_bf16x2_rs_stochastic(%srcA : f32, %srcB : f32, %rbits : i32) {
// CHECK: %{{.*}} = call <2 x bfloat> @llvm.nvvm.ff2bf16x2.rs(float %{{.*}}, float %{{.*}}, i32 %{{.*}})
%res1 = nvvm.convert.f32x2.to.bf16x2 %srcA, %srcB, %rbits {rnd = #nvvm.fp_rnd_mode<rs>} : vector<2xbf16>
// CHECK: %{{.*}} = call <2 x bfloat> @llvm.nvvm.ff2bf16x2.rs.relu(float %{{.*}}, float %{{.*}}, i32 %{{.*}})
%res2 = nvvm.convert.f32x2.to.bf16x2 %srcA, %srcB, %rbits {relu = true, rnd = #nvvm.fp_rnd_mode<rs>} : vector<2xbf16>
// CHECK: %{{.*}} = call <2 x bfloat> @llvm.nvvm.ff2bf16x2.rs.satfinite(float %{{.*}}, float %{{.*}}, i32 %{{.*}})
%res3 = nvvm.convert.f32x2.to.bf16x2 %srcA, %srcB, %rbits {rnd = #nvvm.fp_rnd_mode<rs>, sat = #nvvm.sat_mode<satfinite>} : vector<2xbf16>
// CHECK: %{{.*}} = call <2 x bfloat> @llvm.nvvm.ff2bf16x2.rs.relu.satfinite(float %{{.*}}, float %{{.*}}, i32 %{{.*}})
%res4 = nvvm.convert.f32x2.to.bf16x2 %srcA, %srcB, %rbits {relu = true, rnd = #nvvm.fp_rnd_mode<rs>, sat = #nvvm.sat_mode<satfinite>} : vector<2xbf16>

llvm.return
}
Loading