Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 15 additions & 4 deletions clang/lib/Headers/avx512bf16intrin.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,14 @@ typedef __bf16 __bfloat16 __attribute__((deprecated("use __bf16 instead")));
#define __DEFAULT_FN_ATTRS \
__attribute__((__always_inline__, __nodebug__, __target__("avx512bf16")))

#if defined(__cplusplus) && (__cplusplus >= 201103L)
#define __DEFAULT_FN_ATTRS512_CONSTEXPR __DEFAULT_FN_ATTRS512 constexpr
#define __DEFAULT_FN_ATTRS_CONSTEXPR __DEFAULT_FN_ATTRS constexpr
#else
#define __DEFAULT_FN_ATTRS512_CONSTEXPR __DEFAULT_FN_ATTRS512
#define __DEFAULT_FN_ATTRS_CONSTEXPR __DEFAULT_FN_ATTRS
#endif

/// Convert One BF16 Data to One Single Float Data.
///
/// \headerfile <x86intrin.h>
Expand All @@ -35,7 +43,7 @@ typedef __bf16 __bfloat16 __attribute__((deprecated("use __bf16 instead")));
/// A bfloat data.
/// \returns A float data whose sign field and exponent field keep unchanged,
/// and fraction field is extended to 23 bits.
static __inline__ float __DEFAULT_FN_ATTRS _mm_cvtsbh_ss(__bf16 __A) {
static __inline__ float __DEFAULT_FN_ATTRS_CONSTEXPR _mm_cvtsbh_ss(__bf16 __A) {
return (float)(__A);
}

Expand Down Expand Up @@ -235,7 +243,8 @@ _mm512_maskz_dpbf16_ps(__mmask16 __U, __m512 __D, __m512bh __A, __m512bh __B) {
/// \param __A
/// A 256-bit vector of [16 x bfloat].
/// \returns A 512-bit vector of [16 x float] come from conversion of __A
static __inline__ __m512 __DEFAULT_FN_ATTRS512 _mm512_cvtpbh_ps(__m256bh __A) {
static __inline__ __m512 __DEFAULT_FN_ATTRS512_CONSTEXPR
_mm512_cvtpbh_ps(__m256bh __A) {
return (__m512) __builtin_convertvector(__A, __v16sf);
}

Expand All @@ -249,7 +258,7 @@ static __inline__ __m512 __DEFAULT_FN_ATTRS512 _mm512_cvtpbh_ps(__m256bh __A) {
/// \param __A
/// A 256-bit vector of [16 x bfloat].
/// \returns A 512-bit vector of [16 x float] come from conversion of __A
static __inline__ __m512 __DEFAULT_FN_ATTRS512
static __inline__ __m512 __DEFAULT_FN_ATTRS512_CONSTEXPR
_mm512_maskz_cvtpbh_ps(__mmask16 __U, __m256bh __A) {
return (__m512)__builtin_ia32_selectps_512((__mmask16)__U,
(__v16sf)_mm512_cvtpbh_ps(__A),
Expand All @@ -268,14 +277,16 @@ _mm512_maskz_cvtpbh_ps(__mmask16 __U, __m256bh __A) {
/// \param __A
/// A 256-bit vector of [16 x bfloat].
/// \returns A 512-bit vector of [16 x float] come from conversion of __A
static __inline__ __m512 __DEFAULT_FN_ATTRS512
static __inline__ __m512 __DEFAULT_FN_ATTRS512_CONSTEXPR
_mm512_mask_cvtpbh_ps(__m512 __S, __mmask16 __U, __m256bh __A) {
return (__m512)__builtin_ia32_selectps_512(
(__mmask16)__U, (__v16sf)_mm512_cvtpbh_ps(__A), (__v16sf)__S);
}

#undef __DEFAULT_FN_ATTRS
#undef __DEFAULT_FN_ATTRS_CONSTEXPR
#undef __DEFAULT_FN_ATTRS512
#undef __DEFAULT_FN_ATTRS512_CONSTEXPR

#endif
#endif
24 changes: 18 additions & 6 deletions clang/lib/Headers/avx512vlbf16intrin.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,14 @@
__target__("avx512vl,avx512bf16"), \
__min_vector_width__(256)))

#if defined(__cplusplus) && (__cplusplus >= 201103L)
#define __DEFAULT_FN_ATTRS128_CONSTEXPR __DEFAULT_FN_ATTRS128 constexpr
#define __DEFAULT_FN_ATTRS256_CONSTEXPR __DEFAULT_FN_ATTRS256 constexpr
#else
#define __DEFAULT_FN_ATTRS128_CONSTEXPR __DEFAULT_FN_ATTRS128
#define __DEFAULT_FN_ATTRS256_CONSTEXPR __DEFAULT_FN_ATTRS256
#endif

/// Convert Two Packed Single Data to One Packed BF16 Data.
///
/// \headerfile <x86intrin.h>
Expand Down Expand Up @@ -421,7 +429,8 @@ static __inline__ __bf16 __DEFAULT_FN_ATTRS128 _mm_cvtness_sbh(float __A) {
/// \param __A
/// A 128-bit vector of [4 x bfloat].
/// \returns A 128-bit vector of [4 x float] come from conversion of __A
static __inline__ __m128 __DEFAULT_FN_ATTRS128 _mm_cvtpbh_ps(__m128bh __A) {
static __inline__ __m128 __DEFAULT_FN_ATTRS128_CONSTEXPR
_mm_cvtpbh_ps(__m128bh __A) {
return (__m128)_mm256_castps256_ps128(
(__m256) __builtin_convertvector(__A, __v8sf));
}
Expand All @@ -433,7 +442,8 @@ static __inline__ __m128 __DEFAULT_FN_ATTRS128 _mm_cvtpbh_ps(__m128bh __A) {
/// \param __A
/// A 128-bit vector of [8 x bfloat].
/// \returns A 256-bit vector of [8 x float] come from conversion of __A
static __inline__ __m256 __DEFAULT_FN_ATTRS256 _mm256_cvtpbh_ps(__m128bh __A) {
static __inline__ __m256 __DEFAULT_FN_ATTRS256_CONSTEXPR
_mm256_cvtpbh_ps(__m128bh __A) {
return (__m256) __builtin_convertvector(__A, __v8sf);
}

Expand All @@ -447,7 +457,7 @@ static __inline__ __m256 __DEFAULT_FN_ATTRS256 _mm256_cvtpbh_ps(__m128bh __A) {
/// \param __A
/// A 128-bit vector of [4 x bfloat].
/// \returns A 128-bit vector of [4 x float] come from conversion of __A
static __inline__ __m128 __DEFAULT_FN_ATTRS128
static __inline__ __m128 __DEFAULT_FN_ATTRS128_CONSTEXPR
_mm_maskz_cvtpbh_ps(__mmask8 __U, __m128bh __A) {
return (__m128)__builtin_ia32_selectps_128(
(__mmask8)__U, (__v4sf)_mm_cvtpbh_ps(__A), (__v4sf)_mm_setzero_ps());
Expand All @@ -463,7 +473,7 @@ _mm_maskz_cvtpbh_ps(__mmask8 __U, __m128bh __A) {
/// \param __A
/// A 128-bit vector of [8 x bfloat].
/// \returns A 256-bit vector of [8 x float] come from conversion of __A
static __inline__ __m256 __DEFAULT_FN_ATTRS256
static __inline__ __m256 __DEFAULT_FN_ATTRS256_CONSTEXPR
_mm256_maskz_cvtpbh_ps(__mmask8 __U, __m128bh __A) {
return (__m256)__builtin_ia32_selectps_256((__mmask8)__U,
(__v8sf)_mm256_cvtpbh_ps(__A),
Expand All @@ -483,7 +493,7 @@ _mm256_maskz_cvtpbh_ps(__mmask8 __U, __m128bh __A) {
/// \param __A
/// A 128-bit vector of [4 x bfloat].
/// \returns A 128-bit vector of [4 x float] come from conversion of __A
static __inline__ __m128 __DEFAULT_FN_ATTRS128
static __inline__ __m128 __DEFAULT_FN_ATTRS128_CONSTEXPR
_mm_mask_cvtpbh_ps(__m128 __S, __mmask8 __U, __m128bh __A) {
return (__m128)__builtin_ia32_selectps_128(
(__mmask8)__U, (__v4sf)_mm_cvtpbh_ps(__A), (__v4sf)__S);
Expand All @@ -502,14 +512,16 @@ _mm_mask_cvtpbh_ps(__m128 __S, __mmask8 __U, __m128bh __A) {
/// \param __A
/// A 128-bit vector of [8 x bfloat].
/// \returns A 256-bit vector of [8 x float] come from conversion of __A
static __inline__ __m256 __DEFAULT_FN_ATTRS256
static __inline__ __m256 __DEFAULT_FN_ATTRS256_CONSTEXPR
_mm256_mask_cvtpbh_ps(__m256 __S, __mmask8 __U, __m128bh __A) {
return (__m256)__builtin_ia32_selectps_256(
(__mmask8)__U, (__v8sf)_mm256_cvtpbh_ps(__A), (__v8sf)__S);
}

#undef __DEFAULT_FN_ATTRS128
#undef __DEFAULT_FN_ATTRS256
#undef __DEFAULT_FN_ATTRS128_CONSTEXPR
#undef __DEFAULT_FN_ATTRS256_CONSTEXPR

#endif
#endif
5 changes: 5 additions & 0 deletions clang/test/CodeGen/X86/avx512bf16-builtins.c
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,15 @@
// RUN: %clang_cc1 -x c++ -flax-vector-conversions=none -ffreestanding %s -triple=i386-apple-darwin -target-feature +avx512bf16 -emit-llvm -o - -Wall -Werror -fexperimental-new-constant-interpreter | FileCheck %s

#include <immintrin.h>
#include "builtin_test_helpers.h"

float test_mm_cvtsbh_ss(__bf16 A) {
// CHECK-LABEL: test_mm_cvtsbh_ss
// CHECK: fpext bfloat %{{.*}} to float
// CHECK: ret float %{{.*}}
return _mm_cvtsbh_ss(A);
}
TEST_CONSTEXPR(_mm_cvtsbh_ss(-1.0f) == -1.0f);

__m512bh test_mm512_cvtne2ps_pbh(__m512 A, __m512 B) {
// CHECK-LABEL: test_mm512_cvtne2ps_pbh
Expand Down Expand Up @@ -82,17 +84,20 @@ __m512 test_mm512_cvtpbh_ps(__m256bh A) {
// CHECK: fpext <16 x bfloat> %{{.*}} to <16 x float>
return _mm512_cvtpbh_ps(A);
}
TEST_CONSTEXPR(match_m512(_mm512_cvtpbh_ps((__m256bh){-0.0f, 1.0f, -2.0f, 4.0f, -8.0f, 16.0f, -32.0f, 64.0f, -128.0f, -0.5f, 0.25f, -0.125f, -4.0f, 2.0f, -1.0f, 0.0f}), -0.0f, 1.0f, -2.0f, 4.0f, -8.0f, 16.0f, -32.0f, 64.0f, -128.0f, -0.5f, 0.25f, -0.125f, -4.0f, 2.0f, -1.0f, 0.0f));

__m512 test_mm512_maskz_cvtpbh_ps(__mmask16 M, __m256bh A) {
// CHECK-LABEL: test_mm512_maskz_cvtpbh_ps
// CHECK: fpext <16 x bfloat> %{{.*}} to <16 x float>
// CHECK: select <16 x i1> %{{.*}}, <16 x float> %{{.*}}, <16 x float> %{{.*}}
return _mm512_maskz_cvtpbh_ps(M, A);
}
TEST_CONSTEXPR(match_m512(_mm512_maskz_cvtpbh_ps(0xA753, (__m256bh){-0.0f, 1.0f, -2.0f, 4.0f, -8.0f, 16.0f, -32.0f, 64.0f, -128.0f, -0.5f, 0.25f, -0.125f, -4.0f, 2.0f, -1.0f, 0.0f}), -0.0f, 1.0f, 0.0f, 0.0f, -8.0f, 0.0f, -32.0f, 0.0f, -128.0f, -0.5f, 0.25f, 0.0f, 0.0f, 2.0f, 0.0f, 0.0f));

__m512 test_mm512_mask_cvtpbh_ps(__m512 S, __mmask16 M, __m256bh A) {
// CHECK-LABEL: test_mm512_mask_cvtpbh_ps
// CHECK: fpext <16 x bfloat> %{{.*}} to <16 x float>
// CHECK: select <16 x i1> %{{.*}}, <16 x float> %{{.*}}, <16 x float> %{{.*}}
return _mm512_mask_cvtpbh_ps(S, M, A);
}
TEST_CONSTEXPR(match_m512(_mm512_mask_cvtpbh_ps((__m512){ 99.0f, 99.0f, 99.0f, 99.0f, 99.0f, 99.0f, 99.0f, 99.0f, 99.0f, 99.0f, 99.0f, 99.0f, 99.0f, 99.0f, 99.0f, 99.0f }, 0xA753, (__m256bh){-0.0f, 1.0f, -2.0f, 4.0f, -8.0f, 16.0f, -32.0f, 64.0f, -128.0f, -0.5f, 0.25f, -0.125f, -4.0f, 2.0f, -1.0f, 0.0f}), -0.0f, 1.0f, 99.0f, 99.0f, -8.0f, 99.0f, -32.0f, 99.0f, -128.0f, -0.5f, 0.25f, 99.0f, 99.0f, 2.0f, 99.0f, 0.0f));
7 changes: 7 additions & 0 deletions clang/test/CodeGen/X86/avx512vlbf16-builtins.c
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
// RUN: %clang_cc1 -x c++ -flax-vector-conversions=none -ffreestanding %s -triple=i386-apple-darwin -target-feature +avx512bf16 -target-feature +avx512vl -emit-llvm -o - -Wall -Werror -fexperimental-new-constant-interpreter | FileCheck %s

#include <immintrin.h>
#include "builtin_test_helpers.h"

__m128bh test_mm_cvtne2ps2bf16(__m128 A, __m128 B) {
// CHECK-LABEL: test_mm_cvtne2ps2bf16
Expand Down Expand Up @@ -160,12 +161,14 @@ __m128 test_mm_cvtpbh_ps(__m128bh A) {
// CHECK: shufflevector <8 x float> %{{.*}}, <8 x float> %{{.*}}, <4 x i32> <i32 0, i32 1, i32 2, i32 3>
return _mm_cvtpbh_ps(A);
}
TEST_CONSTEXPR(match_m128(_mm_cvtpbh_ps((__m128bh){-8.0f, 16.0f, -32.0f, 64.0f, -0.0f, 1.0f, -2.0f, 4.0f}), -8.0f, 16.0f, -32.0f, 64.0f));

__m256 test_mm256_cvtpbh_ps(__m128bh A) {
// CHECK-LABEL: test_mm256_cvtpbh_ps
// CHECK: fpext <8 x bfloat> %{{.*}} to <8 x float>
return _mm256_cvtpbh_ps(A);
}
TEST_CONSTEXPR(match_m256(_mm256_cvtpbh_ps((__m128bh){-0.0f, 1.0f, -2.0f, 4.0f, -8.0f, 16.0f, -32.0f, 64.0f}), -0.0f, 1.0f, -2.0f, 4.0f, -8.0f, 16.0f, -32.0f, 64.0f));

__m128 test_mm_maskz_cvtpbh_ps(__mmask8 M, __m128bh A) {
// CHECK-LABEL: test_mm_maskz_cvtpbh_ps
Expand All @@ -174,13 +177,15 @@ __m128 test_mm_maskz_cvtpbh_ps(__mmask8 M, __m128bh A) {
// CHECK: select <4 x i1> %{{.*}}, <4 x float> %{{.*}}, <4 x float> %{{.*}}
return _mm_maskz_cvtpbh_ps(M, A);
}
TEST_CONSTEXPR(match_m128(_mm_maskz_cvtpbh_ps(0x01, (__m128bh){-0.0f, 1.0f, -2.0f, 4.0f, -8.0f, 16.0f, -32.0f, 64.0f}), -0.0f, 0.0f, 0.0f, 0.0f));

__m256 test_mm256_maskz_cvtpbh_ps(__mmask8 M, __m128bh A) {
// CHECK-LABEL: test_mm256_maskz_cvtpbh_ps
// CHECK: fpext <8 x bfloat> %{{.*}} to <8 x float>
// CHECK: select <8 x i1> %{{.*}}, <8 x float> %{{.*}}, <8 x float> %{{.*}}
return _mm256_maskz_cvtpbh_ps(M, A);
}
TEST_CONSTEXPR(match_m256(_mm256_maskz_cvtpbh_ps(0x73, (__m128bh){-0.0f, 1.0f, -2.0f, 4.0f, -8.0f, 16.0f, -32.0f, 64.0f}), -0.0f, 1.0f, 0.0f, 0.0f, -8.0f, 16.0f, -32.0f, 0.0f));

__m128 test_mm_mask_cvtpbh_ps(__m128 S, __mmask8 M, __m128bh A) {
// CHECK-LABEL: test_mm_mask_cvtpbh_ps
Expand All @@ -189,10 +194,12 @@ __m128 test_mm_mask_cvtpbh_ps(__m128 S, __mmask8 M, __m128bh A) {
// CHECK: select <4 x i1> %{{.*}}, <4 x float> %{{.*}}, <4 x float> %{{.*}}
return _mm_mask_cvtpbh_ps(S, M, A);
}
TEST_CONSTEXPR(match_m128(_mm_mask_cvtpbh_ps((__m128){ 99.0f, 99.0f, 99.0f, 99.0f }, 0x03, (__m128bh){-0.0f, 1.0f, -2.0f, 4.0f, -8.0f, 16.0f, -32.0f, 64.0f}), -0.0f, 1.0f, 99.0f, 99.0f));

__m256 test_mm256_mask_cvtpbh_ps(__m256 S, __mmask8 M, __m128bh A) {
// CHECK-LABEL: test_mm256_mask_cvtpbh_ps
// CHECK: fpext <8 x bfloat> %{{.*}} to <8 x float>
// CHECK: select <8 x i1> %{{.*}}, <8 x float> %{{.*}}, <8 x float> %{{.*}}
return _mm256_mask_cvtpbh_ps(S, M, A);
}
TEST_CONSTEXPR(match_m256(_mm256_mask_cvtpbh_ps((__m256){ 99.0f, 99.0f, 99.0f, 99.0f, 99.0f, 99.0f, 99.0f, 99.0f }, 0x37, (__m128bh){-0.0f, 1.0f, -2.0f, 4.0f, -8.0f, 16.0f, -32.0f, 64.0f}), -0.0f, 1.0f, -2.0f, 99.0f, -8.0f, 16.0f, 99.0f, 99.0f));
Loading