From 32a6210f2a0cd00402269a2a5cd59fa07466c46b Mon Sep 17 00:00:00 2001 From: Vyacheslav N Klochkov Date: Mon, 3 Oct 2022 13:09:41 -0700 Subject: [PATCH] [ESIMDS] Support tfloat32 types in dpas() Signed-off-by: Vyacheslav N Klochkov --- .../include/sycl/ext/intel/esimd/xmx/dpas.hpp | 6 ++++-- sycl/test/esimd/dpas.cpp | 21 +++++++++++++++++++ 2 files changed, 25 insertions(+), 2 deletions(-) diff --git a/sycl/include/sycl/ext/intel/esimd/xmx/dpas.hpp b/sycl/include/sycl/ext/intel/esimd/xmx/dpas.hpp index 8b16e6b6b0119..69d27ef2b8f52 100644 --- a/sycl/include/sycl/ext/intel/esimd/xmx/dpas.hpp +++ b/sycl/include/sycl/ext/intel/esimd/xmx/dpas.hpp @@ -24,8 +24,10 @@ namespace ext::intel::esimd::xmx { namespace detail { template constexpr dpas_argument_type dpas_precision_from_type() { - // TODO: add support for tfloat32 here. - if constexpr (std::is_same_v) + if constexpr (std::is_same_v) + return dpas_argument_type::tf32; + else if constexpr (std::is_same_v) return dpas_argument_type::fp16; else if constexpr (std::is_same_v) diff --git a/sycl/test/esimd/dpas.cpp b/sycl/test/esimd/dpas.cpp index 3e9b9d519d55f..207886aaa6eec 100644 --- a/sycl/test/esimd/dpas.cpp +++ b/sycl/test/esimd/dpas.cpp @@ -207,6 +207,7 @@ SYCL_ESIMD_FUNCTION SYCL_EXTERNAL void xmx_func() { constexpr int K_half = 8 * 2; constexpr int K_bf16 = 8 * 2; constexpr int K_int8x2 = 8 * 4; + constexpr int K_tf32 = 8 * 1; constexpr int N_pvc = 16; constexpr int N_dg2 = 8; @@ -338,6 +339,26 @@ SYCL_ESIMD_FUNCTION SYCL_EXTERNAL void xmx_func() { // CHECK: call <8 x float> @llvm.genx.dpasw.nosrc0.v8f32.v64i32.v4i32(<64 x i32> {{[^,]+}}, <4 x i32> {{[^,]+}}, i32 17304074) } + { // ======= DPAS TFLOAT32 =================================================== + simd R_f = 0; + simd C_f = 0; + + simd B_tf = + 0; + simd A_tf = + 0; + + // ------------------- TFLOAT32: WITH ACC OPERAND -------------------------- + R_f = xmx::dpas<8, 1, float>(C_f, B_tf, A_tf); + zoo(R_f); + // CHECK: call <16 x float> @llvm.genx.dpas2.v16f32.v16f32.v128i32.v8i32(<16 x float> {{[^,]+}}, <128 x i32> {{[^,]+}}, <8 x i32> {{[^,]+}}, i32 12, i32 12, i32 8, i32 1, i32 1, i32 1) + + // ------------------- TFLOAT32: NO ACC OPERAND ---------------------------- + R_f = xmx::dpas<8, 1, float>(B_tf, A_tf); + zoo(R_f); + // CHECK: call <16 x float> @llvm.genx.dpas.nosrc0.v16f32.v128i32.v8i32(<128 x i32> {{[^,]+}}, <8 x i32> {{[^,]+}}, i32 17304588) + } + xmx_func_end(); // CHECK: call spir_func void @_Z12xmx_func_endv() }