Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions sycl/include/sycl/ext/intel/esimd/xmx/dpas.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,10 @@ namespace ext::intel::esimd::xmx {
namespace detail {

template <typename T> constexpr dpas_argument_type dpas_precision_from_type() {
// TODO: add support for tfloat32 here.
if constexpr (std::is_same_v<T, sycl::half>)
if constexpr (std::is_same_v<T,
sycl::ext::intel::experimental::esimd::tfloat32>)
return dpas_argument_type::tf32;
else if constexpr (std::is_same_v<T, sycl::half>)
return dpas_argument_type::fp16;
else if constexpr (std::is_same_v<T,
sycl::ext::oneapi::experimental::bfloat16>)
Expand Down
21 changes: 21 additions & 0 deletions sycl/test/esimd/dpas.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,7 @@ SYCL_ESIMD_FUNCTION SYCL_EXTERNAL void xmx_func() {
constexpr int K_half = 8 * 2;
constexpr int K_bf16 = 8 * 2;
constexpr int K_int8x2 = 8 * 4;
constexpr int K_tf32 = 8 * 1;
constexpr int N_pvc = 16;
constexpr int N_dg2 = 8;

Expand Down Expand Up @@ -338,6 +339,26 @@ SYCL_ESIMD_FUNCTION SYCL_EXTERNAL void xmx_func() {
// CHECK: call <8 x float> @llvm.genx.dpasw.nosrc0.v8f32.v64i32.v4i32(<64 x i32> {{[^,]+}}, <4 x i32> {{[^,]+}}, i32 17304074)
}

{ // ======= DPAS TFLOAT32 ===================================================
simd<float, M_one *N_pvc> R_f = 0;
simd<float, M_one *N_pvc> C_f = 0;

simd<sycl::ext::intel::experimental::esimd::tfloat32, K_tf32 *N_pvc> B_tf =
0;
simd<sycl::ext::intel::experimental::esimd::tfloat32, M_one *K_tf32> A_tf =
0;

// ------------------- TFLOAT32: WITH ACC OPERAND --------------------------
R_f = xmx::dpas<8, 1, float>(C_f, B_tf, A_tf);
zoo(R_f);
// CHECK: call <16 x float> @llvm.genx.dpas2.v16f32.v16f32.v128i32.v8i32(<16 x float> {{[^,]+}}, <128 x i32> {{[^,]+}}, <8 x i32> {{[^,]+}}, i32 12, i32 12, i32 8, i32 1, i32 1, i32 1)

// ------------------- TFLOAT32: NO ACC OPERAND ----------------------------
R_f = xmx::dpas<8, 1, float>(B_tf, A_tf);
zoo(R_f);
// CHECK: call <16 x float> @llvm.genx.dpas.nosrc0.v16f32.v128i32.v8i32(<128 x i32> {{[^,]+}}, <8 x i32> {{[^,]+}}, i32 17304588)
}

xmx_func_end();
// CHECK: call spir_func void @_Z12xmx_func_endv()
}