Skip to content

Commit

Permalink
adding bf16 support to NVPTX
Browse files Browse the repository at this point in the history
Currently, bf16 has been scatteredly added to the PTX codegen. This patch aims to complete the set of instructions and code path required to support bf16 data type.

Reviewed By: tra

Differential Revision: https://reviews.llvm.org/D144911

Co-authored-by: Artem Belevich <tra@google.com>
  • Loading branch information
kushanam and Artem-B committed Jun 28, 2023
1 parent 85bdea0 commit 250f2bb
Show file tree
Hide file tree
Showing 24 changed files with 1,706 additions and 370 deletions.
72 changes: 40 additions & 32 deletions clang/include/clang/Basic/BuiltinsNVPTX.def
Original file line number Diff line number Diff line change
Expand Up @@ -173,16 +173,20 @@ TARGET_BUILTIN(__nvvm_fmin_nan_xorsign_abs_f16x2, "V2hV2hV2h", "",
AND(SM_86, PTX72))
TARGET_BUILTIN(__nvvm_fmin_ftz_nan_xorsign_abs_f16x2, "V2hV2hV2h", "",
AND(SM_86, PTX72))
TARGET_BUILTIN(__nvvm_fmin_bf16, "UsUsUs", "", AND(SM_80, PTX70))
TARGET_BUILTIN(__nvvm_fmin_nan_bf16, "UsUsUs", "", AND(SM_80, PTX70))
TARGET_BUILTIN(__nvvm_fmin_xorsign_abs_bf16, "UsUsUs", "", AND(SM_86, PTX72))
TARGET_BUILTIN(__nvvm_fmin_nan_xorsign_abs_bf16, "UsUsUs", "",
TARGET_BUILTIN(__nvvm_fmin_bf16, "yyy", "", AND(SM_80, PTX70))
TARGET_BUILTIN(__nvvm_fmin_ftz_bf16, "yyy", "", AND(SM_80, PTX70))
TARGET_BUILTIN(__nvvm_fmin_nan_bf16, "yyy", "", AND(SM_80, PTX70))
TARGET_BUILTIN(__nvvm_fmin_ftz_nan_bf16, "yyy", "", AND(SM_80, PTX70))
TARGET_BUILTIN(__nvvm_fmin_xorsign_abs_bf16, "yyy", "", AND(SM_86, PTX72))
TARGET_BUILTIN(__nvvm_fmin_nan_xorsign_abs_bf16, "yyy", "",
AND(SM_86, PTX72))
TARGET_BUILTIN(__nvvm_fmin_bf16x2, "ZUiZUiZUi", "", AND(SM_80, PTX70))
TARGET_BUILTIN(__nvvm_fmin_nan_bf16x2, "ZUiZUiZUi", "", AND(SM_80, PTX70))
TARGET_BUILTIN(__nvvm_fmin_xorsign_abs_bf16x2, "ZUiZUiZUi", "",
TARGET_BUILTIN(__nvvm_fmin_bf16x2, "V2yV2yV2y", "", AND(SM_80, PTX70))
TARGET_BUILTIN(__nvvm_fmin_ftz_bf16x2, "V2yV2yV2y", "", AND(SM_80, PTX70))
TARGET_BUILTIN(__nvvm_fmin_nan_bf16x2, "V2yV2yV2y", "", AND(SM_80, PTX70))
TARGET_BUILTIN(__nvvm_fmin_ftz_nan_bf16x2, "V2yV2yV2y", "", AND(SM_80, PTX70))
TARGET_BUILTIN(__nvvm_fmin_xorsign_abs_bf16x2, "V2yV2yV2y", "",
AND(SM_86, PTX72))
TARGET_BUILTIN(__nvvm_fmin_nan_xorsign_abs_bf16x2, "ZUiZUiZUi", "",
TARGET_BUILTIN(__nvvm_fmin_nan_xorsign_abs_bf16x2, "V2yV2yV2y", "",
AND(SM_86, PTX72))
BUILTIN(__nvvm_fmin_f, "fff", "")
BUILTIN(__nvvm_fmin_ftz_f, "fff", "")
Expand Down Expand Up @@ -215,16 +219,20 @@ TARGET_BUILTIN(__nvvm_fmax_nan_xorsign_abs_f16x2, "V2hV2hV2h", "",
AND(SM_86, PTX72))
TARGET_BUILTIN(__nvvm_fmax_ftz_nan_xorsign_abs_f16x2, "V2hV2hV2h", "",
AND(SM_86, PTX72))
TARGET_BUILTIN(__nvvm_fmax_bf16, "UsUsUs", "", AND(SM_80, PTX70))
TARGET_BUILTIN(__nvvm_fmax_nan_bf16, "UsUsUs", "", AND(SM_80, PTX70))
TARGET_BUILTIN(__nvvm_fmax_xorsign_abs_bf16, "UsUsUs", "", AND(SM_86, PTX72))
TARGET_BUILTIN(__nvvm_fmax_nan_xorsign_abs_bf16, "UsUsUs", "",
TARGET_BUILTIN(__nvvm_fmax_bf16, "yyy", "", AND(SM_80, PTX70))
TARGET_BUILTIN(__nvvm_fmax_ftz_bf16, "yyy", "", AND(SM_80, PTX70))
TARGET_BUILTIN(__nvvm_fmax_nan_bf16, "yyy", "", AND(SM_80, PTX70))
TARGET_BUILTIN(__nvvm_fmax_ftz_nan_bf16, "yyy", "", AND(SM_80, PTX70))
TARGET_BUILTIN(__nvvm_fmax_xorsign_abs_bf16, "yyy", "", AND(SM_86, PTX72))
TARGET_BUILTIN(__nvvm_fmax_nan_xorsign_abs_bf16, "yyy", "",
AND(SM_86, PTX72))
TARGET_BUILTIN(__nvvm_fmax_bf16x2, "ZUiZUiZUi", "", AND(SM_80, PTX70))
TARGET_BUILTIN(__nvvm_fmax_nan_bf16x2, "ZUiZUiZUi", "", AND(SM_80, PTX70))
TARGET_BUILTIN(__nvvm_fmax_xorsign_abs_bf16x2, "ZUiZUiZUi", "",
TARGET_BUILTIN(__nvvm_fmax_bf16x2, "V2yV2yV2y", "", AND(SM_80, PTX70))
TARGET_BUILTIN(__nvvm_fmax_ftz_bf16x2, "V2yV2yV2y", "", AND(SM_80, PTX70))
TARGET_BUILTIN(__nvvm_fmax_nan_bf16x2, "V2yV2yV2y", "", AND(SM_80, PTX70))
TARGET_BUILTIN(__nvvm_fmax_ftz_nan_bf16x2, "V2yV2yV2y", "", AND(SM_80, PTX70))
TARGET_BUILTIN(__nvvm_fmax_xorsign_abs_bf16x2, "V2yV2yV2y", "",
AND(SM_86, PTX72))
TARGET_BUILTIN(__nvvm_fmax_nan_xorsign_abs_bf16x2, "ZUiZUiZUi", "",
TARGET_BUILTIN(__nvvm_fmax_nan_xorsign_abs_bf16x2, "V2yV2yV2y", "",
AND(SM_86, PTX72))
BUILTIN(__nvvm_fmax_f, "fff", "")
BUILTIN(__nvvm_fmax_ftz_f, "fff", "")
Expand Down Expand Up @@ -352,10 +360,10 @@ TARGET_BUILTIN(__nvvm_fma_rn_sat_f16x2, "V2hV2hV2hV2h", "", AND(SM_53, PTX42))
TARGET_BUILTIN(__nvvm_fma_rn_ftz_sat_f16x2, "V2hV2hV2hV2h", "", AND(SM_53, PTX42))
TARGET_BUILTIN(__nvvm_fma_rn_relu_f16x2, "V2hV2hV2hV2h", "", AND(SM_80, PTX70))
TARGET_BUILTIN(__nvvm_fma_rn_ftz_relu_f16x2, "V2hV2hV2hV2h", "", AND(SM_80, PTX70))
TARGET_BUILTIN(__nvvm_fma_rn_bf16, "UsUsUsUs", "", AND(SM_80, PTX70))
TARGET_BUILTIN(__nvvm_fma_rn_relu_bf16, "UsUsUsUs", "", AND(SM_80, PTX70))
TARGET_BUILTIN(__nvvm_fma_rn_bf16x2, "ZUiZUiZUiZUi", "", AND(SM_80, PTX70))
TARGET_BUILTIN(__nvvm_fma_rn_relu_bf16x2, "ZUiZUiZUiZUi", "", AND(SM_80, PTX70))
TARGET_BUILTIN(__nvvm_fma_rn_bf16, "yyyy", "", AND(SM_80, PTX70))
TARGET_BUILTIN(__nvvm_fma_rn_relu_bf16, "yyyy", "", AND(SM_80, PTX70))
TARGET_BUILTIN(__nvvm_fma_rn_bf16x2, "V2yV2yV2yV2y", "", AND(SM_80, PTX70))
TARGET_BUILTIN(__nvvm_fma_rn_relu_bf16x2, "V2yV2yV2yV2y", "", AND(SM_80, PTX70))
BUILTIN(__nvvm_fma_rn_ftz_f, "ffff", "")
BUILTIN(__nvvm_fma_rn_f, "ffff", "")
BUILTIN(__nvvm_fma_rz_ftz_f, "ffff", "")
Expand Down Expand Up @@ -543,20 +551,20 @@ BUILTIN(__nvvm_ull2d_rp, "dULLi", "")
BUILTIN(__nvvm_f2h_rn_ftz, "Usf", "")
BUILTIN(__nvvm_f2h_rn, "Usf", "")

TARGET_BUILTIN(__nvvm_ff2bf16x2_rn, "ZUiff", "", AND(SM_80,PTX70))
TARGET_BUILTIN(__nvvm_ff2bf16x2_rn_relu, "ZUiff", "", AND(SM_80,PTX70))
TARGET_BUILTIN(__nvvm_ff2bf16x2_rz, "ZUiff", "", AND(SM_80,PTX70))
TARGET_BUILTIN(__nvvm_ff2bf16x2_rz_relu, "ZUiff", "", AND(SM_80,PTX70))
TARGET_BUILTIN(__nvvm_ff2bf16x2_rn, "V2yff", "", AND(SM_80,PTX70))
TARGET_BUILTIN(__nvvm_ff2bf16x2_rn_relu, "V2yff", "", AND(SM_80,PTX70))
TARGET_BUILTIN(__nvvm_ff2bf16x2_rz, "V2yff", "", AND(SM_80,PTX70))
TARGET_BUILTIN(__nvvm_ff2bf16x2_rz_relu, "V2yff", "", AND(SM_80,PTX70))

TARGET_BUILTIN(__nvvm_ff2f16x2_rn, "V2hff", "", AND(SM_80,PTX70))
TARGET_BUILTIN(__nvvm_ff2f16x2_rn_relu, "V2hff", "", AND(SM_80,PTX70))
TARGET_BUILTIN(__nvvm_ff2f16x2_rz, "V2hff", "", AND(SM_80,PTX70))
TARGET_BUILTIN(__nvvm_ff2f16x2_rz_relu, "V2hff", "", AND(SM_80,PTX70))

TARGET_BUILTIN(__nvvm_f2bf16_rn, "ZUsf", "", AND(SM_80,PTX70))
TARGET_BUILTIN(__nvvm_f2bf16_rn_relu, "ZUsf", "", AND(SM_80,PTX70))
TARGET_BUILTIN(__nvvm_f2bf16_rz, "ZUsf", "", AND(SM_80,PTX70))
TARGET_BUILTIN(__nvvm_f2bf16_rz_relu, "ZUsf", "", AND(SM_80,PTX70))
TARGET_BUILTIN(__nvvm_f2bf16_rn, "yf", "", AND(SM_80,PTX70))
TARGET_BUILTIN(__nvvm_f2bf16_rn_relu, "yf", "", AND(SM_80,PTX70))
TARGET_BUILTIN(__nvvm_f2bf16_rz, "yf", "", AND(SM_80,PTX70))
TARGET_BUILTIN(__nvvm_f2bf16_rz_relu, "yf", "", AND(SM_80,PTX70))

TARGET_BUILTIN(__nvvm_f2tf32_rna, "ZUif", "", AND(SM_80,PTX70))

Expand Down Expand Up @@ -1024,10 +1032,10 @@ TARGET_BUILTIN(__nvvm_cp_async_wait_all, "v", "", AND(SM_80,PTX70))


// bf16, bf16x2 abs, neg
TARGET_BUILTIN(__nvvm_abs_bf16, "UsUs", "", AND(SM_80,PTX70))
TARGET_BUILTIN(__nvvm_abs_bf16x2, "ZUiZUi", "", AND(SM_80,PTX70))
TARGET_BUILTIN(__nvvm_neg_bf16, "UsUs", "", AND(SM_80,PTX70))
TARGET_BUILTIN(__nvvm_neg_bf16x2, "ZUiZUi", "", AND(SM_80,PTX70))
TARGET_BUILTIN(__nvvm_abs_bf16, "yy", "", AND(SM_80,PTX70))
TARGET_BUILTIN(__nvvm_abs_bf16x2, "V2yV2y", "", AND(SM_80,PTX70))
TARGET_BUILTIN(__nvvm_neg_bf16, "yy", "", AND(SM_80,PTX70))
TARGET_BUILTIN(__nvvm_neg_bf16x2, "V2yV2y", "", AND(SM_80,PTX70))

TARGET_BUILTIN(__nvvm_mapa, "v*v*i", "", AND(SM_90, PTX78))
TARGET_BUILTIN(__nvvm_mapa_shared_cluster, "v*3v*3i", "", AND(SM_90, PTX78))
Expand Down
114 changes: 65 additions & 49 deletions clang/test/CodeGen/builtins-nvptx.c
Original file line number Diff line number Diff line change
Expand Up @@ -899,13 +899,13 @@ __device__ void nvvm_async_copy(__attribute__((address_space(3))) void* dst, __a
// CHECK-LABEL: nvvm_cvt_sm80
__device__ void nvvm_cvt_sm80() {
#if __CUDA_ARCH__ >= 800
// CHECK_PTX70_SM80: call i32 @llvm.nvvm.ff2bf16x2.rn(float 1.000000e+00, float 1.000000e+00)
// CHECK_PTX70_SM80: call <2 x bfloat> @llvm.nvvm.ff2bf16x2.rn(float 1.000000e+00, float 1.000000e+00)
__nvvm_ff2bf16x2_rn(1, 1);
// CHECK_PTX70_SM80: call i32 @llvm.nvvm.ff2bf16x2.rn.relu(float 1.000000e+00, float 1.000000e+00)
// CHECK_PTX70_SM80: call <2 x bfloat> @llvm.nvvm.ff2bf16x2.rn.relu(float 1.000000e+00, float 1.000000e+00)
__nvvm_ff2bf16x2_rn_relu(1, 1);
// CHECK_PTX70_SM80: call i32 @llvm.nvvm.ff2bf16x2.rz(float 1.000000e+00, float 1.000000e+00)
// CHECK_PTX70_SM80: call <2 x bfloat> @llvm.nvvm.ff2bf16x2.rz(float 1.000000e+00, float 1.000000e+00)
__nvvm_ff2bf16x2_rz(1, 1);
// CHECK_PTX70_SM80: call i32 @llvm.nvvm.ff2bf16x2.rz.relu(float 1.000000e+00, float 1.000000e+00)
// CHECK_PTX70_SM80: call <2 x bfloat> @llvm.nvvm.ff2bf16x2.rz.relu(float 1.000000e+00, float 1.000000e+00)
__nvvm_ff2bf16x2_rz_relu(1, 1);

// CHECK_PTX70_SM80: call <2 x half> @llvm.nvvm.ff2f16x2.rn(float 1.000000e+00, float 1.000000e+00)
Expand All @@ -917,13 +917,13 @@ __device__ void nvvm_cvt_sm80() {
// CHECK_PTX70_SM80: call <2 x half> @llvm.nvvm.ff2f16x2.rz.relu(float 1.000000e+00, float 1.000000e+00)
__nvvm_ff2f16x2_rz_relu(1, 1);

// CHECK_PTX70_SM80: call i16 @llvm.nvvm.f2bf16.rn(float 1.000000e+00)
// CHECK_PTX70_SM80: call bfloat @llvm.nvvm.f2bf16.rn(float 1.000000e+00)
__nvvm_f2bf16_rn(1);
// CHECK_PTX70_SM80: call i16 @llvm.nvvm.f2bf16.rn.relu(float 1.000000e+00)
// CHECK_PTX70_SM80: call bfloat @llvm.nvvm.f2bf16.rn.relu(float 1.000000e+00)
__nvvm_f2bf16_rn_relu(1);
// CHECK_PTX70_SM80: call i16 @llvm.nvvm.f2bf16.rz(float 1.000000e+00)
// CHECK_PTX70_SM80: call bfloat @llvm.nvvm.f2bf16.rz(float 1.000000e+00)
__nvvm_f2bf16_rz(1);
// CHECK_PTX70_SM80: call i16 @llvm.nvvm.f2bf16.rz.relu(float 1.000000e+00)
// CHECK_PTX70_SM80: call bfloat @llvm.nvvm.f2bf16.rz.relu(float 1.000000e+00)
__nvvm_f2bf16_rz_relu(1);

// CHECK_PTX70_SM80: call i32 @llvm.nvvm.f2tf32.rna(float 1.000000e+00)
Expand All @@ -932,32 +932,32 @@ __device__ void nvvm_cvt_sm80() {
// CHECK: ret void
}

#define NAN32 0x7FBFFFFF
#define NAN16 (__bf16)0x7FBF
#define BF16 (__bf16)0.1f
#define BF16_2 (__bf16)0.2f
#define NANBF16 (__bf16)0xFFC1
#define BF16X2 {(__bf16)0.1f, (__bf16)0.1f}
#define BF16X2_2 {(__bf16)0.2f, (__bf16)0.2f}
#define NANBF16X2 {NANBF16, NANBF16}

// CHECK-LABEL: nvvm_abs_neg_bf16_bf16x2_sm80
__device__ void nvvm_abs_neg_bf16_bf16x2_sm80() {
#if __CUDA_ARCH__ >= 800

// CHECK_PTX70_SM80: call i16 @llvm.nvvm.abs.bf16(i16 -1)
__nvvm_abs_bf16(0xFFFF);
// CHECK_PTX70_SM80: call i32 @llvm.nvvm.abs.bf16x2(i32 -1)
__nvvm_abs_bf16x2(0xFFFFFFFF);
// CHECK_PTX70_SM80: call bfloat @llvm.nvvm.abs.bf16(bfloat 0xR3DCD)
__nvvm_abs_bf16(BF16);
// CHECK_PTX70_SM80: call <2 x bfloat> @llvm.nvvm.abs.bf16x2(<2 x bfloat> <bfloat 0xR3DCD, bfloat 0xR3DCD>)
__nvvm_abs_bf16x2(BF16X2);

// CHECK_PTX70_SM80: call i16 @llvm.nvvm.neg.bf16(i16 -1)
__nvvm_neg_bf16(0xFFFF);
// CHECK_PTX70_SM80: call i32 @llvm.nvvm.neg.bf16x2(i32 -1)
__nvvm_neg_bf16x2(0xFFFFFFFF);
// CHECK_PTX70_SM80: call bfloat @llvm.nvvm.neg.bf16(bfloat 0xR3DCD)
__nvvm_neg_bf16(BF16);
// CHECK_PTX70_SM80: call <2 x bfloat> @llvm.nvvm.neg.bf16x2(<2 x bfloat> <bfloat 0xR3DCD, bfloat 0xR3DCD>)
__nvvm_neg_bf16x2(BF16X2);
#endif
// CHECK: ret void
}

#define NAN32 0x7FBFFFFF
#define NAN16 0x7FBF
#define BF16 0x1234
#define BF16_2 0x4321
#define NANBF16 0xFFC1
#define BF16X2 0x12341234
#define BF16X2_2 0x32343234
#define NANBF16X2 0xFFC1FFC1

// CHECK-LABEL: nvvm_min_max_sm80
__device__ void nvvm_min_max_sm80() {
#if __CUDA_ARCH__ >= 800
Expand All @@ -967,14 +967,22 @@ __device__ void nvvm_min_max_sm80() {
// CHECK_PTX70_SM80: call float @llvm.nvvm.fmin.ftz.nan.f
__nvvm_fmin_ftz_nan_f(0.1f, (float)NAN32);

// CHECK_PTX70_SM80: call i16 @llvm.nvvm.fmin.bf16
// CHECK_PTX70_SM80: call bfloat @llvm.nvvm.fmin.bf16
__nvvm_fmin_bf16(BF16, BF16_2);
// CHECK_PTX70_SM80: call i16 @llvm.nvvm.fmin.nan.bf16
// CHECK_PTX70_SM80: call bfloat @llvm.nvvm.fmin.ftz.bf16
__nvvm_fmin_ftz_bf16(BF16, BF16_2);
// CHECK_PTX70_SM80: call bfloat @llvm.nvvm.fmin.nan.bf16
__nvvm_fmin_nan_bf16(BF16, NANBF16);
// CHECK_PTX70_SM80: call i32 @llvm.nvvm.fmin.bf16x2
// CHECK_PTX70_SM80: call bfloat @llvm.nvvm.fmin.ftz.nan.bf16
__nvvm_fmin_ftz_nan_bf16(BF16, NANBF16);
// CHECK_PTX70_SM80: call <2 x bfloat> @llvm.nvvm.fmin.bf16x2
__nvvm_fmin_bf16x2(BF16X2, BF16X2_2);
// CHECK_PTX70_SM80: call i32 @llvm.nvvm.fmin.nan.bf16x2
// CHECK_PTX70_SM80: call <2 x bfloat> @llvm.nvvm.fmin.ftz.bf16x2
__nvvm_fmin_ftz_bf16x2(BF16X2, BF16X2_2);
// CHECK_PTX70_SM80: call <2 x bfloat> @llvm.nvvm.fmin.nan.bf16x2
__nvvm_fmin_nan_bf16x2(BF16X2, NANBF16X2);
// CHECK_PTX70_SM80: call <2 x bfloat> @llvm.nvvm.fmin.ftz.nan.bf16x2
__nvvm_fmin_ftz_nan_bf16x2(BF16X2, NANBF16X2);
// CHECK_PTX70_SM80: call float @llvm.nvvm.fmax.nan.f
__nvvm_fmax_nan_f(0.1f, 0.11f);
// CHECK_PTX70_SM80: call float @llvm.nvvm.fmax.ftz.nan.f
Expand All @@ -984,14 +992,22 @@ __device__ void nvvm_min_max_sm80() {
__nvvm_fmax_nan_f(0.1f, (float)NAN32);
// CHECK_PTX70_SM80: call float @llvm.nvvm.fmax.ftz.nan.f
__nvvm_fmax_ftz_nan_f(0.1f, (float)NAN32);
// CHECK_PTX70_SM80: call i16 @llvm.nvvm.fmax.bf16
// CHECK_PTX70_SM80: call bfloat @llvm.nvvm.fmax.bf16
__nvvm_fmax_bf16(BF16, BF16_2);
// CHECK_PTX70_SM80: call i16 @llvm.nvvm.fmax.nan.bf16
// CHECK_PTX70_SM80: call bfloat @llvm.nvvm.fmax.ftz.bf16
__nvvm_fmax_ftz_bf16(BF16, BF16_2);
// CHECK_PTX70_SM80: call bfloat @llvm.nvvm.fmax.nan.bf16
__nvvm_fmax_nan_bf16(BF16, NANBF16);
// CHECK_PTX70_SM80: call i32 @llvm.nvvm.fmax.bf16x2
// CHECK_PTX70_SM80: call bfloat @llvm.nvvm.fmax.ftz.nan.bf16
__nvvm_fmax_ftz_nan_bf16(BF16, NANBF16);
// CHECK_PTX70_SM80: call <2 x bfloat> @llvm.nvvm.fmax.bf16x2
__nvvm_fmax_bf16x2(BF16X2, BF16X2_2);
// CHECK_PTX70_SM80: call i32 @llvm.nvvm.fmax.nan.bf16x2
// CHECK_PTX70_SM80: call <2 x bfloat> @llvm.nvvm.fmax.ftz.bf16x2
__nvvm_fmax_ftz_bf16x2(BF16X2, BF16X2_2);
// CHECK_PTX70_SM80: call <2 x bfloat> @llvm.nvvm.fmax.nan.bf16x2
__nvvm_fmax_nan_bf16x2(NANBF16X2, BF16X2);
// CHECK_PTX70_SM80: call <2 x bfloat> @llvm.nvvm.fmax.ftz.nan.bf16x2
__nvvm_fmax_ftz_nan_bf16x2(NANBF16X2, BF16X2);
// CHECK_PTX70_SM80: call float @llvm.nvvm.fmax.nan.f
__nvvm_fmax_nan_f(0.1f, (float)NAN32);
// CHECK_PTX70_SM80: call float @llvm.nvvm.fmax.ftz.nan.f
Expand All @@ -1004,14 +1020,14 @@ __device__ void nvvm_min_max_sm80() {
// CHECK-LABEL: nvvm_fma_bf16_bf16x2_sm80
__device__ void nvvm_fma_bf16_bf16x2_sm80() {
#if __CUDA_ARCH__ >= 800
// CHECK_PTX70_SM80: call i16 @llvm.nvvm.fma.rn.bf16
__nvvm_fma_rn_bf16(0x1234, 0x7FBF, 0x1234);
// CHECK_PTX70_SM80: call i16 @llvm.nvvm.fma.rn.relu.bf16
__nvvm_fma_rn_relu_bf16(0x1234, 0x7FBF, 0x1234);
// CHECK_PTX70_SM80: call i32 @llvm.nvvm.fma.rn.bf16x2
__nvvm_fma_rn_bf16x2(0x7FBFFFFF, 0xFFFFFFFF, 0x7FBFFFFF);
// CHECK_PTX70_SM80: call i32 @llvm.nvvm.fma.rn.relu.bf16x2
__nvvm_fma_rn_relu_bf16x2(0x7FBFFFFF, 0xFFFFFFFF, 0x7FBFFFFF);
// CHECK_PTX70_SM80: call bfloat @llvm.nvvm.fma.rn.bf16
__nvvm_fma_rn_bf16(BF16, BF16_2, BF16_2);
// CHECK_PTX70_SM80: call bfloat @llvm.nvvm.fma.rn.relu.bf16
__nvvm_fma_rn_relu_bf16(BF16, BF16_2, BF16_2);
// CHECK_PTX70_SM80: call <2 x bfloat> @llvm.nvvm.fma.rn.bf16x2
__nvvm_fma_rn_bf16x2(BF16X2, BF16X2_2, BF16X2_2);
// CHECK_PTX70_SM80: call <2 x bfloat> @llvm.nvvm.fma.rn.relu.bf16x2
__nvvm_fma_rn_relu_bf16x2(BF16X2, BF16X2_2, BF16X2_2);
#endif
// CHECK: ret void
}
Expand All @@ -1020,13 +1036,13 @@ __device__ void nvvm_fma_bf16_bf16x2_sm80() {
__device__ void nvvm_min_max_sm86() {
#if __CUDA_ARCH__ >= 860

// CHECK_PTX72_SM86: call i16 @llvm.nvvm.fmin.xorsign.abs.bf16
// CHECK_PTX72_SM86: call bfloat @llvm.nvvm.fmin.xorsign.abs.bf16
__nvvm_fmin_xorsign_abs_bf16(BF16, BF16_2);
// CHECK_PTX72_SM86: call i16 @llvm.nvvm.fmin.nan.xorsign.abs.bf16
// CHECK_PTX72_SM86: call bfloat @llvm.nvvm.fmin.nan.xorsign.abs.bf16
__nvvm_fmin_nan_xorsign_abs_bf16(BF16, NANBF16);
// CHECK_PTX72_SM86: call i32 @llvm.nvvm.fmin.xorsign.abs.bf16x2
// CHECK_PTX72_SM86: call <2 x bfloat> @llvm.nvvm.fmin.xorsign.abs.bf16x2
__nvvm_fmin_xorsign_abs_bf16x2(BF16X2, BF16X2_2);
// CHECK_PTX72_SM86: call i32 @llvm.nvvm.fmin.nan.xorsign.abs.bf16x2
// CHECK_PTX72_SM86: call <2 x bfloat> @llvm.nvvm.fmin.nan.xorsign.abs.bf16x2
__nvvm_fmin_nan_xorsign_abs_bf16x2(BF16X2, NANBF16X2);
// CHECK_PTX72_SM86: call float @llvm.nvvm.fmin.xorsign.abs.f
__nvvm_fmin_xorsign_abs_f(-0.1f, 0.1f);
Expand All @@ -1037,13 +1053,13 @@ __device__ void nvvm_min_max_sm86() {
// CHECK_PTX72_SM86: call float @llvm.nvvm.fmin.ftz.nan.xorsign.abs.f
__nvvm_fmin_ftz_nan_xorsign_abs_f(-0.1f, (float)NAN32);

// CHECK_PTX72_SM86: call i16 @llvm.nvvm.fmax.xorsign.abs.bf16
// CHECK_PTX72_SM86: call bfloat @llvm.nvvm.fmax.xorsign.abs.bf16
__nvvm_fmax_xorsign_abs_bf16(BF16, BF16_2);
// CHECK_PTX72_SM86: call i16 @llvm.nvvm.fmax.nan.xorsign.abs.bf16
// CHECK_PTX72_SM86: call bfloat @llvm.nvvm.fmax.nan.xorsign.abs.bf16
__nvvm_fmax_nan_xorsign_abs_bf16(BF16, NANBF16);
// CHECK_PTX72_SM86: call i32 @llvm.nvvm.fmax.xorsign.abs.bf16x2
// CHECK_PTX72_SM86: call <2 x bfloat> @llvm.nvvm.fmax.xorsign.abs.bf16x2
__nvvm_fmax_xorsign_abs_bf16x2(BF16X2, BF16X2_2);
// CHECK_PTX72_SM86: call i32 @llvm.nvvm.fmax.nan.xorsign.abs.bf16x2
// CHECK_PTX72_SM86: call <2 x bfloat> @llvm.nvvm.fmax.nan.xorsign.abs.bf16x2
__nvvm_fmax_nan_xorsign_abs_bf16x2(BF16X2, NANBF16X2);
// CHECK_PTX72_SM86: call float @llvm.nvvm.fmax.xorsign.abs.f
__nvvm_fmax_xorsign_abs_f(-0.1f, 0.1f);
Expand Down
12 changes: 6 additions & 6 deletions clang/test/CodeGenCUDA/bf16.cu
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

// CHECK-LABEL: .visible .func _Z8test_argPDF16bDF16b(
// CHECK: .param .b64 _Z8test_argPDF16bDF16b_param_0,
// CHECK: .param .b16 _Z8test_argPDF16bDF16b_param_1
// CHECK: .param .align 2 .b8 _Z8test_argPDF16bDF16b_param_1[2]
//
__device__ void test_arg(__bf16 *out, __bf16 in) {
// CHECK-DAG: ld.param.u64 %[[A:rd[0-9]+]], [_Z8test_argPDF16bDF16b_param_0];
Expand All @@ -20,8 +20,8 @@ __device__ void test_arg(__bf16 *out, __bf16 in) {
}


// CHECK-LABEL: .visible .func (.param .b32 func_retval0) _Z8test_retDF16b(
// CHECK: .param .b16 _Z8test_retDF16b_param_0
// CHECK-LABEL: .visible .func (.param .align 2 .b8 func_retval0[2]) _Z8test_retDF16b(
// CHECK: .param .align 2 .b8 _Z8test_retDF16b_param_0[2]
__device__ __bf16 test_ret( __bf16 in) {
// CHECK: ld.param.b16 %[[R:rs[0-9]+]], [_Z8test_retDF16b_param_0];
return in;
Expand All @@ -31,12 +31,12 @@ __device__ __bf16 test_ret( __bf16 in) {

__device__ __bf16 external_func( __bf16 in);

// CHECK-LABEL: .visible .func (.param .b32 func_retval0) _Z9test_callDF16b(
// CHECK: .param .b16 _Z9test_callDF16b_param_0
// CHECK-LABEL: .visible .func (.param .align 2 .b8 func_retval0[2]) _Z9test_callDF16b(
// CHECK: .param .align 2 .b8 _Z9test_callDF16b_param_0[2]
__device__ __bf16 test_call( __bf16 in) {
// CHECK: ld.param.b16 %[[R:rs[0-9]+]], [_Z9test_callDF16b_param_0];
// CHECK: st.param.b16 [param0+0], %[[R]];
// CHECK: .param .b32 retval0;
// CHECK: .param .align 2 .b8 retval0[2];
// CHECK: call.uni (retval0),
// CHECK-NEXT: _Z13external_funcDF16b,
// CHECK-NEXT: (
Expand Down

0 comments on commit 250f2bb

Please sign in to comment.