Skip to content

Commit

Permalink
[X86][RFC] Using __bf16 for AVX512_BF16 intrinsics
Browse files Browse the repository at this point in the history
This is an alternative of D120395 and D120411.

Previously we use `__bfloat16` as a typedef of `unsigned short`. The
name may give user an impression it is a brand new type to represent
BF16. So that they may use it in arithmetic operations and we don't have
a good way to block it.

To solve the problem, we introduced `__bf16` to X86 psABI and landed the
support in Clang by D130964. Now we can solve the problem by switching
intrinsics to the new type.

Reviewed By: LuoYuanke, RKSimon

Differential Revision: https://reviews.llvm.org/D132329
  • Loading branch information
phoebewang committed Oct 19, 2022
1 parent f0ca946 commit bc18193
Show file tree
Hide file tree
Showing 21 changed files with 1,568 additions and 607 deletions.
1 change: 1 addition & 0 deletions clang/docs/ReleaseNotes.rst
Expand Up @@ -583,6 +583,7 @@ X86 Support in Clang
--------------------
- Support ``-mindirect-branch-cs-prefix`` for call and jmp to indirect thunk.
- Fix 32-bit ``__fastcall`` and ``__vectorcall`` ABI mismatch with MSVC.
- Switch ``AVX512-BF16`` intrinsics types from ``short`` to ``__bf16``.

DWARF Support in Clang
----------------------
Expand Down
24 changes: 14 additions & 10 deletions clang/include/clang/Basic/BuiltinsX86.def
Expand Up @@ -1749,16 +1749,16 @@ TARGET_BUILTIN(__builtin_ia32_vpmultishiftqb128, "V16cV16cV16c", "ncV:128:", "av
TARGET_BUILTIN(__builtin_ia32_vpmultishiftqb256, "V32cV32cV32c", "ncV:256:", "avx512vbmi,avx512vl")

// bf16 intrinsics
TARGET_BUILTIN(__builtin_ia32_cvtne2ps2bf16_128, "V8sV4fV4f", "ncV:128:", "avx512bf16,avx512vl")
TARGET_BUILTIN(__builtin_ia32_cvtne2ps2bf16_256, "V16sV8fV8f", "ncV:256:", "avx512bf16,avx512vl")
TARGET_BUILTIN(__builtin_ia32_cvtne2ps2bf16_512, "V32sV16fV16f", "ncV:512:", "avx512bf16")
TARGET_BUILTIN(__builtin_ia32_cvtneps2bf16_128_mask, "V8sV4fV8sUc", "ncV:128:", "avx512bf16,avx512vl")
TARGET_BUILTIN(__builtin_ia32_cvtneps2bf16_256_mask, "V8sV8fV8sUc", "ncV:256:", "avx512bf16,avx512vl")
TARGET_BUILTIN(__builtin_ia32_cvtneps2bf16_512_mask, "V16sV16fV16sUs", "ncV:512:", "avx512bf16")
TARGET_BUILTIN(__builtin_ia32_dpbf16ps_128, "V4fV4fV4iV4i", "ncV:128:", "avx512bf16,avx512vl")
TARGET_BUILTIN(__builtin_ia32_dpbf16ps_256, "V8fV8fV8iV8i", "ncV:256:", "avx512bf16,avx512vl")
TARGET_BUILTIN(__builtin_ia32_dpbf16ps_512, "V16fV16fV16iV16i", "ncV:512:", "avx512bf16")
TARGET_BUILTIN(__builtin_ia32_cvtsbf162ss_32, "fUs", "nc", "avx512bf16")
TARGET_BUILTIN(__builtin_ia32_cvtne2ps2bf16_128, "V8yV4fV4f", "ncV:128:", "avx512bf16,avx512vl")
TARGET_BUILTIN(__builtin_ia32_cvtne2ps2bf16_256, "V16yV8fV8f", "ncV:256:", "avx512bf16,avx512vl")
TARGET_BUILTIN(__builtin_ia32_cvtne2ps2bf16_512, "V32yV16fV16f", "ncV:512:", "avx512bf16")
TARGET_BUILTIN(__builtin_ia32_cvtneps2bf16_128_mask, "V8yV4fV8yUc", "ncV:128:", "avx512bf16,avx512vl")
TARGET_BUILTIN(__builtin_ia32_cvtneps2bf16_256_mask, "V8yV8fV8yUc", "ncV:256:", "avx512bf16,avx512vl")
TARGET_BUILTIN(__builtin_ia32_cvtneps2bf16_512_mask, "V16yV16fV16yUs", "ncV:512:", "avx512bf16")
TARGET_BUILTIN(__builtin_ia32_dpbf16ps_128, "V4fV4fV8yV8y", "ncV:128:", "avx512bf16,avx512vl")
TARGET_BUILTIN(__builtin_ia32_dpbf16ps_256, "V8fV8fV16yV16y", "ncV:256:", "avx512bf16,avx512vl")
TARGET_BUILTIN(__builtin_ia32_dpbf16ps_512, "V16fV16fV32yV32y", "ncV:512:", "avx512bf16")
TARGET_BUILTIN(__builtin_ia32_cvtsbf162ss_32, "fy", "nc", "avx512bf16")

TARGET_BUILTIN(__builtin_ia32_vp2intersect_q_512, "vV8OiV8OiUc*Uc*", "nV:512:", "avx512vp2intersect")
TARGET_BUILTIN(__builtin_ia32_vp2intersect_q_256, "vV4OiV4OiUc*Uc*", "nV:256:", "avx512vp2intersect,avx512vl")
Expand Down Expand Up @@ -1977,6 +1977,9 @@ TARGET_BUILTIN(__builtin_ia32_selectd_512, "V16iUsV16iV16i", "ncV:512:", "avx512
TARGET_BUILTIN(__builtin_ia32_selectph_128, "V8xUcV8xV8x", "ncV:128:", "avx512fp16,avx512vl")
TARGET_BUILTIN(__builtin_ia32_selectph_256, "V16xUsV16xV16x", "ncV:256:", "avx512fp16,avx512vl")
TARGET_BUILTIN(__builtin_ia32_selectph_512, "V32xUiV32xV32x", "ncV:512:", "avx512fp16")
TARGET_BUILTIN(__builtin_ia32_selectpbf_128, "V8yUcV8yV8y", "ncV:128:", "avx512bf16,avx512vl")
TARGET_BUILTIN(__builtin_ia32_selectpbf_256, "V16yUsV16yV16y", "ncV:256:", "avx512bf16,avx512vl")
TARGET_BUILTIN(__builtin_ia32_selectpbf_512, "V32yUiV32yV32y", "ncV:512:", "avx512bf16")
TARGET_BUILTIN(__builtin_ia32_selectq_128, "V2OiUcV2OiV2Oi", "ncV:128:", "avx512vl")
TARGET_BUILTIN(__builtin_ia32_selectq_256, "V4OiUcV4OiV4Oi", "ncV:256:", "avx512vl")
TARGET_BUILTIN(__builtin_ia32_selectq_512, "V8OiUcV8OiV8Oi", "ncV:512:", "avx512f")
Expand All @@ -1987,6 +1990,7 @@ TARGET_BUILTIN(__builtin_ia32_selectpd_128, "V2dUcV2dV2d", "ncV:128:", "avx512vl
TARGET_BUILTIN(__builtin_ia32_selectpd_256, "V4dUcV4dV4d", "ncV:256:", "avx512vl")
TARGET_BUILTIN(__builtin_ia32_selectpd_512, "V8dUcV8dV8d", "ncV:512:", "avx512f")
TARGET_BUILTIN(__builtin_ia32_selectsh_128, "V8xUcV8xV8x", "ncV:128:", "avx512fp16")
TARGET_BUILTIN(__builtin_ia32_selectsbf_128, "V8yUcV8yV8y", "ncV:128:", "avx512bf16")
TARGET_BUILTIN(__builtin_ia32_selectss_128, "V4fUcV4fV4f", "ncV:128:", "avx512f")
TARGET_BUILTIN(__builtin_ia32_selectsd_128, "V2dUcV2dV2d", "ncV:128:", "avx512f")

Expand Down
20 changes: 6 additions & 14 deletions clang/lib/CodeGen/CGBuiltin.cpp
Expand Up @@ -12873,18 +12873,6 @@ static Value *EmitX86CvtF16ToFloatExpr(CodeGenFunction &CGF,
return Res;
}

// Convert a BF16 to a float.
static Value *EmitX86CvtBF16ToFloatExpr(CodeGenFunction &CGF,
const CallExpr *E,
ArrayRef<Value *> Ops) {
llvm::Type *Int32Ty = CGF.Builder.getInt32Ty();
Value *ZeroExt = CGF.Builder.CreateZExt(Ops[0], Int32Ty);
Value *Shl = CGF.Builder.CreateShl(ZeroExt, 16);
llvm::Type *ResultType = CGF.ConvertType(E->getType());
Value *BitCast = CGF.Builder.CreateBitCast(Shl, ResultType);
return BitCast;
}

Value *CodeGenFunction::EmitX86CpuIs(StringRef CPUStr) {

llvm::Type *Int32Ty = Builder.getInt32Ty();
Expand Down Expand Up @@ -14291,6 +14279,9 @@ Value *CodeGenFunction::EmitX86BuiltinExpr(unsigned BuiltinID,
case X86::BI__builtin_ia32_selectph_128:
case X86::BI__builtin_ia32_selectph_256:
case X86::BI__builtin_ia32_selectph_512:
case X86::BI__builtin_ia32_selectpbf_128:
case X86::BI__builtin_ia32_selectpbf_256:
case X86::BI__builtin_ia32_selectpbf_512:
case X86::BI__builtin_ia32_selectps_128:
case X86::BI__builtin_ia32_selectps_256:
case X86::BI__builtin_ia32_selectps_512:
Expand All @@ -14299,6 +14290,7 @@ Value *CodeGenFunction::EmitX86BuiltinExpr(unsigned BuiltinID,
case X86::BI__builtin_ia32_selectpd_512:
return EmitX86Select(*this, Ops[0], Ops[1], Ops[2]);
case X86::BI__builtin_ia32_selectsh_128:
case X86::BI__builtin_ia32_selectsbf_128:
case X86::BI__builtin_ia32_selectss_128:
case X86::BI__builtin_ia32_selectsd_128: {
Value *A = Builder.CreateExtractElement(Ops[1], (uint64_t)0);
Expand Down Expand Up @@ -15135,7 +15127,7 @@ Value *CodeGenFunction::EmitX86BuiltinExpr(unsigned BuiltinID,
return EmitX86CvtF16ToFloatExpr(*this, Ops, ConvertType(E->getType()));
}

// AVX512 bf16 intrinsics
// AVX512 bf16 intrinsics
case X86::BI__builtin_ia32_cvtneps2bf16_128_mask: {
Ops[2] = getMaskVecValue(
*this, Ops[2],
Expand All @@ -15144,7 +15136,7 @@ Value *CodeGenFunction::EmitX86BuiltinExpr(unsigned BuiltinID,
return Builder.CreateCall(CGM.getIntrinsic(IID), Ops);
}
case X86::BI__builtin_ia32_cvtsbf162ss_32:
return EmitX86CvtBF16ToFloatExpr(*this, E, Ops);
return Builder.CreateFPExt(Ops[0], Builder.getFloatTy());

case X86::BI__builtin_ia32_cvtneps2bf16_256_mask:
case X86::BI__builtin_ia32_cvtneps2bf16_512_mask: {
Expand Down
35 changes: 20 additions & 15 deletions clang/lib/Headers/avx512bf16intrin.h
Expand Up @@ -10,12 +10,16 @@
#error "Never use <avx512bf16intrin.h> directly; include <immintrin.h> instead."
#endif

#ifdef __SSE2__

#ifndef __AVX512BF16INTRIN_H
#define __AVX512BF16INTRIN_H

typedef short __m512bh __attribute__((__vector_size__(64), __aligned__(64)));
typedef short __m256bh __attribute__((__vector_size__(32), __aligned__(32)));
typedef unsigned short __bfloat16;
typedef __bf16 __v32bf __attribute__((__vector_size__(64), __aligned__(64)));
typedef __bf16 __m512bh __attribute__((__vector_size__(64), __aligned__(64)));
typedef __bf16 __v16bf __attribute__((__vector_size__(32), __aligned__(32)));
typedef __bf16 __m256bh __attribute__((__vector_size__(32), __aligned__(32)));
typedef __bf16 __bfloat16 __attribute__((deprecated("use __bf16 instead")));

#define __DEFAULT_FN_ATTRS512 \
__attribute__((__always_inline__, __nodebug__, __target__("avx512bf16"), \
Expand All @@ -33,7 +37,7 @@ typedef unsigned short __bfloat16;
/// A bfloat data.
/// \returns A float data whose sign field and exponent field keep unchanged,
/// and fraction field is extended to 23 bits.
static __inline__ float __DEFAULT_FN_ATTRS _mm_cvtsbh_ss(__bfloat16 __A) {
static __inline__ float __DEFAULT_FN_ATTRS _mm_cvtsbh_ss(__bf16 __A) {
return __builtin_ia32_cvtsbf162ss_32(__A);
}

Expand Down Expand Up @@ -74,9 +78,9 @@ _mm512_cvtne2ps_pbh(__m512 __A, __m512 __B) {
/// conversion of __B, and higher 256 bits come from conversion of __A.
static __inline__ __m512bh __DEFAULT_FN_ATTRS512
_mm512_mask_cvtne2ps_pbh(__m512bh __W, __mmask32 __U, __m512 __A, __m512 __B) {
return (__m512bh)__builtin_ia32_selectw_512((__mmask32)__U,
(__v32hi)_mm512_cvtne2ps_pbh(__A, __B),
(__v32hi)__W);
return (__m512bh)__builtin_ia32_selectpbf_512((__mmask32)__U,
(__v32bf)_mm512_cvtne2ps_pbh(__A, __B),
(__v32bf)__W);
}

/// Convert Two Packed Single Data to One Packed BF16 Data.
Expand All @@ -96,9 +100,9 @@ _mm512_mask_cvtne2ps_pbh(__m512bh __W, __mmask32 __U, __m512 __A, __m512 __B) {
/// conversion of __B, and higher 256 bits come from conversion of __A.
static __inline__ __m512bh __DEFAULT_FN_ATTRS512
_mm512_maskz_cvtne2ps_pbh(__mmask32 __U, __m512 __A, __m512 __B) {
return (__m512bh)__builtin_ia32_selectw_512((__mmask32)__U,
(__v32hi)_mm512_cvtne2ps_pbh(__A, __B),
(__v32hi)_mm512_setzero_si512());
return (__m512bh)__builtin_ia32_selectpbf_512((__mmask32)__U,
(__v32bf)_mm512_cvtne2ps_pbh(__A, __B),
(__v32bf)_mm512_setzero_si512());
}

/// Convert Packed Single Data to Packed BF16 Data.
Expand All @@ -113,7 +117,7 @@ _mm512_maskz_cvtne2ps_pbh(__mmask32 __U, __m512 __A, __m512 __B) {
static __inline__ __m256bh __DEFAULT_FN_ATTRS512
_mm512_cvtneps_pbh(__m512 __A) {
return (__m256bh)__builtin_ia32_cvtneps2bf16_512_mask((__v16sf)__A,
(__v16hi)_mm256_undefined_si256(),
(__v16bf)_mm256_undefined_si256(),
(__mmask16)-1);
}

Expand All @@ -134,7 +138,7 @@ _mm512_cvtneps_pbh(__m512 __A) {
static __inline__ __m256bh __DEFAULT_FN_ATTRS512
_mm512_mask_cvtneps_pbh(__m256bh __W, __mmask16 __U, __m512 __A) {
return (__m256bh)__builtin_ia32_cvtneps2bf16_512_mask((__v16sf)__A,
(__v16hi)__W,
(__v16bf)__W,
(__mmask16)__U);
}

Expand All @@ -153,7 +157,7 @@ _mm512_mask_cvtneps_pbh(__m256bh __W, __mmask16 __U, __m512 __A) {
static __inline__ __m256bh __DEFAULT_FN_ATTRS512
_mm512_maskz_cvtneps_pbh(__mmask16 __U, __m512 __A) {
return (__m256bh)__builtin_ia32_cvtneps2bf16_512_mask((__v16sf)__A,
(__v16hi)_mm256_setzero_si256(),
(__v16bf)_mm256_setzero_si256(),
(__mmask16)__U);
}

Expand All @@ -174,8 +178,8 @@ _mm512_maskz_cvtneps_pbh(__mmask16 __U, __m512 __A) {
static __inline__ __m512 __DEFAULT_FN_ATTRS512
_mm512_dpbf16_ps(__m512 __D, __m512bh __A, __m512bh __B) {
return (__m512)__builtin_ia32_dpbf16ps_512((__v16sf) __D,
(__v16si) __A,
(__v16si) __B);
(__v32bf) __A,
(__v32bf) __B);
}

/// Dot Product of BF16 Pairs Accumulated into Packed Single Precision.
Expand Down Expand Up @@ -277,3 +281,4 @@ _mm512_mask_cvtpbh_ps(__m512 __S, __mmask16 __U, __m256bh __A) {
#undef __DEFAULT_FN_ATTRS512

#endif
#endif
58 changes: 31 additions & 27 deletions clang/lib/Headers/avx512vlbf16intrin.h
Expand Up @@ -10,10 +10,13 @@
#error "Never use <avx512vlbf16intrin.h> directly; include <immintrin.h> instead."
#endif

#ifdef __SSE2__

#ifndef __AVX512VLBF16INTRIN_H
#define __AVX512VLBF16INTRIN_H

typedef short __m128bh __attribute__((__vector_size__(16), __aligned__(16)));
typedef __bf16 __v8bf __attribute__((__vector_size__(16), __aligned__(16)));
typedef __bf16 __m128bh __attribute__((__vector_size__(16), __aligned__(16)));

#define __DEFAULT_FN_ATTRS128 \
__attribute__((__always_inline__, __nodebug__, \
Expand Down Expand Up @@ -59,9 +62,9 @@ _mm_cvtne2ps_pbh(__m128 __A, __m128 __B) {
/// conversion of __B, and higher 64 bits come from conversion of __A.
static __inline__ __m128bh __DEFAULT_FN_ATTRS128
_mm_mask_cvtne2ps_pbh(__m128bh __W, __mmask8 __U, __m128 __A, __m128 __B) {
return (__m128bh)__builtin_ia32_selectw_128((__mmask8)__U,
(__v8hi)_mm_cvtne2ps_pbh(__A, __B),
(__v8hi)__W);
return (__m128bh)__builtin_ia32_selectpbf_128((__mmask8)__U,
(__v8bf)_mm_cvtne2ps_pbh(__A, __B),
(__v8bf)__W);
}

/// Convert Two Packed Single Data to One Packed BF16 Data.
Expand All @@ -81,9 +84,9 @@ _mm_mask_cvtne2ps_pbh(__m128bh __W, __mmask8 __U, __m128 __A, __m128 __B) {
/// conversion of __B, and higher 64 bits come from conversion of __A.
static __inline__ __m128bh __DEFAULT_FN_ATTRS128
_mm_maskz_cvtne2ps_pbh(__mmask8 __U, __m128 __A, __m128 __B) {
return (__m128bh)__builtin_ia32_selectw_128((__mmask8)__U,
(__v8hi)_mm_cvtne2ps_pbh(__A, __B),
(__v8hi)_mm_setzero_si128());
return (__m128bh)__builtin_ia32_selectpbf_128((__mmask8)__U,
(__v8bf)_mm_cvtne2ps_pbh(__A, __B),
(__v8bf)_mm_setzero_si128());
}

/// Convert Two Packed Single Data to One Packed BF16 Data.
Expand Down Expand Up @@ -123,9 +126,9 @@ _mm256_cvtne2ps_pbh(__m256 __A, __m256 __B) {
/// conversion of __B, and higher 128 bits come from conversion of __A.
static __inline__ __m256bh __DEFAULT_FN_ATTRS256
_mm256_mask_cvtne2ps_pbh(__m256bh __W, __mmask16 __U, __m256 __A, __m256 __B) {
return (__m256bh)__builtin_ia32_selectw_256((__mmask16)__U,
(__v16hi)_mm256_cvtne2ps_pbh(__A, __B),
(__v16hi)__W);
return (__m256bh)__builtin_ia32_selectpbf_256((__mmask16)__U,
(__v16bf)_mm256_cvtne2ps_pbh(__A, __B),
(__v16bf)__W);
}

/// Convert Two Packed Single Data to One Packed BF16 Data.
Expand All @@ -145,9 +148,9 @@ _mm256_mask_cvtne2ps_pbh(__m256bh __W, __mmask16 __U, __m256 __A, __m256 __B) {
/// conversion of __B, and higher 128 bits come from conversion of __A.
static __inline__ __m256bh __DEFAULT_FN_ATTRS256
_mm256_maskz_cvtne2ps_pbh(__mmask16 __U, __m256 __A, __m256 __B) {
return (__m256bh)__builtin_ia32_selectw_256((__mmask16)__U,
(__v16hi)_mm256_cvtne2ps_pbh(__A, __B),
(__v16hi)_mm256_setzero_si256());
return (__m256bh)__builtin_ia32_selectpbf_256((__mmask16)__U,
(__v16bf)_mm256_cvtne2ps_pbh(__A, __B),
(__v16bf)_mm256_setzero_si256());
}

/// Convert Packed Single Data to Packed BF16 Data.
Expand All @@ -163,7 +166,7 @@ _mm256_maskz_cvtne2ps_pbh(__mmask16 __U, __m256 __A, __m256 __B) {
static __inline__ __m128bh __DEFAULT_FN_ATTRS128
_mm_cvtneps_pbh(__m128 __A) {
return (__m128bh)__builtin_ia32_cvtneps2bf16_128_mask((__v4sf) __A,
(__v8hi)_mm_undefined_si128(),
(__v8bf)_mm_undefined_si128(),
(__mmask8)-1);
}

Expand All @@ -185,7 +188,7 @@ _mm_cvtneps_pbh(__m128 __A) {
static __inline__ __m128bh __DEFAULT_FN_ATTRS128
_mm_mask_cvtneps_pbh(__m128bh __W, __mmask8 __U, __m128 __A) {
return (__m128bh)__builtin_ia32_cvtneps2bf16_128_mask((__v4sf) __A,
(__v8hi)__W,
(__v8bf)__W,
(__mmask8)__U);
}

Expand All @@ -205,7 +208,7 @@ _mm_mask_cvtneps_pbh(__m128bh __W, __mmask8 __U, __m128 __A) {
static __inline__ __m128bh __DEFAULT_FN_ATTRS128
_mm_maskz_cvtneps_pbh(__mmask8 __U, __m128 __A) {
return (__m128bh)__builtin_ia32_cvtneps2bf16_128_mask((__v4sf) __A,
(__v8hi)_mm_setzero_si128(),
(__v8bf)_mm_setzero_si128(),
(__mmask8)__U);
}

Expand All @@ -221,7 +224,7 @@ _mm_maskz_cvtneps_pbh(__mmask8 __U, __m128 __A) {
static __inline__ __m128bh __DEFAULT_FN_ATTRS256
_mm256_cvtneps_pbh(__m256 __A) {
return (__m128bh)__builtin_ia32_cvtneps2bf16_256_mask((__v8sf)__A,
(__v8hi)_mm_undefined_si128(),
(__v8bf)_mm_undefined_si128(),
(__mmask8)-1);
}

Expand All @@ -242,7 +245,7 @@ _mm256_cvtneps_pbh(__m256 __A) {
static __inline__ __m128bh __DEFAULT_FN_ATTRS256
_mm256_mask_cvtneps_pbh(__m128bh __W, __mmask8 __U, __m256 __A) {
return (__m128bh)__builtin_ia32_cvtneps2bf16_256_mask((__v8sf)__A,
(__v8hi)__W,
(__v8bf)__W,
(__mmask8)__U);
}

Expand All @@ -261,7 +264,7 @@ _mm256_mask_cvtneps_pbh(__m128bh __W, __mmask8 __U, __m256 __A) {
static __inline__ __m128bh __DEFAULT_FN_ATTRS256
_mm256_maskz_cvtneps_pbh(__mmask8 __U, __m256 __A) {
return (__m128bh)__builtin_ia32_cvtneps2bf16_256_mask((__v8sf)__A,
(__v8hi)_mm_setzero_si128(),
(__v8bf)_mm_setzero_si128(),
(__mmask8)__U);
}

Expand All @@ -282,8 +285,8 @@ _mm256_maskz_cvtneps_pbh(__mmask8 __U, __m256 __A) {
static __inline__ __m128 __DEFAULT_FN_ATTRS128
_mm_dpbf16_ps(__m128 __D, __m128bh __A, __m128bh __B) {
return (__m128)__builtin_ia32_dpbf16ps_128((__v4sf)__D,
(__v4si)__A,
(__v4si)__B);
(__v8bf)__A,
(__v8bf)__B);
}

/// Dot Product of BF16 Pairs Accumulated into Packed Single Precision.
Expand Down Expand Up @@ -351,8 +354,8 @@ _mm_maskz_dpbf16_ps(__mmask8 __U, __m128 __D, __m128bh __A, __m128bh __B) {
static __inline__ __m256 __DEFAULT_FN_ATTRS256
_mm256_dpbf16_ps(__m256 __D, __m256bh __A, __m256bh __B) {
return (__m256)__builtin_ia32_dpbf16ps_256((__v8sf)__D,
(__v8si)__A,
(__v8si)__B);
(__v16bf)__A,
(__v16bf)__B);
}

/// Dot Product of BF16 Pairs Accumulated into Packed Single Precision.
Expand Down Expand Up @@ -413,11 +416,11 @@ _mm256_maskz_dpbf16_ps(__mmask8 __U, __m256 __D, __m256bh __A, __m256bh __B) {
/// A float data.
/// \returns A bf16 data whose sign field and exponent field keep unchanged,
/// and fraction field is truncated to 7 bits.
static __inline__ __bfloat16 __DEFAULT_FN_ATTRS128 _mm_cvtness_sbh(float __A) {
static __inline__ __bf16 __DEFAULT_FN_ATTRS128 _mm_cvtness_sbh(float __A) {
__v4sf __V = {__A, 0, 0, 0};
__v8hi __R = __builtin_ia32_cvtneps2bf16_128_mask(
(__v4sf)__V, (__v8hi)_mm_undefined_si128(), (__mmask8)-1);
return (__bfloat16)__R[0];
__v8bf __R = __builtin_ia32_cvtneps2bf16_128_mask(
(__v4sf)__V, (__v8bf)_mm_undefined_si128(), (__mmask8)-1);
return (__bf16)__R[0];
}

/// Convert Packed BF16 Data to Packed float Data.
Expand Down Expand Up @@ -520,3 +523,4 @@ _mm256_mask_cvtpbh_ps(__m256 __S, __mmask8 __U, __m128bh __A) {
#undef __DEFAULT_FN_ATTRS256

#endif
#endif

0 comments on commit bc18193

Please sign in to comment.