diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td index 4f483859ac18d..bbc2e8d888173 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td @@ -1604,10 +1604,11 @@ def FPRoundingModeRM : I32EnumAttrCase<"RM", 2, "rm">; def FPRoundingModeRP : I32EnumAttrCase<"RP", 3, "rp">; def FPRoundingModeRZ : I32EnumAttrCase<"RZ", 4, "rz">; def FPRoundingModeRNA : I32EnumAttrCase<"RNA", 5, "rna">; +def FPRoundingModeRS : I32EnumAttrCase<"RS", 6, "rs">; def FPRoundingMode : I32EnumAttr<"FPRoundingMode", "NVVM FPRoundingMode kind", [FPRoundingModeNone, FPRoundingModeRN, FPRoundingModeRM, - FPRoundingModeRP, FPRoundingModeRZ, FPRoundingModeRNA]> { + FPRoundingModeRP, FPRoundingModeRZ, FPRoundingModeRNA, FPRoundingModeRS]> { let genSpecializedAttr = 0; let cppNamespace = "::mlir::NVVM"; } @@ -1921,6 +1922,96 @@ def NVVM_ConvertF6x2ToF16x2Op : def NVVM_ConvertF4x2ToF16x2Op : NVVM_ConvertToFP16x2Op_Base<"F4", I8, "F16">; +//===----------------------------------------------------------------------===// +// NVVM Stochastic Rounding Conversion Ops +//===----------------------------------------------------------------------===// + +// Base class for conversions from F32x2 to FPx2 formats +// (F16x2, BF16x2) +// TODO: In separate PR, add .rn and .rz rounding variants for this conversion +// as currently only support .rs rounding mode +class NVVM_ConvertF32x2ToFPx2OpBase : + NVVM_Op]>, + Results<(outs dstType:$dst)>, + Arguments<(ins F32:$src_hi, F32:$src_lo, I32:$rbits, + DefaultValuedAttr:$rnd, + DefaultValuedAttr:$sat, + DefaultValuedAttr:$relu)> { + let summary = "Convert two F32 values to packed " # dstFormat # " with stochastic rounding (.rs)"; + let description = [{ + Converts two F32 values to packed }] # dstFormat # [{ format using stochastic + rounding (.rs) mode with randomness provided by the `rbits` parameter. The + `relu` attribute clamps negative results to 0. The `sat` attribute determines + saturation behavior. The `src_hi` and `src_lo` parameters correspond to operands + `a` and `b` in the PTX ISA, respectively. + + [For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cvt) + }]; + + let assemblyFormat = "$src_hi `,` $src_lo `,` $rbits attr-dict `:` type($dst)"; + + let hasVerifier = 1; + + let extraClassDeclaration = [{ + llvm::Intrinsic::ID getIntrinsicID(); + }]; + + string llvmBuilder = [{ + auto intId = op.getIntrinsicID(); + $dst = createIntrinsicCall(builder, intId, {$src_hi, $src_lo, $rbits}); + }]; + } + +// F32x2 -> F16x2 with stochastic rounding +def NVVM_ConvertF32x2ToF16x2Op : NVVM_ConvertF32x2ToFPx2OpBase<"f16x2", "convert.f32x2.to.f16x2", VectorOfLengthAndType<[2], [F16]>>; + +// F32x2 -> BF16x2 with stochastic rounding +def NVVM_ConvertF32x2ToBF16x2Op : NVVM_ConvertF32x2ToFPx2OpBase<"bf16x2", "convert.f32x2.to.bf16x2", VectorOfLengthAndType<[2], [BF16]>>; + +// Base class for stochastic rounding conversions from F32x4 to FPx4 formats +// (E4M3x4, E5M2x4, E2M3x4, E3M2x4, E2M1x4) +// These operations always use RS (stochastic rounding) mode with SATFINITE saturation. +class NVVM_ConvertF32x4ToFPx4OpBase : + NVVM_Op]>, + Results<(outs dstType:$dst)>, + Arguments<(ins VectorOfLengthAndType<[4], [F32]>:$src, I32:$rbits, + DefaultValuedAttr:$relu, + TypeAttr:$dstTy)> { + let summary = "Convert vector<4xf32> to packed " # dstFormat # " with stochastic rounding (.rs) and satfinite"; + let description = [{ + Converts a vector<4xf32> to packed }] # dstFormat # [{ format using + stochastic rounding (.rs) mode with SATFINITE saturation. Randomness is + provided by the `rbits` parameter. The `dstTy` attribute specifies the + target floating-point format. The `relu` attribute clamps negative results to 0. + + Note: These operations always use RS rounding mode and SATFINITE saturation mode. + + [For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cvt) + }]; + + let assemblyFormat = "$src `,` $rbits attr-dict `:` type($src) `->` type($dst) `(` $dstTy `)`"; + + let hasVerifier = 1; + + let extraClassDeclaration = [{ + llvm::Intrinsic::ID getIntrinsicID(); + }]; + + string llvmBuilder = [{ + auto intId = op.getIntrinsicID(); + $dst = createIntrinsicCall(builder, intId, {$src, $rbits}); + }]; +} + +// F32x4 -> F8x4 with stochastic rounding (supports E4M3FN, E5M2) +def NVVM_ConvertF32x4ToF8x4Op : NVVM_ConvertF32x4ToFPx4OpBase<"f8x4", "convert.f32x4.to.f8x4", VectorOfLengthAndType<[4], [I8]>>; + +// F32x4 -> F6x4 with stochastic rounding (supports E2M3FN, E3M2FN) +def NVVM_ConvertF32x4ToF6x4Op : NVVM_ConvertF32x4ToFPx4OpBase<"f6x4", "convert.f32x4.to.f6x4", VectorOfLengthAndType<[4], [I8]>>; + +// F32x4 -> F4x4 with stochastic rounding (supports E2M1FN) +def NVVM_ConvertF32x4ToF4x4Op : NVVM_ConvertF32x4ToFPx4OpBase<"f4x4", "convert.f32x4.to.f4x4", I16>; + //===----------------------------------------------------------------------===// // NVVM MMA Ops //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp index f0de4dbcc1d4b..091b4a93842bb 100644 --- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp @@ -365,6 +365,59 @@ LogicalResult ConvertF4x2ToF16x2Op::verify() { return success(); } +//===----------------------------------------------------------------------===// +// Stochastic Rounding Conversion Ops +//===----------------------------------------------------------------------===// + +LogicalResult ConvertF32x2ToF16x2Op::verify() { + if (getRnd() != FPRoundingMode::RS) + return emitOpError("Only RS rounding mode is supported for " + "conversions from f32x2 to f16x2."); + return success(); +} + +LogicalResult ConvertF32x2ToBF16x2Op::verify() { + if (getRnd() != FPRoundingMode::RS) + return emitOpError("Only RS rounding mode is supported for " + "conversions from f32x2 to bf16x2."); + return success(); +} + +LogicalResult ConvertF32x4ToF8x4Op::verify() { + mlir::MLIRContext *ctx = getContext(); + + if (!llvm::isa(getDstTy())) + return emitOpError("Only ") + << mlir::Float8E4M3FNType::get(ctx) << " and " + << mlir::Float8E5M2Type::get(ctx) + << " types are supported for conversions from f32x4 to f8x4."; + + return success(); +} + +LogicalResult ConvertF32x4ToF6x4Op::verify() { + mlir::MLIRContext *ctx = getContext(); + + if (!llvm::isa(getDstTy())) + return emitOpError("Only ") + << mlir::Float6E2M3FNType::get(ctx) << " and " + << mlir::Float6E3M2FNType::get(ctx) + << " types are supported for conversions from f32x4 to f6x4."; + + return success(); +} + +LogicalResult ConvertF32x4ToF4x4Op::verify() { + mlir::MLIRContext *ctx = getContext(); + + if (!llvm::isa(getDstTy())) + return emitOpError("Only ") << mlir::Float4E2M1FNType::get(ctx) + << " type is supported for conversions from " + "f32x4 to f4x4."; + + return success(); +} + LogicalResult BulkStoreOp::verify() { if (getInitVal() != 0) return emitOpError("only 0 is supported for initVal, got ") << getInitVal(); @@ -2412,6 +2465,85 @@ Tcgen05CommitOp::getIntrinsicIDAndArgs(Operation &op, return TCGEN05_CP_2CTA(shape_mc, , is_2cta); \ }() +llvm::Intrinsic::ID ConvertF32x2ToF16x2Op::getIntrinsicID() { + bool hasRelu = getRelu(); + bool hasSatFinite = (getSat() == NVVM::SaturationMode::SATFINITE); + + if (hasRelu && hasSatFinite) + return llvm::Intrinsic::nvvm_ff2f16x2_rs_relu_satfinite; + if (hasRelu) + return llvm::Intrinsic::nvvm_ff2f16x2_rs_relu; + if (hasSatFinite) + return llvm::Intrinsic::nvvm_ff2f16x2_rs_satfinite; + return llvm::Intrinsic::nvvm_ff2f16x2_rs; +} + +llvm::Intrinsic::ID ConvertF32x2ToBF16x2Op::getIntrinsicID() { + bool hasRelu = getRelu(); + bool hasSatFinite = (getSat() == NVVM::SaturationMode::SATFINITE); + + if (hasRelu && hasSatFinite) + return llvm::Intrinsic::nvvm_ff2bf16x2_rs_relu_satfinite; + if (hasRelu) + return llvm::Intrinsic::nvvm_ff2bf16x2_rs_relu; + if (hasSatFinite) + return llvm::Intrinsic::nvvm_ff2bf16x2_rs_satfinite; + return llvm::Intrinsic::nvvm_ff2bf16x2_rs; +} + +llvm::Intrinsic::ID ConvertF32x4ToF8x4Op::getIntrinsicID() { + mlir::Type dstTy = getDstTy(); + bool hasRelu = getRelu(); + + return llvm::TypeSwitch(dstTy) + .Case([&](mlir::Float8E4M3FNType) { + return hasRelu ? llvm::Intrinsic::nvvm_f32x4_to_e4m3x4_rs_relu_satfinite + : llvm::Intrinsic::nvvm_f32x4_to_e4m3x4_rs_satfinite; + }) + .Case([&](mlir::Float8E5M2Type) { + return hasRelu ? llvm::Intrinsic::nvvm_f32x4_to_e5m2x4_rs_relu_satfinite + : llvm::Intrinsic::nvvm_f32x4_to_e5m2x4_rs_satfinite; + }) + .Default([](mlir::Type) { + llvm_unreachable("Invalid F8 type in ConvertF32x4ToF8x4Op"); + return llvm::Intrinsic::not_intrinsic; + }); +} + +llvm::Intrinsic::ID ConvertF32x4ToF6x4Op::getIntrinsicID() { + mlir::Type dstTy = getDstTy(); + bool hasRelu = getRelu(); + + return llvm::TypeSwitch(dstTy) + .Case([&](mlir::Float6E2M3FNType) { + return hasRelu ? llvm::Intrinsic::nvvm_f32x4_to_e2m3x4_rs_relu_satfinite + : llvm::Intrinsic::nvvm_f32x4_to_e2m3x4_rs_satfinite; + }) + .Case([&](mlir::Float6E3M2FNType) { + return hasRelu ? llvm::Intrinsic::nvvm_f32x4_to_e3m2x4_rs_relu_satfinite + : llvm::Intrinsic::nvvm_f32x4_to_e3m2x4_rs_satfinite; + }) + .Default([](mlir::Type) { + llvm_unreachable("Invalid F6 type in ConvertF32x4ToF6x4Op"); + return llvm::Intrinsic::not_intrinsic; + }); +} + +llvm::Intrinsic::ID ConvertF32x4ToF4x4Op::getIntrinsicID() { + mlir::Type dstTy = getDstTy(); + bool hasRelu = getRelu(); + + return llvm::TypeSwitch(dstTy) + .Case([&](mlir::Float4E2M1FNType) { + return hasRelu ? llvm::Intrinsic::nvvm_f32x4_to_e2m1x4_rs_relu_satfinite + : llvm::Intrinsic::nvvm_f32x4_to_e2m1x4_rs_satfinite; + }) + .Default([](mlir::Type) { + llvm_unreachable("Invalid F4 type in ConvertF32x4ToF4x4Op"); + return llvm::Intrinsic::not_intrinsic; + }); +} + llvm::Intrinsic::ID Tcgen05CpOp::getIntrinsicID(Operation &op) { auto curOp = cast(op); bool is2CTA = curOp.getGroup() == CTAGroupKind::CTA_2; diff --git a/mlir/test/Dialect/LLVMIR/nvvm/invalid-convert-stochastic-rounding.mlir b/mlir/test/Dialect/LLVMIR/nvvm/invalid-convert-stochastic-rounding.mlir new file mode 100644 index 0000000000000..35f5e1b3c8ba2 --- /dev/null +++ b/mlir/test/Dialect/LLVMIR/nvvm/invalid-convert-stochastic-rounding.mlir @@ -0,0 +1,90 @@ +// RUN: mlir-opt %s -split-input-file -verify-diagnostics + +// Test invalid target architecture (sm_100 instead of sm_100a) +gpu.module @invalid_arch_sm_100 [#nvvm.target] { + func.func @convert_rs() { + %f1 = llvm.mlir.constant(1.0 : f32) : f32 + %f2 = llvm.mlir.constant(2.0 : f32) : f32 + %rbits = llvm.mlir.constant(0x12345678 : i32) : i32 + // expected-error@+1 {{'nvvm.convert.f32x2.to.f16x2' op is not supported on sm_100}} + %res = nvvm.convert.f32x2.to.f16x2 %f1, %f2, %rbits : vector<2xf16> + return + } +} + +// ----- + +// Test that operations require stochastic rounding mode +llvm.func @invalid_rnd_mode_f16x2(%srcA : f32, %srcB : f32, %rbits : i32) -> vector<2xf16> { + // expected-error@+1 {{Only RS rounding mode is supported for conversions from f32x2 to f16x2.}} + %res = nvvm.convert.f32x2.to.f16x2 %srcA, %srcB, %rbits {rnd = #nvvm.fp_rnd_mode} : vector<2xf16> + llvm.return %res : vector<2xf16> +} + +// ----- + +llvm.func @invalid_rnd_mode_bf16x2(%srcA : f32, %srcB : f32, %rbits : i32) -> vector<2xbf16> { + // expected-error@+1 {{Only RS rounding mode is supported for conversions from f32x2 to bf16x2.}} + %res = nvvm.convert.f32x2.to.bf16x2 %srcA, %srcB, %rbits {rnd = #nvvm.fp_rnd_mode} : vector<2xbf16> + llvm.return %res : vector<2xbf16> +} + +// ----- + +// Test invalid destination types for f8x4 (should only accept f8E4M3FN, f8E5M2) +llvm.func @invalid_dst_type_f8x4_e3m4(%src : vector<4xf32>, %rbits : i32) -> vector<4xi8> { + // expected-error@+1 {{Only 'f8E4M3FN' and 'f8E5M2' types are supported for conversions from f32x4 to f8x4.}} + %res = nvvm.convert.f32x4.to.f8x4 %src, %rbits : vector<4xf32> -> vector<4xi8> (f8E3M4) + llvm.return %res : vector<4xi8> +} + +// ----- + +llvm.func @invalid_dst_type_f8x4_e8m0(%src : vector<4xf32>, %rbits : i32) -> vector<4xi8> { + // expected-error@+1 {{Only 'f8E4M3FN' and 'f8E5M2' types are supported for conversions from f32x4 to f8x4.}} + %res = nvvm.convert.f32x4.to.f8x4 %src, %rbits : vector<4xf32> -> vector<4xi8> (f8E8M0FNU) + llvm.return %res : vector<4xi8> +} + +// ----- + +// Test invalid destination types for f6x4 (should only accept f6E2M3FN, f6E3M2FN) +llvm.func @invalid_dst_type_f6x4_f8(%src : vector<4xf32>, %rbits : i32) -> vector<4xi8> { + // expected-error@+1 {{Only 'f6E2M3FN' and 'f6E3M2FN' types are supported for conversions from f32x4 to f6x4.}} + %res = nvvm.convert.f32x4.to.f6x4 %src, %rbits : vector<4xf32> -> vector<4xi8> (f8E4M3FN) + llvm.return %res : vector<4xi8> +} + +// ----- + +// Test invalid destination type for f4x4 (should only accept f4E2M1FN) +llvm.func @invalid_dst_type_f4x4_f6(%src : vector<4xf32>, %rbits : i32) -> i16 { + // expected-error@+1 {{Only 'f4E2M1FN' type is supported for conversions from f32x4 to f4x4.}} + %res = nvvm.convert.f32x4.to.f4x4 %src, %rbits : vector<4xf32> -> i16 (f6E2M3FN) + llvm.return %res : i16 +} + +// ----- + +// Test invalid rounding modes for non-stochastic ops +llvm.func @convert_float_to_tf32_rs_not_supported(%src : f32) -> i32 { + // expected-error @below {{Only {rn,rz,rna} rounding modes supported for ConvertFloatToTF32Op.}} + %res = nvvm.convert.float.to.tf32 %src {rnd = #nvvm.fp_rnd_mode} + llvm.return %res : i32 +} + +// ----- + +llvm.func @convert_f32x2_to_f8x2_rs_not_supported(%a : f32, %b : f32) { + // expected-error @below {{Only RN rounding mode is supported for conversions from f32x2 to 'f8E4M3FN' and 'f8E5M2' types}} + %res = nvvm.convert.f32x2.to.f8x2 %a, %b {rnd = #nvvm.fp_rnd_mode, sat = #nvvm.sat_mode} : i16 (f8E4M3FN) + llvm.return +} + +// ----- + +llvm.func @convert_bf16x2_to_f8x2_rs_not_supported(%src : vector<2xbf16>) { + // expected-error @below {{Only RZ and RP rounding modes are supported for conversions from bf16x2 to f8x2.}} + %res = nvvm.convert.bf16x2.to.f8x2 %src {rnd = #nvvm.fp_rnd_mode} : vector<2xbf16> -> i16 (f8E8M0FNU) + llvm.return +} diff --git a/mlir/test/Target/LLVMIR/nvvm/convert_stochastic_rounding.mlir b/mlir/test/Target/LLVMIR/nvvm/convert_stochastic_rounding.mlir new file mode 100644 index 0000000000000..b5bb22350dcd7 --- /dev/null +++ b/mlir/test/Target/LLVMIR/nvvm/convert_stochastic_rounding.mlir @@ -0,0 +1,182 @@ +// RUN: mlir-translate -mlir-to-llvmir %s | FileCheck %s + +// ----- + +// Test valid architectures work + +// Valid case on sm_100a +gpu.module @valid_f16x2_rs_sm_100a [#nvvm.target] { + func.func @convert_rs() { + %f1 = llvm.mlir.constant(1.0 : f32) : f32 + %f2 = llvm.mlir.constant(2.0 : f32) : f32 + %rbits = llvm.mlir.constant(0x12345678 : i32) : i32 + %res = nvvm.convert.f32x2.to.f16x2 %f1, %f2, %rbits : vector<2xf16> + return + } +} + +// Valid case on sm_103a +gpu.module @valid_bf16x2_rs_sm_103a [#nvvm.target] { + func.func @convert_rs() { + %f1 = llvm.mlir.constant(1.0 : f32) : f32 + %f2 = llvm.mlir.constant(2.0 : f32) : f32 + %rbits = llvm.mlir.constant(0 : i32) : i32 + %res = nvvm.convert.f32x2.to.bf16x2 %f1, %f2, %rbits : vector<2xbf16> + return + } +} + +// ----- + +// Test F32x2 -> F16x2 with stochastic rounding (.rs) + +// CHECK-LABEL: @convert_f32x2_to_f16x2_rs +llvm.func @convert_f32x2_to_f16x2_rs(%srcA : f32, %srcB : f32, %rbits : i32) -> vector<2xf16> { + // CHECK: %{{.*}} = call <2 x half> @llvm.nvvm.ff2f16x2.rs(float %{{.*}}, float %{{.*}}, i32 %{{.*}}) + %res = nvvm.convert.f32x2.to.f16x2 %srcA, %srcB, %rbits : vector<2xf16> + llvm.return %res : vector<2xf16> +} + +// CHECK-LABEL: @convert_f32x2_to_f16x2_rs_satfinite +llvm.func @convert_f32x2_to_f16x2_rs_satfinite(%srcA : f32, %srcB : f32, %rbits : i32) -> vector<2xf16> { + // CHECK: %{{.*}} = call <2 x half> @llvm.nvvm.ff2f16x2.rs.satfinite(float %{{.*}}, float %{{.*}}, i32 %{{.*}}) + %res = nvvm.convert.f32x2.to.f16x2 %srcA, %srcB, %rbits {sat = #nvvm.sat_mode} : vector<2xf16> + llvm.return %res : vector<2xf16> +} + +// CHECK-LABEL: @convert_f32x2_to_f16x2_rs_relu +llvm.func @convert_f32x2_to_f16x2_rs_relu(%srcA : f32, %srcB : f32, %rbits : i32) -> vector<2xf16> { + // CHECK: %{{.*}} = call <2 x half> @llvm.nvvm.ff2f16x2.rs.relu(float %{{.*}}, float %{{.*}}, i32 %{{.*}}) + %res = nvvm.convert.f32x2.to.f16x2 %srcA, %srcB, %rbits {relu = true} : vector<2xf16> + llvm.return %res : vector<2xf16> +} + +// CHECK-LABEL: @convert_f32x2_to_f16x2_rs_relu_satfinite +llvm.func @convert_f32x2_to_f16x2_rs_relu_satfinite(%srcA : f32, %srcB : f32, %rbits : i32) -> vector<2xf16> { + // CHECK: %{{.*}} = call <2 x half> @llvm.nvvm.ff2f16x2.rs.relu.satfinite(float %{{.*}}, float %{{.*}}, i32 %{{.*}}) + %res = nvvm.convert.f32x2.to.f16x2 %srcA, %srcB, %rbits {relu = true, sat = #nvvm.sat_mode} : vector<2xf16> + llvm.return %res : vector<2xf16> +} + +// ----- + +// Test F32x2 -> BF16x2 with stochastic rounding (.rs) + +// CHECK-LABEL: @convert_f32x2_to_bf16x2_rs +llvm.func @convert_f32x2_to_bf16x2_rs(%srcA : f32, %srcB : f32, %rbits : i32) -> vector<2xbf16> { + // CHECK: %{{.*}} = call <2 x bfloat> @llvm.nvvm.ff2bf16x2.rs(float %{{.*}}, float %{{.*}}, i32 %{{.*}}) + %res = nvvm.convert.f32x2.to.bf16x2 %srcA, %srcB, %rbits : vector<2xbf16> + llvm.return %res : vector<2xbf16> +} + +// CHECK-LABEL: @convert_f32x2_to_bf16x2_rs_satfinite +llvm.func @convert_f32x2_to_bf16x2_rs_satfinite(%srcA : f32, %srcB : f32, %rbits : i32) -> vector<2xbf16> { + // CHECK: %{{.*}} = call <2 x bfloat> @llvm.nvvm.ff2bf16x2.rs.satfinite(float %{{.*}}, float %{{.*}}, i32 %{{.*}}) + %res = nvvm.convert.f32x2.to.bf16x2 %srcA, %srcB, %rbits {sat = #nvvm.sat_mode} : vector<2xbf16> + llvm.return %res : vector<2xbf16> +} + +// CHECK-LABEL: @convert_f32x2_to_bf16x2_rs_relu +llvm.func @convert_f32x2_to_bf16x2_rs_relu(%srcA : f32, %srcB : f32, %rbits : i32) -> vector<2xbf16> { + // CHECK: %{{.*}} = call <2 x bfloat> @llvm.nvvm.ff2bf16x2.rs.relu(float %{{.*}}, float %{{.*}}, i32 %{{.*}}) + %res = nvvm.convert.f32x2.to.bf16x2 %srcA, %srcB, %rbits {relu = true} : vector<2xbf16> + llvm.return %res : vector<2xbf16> +} + +// CHECK-LABEL: @convert_f32x2_to_bf16x2_rs_relu_satfinite +llvm.func @convert_f32x2_to_bf16x2_rs_relu_satfinite(%srcA : f32, %srcB : f32, %rbits : i32) -> vector<2xbf16> { + // CHECK: %{{.*}} = call <2 x bfloat> @llvm.nvvm.ff2bf16x2.rs.relu.satfinite(float %{{.*}}, float %{{.*}}, i32 %{{.*}}) + %res = nvvm.convert.f32x2.to.bf16x2 %srcA, %srcB, %rbits {relu = true, sat = #nvvm.sat_mode} : vector<2xbf16> + llvm.return %res : vector<2xbf16> +} + +// ----- + +// Test F32x4 -> F8x4 (E4M3) with stochastic rounding (.rs) + +// CHECK-LABEL: @convert_f32x4_to_f8x4_e4m3_rs +llvm.func @convert_f32x4_to_f8x4_e4m3_rs(%src : vector<4xf32>, %rbits : i32) -> vector<4xi8> { + // CHECK: %{{.*}} = call <4 x i8> @llvm.nvvm.f32x4.to.e4m3x4.rs.satfinite(<4 x float> %{{.*}}, i32 %{{.*}}) + %res = nvvm.convert.f32x4.to.f8x4 %src, %rbits : vector<4xf32> -> vector<4xi8> (f8E4M3FN) + llvm.return %res : vector<4xi8> +} + +// CHECK-LABEL: @convert_f32x4_to_f8x4_e4m3_rs_relu +llvm.func @convert_f32x4_to_f8x4_e4m3_rs_relu(%src : vector<4xf32>, %rbits : i32) -> vector<4xi8> { + // CHECK: %{{.*}} = call <4 x i8> @llvm.nvvm.f32x4.to.e4m3x4.rs.relu.satfinite(<4 x float> %{{.*}}, i32 %{{.*}}) + %res = nvvm.convert.f32x4.to.f8x4 %src, %rbits {relu = true} : vector<4xf32> -> vector<4xi8> (f8E4M3FN) + llvm.return %res : vector<4xi8> +} + +// ----- + +// Test F32x4 -> F8x4 (E5M2) with stochastic rounding (.rs) + +// CHECK-LABEL: @convert_f32x4_to_f8x4_e5m2_rs +llvm.func @convert_f32x4_to_f8x4_e5m2_rs(%src : vector<4xf32>, %rbits : i32) -> vector<4xi8> { + // CHECK: %{{.*}} = call <4 x i8> @llvm.nvvm.f32x4.to.e5m2x4.rs.satfinite(<4 x float> %{{.*}}, i32 %{{.*}}) + %res = nvvm.convert.f32x4.to.f8x4 %src, %rbits : vector<4xf32> -> vector<4xi8> (f8E5M2) + llvm.return %res : vector<4xi8> +} + +// CHECK-LABEL: @convert_f32x4_to_f8x4_e5m2_rs_relu +llvm.func @convert_f32x4_to_f8x4_e5m2_rs_relu(%src : vector<4xf32>, %rbits : i32) -> vector<4xi8> { + // CHECK: %{{.*}} = call <4 x i8> @llvm.nvvm.f32x4.to.e5m2x4.rs.relu.satfinite(<4 x float> %{{.*}}, i32 %{{.*}}) + %res = nvvm.convert.f32x4.to.f8x4 %src, %rbits {relu = true} : vector<4xf32> -> vector<4xi8> (f8E5M2) + llvm.return %res : vector<4xi8> +} + +// ----- + +// Test F32x4 -> F6x4 (E2M3) with stochastic rounding (.rs) + +// CHECK-LABEL: @convert_f32x4_to_f6x4_e2m3_rs +llvm.func @convert_f32x4_to_f6x4_e2m3_rs(%src : vector<4xf32>, %rbits : i32) -> vector<4xi8> { + // CHECK: %{{.*}} = call <4 x i8> @llvm.nvvm.f32x4.to.e2m3x4.rs.satfinite(<4 x float> %{{.*}}, i32 %{{.*}}) + %res = nvvm.convert.f32x4.to.f6x4 %src, %rbits : vector<4xf32> -> vector<4xi8> (f6E2M3FN) + llvm.return %res : vector<4xi8> +} + +// CHECK-LABEL: @convert_f32x4_to_f6x4_e2m3_rs_relu +llvm.func @convert_f32x4_to_f6x4_e2m3_rs_relu(%src : vector<4xf32>, %rbits : i32) -> vector<4xi8> { + // CHECK: %{{.*}} = call <4 x i8> @llvm.nvvm.f32x4.to.e2m3x4.rs.relu.satfinite(<4 x float> %{{.*}}, i32 %{{.*}}) + %res = nvvm.convert.f32x4.to.f6x4 %src, %rbits {relu = true} : vector<4xf32> -> vector<4xi8> (f6E2M3FN) + llvm.return %res : vector<4xi8> +} + +// ----- + +// Test F32x4 -> F6x4 (E3M2) with stochastic rounding (.rs) + +// CHECK-LABEL: @convert_f32x4_to_f6x4_e3m2_rs +llvm.func @convert_f32x4_to_f6x4_e3m2_rs(%src : vector<4xf32>, %rbits : i32) -> vector<4xi8> { + // CHECK: %{{.*}} = call <4 x i8> @llvm.nvvm.f32x4.to.e3m2x4.rs.satfinite(<4 x float> %{{.*}}, i32 %{{.*}}) + %res = nvvm.convert.f32x4.to.f6x4 %src, %rbits : vector<4xf32> -> vector<4xi8> (f6E3M2FN) + llvm.return %res : vector<4xi8> +} + +// CHECK-LABEL: @convert_f32x4_to_f6x4_e3m2_rs_relu +llvm.func @convert_f32x4_to_f6x4_e3m2_rs_relu(%src : vector<4xf32>, %rbits : i32) -> vector<4xi8> { + // CHECK: %{{.*}} = call <4 x i8> @llvm.nvvm.f32x4.to.e3m2x4.rs.relu.satfinite(<4 x float> %{{.*}}, i32 %{{.*}}) + %res = nvvm.convert.f32x4.to.f6x4 %src, %rbits {relu = true} : vector<4xf32> -> vector<4xi8> (f6E3M2FN) + llvm.return %res : vector<4xi8> +} + +// ----- + +// Test F32x4 -> F4x4 (E2M1) with stochastic rounding (.rs) + +// CHECK-LABEL: @convert_f32x4_to_f4x4_e2m1_rs +llvm.func @convert_f32x4_to_f4x4_e2m1_rs(%src : vector<4xf32>, %rbits : i32) -> i16 { + // CHECK: %{{.*}} = call i16 @llvm.nvvm.f32x4.to.e2m1x4.rs.satfinite(<4 x float> %{{.*}}, i32 %{{.*}}) + %res = nvvm.convert.f32x4.to.f4x4 %src, %rbits : vector<4xf32> -> i16 (f4E2M1FN) + llvm.return %res : i16 +} + +// CHECK-LABEL: @convert_f32x4_to_f4x4_e2m1_rs_relu +llvm.func @convert_f32x4_to_f4x4_e2m1_rs_relu(%src : vector<4xf32>, %rbits : i32) -> i16 { + // CHECK: %{{.*}} = call i16 @llvm.nvvm.f32x4.to.e2m1x4.rs.relu.satfinite(<4 x float> %{{.*}}, i32 %{{.*}}) + %res = nvvm.convert.f32x4.to.f4x4 %src, %rbits {relu = true} : vector<4xf32> -> i16 (f4E2M1FN) + llvm.return %res : i16 +} +