Skip to content
Merged
25 changes: 19 additions & 6 deletions sycl/include/sycl/ext/oneapi/experimental/bfloat16.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,10 @@
#include <CL/__spirv/spirv_ops.hpp>
#include <sycl/half_type.hpp>

#if !defined(__SYCL_DEVICE_ONLY__)
#include <cmath>
#endif

namespace sycl {
__SYCL_INLINE_VER_NAMESPACE(_V1) {
namespace ext {
Expand All @@ -35,9 +39,17 @@ class bfloat16 {
return __spirv_ConvertFToBF16INTEL(a);
#endif
#else
(void)a;
throw exception{errc::feature_not_supported,
"Bfloat16 conversion is not supported on host device"};
// In case of float value is nan - propagate bfloat16's qnan
if (std::isnan(a))
return 0xffc1;
union {
uint32_t intStorage;
float floatValue;
};
floatValue = a;
// Do RNE and truncate
uint32_t roundingBias = ((intStorage >> 16) & 0x1) + 0x00007FFF;
return static_cast<uint16_t>((intStorage + roundingBias) >> 16);
#endif
}
static float to_float(const storage_t &a) {
Expand All @@ -51,9 +63,10 @@ class bfloat16 {
return __spirv_ConvertBF16ToFINTEL(a);
#endif
#else
(void)a;
throw exception{errc::feature_not_supported,
"Bfloat16 conversion is not supported on host device"};
// Shift temporary variable to silence the warning
uint32_t bits = a;
bits <<= 16;
return static_cast<float>(bits);
#endif
}

Expand Down
88 changes: 88 additions & 0 deletions sycl/test/extensions/bfloat16_host.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
//==------------ bfloat16_host.cpp - SYCL vectors test ---------------------==//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

// RUN: %clangxx -fsycl %s -o %t.out
// RUN: %RUN_ON_HOST %t.out
#include <sycl/ext/oneapi/experimental/bfloat16.hpp>
#include <sycl/sycl.hpp>

#include <cmath>
#include <cstdint>
#include <iostream>
#include <limits>
#include <string>

// Helper to convert the expected bits to float value to compare with the result
typedef union {
float Value;
struct {
uint32_t Mantissa : 23;
uint32_t Exponent : 8;
uint32_t Sign : 1;
} RawData;
} floatConvHelper;

float bitsToFloatConv(std::string Bits) {
floatConvHelper Helper;
Helper.RawData.Sign = static_cast<uint32_t>(Bits[0] - '0');
uint32_t Exponent = 0;
for (size_t I = 1; I != 9; ++I)
Exponent = Exponent + static_cast<uint32_t>(Bits[I] - '0') * pow(2, 8 - I);
Helper.RawData.Exponent = Exponent;
uint32_t Mantissa = 0;
for (size_t I = 9; I != 32; ++I)
Mantissa = Mantissa + static_cast<uint32_t>(Bits[I] - '0') * pow(2, 31 - I);
Helper.RawData.Mantissa = Mantissa;
return Helper.Value;
}

bool check_bf16_from_float(float Val, uint16_t Expected) {
uint16_t Result = sycl::ext::oneapi::experimental::bfloat16::from_float(Val);
if (Result != Expected) {
std::cout << "from_float check for Val = " << Val << " failed!\n"
<< "Expected " << Expected << " Got " << Result << "\n";
return false;
}
return true;
}

bool check_bf16_to_float(uint16_t Val, float Expected) {
float Result = sycl::ext::oneapi::experimental::bfloat16::to_float(Val);
if (Result != Expected) {
std::cout << "to_float check for Val = " << Val << " failed!\n"
<< "Expected " << Expected << " Got " << Result << "\n";
return false;
}
return true;
}

int main() {
bool Success =
check_bf16_from_float(0.0f, std::stoi("0000000000000000", nullptr, 2));
Success &=
check_bf16_from_float(42.0f, std::stoi("100001000101000", nullptr, 2));
Success &= check_bf16_from_float(std::numeric_limits<float>::min(),
std::stoi("0000000010000000", nullptr, 2));
Success &= check_bf16_from_float(std::numeric_limits<float>::max(),
std::stoi("0111111110000000", nullptr, 2));
Success &= check_bf16_from_float(std::numeric_limits<float>::quiet_NaN(),
std::stoi("1111111111000001", nullptr, 2));

Success &= check_bf16_to_float(
0, bitsToFloatConv(std::string("00000000000000000000000000000000")));
Success &= check_bf16_to_float(
1, bitsToFloatConv(std::string("01000111100000000000000000000000")));
Success &= check_bf16_to_float(
42, bitsToFloatConv(std::string("01001010001010000000000000000000")));
Success &= check_bf16_to_float(
std::numeric_limits<uint16_t>::max(),
bitsToFloatConv(std::string("01001111011111111111111100000000")));
if (!Success)
return -1;
return 0;
}