From 007bacca85dab8d8b3a09fa8c5d069480c9ac837 Mon Sep 17 00:00:00 2001 From: Peter Caday Date: Tue, 26 Aug 2025 10:59:43 -0700 Subject: [PATCH 01/16] [CuTe core] Add ScaledBasis division operator --- include/cute/numeric/arithmetic_tuple.hpp | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/include/cute/numeric/arithmetic_tuple.hpp b/include/cute/numeric/arithmetic_tuple.hpp index 016ac5b6cb..6e6d915939 100644 --- a/include/cute/numeric/arithmetic_tuple.hpp +++ b/include/cute/numeric/arithmetic_tuple.hpp @@ -408,6 +408,15 @@ operator*(ScaledBasis const& e, B const& b) { return ScaledBasis{r}; } +// Division +template +CUTE_HOST_DEVICE constexpr +auto +operator/(ScaledBasis const& e, B const& b) { + auto r = e.value() / b; + return ScaledBasis{r}; +} + // Addition template CUTE_HOST_DEVICE constexpr From 4329c197d6ad0a7fb194819ec9e00c97ffcb830d Mon Sep 17 00:00:00 2001 From: Peter Caday Date: Mon, 15 Sep 2025 15:20:57 -0700 Subject: [PATCH 02/16] [CuTe core] Make basis_{get,value} constexpr again --- include/cute/numeric/arithmetic_tuple.hpp | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/include/cute/numeric/arithmetic_tuple.hpp b/include/cute/numeric/arithmetic_tuple.hpp index 6e6d915939..8fd3b98abf 100644 --- a/include/cute/numeric/arithmetic_tuple.hpp +++ b/include/cute/numeric/arithmetic_tuple.hpp @@ -259,14 +259,16 @@ using E = ScaledBasis,Ns...>; // Apply the Ns... pack to another Tuple template -CUTE_HOST_DEVICE decltype(auto) +CUTE_HOST_DEVICE constexpr +decltype(auto) basis_get(T const&, Tuple&& t) { return static_cast(t); } template -CUTE_HOST_DEVICE decltype(auto) +CUTE_HOST_DEVICE constexpr +decltype(auto) basis_get(ScaledBasis const&, Tuple&& t) { if constexpr (sizeof...(Ns) == 0) { @@ -278,7 +280,8 @@ basis_get(ScaledBasis const&, Tuple&& t) } template -CUTE_HOST_DEVICE decltype(auto) +CUTE_HOST_DEVICE constexpr +decltype(auto) basis_value(T const& e) { if constexpr (is_scaled_basis::value) { return e.value(); From 221453b8b14ed7be500a8e65aa3bdbcea1f5a35d Mon Sep 17 00:00:00 2001 From: Peter Caday Date: Mon, 18 Aug 2025 22:03:31 -0700 Subject: [PATCH 03/16] [CuTe core] Enable make_fragment_like for coordinate layouts --- include/cute/algorithm/functional.hpp | 2 + include/cute/layout.hpp | 47 ++++++++++++++++++++++- include/cute/numeric/arithmetic_tuple.hpp | 45 ++++++++++++++++++++++ 3 files changed, 93 insertions(+), 1 deletion(-) diff --git a/include/cute/algorithm/functional.hpp b/include/cute/algorithm/functional.hpp index 5c56eb5cc8..c5f730b1c7 100644 --- a/include/cute/algorithm/functional.hpp +++ b/include/cute/algorithm/functional.hpp @@ -1,5 +1,6 @@ /*************************************************************************************************** * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (C) 2025 Intel Corporation, All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without @@ -188,6 +189,7 @@ CUTE_BINARY_OP(less_equal, <=); CUTE_NAMED_BINARY_OP(max_fn, cute::max); CUTE_NAMED_BINARY_OP(min_fn, cute::min); +CUTE_NAMED_BINARY_OP(gcd_fn, cute::gcd); #undef CUTE_BINARY_OP #undef CUTE_NAMED_BINARY_OP diff --git a/include/cute/layout.hpp b/include/cute/layout.hpp index cb161369cb..76446f0244 100644 --- a/include/cute/layout.hpp +++ b/include/cute/layout.hpp @@ -1,5 +1,6 @@ /*************************************************************************************************** * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (C) 2025 Intel Corporation, All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without @@ -455,7 +456,9 @@ auto make_fragment_like(Layout const& layout) { constexpr int R = Layout::rank; - if constexpr (R > 1 && is_static::value) { + if constexpr (is_arithmetic_tuple_like::value) { + return make_fragment_like(project_strides(layout)); + } else if constexpr (R > 1 && is_static::value) { return tiled_product(make_layout(get<0>(layout.shape()), compact_major(filter_zeros(get<0>(layout.stride()), get<0>(layout.shape())))), make_ordered_layout(take<1,R>(layout.shape()), take<1,R>(layout.stride()))); @@ -654,6 +657,35 @@ coshape(Layout const& layout) return transform_leaf(co_coord, [](auto c) { return c + Int<1>{}; }); } +// Compute max(a_leaf * b_leaf) across all ArithmeticTuple leaf pairs of a/b, +// where 'max' and '*' are acting in elementwise fashion on tuples. +template +CUTE_HOST_DEVICE constexpr +auto +inner_product_atuple_max(IntTupleA const& a, IntTupleB const& b) +{ + if constexpr (is_tuple::value && is_tuple::value) { + static_assert(tuple_size::value == tuple_size::value, "Mismatched ranks"); + return transform_apply(a, b, [](auto const& x, auto const& y) { return inner_product_atuple_max(x,y); }, + [](auto const&... v) { return atuple_max(v...); }); + } else { + return a * b; + } + + CUTE_GCC_UNREACHABLE; +} + +// Return the codomain shape of an ArithmeticTuple-strided layout, +// treating each dimension of the ArithmeticTuple separately. +template +CUTE_HOST_DEVICE constexpr +auto +atuple_coshape(Layout const& layout) +{ + auto flayout = filter(flatten(layout)); + return inner_product_atuple_max(shape(flayout), stride(flayout)); +} + // Return the codomain size of a mode // @return M smallest integer such that // size(@a sub_layout(c)) < M for all c < size(@a sub_layout) @@ -681,6 +713,19 @@ crd2idx(Coord const& c, Layout const& layout) return crd2idx(c, layout.shape(), layout.stride()); } +// Project an ArithmeticTuple-strided layout to a standard layout, +// by composing it with a LayoutLeft layout. +template +CUTE_HOST_DEVICE constexpr +auto +project_strides(Layout const& layout) +{ + if constexpr (is_arithmetic_tuple_like::value) + return composition(make_layout(atuple_coshape(layout)), layout); + else + return layout; +} + // // Slice and Dice a layout // diff --git a/include/cute/numeric/arithmetic_tuple.hpp b/include/cute/numeric/arithmetic_tuple.hpp index 8fd3b98abf..404fd23aa9 100644 --- a/include/cute/numeric/arithmetic_tuple.hpp +++ b/include/cute/numeric/arithmetic_tuple.hpp @@ -1,5 +1,6 @@ /*************************************************************************************************** * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (C) 2025 Intel Corporation, All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without @@ -243,6 +244,13 @@ struct is_scaled_basis : false_type {}; template struct is_scaled_basis> : true_type {}; +template +struct is_arithmetic_tuple_like : false_type {}; +template +struct is_arithmetic_tuple_like> : true_type {}; +template +struct is_arithmetic_tuple_like> : true_type {}; + template struct is_integral> : true_type {}; @@ -468,6 +476,43 @@ operator+(ScaledBasis const& t, C) { CUTE_GCC_UNREACHABLE; } +// Component-wise maximum +template +CUTE_HOST_DEVICE constexpr +auto +atuple_max(ArithmeticTuple const& t, ArithmeticTuple const& u) { + constexpr int R = cute::max(int(sizeof...(T)), int(sizeof...(U))); + return transform_apply(append(t,Int<0>{}), append(u,Int<0>{}), max_fn{}, [](auto const&... a){ return make_arithmetic_tuple(a...); }); +} + +template +CUTE_HOST_DEVICE constexpr +auto +atuple_max(ScaledBasis const& t, ScaledBasis const& u) { + return atuple_max(as_arithmetic_tuple(t), as_arithmetic_tuple(u)); +} + +template +CUTE_HOST_DEVICE constexpr +auto +atuple_max(ArithmeticTuple const& t, ScaledBasis const& u) { + return atuple_max(t, as_arithmetic_tuple(u)); +} + +template +CUTE_HOST_DEVICE constexpr +auto +atuple_max(ScaledBasis const& t, ArithmeticTuple const& u) { + return atuple_max(as_arithmetic_tuple(t), u); +} + +template +CUTE_HOST_DEVICE constexpr +auto +atuple_max(T0 const& t0, T1 const& t1, Ts const&... ts) { + return atuple_max(t0, atuple_max(t1, ts...)); +} + // // Display utilities // From 098f7ecc520650078a8842b791d96846060fad4f Mon Sep 17 00:00:00 2001 From: Peter Caday Date: Mon, 18 Aug 2025 22:04:55 -0700 Subject: [PATCH 04/16] [CuTe core] Add ThrCopy::partition_fragment_{S,D} methods mirroring ThrMMA --- include/cute/atom/copy_atom.hpp | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/include/cute/atom/copy_atom.hpp b/include/cute/atom/copy_atom.hpp index 20bebfc53d..9149065097 100644 --- a/include/cute/atom/copy_atom.hpp +++ b/include/cute/atom/copy_atom.hpp @@ -399,6 +399,20 @@ struct ThrCopy // "Expected ValType for tiling DstTensor."); return make_tensor(static_cast(dtensor).data(), TiledCopy::retile(dtensor.layout())); } + + template + CUTE_HOST_DEVICE + auto + partition_fragment_S(STensor&& stensor) const { + return make_fragment_like(partition_S(stensor)); + } + + template + CUTE_HOST_DEVICE + auto + partition_fragment_D(DTensor&& dtensor) const { + return make_fragment_like(partition_D(dtensor)); + } }; From 80774cd7576c7693fe96da7972b8b1e0e30a4934 Mon Sep 17 00:00:00 2001 From: Peter Caday Date: Thu, 21 Aug 2025 22:17:54 -0700 Subject: [PATCH 05/16] [CuTe core] Additional TiledCopy helpers --- include/cute/atom/copy_atom.hpp | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/include/cute/atom/copy_atom.hpp b/include/cute/atom/copy_atom.hpp index 9149065097..a709a658a4 100644 --- a/include/cute/atom/copy_atom.hpp +++ b/include/cute/atom/copy_atom.hpp @@ -196,6 +196,8 @@ struct TiledCopy : Copy_Atom using AtomNumThr = decltype(size<0>(AtomLayoutRef{})); using AtomNumVal = decltype(size<1>(AtomLayoutRef{})); + using Atom = Copy_Atom; + // Layout information for the TiledCopy using Tiler_MN = ShapeTiler_MN; using TiledLayout_TV = LayoutCopy_TV; @@ -205,6 +207,16 @@ struct TiledCopy : Copy_Atom CUTE_STATIC_ASSERT_V(TiledNumThr{} % AtomNumThr{} == Int<0>{}, "TiledCopy uses too few thrs for selected CopyAtom"); CUTE_STATIC_ASSERT_V(TiledNumVal{} % AtomNumVal{} == Int<0>{}, "TiledCopy uses too few vals for selected CopyAtom"); + // Additional Trait parameters/transformations + template + CUTE_HOST_DEVICE + auto + with(TraitsArgs&&... args) const { + TiledCopy result; + static_cast(result) = Copy_Atom::with(static_cast(args)...); + return result; + } + // Tile a tensor or a layout from shape // (M,N,...) // to shape From dcaa47faa62d91fabfae3fb0bbd5a66a5b2d0c26 Mon Sep 17 00:00:00 2001 From: Peter Caday Date: Mon, 18 Aug 2025 22:13:23 -0700 Subject: [PATCH 06/16] [CuTe] Mark existing Xe atoms as legacy --- include/cute/arch/copy_xe.hpp | 37 +- include/cute/arch/copy_xe_legacy.hpp | 71 + ...copy_xe_U16.hpp => copy_xe_legacy_U16.hpp} | 0 ...copy_xe_U32.hpp => copy_xe_legacy_U32.hpp} | 0 .../{copy_xe_U4.hpp => copy_xe_legacy_U4.hpp} | 0 ...copy_xe_U64.hpp => copy_xe_legacy_U64.hpp} | 0 .../{copy_xe_U8.hpp => copy_xe_legacy_U8.hpp} | 0 ...builtin.hpp => copy_xe_legacy_builtin.hpp} | 0 ..._xe_spirv.hpp => copy_xe_legacy_spirv.hpp} | 0 .../arch/{mma_xe.hpp => mma_xe_legacy.hpp} | 5 +- ..._builtin.hpp => mma_xe_legacy_builtin.hpp} | 0 ...a_xe_spirv.hpp => mma_xe_legacy_spirv.hpp} | 0 include/cute/atom/copy_atom.hpp | 2 + include/cute/atom/copy_traits_xe.hpp | 2587 ---------------- include/cute/atom/copy_traits_xe_legacy.hpp | 2628 +++++++++++++++++ include/cute/atom/mma_atom.hpp | 23 +- ...traits_xe.hpp => mma_traits_xe_legacy.hpp} | 4 +- 17 files changed, 2720 insertions(+), 2637 deletions(-) create mode 100644 include/cute/arch/copy_xe_legacy.hpp rename include/cute/arch/{copy_xe_U16.hpp => copy_xe_legacy_U16.hpp} (100%) rename include/cute/arch/{copy_xe_U32.hpp => copy_xe_legacy_U32.hpp} (100%) rename include/cute/arch/{copy_xe_U4.hpp => copy_xe_legacy_U4.hpp} (100%) rename include/cute/arch/{copy_xe_U64.hpp => copy_xe_legacy_U64.hpp} (100%) rename include/cute/arch/{copy_xe_U8.hpp => copy_xe_legacy_U8.hpp} (100%) rename include/cute/arch/{copy_xe_builtin.hpp => copy_xe_legacy_builtin.hpp} (100%) rename include/cute/arch/{copy_xe_spirv.hpp => copy_xe_legacy_spirv.hpp} (100%) rename include/cute/arch/{mma_xe.hpp => mma_xe_legacy.hpp} (99%) rename include/cute/arch/{mma_xe_builtin.hpp => mma_xe_legacy_builtin.hpp} (100%) rename include/cute/arch/{mma_xe_spirv.hpp => mma_xe_legacy_spirv.hpp} (100%) create mode 100644 include/cute/atom/copy_traits_xe_legacy.hpp rename include/cute/atom/{mma_traits_xe.hpp => mma_traits_xe_legacy.hpp} (99%) diff --git a/include/cute/arch/copy_xe.hpp b/include/cute/arch/copy_xe.hpp index a233ecf073..ca6c007ad0 100644 --- a/include/cute/arch/copy_xe.hpp +++ b/include/cute/arch/copy_xe.hpp @@ -1,5 +1,6 @@ /*************************************************************************************************** * Copyright (c) 2024 - 2024 Codeplay Software Ltd. All rights reserved. + * Copyright (C) 2025 Intel Corporation, All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without @@ -30,42 +31,10 @@ **************************************************************************************************/ #pragma once -#if defined(__SYCL_DEVICE_ONLY__) && defined(SYCL_INTEL_TARGET) -#define CUTE_ARCH_COPY_XE_ENABLED -#endif - -#if defined(CUTE_ARCH_COPY_XE_ENABLED) && ((defined(__INTEL_LLVM_COMPILER) && (__INTEL_LLVM_COMPILER < 20250200)) || defined(CUTLASS_SYCL_BUILTIN_ENABLE)) -#include -#elif defined(CUTE_ARCH_COPY_XE_ENABLED) -#include -#endif - -#include -#include -#include -#include -#include - -#ifdef __SYCL_DEVICE_ONLY__ -SYCL_EXTERNAL __attribute__((convergent)) void __spirv_ControlBarrierWaitINTEL(int execution_scope, int memory_scope, int memory_semantics); -SYCL_EXTERNAL __attribute__((convergent)) void __spirv_ControlBarrierArriveINTEL(int execution_scope, int memory_scope, int memory_semantics); -#endif namespace cute { -// scope = 3 is for subgroup, scope = 2 is for workgroup -CUTE_HOST_DEVICE void barrier_arrive(int scope, int memory_scope = 0, int memory_semantics = 0) { -#ifdef __SYCL_DEVICE_ONLY__ - __spirv_ControlBarrierArriveINTEL(scope, memory_scope, memory_semantics); -#endif -} -CUTE_HOST_DEVICE void barrier_wait(int scope, int memory_scope = 0, int memory_semantics = 0) { -#ifdef __SYCL_DEVICE_ONLY__ - __spirv_ControlBarrierWaitINTEL(scope, memory_scope, memory_semantics); -#endif -} - template struct XE_ATOMIC { using SRegisters = S[1]; @@ -112,7 +81,7 @@ struct XE_1D_LDSM { sg, &src, *reinterpret_cast *>(&dst), props); #else CUTE_INVALID_CONTROL_PATH("Trying to use block loads on non-Xe hardware"); - #endif + #endif } }; @@ -141,7 +110,7 @@ struct XE_1D_LOAD_GLOBAL { sg, &src, *reinterpret_cast *>(&dst), props); #else CUTE_INVALID_CONTROL_PATH("Trying to use block loads on non-Xe hardware"); - #endif + #endif } }; diff --git a/include/cute/arch/copy_xe_legacy.hpp b/include/cute/arch/copy_xe_legacy.hpp new file mode 100644 index 0000000000..a414885033 --- /dev/null +++ b/include/cute/arch/copy_xe_legacy.hpp @@ -0,0 +1,71 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2024 Codeplay Software Ltd. All rights reserved. + * Copyright (C) 2025 Intel Corporation, All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#if defined(__SYCL_DEVICE_ONLY__) && defined(SYCL_INTEL_TARGET) +#define CUTE_ARCH_COPY_XE_ENABLED +#endif + +#if defined(CUTE_ARCH_COPY_XE_ENABLED) && ((defined(__INTEL_LLVM_COMPILER) && (__INTEL_LLVM_COMPILER < 20250200)) || defined(CUTLASS_SYCL_BUILTIN_ENABLE)) +#include +#elif defined(CUTE_ARCH_COPY_XE_ENABLED) +#include +#endif + +#include +#include +#include +#include +#include + +// FIXME: these are not copy-related and should be declared elsewhere. +#ifdef __SYCL_DEVICE_ONLY__ +SYCL_EXTERNAL __attribute__((convergent)) void __spirv_ControlBarrierWaitINTEL(int execution_scope, int memory_scope, int memory_semantics); +SYCL_EXTERNAL __attribute__((convergent)) void __spirv_ControlBarrierArriveINTEL(int execution_scope, int memory_scope, int memory_semantics); +#endif + +namespace cute +{ + +// scope = 3 is for subgroup, scop = 2 is for workgroup +CUTE_HOST_DEVICE void barrier_arrive(int scope, int memory_scope = 0, int memory_semantics = 0) { +#ifdef __SYCL_DEVICE_ONLY__ + __spirv_ControlBarrierArriveINTEL(scope, memory_scope, memory_semantics); +#endif +} +CUTE_HOST_DEVICE void barrier_wait(int scope, int memory_scope = 0, int memory_semantics = 0) { +#ifdef __SYCL_DEVICE_ONLY__ + __spirv_ControlBarrierWaitINTEL(scope, memory_scope, memory_semantics); +#endif +} + +} // end namespace cute diff --git a/include/cute/arch/copy_xe_U16.hpp b/include/cute/arch/copy_xe_legacy_U16.hpp similarity index 100% rename from include/cute/arch/copy_xe_U16.hpp rename to include/cute/arch/copy_xe_legacy_U16.hpp diff --git a/include/cute/arch/copy_xe_U32.hpp b/include/cute/arch/copy_xe_legacy_U32.hpp similarity index 100% rename from include/cute/arch/copy_xe_U32.hpp rename to include/cute/arch/copy_xe_legacy_U32.hpp diff --git a/include/cute/arch/copy_xe_U4.hpp b/include/cute/arch/copy_xe_legacy_U4.hpp similarity index 100% rename from include/cute/arch/copy_xe_U4.hpp rename to include/cute/arch/copy_xe_legacy_U4.hpp diff --git a/include/cute/arch/copy_xe_U64.hpp b/include/cute/arch/copy_xe_legacy_U64.hpp similarity index 100% rename from include/cute/arch/copy_xe_U64.hpp rename to include/cute/arch/copy_xe_legacy_U64.hpp diff --git a/include/cute/arch/copy_xe_U8.hpp b/include/cute/arch/copy_xe_legacy_U8.hpp similarity index 100% rename from include/cute/arch/copy_xe_U8.hpp rename to include/cute/arch/copy_xe_legacy_U8.hpp diff --git a/include/cute/arch/copy_xe_builtin.hpp b/include/cute/arch/copy_xe_legacy_builtin.hpp similarity index 100% rename from include/cute/arch/copy_xe_builtin.hpp rename to include/cute/arch/copy_xe_legacy_builtin.hpp diff --git a/include/cute/arch/copy_xe_spirv.hpp b/include/cute/arch/copy_xe_legacy_spirv.hpp similarity index 100% rename from include/cute/arch/copy_xe_spirv.hpp rename to include/cute/arch/copy_xe_legacy_spirv.hpp diff --git a/include/cute/arch/mma_xe.hpp b/include/cute/arch/mma_xe_legacy.hpp similarity index 99% rename from include/cute/arch/mma_xe.hpp rename to include/cute/arch/mma_xe_legacy.hpp index 763da5020f..7c3b13507b 100644 --- a/include/cute/arch/mma_xe.hpp +++ b/include/cute/arch/mma_xe_legacy.hpp @@ -1,5 +1,6 @@ /*************************************************************************************************** * Copyright (c) 2024 - 2024 Codeplay Software Ltd. All rights reserved. + * Copyright (C) 2025 Intel Corporation, All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without @@ -35,9 +36,9 @@ #endif #if defined(CUTE_ARCH_MMA_XE_ENABLED) && ((defined(__INTEL_LLVM_COMPILER) && (__INTEL_LLVM_COMPILER < 20250200)) || defined(CUTLASS_SYCL_BUILTIN_ENABLE)) -#include +#include #elif defined(CUTE_ARCH_MMA_XE_ENABLED) -#include +#include #endif #include diff --git a/include/cute/arch/mma_xe_builtin.hpp b/include/cute/arch/mma_xe_legacy_builtin.hpp similarity index 100% rename from include/cute/arch/mma_xe_builtin.hpp rename to include/cute/arch/mma_xe_legacy_builtin.hpp diff --git a/include/cute/arch/mma_xe_spirv.hpp b/include/cute/arch/mma_xe_legacy_spirv.hpp similarity index 100% rename from include/cute/arch/mma_xe_spirv.hpp rename to include/cute/arch/mma_xe_legacy_spirv.hpp diff --git a/include/cute/atom/copy_atom.hpp b/include/cute/atom/copy_atom.hpp index a709a658a4..797d2a22ae 100644 --- a/include/cute/atom/copy_atom.hpp +++ b/include/cute/atom/copy_atom.hpp @@ -1,5 +1,6 @@ /*************************************************************************************************** * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (C) 2025 Intel Corporation, All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without @@ -715,6 +716,7 @@ print(ThrCopy const& thr_copy) #if defined(SYCL_INTEL_TARGET) #include +#include #endif //////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cute/atom/copy_traits_xe.hpp b/include/cute/atom/copy_traits_xe.hpp index ffbf4ecd5e..fb3444952d 100644 --- a/include/cute/atom/copy_traits_xe.hpp +++ b/include/cute/atom/copy_traits_xe.hpp @@ -33,2442 +33,10 @@ #include #include -#include #include namespace cute { -namespace detail { - -static constexpr auto subgroup_size = 16; - -// ========== size_of_inst_bits ========== -template -static constexpr auto size_of_inst_bits = sizeof_bits_v; - -template -static constexpr auto size_of_inst_bits> = sizeof_bits_v; - - -// ========== is_transpose_load ========== -template -static constexpr bool is_transpose_load = false; - -template -static constexpr bool is_transpose_load>> = T::is_transpose; - - -// ========== is_stride_leftmost ========== -template -static constexpr bool is_stride_leftmost = std::is_same_v<_1, decltype(get<0>(T{}))>; - -template -static constexpr bool is_stride_leftmost> = std::is_same_v<_1, decltype(get<0>(T{}.stride()))>; - -// Swap the Src or Dst Layout of a Copy_Traits if the logical/memory layouts differ -template -auto get_logical_layout(LayoutIn &&, BlockShape &&) { - static_assert(cute::rank(BlockShape{}) == 2, "Expected 2D BlockShape for XE_2D copy op."); - static_assert(cute::rank(LayoutIn{}) == 2, "Expected 2D LayoutIn for XE_2D copy op."); - if constexpr (!is_matrix_B) { - return LayoutIn{}; - } else { - // (16, (32, 2)) - // ^-- the size of an element in bits - static_assert(size(LayoutIn{}) % size(BlockShape{}) == 0); - constexpr int ElemBitSize = size(LayoutIn{}) / size(BlockShape{}); - // Construct a generic row-major layout of the relevant size - using RowMajorLayout = - decltype(make_ordered_layout(Shape, BlockShape>{}, Step<_0, Step<_2, _1>>{})); - // Compose with LayoutIn to produce the transposed Copy_Traits layout - return right_inverse(RowMajorLayout{}).compose(LayoutIn{}); - } -} -} // end namespace detail - -template -struct choose_prefetch_for_type { - static_assert(dependent_false<>, "Invalid prefetch"); -}; - -// U4 -template <> -struct choose_prefetch_for_type<4, 1> { - using Prefetch = XE_2D_U16x1x32_LD_N; -}; - -template <> -struct choose_prefetch_for_type<4, 2> { - using Prefetch = XE_2D_U16x2x32_LD_N; -}; - -template <> -struct choose_prefetch_for_type<4, 4> { - using Prefetch = XE_2D_U16x4x32_LD_N; -}; - -template <> -struct choose_prefetch_for_type<4, 8> { - using Prefetch = XE_2D_U16x8x32_LD_N; -}; - -template <> -struct choose_prefetch_for_type<4, 16> { - using Prefetch = XE_2D_U16x16x32_LD_N; -}; - -template <> -struct choose_prefetch_for_type<4, 32> { - using Prefetch = XE_2D_U16x32x32_LD_N; -}; - -// U8 -template <> -struct choose_prefetch_for_type<8, 1> { - using Prefetch = XE_2D_Packed_U8x1x64_LD_N; -}; - -template <> -struct choose_prefetch_for_type<8, 2> { - using Prefetch = XE_2D_Packed_U8x2x64_LD_N; -}; - -template <> -struct choose_prefetch_for_type<8, 4> { - using Prefetch = XE_2D_Packed_U8x4x64_LD_N; -}; - -template <> -struct choose_prefetch_for_type<8, 8> { - using Prefetch = XE_2D_Packed_U8x8x64_LD_N; -}; - -template <> -struct choose_prefetch_for_type<8, 16> { - using Prefetch = XE_2D_Packed_U8x16x64_LD_N; -}; - -template <> -struct choose_prefetch_for_type<8, 32> { - using Prefetch = XE_2D_Packed_U8x32x64_LD_N; -}; - -// U16 -template <> -struct choose_prefetch_for_type<16, 1> { - using Prefetch = XE_2D_U16x1x32_LD_N; -}; - -template <> -struct choose_prefetch_for_type<16, 2> { - using Prefetch = XE_2D_U16x2x32_LD_N; -}; - -template <> -struct choose_prefetch_for_type<16, 4> { - using Prefetch = XE_2D_U16x4x32_LD_N; -}; - -template <> -struct choose_prefetch_for_type<16, 8> { - using Prefetch = XE_2D_U16x8x32_LD_N; -}; - -template <> -struct choose_prefetch_for_type<16, 16> { - using Prefetch = XE_2D_U16x16x32_LD_N; -}; - -template <> -struct choose_prefetch_for_type<16, 32> { - using Prefetch = XE_2D_U16x32x32_LD_N; -}; - -// U32 -template <> -struct choose_prefetch_for_type<32, 1> { - using Prefetch = XE_2D_U32x1x16_LD_N; -}; - -template <> -struct choose_prefetch_for_type<32, 2> { - using Prefetch = XE_2D_U32x2x16_LD_N; -}; - -template <> -struct choose_prefetch_for_type<32, 4> { - using Prefetch = XE_2D_U32x4x16_LD_N; -}; - -template <> -struct choose_prefetch_for_type<32, 8> { - using Prefetch = XE_2D_U32x8x16_LD_N; -}; - -template <> -struct choose_prefetch_for_type<32, 16> { - using Prefetch = XE_2D_U32x16x16_LD_N; -}; - -template <> -struct choose_prefetch_for_type<32, 32> { - using Prefetch = XE_2D_U32x32x16_LD_N; -}; - -template -CUTE_HOST_DEVICE auto prefetch_selector(Tensor const& tensor) { - constexpr size_t cacheline_bytes = 64; - using dtype = typename Tensor::value_type; - constexpr size_t dtype_size_bits = sizeof_bits_v; - constexpr bool is_tensor_M_major = detail::is_stride_leftmost; - using CopyThreadShape = std::conditional_t, _1>, - Shape<_1, Int>>; - - constexpr int tile_contig_size = is_tensor_M_major ? size<0>(TileShape{}) : size<1>(TileShape{}); - constexpr int tile_non_contig_size = is_tensor_M_major ? size<1>(TileShape{}) : size<0>(TileShape{}); - - // block here is what is prefetched in one atom execution - // min(32,32)-> 32 (256, 32) -> 32 - static constexpr auto block_contig_size = cute::min(tile_contig_size, cacheline_bytes * sizeof_bits_v / sizeof_bits_v); - // A: 1 -> trans or B 256/32 = 8 - static constexpr auto nums_blocks_contig = ceil_div(tile_contig_size, block_contig_size); - - // layout of sub groups - // A shape<32,1> / trans or B shape<4,8> - constexpr int sgs_contig = cute::gcd(Num_SGs, nums_blocks_contig); - constexpr int sgs_non_contig = Num_SGs / sgs_contig; - - constexpr auto block_non_contig_size = tile_non_contig_size / sgs_non_contig; - - using PrefetchTilingLayout = std::conditional_t, Int>, Int>, - Stride>, Int>>, - Layout, Shape, Int>>, - Stride, Stride<_1, Int>>> - >; - - using PrefetchOp = typename choose_prefetch_for_type::Prefetch; - using PrefetchTraits = Copy_Traits; - using PrefetchAtom = Copy_Atom; - using Scalar = Int / dtype_size_bits)>; - using ScalarLayout = std::conditional_t>, - Layout>>; - using ScalarPrefetchShape = decltype(product_each(raked_product(ScalarLayout{}, - Layout{}).shape())); - using PrefetchValLayout = decltype(make_layout(shape_div(ScalarPrefetchShape{}, CopyThreadShape{}))); - return make_tiled_copy(PrefetchAtom{}.with(tensor), PrefetchTilingLayout{}, PrefetchValLayout{}); - -} - -template -CUTE_HOST_DEVICE auto prefetch_selector(TiledCopy const& tiled_copy) { - using Tiled_Copy = TiledCopy; - constexpr int subgroup_size = size(typename Tiled_Copy::Traits_LD_t::ThrID{}); - int M, N; - if constexpr (Tiled_Copy::is_tensor_M_major) { - M = tiled_copy.width; - N = tiled_copy.height; - } else{ - M = tiled_copy.height; - N = tiled_copy.width; - } - // L is not used in prefetch_selector and we do not have the correct value here. Just set it to some arbitrary value. - int L = 1; - auto data = make_gmem_ptr(static_cast(tiled_copy.base_ptr)); - auto shape = make_shape(M, N, L); - auto stride = [=](){ - if constexpr (Tiled_Copy::is_tensor_M_major){ - return make_stride(_1{}, tiled_copy.pitch, tiled_copy.stride_l); - }else{ - return make_stride(tiled_copy.pitch, _1{}, tiled_copy.stride_l); - } - }(); - auto tensor = make_tensor(data, make_layout(shape, stride)); - return cute::prefetch_selector(tensor); -} - - -template , int64_t>> -struct XE_2D_LD_Unpack { - using BlockShape = typename CopyOp::BlockShape; // this is not the same as Traits_LD_t::BlockShape iff is_matrix_B - using Traits_LD_t = Copy_Traits; - static constexpr auto stride_rank = rank(StrideOrTensor{}); - static_assert(stride_rank == 2 || stride_rank == 3); - - // Assume LD_T/LD_N will be used for column/row major matrices respectively - static constexpr bool is_transpose_copy = detail::is_transpose_load; - - // We need to reverse some parameters becasue intel xe 2d copy intrinsic always assume the matrix use (M, N):(N, 1) layout - // M-major if we label the matrix shape (M,N,L). M-major for matrix A or C is col-major. For matrix B it is row-major. - static constexpr bool is_tensor_M_major = detail::is_stride_leftmost; - - // For matrix B cute internally has transposed representation compared to other matrices, for cute its shape is (N,K) - // Intel copy instructions, on the other hand follow blas convention, where matrix B has shape (K,N) - static constexpr bool is_matrix_B = is_tensor_M_major ^ is_transpose_copy; - - using CopyThreadShape = Shape<_1, Int>; - // we can not use Traits_LD_t::BlockShape as this is a parent class of Traits_LD_t, so that would be recursion. Recalculate it instead. - using DefaultValLayout = decltype(make_layout(shape_div(std::conditional_t{}, CopyThreadShape{}))); - - template - using DefaultTiledCopy = decltype(make_tiled_copy(Copy_Atom{}, Layout{}, DefaultValLayout{})); - - // 2d copy parameters - const void *base_ptr; - uint32_t width; - uint32_t height; - uint32_t pitch; - uint32_t stride_l = 0; - - - XE_2D_LD_Unpack(const void *ptr, uint32_t y, - uint32_t x, uint32_t p = 0) : base_ptr(ptr) { - if constexpr (is_tensor_M_major) { - width = y; - height = x; - } - else { - width = x; - height = y; - } - - pitch = (p == 0 ? width : p); - } - - template - XE_2D_LD_Unpack(Tensor const &tensor) { - base_ptr = raw_pointer_cast(tensor.data()); - - if constexpr (is_tensor_M_major) - { - width = size<0>(tensor.shape()); - height = size<1>(tensor.shape()); - pitch = size<1>(tensor.stride()); - } - else - { - width = size<1>(tensor.shape()); - height = size<0>(tensor.shape()); - pitch = size<0>(tensor.stride()); - } - - if constexpr (stride_rank == 3) { - stride_l = size<2>(tensor.stride()); - } - } - - XE_2D_LD_Unpack(Traits_LD_t const &traits) : base_ptr(traits.base_ptr), - width(traits.width), height(traits.height), pitch(traits.pitch), - stride_l(traits.stride_l) {} - - XE_2D_LD_Unpack() {} - - template - CUTE_HOST_DEVICE friend constexpr void - copy_unpack(Traits_LD_t const &traits, Tensor const &src, - Tensor &dst) { - using dtype = typename Tensor::value_type; - constexpr int dtype_bits = sizeof_bits_v; - - static_assert(is_rmem::value); - static_assert(size(SLayout{}) * dtype_bits == size<1>(typename Traits_LD_t::SrcLayout{}), - "Src tensor size does not match copy atom size."); - static_assert(size(DLayout{}) * dtype_bits == size<1>(typename Traits_LD_t::DstLayout{}), - "Dst tensor size does not match copy atom size."); - - dtype *base_addr = (dtype *)traits.base_ptr; - - auto [m, n, l] = src.data().coord_; - int x = is_tensor_M_major ? m : n; - int y = is_tensor_M_major ? n : m; - - constexpr auto inst_size_bits = detail::size_of_inst_bits; - - CopyOp::copy(((uint8_t*)base_addr) + static_cast(l) * traits.stride_l * sizeof_bits_v / 8, - (traits.width * sizeof_bits_v) / sizeof_bits_v, traits.height, - (traits.pitch * sizeof_bits_v) / sizeof_bits_v, - intel::coord_t{(int)(x * sizeof_bits_v / inst_size_bits), y}, - raw_pointer_cast(&((&*dst.data())[0]))); - } - - template - CUTE_HOST_DEVICE friend constexpr void - prefetch(Copy_Atom, CopyType> const& atom, - Tensor const& src) { - using dtype = typename Copy_Atom, CopyType>::ValType; - - static_assert(detail::has_prefetch); - static_assert(size(SLayout{}) * sizeof_bits_v == size<1>(typename Copy_Atom, CopyType>::SrcLayout{}), - "Src tensor size does not match copy atom for prefetch size"); - - dtype *base_addr = (dtype *)atom.base_ptr; - - auto [m, n, l] = src.data().coord_; - - int x = is_tensor_M_major ? m : n; - int y = is_tensor_M_major ? n : m; - - constexpr auto inst_size_bits = detail::size_of_inst_bits; - - CopyOp::PREFETCH::copy(((uint8_t*)(base_addr)) + static_cast(l) * atom.stride_l * sizeof_bits_v / 8, - (atom.width * sizeof_bits_v) / sizeof_bits_v, atom.height, - (atom.pitch * sizeof_bits_v) / sizeof_bits_v, - intel::coord_t{(int)(x * sizeof_bits_v / inst_size_bits), y}); - } - - template - static constexpr auto with(Tensor const &tensor) { - return Traits_LD_t{tensor}; - } - - template - static constexpr auto with(T0 && arg0, T1 && arg1, Ts&&... args) { - return Traits_LD_t{arg0, arg1, args...}; - } -}; - -template , int64_t>> struct XE_2D_ST_Unpack { - using Traits_ST_t = Copy_Traits; - using BlockShape = typename CopyOp::BlockShape; - - static constexpr auto stride_rank = rank(StrideOrTensor{}); - static_assert(stride_rank == 2 || stride_rank == 3); - - static constexpr bool is_matrix_B = false; - - const void *base_ptr; - uint32_t width; - uint32_t height; - uint32_t pitch; - uint32_t stride_l = 0; - - XE_2D_ST_Unpack(const void *ptr, uint32_t y, - uint32_t x, uint32_t p = 0) : base_ptr(ptr) { - width = x; - height = y; - pitch = (p == 0 ? width : p); - } - - template - XE_2D_ST_Unpack(Tensor const &tensor) { - base_ptr = tensor.data().get(); - width = size<1>(tensor.shape()); - height = size<0>(tensor.shape()); - pitch = size<0>(tensor.stride()); - - if constexpr (stride_rank == 3) { - stride_l = size<2>(tensor.stride()); - } - } - - XE_2D_ST_Unpack(Traits_ST_t const &traits) : base_ptr(traits.base_ptr), - width(traits.width), height(traits.height), pitch(traits.pitch), - stride_l(traits.stride_l) {} - - XE_2D_ST_Unpack() {} - - - template - CUTE_HOST_DEVICE friend constexpr void - copy_unpack(Traits_ST_t const &traits, Tensor const &src, - Tensor &dst) { - - using dtype = typename Tensor::value_type; - constexpr int dtype_bits = sizeof_bits_v; - - static_assert(is_rmem::value); - static_assert(size(SLayout{}) * dtype_bits == size<1>(typename Traits_ST_t::SrcLayout{}), - "Src tensor size does not match copy atom size."); - static_assert(size(DLayout{}) * dtype_bits == size<1>(typename Traits_ST_t::DstLayout{}), - "Dst tensor size does not match copy atom size."); - - dtype *base_addr = (dtype *)traits.base_ptr; - - auto [m, n, l] = dst.data().coord_; - - CopyOp::copy(((uint8_t*)(base_addr)) + static_cast(l) * traits.stride_l * sizeof_bits_v / 8, - traits.width * sizeof(dtype), traits.height, - traits.pitch * sizeof(dtype), - intel::coord_t{(int)n, (int)m}, &*src.data()); - } - - template - static constexpr auto with(Tensor const &tensor) { - return Traits_ST_t{tensor}; - } - - template - static constexpr auto with(T0 && arg0, T1 && arg1, Ts&&... args) { - return Traits_ST_t{arg0, arg1, args...}; - } - -}; - -template -CUTE_HOST_DEVICE constexpr auto make_fragment_layout(TiledCopy &tiled_copy, - TLShape &&fragment_top_level_shape) { - // Shapes are reversed for col major case between registers and global memory, - // so all variables contain in their name whether they refer to the shape in registers or in global memory - - // TODO(Codeplay): reverse values in 2d (U8) MMA atoms instead - constexpr auto mma_atom_regs_shape = cute::reverse(get<0>(TLShape{})); - using MmaValsShapeRegs2d = std::conditional_t(mma_atom_regs_shape, _1{})), - decltype(append<2>(mma_atom_regs_shape, _1{}))>; - - using ThreadLayout_ = Shape, _1>; - using ThreadLayoutRegs = std::conditional_t; - using BlockShapeRegs = typename TiledCopy::BlockShape; - using TotalMmaAtomItersRegs = decltype(select<1,2>(TLShape{})); - - using CopyValsShapeRegs = decltype(shape_div(BlockShapeRegs{}, ThreadLayoutRegs{})); - // This case would need to rearrange data in registers between copy and mma calls - static_assert(get<0>(CopyValsShapeRegs{}) >= get<0>(MmaValsShapeRegs2d{}) || - get<1>(CopyValsShapeRegs{}) <= get<1>(MmaValsShapeRegs2d{}), - "It is not currently supported to have MMA atom be bigger than copy atom in one dimension and smaller in other dimension!"); - using MmaItersInCopyRegs = decltype(ceil_div(CopyValsShapeRegs{}, MmaValsShapeRegs2d{})); - using CopyItersRegs = decltype(shape_div(TotalMmaAtomItersRegs{}, MmaItersInCopyRegs{})); - - auto order = std::conditional_t, Step<_3, _5>, Step<_2, _4>>, - Step, Step<_2, _4>, Step<_3, _5>>>{}; - - auto res = make_ordered_layout( - prepend(cute::zip(MmaItersInCopyRegs{}, CopyItersRegs{}), MmaValsShapeRegs2d{}), - order); - - static_assert(size(res) == size(TLShape{}), "Internal eror in make_fragment_layout()."); - return res; -}; - -// clang-format off - -template -struct Copy_Traits_{ - static_assert(cute::dependent_false, "Copy_Traits_ not defined for this CopyOp"); -}; - -template -struct Copy_Traits_ - : XE_2D_LD_Unpack { - using ThrID = Layout<_16>; - // Map from (src-thr,src-val) to bit - using SrcLayout = Layout, - Stride< _0, _1>>; - // Map from (dst-thr,dst-val) to bit - // mode 0 16:8 This will show there are 16 thread ehrtr each thread with stride of value with 8 bits away from the adjusent thread - // mode 1: <_8>:< _1> This says each thread will get 1 element each of them 8 bits. - using DstLayout = Layout, - Stride<_8, _1>>; - // Reference map from (thr,val) to bit - using RefLayout = DstLayout; - - template - Copy_Traits_(ArgT... args) - : XE_2D_LD_Unpack(args...) {} -}; - - -template -struct Copy_Traits_ - : XE_2D_LD_Unpack { - using ThrID = Layout<_16>; - // Map from (src-thr,src-val) to bit - using SrcLayout = Layout>, - Stride< _0, Stride <_1, _128, _256>>>; - // Map from (dst-thr,dst-val) to bit - // mode 0 16:8 This will show there are 16 thread ehrtr each thread with stride of value with 8 bits away from the adjusent thread - // mode 1: <_8, _2, _1>:< _1, _128, _256> This says each thread will get 2x1 element - // each of them 8 bits. The stired shows each thread jumps 16x8 bits for the next element in the block and 16x8x2 for the next row in the block - using DstLayout = Layout>, - Stride<_8, Stride < _1, _128, _256>>>; - // Reference map from (thr,val) to bit - using RefLayout = DstLayout; - - template - Copy_Traits_(ArgT... args) - : XE_2D_LD_Unpack(args...) {} -}; - -template -struct Copy_Traits_ - : XE_2D_LD_Unpack { - using ThrID = Layout<_16>; - // Map from (src-thr,src-val) to bit - using SrcLayout = Layout, - Stride< _0,_1>>; - // Map from (dst-thr,dst-val) to bit - using DstLayout = Layout, - Stride<_16, _1>>; - // Reference map from (thr,val) to bit - using RefLayout = DstLayout; - - template - Copy_Traits_(ArgT... args) - : XE_2D_LD_Unpack(args...) {} -}; - -template -struct Copy_Traits_ - : XE_2D_LD_Unpack { - using ThrID = Layout<_16>; - // Map from (src-thr,src-val) to bit - using SrcLayout = Layout>, - Stride< _0,Stride< _1,_128,_256>>>; - // Map from (dst-thr,dst-val) to bit - using DstLayout = Layout>, - Stride<_16,Stride< _1,_128,_256>>>; - // Reference map from (thr,val) to bit - using RefLayout = DstLayout; - - template - Copy_Traits_(ArgT... args) - : XE_2D_LD_Unpack(args...) {} -}; - -template -struct Copy_Traits_ - : XE_2D_LD_Unpack { - using ThrID = Layout<_16>; - // Map from (src-thr,src-val) to bit - using SrcLayout = Layout>, - Stride< _0,Stride< _1,_128,_256>>>; - // Map from (dst-thr,dst-val) to bit - using DstLayout = Layout>, - Stride<_16,Stride< _1,_128,_256>>>; - // Reference map from (thr,val) to bit - using RefLayout = DstLayout; - - template - Copy_Traits_(ArgT... args) - : XE_2D_LD_Unpack(args...) {} -}; - -template -struct Copy_Traits_ - : XE_2D_LD_Unpack { - using ThrID = Layout<_16>; - // Map from (src-thr,src-val) to bit - using SrcLayout = Layout>, - Stride< _0,Stride< _1,_128,_256>>>; - // Map from (dst-thr,dst-val) to bit - using DstLayout = Layout>, - Stride<_16,Stride< _1,_128,_256>>>; - // Reference map from (thr,val) to bit - using RefLayout = DstLayout; - - template - Copy_Traits_(ArgT... args) - : XE_2D_LD_Unpack(args...) {} -}; - - -template -struct Copy_Traits_ - : XE_2D_LD_Unpack { - using ThrID = Layout<_16>; - // Map from (src-thr,src-val) to bit - using SrcLayout = Layout>, - Stride< _0,Stride< _1,_128,_256>>>; - // Map from (dst-thr,dst-val) to bit - using DstLayout = Layout>, // 16 thread with stride of value with 8 bits away from second thread - // Mode1 :The second parameter shows the jump for each thread int bits//the third prameter is with of the row in bits(32x8 bits) - Stride<_8,Stride< _1,_128,_256>>>; - // Reference map from (thr,val) to bit - using RefLayout = DstLayout; - - template - Copy_Traits_(ArgT... args) - : XE_2D_LD_Unpack(args...) {} -}; - -template -struct Copy_Traits_ - : XE_2D_LD_Unpack { - using ThrID = Layout<_16>; - // Map from (src-thr,src-val) to bit - using SrcLayout = Layout>, - Stride< _0,Stride< _1,_8,_256>>>; - // Map from (dst-thr,dst-val) to bit - using DstLayout = Layout>, - Stride<_16,Stride< _1,_8,_256>>>; - // Reference map from (thr,val) to bit - using RefLayout = DstLayout; - - template - Copy_Traits_(ArgT... args) - : XE_2D_LD_Unpack(args...) {} -}; - -template -struct Copy_Traits_ - : XE_2D_LD_Unpack { - using ThrID = Layout<_16>; - // Map from (src-thr,src-val) to bit - using SrcLayout = Layout>, - Stride< _0,Stride< _1,_128,_256>>>; - // Map from (dst-thr,dst-val) to bit - using DstLayout = Layout>, - Stride<_8,Stride< _1,_128,_256>>>; - // Reference map from (thr,val) to bit - using RefLayout = DstLayout; - - template - Copy_Traits_(ArgT... args) - : XE_2D_LD_Unpack(args...) {} -}; - -template -struct Copy_Traits_ - : XE_2D_LD_Unpack { - using ThrID = Layout<_16>; - // Map from (src-thr,src-val) to bit - using SrcLayout = Layout>, - Stride< _0,Stride< _1,_8,_256>>>; - // Map from (dst-thr,dst-val) to bit - using DstLayout = Layout>, - Stride<_16,Stride< _1,_8,_256>>>; - // Reference map from (thr,val) to bit - using RefLayout = DstLayout; - - template - Copy_Traits_(ArgT... args) - : XE_2D_LD_Unpack(args...) {} -}; - -template -struct Copy_Traits_ - : XE_2D_LD_Unpack { - using ThrID = Layout<_16>; - // Map from (src-thr,src-val) to bit - using SrcLayout = Layout>, - Stride< _0,Stride< _1, _4>>>; - // Map from (dst-thr,dst-val) to bit - using DstLayout = Layout>, - Stride<_32,Stride< _1, _4>>>; - // Reference map from (thr,val) to bit - using RefLayout = DstLayout; - - template - Copy_Traits_(ArgT... args) - : XE_2D_LD_Unpack(args...) {} -}; - -template -struct Copy_Traits_ - : XE_2D_LD_Unpack { - using ThrID = Layout<_16>; - // Map from (src-thr,src-val) to bit - using SrcLayout = Layout>, - Stride<_0,Stride< _1, _4, _16>>>; - // Map from (dst-thr,dst-val) to bit - using DstLayout = Layout>, - Stride<_16,Stride< _1, _4, _16>>>; - // Reference map from (thr,val) to bit - using RefLayout = DstLayout; - - template - Copy_Traits_(ArgT... args) - : XE_2D_LD_Unpack(args...) {} -}; - -template -struct Copy_Traits_ - : XE_2D_LD_Unpack { - using ThrID = Layout<_16>; - // Map from (src-thr,src-val) to bit - using SrcLayout = Layout>, - Stride<_0,Stride <_1, _4, _16>>>; - // Map from (dst-thr,dst-val) to bit - using DstLayout = Layout>, - Stride<_16,Stride<_1, _4, _16>>>; - // Reference map from (thr,val) to bit - using RefLayout = DstLayout; - - template - Copy_Traits_(ArgT... args) - : XE_2D_LD_Unpack(args...) {} -}; - -template -struct Copy_Traits_ - : XE_2D_LD_Unpack { - using ThrID = Layout<_16>; - // Map from (src-thr,src-val) to bit - using SrcLayout = Layout>, - Stride<_0,Stride< _1, _4>>>; - // Map from (dst-thr,dst-val) to bit - using DstLayout = Layout>, - Stride<_32,Stride<_1, _4>>>; - // Reference map from (thr,val) to bit - using RefLayout = DstLayout; - - template - Copy_Traits_(ArgT... args) - : XE_2D_LD_Unpack(args...) {} -}; - -template -struct Copy_Traits_ - : XE_2D_LD_Unpack { - using ThrID = Layout<_16>; - // Map from (src-thr,src-val) to bit - using SrcLayout = Layout>, - Stride<_0,Stride< _1,_64>>>; - // Map from (dst-thr,dst-val) to bit - using DstLayout = Layout>, - Stride<_16,Stride< _1,_64>>>; - // Reference map from (thr,val) to bit - using RefLayout = DstLayout; - - template - Copy_Traits_(ArgT... args) - : XE_2D_LD_Unpack(args...) {} -}; -template -struct Copy_Traits_ - : XE_2D_LD_Unpack { - using ThrID = Layout<_16>; - // Map from (src-thr,src-val) to bit - using SrcLayout = Layout>, - Stride<_0,Stride< _1,_64>>>; - // Map from (dst-thr,dst-val) to bit - using DstLayout = Layout>, - Stride<_16,Stride< _1,_64>>>; - // Reference map from (thr,val) to bit - using RefLayout = DstLayout; - - template - Copy_Traits_(ArgT... args) - : XE_2D_LD_Unpack(args...) {} -}; - -template -struct Copy_Traits_ - : XE_2D_LD_Unpack { - using ThrID = Layout<_16>; - // Map from (src-thr,src-val) to bit - using SrcLayout = Layout, - Stride< _0,_1>>; - // Map from (dst-thr,dst-val) to bit - using DstLayout = Layout>, - Stride<_16,Stride< _1,_256>>>; - // Reference map from (thr,val) to bit - using RefLayout = DstLayout; - - template - Copy_Traits_(ArgT... args) - : XE_2D_LD_Unpack(args...) {} -}; - -template -struct Copy_Traits_ - : XE_2D_LD_Unpack { - // Logical thread id to thread idx - using ThrID = Layout<_16>; - // Map from (src-thr,src-val) to bit - using SrcLayout = Layout>, - Stride<_16,Stride<_1,_8,_256>>>; - // Map from (dst-thr,dst-val) to bit - using DstLayout = Layout>, - Stride<_16,Stride<_1,_8,_256>>>; - // Reference map from (thr,val) to bit - using RefLayout = DstLayout; - using CopyInternalType = cute::intel::ushort; -}; - -template -struct Copy_Traits_ - : XE_2D_LD_Unpack { - using ThrID = Layout<_16>; - // Map from (src-thr,src-val) to bit - using SrcLayout = Layout>, - Stride< _0,Stride< _1,_512,_256>>>; - // Map from (dst-thr,dst-val) to bit - using DstLayout = Layout>, - Stride<_16,Stride< _1,_512,_256>>>; - // Reference map from (thr,val) to bit - using RefLayout = DstLayout; - - template - Copy_Traits_(ArgT... args) - : XE_2D_LD_Unpack(args...) {} -}; - - -template -struct Copy_Traits_ - : XE_2D_LD_Unpack { - // Logical thread id to thread idx - using ThrID = Layout<_16>; - // Map from (src-thr,src-val) to bit - using SrcLayout = Layout>, - Stride<_16,Stride<_1,_8,_256,_512>>>; - // Map from (dst-thr,dst-val) to bit - using DstLayout = Layout>, - Stride<_16,Stride<_1,_8,_256,_512>>>; - // Reference map from (thr,val) to bit - using RefLayout = DstLayout; - using CopyInternalType = cute::intel::ushort; -}; - -template -struct Copy_Traits_ - : XE_2D_LD_Unpack { - using ThrID = Layout<_16>; - // Map from (dst-thr,dst-val) to bit - using SrcLayout = Layout>, - Stride< _0,Stride< _1,_512,_256>>>; - // Map from (dst-thr,dst-val) to bit - using DstLayout = Layout>, - Stride<_16,Stride< _1,_512,_256>>>; - // Reference map from (thr,val) to bit - using RefLayout = DstLayout; - - template - Copy_Traits_(ArgT... args) - : XE_2D_LD_Unpack(args...) {} -}; - -template -struct Copy_Traits_ - : XE_2D_LD_Unpack { - // Logical thread id to thread idx - using ThrID = Layout<_16>; - // Map from (src-thr,src-val) to bit - using SrcLayout = Layout>, - Stride<_16,Stride<_1,_8,_256,_512>>>; - // Map from (dst-thr,dst-val) to bit - using DstLayout = Layout>, - Stride<_16,Stride<_1,_8,_256,_512>>>; - // Reference map from (thr,val) to bit - using RefLayout = DstLayout; - using CopyInternalType = cute::intel::ushort; -}; - -template -struct Copy_Traits_ - : XE_2D_LD_Unpack { - using ThrID = Layout<_16>; - // Map from (src-thr,src-val) to bit - using SrcLayout = Layout>, - Stride< _0,Stride< _1,_512,_256>>>; - // Map from (dst-thr,dst-val) to bit - using DstLayout = Layout>, - Stride<_16,Stride< _1,_256,_512>>>; - // Reference map from (thr,val) to bit - using RefLayout = DstLayout; - - template - Copy_Traits_(ArgT... args) - : XE_2D_LD_Unpack(args...) {} -}; - -template -struct Copy_Traits_ - : XE_2D_LD_Unpack { - // Logical thread id to thread idx - using ThrID = Layout<_16>; - // Map from (src-thr,src-val) to bit - using SrcLayout = Layout>, - Stride<_16,Stride<_1,_8,_256,_512>>>; - // Map from (dst-thr,dst-val) to bit - using DstLayout = Layout>, - Stride<_16,Stride<_1,_8,_256,_512>>>; - // Reference map from (thr,val) to bit - using RefLayout = DstLayout; - using CopyInternalType = cute::intel::ushort; -}; - - -template -struct Copy_Traits_ - : XE_2D_LD_Unpack { - using ThrID = Layout<_16>; - // Map from (src-thr,src-val) to bit - using SrcLayout = Layout>, - Stride< _0,Stride< _1,_512,_256>>>; - // Map from (dst-thr,dst-val) to bit - using DstLayout = Layout>, - Stride<_16,Stride< _1,_256,_512>>>; - // Reference map from (thr,val) to bit - using RefLayout = DstLayout; - template - Copy_Traits_(ArgT... args) - : XE_2D_LD_Unpack(args...) {} -}; - -template -struct Copy_Traits_ - : XE_2D_LD_Unpack { - using ThrID = Layout<_16>; - // Map from (src-thr,src-val) to bit - using SrcLayout = Layout>, - Stride<_16,Stride< _1,_256,_512>>>; - // Map from (dst-thr,dst-val) to bit - using DstLayout = Layout>, - Stride<_16,Stride< _1,_256,_512>>>; - // Reference map from (thr,val) to bit - using RefLayout = DstLayout; -}; - -template -struct Copy_Traits_ - : XE_2D_LD_Unpack { - using ThrID = Layout<_16>; - // Map from (src-thr,src-val) to bit - using SrcLayout = Layout>, - Stride< _0,Stride< _1,_512,_256>>>; - // Map from (dst-thr,dst-val) to bit - using DstLayout = Layout>, - Stride<_16,Stride< _1,_512,_256>>>; - // Reference map from (thr,val) to bit - using RefLayout = DstLayout; - template - Copy_Traits_(ArgT... args) - : XE_2D_LD_Unpack(args...) {} -}; - -template -struct Copy_Traits_ - : XE_2D_LD_Unpack { - using ThrID = Layout<_16>; - // Map from (src-thr,src-val) to bit - using SrcLayout = Layout>, - Stride<_16,Stride< _1,_256,_512>>>; - // Map from (dst-thr,dst-val) to bit - using DstLayout = Layout>, - Stride<_16,Stride< _1,_256,_512>>>; - // Reference map from (thr,val) to bit - using RefLayout = DstLayout; -}; - -template -struct Copy_Traits_ - : XE_2D_LD_Unpack { - using ThrID = Layout<_16>; - // Map from (src-thr,src-val) to bit - using SrcLayout = Layout>, - Stride<_0, Stride<_1, _8, _16>>>; - // Map from (dst-thr,dst-val) to bit - using DstLayout = Layout>, - Stride<_256,Stride<_1, _8, _16>>>; - // Reference map from (thr,val) to bit - using RefLayout = DstLayout; - - template - Copy_Traits_(ArgT... args) - : XE_2D_LD_Unpack(args...) {} -}; - -template -struct Copy_Traits_ - : XE_2D_LD_Unpack { - using ThrID = Layout<_16>; - // Map from (src-thr,src-val) to bit - using SrcLayout = Layout>, - Stride<_0, Stride<_1, _8, _16>>>; - // Map from (dst-thr,dst-val) to bit - using DstLayout = Layout>, - Stride<_256,Stride<_1, _8, _16>>>; - // Reference map from (thr,val) to bit - using RefLayout = DstLayout; - - template - Copy_Traits_(ArgT... args) - : XE_2D_LD_Unpack(args...) {} -}; - -template -struct Copy_Traits_ - : XE_2D_LD_Unpack { - using ThrID = Layout<_16>; - // Map from (src-thr,src-val) to bit - - using SrcLayout = Layout>, - Stride< _0,Stride< _1,_256>>>; - // Map from (dst-thr,dst-val) to bit - using DstLayout = Layout>, - Stride<_16,Stride< _1,_256>>>; - // Reference map from (thr,val) to bit - using RefLayout = DstLayout; - - template - Copy_Traits_(ArgT... args) - : XE_2D_LD_Unpack(args...) {} -}; - -template -struct Copy_Traits_ - : XE_2D_LD_Unpack { - using ThrID = Layout<_16>; - // Map from (src-thr,src-val) to bit - using SrcLayout = Layout>, - Stride<_0,Stride< _1,_256>>>; - // Map from (dst-thr,dst-val) to bit - using DstLayout = Layout>, - Stride<_16,Stride< _1,_256>>>; - // Reference map from (thr,val) to bit - using RefLayout = DstLayout; - - template - Copy_Traits_(ArgT... args) - : XE_2D_LD_Unpack(args...) {} -}; - -template -struct Copy_Traits_ - : XE_2D_LD_Unpack { - using ThrID = Layout<_16>; - // Map from (src-thr,src-val) to bit - using SrcLayout = Layout>, - Stride< _0,Stride< _1,_256>>>; - // Map from (dst-thr,dst-val) to bit - using DstLayout = Layout>, - Stride<_16,Stride< _1,_256>>>; - // Reference map from (thr,val) to bit - using RefLayout = DstLayout; - - template - Copy_Traits_(ArgT... args) - : XE_2D_LD_Unpack(args...) {} -}; - -template -struct Copy_Traits_ - : XE_2D_LD_Unpack { - // Logical thread id to thread idx - using ThrID = Layout<_16>; - // Map from (src-thr,src-val) to bit - using SrcLayout = Layout>, - Stride< _0,Stride< _1,_256>>>; - // Map from (dst-thr,dst-val) to bit - using DstLayout = Layout>, - Stride<_16,Stride< _1,_256>>>; - // Reference map from (thr,val) to bit - using RefLayout = DstLayout; - - template - Copy_Traits_(ArgTs... args) - : XE_2D_LD_Unpack(args...) {} -}; - -template -struct Copy_Traits_ - : XE_2D_LD_Unpack { - // Logical thread id to thread idx - using ThrID = Layout<_16>; - // Map from (src-thr,src-val) to bit - using SrcLayout = Layout>, - Stride<_16,Stride< _1,_256>>>; - // Map from (dst-thr,dst-val) to bit - using DstLayout = Layout>, - Stride<_16,Stride< _1,_256>>>; - // Reference map from (thr,val) to bit - using RefLayout = DstLayout; -}; - -template -struct Copy_Traits_ - : XE_2D_LD_Unpack { - using ThrID = Layout<_16>; - // Map from (src-thr,src-val) to bit - using SrcLayout = Layout>, - Stride< _0, Stride< _1,_256>>>; - // Map from (dst-thr,dst-val) to bit - using DstLayout = Layout>, - Stride<_16,Stride< _1,_256>>>; - // Reference map from (thr,val) to bit - using RefLayout = DstLayout; - - template - Copy_Traits_(ArgTs... args) - : XE_2D_LD_Unpack(args...) {} -}; - -template -struct Copy_Traits_ - : XE_2D_LD_Unpack { - using ThrID = Layout<_16>; - // Map from (src-thr,src-val) to bit - using SrcLayout = Layout>, - Stride< _0,Stride< _1, _16,_512>>>; - // Map from (dst-thr,dst-val) to bit - using DstLayout = Layout>, - Stride<_32,Stride< _1, _16,_512>>>; - // Reference map from (thr,val) to bit - using RefLayout = DstLayout; - - template - Copy_Traits_(ArgTs... args) - : XE_2D_LD_Unpack(args...) {} -}; - -template -struct Copy_Traits_ - : XE_2D_LD_Unpack { - // Logical thread id to thread idx - using ThrID = Layout<_16>; - // Map from (src-thr,src-val) to bit - using SrcLayout = Layout>, - Stride<_16,Stride< _1,_256>>>; - // Map from (dst-thr,dst-val) to bit - using DstLayout = Layout>, - Stride<_16,Stride< _1,_256>>>; - // Reference map from (thr,val) to bit - using RefLayout = DstLayout; -}; - -template -struct Copy_Traits_ - : XE_2D_LD_Unpack { - using ThrID = Layout<_16>; - // Map from (src-thr,src-val) to bit - using SrcLayout = Layout, - Stride< _0, _1>>; - // Map from (dst-thr,dst-val) to bit - using DstLayout = Layout>, - Stride<_16,Stride< _1,_256>>>; - // Reference map from (thr,val) to bit - using RefLayout = DstLayout; - - template - Copy_Traits_(ArgTs... args) - : XE_2D_LD_Unpack(args...) {} -}; - -template -struct Copy_Traits_ - : XE_2D_LD_Unpack { - // Logical thread id to thread idx - using ThrID = Layout<_16>; - // Map from (src-thr,src-val) to bit - using SrcLayout = Layout>, - Stride<_16,Stride< _1,_256>>>; - // Map from (dst-thr,dst-val) to bit - using DstLayout = Layout>, - Stride<_16,Stride< _1,_256>>>; - // Reference map from (thr,val) to bit - using RefLayout = DstLayout; -}; - -template -struct Copy_Traits_ - : XE_2D_LD_Unpack { - using ThrID = Layout<_16>; - // Map from (src-thr,src-val) to bit - using SrcLayout = Layout>, - Stride< _0,Stride< _1,_256>>>; - // Map from (dst-thr,dst-val) to bit - using DstLayout = Layout>, - Stride<_16,Stride< _1,_256>>>; - // Reference map from (thr,val) to bit - using RefLayout = DstLayout; - - template - Copy_Traits_(ArgTs... args) - : XE_2D_LD_Unpack(args...) {} -}; - -template -struct Copy_Traits_ - : XE_2D_LD_Unpack { - // Logical thread id to thread idx - using ThrID = Layout<_16>; - // Map from (src-thr,src-val) to bit - using SrcLayout = Layout>, - Stride<_16,Stride< _1,_256>>>; - // Map from (dst-thr,dst-val) to bit - using DstLayout = Layout>, - Stride<_16,Stride< _1,_256>>>; - // Reference map from (thr,val) to bit - using RefLayout = DstLayout; -}; - -template -struct Copy_Traits_ - : XE_2D_LD_Unpack { - using ThrID = Layout<_16>; - // Map from (src-thr,src-val) to bit - using SrcLayout = Layout>, - Stride< _0,Stride< _1,_256,_512>>>; - // Map from (dst-thr,dst-val) to bit - using DstLayout = Layout>, - Stride<_16,Stride< _1,_256,_512>>>; - // Reference map from (thr,val) to bit - using RefLayout = DstLayout; - - template - Copy_Traits_(ArgTs... args) - : XE_2D_LD_Unpack(args...) {} -}; - -template -struct Copy_Traits_ - : XE_2D_LD_Unpack { - // Logical thread id to thread idx - using ThrID = Layout<_16>; - // Map from (src-thr,src-val) to bit - using SrcLayout = Layout>, - Stride<_16,Stride< _1,_256,_512>>>; - // Map from (dst-thr,dst-val) to bit - using DstLayout = Layout>, - Stride<_16,Stride< _1,_256,_512>>>; - // Reference map from (thr,val) to bit - using RefLayout = DstLayout; -}; - -template -struct Copy_Traits_ - : XE_2D_LD_Unpack { - using ThrID = Layout<_16>; - // Map from (src-thr,src-val) to bit - using SrcLayout = Layout>, - Stride< _0,Stride< _1,_256,_512>>>; - // Map from (dst-thr,dst-val) to bit - using DstLayout = Layout>, - Stride<_16,Stride< _1,_256,_512>>>; - // Reference map from (thr,val) to bit - using RefLayout = DstLayout; - - template - Copy_Traits_(ArgTs... args) - : XE_2D_LD_Unpack(args...) {} -}; - -template -struct Copy_Traits_ - : XE_2D_LD_Unpack { - // Logical thread id to thread idx - using ThrID = Layout<_16>; - // Map from (src-thr,src-val) to bit - using SrcLayout = Layout>, - Stride<_16,Stride< _1,_256,_512>>>; - // Map from (dst-thr,dst-val) to bit - using DstLayout = Layout>, - Stride<_16,Stride< _1,_256,_512>>>; - // Reference map from (thr,val) to bit - using RefLayout = DstLayout; -}; - -template -struct Copy_Traits_ - : XE_2D_LD_Unpack { - using ThrID = Layout<_16>; - - // Map from (src-thr,src-val) to bit - using SrcLayout = Layout>, - Stride< _0,Stride< _1,_256,_512>>>; - // Map from (dst-thr,dst-val) to bit - using DstLayout = Layout>, - Stride<_16,Stride< _1,_256,_512>>>; - // Reference map from (thr,val) to bit - using RefLayout = DstLayout; - - template - Copy_Traits_(ArgTs... args) - : XE_2D_LD_Unpack(args...) {} -}; - -template -struct Copy_Traits_ - : XE_2D_LD_Unpack { - // Logical thread id to thread idx - using ThrID = Layout<_16>; - // Map from (src-thr,src-val) to bit - using SrcLayout = Layout>, - Stride<_32,Stride< _1,_512>>>; - // Map from (dst-thr,dst-val) to bit - using DstLayout = Layout>, - Stride<_32,Stride< _1,_512>>>; - // Reference map from (thr,val) to bit - using RefLayout = DstLayout; - using CopyInternalType = cute::intel::ushort; -}; - -template -struct Copy_Traits_ - : XE_2D_LD_Unpack { - // Logical thread id to thread idx - using ThrID = Layout<_16>; - - - // Map from (src-thr,src-val) to bit - using SrcLayout = Layout>, - Stride< _0,Stride< _1,_256,_512>>>; - // Map from (dst-thr,dst-val) to bit - using DstLayout = Layout>, - Stride<_16,Stride< _1,_256,_512>>>; - // Reference map from (thr,val) to bit - using RefLayout = DstLayout; - - template - Copy_Traits_(ArgTs... args) - : XE_2D_LD_Unpack(args...) {} -}; - -template -struct Copy_Traits_ - : XE_2D_LD_Unpack { - // Logical thread id to thread idx - using ThrID = Layout<_16>; - // Map from (src-thr,src-val) to bit - using SrcLayout = Layout>, - Stride<_32,Stride< _1,_512>>>; - // Map from (dst-thr,dst-val) to bit - using DstLayout = Layout>, - Stride<_32,Stride< _1,_512>>>; - // Reference map from (thr,val) to bit - using RefLayout = DstLayout; - using CopyInternalType = cute::intel::ushort; -}; - -template -struct Copy_Traits_ - : XE_2D_LD_Unpack { - using ThrID = Layout<_16>; - // Map from (src-thr,src-val) to bit - using SrcLayout = Layout>, - Stride<_0, Stride< _1,_256,_512>>>; - // Map from (dst-thr,dst-val) to bit - using DstLayout = Layout>, - Stride<_16, Stride< _1,_256,_512>>>; - // Reference map from (thr,val) to bit - using RefLayout = DstLayout; - - template - Copy_Traits_(ArgTs... args) - : XE_2D_LD_Unpack(args...) {} -}; - -template -struct Copy_Traits_ - : XE_2D_LD_Unpack { - using ThrID = Layout<_16>; - // Map from (src-thr,src-val) to bit - using SrcLayout = Layout>, - Stride<_0, Stride<_1, _16, _512>>>; - // Map from (dst-thr,dst-val) to bit - using DstLayout = Layout>, - Stride<_32, Stride<_1, _16, _512>>>; - // Reference map from (thr,val) to bit - using RefLayout = DstLayout; - - template - Copy_Traits_(ArgTs... args) - : XE_2D_LD_Unpack(args...) {} -}; - -template -struct Copy_Traits_ - : XE_2D_LD_Unpack { - // Logical thread id to thread idx - using ThrID = Layout<_16>; - // Map from (src-thr,src-val) to bit - using SrcLayout = Layout>, - Stride<_32,Stride< _1,_512>>>; - // Map from (dst-thr,dst-val) to bit - using DstLayout = Layout>, - Stride<_32,Stride< _1,_512>>>; - // Reference map from (thr,val) to bit - using RefLayout = DstLayout; - using CopyInternalType = cute::intel::ushort; -}; - -template -struct Copy_Traits_ - : XE_2D_LD_Unpack { - using ThrID = Layout<_16>; - // Map from (src-thr,src-val) to bit - using SrcLayout = Layout, - Stride< _0, _1>>; - // Map from (dst-thr,dst-val) to bit - using DstLayout = Layout,_32>, - Stride, _1>>; - // Reference map from (thr,val) to bit - using RefLayout = DstLayout; - - template - Copy_Traits_(ArgTs... args) - : XE_2D_LD_Unpack(args...) {} -}; - -template -struct Copy_Traits_ - : XE_2D_LD_Unpack { - using ThrID = Layout<_16>; - // Map from (src-thr,src-val) to bit - using SrcLayout = Layout, - Stride< _0, _1>>; - // Map from (dst-thr,dst-val) to bit - using DstLayout = Layout,_32>, - Stride, _1>>; - // Reference map from (thr,val) to bit - using RefLayout = DstLayout; - - template - Copy_Traits_(ArgTs... args) - : XE_2D_LD_Unpack(args...) {} -}; - -template -struct Copy_Traits_ - : XE_2D_LD_Unpack { - using ThrID = Layout<_16>; - // Map from (src-thr,src-val) to bit - using SrcLayout = Layout, - Stride< _0, _1>>; - // Map from (dst-thr,dst-val) to bit - using DstLayout = Layout,Shape <_32, _2>>, - Stride,Stride< _1,_512>>>; - // Reference map from (thr,val) to bit - using RefLayout = DstLayout; - - template - Copy_Traits_(ArgTs... args) - : XE_2D_LD_Unpack(args...) {} -}; - -template -struct Copy_Traits_ - : XE_2D_LD_Unpack { - using ThrID = Layout<_16>; - // Map from (src-thr,src-val) to bit - using SrcLayout = Layout, - Stride< _0, _1>>; - // Map from (dst-thr,dst-val) to bit - using DstLayout = Layout,Shape <_32, _4>>, - Stride,Stride< _1,_512>>>; - // Reference map from (thr,val) to bit - using RefLayout = DstLayout; - - template - Copy_Traits_(ArgTs... args) - : XE_2D_LD_Unpack(args...) {} -}; - -template -struct Copy_Traits_ - : XE_2D_LD_Unpack { - using ThrID = Layout<_16>; - // Map from (src-thr,src-val) to bit - using SrcLayout = Layout, - Stride< _0, _1>>; - // Map from (dst-thr,dst-val) to bit - using DstLayout = Layout,Shape <_32, _8>>, - Stride,Stride< _1, _512>>>; - // Reference map from (thr,val) to bit - using RefLayout = DstLayout; - - template - Copy_Traits_(ArgTs... args) - : XE_2D_LD_Unpack(args...) {} -}; - -template -struct Copy_Traits_ - : XE_2D_LD_Unpack { - using ThrID = Layout<_16>; - // Map from (src-thr,src-val) to bit - using SrcLayout = Layout, - Stride< _0, _1>>; - // Map from (dst-thr,dst-val) to bit - using DstLayout = Layout,Shape <_32, _16>>, - Stride,Stride< _1,_512>>>; - // Reference map from (thr,val) to bit - using RefLayout = DstLayout; - - template - Copy_Traits_(ArgTs... args) - : XE_2D_LD_Unpack(args...) {} -}; - -template -struct Copy_Traits_ - : XE_2D_LD_Unpack { - using ThrID = Layout<_16>; - // Map from (src-thr,src-val) to bit - using SrcLayout = Layout, - Stride< _0, _1>>; - // Map from (dst-thr,dst-val) to bit - using DstLayout = Layout,Shape <_32, _2>>, - Stride,Stride< _1,_256>>>; - // Reference map from (thr,val) to bit - using RefLayout = DstLayout; - - template - Copy_Traits_(ArgTs... args) - : XE_2D_LD_Unpack(args...) {} -}; - -template -struct Copy_Traits_ - : XE_2D_LD_Unpack { - using ThrID = Layout<_16>; - // Map from (src-thr,src-val) to bit - using SrcLayout = Layout, - Stride< _0, _1>>; - // Map from (dst-thr,dst-val) to bit - using DstLayout = Layout,Shape <_32, _2>>, - Stride,Stride< _1,_256>>>; - // Reference map from (thr,val) to bit - using RefLayout = DstLayout; - - template - Copy_Traits_(ArgTs... args) - : XE_2D_LD_Unpack(args...) {} -}; - -template -struct Copy_Traits_ - : XE_2D_LD_Unpack { - using ThrID = Layout<_16>; - // Map from (src-thr,src-val) to bit - using SrcLayout = Layout, - Stride< _0, _1>>; - // Map from (dst-thr,dst-val) to bit - using DstLayout = Layout,Shape <_32, _2, _2>>, - Stride,Stride< _1,_256,_1024>>>; - // Reference map from (thr,val) to bit - using RefLayout = DstLayout; - - template - Copy_Traits_(ArgTs... args) - : XE_2D_LD_Unpack(args...) {} -}; - -template -struct Copy_Traits_ - : XE_2D_LD_Unpack { - using ThrID = Layout<_16>; - // Map from (src-thr,src-val) to bit - using SrcLayout = Layout, - Stride< _0, _1>>; - // Map from (dst-thr,dst-val) to bit - using DstLayout = Layout,Shape <_32, _2, _4>>, - Stride,Stride< _1,_256,_1024>>>; - // Reference map from (thr,val) to bit - using RefLayout = DstLayout; - - template - Copy_Traits_(ArgTs... args) - : XE_2D_LD_Unpack(args...) {} -}; - -template -struct Copy_Traits_ - : XE_2D_LD_Unpack { - using ThrID = Layout<_16>; - // Map from (src-thr,src-val) to bit - using SrcLayout = Layout, - Stride< _0, _1>>; - // Map from (dst-thr,dst-val) to bit - using DstLayout = Layout>, - Stride<_0,Stride< _512, _1>>>; - // Reference map from (thr,val) to bit - using RefLayout = DstLayout; - - template - Copy_Traits_(ArgTs... args) - : XE_2D_LD_Unpack(args...) {} -}; - -template -struct Copy_Traits_ - : XE_2D_LD_Unpack { - using ThrID = Layout<_16>; - // Map from (src-thr,src-val) to bit - using SrcLayout = Layout, - Stride< _0, _1>>; - // Map from (dst-thr,dst-val) to bit - using DstLayout = Layout>, - Stride< _0, Stride<_512, Int<512 * 16>, _1>>>; - // Reference map from (thr,val) to bit - using RefLayout = DstLayout; - - template - Copy_Traits_(ArgTs... args) - : XE_2D_LD_Unpack(args...) {} -}; - -template -struct Copy_Traits_ - : XE_2D_LD_Unpack { - using ThrID = Layout<_16>; - // Map from (src-thr,src-val) to bit - using SrcLayout = Layout>, - Stride< _0, Stride<_32, _1>>>; - // Map from (dst-thr,dst-val) to bit - using DstLayout = Layout>, - Stride< _32, Stride<_32, _1>>>; - // Reference map from (thr,val) to bit - using RefLayout = DstLayout; - - template - Copy_Traits_(ArgTs... args) - : XE_2D_LD_Unpack(args...) {} -}; - -template -struct Copy_Traits_ - : XE_2D_LD_Unpack { - using ThrID = Layout<_16>; - // Map from (src-thr,src-val) to bit - using SrcLayout = Layout, - Stride< _0, _1>>; - // Map from (dst-thr,dst-val) to bit - using DstLayout = Layout, - Stride<_32, _1>>; - // Reference map from (thr,val) to bit - using RefLayout = DstLayout; - - template - Copy_Traits_(ArgTs... args) - : XE_2D_LD_Unpack(args...) {} -}; - -template -struct Copy_Traits_ - : XE_2D_LD_Unpack { - using ThrID = Layout<_16>; - // Map from (src-thr,src-val) to bit - using SrcLayout = Layout>, - Stride< _0,Stride< _1,_512>>>; - // Map from (dst-thr,dst-val) to bit - using DstLayout = Layout>, - Stride<_32,Stride< _1,_512>>>; - // Reference map from (thr,val) to bit - using RefLayout = DstLayout; - - template - Copy_Traits_(ArgTs... args) - : XE_2D_LD_Unpack(args...) {} -}; - -template -struct Copy_Traits_ - : XE_2D_LD_Unpack { - using ThrID = Layout<_16>; - // Map from (src-thr,src-val) to bit - using SrcLayout = Layout>, - Stride< _0, Stride< _1,_512>>>; - // Map from (dst-thr,dst-val) to bit - using DstLayout = Layout>, - Stride<_32,Stride< _1,_512>>>; - // Reference map from (thr,val) to bit - using RefLayout = DstLayout; - - template - Copy_Traits_(ArgTs... args) - : XE_2D_LD_Unpack(args...) {} -}; - -template -struct Copy_Traits_ - : XE_2D_LD_Unpack { - // Logical thread id to thread idx - using ThrID = Layout<_16>; - // Map from (src-thr,src-val) to bit - using SrcLayout = Layout>, - Stride< _0, Stride< _1,_512>>>; - // Map from (dst-thr,dst-val) to bit - using DstLayout = Layout>, - Stride<_32,Stride< _1,_512>>>; - // Reference map from (thr,val) to bit - using RefLayout = DstLayout; - template - Copy_Traits_(ArgTs... args) - : XE_2D_LD_Unpack(args...) {} -}; - -template -struct Copy_Traits_ - : XE_2D_LD_Unpack { - using ThrID = Layout<_16>; - // Map from (src-thr,src-val) to bit - using SrcLayout = Layout>, - Stride< _0,Stride< _1,_512>>>; - // Map from (dst-thr,dst-val) to bit - using DstLayout = Layout>, - Stride<_32,Stride< _1,_512>>>; - // Reference map from (thr,val) to bit - using RefLayout = DstLayout; - - template - Copy_Traits_(ArgTs... args) - : XE_2D_LD_Unpack(args...) {} -}; - -template -struct Copy_Traits_ - : XE_2D_LD_Unpack { - using ThrID = Layout<_16>; - // Map from (src-thr,src-val) to bit - using SrcLayout = Layout>, - Stride< _0,Stride< _1,_512>>>; - // Map from (dst-thr,dst-val) to bit - using DstLayout = Layout>, - Stride<_32,Stride< _1,_512>>>; - // Reference map from (thr,val) to bit - using RefLayout = DstLayout; - - template - Copy_Traits_(ArgTs... args) - : XE_2D_LD_Unpack(args...) {} -}; - -template -struct Copy_Traits_ - : XE_2D_LD_Unpack { - using ThrID = Layout<_16>; - // Map from (src-thr,src-val) to bit - using SrcLayout = Layout>, - Stride< _0,Stride<_1,_128>>>; - // Map from (dst-thr,dst-val) to bit - using DstLayout = Layout>, - Stride< _8,Stride<_1,_128>>>; - // Reference map from (thr,val) to bit - using RefLayout = DstLayout; - - template - Copy_Traits_(ArgTs... args) - : XE_2D_LD_Unpack(args...) {} -}; - -template -struct Copy_Traits_ - : XE_2D_LD_Unpack { - // Logical thread id to thread idx - using ThrID = Layout<_16>; - // Map from (src-thr,src-val) to bit - using SrcLayout = Layout>, - Stride< _8,Stride<_1,_128>>>; - // Map from (dst-thr,dst-val) to bit - using DstLayout = Layout>, - Stride< _8,Stride<_1,_128>>>; - // Reference map from (thr,val) to bit - using RefLayout = DstLayout; -}; - -template -struct Copy_Traits_ - : XE_2D_LD_Unpack { - using ThrID = Layout<_16>; - // Map from (src-thr,src-val) to bit - using SrcLayout = Layout>, - Stride< _0,Stride<_1,_256,_128>>>; - // Map from (dst-thr,dst-val) to bit - using DstLayout = Layout>, - Stride< _8,Stride<_1,_256,_128>>>; - // Reference map from (thr,val) to bit - using RefLayout = DstLayout; - - template - Copy_Traits_(ArgTs... args) - : XE_2D_LD_Unpack(args...) {} -}; - -template -struct Copy_Traits_ - : XE_2D_LD_Unpack { - using ThrID = Layout<_16>; - // Map from (src-thr,src-val) to bit - using SrcLayout = Layout>, - Stride< _0,Stride<_1,_512, _8,_256>>>; - // Map from (dst-thr,dst-val) to bit - using DstLayout = Layout>, - Stride< _8,Stride<_1,_512, _8,_256>>>; - // Reference map from (thr,val) to bit - using RefLayout = DstLayout; - - template - Copy_Traits_(ArgTs... args) - : XE_2D_LD_Unpack(args...) {} -}; - -template -struct Copy_Traits_ - : XE_2D_LD_Unpack { - using ThrID = Layout<_16>; - // Map from (src-thr,src-val) to bit - // TODO(joe): Not convinced that changing from <_16, _256> should be required here - // but get_logical_layout assumes get<1,0>(layout.shape) is the type size - using SrcLayout = Layout>, - Stride< _0,Stride<_1, _8>>>; - // Map from (dst-thr,dst-val) to bit - using DstLayout = Layout>, - Stride<_256,Stride<_1, _8>>>; - // Reference map from (thr,val) to bit - using RefLayout = DstLayout; - - template - Copy_Traits_(ArgT... args) - : XE_2D_LD_Unpack(args...) {} -}; - -template -struct Copy_Traits_ - : XE_2D_LD_Unpack { - using ThrID = Layout<_16>; - // Map from (src-thr,src-val) to bit - // TODO(joe): Not convinced that changing from <_16, _256> should be required here - // but get_logical_layout assumes get<1,0>(layout.shape) is the type size - using SrcLayout = Layout>, - Stride< _0,Stride<_1, _8>>>; - // Map from (dst-thr,dst-val) to bit - using DstLayout = Layout>, - Stride<_128,Stride<_1, _8>>>; - // Reference map from (thr,val) to bit - using RefLayout = DstLayout; - - template - Copy_Traits_(ArgT... args) - : XE_2D_LD_Unpack(args...) {} -}; - -template -struct Copy_Traits_ - : XE_2D_LD_Unpack { - using ThrID = Layout<_16>; - // Map from (src-thr,src-val) to bit - // TODO(joe): Not convinced that changing from <_16, _256> should be required here - // but get_logical_layout assumes get<1,0>(layout.shape) is the type size - using SrcLayout = Layout>, - Stride< _0,Stride<_1,_8>>>; - // Map from (dst-thr,dst-val) to bit - using DstLayout = Layout>, - Stride<_64,Stride<_1,_8>>>; - // Reference map from (thr,val) to bit - using RefLayout = DstLayout; - - template - Copy_Traits_(ArgT... args) - : XE_2D_LD_Unpack(args...) {} -}; - -template -struct Copy_Traits_ - : XE_2D_LD_Unpack { - // Logical thread id to thread idx - using ThrID = Layout<_16>; - // Map from (src-thr,src-val) to bit - using SrcLayout = Layout>, - Stride< _0,Stride< _1,_256,_512>>>; - // Map from (dst-thr,dst-val) to bit - using DstLayout = Layout>, - Stride<_16,Stride< _1,_256,_512>>>; - // Reference map from (thr,val) to bit - using RefLayout = DstLayout; - - template - Copy_Traits_(ArgTs... args) - : XE_2D_LD_Unpack(args...) {} -}; - -template -struct Copy_Traits_ - : XE_2D_LD_Unpack { - using ThrID = Layout<_16>; - // Map from (src-thr,src-val) to bit - using SrcLayout = Layout>, - Stride< _0,Stride< _1,_256,_512>>>; - // Map from (dst-thr,dst-val) to bit - using DstLayout = Layout>, - Stride<_16,Stride< _1,_256,_512>>>; - // Reference map from (thr,val) to bit - using RefLayout = DstLayout; - - template - Copy_Traits_(ArgTs... args) - : XE_2D_LD_Unpack(args...) {} -}; - -template -struct Copy_Traits_ - : XE_2D_LD_Unpack { - // Logical thread id to thread idx - using ThrID = Layout<_16>; - // Map from (src-thr,src-val) to bit - using SrcLayout = Layout>, - Stride<_0,Stride< _1,_512,_256,_1024>>>; - // Map from (dst-thr,dst-val) to bit - using DstLayout = Layout>, - Stride<_16,Stride< _1,_512,_256,_1024>>>; - // Reference map from (thr,val) to bit - using RefLayout = DstLayout; - - template - Copy_Traits_(ArgTs... args) - : XE_2D_LD_Unpack(args...) {} -}; - -template -struct Copy_Traits_ - : XE_2D_LD_Unpack { - // Logical thread id to thread idx - using ThrID = Layout<_16>; - // Map from (src-thr,src-val) to bit - using SrcLayout = Layout>, - Stride<_0,Stride< _1,_512,_256,_1024>>>; - // Map from (dst-thr,dst-val) to bit - using DstLayout = Layout>, - Stride<_16,Stride< _1,_512,_256,_1024>>>; - // Reference map from (thr,val) to bit - using RefLayout = DstLayout; - - template - Copy_Traits_(ArgTs... args) - : XE_2D_LD_Unpack(args...) {} -}; - -template -struct Copy_Traits_ - : XE_2D_LD_Unpack { - using ThrID = Layout<_16>; - // Map from (src-thr,src-val) to bit - using SrcLayout = Layout>, - Stride< _0, Stride< _1,_16>>>; - // Map from (dst-thr,dst-val) to bit - using DstLayout = Layout>, - Stride<_128,Stride< _1,_16>>>; - // Reference map from (thr,val) to bit - using RefLayout = DstLayout; - - template - Copy_Traits_(ArgT... args) - : XE_2D_LD_Unpack(args...) {} -}; - -template -struct Copy_Traits_ - : XE_2D_LD_Unpack { - using ThrID = Layout<_16>; - // Map from (src-thr,src-val) to bit - // TODO(joe): Not convinced that changing from <_16, _256> should be required here - // but get_logical_layout assumes get<1,0>(layout.shape) is the type size - using SrcLayout = Layout>, - Stride<_0,Stride< _1,_16>>>; - // Map from (dst-thr,dst-val) to bit - using DstLayout = Layout>, - Stride<_256,Stride< _1,_16>>>; - // Reference map from (thr,val) to bit - using RefLayout = DstLayout; - - template - Copy_Traits_(ArgT... args) - : XE_2D_LD_Unpack(args...) {} -}; - -// template -// struct Copy_Traits -// : XE_2D_LD_Unpack { -// // Logical thread id to thread idx -// using ThrID = Layout<_16>; -// // Map from (src-thr,src-val) to bit -// using SrcLayout = Layout, -// Stride< _0, _1>>; -// // Map from (dst-thr,dst-val) to bit -// using DstLayout = Layout, -// Stride<_32, _1>>; -// // Reference map from (thr,val) to bit -// using RefLayout = DstLayout; -// }; - -template -struct Copy_Traits_ - : XE_2D_LD_Unpack { - // Logical thread id to thread idx - using ThrID = Layout<_16>; - // Map from (src-thr,src-val) to bit - using SrcLayout = Layout>, - Stride< _0,Stride< _1,_32>>>; - // Map from (dst-thr,dst-val) to bit - using DstLayout = Layout>, - Stride<_64,Stride< _1,_32>>>; - // Reference map from (thr,val) to bit - using RefLayout = DstLayout; - - template - Copy_Traits_(ArgTs... args) - : XE_2D_LD_Unpack(args...) {} -}; - -template -struct Copy_Traits_ - : XE_2D_LD_Unpack { - // Logical thread id to thread idx - using ThrID = Layout<_16>; - // Map from (src-thr,src-val) to bit - using SrcLayout = Layout>, - Stride< _0,Stride< _1,_32>>>; - // Map from (dst-thr,dst-val) to bit - using DstLayout = Layout>, - Stride<_64,Stride< _1,_32>>>; - // Reference map from (thr,val) to bit - using RefLayout = DstLayout; - - template - Copy_Traits_(ArgTs... args) - : XE_2D_LD_Unpack(args...) {} -}; - -template -struct Copy_Traits_ - : XE_2D_LD_Unpack { - // Logical thread id to thread idx - using ThrID = Layout<_16>; - // Map from (src-thr,src-val) to bit - using SrcLayout = Layout>, - Stride< _0,Stride< _1,_32>>>; - // Map from (dst-thr,dst-val) to bit - using DstLayout = Layout>, - Stride<_128,Stride< _1,_32>>>; - // Reference map from (thr,val) to bit - using RefLayout = DstLayout; - - template - Copy_Traits_(ArgTs... args) - : XE_2D_LD_Unpack(args...) {} -}; - -template -struct Copy_Traits_ - : XE_2D_LD_Unpack { - // Logical thread id to thread idx - using ThrID = Layout<_16>; - // Map from (src-thr,src-val) to bit - using SrcLayout = Layout,Shape <_32, _16>>, - Stride,Stride< _1,_256>>>; - // Map from (dst-thr,dst-val) to bit - using DstLayout = Layout,Shape <_32, _16>>, - Stride,Stride< _1,_256>>>; - // Reference map from (thr,val) to bit - using RefLayout = DstLayout; -}; - -template -struct Copy_Traits_ - : XE_2D_LD_Unpack { - // Logical thread id to thread idx - using ThrID = Layout<_16>; - // Map from (src-thr,src-val) to bit - using SrcLayout = Layout, - Stride< _0, _1>>; - // Map from (dst-thr,dst-val) to bit - using DstLayout = Layout,_64>, - Stride, _1>>; - // Reference map from (thr,val) to bit - using RefLayout = DstLayout; - - template - Copy_Traits_(ArgTs... args) - : XE_2D_LD_Unpack(args...) {} -}; - -template -struct Copy_Traits_ - : XE_2D_LD_Unpack { - // Logical thread id to thread idx - using ThrID = Layout<_16>; - // Map from (src-thr,src-val) to bit - using SrcLayout = Layout, - Stride< _0, _1>>; - // Map from (dst-thr,dst-val) to bit - using DstLayout = Layout,Shape <_64, _2>>, - Stride,Stride< _1,_64>>>; - // Reference map from (thr,val) to bit - using RefLayout = DstLayout; - - template - Copy_Traits_(ArgTs... args) - : XE_2D_LD_Unpack(args...) {} -}; - -template -struct Copy_Traits_ - : XE_2D_LD_Unpack { - // Logical thread id to thread idx - using ThrID = Layout<_16>; - // Map from (src-thr,src-val) to bit - using SrcLayout = Layout, - Stride< _0, _1>>; - // Map from (dst-thr,dst-val) to bit - using DstLayout = Layout,Shape <_64, _4>>, - Stride,Stride< _1,_64>>>; - // Reference map from (thr,val) to bit - using RefLayout = DstLayout; - - template - Copy_Traits_(ArgTs... args) - : XE_2D_LD_Unpack(args...) {} -}; - -template -struct Copy_Traits_ - : XE_2D_ST_Unpack { - using ThrID = Layout<_16>; - // Map from (src-thr,src-val) to bit - using SrcLayout = Layout>, - Stride<_16,Stride< _1,_128,_256>>>; - // Map from (dst-thr,dst-val) to bit - using DstLayout = Layout>, - Stride< _0,Stride< _1,_128,_256>>>; - // Reference map from (thr,val) to bit - using RefLayout = SrcLayout; - - template - Copy_Traits_(ArgT... args) - : XE_2D_ST_Unpack(args...) {} -}; - -template -struct Copy_Traits_ - : XE_2D_ST_Unpack { - // Logical thread id to thread idx - using ThrID = Layout<_16>; - // Map from (src-thr,src-val) to bit - using SrcLayout = Layout, - Stride< _8,_1>>; - // Map from (dst-thr,dst-val) to bit - using DstLayout = Layout, - Stride< _0,_1>>; - // Reference map from (thr,val) to bit - using RefLayout = SrcLayout; - - template - Copy_Traits_(ArgTs... args) - : XE_2D_ST_Unpack(args...) {} -}; - -template -struct Copy_Traits_ - : XE_2D_ST_Unpack { - // Logical thread id to thread idx - using ThrID = Layout<_16>; - // Map from (src-thr,src-val) to bit - using SrcLayout = Layout>, - Stride< _8,Stride<_1,_128>>>; - // Map from (dst-thr,dst-val) to bit - using DstLayout = Layout, - Stride< _0,_1>>; - // Reference map from (thr,val) to bit - using RefLayout = SrcLayout; - - template - Copy_Traits_(ArgTs... args) - : XE_2D_ST_Unpack(args...) {} -}; - -template -struct Copy_Traits_ - : XE_2D_ST_Unpack { - // Logical thread id to thread idx - using ThrID = Layout<_16>; - // Map from (src-thr,src-val) to bit - using SrcLayout = Layout>, - Stride< _8,Stride<_1,_128>>>; - // Map from (dst-thr,dst-val) to bit - using DstLayout = Layout, - Stride< _0,_1>>; - // Reference map from (thr,val) to bit - using RefLayout = SrcLayout; -}; - -template -struct Copy_Traits_ - : XE_2D_ST_Unpack { - // Logical thread id to thread idx - using ThrID = Layout<_16>; - // Map from (src-thr,src-val) to bit - using SrcLayout = Layout>, - Stride< _8,Stride<_1,_128>>>; - // Map from (dst-thr,dst-val) to bit - using DstLayout = Layout, - Stride< _0,_1>>; - // Reference map from (thr,val) to bit - using RefLayout = SrcLayout; - using CopyInternalType = uint8_t; - - template - Copy_Traits_(ArgT... args) - : XE_2D_ST_Unpack(args...) {} -}; - -template -struct Copy_Traits_ - : XE_2D_ST_Unpack { - // Logical thread id to thread idx - using ThrID = Layout<_16>; - // Map from (src-thr,src-val) to bit - using SrcLayout = - Layout>, Stride<_0, Stride<_0, _1>>>; - // Map from (dst-thr,dst-val) to bit - using DstLayout = Layout, - Stride< _0,_1>>; - // Reference map from (thr,val) to bit - using RefLayout = SrcLayout; - using CopyInternalType = uint8_t; - - template - Copy_Traits_(ArgT... args) - : XE_2D_ST_Unpack(args...) {} -}; - -template -struct Copy_Traits_ - : XE_2D_ST_Unpack { - // Logical thread id to thread idx - using ThrID = Layout<_16>; - // Map from (src-thr,src-val) to bit - using SrcLayout = Layout, - Stride<_16, _1>>; - // Map from (dst-thr,dst-val) to bit - using DstLayout = Layout, - Stride< _0, _1>>; - // Reference map from (thr,val) to bit - using RefLayout = SrcLayout; - - template - Copy_Traits_(ArgT... args) - : XE_2D_ST_Unpack(args...) {} -}; - -template -struct Copy_Traits_ - : XE_2D_ST_Unpack { - // Logical thread id to thread idx - using ThrID = Layout<_16>; - // Map from (src-thr,src-val) to bit - using SrcLayout = Layout>, - Stride<_16,Stride< _1,_256>>>; - // Map from (dst-thr,dst-val) to bit - using DstLayout = Layout>, - Stride< _0,Stride< _1,_256>>>; - // Reference map from (thr,val) to bit - using RefLayout = SrcLayout; - - template - Copy_Traits_(ArgT... args) - : XE_2D_ST_Unpack(args...) {} -}; - -template -struct Copy_Traits_ - : XE_2D_ST_Unpack { - // Logical thread id to thread idx - using ThrID = Layout<_16>; - // Map from (src-thr,src-val) to bit - using SrcLayout = Layout>, - Stride<_16,Stride< _1,_256>>>; - // Map from (dst-thr,dst-val) to bit - using DstLayout = Layout>, - Stride< _0,Stride< _1,_256>>>; - // Reference map from (thr,val) to bit - using RefLayout = SrcLayout; - - template - Copy_Traits_(ArgT... args) - : XE_2D_ST_Unpack(args...) {} -}; - -template -struct Copy_Traits_ - : XE_2D_ST_Unpack { - // Logical thread id to thread idx - using ThrID = Layout<_16>; - // Map from (src-thr,src-val) to bit - using SrcLayout = Layout>, - Stride<_16,Stride< _1,_256>>>; - // Map from (dst-thr,dst-val) to bit - using DstLayout = Layout>, - Stride< _0,Stride< _1,_256>>>; - // Reference map from (thr,val) to bit - using RefLayout = SrcLayout; - - template - Copy_Traits_(ArgT... args) - : XE_2D_ST_Unpack(args...) {} -}; - -template -struct Copy_Traits_ - : XE_2D_ST_Unpack { - // Logical thread id to thread idx - using ThrID = Layout<_16>; - // Map from (src-thr,src-val) to bit - using SrcLayout = Layout, - Stride<_32, _1>>; - // Map from (dst-thr,dst-val) to bit - using DstLayout = Layout, - Stride< _0, _1>>; - // Reference map from (thr,val) to bit - using RefLayout = SrcLayout; - - template - Copy_Traits_(ArgT... args) - : XE_2D_ST_Unpack(args...) {} -}; - -template -struct Copy_Traits_ - : XE_2D_ST_Unpack { - // Logical thread id to thread idx - using ThrID = Layout<_16>; - // Map from (src-thr,src-val) to bit - using SrcLayout = Layout>, - Stride<_32,Stride< _1,_512>>>; - // Map from (dst-thr,dst-val) to bit - using DstLayout = Layout>, - Stride< _0,Stride< _1,_512>>>; - // Reference map from (thr,val) to bit - using RefLayout = SrcLayout; - - template - Copy_Traits_(ArgT... args) - : XE_2D_ST_Unpack(args...) {} -}; - -template -struct Copy_Traits_ - : XE_2D_ST_Unpack { - // Logical thread id to thread idx - using ThrID = Layout<_16>; - // Map from (src-thr,src-val) to bit - using SrcLayout = Layout>, - Stride<_32,Stride< _1,_512>>>; - // Map from (dst-thr,dst-val) to bit - using DstLayout = Layout>, - Stride< _0,Stride< _1,_512>>>; - // Reference map from (thr,val) to bit - using RefLayout = SrcLayout; - - template - Copy_Traits_(ArgT... args) - : XE_2D_ST_Unpack(args...) {} -}; - -template -struct Copy_Traits_ - : XE_2D_ST_Unpack { - // Logical thread id to thread idx - using ThrID = Layout<_16>; - // Map from (src-thr,src-val) to bit - using SrcLayout = Layout>, - Stride<_32,Stride< _1,_512>>>; - // Map from (dst-thr,dst-val) to bit - using DstLayout = Layout>, - Stride< _0,Stride< _1,_512>>>; // 0 here makes all threads in a warp get the same base address - // Reference map from (thr,val) to bit - using RefLayout = SrcLayout; - - template - Copy_Traits_(ArgTs... args) - : XE_2D_ST_Unpack(args...) {} -}; - template struct Copy_Traits> { // Logical thread id to thread idx (one-thread) @@ -2535,159 +103,4 @@ struct Copy_Traits> { using RefLayout = SrcLayout; }; -// This is the Copy_Traits for Xe 2D Block copies, which inherits from `Copy_Traits_` and handles -// transposing the traits depending on the layout of the tensor in memory (MN vs. K-major). -// Since we can't SFINAE this (Copy_Traits has no Enable = void template param), we are -// obliged to define the actual Copy_Traits for each instruction individually -// TODO(codeplay): Revisit this after SPIR-V copy builtins added -// TODO(codeplay): Is it safe that we default to row-major when constructing Copy_Traits, -// or should we insist that the developer provide the stride? -#define COPY_TRAIT_LD_DEF(COPY_OP) \ -template \ -struct Copy_Traits : Copy_Traits_{ \ - using CopyOp = COPY_OP; \ - using Base = Copy_Traits_; \ - using XE_2D_LD_Unpack::is_matrix_B; \ - using typename Base::ThrID; \ - using BlockShape = std::conditional_t; \ - using SrcLayout = decltype(detail::get_logical_layout(typename Base::SrcLayout{}, typename Base::BlockShape{})); \ - using DstLayout = decltype(detail::get_logical_layout(typename Base::DstLayout{}, typename Base::BlockShape{})); \ - using RefLayout = DstLayout; \ - template \ - Copy_Traits(ArgTs... args) \ - : Copy_Traits_(args...) {} \ -}; - -#define COPY_TRAIT_ST_DEF(COPY_OP) \ -template \ -struct Copy_Traits : Copy_Traits_{ \ - using CopyOp = COPY_OP; \ - using Base = Copy_Traits_; \ - using XE_2D_ST_Unpack::is_matrix_B; \ - using typename Base::ThrID; \ - using BlockShape = std::conditional_t; \ - using SrcLayout = decltype(detail::get_logical_layout(typename Base::SrcLayout{}, typename Base::BlockShape{})); \ - using DstLayout = decltype(detail::get_logical_layout(typename Base::DstLayout{}, typename Base::BlockShape{})); \ - using RefLayout = SrcLayout; \ - template \ - Copy_Traits(ArgTs... args) \ - : Copy_Traits_(args...) {} \ -}; - -COPY_TRAIT_LD_DEF(XE_2D_Packed_U4x1x128_LD_N) -COPY_TRAIT_LD_DEF(XE_2D_U8x1x16_LD_N) -COPY_TRAIT_LD_DEF(XE_2D_U8x1x32_LD_N) -COPY_TRAIT_LD_DEF(XE_2D_Packed_U8x1x32_LD_N) -COPY_TRAIT_LD_DEF(XE_2D_Packed_U8x2x32_LD_N) -COPY_TRAIT_LD_DEF(XE_2D_Packed_U8x4x32_LD_N) -COPY_TRAIT_LD_DEF(XE_2D_Packed_U8x8x32_LD_N) -COPY_TRAIT_LD_DEF(XE_2D_U8x8x32_LD_N) -COPY_TRAIT_LD_DEF(XE_2D_Packed_U8x1x64_LD_N) -COPY_TRAIT_LD_DEF(XE_2D_Packed_U8x2x64_LD_N) -COPY_TRAIT_LD_DEF(XE_2D_Packed_U8x4x64_LD_N) -COPY_TRAIT_LD_DEF(XE_2D_Packed_U8x8x64_LD_N) -COPY_TRAIT_LD_DEF(XE_2D_U8x32x8_LD_T) -COPY_TRAIT_LD_DEF(XE_2D_U8x32x4_LD_T) -COPY_TRAIT_LD_DEF(XE_2D_U64x8x1_LD_T) -COPY_TRAIT_LD_DEF(XE_2D_U64x8x2_LD_T) -COPY_TRAIT_LD_DEF(XE_2D_U64x8x4_LD_T) -COPY_TRAIT_LD_DEF(XE_2D_Packed_U8x16x32_LD_N) -COPY_TRAIT_LD_DEF(XE_2D_Packed_U8x32x32_LD_N) -COPY_TRAIT_LD_DEF(XE_2D_Packed_U8x16x64_LD_N) -COPY_TRAIT_LD_DEF(XE_2D_Packed_U8x32x64_LD_N) -COPY_TRAIT_LD_DEF(XE_2D_U8x32x32_LD_N) -COPY_TRAIT_LD_DEF(XE_2D_U16x1x16_LD_N) -COPY_TRAIT_LD_DEF(XE_2D_U16x2x16_LD_N) -COPY_TRAIT_LD_DEF(XE_2D_U16x4x16_LD_N) -COPY_TRAIT_LD_DEF(XE_2D_U16x8x16_LD_N) -COPY_TRAIT_LD_DEF(XE_2D_U16x1x32_LD_N) -COPY_TRAIT_LD_DEF(XE_2D_U16x2x32_LD_N) -COPY_TRAIT_LD_DEF(XE_2D_U16x4x32_LD_N) -COPY_TRAIT_LD_DEF(XE_2D_U16x8x32_LD_N) -COPY_TRAIT_LD_DEF(XE_2D_U16x16x32_LD_N) -COPY_TRAIT_LD_DEF(XE_2D_TF32x1x8_LD_N) -COPY_TRAIT_LD_DEF(XE_2D_TF32x2x8_LD_N) -COPY_TRAIT_LD_DEF(XE_2D_TF32x4x8_LD_N) -COPY_TRAIT_LD_DEF(XE_2D_TF32x8x8_LD_N) -COPY_TRAIT_LD_DEF(XE_2D_TF32x8x8_LD_T) -COPY_TRAIT_LD_DEF(XE_2D_U32x1x16_LD_N) -COPY_TRAIT_LD_DEF(XE_2D_U32x2x16_LD_N) -COPY_TRAIT_LD_DEF(XE_2D_U32x4x16_LD_N) -COPY_TRAIT_LD_DEF(XE_2D_U32x8x16_LD_N) -COPY_TRAIT_LD_DEF(XE_2D_U8x32x16_LD_V) -COPY_TRAIT_LD_DEF(XE_2D_U8x32x32_LD_V) -COPY_TRAIT_LD_DEF(XE_2D_U8x32x64_LD_V) -COPY_TRAIT_LD_DEF(XE_2D_U8x16x32_LD_T) -COPY_TRAIT_LD_DEF(XE_2D_U8x16x16_LD_T) -COPY_TRAIT_LD_DEF(XE_2D_U8x16x8_LD_T) -COPY_TRAIT_LD_DEF(XE_2D_U16x16x8_LD_T) -COPY_TRAIT_LD_DEF(XE_2D_U32x16x2_LD_T) -COPY_TRAIT_LD_DEF(XE_2D_U32x16x4_LD_T) -COPY_TRAIT_LD_DEF(XE_2D_U32x16x8_LD_T) -COPY_TRAIT_LD_DEF(XE_2D_U16x16x16_LD_N) -COPY_TRAIT_LD_DEF(XE_2D_Packed_U16x16x32_LD_N) -COPY_TRAIT_LD_DEF(XE_2D_U16x32x16_LD_N) -COPY_TRAIT_LD_DEF(XE_2D_U16x32x32_LD_N) -COPY_TRAIT_LD_DEF(XE_2D_Packed_U16x32x32_LD_N) -COPY_TRAIT_LD_DEF(XE_2D_TF32x16x8_LD_N) -COPY_TRAIT_LD_DEF(XE_2D_TF32x32x8_LD_N) -COPY_TRAIT_LD_DEF(XE_2D_TF32x1x16_LD_N) -COPY_TRAIT_LD_DEF(XE_2D_TF32x2x16_LD_N) -COPY_TRAIT_LD_DEF(XE_2D_TF32x4x16_LD_N) -COPY_TRAIT_LD_DEF(XE_2D_TF32x8x16_LD_N) -COPY_TRAIT_LD_DEF(XE_2D_U32x16x16_LD_N) -COPY_TRAIT_LD_DEF(XE_2D_U32x32x16_LD_N) -COPY_TRAIT_LD_DEF(XE_2D_U16x16x16_LD_V) -COPY_TRAIT_LD_DEF(XE_2D_U16x32x16_LD_V) -COPY_TRAIT_LD_DEF(XE_2D_U16x32x32_LD_V) -COPY_TRAIT_LD_DEF(XE_2D_U16x16x32_LD_V) -COPY_TRAIT_LD_DEF(XE_2D_U16x16x16_LD_T) -COPY_TRAIT_LD_DEF(XE_2D_TF32x16x16_LD_N) -COPY_TRAIT_LD_DEF(XE_2D_TF32x32x16_LD_N) -COPY_TRAIT_LD_DEF(XE_2D_U4x32x64_LD_N) -COPY_TRAIT_LD_DEF(XE_2D_U4x16x64_LD_N) -COPY_TRAIT_LD_DEF(XE_2D_U4x32x16_LD_T) -COPY_TRAIT_LD_DEF(XE_2D_U4x16x8_LD_T) -COPY_TRAIT_LD_DEF(XE_2D_U4x16x16_LD_T) -COPY_TRAIT_LD_DEF(XE_2D_Packed_U8x1x64_LD_N::PREFETCH) -COPY_TRAIT_LD_DEF(XE_2D_Packed_U8x2x64_LD_N::PREFETCH) -COPY_TRAIT_LD_DEF(XE_2D_Packed_U8x4x64_LD_N::PREFETCH) -COPY_TRAIT_LD_DEF(XE_2D_Packed_U8x8x64_LD_N::PREFETCH) -COPY_TRAIT_LD_DEF(XE_2D_Packed_U8x16x64_LD_N::PREFETCH) -COPY_TRAIT_LD_DEF(XE_2D_Packed_U8x32x64_LD_N::PREFETCH) -COPY_TRAIT_LD_DEF(XE_2D_U16x8x16_LD_N::PREFETCH) -COPY_TRAIT_LD_DEF(XE_2D_U16x1x32_LD_N::PREFETCH) -COPY_TRAIT_LD_DEF(XE_2D_U16x2x32_LD_N::PREFETCH) -COPY_TRAIT_LD_DEF(XE_2D_U16x4x32_LD_N::PREFETCH) -COPY_TRAIT_LD_DEF(XE_2D_U16x8x32_LD_N::PREFETCH) -COPY_TRAIT_LD_DEF(XE_2D_U8x32x16_LD_V::PREFETCH) -COPY_TRAIT_LD_DEF(XE_2D_U32x16x8_LD_T::PREFETCH) -COPY_TRAIT_LD_DEF(XE_2D_U16x16x16_LD_N::PREFETCH) -COPY_TRAIT_LD_DEF(XE_2D_U16x32x16_LD_N::PREFETCH) -COPY_TRAIT_LD_DEF(XE_2D_U16x16x32_LD_N::PREFETCH) -COPY_TRAIT_LD_DEF(XE_2D_U16x32x32_LD_N::PREFETCH) - -COPY_TRAIT_ST_DEF(XE_2D_U8x2x32_ST_N) -COPY_TRAIT_ST_DEF(XE_2D_U8x1x16_ST_N) -COPY_TRAIT_ST_DEF(XE_2D_U8x2x16_ST_N) -COPY_TRAIT_ST_DEF(XE_2D_U8x4x16_ST_N) -COPY_TRAIT_ST_DEF(XE_2D_U8x8x16_ST_N) -COPY_TRAIT_ST_DEF(XE_2D_U8x8x32_ST_N) -COPY_TRAIT_ST_DEF(XE_2D_U16x1x16_ST_N) -COPY_TRAIT_ST_DEF(XE_2D_U16x2x16_ST_N) -COPY_TRAIT_ST_DEF(XE_2D_U16x4x16_ST_N) -COPY_TRAIT_ST_DEF(XE_2D_U16x8x16_ST_N) -COPY_TRAIT_ST_DEF(XE_2D_U32x1x16_ST_N) -COPY_TRAIT_ST_DEF(XE_2D_U32x2x16_ST_N) -COPY_TRAIT_ST_DEF(XE_2D_U32x4x16_ST_N) -COPY_TRAIT_ST_DEF(XE_2D_U32x8x16_ST_N) - -// Generate the Xe coordinate tensor -template -CUTE_HOST_DEVICE constexpr -auto -get_xe_tensor(GShape const& g_shape) { - return make_coord_tensor(make_identity_layout(g_shape)); -} - } // end namespace cute diff --git a/include/cute/atom/copy_traits_xe_legacy.hpp b/include/cute/atom/copy_traits_xe_legacy.hpp new file mode 100644 index 0000000000..fffd557e7f --- /dev/null +++ b/include/cute/atom/copy_traits_xe_legacy.hpp @@ -0,0 +1,2628 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2024 Codeplay Software Ltd. All rights reserved. + * Copyright (C) 2025 Intel Corporation, All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include +#include + +#include +#include + +namespace cute { + +namespace detail { + +static constexpr auto subgroup_size = 16; + +// ========== size_of_inst_bits ========== +template +static constexpr auto size_of_inst_bits = sizeof_bits_v; + +template +static constexpr auto size_of_inst_bits> = sizeof_bits_v; + + +// ========== is_transpose_load ========== +template +static constexpr bool is_transpose_load = false; + +template +static constexpr bool is_transpose_load>> = T::is_transpose; + + +// ========== is_stride_leftmost ========== +template +static constexpr bool is_stride_leftmost = std::is_same_v<_1, decltype(get<0>(T{}))>; + +template +static constexpr bool is_stride_leftmost> = std::is_same_v<_1, decltype(get<0>(T{}.stride()))>; + +// Swap the Src or Dst Layout of a Copy_Traits if the logical/memory layouts differ +template +auto get_logical_layout(LayoutIn &&, BlockShape &&) { + static_assert(cute::rank(BlockShape{}) == 2, "Expected 2D BlockShape for XE_2D copy op."); + static_assert(cute::rank(LayoutIn{}) == 2, "Expected 2D LayoutIn for XE_2D copy op."); + if constexpr (!is_matrix_B) { + return LayoutIn{}; + } else { + // (16, (32, 2)) + // ^-- the size of an element in bits + static_assert(size(LayoutIn{}) % size(BlockShape{}) == 0); + constexpr int ElemBitSize = size(LayoutIn{}) / size(BlockShape{}); + // Construct a generic row-major layout of the relevant size + using RowMajorLayout = + decltype(make_ordered_layout(Shape, BlockShape>{}, Step<_0, Step<_2, _1>>{})); + // Compose with LayoutIn to produce the transposed Copy_Traits layout + return right_inverse(RowMajorLayout{}).compose(LayoutIn{}); + } +} +} // end namespace detail + +template +struct choose_prefetch_for_type { + static_assert(dependent_false<>, "Invalid prefetch"); +}; + +// U4 +template <> +struct choose_prefetch_for_type<4, 1> { + using Prefetch = XE_2D_U16x1x32_LD_N; +}; + +template <> +struct choose_prefetch_for_type<4, 2> { + using Prefetch = XE_2D_U16x2x32_LD_N; +}; + +template <> +struct choose_prefetch_for_type<4, 4> { + using Prefetch = XE_2D_U16x4x32_LD_N; +}; + +template <> +struct choose_prefetch_for_type<4, 8> { + using Prefetch = XE_2D_U16x8x32_LD_N; +}; + +template <> +struct choose_prefetch_for_type<4, 16> { + using Prefetch = XE_2D_U16x16x32_LD_N; +}; + +template <> +struct choose_prefetch_for_type<4, 32> { + using Prefetch = XE_2D_U16x32x32_LD_N; +}; + +// U8 +template <> +struct choose_prefetch_for_type<8, 1> { + using Prefetch = XE_2D_Packed_U8x1x64_LD_N; +}; + +template <> +struct choose_prefetch_for_type<8, 2> { + using Prefetch = XE_2D_Packed_U8x2x64_LD_N; +}; + +template <> +struct choose_prefetch_for_type<8, 4> { + using Prefetch = XE_2D_Packed_U8x4x64_LD_N; +}; + +template <> +struct choose_prefetch_for_type<8, 8> { + using Prefetch = XE_2D_Packed_U8x8x64_LD_N; +}; + +template <> +struct choose_prefetch_for_type<8, 16> { + using Prefetch = XE_2D_Packed_U8x16x64_LD_N; +}; + +template <> +struct choose_prefetch_for_type<8, 32> { + using Prefetch = XE_2D_Packed_U8x32x64_LD_N; +}; + +// U16 +template <> +struct choose_prefetch_for_type<16, 1> { + using Prefetch = XE_2D_U16x1x32_LD_N; +}; + +template <> +struct choose_prefetch_for_type<16, 2> { + using Prefetch = XE_2D_U16x2x32_LD_N; +}; + +template <> +struct choose_prefetch_for_type<16, 4> { + using Prefetch = XE_2D_U16x4x32_LD_N; +}; + +template <> +struct choose_prefetch_for_type<16, 8> { + using Prefetch = XE_2D_U16x8x32_LD_N; +}; + +template <> +struct choose_prefetch_for_type<16, 16> { + using Prefetch = XE_2D_U16x16x32_LD_N; +}; + +template <> +struct choose_prefetch_for_type<16, 32> { + using Prefetch = XE_2D_U16x32x32_LD_N; +}; + +// U32 +template <> +struct choose_prefetch_for_type<32, 1> { + using Prefetch = XE_2D_U32x1x16_LD_N; +}; + +template <> +struct choose_prefetch_for_type<32, 2> { + using Prefetch = XE_2D_U32x2x16_LD_N; +}; + +template <> +struct choose_prefetch_for_type<32, 4> { + using Prefetch = XE_2D_U32x4x16_LD_N; +}; + +template <> +struct choose_prefetch_for_type<32, 8> { + using Prefetch = XE_2D_U32x8x16_LD_N; +}; + +template <> +struct choose_prefetch_for_type<32, 16> { + using Prefetch = XE_2D_U32x16x16_LD_N; +}; + +template <> +struct choose_prefetch_for_type<32, 32> { + using Prefetch = XE_2D_U32x32x16_LD_N; +}; + +template +CUTE_HOST_DEVICE auto prefetch_selector(Tensor const& tensor) { + constexpr size_t cacheline_bytes = 64; + using dtype = typename Tensor::value_type; + constexpr size_t dtype_size_bits = sizeof_bits_v; + constexpr bool is_tensor_M_major = detail::is_stride_leftmost; + using CopyThreadShape = std::conditional_t, _1>, + Shape<_1, Int>>; + + constexpr int tile_contig_size = is_tensor_M_major ? size<0>(TileShape{}) : size<1>(TileShape{}); + constexpr int tile_non_contig_size = is_tensor_M_major ? size<1>(TileShape{}) : size<0>(TileShape{}); + + // block here is what is prefetched in one atom execution + // min(32,32)-> 32 (256, 32) -> 32 + static constexpr auto block_contig_size = cute::min(tile_contig_size, cacheline_bytes * sizeof_bits_v / sizeof_bits_v); + // A: 1 -> trans or B 256/32 = 8 + static constexpr auto nums_blocks_contig = ceil_div(tile_contig_size, block_contig_size); + + // layout of sub groups + // A shape<32,1> / trans or B shape<4,8> + constexpr int sgs_contig = cute::gcd(Num_SGs, nums_blocks_contig); + constexpr int sgs_non_contig = Num_SGs / sgs_contig; + + constexpr auto block_non_contig_size = tile_non_contig_size / sgs_non_contig; + + using PrefetchTilingLayout = std::conditional_t, Int>, Int>, + Stride>, Int>>, + Layout, Shape, Int>>, + Stride, Stride<_1, Int>>> + >; + + using PrefetchOp = typename choose_prefetch_for_type::Prefetch; + using PrefetchTraits = Copy_Traits; + using PrefetchAtom = Copy_Atom; + using Scalar = Int / dtype_size_bits)>; + using ScalarLayout = std::conditional_t>, + Layout>>; + using ScalarPrefetchShape = decltype(product_each(raked_product(ScalarLayout{}, + Layout{}).shape())); + using PrefetchValLayout = decltype(make_layout(shape_div(ScalarPrefetchShape{}, CopyThreadShape{}))); + return make_tiled_copy(PrefetchAtom{}.with(tensor), PrefetchTilingLayout{}, PrefetchValLayout{}); + +} + +template +CUTE_HOST_DEVICE auto prefetch_selector(TiledCopy const& tiled_copy) { + using Tiled_Copy = TiledCopy; + constexpr int subgroup_size = size(typename Tiled_Copy::Traits_LD_t::ThrID{}); + int M, N; + if constexpr (Tiled_Copy::is_tensor_M_major) { + M = tiled_copy.width; + N = tiled_copy.height; + } else{ + M = tiled_copy.height; + N = tiled_copy.width; + } + // L is not used in prefetch_selector and we do not have the correct value here. Just set it to some arbitrary value. + int L = 1; + auto data = make_gmem_ptr(static_cast(tiled_copy.base_ptr)); + auto shape = make_shape(M, N, L); + auto stride = [=](){ + if constexpr (Tiled_Copy::is_tensor_M_major){ + return make_stride(_1{}, tiled_copy.pitch, tiled_copy.stride_l); + }else{ + return make_stride(tiled_copy.pitch, _1{}, tiled_copy.stride_l); + } + }(); + auto tensor = make_tensor(data, make_layout(shape, stride)); + return cute::prefetch_selector(tensor); +} + + +template , int64_t>> +struct XE_2D_LD_Unpack { + using BlockShape = typename CopyOp::BlockShape; // this is not the same as Traits_LD_t::BlockShape iff is_matrix_B + using Traits_LD_t = Copy_Traits; + static constexpr auto stride_rank = rank(StrideOrTensor{}); + static_assert(stride_rank == 2 || stride_rank == 3); + + // Assume LD_T/LD_N will be used for column/row major matrices respectively + static constexpr bool is_transpose_copy = detail::is_transpose_load; + + // We need to reverse some parameters becasue intel xe 2d copy intrinsic always assume the matrix use (M, N):(N, 1) layout + // M-major if we label the matrix shape (M,N,L). M-major for matrix A or C is col-major. For matrix B it is row-major. + static constexpr bool is_tensor_M_major = detail::is_stride_leftmost; + + // For matrix B cute internally has transposed representation compared to other matrices, for cute its shape is (N,K) + // Intel copy instructions, on the other hand follow blas convention, where matrix B has shape (K,N) + static constexpr bool is_matrix_B = is_tensor_M_major ^ is_transpose_copy; + + using CopyThreadShape = Shape<_1, Int>; + // we can not use Traits_LD_t::BlockShape as this is a parent class of Traits_LD_t, so that would be recursion. Recalculate it instead. + using DefaultValLayout = decltype(make_layout(shape_div(std::conditional_t{}, CopyThreadShape{}))); + + template + using DefaultTiledCopy = decltype(make_tiled_copy(Copy_Atom{}, Layout{}, DefaultValLayout{})); + + // 2d copy parameters + const void *base_ptr; + uint32_t width; + uint32_t height; + uint32_t pitch; + uint32_t stride_l = 0; + + + XE_2D_LD_Unpack(const void *ptr, uint32_t y, + uint32_t x, uint32_t p = 0) : base_ptr(ptr) { + if constexpr (is_tensor_M_major) { + width = y; + height = x; + } + else { + width = x; + height = y; + } + + pitch = (p == 0 ? width : p); + } + + template + XE_2D_LD_Unpack(Tensor const &tensor) { + base_ptr = raw_pointer_cast(tensor.data()); + + if constexpr (is_tensor_M_major) + { + width = size<0>(tensor.shape()); + height = size<1>(tensor.shape()); + pitch = size<1>(tensor.stride()); + } + else + { + width = size<1>(tensor.shape()); + height = size<0>(tensor.shape()); + pitch = size<0>(tensor.stride()); + } + + if constexpr (stride_rank == 3) { + stride_l = size<2>(tensor.stride()); + } + } + + XE_2D_LD_Unpack(Traits_LD_t const &traits) : base_ptr(traits.base_ptr), + width(traits.width), height(traits.height), pitch(traits.pitch), + stride_l(traits.stride_l) {} + + XE_2D_LD_Unpack() {} + + template + CUTE_HOST_DEVICE friend constexpr void + copy_unpack(Traits_LD_t const &traits, Tensor const &src, + Tensor &dst) { + using dtype = typename Tensor::value_type; + constexpr int dtype_bits = sizeof_bits_v; + + static_assert(is_rmem::value); + static_assert(size(SLayout{}) * dtype_bits == size<1>(typename Traits_LD_t::SrcLayout{}), + "Src tensor size does not match copy atom size."); + static_assert(size(DLayout{}) * dtype_bits == size<1>(typename Traits_LD_t::DstLayout{}), + "Dst tensor size does not match copy atom size."); + + dtype *base_addr = (dtype *)traits.base_ptr; + + auto [m, n, l] = src.data().coord_; + int x = is_tensor_M_major ? m : n; + int y = is_tensor_M_major ? n : m; + + constexpr auto inst_size_bits = detail::size_of_inst_bits; + + CopyOp::copy(((uint8_t*)base_addr) + static_cast(l) * traits.stride_l * sizeof_bits_v / 8, + (traits.width * sizeof_bits_v) / sizeof_bits_v, traits.height, + (traits.pitch * sizeof_bits_v) / sizeof_bits_v, + intel::coord_t{(int)(x * sizeof_bits_v / inst_size_bits), y}, + raw_pointer_cast(&((&*dst.data())[0]))); + } + + template + CUTE_HOST_DEVICE friend constexpr void + prefetch(Copy_Atom, CopyType> const& atom, + Tensor const& src) { + using dtype = typename Copy_Atom, CopyType>::ValType; + + static_assert(detail::has_prefetch); + static_assert(size(SLayout{}) * sizeof_bits_v == size<1>(typename Copy_Atom, CopyType>::SrcLayout{}), + "Src tensor size does not match copy atom for prefetch size"); + + dtype *base_addr = (dtype *)atom.base_ptr; + + auto [m, n, l] = src.data().coord_; + + int x = is_tensor_M_major ? m : n; + int y = is_tensor_M_major ? n : m; + + constexpr auto inst_size_bits = detail::size_of_inst_bits; + + CopyOp::PREFETCH::copy(((uint8_t*)(base_addr)) + static_cast(l) * atom.stride_l * sizeof_bits_v / 8, + (atom.width * sizeof_bits_v) / sizeof_bits_v, atom.height, + (atom.pitch * sizeof_bits_v) / sizeof_bits_v, + intel::coord_t{(int)(x * sizeof_bits_v / inst_size_bits), y}); + } + + template + static constexpr auto with(Tensor const &tensor) { + return Traits_LD_t{tensor}; + } + + template + static constexpr auto with(T0 && arg0, T1 && arg1, Ts&&... args) { + return Traits_LD_t{arg0, arg1, args...}; + } +}; + +template , int64_t>> struct XE_2D_ST_Unpack { + using Traits_ST_t = Copy_Traits; + using BlockShape = typename CopyOp::BlockShape; + + static constexpr auto stride_rank = rank(StrideOrTensor{}); + static_assert(stride_rank == 2 || stride_rank == 3); + + static constexpr bool is_matrix_B = false; + + const void *base_ptr; + uint32_t width; + uint32_t height; + uint32_t pitch; + uint32_t stride_l = 0; + + XE_2D_ST_Unpack(const void *ptr, uint32_t y, + uint32_t x, uint32_t p = 0) : base_ptr(ptr) { + width = x; + height = y; + pitch = (p == 0 ? width : p); + } + + template + XE_2D_ST_Unpack(Tensor const &tensor) { + base_ptr = tensor.data().get(); + width = size<1>(tensor.shape()); + height = size<0>(tensor.shape()); + pitch = size<0>(tensor.stride()); + + if constexpr (stride_rank == 3) { + stride_l = size<2>(tensor.stride()); + } + } + + XE_2D_ST_Unpack(Traits_ST_t const &traits) : base_ptr(traits.base_ptr), + width(traits.width), height(traits.height), pitch(traits.pitch), + stride_l(traits.stride_l) {} + + XE_2D_ST_Unpack() {} + + + template + CUTE_HOST_DEVICE friend constexpr void + copy_unpack(Traits_ST_t const &traits, Tensor const &src, + Tensor &dst) { + + using dtype = typename Tensor::value_type; + constexpr int dtype_bits = sizeof_bits_v; + + static_assert(is_rmem::value); + static_assert(size(SLayout{}) * dtype_bits == size<1>(typename Traits_ST_t::SrcLayout{}), + "Src tensor size does not match copy atom size."); + static_assert(size(DLayout{}) * dtype_bits == size<1>(typename Traits_ST_t::DstLayout{}), + "Dst tensor size does not match copy atom size."); + + dtype *base_addr = (dtype *)traits.base_ptr; + + auto [m, n, l] = dst.data().coord_; + + CopyOp::copy(((uint8_t*)(base_addr)) + static_cast(l) * traits.stride_l * sizeof_bits_v / 8, + traits.width * sizeof(dtype), traits.height, + traits.pitch * sizeof(dtype), + intel::coord_t{(int)n, (int)m}, &*src.data()); + } + + template + static constexpr auto with(Tensor const &tensor) { + return Traits_ST_t{tensor}; + } + + template + static constexpr auto with(T0 && arg0, T1 && arg1, Ts&&... args) { + return Traits_ST_t{arg0, arg1, args...}; + } + +}; + +template +CUTE_HOST_DEVICE constexpr auto make_fragment_layout(TiledCopy &tiled_copy, + TLShape &&fragment_top_level_shape) { + // Shapes are reversed for col major case between registers and global memory, + // so all variables contain in their name whether they refer to the shape in registers or in global memory + + // TODO(Codeplay): reverse values in 2d (U8) MMA atoms instead + constexpr auto mma_atom_regs_shape = cute::reverse(get<0>(TLShape{})); + using MmaValsShapeRegs2d = std::conditional_t(mma_atom_regs_shape, _1{})), + decltype(append<2>(mma_atom_regs_shape, _1{}))>; + + using ThreadLayout_ = Shape, _1>; + using ThreadLayoutRegs = std::conditional_t; + using BlockShapeRegs = typename TiledCopy::BlockShape; + using TotalMmaAtomItersRegs = decltype(select<1,2>(TLShape{})); + + using CopyValsShapeRegs = decltype(shape_div(BlockShapeRegs{}, ThreadLayoutRegs{})); + // This case would need to rearrange data in registers between copy and mma calls + static_assert(get<0>(CopyValsShapeRegs{}) >= get<0>(MmaValsShapeRegs2d{}) || + get<1>(CopyValsShapeRegs{}) <= get<1>(MmaValsShapeRegs2d{}), + "It is not currently supported to have MMA atom be bigger than copy atom in one dimension and smaller in other dimension!"); + using MmaItersInCopyRegs = decltype(ceil_div(CopyValsShapeRegs{}, MmaValsShapeRegs2d{})); + using CopyItersRegs = decltype(shape_div(TotalMmaAtomItersRegs{}, MmaItersInCopyRegs{})); + + auto order = std::conditional_t, Step<_3, _5>, Step<_2, _4>>, + Step, Step<_2, _4>, Step<_3, _5>>>{}; + + auto res = make_ordered_layout( + prepend(cute::zip(MmaItersInCopyRegs{}, CopyItersRegs{}), MmaValsShapeRegs2d{}), + order); + + static_assert(size(res) == size(TLShape{}), "Internal eror in make_fragment_layout()."); + return res; +}; + +// clang-format off + +template +struct Copy_Traits_{ + static_assert(cute::dependent_false, "Copy_Traits_ not defined for this CopyOp"); +}; + +template +struct Copy_Traits_ + : XE_2D_LD_Unpack { + using ThrID = Layout<_16>; + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout, + Stride< _0, _1>>; + // Map from (dst-thr,dst-val) to bit + // mode 0 16:8 This will show there are 16 thread ehrtr each thread with stride of value with 8 bits away from the adjusent thread + // mode 1: <_8>:< _1> This says each thread will get 1 element each of them 8 bits. + using DstLayout = Layout, + Stride<_8, _1>>; + // Reference map from (thr,val) to bit + using RefLayout = DstLayout; + + template + Copy_Traits_(ArgT... args) + : XE_2D_LD_Unpack(args...) {} +}; + + +template +struct Copy_Traits_ + : XE_2D_LD_Unpack { + using ThrID = Layout<_16>; + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout>, + Stride< _0, Stride <_1, _128, _256>>>; + // Map from (dst-thr,dst-val) to bit + // mode 0 16:8 This will show there are 16 thread ehrtr each thread with stride of value with 8 bits away from the adjusent thread + // mode 1: <_8, _2, _1>:< _1, _128, _256> This says each thread will get 2x1 element + // each of them 8 bits. The stired shows each thread jumps 16x8 bits for the next element in the block and 16x8x2 for the next row in the block + using DstLayout = Layout>, + Stride<_8, Stride < _1, _128, _256>>>; + // Reference map from (thr,val) to bit + using RefLayout = DstLayout; + + template + Copy_Traits_(ArgT... args) + : XE_2D_LD_Unpack(args...) {} +}; + +template +struct Copy_Traits_ + : XE_2D_LD_Unpack { + using ThrID = Layout<_16>; + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout, + Stride< _0,_1>>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout, + Stride<_16, _1>>; + // Reference map from (thr,val) to bit + using RefLayout = DstLayout; + + template + Copy_Traits_(ArgT... args) + : XE_2D_LD_Unpack(args...) {} +}; + +template +struct Copy_Traits_ + : XE_2D_LD_Unpack { + using ThrID = Layout<_16>; + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout>, + Stride< _0,Stride< _1,_128,_256>>>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout>, + Stride<_16,Stride< _1,_128,_256>>>; + // Reference map from (thr,val) to bit + using RefLayout = DstLayout; + + template + Copy_Traits_(ArgT... args) + : XE_2D_LD_Unpack(args...) {} +}; + +template +struct Copy_Traits_ + : XE_2D_LD_Unpack { + using ThrID = Layout<_16>; + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout>, + Stride< _0,Stride< _1,_128,_256>>>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout>, + Stride<_16,Stride< _1,_128,_256>>>; + // Reference map from (thr,val) to bit + using RefLayout = DstLayout; + + template + Copy_Traits_(ArgT... args) + : XE_2D_LD_Unpack(args...) {} +}; + +template +struct Copy_Traits_ + : XE_2D_LD_Unpack { + using ThrID = Layout<_16>; + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout>, + Stride< _0,Stride< _1,_128,_256>>>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout>, + Stride<_16,Stride< _1,_128,_256>>>; + // Reference map from (thr,val) to bit + using RefLayout = DstLayout; + + template + Copy_Traits_(ArgT... args) + : XE_2D_LD_Unpack(args...) {} +}; + + +template +struct Copy_Traits_ + : XE_2D_LD_Unpack { + using ThrID = Layout<_16>; + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout>, + Stride< _0,Stride< _1,_128,_256>>>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout>, // 16 thread with stride of value with 8 bits away from second thread + // Mode1 :The second parameter shows the jump for each thread int bits//the third prameter is with of the row in bits(32x8 bits) + Stride<_8,Stride< _1,_128,_256>>>; + // Reference map from (thr,val) to bit + using RefLayout = DstLayout; + + template + Copy_Traits_(ArgT... args) + : XE_2D_LD_Unpack(args...) {} +}; + +template +struct Copy_Traits_ + : XE_2D_LD_Unpack { + using ThrID = Layout<_16>; + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout>, + Stride< _0,Stride< _1,_8,_256>>>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout>, + Stride<_16,Stride< _1,_8,_256>>>; + // Reference map from (thr,val) to bit + using RefLayout = DstLayout; + + template + Copy_Traits_(ArgT... args) + : XE_2D_LD_Unpack(args...) {} +}; + +template +struct Copy_Traits_ + : XE_2D_LD_Unpack { + using ThrID = Layout<_16>; + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout>, + Stride< _0,Stride< _1,_128,_256>>>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout>, + Stride<_8,Stride< _1,_128,_256>>>; + // Reference map from (thr,val) to bit + using RefLayout = DstLayout; + + template + Copy_Traits_(ArgT... args) + : XE_2D_LD_Unpack(args...) {} +}; + +template +struct Copy_Traits_ + : XE_2D_LD_Unpack { + using ThrID = Layout<_16>; + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout>, + Stride< _0,Stride< _1,_8,_256>>>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout>, + Stride<_16,Stride< _1,_8,_256>>>; + // Reference map from (thr,val) to bit + using RefLayout = DstLayout; + + template + Copy_Traits_(ArgT... args) + : XE_2D_LD_Unpack(args...) {} +}; + +template +struct Copy_Traits_ + : XE_2D_LD_Unpack { + using ThrID = Layout<_16>; + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout>, + Stride< _0,Stride< _1, _4>>>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout>, + Stride<_32,Stride< _1, _4>>>; + // Reference map from (thr,val) to bit + using RefLayout = DstLayout; + + template + Copy_Traits_(ArgT... args) + : XE_2D_LD_Unpack(args...) {} +}; + +template +struct Copy_Traits_ + : XE_2D_LD_Unpack { + using ThrID = Layout<_16>; + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout>, + Stride<_0,Stride< _1, _4, _16>>>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout>, + Stride<_16,Stride< _1, _4, _16>>>; + // Reference map from (thr,val) to bit + using RefLayout = DstLayout; + + template + Copy_Traits_(ArgT... args) + : XE_2D_LD_Unpack(args...) {} +}; + +template +struct Copy_Traits_ + : XE_2D_LD_Unpack { + using ThrID = Layout<_16>; + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout>, + Stride<_0,Stride <_1, _4, _16>>>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout>, + Stride<_16,Stride<_1, _4, _16>>>; + // Reference map from (thr,val) to bit + using RefLayout = DstLayout; + + template + Copy_Traits_(ArgT... args) + : XE_2D_LD_Unpack(args...) {} +}; + +template +struct Copy_Traits_ + : XE_2D_LD_Unpack { + using ThrID = Layout<_16>; + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout>, + Stride<_0,Stride< _1, _4>>>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout>, + Stride<_32,Stride<_1, _4>>>; + // Reference map from (thr,val) to bit + using RefLayout = DstLayout; + + template + Copy_Traits_(ArgT... args) + : XE_2D_LD_Unpack(args...) {} +}; + +template +struct Copy_Traits_ + : XE_2D_LD_Unpack { + using ThrID = Layout<_16>; + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout>, + Stride<_0,Stride< _1,_64>>>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout>, + Stride<_16,Stride< _1,_64>>>; + // Reference map from (thr,val) to bit + using RefLayout = DstLayout; + + template + Copy_Traits_(ArgT... args) + : XE_2D_LD_Unpack(args...) {} +}; +template +struct Copy_Traits_ + : XE_2D_LD_Unpack { + using ThrID = Layout<_16>; + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout>, + Stride<_0,Stride< _1,_64>>>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout>, + Stride<_16,Stride< _1,_64>>>; + // Reference map from (thr,val) to bit + using RefLayout = DstLayout; + + template + Copy_Traits_(ArgT... args) + : XE_2D_LD_Unpack(args...) {} +}; + +template +struct Copy_Traits_ + : XE_2D_LD_Unpack { + using ThrID = Layout<_16>; + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout, + Stride< _0,_1>>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout>, + Stride<_16,Stride< _1,_256>>>; + // Reference map from (thr,val) to bit + using RefLayout = DstLayout; + + template + Copy_Traits_(ArgT... args) + : XE_2D_LD_Unpack(args...) {} +}; + +template +struct Copy_Traits_ + : XE_2D_LD_Unpack { + // Logical thread id to thread idx + using ThrID = Layout<_16>; + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout>, + Stride<_16,Stride<_1,_8,_256>>>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout>, + Stride<_16,Stride<_1,_8,_256>>>; + // Reference map from (thr,val) to bit + using RefLayout = DstLayout; + using CopyInternalType = cute::intel::ushort; +}; + +template +struct Copy_Traits_ + : XE_2D_LD_Unpack { + using ThrID = Layout<_16>; + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout>, + Stride< _0,Stride< _1,_512,_256>>>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout>, + Stride<_16,Stride< _1,_512,_256>>>; + // Reference map from (thr,val) to bit + using RefLayout = DstLayout; + + template + Copy_Traits_(ArgT... args) + : XE_2D_LD_Unpack(args...) {} +}; + + +template +struct Copy_Traits_ + : XE_2D_LD_Unpack { + // Logical thread id to thread idx + using ThrID = Layout<_16>; + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout>, + Stride<_16,Stride<_1,_8,_256,_512>>>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout>, + Stride<_16,Stride<_1,_8,_256,_512>>>; + // Reference map from (thr,val) to bit + using RefLayout = DstLayout; + using CopyInternalType = cute::intel::ushort; +}; + +template +struct Copy_Traits_ + : XE_2D_LD_Unpack { + using ThrID = Layout<_16>; + // Map from (dst-thr,dst-val) to bit + using SrcLayout = Layout>, + Stride< _0,Stride< _1,_512,_256>>>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout>, + Stride<_16,Stride< _1,_512,_256>>>; + // Reference map from (thr,val) to bit + using RefLayout = DstLayout; + + template + Copy_Traits_(ArgT... args) + : XE_2D_LD_Unpack(args...) {} +}; + +template +struct Copy_Traits_ + : XE_2D_LD_Unpack { + // Logical thread id to thread idx + using ThrID = Layout<_16>; + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout>, + Stride<_16,Stride<_1,_8,_256,_512>>>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout>, + Stride<_16,Stride<_1,_8,_256,_512>>>; + // Reference map from (thr,val) to bit + using RefLayout = DstLayout; + using CopyInternalType = cute::intel::ushort; +}; + +template +struct Copy_Traits_ + : XE_2D_LD_Unpack { + using ThrID = Layout<_16>; + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout>, + Stride< _0,Stride< _1,_512,_256>>>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout>, + Stride<_16,Stride< _1,_256,_512>>>; + // Reference map from (thr,val) to bit + using RefLayout = DstLayout; + + template + Copy_Traits_(ArgT... args) + : XE_2D_LD_Unpack(args...) {} +}; + +template +struct Copy_Traits_ + : XE_2D_LD_Unpack { + // Logical thread id to thread idx + using ThrID = Layout<_16>; + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout>, + Stride<_16,Stride<_1,_8,_256,_512>>>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout>, + Stride<_16,Stride<_1,_8,_256,_512>>>; + // Reference map from (thr,val) to bit + using RefLayout = DstLayout; + using CopyInternalType = cute::intel::ushort; +}; + + +template +struct Copy_Traits_ + : XE_2D_LD_Unpack { + using ThrID = Layout<_16>; + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout>, + Stride< _0,Stride< _1,_512,_256>>>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout>, + Stride<_16,Stride< _1,_256,_512>>>; + // Reference map from (thr,val) to bit + using RefLayout = DstLayout; + template + Copy_Traits_(ArgT... args) + : XE_2D_LD_Unpack(args...) {} +}; + +template +struct Copy_Traits_ + : XE_2D_LD_Unpack { + using ThrID = Layout<_16>; + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout>, + Stride<_16,Stride< _1,_256,_512>>>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout>, + Stride<_16,Stride< _1,_256,_512>>>; + // Reference map from (thr,val) to bit + using RefLayout = DstLayout; +}; + +template +struct Copy_Traits_ + : XE_2D_LD_Unpack { + using ThrID = Layout<_16>; + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout>, + Stride< _0,Stride< _1,_512,_256>>>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout>, + Stride<_16,Stride< _1,_512,_256>>>; + // Reference map from (thr,val) to bit + using RefLayout = DstLayout; + template + Copy_Traits_(ArgT... args) + : XE_2D_LD_Unpack(args...) {} +}; + +template +struct Copy_Traits_ + : XE_2D_LD_Unpack { + using ThrID = Layout<_16>; + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout>, + Stride<_16,Stride< _1,_256,_512>>>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout>, + Stride<_16,Stride< _1,_256,_512>>>; + // Reference map from (thr,val) to bit + using RefLayout = DstLayout; +}; + +template +struct Copy_Traits_ + : XE_2D_LD_Unpack { + using ThrID = Layout<_16>; + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout>, + Stride<_0, Stride<_1, _8, _16>>>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout>, + Stride<_256,Stride<_1, _8, _16>>>; + // Reference map from (thr,val) to bit + using RefLayout = DstLayout; + + template + Copy_Traits_(ArgT... args) + : XE_2D_LD_Unpack(args...) {} +}; + +template +struct Copy_Traits_ + : XE_2D_LD_Unpack { + using ThrID = Layout<_16>; + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout>, + Stride<_0, Stride<_1, _8, _16>>>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout>, + Stride<_256,Stride<_1, _8, _16>>>; + // Reference map from (thr,val) to bit + using RefLayout = DstLayout; + + template + Copy_Traits_(ArgT... args) + : XE_2D_LD_Unpack(args...) {} +}; + +template +struct Copy_Traits_ + : XE_2D_LD_Unpack { + using ThrID = Layout<_16>; + // Map from (src-thr,src-val) to bit + + using SrcLayout = Layout>, + Stride< _0,Stride< _1,_256>>>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout>, + Stride<_16,Stride< _1,_256>>>; + // Reference map from (thr,val) to bit + using RefLayout = DstLayout; + + template + Copy_Traits_(ArgT... args) + : XE_2D_LD_Unpack(args...) {} +}; + +template +struct Copy_Traits_ + : XE_2D_LD_Unpack { + using ThrID = Layout<_16>; + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout>, + Stride<_0,Stride< _1,_256>>>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout>, + Stride<_16,Stride< _1,_256>>>; + // Reference map from (thr,val) to bit + using RefLayout = DstLayout; + + template + Copy_Traits_(ArgT... args) + : XE_2D_LD_Unpack(args...) {} +}; + +template +struct Copy_Traits_ + : XE_2D_LD_Unpack { + using ThrID = Layout<_16>; + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout>, + Stride< _0,Stride< _1,_256>>>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout>, + Stride<_16,Stride< _1,_256>>>; + // Reference map from (thr,val) to bit + using RefLayout = DstLayout; + + template + Copy_Traits_(ArgT... args) + : XE_2D_LD_Unpack(args...) {} +}; + +template +struct Copy_Traits_ + : XE_2D_LD_Unpack { + // Logical thread id to thread idx + using ThrID = Layout<_16>; + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout>, + Stride< _0,Stride< _1,_256>>>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout>, + Stride<_16,Stride< _1,_256>>>; + // Reference map from (thr,val) to bit + using RefLayout = DstLayout; + + template + Copy_Traits_(ArgTs... args) + : XE_2D_LD_Unpack(args...) {} +}; + +template +struct Copy_Traits_ + : XE_2D_LD_Unpack { + // Logical thread id to thread idx + using ThrID = Layout<_16>; + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout>, + Stride<_16,Stride< _1,_256>>>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout>, + Stride<_16,Stride< _1,_256>>>; + // Reference map from (thr,val) to bit + using RefLayout = DstLayout; +}; + +template +struct Copy_Traits_ + : XE_2D_LD_Unpack { + using ThrID = Layout<_16>; + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout>, + Stride< _0, Stride< _1,_256>>>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout>, + Stride<_16,Stride< _1,_256>>>; + // Reference map from (thr,val) to bit + using RefLayout = DstLayout; + + template + Copy_Traits_(ArgTs... args) + : XE_2D_LD_Unpack(args...) {} +}; + +template +struct Copy_Traits_ + : XE_2D_LD_Unpack { + using ThrID = Layout<_16>; + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout>, + Stride< _0,Stride< _1, _16,_512>>>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout>, + Stride<_32,Stride< _1, _16,_512>>>; + // Reference map from (thr,val) to bit + using RefLayout = DstLayout; + + template + Copy_Traits_(ArgTs... args) + : XE_2D_LD_Unpack(args...) {} +}; + +template +struct Copy_Traits_ + : XE_2D_LD_Unpack { + // Logical thread id to thread idx + using ThrID = Layout<_16>; + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout>, + Stride<_16,Stride< _1,_256>>>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout>, + Stride<_16,Stride< _1,_256>>>; + // Reference map from (thr,val) to bit + using RefLayout = DstLayout; +}; + +template +struct Copy_Traits_ + : XE_2D_LD_Unpack { + using ThrID = Layout<_16>; + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout, + Stride< _0, _1>>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout>, + Stride<_16,Stride< _1,_256>>>; + // Reference map from (thr,val) to bit + using RefLayout = DstLayout; + + template + Copy_Traits_(ArgTs... args) + : XE_2D_LD_Unpack(args...) {} +}; + +template +struct Copy_Traits_ + : XE_2D_LD_Unpack { + // Logical thread id to thread idx + using ThrID = Layout<_16>; + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout>, + Stride<_16,Stride< _1,_256>>>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout>, + Stride<_16,Stride< _1,_256>>>; + // Reference map from (thr,val) to bit + using RefLayout = DstLayout; +}; + +template +struct Copy_Traits_ + : XE_2D_LD_Unpack { + using ThrID = Layout<_16>; + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout>, + Stride< _0,Stride< _1,_256>>>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout>, + Stride<_16,Stride< _1,_256>>>; + // Reference map from (thr,val) to bit + using RefLayout = DstLayout; + + template + Copy_Traits_(ArgTs... args) + : XE_2D_LD_Unpack(args...) {} +}; + +template +struct Copy_Traits_ + : XE_2D_LD_Unpack { + // Logical thread id to thread idx + using ThrID = Layout<_16>; + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout>, + Stride<_16,Stride< _1,_256>>>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout>, + Stride<_16,Stride< _1,_256>>>; + // Reference map from (thr,val) to bit + using RefLayout = DstLayout; +}; + +template +struct Copy_Traits_ + : XE_2D_LD_Unpack { + using ThrID = Layout<_16>; + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout>, + Stride< _0,Stride< _1,_256,_512>>>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout>, + Stride<_16,Stride< _1,_256,_512>>>; + // Reference map from (thr,val) to bit + using RefLayout = DstLayout; + + template + Copy_Traits_(ArgTs... args) + : XE_2D_LD_Unpack(args...) {} +}; + +template +struct Copy_Traits_ + : XE_2D_LD_Unpack { + // Logical thread id to thread idx + using ThrID = Layout<_16>; + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout>, + Stride<_16,Stride< _1,_256,_512>>>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout>, + Stride<_16,Stride< _1,_256,_512>>>; + // Reference map from (thr,val) to bit + using RefLayout = DstLayout; +}; + +template +struct Copy_Traits_ + : XE_2D_LD_Unpack { + using ThrID = Layout<_16>; + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout>, + Stride< _0,Stride< _1,_256,_512>>>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout>, + Stride<_16,Stride< _1,_256,_512>>>; + // Reference map from (thr,val) to bit + using RefLayout = DstLayout; + + template + Copy_Traits_(ArgTs... args) + : XE_2D_LD_Unpack(args...) {} +}; + +template +struct Copy_Traits_ + : XE_2D_LD_Unpack { + // Logical thread id to thread idx + using ThrID = Layout<_16>; + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout>, + Stride<_16,Stride< _1,_256,_512>>>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout>, + Stride<_16,Stride< _1,_256,_512>>>; + // Reference map from (thr,val) to bit + using RefLayout = DstLayout; +}; + +template +struct Copy_Traits_ + : XE_2D_LD_Unpack { + using ThrID = Layout<_16>; + + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout>, + Stride< _0,Stride< _1,_256,_512>>>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout>, + Stride<_16,Stride< _1,_256,_512>>>; + // Reference map from (thr,val) to bit + using RefLayout = DstLayout; + + template + Copy_Traits_(ArgTs... args) + : XE_2D_LD_Unpack(args...) {} +}; + +template +struct Copy_Traits_ + : XE_2D_LD_Unpack { + // Logical thread id to thread idx + using ThrID = Layout<_16>; + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout>, + Stride<_32,Stride< _1,_512>>>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout>, + Stride<_32,Stride< _1,_512>>>; + // Reference map from (thr,val) to bit + using RefLayout = DstLayout; + using CopyInternalType = cute::intel::ushort; +}; + +template +struct Copy_Traits_ + : XE_2D_LD_Unpack { + // Logical thread id to thread idx + using ThrID = Layout<_16>; + + + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout>, + Stride< _0,Stride< _1,_256,_512>>>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout>, + Stride<_16,Stride< _1,_256,_512>>>; + // Reference map from (thr,val) to bit + using RefLayout = DstLayout; + + template + Copy_Traits_(ArgTs... args) + : XE_2D_LD_Unpack(args...) {} +}; + +template +struct Copy_Traits_ + : XE_2D_LD_Unpack { + // Logical thread id to thread idx + using ThrID = Layout<_16>; + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout>, + Stride<_32,Stride< _1,_512>>>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout>, + Stride<_32,Stride< _1,_512>>>; + // Reference map from (thr,val) to bit + using RefLayout = DstLayout; + using CopyInternalType = cute::intel::ushort; +}; + +template +struct Copy_Traits_ + : XE_2D_LD_Unpack { + using ThrID = Layout<_16>; + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout>, + Stride<_0, Stride< _1,_256,_512>>>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout>, + Stride<_16, Stride< _1,_256,_512>>>; + // Reference map from (thr,val) to bit + using RefLayout = DstLayout; + + template + Copy_Traits_(ArgTs... args) + : XE_2D_LD_Unpack(args...) {} +}; + +template +struct Copy_Traits_ + : XE_2D_LD_Unpack { + using ThrID = Layout<_16>; + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout>, + Stride<_0, Stride<_1, _16, _512>>>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout>, + Stride<_32, Stride<_1, _16, _512>>>; + // Reference map from (thr,val) to bit + using RefLayout = DstLayout; + + template + Copy_Traits_(ArgTs... args) + : XE_2D_LD_Unpack(args...) {} +}; + +template +struct Copy_Traits_ + : XE_2D_LD_Unpack { + // Logical thread id to thread idx + using ThrID = Layout<_16>; + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout>, + Stride<_32,Stride< _1,_512>>>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout>, + Stride<_32,Stride< _1,_512>>>; + // Reference map from (thr,val) to bit + using RefLayout = DstLayout; + using CopyInternalType = cute::intel::ushort; +}; + +template +struct Copy_Traits_ + : XE_2D_LD_Unpack { + using ThrID = Layout<_16>; + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout, + Stride< _0, _1>>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout,_32>, + Stride, _1>>; + // Reference map from (thr,val) to bit + using RefLayout = DstLayout; + + template + Copy_Traits_(ArgTs... args) + : XE_2D_LD_Unpack(args...) {} +}; + +template +struct Copy_Traits_ + : XE_2D_LD_Unpack { + using ThrID = Layout<_16>; + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout, + Stride< _0, _1>>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout,_32>, + Stride, _1>>; + // Reference map from (thr,val) to bit + using RefLayout = DstLayout; + + template + Copy_Traits_(ArgTs... args) + : XE_2D_LD_Unpack(args...) {} +}; + +template +struct Copy_Traits_ + : XE_2D_LD_Unpack { + using ThrID = Layout<_16>; + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout, + Stride< _0, _1>>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout,Shape <_32, _2>>, + Stride,Stride< _1,_512>>>; + // Reference map from (thr,val) to bit + using RefLayout = DstLayout; + + template + Copy_Traits_(ArgTs... args) + : XE_2D_LD_Unpack(args...) {} +}; + +template +struct Copy_Traits_ + : XE_2D_LD_Unpack { + using ThrID = Layout<_16>; + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout, + Stride< _0, _1>>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout,Shape <_32, _4>>, + Stride,Stride< _1,_512>>>; + // Reference map from (thr,val) to bit + using RefLayout = DstLayout; + + template + Copy_Traits_(ArgTs... args) + : XE_2D_LD_Unpack(args...) {} +}; + +template +struct Copy_Traits_ + : XE_2D_LD_Unpack { + using ThrID = Layout<_16>; + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout, + Stride< _0, _1>>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout,Shape <_32, _8>>, + Stride,Stride< _1, _512>>>; + // Reference map from (thr,val) to bit + using RefLayout = DstLayout; + + template + Copy_Traits_(ArgTs... args) + : XE_2D_LD_Unpack(args...) {} +}; + +template +struct Copy_Traits_ + : XE_2D_LD_Unpack { + using ThrID = Layout<_16>; + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout, + Stride< _0, _1>>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout,Shape <_32, _16>>, + Stride,Stride< _1,_512>>>; + // Reference map from (thr,val) to bit + using RefLayout = DstLayout; + + template + Copy_Traits_(ArgTs... args) + : XE_2D_LD_Unpack(args...) {} +}; + +template +struct Copy_Traits_ + : XE_2D_LD_Unpack { + using ThrID = Layout<_16>; + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout, + Stride< _0, _1>>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout,Shape <_32, _2>>, + Stride,Stride< _1,_256>>>; + // Reference map from (thr,val) to bit + using RefLayout = DstLayout; + + template + Copy_Traits_(ArgTs... args) + : XE_2D_LD_Unpack(args...) {} +}; + +template +struct Copy_Traits_ + : XE_2D_LD_Unpack { + using ThrID = Layout<_16>; + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout, + Stride< _0, _1>>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout,Shape <_32, _2>>, + Stride,Stride< _1,_256>>>; + // Reference map from (thr,val) to bit + using RefLayout = DstLayout; + + template + Copy_Traits_(ArgTs... args) + : XE_2D_LD_Unpack(args...) {} +}; + +template +struct Copy_Traits_ + : XE_2D_LD_Unpack { + using ThrID = Layout<_16>; + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout, + Stride< _0, _1>>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout,Shape <_32, _2, _2>>, + Stride,Stride< _1,_256,_1024>>>; + // Reference map from (thr,val) to bit + using RefLayout = DstLayout; + + template + Copy_Traits_(ArgTs... args) + : XE_2D_LD_Unpack(args...) {} +}; + +template +struct Copy_Traits_ + : XE_2D_LD_Unpack { + using ThrID = Layout<_16>; + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout, + Stride< _0, _1>>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout,Shape <_32, _2, _4>>, + Stride,Stride< _1,_256,_1024>>>; + // Reference map from (thr,val) to bit + using RefLayout = DstLayout; + + template + Copy_Traits_(ArgTs... args) + : XE_2D_LD_Unpack(args...) {} +}; + +template +struct Copy_Traits_ + : XE_2D_LD_Unpack { + using ThrID = Layout<_16>; + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout, + Stride< _0, _1>>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout>, + Stride<_0,Stride< _512, _1>>>; + // Reference map from (thr,val) to bit + using RefLayout = DstLayout; + + template + Copy_Traits_(ArgTs... args) + : XE_2D_LD_Unpack(args...) {} +}; + +template +struct Copy_Traits_ + : XE_2D_LD_Unpack { + using ThrID = Layout<_16>; + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout, + Stride< _0, _1>>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout>, + Stride< _0, Stride<_512, Int<512 * 16>, _1>>>; + // Reference map from (thr,val) to bit + using RefLayout = DstLayout; + + template + Copy_Traits_(ArgTs... args) + : XE_2D_LD_Unpack(args...) {} +}; + +template +struct Copy_Traits_ + : XE_2D_LD_Unpack { + using ThrID = Layout<_16>; + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout>, + Stride< _0, Stride<_32, _1>>>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout>, + Stride< _32, Stride<_32, _1>>>; + // Reference map from (thr,val) to bit + using RefLayout = DstLayout; + + template + Copy_Traits_(ArgTs... args) + : XE_2D_LD_Unpack(args...) {} +}; + +template +struct Copy_Traits_ + : XE_2D_LD_Unpack { + using ThrID = Layout<_16>; + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout, + Stride< _0, _1>>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout, + Stride<_32, _1>>; + // Reference map from (thr,val) to bit + using RefLayout = DstLayout; + + template + Copy_Traits_(ArgTs... args) + : XE_2D_LD_Unpack(args...) {} +}; + +template +struct Copy_Traits_ + : XE_2D_LD_Unpack { + using ThrID = Layout<_16>; + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout>, + Stride< _0,Stride< _1,_512>>>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout>, + Stride<_32,Stride< _1,_512>>>; + // Reference map from (thr,val) to bit + using RefLayout = DstLayout; + + template + Copy_Traits_(ArgTs... args) + : XE_2D_LD_Unpack(args...) {} +}; + +template +struct Copy_Traits_ + : XE_2D_LD_Unpack { + using ThrID = Layout<_16>; + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout>, + Stride< _0, Stride< _1,_512>>>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout>, + Stride<_32,Stride< _1,_512>>>; + // Reference map from (thr,val) to bit + using RefLayout = DstLayout; + + template + Copy_Traits_(ArgTs... args) + : XE_2D_LD_Unpack(args...) {} +}; + +template +struct Copy_Traits_ + : XE_2D_LD_Unpack { + // Logical thread id to thread idx + using ThrID = Layout<_16>; + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout>, + Stride< _0, Stride< _1,_512>>>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout>, + Stride<_32,Stride< _1,_512>>>; + // Reference map from (thr,val) to bit + using RefLayout = DstLayout; + template + Copy_Traits_(ArgTs... args) + : XE_2D_LD_Unpack(args...) {} +}; + +template +struct Copy_Traits_ + : XE_2D_LD_Unpack { + using ThrID = Layout<_16>; + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout>, + Stride< _0,Stride< _1,_512>>>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout>, + Stride<_32,Stride< _1,_512>>>; + // Reference map from (thr,val) to bit + using RefLayout = DstLayout; + + template + Copy_Traits_(ArgTs... args) + : XE_2D_LD_Unpack(args...) {} +}; + +template +struct Copy_Traits_ + : XE_2D_LD_Unpack { + using ThrID = Layout<_16>; + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout>, + Stride< _0,Stride< _1,_512>>>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout>, + Stride<_32,Stride< _1,_512>>>; + // Reference map from (thr,val) to bit + using RefLayout = DstLayout; + + template + Copy_Traits_(ArgTs... args) + : XE_2D_LD_Unpack(args...) {} +}; + +template +struct Copy_Traits_ + : XE_2D_LD_Unpack { + using ThrID = Layout<_16>; + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout>, + Stride< _0,Stride<_1,_128>>>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout>, + Stride< _8,Stride<_1,_128>>>; + // Reference map from (thr,val) to bit + using RefLayout = DstLayout; + + template + Copy_Traits_(ArgTs... args) + : XE_2D_LD_Unpack(args...) {} +}; + +template +struct Copy_Traits_ + : XE_2D_LD_Unpack { + // Logical thread id to thread idx + using ThrID = Layout<_16>; + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout>, + Stride< _8,Stride<_1,_128>>>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout>, + Stride< _8,Stride<_1,_128>>>; + // Reference map from (thr,val) to bit + using RefLayout = DstLayout; +}; + +template +struct Copy_Traits_ + : XE_2D_LD_Unpack { + using ThrID = Layout<_16>; + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout>, + Stride< _0,Stride<_1,_256,_128>>>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout>, + Stride< _8,Stride<_1,_256,_128>>>; + // Reference map from (thr,val) to bit + using RefLayout = DstLayout; + + template + Copy_Traits_(ArgTs... args) + : XE_2D_LD_Unpack(args...) {} +}; + +template +struct Copy_Traits_ + : XE_2D_LD_Unpack { + using ThrID = Layout<_16>; + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout>, + Stride< _0,Stride<_1,_512, _8,_256>>>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout>, + Stride< _8,Stride<_1,_512, _8,_256>>>; + // Reference map from (thr,val) to bit + using RefLayout = DstLayout; + + template + Copy_Traits_(ArgTs... args) + : XE_2D_LD_Unpack(args...) {} +}; + +template +struct Copy_Traits_ + : XE_2D_LD_Unpack { + using ThrID = Layout<_16>; + // Map from (src-thr,src-val) to bit + // TODO(joe): Not convinced that changing from <_16, _256> should be required here + // but get_logical_layout assumes get<1,0>(layout.shape) is the type size + using SrcLayout = Layout>, + Stride< _0,Stride<_1, _8>>>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout>, + Stride<_256,Stride<_1, _8>>>; + // Reference map from (thr,val) to bit + using RefLayout = DstLayout; + + template + Copy_Traits_(ArgT... args) + : XE_2D_LD_Unpack(args...) {} +}; + +template +struct Copy_Traits_ + : XE_2D_LD_Unpack { + using ThrID = Layout<_16>; + // Map from (src-thr,src-val) to bit + // TODO(joe): Not convinced that changing from <_16, _256> should be required here + // but get_logical_layout assumes get<1,0>(layout.shape) is the type size + using SrcLayout = Layout>, + Stride< _0,Stride<_1, _8>>>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout>, + Stride<_128,Stride<_1, _8>>>; + // Reference map from (thr,val) to bit + using RefLayout = DstLayout; + + template + Copy_Traits_(ArgT... args) + : XE_2D_LD_Unpack(args...) {} +}; + +template +struct Copy_Traits_ + : XE_2D_LD_Unpack { + using ThrID = Layout<_16>; + // Map from (src-thr,src-val) to bit + // TODO(joe): Not convinced that changing from <_16, _256> should be required here + // but get_logical_layout assumes get<1,0>(layout.shape) is the type size + using SrcLayout = Layout>, + Stride< _0,Stride<_1,_8>>>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout>, + Stride<_64,Stride<_1,_8>>>; + // Reference map from (thr,val) to bit + using RefLayout = DstLayout; + + template + Copy_Traits_(ArgT... args) + : XE_2D_LD_Unpack(args...) {} +}; + +template +struct Copy_Traits_ + : XE_2D_LD_Unpack { + // Logical thread id to thread idx + using ThrID = Layout<_16>; + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout>, + Stride< _0,Stride< _1,_256,_512>>>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout>, + Stride<_16,Stride< _1,_256,_512>>>; + // Reference map from (thr,val) to bit + using RefLayout = DstLayout; + + template + Copy_Traits_(ArgTs... args) + : XE_2D_LD_Unpack(args...) {} +}; + +template +struct Copy_Traits_ + : XE_2D_LD_Unpack { + using ThrID = Layout<_16>; + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout>, + Stride< _0,Stride< _1,_256,_512>>>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout>, + Stride<_16,Stride< _1,_256,_512>>>; + // Reference map from (thr,val) to bit + using RefLayout = DstLayout; + + template + Copy_Traits_(ArgTs... args) + : XE_2D_LD_Unpack(args...) {} +}; + +template +struct Copy_Traits_ + : XE_2D_LD_Unpack { + // Logical thread id to thread idx + using ThrID = Layout<_16>; + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout>, + Stride<_0,Stride< _1,_512,_256,_1024>>>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout>, + Stride<_16,Stride< _1,_512,_256,_1024>>>; + // Reference map from (thr,val) to bit + using RefLayout = DstLayout; + + template + Copy_Traits_(ArgTs... args) + : XE_2D_LD_Unpack(args...) {} +}; + +template +struct Copy_Traits_ + : XE_2D_LD_Unpack { + // Logical thread id to thread idx + using ThrID = Layout<_16>; + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout>, + Stride<_0,Stride< _1,_512,_256,_1024>>>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout>, + Stride<_16,Stride< _1,_512,_256,_1024>>>; + // Reference map from (thr,val) to bit + using RefLayout = DstLayout; + + template + Copy_Traits_(ArgTs... args) + : XE_2D_LD_Unpack(args...) {} +}; + +template +struct Copy_Traits_ + : XE_2D_LD_Unpack { + using ThrID = Layout<_16>; + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout>, + Stride< _0, Stride< _1,_16>>>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout>, + Stride<_128,Stride< _1,_16>>>; + // Reference map from (thr,val) to bit + using RefLayout = DstLayout; + + template + Copy_Traits_(ArgT... args) + : XE_2D_LD_Unpack(args...) {} +}; + +template +struct Copy_Traits_ + : XE_2D_LD_Unpack { + using ThrID = Layout<_16>; + // Map from (src-thr,src-val) to bit + // TODO(joe): Not convinced that changing from <_16, _256> should be required here + // but get_logical_layout assumes get<1,0>(layout.shape) is the type size + using SrcLayout = Layout>, + Stride<_0,Stride< _1,_16>>>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout>, + Stride<_256,Stride< _1,_16>>>; + // Reference map from (thr,val) to bit + using RefLayout = DstLayout; + + template + Copy_Traits_(ArgT... args) + : XE_2D_LD_Unpack(args...) {} +}; + +// template +// struct Copy_Traits +// : XE_2D_LD_Unpack { +// // Logical thread id to thread idx +// using ThrID = Layout<_16>; +// // Map from (src-thr,src-val) to bit +// using SrcLayout = Layout, +// Stride< _0, _1>>; +// // Map from (dst-thr,dst-val) to bit +// using DstLayout = Layout, +// Stride<_32, _1>>; +// // Reference map from (thr,val) to bit +// using RefLayout = DstLayout; +// }; + +template +struct Copy_Traits_ + : XE_2D_LD_Unpack { + // Logical thread id to thread idx + using ThrID = Layout<_16>; + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout>, + Stride< _0,Stride< _1,_32>>>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout>, + Stride<_64,Stride< _1,_32>>>; + // Reference map from (thr,val) to bit + using RefLayout = DstLayout; + + template + Copy_Traits_(ArgTs... args) + : XE_2D_LD_Unpack(args...) {} +}; + +template +struct Copy_Traits_ + : XE_2D_LD_Unpack { + // Logical thread id to thread idx + using ThrID = Layout<_16>; + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout>, + Stride< _0,Stride< _1,_32>>>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout>, + Stride<_64,Stride< _1,_32>>>; + // Reference map from (thr,val) to bit + using RefLayout = DstLayout; + + template + Copy_Traits_(ArgTs... args) + : XE_2D_LD_Unpack(args...) {} +}; + +template +struct Copy_Traits_ + : XE_2D_LD_Unpack { + // Logical thread id to thread idx + using ThrID = Layout<_16>; + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout>, + Stride< _0,Stride< _1,_32>>>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout>, + Stride<_128,Stride< _1,_32>>>; + // Reference map from (thr,val) to bit + using RefLayout = DstLayout; + + template + Copy_Traits_(ArgTs... args) + : XE_2D_LD_Unpack(args...) {} +}; + +template +struct Copy_Traits_ + : XE_2D_LD_Unpack { + // Logical thread id to thread idx + using ThrID = Layout<_16>; + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout,Shape <_32, _16>>, + Stride,Stride< _1,_256>>>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout,Shape <_32, _16>>, + Stride,Stride< _1,_256>>>; + // Reference map from (thr,val) to bit + using RefLayout = DstLayout; +}; + +template +struct Copy_Traits_ + : XE_2D_LD_Unpack { + // Logical thread id to thread idx + using ThrID = Layout<_16>; + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout, + Stride< _0, _1>>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout,_64>, + Stride, _1>>; + // Reference map from (thr,val) to bit + using RefLayout = DstLayout; + + template + Copy_Traits_(ArgTs... args) + : XE_2D_LD_Unpack(args...) {} +}; + +template +struct Copy_Traits_ + : XE_2D_LD_Unpack { + // Logical thread id to thread idx + using ThrID = Layout<_16>; + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout, + Stride< _0, _1>>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout,Shape <_64, _2>>, + Stride,Stride< _1,_64>>>; + // Reference map from (thr,val) to bit + using RefLayout = DstLayout; + + template + Copy_Traits_(ArgTs... args) + : XE_2D_LD_Unpack(args...) {} +}; + +template +struct Copy_Traits_ + : XE_2D_LD_Unpack { + // Logical thread id to thread idx + using ThrID = Layout<_16>; + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout, + Stride< _0, _1>>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout,Shape <_64, _4>>, + Stride,Stride< _1,_64>>>; + // Reference map from (thr,val) to bit + using RefLayout = DstLayout; + + template + Copy_Traits_(ArgTs... args) + : XE_2D_LD_Unpack(args...) {} +}; + +template +struct Copy_Traits_ + : XE_2D_ST_Unpack { + using ThrID = Layout<_16>; + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout>, + Stride<_16,Stride< _1,_128,_256>>>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout>, + Stride< _0,Stride< _1,_128,_256>>>; + // Reference map from (thr,val) to bit + using RefLayout = SrcLayout; + + template + Copy_Traits_(ArgT... args) + : XE_2D_ST_Unpack(args...) {} +}; + +template +struct Copy_Traits_ + : XE_2D_ST_Unpack { + // Logical thread id to thread idx + using ThrID = Layout<_16>; + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout, + Stride< _8,_1>>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout, + Stride< _0,_1>>; + // Reference map from (thr,val) to bit + using RefLayout = SrcLayout; + + template + Copy_Traits_(ArgTs... args) + : XE_2D_ST_Unpack(args...) {} +}; + +template +struct Copy_Traits_ + : XE_2D_ST_Unpack { + // Logical thread id to thread idx + using ThrID = Layout<_16>; + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout>, + Stride< _8,Stride<_1,_128>>>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout, + Stride< _0,_1>>; + // Reference map from (thr,val) to bit + using RefLayout = SrcLayout; + + template + Copy_Traits_(ArgTs... args) + : XE_2D_ST_Unpack(args...) {} +}; + +template +struct Copy_Traits_ + : XE_2D_ST_Unpack { + // Logical thread id to thread idx + using ThrID = Layout<_16>; + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout>, + Stride< _8,Stride<_1,_128>>>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout, + Stride< _0,_1>>; + // Reference map from (thr,val) to bit + using RefLayout = SrcLayout; +}; + +template +struct Copy_Traits_ + : XE_2D_ST_Unpack { + // Logical thread id to thread idx + using ThrID = Layout<_16>; + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout>, + Stride< _8,Stride<_1,_128>>>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout, + Stride< _0,_1>>; + // Reference map from (thr,val) to bit + using RefLayout = SrcLayout; + using CopyInternalType = uint8_t; + + template + Copy_Traits_(ArgT... args) + : XE_2D_ST_Unpack(args...) {} +}; + +template +struct Copy_Traits_ + : XE_2D_ST_Unpack { + // Logical thread id to thread idx + using ThrID = Layout<_16>; + // Map from (src-thr,src-val) to bit + using SrcLayout = + Layout>, Stride<_0, Stride<_0, _1>>>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout, + Stride< _0,_1>>; + // Reference map from (thr,val) to bit + using RefLayout = SrcLayout; + using CopyInternalType = uint8_t; + + template + Copy_Traits_(ArgT... args) + : XE_2D_ST_Unpack(args...) {} +}; + +template +struct Copy_Traits_ + : XE_2D_ST_Unpack { + // Logical thread id to thread idx + using ThrID = Layout<_16>; + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout, + Stride<_16, _1>>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout, + Stride< _0, _1>>; + // Reference map from (thr,val) to bit + using RefLayout = SrcLayout; + + template + Copy_Traits_(ArgT... args) + : XE_2D_ST_Unpack(args...) {} +}; + +template +struct Copy_Traits_ + : XE_2D_ST_Unpack { + // Logical thread id to thread idx + using ThrID = Layout<_16>; + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout>, + Stride<_16,Stride< _1,_256>>>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout>, + Stride< _0,Stride< _1,_256>>>; + // Reference map from (thr,val) to bit + using RefLayout = SrcLayout; + + template + Copy_Traits_(ArgT... args) + : XE_2D_ST_Unpack(args...) {} +}; + +template +struct Copy_Traits_ + : XE_2D_ST_Unpack { + // Logical thread id to thread idx + using ThrID = Layout<_16>; + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout>, + Stride<_16,Stride< _1,_256>>>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout>, + Stride< _0,Stride< _1,_256>>>; + // Reference map from (thr,val) to bit + using RefLayout = SrcLayout; + + template + Copy_Traits_(ArgT... args) + : XE_2D_ST_Unpack(args...) {} +}; + +template +struct Copy_Traits_ + : XE_2D_ST_Unpack { + // Logical thread id to thread idx + using ThrID = Layout<_16>; + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout>, + Stride<_16,Stride< _1,_256>>>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout>, + Stride< _0,Stride< _1,_256>>>; + // Reference map from (thr,val) to bit + using RefLayout = SrcLayout; + + template + Copy_Traits_(ArgT... args) + : XE_2D_ST_Unpack(args...) {} +}; + +template +struct Copy_Traits_ + : XE_2D_ST_Unpack { + // Logical thread id to thread idx + using ThrID = Layout<_16>; + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout, + Stride<_32, _1>>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout, + Stride< _0, _1>>; + // Reference map from (thr,val) to bit + using RefLayout = SrcLayout; + + template + Copy_Traits_(ArgT... args) + : XE_2D_ST_Unpack(args...) {} +}; + +template +struct Copy_Traits_ + : XE_2D_ST_Unpack { + // Logical thread id to thread idx + using ThrID = Layout<_16>; + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout>, + Stride<_32,Stride< _1,_512>>>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout>, + Stride< _0,Stride< _1,_512>>>; + // Reference map from (thr,val) to bit + using RefLayout = SrcLayout; + + template + Copy_Traits_(ArgT... args) + : XE_2D_ST_Unpack(args...) {} +}; + +template +struct Copy_Traits_ + : XE_2D_ST_Unpack { + // Logical thread id to thread idx + using ThrID = Layout<_16>; + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout>, + Stride<_32,Stride< _1,_512>>>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout>, + Stride< _0,Stride< _1,_512>>>; + // Reference map from (thr,val) to bit + using RefLayout = SrcLayout; + + template + Copy_Traits_(ArgT... args) + : XE_2D_ST_Unpack(args...) {} +}; + +template +struct Copy_Traits_ + : XE_2D_ST_Unpack { + // Logical thread id to thread idx + using ThrID = Layout<_16>; + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout>, + Stride<_32,Stride< _1,_512>>>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout>, + Stride< _0,Stride< _1,_512>>>; // 0 here makes all threads in a warp get the same base address + // Reference map from (thr,val) to bit + using RefLayout = SrcLayout; + + template + Copy_Traits_(ArgTs... args) + : XE_2D_ST_Unpack(args...) {} +}; + +// This is the Copy_Traits for Xe 2D Block copies, which inherits from `Copy_Traits_` and handles +// transposing the traits depending on the layout of the tensor in memory (MN vs. K-major). +// Since we can't SFINAE this (Copy_Traits has no Enable = void template param), we are +// obliged to define the actual Copy_Traits for each instruction individually +// TODO(codeplay): Revisit this after SPIR-V copy builtins added +// TODO(codeplay): Is it safe that we default to row-major when constructing Copy_Traits, +// or should we insist that the developer provide the stride? +#define COPY_TRAIT_LD_DEF(COPY_OP) \ +template \ +struct Copy_Traits : Copy_Traits_{ \ + using CopyOp = COPY_OP; \ + using Base = Copy_Traits_; \ + using XE_2D_LD_Unpack::is_matrix_B; \ + using typename Base::ThrID; \ + using BlockShape = std::conditional_t; \ + using SrcLayout = decltype(detail::get_logical_layout(typename Base::SrcLayout{}, typename Base::BlockShape{})); \ + using DstLayout = decltype(detail::get_logical_layout(typename Base::DstLayout{}, typename Base::BlockShape{})); \ + using RefLayout = DstLayout; \ + template \ + Copy_Traits(ArgTs... args) \ + : Copy_Traits_(args...) {} \ +}; + +#define COPY_TRAIT_ST_DEF(COPY_OP) \ +template \ +struct Copy_Traits : Copy_Traits_{ \ + using CopyOp = COPY_OP; \ + using Base = Copy_Traits_; \ + using XE_2D_ST_Unpack::is_matrix_B; \ + using typename Base::ThrID; \ + using BlockShape = std::conditional_t; \ + using SrcLayout = decltype(detail::get_logical_layout(typename Base::SrcLayout{}, typename Base::BlockShape{})); \ + using DstLayout = decltype(detail::get_logical_layout(typename Base::DstLayout{}, typename Base::BlockShape{})); \ + using RefLayout = SrcLayout; \ + template \ + Copy_Traits(ArgTs... args) \ + : Copy_Traits_(args...) {} \ +}; + +COPY_TRAIT_LD_DEF(XE_2D_Packed_U4x1x128_LD_N) +COPY_TRAIT_LD_DEF(XE_2D_U8x1x16_LD_N) +COPY_TRAIT_LD_DEF(XE_2D_U8x1x32_LD_N) +COPY_TRAIT_LD_DEF(XE_2D_Packed_U8x1x32_LD_N) +COPY_TRAIT_LD_DEF(XE_2D_Packed_U8x2x32_LD_N) +COPY_TRAIT_LD_DEF(XE_2D_Packed_U8x4x32_LD_N) +COPY_TRAIT_LD_DEF(XE_2D_Packed_U8x8x32_LD_N) +COPY_TRAIT_LD_DEF(XE_2D_U8x8x32_LD_N) +COPY_TRAIT_LD_DEF(XE_2D_Packed_U8x1x64_LD_N) +COPY_TRAIT_LD_DEF(XE_2D_Packed_U8x2x64_LD_N) +COPY_TRAIT_LD_DEF(XE_2D_Packed_U8x4x64_LD_N) +COPY_TRAIT_LD_DEF(XE_2D_Packed_U8x8x64_LD_N) +COPY_TRAIT_LD_DEF(XE_2D_U8x32x8_LD_T) +COPY_TRAIT_LD_DEF(XE_2D_U8x32x4_LD_T) +COPY_TRAIT_LD_DEF(XE_2D_U64x8x1_LD_T) +COPY_TRAIT_LD_DEF(XE_2D_U64x8x2_LD_T) +COPY_TRAIT_LD_DEF(XE_2D_U64x8x4_LD_T) +COPY_TRAIT_LD_DEF(XE_2D_Packed_U8x16x32_LD_N) +COPY_TRAIT_LD_DEF(XE_2D_Packed_U8x32x32_LD_N) +COPY_TRAIT_LD_DEF(XE_2D_Packed_U8x16x64_LD_N) +COPY_TRAIT_LD_DEF(XE_2D_Packed_U8x32x64_LD_N) +COPY_TRAIT_LD_DEF(XE_2D_U8x32x32_LD_N) +COPY_TRAIT_LD_DEF(XE_2D_U16x1x16_LD_N) +COPY_TRAIT_LD_DEF(XE_2D_U16x2x16_LD_N) +COPY_TRAIT_LD_DEF(XE_2D_U16x4x16_LD_N) +COPY_TRAIT_LD_DEF(XE_2D_U16x8x16_LD_N) +COPY_TRAIT_LD_DEF(XE_2D_U16x1x32_LD_N) +COPY_TRAIT_LD_DEF(XE_2D_U16x2x32_LD_N) +COPY_TRAIT_LD_DEF(XE_2D_U16x4x32_LD_N) +COPY_TRAIT_LD_DEF(XE_2D_U16x8x32_LD_N) +COPY_TRAIT_LD_DEF(XE_2D_U16x16x32_LD_N) +COPY_TRAIT_LD_DEF(XE_2D_TF32x1x8_LD_N) +COPY_TRAIT_LD_DEF(XE_2D_TF32x2x8_LD_N) +COPY_TRAIT_LD_DEF(XE_2D_TF32x4x8_LD_N) +COPY_TRAIT_LD_DEF(XE_2D_TF32x8x8_LD_N) +COPY_TRAIT_LD_DEF(XE_2D_TF32x8x8_LD_T) +COPY_TRAIT_LD_DEF(XE_2D_U32x1x16_LD_N) +COPY_TRAIT_LD_DEF(XE_2D_U32x2x16_LD_N) +COPY_TRAIT_LD_DEF(XE_2D_U32x4x16_LD_N) +COPY_TRAIT_LD_DEF(XE_2D_U32x8x16_LD_N) +COPY_TRAIT_LD_DEF(XE_2D_U8x32x16_LD_V) +COPY_TRAIT_LD_DEF(XE_2D_U8x32x32_LD_V) +COPY_TRAIT_LD_DEF(XE_2D_U8x32x64_LD_V) +COPY_TRAIT_LD_DEF(XE_2D_U8x16x32_LD_T) +COPY_TRAIT_LD_DEF(XE_2D_U8x16x16_LD_T) +COPY_TRAIT_LD_DEF(XE_2D_U8x16x8_LD_T) +COPY_TRAIT_LD_DEF(XE_2D_U16x16x8_LD_T) +COPY_TRAIT_LD_DEF(XE_2D_U32x16x2_LD_T) +COPY_TRAIT_LD_DEF(XE_2D_U32x16x4_LD_T) +COPY_TRAIT_LD_DEF(XE_2D_U32x16x8_LD_T) +COPY_TRAIT_LD_DEF(XE_2D_U16x16x16_LD_N) +COPY_TRAIT_LD_DEF(XE_2D_Packed_U16x16x32_LD_N) +COPY_TRAIT_LD_DEF(XE_2D_U16x32x16_LD_N) +COPY_TRAIT_LD_DEF(XE_2D_U16x32x32_LD_N) +COPY_TRAIT_LD_DEF(XE_2D_Packed_U16x32x32_LD_N) +COPY_TRAIT_LD_DEF(XE_2D_TF32x16x8_LD_N) +COPY_TRAIT_LD_DEF(XE_2D_TF32x32x8_LD_N) +COPY_TRAIT_LD_DEF(XE_2D_TF32x1x16_LD_N) +COPY_TRAIT_LD_DEF(XE_2D_TF32x2x16_LD_N) +COPY_TRAIT_LD_DEF(XE_2D_TF32x4x16_LD_N) +COPY_TRAIT_LD_DEF(XE_2D_TF32x8x16_LD_N) +COPY_TRAIT_LD_DEF(XE_2D_U32x16x16_LD_N) +COPY_TRAIT_LD_DEF(XE_2D_U32x32x16_LD_N) +COPY_TRAIT_LD_DEF(XE_2D_U16x16x16_LD_V) +COPY_TRAIT_LD_DEF(XE_2D_U16x32x16_LD_V) +COPY_TRAIT_LD_DEF(XE_2D_U16x32x32_LD_V) +COPY_TRAIT_LD_DEF(XE_2D_U16x16x32_LD_V) +COPY_TRAIT_LD_DEF(XE_2D_U16x16x16_LD_T) +COPY_TRAIT_LD_DEF(XE_2D_TF32x16x16_LD_N) +COPY_TRAIT_LD_DEF(XE_2D_TF32x32x16_LD_N) +COPY_TRAIT_LD_DEF(XE_2D_U4x32x64_LD_N) +COPY_TRAIT_LD_DEF(XE_2D_U4x16x64_LD_N) +COPY_TRAIT_LD_DEF(XE_2D_U4x32x16_LD_T) +COPY_TRAIT_LD_DEF(XE_2D_U4x16x8_LD_T) +COPY_TRAIT_LD_DEF(XE_2D_U4x16x16_LD_T) +COPY_TRAIT_LD_DEF(XE_2D_Packed_U8x1x64_LD_N::PREFETCH) +COPY_TRAIT_LD_DEF(XE_2D_Packed_U8x2x64_LD_N::PREFETCH) +COPY_TRAIT_LD_DEF(XE_2D_Packed_U8x4x64_LD_N::PREFETCH) +COPY_TRAIT_LD_DEF(XE_2D_Packed_U8x8x64_LD_N::PREFETCH) +COPY_TRAIT_LD_DEF(XE_2D_Packed_U8x16x64_LD_N::PREFETCH) +COPY_TRAIT_LD_DEF(XE_2D_Packed_U8x32x64_LD_N::PREFETCH) +COPY_TRAIT_LD_DEF(XE_2D_U16x8x16_LD_N::PREFETCH) +COPY_TRAIT_LD_DEF(XE_2D_U16x1x32_LD_N::PREFETCH) +COPY_TRAIT_LD_DEF(XE_2D_U16x2x32_LD_N::PREFETCH) +COPY_TRAIT_LD_DEF(XE_2D_U16x4x32_LD_N::PREFETCH) +COPY_TRAIT_LD_DEF(XE_2D_U16x8x32_LD_N::PREFETCH) +COPY_TRAIT_LD_DEF(XE_2D_U8x32x16_LD_V::PREFETCH) +COPY_TRAIT_LD_DEF(XE_2D_U32x16x8_LD_T::PREFETCH) +COPY_TRAIT_LD_DEF(XE_2D_U16x16x16_LD_N::PREFETCH) +COPY_TRAIT_LD_DEF(XE_2D_U16x32x16_LD_N::PREFETCH) +COPY_TRAIT_LD_DEF(XE_2D_U16x16x32_LD_N::PREFETCH) +COPY_TRAIT_LD_DEF(XE_2D_U16x32x32_LD_N::PREFETCH) + +COPY_TRAIT_ST_DEF(XE_2D_U8x2x32_ST_N) +COPY_TRAIT_ST_DEF(XE_2D_U8x1x16_ST_N) +COPY_TRAIT_ST_DEF(XE_2D_U8x2x16_ST_N) +COPY_TRAIT_ST_DEF(XE_2D_U8x4x16_ST_N) +COPY_TRAIT_ST_DEF(XE_2D_U8x8x16_ST_N) +COPY_TRAIT_ST_DEF(XE_2D_U8x8x32_ST_N) +COPY_TRAIT_ST_DEF(XE_2D_U16x1x16_ST_N) +COPY_TRAIT_ST_DEF(XE_2D_U16x2x16_ST_N) +COPY_TRAIT_ST_DEF(XE_2D_U16x4x16_ST_N) +COPY_TRAIT_ST_DEF(XE_2D_U16x8x16_ST_N) +COPY_TRAIT_ST_DEF(XE_2D_U32x1x16_ST_N) +COPY_TRAIT_ST_DEF(XE_2D_U32x2x16_ST_N) +COPY_TRAIT_ST_DEF(XE_2D_U32x4x16_ST_N) +COPY_TRAIT_ST_DEF(XE_2D_U32x8x16_ST_N) + +// Generate the Xe coordinate tensor +template +CUTE_HOST_DEVICE constexpr +auto +get_xe_tensor(GShape const& g_shape) { + return make_coord_tensor(make_identity_layout(g_shape)); +} + +} // end namespace cute diff --git a/include/cute/atom/mma_atom.hpp b/include/cute/atom/mma_atom.hpp index 789c90a1a1..c28d9b1d80 100644 --- a/include/cute/atom/mma_atom.hpp +++ b/include/cute/atom/mma_atom.hpp @@ -1,5 +1,6 @@ /*************************************************************************************************** * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (C) 2025 Intel Corporation, All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without @@ -36,10 +37,6 @@ #include #include -#if defined(CUTLASS_ENABLE_SYCL) -#include -#endif - namespace cute { template @@ -561,10 +558,10 @@ make_tiled_mma(MMA_Op const&, // media/docs/cute/0t_mma_atom.md#tiledmmas to construct a scatter // permutation which ensures hardware operates on contiguous // chunks of the TiledMMA. The docs describe how the Layout -// implies a repetition of the atom across additional hardware. +// implies a repetition of the atom across additional hardware. // Permutations, in the simplest form, imply additional iterations // to cover a larger tile (i.e. CTALayout) than the hardware can handle -// at once. +// at once. // // Consider an example for Xe hardware: // using TiledMma = @@ -572,13 +569,13 @@ make_tiled_mma(MMA_Op const&, // Layout, Stride<_4, _1, _0>>, // Tile<_256, _256, _32>>; // -// This MMA_Atom is performed by a whole warp and operates on an 8x16x16 chunk. +// This MMA_Atom is performed by a whole warp and operates on an 8x16x16 chunk. // The second arg (Layout) defines a repetition of the atom across *additional warps*, // i.e. iterating across more hardware. The third arg (Tile) defines a repetition of this // MMA across *additional values*. For this example, in the M dimension, the atom produces // 8 values of C, the hardware repetition (8) scales this up to 64 values in M, and the -// requested permutation (256) scales this up to 256 values (implying 4 iterations in the -// M direction). +// requested permutation (256) scales this up to 256 values (implying 4 iterations in the +// M direction). // // By cute convention, the repetition of the atom across hardware is the inner // iteration, while the repetition across values is the outer. We can use a more complex @@ -595,11 +592,11 @@ make_tiled_mma(MMA_Op const&, // Layout, Stride<_1, _64, _16>>, // Permutation on N // _32>>; // K unpermuted // -// Consider only the M permutation (each mode's permutation is independent and in this +// Consider only the M permutation (each mode's permutation is independent and in this // example the M & N permutations are similar). This permutation maintains blocks of 8 // contiguous values from the canonical tiling (mode 0 is 8:1). -// It scatters 8 of these blocks of 8 to a spacing of 32 values (mode 1 is 8:32), leaving -// a 'gap' of 24. These gaps of 24 are filled by repeating the preceding pattern 4 times, +// It scatters 8 of these blocks of 8 to a spacing of 32 values (mode 1 is 8:32), leaving +// a 'gap' of 24. These gaps of 24 are filled by repeating the preceding pattern 4 times, // at a spacing of 8 values (mode 2 is 4:8). // In this manner, the tiling has been permuted so that the values handled by each thread are // closer together. @@ -780,6 +777,6 @@ print(ThrMMA const& thr_mma) #include #if defined(CUTLASS_ENABLE_SYCL) -#include +#include #endif //////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cute/atom/mma_traits_xe.hpp b/include/cute/atom/mma_traits_xe_legacy.hpp similarity index 99% rename from include/cute/atom/mma_traits_xe.hpp rename to include/cute/atom/mma_traits_xe_legacy.hpp index f99e171954..aafabd53b4 100644 --- a/include/cute/atom/mma_traits_xe.hpp +++ b/include/cute/atom/mma_traits_xe_legacy.hpp @@ -1,5 +1,6 @@ /*************************************************************************************************** * Copyright (c) 2024 - 2024 Codeplay Software Ltd. All rights reserved. + * Copyright (C) 2025 Intel Corporation, All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without @@ -28,9 +29,10 @@ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. * **************************************************************************************************/ + #pragma once -#include +#include #include #include From 326bbdb3db985bf2a8bf21b6be18a4038f32b880 Mon Sep 17 00:00:00 2001 From: Peter Caday Date: Mon, 18 Aug 2025 22:15:44 -0700 Subject: [PATCH 07/16] [CuTe] Additional SYCL vector helpers --- include/cute/util/sycl_vec.hpp | 30 ++++++++++++++++++++++-------- 1 file changed, 22 insertions(+), 8 deletions(-) diff --git a/include/cute/util/sycl_vec.hpp b/include/cute/util/sycl_vec.hpp index 5c700398ab..bc0937a0ec 100644 --- a/include/cute/util/sycl_vec.hpp +++ b/include/cute/util/sycl_vec.hpp @@ -1,5 +1,6 @@ /*************************************************************************************************** * Copyright (c) 2024 - 2024 Codeplay Software Ltd. All rights reserved. + * Copyright (C) 2025 Intel Corporation, All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without @@ -30,19 +31,33 @@ **************************************************************************************************/ #pragma once -// fwd declare OCL function and OCL types -#include //for sycl::vec +#include // sycl::vec +#include "cute/numeric/int.hpp" // int_byte_t +#include "cute/numeric/numeric_types.hpp" // bfloat16_t, half_t, etc. -namespace cute -{ -namespace intel +namespace cute::intel { + +constexpr int sg_size = 16; +using _SGSize = Int; + #ifdef __SYCL_DEVICE_ONLY__ -template using vector_t = T __attribute__((ext_vector_type(N))); +template struct vector_element_helper { using type = T; }; +template <> struct vector_element_helper { using type = uint32_t; }; +template <> struct vector_element_helper { using type = uint16_t; }; +template <> struct vector_element_helper { using type = uint16_t; }; +template <> struct vector_element_helper { using type = uint8_t; }; +template <> struct vector_element_helper { using type = uint8_t; }; + +template using vector_t = typename vector_element_helper::type __attribute__((ext_vector_type(N))); #else template using vector_t = sycl::marray; #endif +template +using storage_vector_t = vector_t)>, + bits / bytes_to_bits(bits_to_bytes(sizeof_bits_v))>; + using uint = unsigned int; using ushort = unsigned short; using ulong = unsigned long; @@ -90,5 +105,4 @@ using ulong2 = vector_t; using ulong4 = vector_t; using coord_t = vector_t; -} // namespace intel end -} // namespace cute end +} // namespace cute::intel From 9220b7fed53e61efbde759b6551dace5d2d3dbd5 Mon Sep 17 00:00:00 2001 From: Peter Caday Date: Mon, 18 Aug 2025 22:16:40 -0700 Subject: [PATCH 08/16] [CuTe] Introduce new MMA atoms --- include/cute/arch/mma_xe.hpp | 155 ++++++++++++++++++++++++++++ include/cute/atom/mma_atom.hpp | 1 + include/cute/atom/mma_traits_xe.hpp | 75 ++++++++++++++ 3 files changed, 231 insertions(+) create mode 100644 include/cute/arch/mma_xe.hpp create mode 100644 include/cute/atom/mma_traits_xe.hpp diff --git a/include/cute/arch/mma_xe.hpp b/include/cute/arch/mma_xe.hpp new file mode 100644 index 0000000000..20e3ce19a7 --- /dev/null +++ b/include/cute/arch/mma_xe.hpp @@ -0,0 +1,155 @@ +/*************************************************************************************************** +* Copyright (C) 2025 Intel Corporation, All rights reserved. +* SPDX-License-Identifier: BSD-3-Clause +* +* Redistribution and use in source and binary forms, with or without +* modification, are permitted provided that the following conditions are met: +* +* 1. Redistributions of source code must retain the above copyright notice, this +* list of conditions and the following disclaimer. +* +* 2. Redistributions in binary form must reproduce the above copyright notice, +* this list of conditions and the following disclaimer in the documentation +* and/or other materials provided with the distribution. +* +* 3. Neither the name of the copyright holder nor the names of its +* contributors may be used to endorse or promote products derived from +* this software without specific prior written permission. +* +* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +* +**************************************************************************************************/ + +#pragma once + +#if defined(__SYCL_DEVICE_ONLY__) && defined(SYCL_INTEL_TARGET) +#define CUTE_ARCH_MMA_XE_ENABLED +#endif + +#include +#include +#include + +namespace cute { + +template +struct XE_DPAS_TT; + +template +struct XE_DPAS_TT_Base +{ + static constexpr int K = 256 / cute::max(sizeof_bits_v, sizeof_bits_v); + + using DVector = intel::vector_t; + using AVector = intel::vector_t; + using BVector = intel::vector_t; + using CVector = intel::vector_t; + + using DRegisters = DVector[1]; + using ARegisters = AVector[1]; + using BRegisters = BVector[1]; + using CRegisters = CVector[1]; +}; + +namespace dpas_type { + +using f = float; +using tf32 = tfloat32_t; +using bf = bfloat16_t; +using hf = half_t; +using ud = uint32_t; +using d = int32_t; +using u8 = uint8_t; +using s8 = int8_t; +using u4 = uint4_t; +using s4 = int4_t; + +}; /* namespace dpas_type */ + +#ifdef CUTE_ARCH_MMA_XE_ENABLED + +#define CUTE_DECLARE_XE_DPAS_TT(TD, TA, TB, TC) \ +template struct XE_DPAS_TT \ + : public XE_DPAS_TT_Base { \ + using Base = XE_DPAS_TT_Base; \ + using AVector = typename Base::AVector; \ + using BVector = typename Base::BVector; \ + using CVector = typename Base::CVector; \ + using DVector = typename Base::DVector; \ + template \ + CUTE_DEVICE static void \ + fma(DVector& d, AVector const& a, BVector const& b, CVector_ const& c) { \ + if constexpr (std::is_same_v) { \ + d = c; \ + asm ( \ + "{\n" \ + ".decl DST v_type=G type=" #TD " num_elts=%5 alias=<%0,0>\n" \ + ".decl SRC1_UD v_type=G type=UD num_elts=128 alias=<%2,0>\n" \ + ".decl SRC2_UD v_type=G type=UD num_elts=%4 alias=<%1,0>\n" \ + "dpas." #TB "." #TA ".8.%3 (M1, 16) DST.0 DST.0 SRC1_UD.0 SRC2_UD(0,0)\n" \ + "}\n" \ + : "+rw"(d) : "rw"(a), "rw"(b), "P"(M), "P"(M*8), "P"(M*16) \ + ); \ + } else { \ + asm ( \ + "{\n" \ + ".decl DST v_type=G type=" #TD " num_elts=%6 alias=<%0,0>\n" \ + ".decl SRC0 v_type=G type=" #TC " num_elts=%6 alias=<%3,0>\n" \ + ".decl SRC1_UD v_type=G type=UD num_elts=128 alias=<%2,0>\n" \ + ".decl SRC2_UD v_type=G type=UD num_elts=%5 alias=<%1,0>\n" \ + "dpas." #TB "." #TA ".8.%4 (M1, 16) DST.0 SRC0.0 SRC1_UD.0 SRC2_UD(0,0)\n" \ + "}\n" \ + : "=rw"(d) : "rw"(a), "rw"(b), "rw"(c), "P"(M), "P"(M*8), "P"(M*16) \ + ); \ + } \ + } \ +}; + +#else /* !defined(CUTE_ARCH_MMA_XE_ENABLED) */ + +#define CUTE_DECLARE_XE_DPAS_TT(TD, TA, TB, TC) \ +template struct XE_DPAS_TT \ + : public XE_DPAS_TT_Base { \ + using Base = XE_DPAS_TT_Base; \ + using AVector = typename Base::AVector; \ + using BVector = typename Base::BVector; \ + using CVector = typename Base::CVector; \ + using DVector = typename Base::DVector; \ + CUTE_HOST_DEVICE static void \ + fma(DVector& d, AVector const& a, BVector const& b, CVector const& c) { \ + CUTE_INVALID_CONTROL_PATH("Cannot use Xe DPAS MMA atom on non-Xe hardware"); \ + } \ +}; +#endif + + +CUTE_DECLARE_XE_DPAS_TT(f, tf32, tf32, f) + +CUTE_DECLARE_XE_DPAS_TT(f, bf, bf, f) +CUTE_DECLARE_XE_DPAS_TT(bf, bf, bf, f) +CUTE_DECLARE_XE_DPAS_TT(f, bf, bf, bf) +CUTE_DECLARE_XE_DPAS_TT(bf, bf, bf, bf) + +CUTE_DECLARE_XE_DPAS_TT(f, hf, hf, f) +CUTE_DECLARE_XE_DPAS_TT(f, hf, hf, hf) +CUTE_DECLARE_XE_DPAS_TT(hf, hf, hf, f) +CUTE_DECLARE_XE_DPAS_TT(hf, hf, hf, hf) + +CUTE_DECLARE_XE_DPAS_TT(ud, u8, u8, ud) +CUTE_DECLARE_XE_DPAS_TT(d, u8, u8, d) +CUTE_DECLARE_XE_DPAS_TT(d, u8, s8, d) +CUTE_DECLARE_XE_DPAS_TT(d, s8, u8, d) +CUTE_DECLARE_XE_DPAS_TT(d, s8, s8, d) + +#undef CUTE_DECLARE_XE_DPAS_TT + +} //namespace cute diff --git a/include/cute/atom/mma_atom.hpp b/include/cute/atom/mma_atom.hpp index c28d9b1d80..1e01fb1106 100644 --- a/include/cute/atom/mma_atom.hpp +++ b/include/cute/atom/mma_atom.hpp @@ -777,6 +777,7 @@ print(ThrMMA const& thr_mma) #include #if defined(CUTLASS_ENABLE_SYCL) +#include #include #endif //////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cute/atom/mma_traits_xe.hpp b/include/cute/atom/mma_traits_xe.hpp new file mode 100644 index 0000000000..06f812bacf --- /dev/null +++ b/include/cute/atom/mma_traits_xe.hpp @@ -0,0 +1,75 @@ +/*************************************************************************************************** +* Copyright (C) 2025 Intel Corporation, All rights reserved. +* SPDX-License-Identifier: BSD-3-Clause +* +* Redistribution and use in source and binary forms, with or without +* modification, are permitted provided that the following conditions are met: +* +* 1. Redistributions of source code must retain the above copyright notice, this +* list of conditions and the following disclaimer. +* +* 2. Redistributions in binary form must reproduce the above copyright notice, +* this list of conditions and the following disclaimer in the documentation +* and/or other materials provided with the distribution. +* +* 3. Neither the name of the copyright holder nor the names of its +* contributors may be used to endorse or promote products derived from +* this software without specific prior written permission. +* +* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +* +**************************************************************************************************/ + +#pragma once + +#include +#include + +#include + +namespace cute +{ + +template +struct MMA_Traits> +{ + using Op = XE_DPAS_TT; + + static constexpr int BV = 32 / sizeof_bits_v; + static constexpr int K = Op::K; + + using ValTypeD = TD; + using ValTypeA = TA; + using ValTypeB = TB; + using ValTypeC = TC; + using _M = Int; + using _K = Int; + + using Shape_MNK = Shape<_M, _16, _K>; + using ThrID = Layout<_16>; + + // A layout: (T,V) -> (M,K) + // M x K row major, work-items interleaved. + using ALayout = decltype(composition(make_layout(make_shape(_K{}, _M{}), LayoutRight{}), + make_layout(make_shape(_16{}, Int{})))); + + // B layout: (T,V) -> (N,K) + // K x 16 VNNI-transformed row major, work-items interleaved. + using BLayout = Layout, Int<16/BV>>, Shape, Int<16/BV>>>, + Stride, Stride, Int<16*BV>>>>; + + // C layout: (T,V) -> (M,N) + // M x 16 row major, work-items interleaved. + using CLayout = Layout, Stride<_M, _1>>; +}; + +} /* namespace cute */ From c817f83b15f0c772891c6ab358209c4344faa7b6 Mon Sep 17 00:00:00 2001 From: Peter Caday Date: Mon, 18 Aug 2025 22:18:01 -0700 Subject: [PATCH 09/16] [CuTe] Introduce new copy atoms --- include/cute/arch/copy_xe_2d.hpp | 176 ++++ include/cute/atom/copy_atom.hpp | 1 + include/cute/atom/copy_traits_xe_2d.hpp | 1163 +++++++++++++++++++++++ 3 files changed, 1340 insertions(+) create mode 100644 include/cute/arch/copy_xe_2d.hpp create mode 100644 include/cute/atom/copy_traits_xe_2d.hpp diff --git a/include/cute/arch/copy_xe_2d.hpp b/include/cute/arch/copy_xe_2d.hpp new file mode 100644 index 0000000000..02b4b019e1 --- /dev/null +++ b/include/cute/arch/copy_xe_2d.hpp @@ -0,0 +1,176 @@ +/*************************************************************************************************** +* Copyright (C) 2025 Intel Corporation, All rights reserved. +* SPDX-License-Identifier: BSD-3-Clause +* +* Redistribution and use in source and binary forms, with or without +* modification, are permitted provided that the following conditions are met: +* +* 1. Redistributions of source code must retain the above copyright notice, this +* list of conditions and the following disclaimer. +* +* 2. Redistributions in binary form must reproduce the above copyright notice, +* this list of conditions and the following disclaimer in the documentation +* and/or other materials provided with the distribution. +* +* 3. Neither the name of the copyright holder nor the names of its +* contributors may be used to endorse or promote products derived from +* this software without specific prior written permission. +* +* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +* +**************************************************************************************************/ + +#pragma once + +#include "cute/numeric/int.hpp" + +#if defined(__SYCL_DEVICE_ONLY__) && defined(SYCL_INTEL_TARGET) +#define CUTE_ARCH_COPY_XE_ENABLED +#endif + +namespace cute { + +// Xe 2D copy atoms. +// Bits: bits per element in underlying memory operation. +// Height: number of elements in the gmem-strided matrix dimension +// Width: number of elements in the gmem-contiguous matrix dimension +// BlockWidth: blocking factor (in registers) for the width dimension. +// +// For a row-major-in-memory matrix: +// height = #rows, width = #columns. +// +// For a column-major-in-memory matrix: +// height = #columns, width = #rows. + +/* Base class for 2D copy ops, to support common queries */ +template +struct XE_Copy_Op_2D_Base +{ + static_assert(Height <= 32, "Height exceeds hardware limits"); + static_assert(Bits * Width <= 8 * 64, "Total width exceeds hardware limits"); + static_assert(Bits * Count <= 64 && Count <= 4 && Count != 3, "Unsupported block count"); + static_assert(Bits == 8 || Bits == 16 || Bits == 32 || Bits == 64, "Unsupported data size"); + + static constexpr int CopyBits = Bits; + static constexpr int AtomWidth = Width; + static constexpr int AtomHeight = Height; + static constexpr int BlockCount = Count; + static constexpr bool Transposing = Transpose; +}; + +template +struct XE_PREFETCH_2D; + + +template +struct XE_LOAD_2D : XE_Copy_Op_2D_Base +{ + template + CUTE_HOST_DEVICE static void copy(const int *payload, T *dst) { +#ifdef CUTE_ARCH_COPY_XE_ENABLED + auto &dv = *reinterpret_cast*>(dst); + asm ( + "lsc_load_block2d.ugm (M1, 1) %0:d%2.%3x%4x%5nn flat[%1+(0,0)]" + : "=rw"(dv) + : "rw.u"(payload), "P"(Bits), "P"(Width/BlockWidth), "P"(BlockWidth), "P"(Height) + ); +#else + CUTE_INVALID_CONTROL_PATH("Cannot use Xe block 2D copy atom on non-Xe hardware"); +#endif + } + + using PREFETCH = XE_PREFETCH_2D; +}; + +template +struct XE_LOAD_2D_VNNI : XE_Copy_Op_2D_Base +{ + static_assert(Bits == 8 || Bits == 16, "Unsupported data size"); + + template + CUTE_HOST_DEVICE static void copy(const int *payload, T *dst) { +#ifdef CUTE_ARCH_COPY_XE_ENABLED + auto &dv = *reinterpret_cast*>(dst); + asm ( + "lsc_load_block2d.ugm (M1, 1) %0:d%2.%3x%4x%5nt flat[%1+(0,0)]" + : "=rw"(dv) + : "rw.u"(payload), "P"(Bits), "P"(Width/BlockWidth), "P"(BlockWidth), "P"(Height) + ); +#else + CUTE_INVALID_CONTROL_PATH("Cannot use Xe block 2D copy atom on non-Xe hardware"); +#endif + } + + using PREFETCH = XE_PREFETCH_2D; +}; + +template +struct XE_LOAD_2D_TRANSPOSE : XE_Copy_Op_2D_Base +{ + static_assert(Bits == 32 || Bits == 64, "Unsupported data size"); + static_assert(Width <= 8, "Width exceeds hardware limits"); + static_assert(Bits != 64 || (Height == 8 && Width < 4), "Unsupported D64 transpose block size"); + + template + CUTE_HOST_DEVICE static void copy(const int *payload, T *dst) { +#ifdef CUTE_ARCH_COPY_XE_ENABLED + auto &dv = *reinterpret_cast*>(dst); + asm ( + "lsc_load_block2d.ugm (M1, 1) %0:d%2.%3x%4tn flat[%1+(0,0)]" + : "=rw"(dv) + : "rw.u"(payload), "P"(Bits), "P"(Width), "P"(Height) + ); +#else + CUTE_INVALID_CONTROL_PATH("Cannot use Xe block 2D copy atom on non-Xe hardware"); +#endif + } + + using PREFETCH = XE_PREFETCH_2D; +}; + +template +struct XE_PREFETCH_2D : XE_Copy_Op_2D_Base +{ + CUTE_HOST_DEVICE static void copy(const int *payload) { +#ifdef CUTE_ARCH_COPY_XE_ENABLED + asm ( + "lsc_load_block2d.ugm.ca.ca (M1, 1) %%null:d%1.%2x%3nn flat[%0+(0,0)]" + :: "rw.u"(payload), "P"(Bits), "P"(Width), "P"(Height) + ); +#else + CUTE_INVALID_CONTROL_PATH("Cannot use Xe block 2D copy atom on non-Xe hardware"); +#endif + } + + using PREFETCH = XE_PREFETCH_2D; +}; + +template +struct XE_STORE_2D : XE_Copy_Op_2D_Base +{ + static_assert(Height <= 8, "Height exceeds hardware limits"); + + template + CUTE_HOST_DEVICE static void copy(const int *payload, const T *src) { +#ifdef CUTE_ARCH_COPY_XE_ENABLED + auto &sv = *reinterpret_cast*>(src); \ + asm ( + "lsc_store_block2d.ugm (M1, 1) flat[%1+(0,0)] %0:d%2.%3x%4nn" + :: "rw"(sv), "rw.u"(payload), "P"(Bits), "P"(Width), "P"(Height) + ); +#else + CUTE_INVALID_CONTROL_PATH("Cannot use Xe block 2D copy atom on non-Xe hardware"); +#endif + } +}; + +} // end namespace cute diff --git a/include/cute/atom/copy_atom.hpp b/include/cute/atom/copy_atom.hpp index 797d2a22ae..4c653c9433 100644 --- a/include/cute/atom/copy_atom.hpp +++ b/include/cute/atom/copy_atom.hpp @@ -716,6 +716,7 @@ print(ThrCopy const& thr_copy) #if defined(SYCL_INTEL_TARGET) #include +#include #include #endif diff --git a/include/cute/atom/copy_traits_xe_2d.hpp b/include/cute/atom/copy_traits_xe_2d.hpp new file mode 100644 index 0000000000..52aaa743d0 --- /dev/null +++ b/include/cute/atom/copy_traits_xe_2d.hpp @@ -0,0 +1,1163 @@ +/*************************************************************************************************** +* Copyright (C) 2025 Intel Corporation, All rights reserved. +* SPDX-License-Identifier: BSD-3-Clause +* +* Redistribution and use in source and binary forms, with or without +* modification, are permitted provided that the following conditions are met: +* +* 1. Redistributions of source code must retain the above copyright notice, this +* list of conditions and the following disclaimer. +* +* 2. Redistributions in binary form must reproduce the above copyright notice, +* this list of conditions and the following disclaimer in the documentation +* and/or other materials provided with the distribution. +* +* 3. Neither the name of the copyright holder nor the names of its +* contributors may be used to endorse or promote products derived from +* this software without specific prior written permission. +* +* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +* +**************************************************************************************************/ + +#pragma once + +#include +#include + +#include +#include + +// 2D block payload intrinsics +SYCL_EXTERNAL extern "C" int* __builtin_IB_subgroup_createBlock2DAddressPayload(long base, int width_minus_one, int height_minus_one, int pitch_minus_one, + int blockX, int blockY, int blockWidth, int blockHeight, int numBlocks); +SYCL_EXTERNAL extern "C" int* __builtin_IB_subgroup_copyBlock2DAddressPayload(int* AP); + +SYCL_EXTERNAL extern "C" void __builtin_IB_subgroup_addBlock2DAddressPayloadBlockX(int* addrPayload, int blockX); +SYCL_EXTERNAL extern "C" void __builtin_IB_subgroup_addBlock2DAddressPayloadBlockY(int* addrPayload, int blockY); +SYCL_EXTERNAL extern "C" void __builtin_IB_subgroup_setBlock2DAddressPayloadBlockX(int* addrPayload, int blockX); +SYCL_EXTERNAL extern "C" void __builtin_IB_subgroup_setBlock2DAddressPayloadBlockY(int* addrPayload, int blockY); +SYCL_EXTERNAL extern "C" void __builtin_IB_subgroup_setBlock2DAddressPayloadBase(int* addrPayload, long base); +SYCL_EXTERNAL extern "C" void __builtin_IB_subgroup_setBlock2DAddressPayloadWidth(int* addrPayload, int width_minus_one); +SYCL_EXTERNAL extern "C" void __builtin_IB_subgroup_setBlock2DAddressPayloadHeigth(int* addrPayload, int height_minus_one); +SYCL_EXTERNAL extern "C" void __builtin_IB_subgroup_setBlock2DAddressPayloadPitch(int* addrPayload, int pitch_minus_one); + + +namespace cute { + +// Utility to check if a layout belongs to a coordinate tensor. +template +static constexpr bool is_counting_layout_v = is_arithmetic_tuple_like::value; + + + + +// Base traits class for block 2D messages. +// +// XMode and YMode are mode indices into the tensor, identifying which modes map to the block 2D dimensions. +// X: consecutive dimension +// Y: strided dimension internal to the copy atom +// While individual atoms perform 2D copies, additional dimensions are supported by tiling. +// +// If the value type of the tensor has a different size from the underlying copy atoms, +// it must be specified via the ValType template argument. Due to the SIMD-like layout of data +// in registers, the generic CuTe code for handling type size changes (via Copy_Atom) does not +// work properly in most cases. +template > +struct Xe2DTraitsBase +{ + using Traits = Copy_Traits; + using ThrID = Layout<_16>; + + static constexpr int ValBits = is_void_v ? Op::CopyBits + : int(sizeof_bits_v); + static_assert(Op::CopyBits % ValBits == 0, "Type is incompatible with this copy atom"); + + // Payload for 2D block message: + // - base pointer + // - matrix width/height/pitch in global memory + // - x/y offsets (overwritten during each copy operation) + // - block width/height/count + // Note the payload is mutable to allow x/y offsets to be dynamically updated for each use. + mutable int *payload; + + // Copy of base pointer, to allow payload updates for >2D tensors. + uint64_t base_ptr; + + // Copies of width/height/pitch, for constructing related traits (e.g. load->prefetch) + uint32_t width, height, pitch; + + // Strides not handled by block 2D operations (>2D tensors). + TiledStrides tiled_strides; + + static constexpr bool nontrivial_tiled_strides = !is_static_v + || !is_constant_v<0, decltype(cute::max(TiledStrides{}))>; + + // Uninitialized atom, available on host or device. + CUTE_HOST_DEVICE + Xe2DTraitsBase() {} + + // Initialized atom, device-only. + template + CUTE_DEVICE + Xe2DTraitsBase(Tensor const& src) + : base_ptr((uint64_t) &*src.data()), + tiled_strides(replace(replace(src.stride(), _0{}), _0{})) + { + constexpr auto SBits = sizeof_bits_v; + width = (shape(src) * SBits) >> 3; + height = shape(src); + pitch = (stride(src) * SBits) >> 3; +#ifdef CUTE_ENABLE_XE_BLOCK_2D_ASSERT + assert((base_ptr % 64 == 0) && "CuTe runtime error: misaligned block 2D base pointer"); + assert((width % 4 == 0) && "CuTe runtime error: misaligned block 2D tensor width"); + assert((pitch % 4 == 0) && "CuTe runtime error: misaligned block 2D tensor pitch"); + assert((width <= 0xFFFFFF) && "CuTe runtime error: block 2D tensor width exceeds 2^24"); + assert((height <= 0xFFFFFF) && "CuTe runtime error: block 2D tensor height exceeds 2^24"); + assert((pitch <= 0xFFFFFF) && "CuTe runtime error: block 2D tensor pitch exceeds 2^24"); +#endif + init_payload(); + } + + template + CUTE_DEVICE explicit + Xe2DTraitsBase(Xe2DTraitsBase const& other) + : base_ptr(other.base_ptr), width(other.width), height(other.height), pitch(other.pitch), + tiled_strides(other.tiled_strides) + { + init_payload(); + } + + // Initialize a previously-uninitialized atom. + template + CUTE_DEVICE static auto + with(Args&&... args) { + return Traits(std::forward(args)...); + } + + CUTE_DEVICE + void init_payload() { +#ifdef __SYCL_DEVICE_ONLY__ + payload = __builtin_IB_subgroup_createBlock2DAddressPayload( + base_ptr, + width - 1, + height - 1, + pitch - 1, + 0, /* x offset, configured per-copy */ + 0, /* y offset, configured per-copy */ + Op::AtomWidth / Op::BlockCount, + Op::AtomHeight, + Op::BlockCount + ); +#endif + } + + template + CUTE_DEVICE + void update_payload(const Coord &coord) const + { +#ifdef __SYCL_DEVICE_ONLY__ + // Update x/y offsets in payload + int32_t x = get(coord) * Bits / Op::CopyBits; + int32_t y = get(coord); + __builtin_IB_subgroup_setBlock2DAddressPayloadBlockX(payload, x); + __builtin_IB_subgroup_setBlock2DAddressPayloadBlockY(payload, y); + +#ifdef CUTE_ENABLE_XE_BLOCK_2D_ASSERT + assert((x % 4 == 0) && "CuTe runtime error: misaligned block 2D x offset"); +#endif + + // Perform stride calculation and update base pointer for > 2D tensors + if constexpr (nontrivial_tiled_strides) { + auto offset = inner_product(coord, tiled_strides); + auto byte_offset = (offset * Bits) >> 3; + __builtin_IB_subgroup_setBlock2DAddressPayloadBase(payload, base_ptr + byte_offset); + +#ifdef CUTE_ENABLE_XE_BLOCK_2D_ASSERT + assert((byte_offset % 64 == 0) && "CuTe runtime error: misaligned block 2D base pointer"); +#endif + } +#endif /* __SYCL_DEVICE_ONLY__ */ + } + + static constexpr auto get_x_mode() { return XMode{}; } + static constexpr auto get_y_mode() { return YMode{}; } +}; + +template > +struct Xe2DLoadTraitsBase : Xe2DTraitsBase +{ + using Super = Xe2DTraitsBase; + using Traits = typename Super::Traits; + using ThrID = typename Super::ThrID; + + using Super::Super; + + // Execution. + template + CUTE_DEVICE friend constexpr void + copy_unpack(Traits const& traits, + Tensor const& src, + Tensor & dst) { + using SType = typename SEngine::value_type; + using DType = typename DEngine::value_type; + using SrcLayout = typename Traits::SrcLayout; + using DstLayout = typename Traits::DstLayout; + constexpr auto DBits = sizeof_bits_v; + + static_assert(is_counting_layout_v, "Source tensor must be a coordinate tensor."); + static_assert(is_rmem_v, "Destination tensor must be in registers."); + static_assert(size(SLayout{}) * DBits == size<1>(SrcLayout{}), + "Source tensor size does not match copy atom size."); + static_assert(size(DLayout{}) * DBits == size<1>(DstLayout{}), + "Destination tensor size does not match copy atom size."); + + traits.template update_payload(src.data().coord_); + Op::copy(traits.payload, recast_ptr>(&*dst.data())); + } +}; + + +// Split a subgroup-level layout into a TV-layout. +template +struct XeInterleavedLayoutHelper { + // Underlying SIMD vector type's element width: + static constexpr int VecTypeBits = cute::max(ValBits, 8); + + // Expand from CopyBits to VecTypeBits in x dimension: + using Expanded = decltype(logical_product(Layout>>{}, InLayout{})); // V' -> (x', y) + + // Split elements between work-items, interleaving: + using TVLayout = decltype(composition(Expanded{}, make_layout(make_shape(Int{}, Int{})))); + + // Expand from elements to bits: + using PreResult = decltype(blocked_product(Layout>>{}, TVLayout{})); + + // Simplify for nicer-looking layouts: + using Result = decltype(coalesce(PreResult{}, Step<_1, _1>{})); + + // Examples: + + // U16 32x16 nontranspose -> U4/U8 + // In: (_32, _16):(_1, _32) V -> (x,y) + // Exp: (_2, _32, _16):(_1, _2, _64) Vbit -> (xbit,y) + // Compose with (_16, _64):(_1, _16) + // TV: (_16, _64):(_1, _16) + // Res: (_16, (_8, _64)):(_8, (_1, _128)) + + // U32 8x16 transpose -> U16 (16x16) LD_T + // In: (_16, _8):(_8, _1) V -> (x,y) + // Exp: (_2, _16, _8):(_1, _16, _2) V16 -> (x16,y) + // Compose with (_16, _16):(_1, _16) + // TV: ((_2, _8), (_2, _8)):((_1, _16), (_128, _2)) (T,V) -> (x16,y) + // Res: ((_2, _8), (_16, _2, _8)):((_16, _256), (_1, _2048, _32)) (T,V) -> (xbit,y) +}; + +template +using XeInterleavedLayout = typename XeInterleavedLayoutHelper::Result; + +// Block 2D load traits. +template +struct Copy_Traits, XMode, YMode, ValType, TiledStrides> + : Xe2DLoadTraitsBase, XMode, YMode, ValType, TiledStrides> +{ + using Super = Xe2DLoadTraitsBase, XMode, YMode, ValType, TiledStrides>; + using Super::Super; + + // (dst-thr, dst-val) -> (x, y) + using DstLayout = XeInterleavedLayout, Int, Int>, + Stride<_1, Int, Int>>, + CopyBits, + sizeof_bits_v>; + + using RefLayout = DstLayout; + using SrcLayout = decltype(replace<0>(RefLayout{}, Layout, Stride<_0>>{})); +}; + +// Block 2D VNNI load traits. +template +struct Copy_Traits, XMode, YMode, ValType, TiledStrides> + : Xe2DLoadTraitsBase, XMode, YMode, ValType, TiledStrides> +{ + using Super = Xe2DLoadTraitsBase, XMode, YMode, ValType, TiledStrides>; + using Super::Super; + + static constexpr int BV = 32 / CopyBits; + + // (dst-thr, dst-val) -> (x, y) + using DstLayout = XeInterleavedLayout, Int, Int, Int>, + Stride, _1, Int, Int>>, + CopyBits, + sizeof_bits_v>; + + using RefLayout = DstLayout; + using SrcLayout = decltype(replace<0>(RefLayout{}, Layout, Stride<_0>>{})); +}; + +// Block 2D transposed load traits. +template +struct Copy_Traits, XMode, YMode, ValType, TiledStrides> + : Xe2DLoadTraitsBase, XMode, YMode, ValType, TiledStrides> +{ + using Super = Xe2DLoadTraitsBase, XMode, YMode, ValType, TiledStrides>; + using Super::Super; + + // (dst-thr, dst-val) -> (x, y) + using DstLayout = XeInterleavedLayout, Int>, + Stride, _1>>, + CopyBits, + sizeof_bits_v>; + + using RefLayout = DstLayout; + using SrcLayout = decltype(replace<0>(RefLayout{}, Layout, Stride<_0>>{})); +}; + +// Block 2D store traits. +template +struct Copy_Traits, XMode, YMode, ValType, TiledStrides> + : Xe2DTraitsBase, XMode, YMode, ValType, TiledStrides> +{ + // (src-thr, src-val) -> (x, y) + using SrcLayout = XeInterleavedLayout, Int>>, + CopyBits, + sizeof_bits_v>; + + using RefLayout = SrcLayout; + using DstLayout = decltype(replace<0>(RefLayout{}, Layout, Stride<_0>>{})); + + using Op = XE_STORE_2D; + using Super = Xe2DTraitsBase; + using Traits = typename Super::Traits; // a.k.a. this class + using ThrID = typename Super::ThrID; + + using Super::Super; + + // Execution. + template + CUTE_DEVICE friend constexpr void + copy_unpack(Traits const& traits, + Tensor const& src, + Tensor & dst) { + using SType = typename SEngine::value_type; + using DType = typename DEngine::value_type; + using SrcLayout = typename Traits::SrcLayout; + using DstLayout = typename Traits::DstLayout; + constexpr auto SBits = sizeof_bits_v; + + static_assert(is_counting_layout_v, "Destination tensor must be a coordinate tensor."); + static_assert(is_rmem_v, "Source tensor must be in registers."); + static_assert(size(SLayout{}) * SBits == size<1>(SrcLayout{}), + "Source tensor size does not match copy atom size."); + static_assert(size(DLayout{}) * SBits == size<1>(DstLayout{}), + "Destination tensor size does not match copy atom size."); + + traits.template update_payload(dst.data().coord_); + Op::copy(traits.payload, recast_ptr>(&*src.data())); + } +}; + +// Block 2D prefetch traits. +// +// Note prefetch does not use/need block width; it is present for template arg compatibility +// between loads and their prefetches. +template +struct Copy_Traits, XMode, YMode, ValType, TiledStrides> + : Xe2DTraitsBase, XMode, YMode, ValType, TiledStrides> +{ + // (dst-thr, dst-val) -> (x, y) + using DstLayout = XeInterleavedLayout, Int>>, + CopyBits, + sizeof_bits_v>; + + using RefLayout = DstLayout; + using SrcLayout = decltype(replace<0>(RefLayout{}, Layout, Stride<_0>>{})); + + using Op = XE_PREFETCH_2D; + using Super = Xe2DTraitsBase; + using Traits = typename Super::Traits; // a.k.a. this class + using ThrID = typename Super::ThrID; + + using Super::Super; + + // Execution. + template + CUTE_DEVICE friend constexpr void + copy_unpack(Traits const& traits, + Tensor const& src, + Tensor & dst) { + using SType = typename SEngine::value_type; + using SrcLayout = typename Traits::SrcLayout; + + static_assert(is_counting_layout_v, "Source tensor must be a coordinate tensor."); + static_assert(size(SLayout{}) * Super::ValBits == size<1>(SrcLayout{}), + "Source tensor size does not match copy atom size."); + + traits.template update_payload(src.data().coord_); + Op::copy(traits.payload); + } +}; + +// Helpers for creating a tiling of block 2D copy atoms for a given global memory tensor. +// +// The x/y modes are deduced according to the rules: +// x: innermost constant-stride-1 mode +// y: innermost dynamic-stride mode, or innermost non-1 stride if there are no dynamic strides. +template +CUTE_HOST_DEVICE +auto +make_block_2d_copy(const CopyOp& op, const Tensor& gmem) { + return make_block_2d_copy(op, gmem.stride()).with(gmem); +} + +template +CUTE_HOST_DEVICE +auto +make_block_2d_copy(const CopyOp& op, const Stride&) +{ + // Configure traits for this atom, identifying x and y modes. + using ValType = std::conditional_t, + int_bit_t, + OptionalValType>; + + Stride strides{}; + return make_block_2d_copy(op, strides, find_x_mode(strides), find_y_mode(strides)); +} + +template +CUTE_HOST_DEVICE +auto +make_block_2d_copy(const CopyOp& op, const Stride&, const XMode&, const YMode&) +{ + static constexpr auto ValBits = sizeof_bits_v; + + Stride strides{}; + XMode x_mode{}; + YMode y_mode{}; + + using TiledStrides = decltype(replace(replace(strides, _0{}), _0{})); + + using Traits = Copy_Traits; + using Atom = Copy_Atom; + + // Create tiler for the TiledCopy. + constexpr auto tile_1 = tuple_repeat(_1{}); + constexpr auto Width = CopyOp::AtomWidth * CopyOp::CopyBits / ValBits; + constexpr auto Height = CopyOp::AtomHeight; + using ShapeTiler_MN = decltype(replace(replace(tile_1, Int{}), Int{})); + + // Create proper TV-layout for the TiledCopy, using the copy atom's reference layout. + // + // ValLayoutRef for all block 2D atoms is (T,V)->(X,Y). + // If the x/y ordering in ValLayoutRef matches the order of XMode/YMode in the given strides, then + // the TiledCopy's TV-layout is just ValLayoutRef. Otherwise, we need to transpose x/y in the RefLayout. + constexpr bool transpose_tv = (y_mode < x_mode); + using MaybeTranspose = Layout, Int>, + Stride, + Int>>; + using LayoutCopy_TV = decltype(composition(MaybeTranspose{}, typename Atom::ValLayoutRef{})); + + return TiledCopy{}; +} + +// Low-level routine for creating a block 2D TiledCopy for multiple subgroups. +// In addition to the usual parameters, it takes: +// - atom_shape = "subgroup shape" (# of copy blocks in each dimension) +// - sv_layout = "subgroup-value layout" (ordering for subgroups in the tiling) +template +CUTE_HOST_DEVICE +auto +make_block_2d_copy(const CopyOp& op, + const Stride& strides, + const XMode& x_mode, const YMode& y_mode, + const SGShape& atom_shape, // (SG_M, SG_N, ...) + const SVLayout& sv_layout) // (SG #, SG value) -> (SG_M, SG_N, ...) +{ + // Create TiledCopy for a single subgroup. + using SGCopy = decltype(make_block_2d_copy(op, strides, x_mode, y_mode)); + using Atom = typename SGCopy::Atom; + using ShapeTiler_MN = typename SGCopy::Tiler_MN; + using LayoutCopy_TV = typename SGCopy::TiledLayout_TV; + + // Expand the shape. + auto x_shape = elem_scale(ShapeTiler_MN{}, atom_shape); + + // Expand the single-SG TV layout to the full shape, then tile. + auto x_tv_layout1 = composition(make_layout(ShapeTiler_MN{}, make_layout(x_shape).stride()), LayoutCopy_TV{}); + auto x_tv_layout = blocked_product(x_tv_layout1, sv_layout); + + return TiledCopy{}; +} + + +template +CUTE_HOST_DEVICE +constexpr auto +find_x_mode(const Stride &) { + Stride strides{}; + return find_if(strides, [](auto const &x) { return C>{}; }); +} + +template +CUTE_HOST_DEVICE +constexpr auto +find_y_mode(const Stride&) { + Stride strides{}; + constexpr auto YModeDyn = find_if(strides, [](auto const &x) { return C>{}; }); + if constexpr (YModeDyn < rank(strides)) + return YModeDyn; + else + return find_if(strides, [](auto const &x) { return C>{}; }); +} + +// Copy selection and creation. + +template +struct find_first_basis_mode_pred { using type = C; }; + +template +struct find_first_basis_mode_pred, I>> { + using type = C<(S >= min_scale)>; +}; + +template +CUTE_HOST_DEVICE +constexpr auto +find_first_basis_mode(InLayout const&) { + return find_if(InLayout{}.stride(), [](auto const &x) { + using XType = remove_cvref_t; + return typename find_first_basis_mode_pred::type{}; + }); +} + +// Find the first block size (stride) in dimension N of size at least min_size. +// FIXME: should look through all matching strides and pick the smallest. +template +CUTE_HOST_DEVICE +constexpr auto +get_block_size(InLayout const&) { + InLayout layout{}; + constexpr auto block_mode = find_first_basis_mode(layout); + if constexpr (block_mode < rank(layout)) + return basis_value(stride(layout)); + else + return get(atuple_coshape(layout)); +} + +// Remove VNNI modes from a layout, if present. +// Returns a std::pair = (layout_out, has_vnni) +template +CUTE_HOST_DEVICE +constexpr auto +strip_vnni(InLayout const&) +{ + constexpr InLayout layout{}; + constexpr int R = rank(InLayout{}); + constexpr bool vnni = (R >= 2) + && (Bits < 32) + && is_constant_v<32 / Bits, decltype(size<0>(layout))>; + + if constexpr (vnni) { + // Coalesce VNNI mode with next mode in that dimension, if any, + // or else move it to the end of the layout. + constexpr auto vmode = get<0>(layout); + constexpr auto vdim = stride<0>(layout).mode(); + constexpr auto slayout = take<1,R>(layout); + constexpr auto next_vmode = find_first_basis_mode(slayout); + if constexpr (next_vmode < R - 1) + return std::make_pair(replace(slayout, coalesce(make_layout(vmode, get(slayout)))), true); + else + return std::make_pair(append(layout,vmode), true); + } else + return std::make_pair(layout, false); +} + +enum class Block2DTransform {N, T, V}; + +template +CUTE_HOST_DEVICE +constexpr Block2DTransform +block_2d_transform_selector(DesiredCoordLayout const& layout, + GlobalStride const& gstride) +{ + // Stores are always non-transpose. + if constexpr (Store) + return Block2DTransform::N; + + // Check if copy's consumer wants VNNI layout. + constexpr auto result = strip_vnni(DesiredCoordLayout{}); + constexpr auto slayout = get<0>(result); + constexpr bool vnni = get<1>(result); + constexpr bool transpose = !is_constant_v<1, decltype(basis_get(stride<0>(slayout), gstride))>; + + // If VNNI needed, use VNNI load for 8/16-bit types in memory, otherwise regular. + if constexpr (vnni && !transpose) + return (MemBits == 8 || MemBits == 16) ? Block2DTransform::V : Block2DTransform::N; + + // Otherwise, use transpose load if significant transposition required. + if constexpr (transpose && decltype(size<0>(slayout))::value * MemBits >= 16) + return Block2DTransform::T; + else + return Block2DTransform::N; +} + +// Heuristically select a block 2D copy operation. +// MemType: type of data in memory +// RegType: type of data in registers, as associated with CoordLayout +// Store: true for stores, false for loads (default) +// CoordLayout: desired subgroup coordinate layout in registers +// (Note: a reorder may be required to achieve data in this layout) +// GlobalStride: strides of data in memory +template +CUTE_HOST_DEVICE +constexpr auto +block_2d_selector(CoordLayout const&, GlobalStride const&) +{ + static_assert(is_static_v, "Coordinate layout must be static"); + + auto layout = coalesce(CoordLayout{}); + GlobalStride gstride{}; + + // Determine size of copy. + constexpr int MemBits = sizeof_bits_v; + constexpr int RegBits = sizeof_bits_v; + + // Determine which kind of block 2D message to use (regular/VNNI/transpose) + constexpr auto kind = block_2d_transform_selector(layout, gstride); + + // Strip off VNNI mode if present. + constexpr auto slayout = get<0>(strip_vnni(layout)); + + constexpr auto x_mode = find_x_mode(gstride); + constexpr auto y_mode = find_y_mode(gstride); + + constexpr int grf_elems = 512 / RegBits; + constexpr bool resize = (MemBits != RegBits); + + auto shape = atuple_coshape(layout); + + if constexpr (kind != Block2DTransform::T) { + constexpr int CopyBits = cute::max(8, cute::min(64, MemBits)); + + // Determine block width. + // Get innermost stride in x dimension that is >= 1/2 GRF + // Block width = highest power of 2 divisor (up to 64b) + // Width = highest power of 2 divisor of full tile's width, up to 64b and 4x block width + constexpr int max_w = 64 * 8 / MemBits; + constexpr int x_stride = get_block_size(slayout); + constexpr int block_width = cute::gcd(max_w, x_stride); + constexpr int load_width = cute::gcd(cute::min(max_w, 4 * block_width), + get(shape)); + constexpr int width = Store ? block_width : load_width; + constexpr int block_cwidth = block_width * MemBits / CopyBits; + constexpr int cwidth = width * MemBits / CopyBits; + + // Determine block height. + // Get innermost stride in H dimension, besides VNNI stride if VNNI. + // However, if data resizing will occur, choose full tile height, up to block height limit. + // (Rationale: we are already moving data, so layouts don't need to match) + constexpr int y_stride = get_block_size(slayout); + constexpr int max_h = Store ? 8 : 32; + constexpr int height = cute::gcd(resize ? get(shape) : y_stride, max_h); + + if constexpr (Store) + return XE_STORE_2D {}; + else if constexpr (kind == Block2DTransform::V) + return XE_LOAD_2D_VNNI{}; + else + return XE_LOAD_2D {}; + } else { + // Similar process for transposing copies, but with width/height reversed. + constexpr int CopyBits = cute::max(32, cute::min(64, MemBits)); + + constexpr int y_stride = get_block_size(slayout); + constexpr int height = cute::gcd(32, y_stride); + + constexpr int max_w = 32 * 8 / MemBits; + constexpr int x_stride = get_block_size(slayout); + constexpr int width = cute::gcd(resize ? get(shape) : x_stride, max_w); + constexpr int cwidth = width * MemBits / CopyBits; + + return XE_LOAD_2D_TRANSPOSE{}; + } +} + +// Helper for make_block_2d_copy_* routines +template +CUTE_HOST_DEVICE +auto +make_block_2d_copy_X(CopyOp const& op, // Copy operation + TiledMMA const& mma, // TiledMMA instance + Stride const& gstride, // Global memory strides + XMode const& x_mode, // x, y modes + YMode const& y_mode, + MMAShape const& mma_shape, // Coordinate space + SVLayout const& sv_layout) // (SG,V) -> coord +{ + // Divide coordinate codomain into copy tiles. + constexpr int Width = CopyOp::AtomWidth * CopyOp::CopyBits / sizeof_bits_v; + constexpr int Height = CopyOp::AtomHeight; + auto op_tile = Int{} * E{} + + Int{} * E{}; + auto atom_shape = shape_div(mma_shape, op_tile); + + auto divide_by_op_tile = zip(make_layout(op_tile, make_stride(_0{}, _0{})), + make_layout(atom_shape)); // (M,K) -> (M tile, K tile) + + auto sv_layout_t0 = composition(divide_by_op_tile, sv_layout); // (SG,V) -> (M tile, K tile) + + // Filter out value modes that are internal to copy tiles. + auto sv_layout_t = make_layout(get<0>(sv_layout_t0), + filter(get<1>(sv_layout_t0))); // (SG,V') -> (M tile, K tile) + + // Tile copy operation. + return make_block_2d_copy(op, gstride, x_mode, y_mode, atom_shape, sv_layout_t); +} + +// MMA-focused TiledCopy creation functions. +template +CUTE_HOST_DEVICE +auto +make_block_2d_copy_A(TiledMMA const& mma, // TiledMMA instance + Tensor const& gmem) // Global tensor +{ + using ValType = typename GEngine::value_type; + return make_block_2d_copy_A(mma, gmem.stride()).with(gmem); +} + +template +CUTE_HOST_DEVICE +auto +make_block_2d_copy_A(CopyOp const& op, // Copy operation + TiledMMA const& mma, // TiledMMA instance + Tensor const& gmem) // Global tensor +{ + using ValType = typename GEngine::value_type; + return make_block_2d_copy_A(op, mma, gmem.stride()).with(gmem); +} + +template +CUTE_HOST_DEVICE +auto +make_block_2d_copy_A(TiledMMA const& mma, // TiledMMA instance + Stride const& gstride) // Global memory strides +{ + using MMAType = typename TiledMMA::ValTypeA; + auto cA = make_identity_tensor(select<0,2>(mma.tile_mnk())); + auto op = block_2d_selector(mma.get_slice(0).atom_partition_A(cA).layout(), gstride); + return make_block_2d_copy_A(op, mma, gstride); +} + +template +CUTE_HOST_DEVICE +auto +make_block_2d_copy_A(CopyOp const& op, // Copy operation + TiledMMA const& mma, // TiledMMA instance + Stride const& gstride) // Global memory strides +{ + return make_block_2d_copy_A(op, mma, gstride, find_x_mode(gstride), find_y_mode(gstride)); +} + +template +CUTE_HOST_DEVICE +auto +make_block_2d_copy_A(CopyOp const& op, // Copy operation + TiledMMA const& mma, // TiledMMA instance + Stride const& gstride, // Global memory strides + XMode const& x_mode, // x, y modes + YMode const& y_mode) +{ + // Retrieve MMA atom's (subgroup, value) -> (M,K) layout + auto tile_mk = select<0,2>(mma.tile_mnk()); + + auto thr_vmnk = mma.get_thr_layout_vmnk(); // (ThrV,ThrM,ThrN,ThrK) -> thr + auto shape_vmnk = shape(thr_vmnk); // (ThrV,ThrM,ThrN,ThrK) + auto drop_n = make_layout(shape_vmnk, + make_stride(_1{}, get<0>(shape_vmnk), _0{}, + get<0>(shape_vmnk) * get<1>(shape_vmnk))); // (ThrV,ThrM,ThrN,ThrK) -> (ThrV,ThrM,ThrK) + + auto thr_to_vmk = composition(drop_n, right_inverse(thr_vmnk)); // thr -> (ThrV,ThrM,ThrK) + auto sg_to_vmk = composition(thr_to_vmk, + make_layout(product(take<1,4>(shape_vmnk)), get<0>(shape_vmnk))); // SG -> (0,ThrM,ThrK) + + auto svA = composition(mma.thrfrg_A(make_layout(tile_mk)), + make_tile(sg_to_vmk, _)); // (SG,V) -> (M,K) + + // Derive copy tile layout and create TiledCopy + return make_block_2d_copy_X(op, mma, gstride, x_mode, y_mode, tile_mk, svA); +} + +template +CUTE_HOST_DEVICE +auto +make_block_2d_copy_B(TiledMMA const& mma, // TiledMMA instance + Tensor const& gmem) // Global tensor +{ + using ValType = typename GEngine::value_type; + return make_block_2d_copy_B(mma, gmem.stride()).with(gmem); +} + +template +CUTE_HOST_DEVICE +auto +make_block_2d_copy_B(CopyOp const& op, // Copy operation + TiledMMA const& mma, // TiledMMA instance + Tensor const& gmem) // Global tensor +{ + using ValType = typename GEngine::value_type; + return make_block_2d_copy_B(op, mma, gmem.stride()).with(gmem); +} + +template +CUTE_HOST_DEVICE +auto +make_block_2d_copy_B(TiledMMA const& mma, // TiledMMA instance + Stride const& gstride) // Global memory strides +{ + using MMAType = typename TiledMMA::ValTypeB; + auto cB = make_identity_tensor(select<1,2>(mma.tile_mnk())); + auto op = block_2d_selector(mma.get_slice(0).atom_partition_B(cB).layout(), gstride); + return make_block_2d_copy_B(op, mma, gstride); +} + +template +CUTE_HOST_DEVICE +auto +make_block_2d_copy_B(CopyOp const& op, // Copy operation + TiledMMA const& mma, // TiledMMA instance + Stride const& gstride) // Global memory strides +{ + return make_block_2d_copy_B(op, mma, gstride, find_x_mode(gstride), find_y_mode(gstride)); +} + +template +CUTE_HOST_DEVICE +auto +make_block_2d_copy_B(CopyOp const& op, // Copy operation + TiledMMA const& mma, // TiledMMA instance + Stride const& gstride, // Global memory strides + XMode const& x_mode, // x, y modes + YMode const& y_mode) +{ + // Retrieve MMA atom's (subgroup, value) -> (N,K) layout + auto tile_nk = select<1,2>(mma.tile_mnk()); + + auto thr_vmnk = mma.get_thr_layout_vmnk(); // (ThrV,ThrM,ThrN,ThrK) -> thr + auto shape_vmnk = shape(thr_vmnk); // (ThrV,ThrM,ThrN,ThrK) + auto drop_m = make_layout(shape_vmnk, + make_stride(_1{}, _0{}, get<0>(shape_vmnk), _0{}, + get<0>(shape_vmnk) * get<2>(shape_vmnk))); // (ThrV,ThrM,ThrN,ThrK) -> (ThrV,ThrN,ThrK) + + auto thr_to_vnk = composition(drop_m, right_inverse(thr_vmnk)); // thr -> (ThrV,ThrN,ThrK) + auto sg_to_vnk = composition(thr_to_vnk, + make_layout(product(take<1,4>(shape_vmnk)), get<0>(shape_vmnk))); // SG -> (0,ThrN,ThrK) + + auto svB = composition(mma.thrfrg_B(make_layout(tile_nk)), + make_tile(sg_to_vnk, _)); // (SG,V) -> (N,K) + + // Derive copy tile layout and create TiledCopy + return make_block_2d_copy_X(op, mma, gstride, x_mode, y_mode, tile_nk, svB); +} + +template +CUTE_HOST_DEVICE +auto +make_block_2d_copy_C(TiledMMA const& mma, // TiledMMA instance + Tensor const& gmem) // Global tensor +{ + using ValType = typename GEngine::value_type; + return make_block_2d_copy_C(mma, gmem.stride()).with(gmem); +} + +template +CUTE_HOST_DEVICE +auto +make_block_2d_copy_C(CopyOp const& op, // Copy operation + TiledMMA const& mma, // TiledMMA instance + Tensor const& gmem) // Global tensor +{ + using ValType = typename GEngine::value_type; + return make_block_2d_copy_C(op, mma, gmem.stride()).with(gmem); +} + +template +CUTE_HOST_DEVICE +auto +make_block_2d_copy_C(TiledMMA const& mma, // TiledMMA instance + Stride const& gstride) // Global memory strides +{ + using MMAType = typename TiledMMA::ValTypeA; + auto cC = make_identity_tensor(select<0,1>(mma.tile_mnk())); + auto op = block_2d_selector( + mma.get_slice(0).atom_partition_C(cC).layout(), gstride + ); + return make_block_2d_copy_C(op, mma, gstride); +} + +template +CUTE_HOST_DEVICE +auto +make_block_2d_copy_C(CopyOp const& op, // Copy operation + TiledMMA const& mma, // TiledMMA instance + Stride const& gstride) // Global memory strides +{ + return make_block_2d_copy_C(op, mma, gstride, find_x_mode(gstride), find_y_mode(gstride)); +} + +template +CUTE_HOST_DEVICE +auto +make_block_2d_copy_C(CopyOp const& op, // Copy operation + TiledMMA const& mma, // TiledMMA instance + Stride const& gstride, // Global memory strides + XMode const& x_mode, // x, y modes + YMode const& y_mode) +{ + // Retrieve MMA atom's (subgroup, value) -> (M,N) layout + auto tile_mn = select<0,1>(mma.tile_mnk()); + + auto thr_vmnk = mma.get_thr_layout_vmnk(); // (ThrV,ThrM,ThrN,ThrK) -> thr + auto shape_vmnk = shape(thr_vmnk); // (ThrV,ThrM,ThrN,ThrK) + auto drop_k = replace<3>(make_layout(shape_vmnk), + make_layout(get<3>(shape_vmnk), _0{})); // (ThrV,ThrM,ThrN,ThrK) -> (ThrV,ThrM,ThrN) + + auto thr_to_vmn = composition(drop_k, right_inverse(thr_vmnk)); // thr -> (ThrV,ThrM,ThrN) + auto sg_to_vmn = composition(thr_to_vmn, + make_layout(product(take<1,4>(shape_vmnk)), get<0>(shape_vmnk))); // SG -> (0,ThrM,ThrN) + + auto svC = composition(mma.thrfrg_C(make_layout(tile_mn)), + make_tile(sg_to_vmn, _)); // (SG,V) -> (M,N) + + // Derive copy tile layout and create TiledCopy + return make_block_2d_copy_X(op, mma, gstride, x_mode, y_mode, tile_mn, svC); +} + +// Prefetch selection and creation. +namespace detail { + template + CUTE_HOST_DEVICE decltype(auto) + as_block_2d_traits(Xe2DTraitsBase const &o) { + return o; + } +}; + +template +CUTE_HOST_DEVICE +auto +make_block_2d_prefetch(TiledCopy const& tiled_copy) +{ + using TCopy = TiledCopy; + + constexpr auto sg_count = typename TCopy::TiledNumThr{} / typename TCopy::AtomNumThr{}; + auto &traits = detail::as_block_2d_traits(tiled_copy); + + return make_block_2d_prefetch( + ShapeTiler_MN{}, traits.tiled_strides, traits.get_x_mode(), traits.get_y_mode() + ).with(traits); +} + +template +CUTE_HOST_DEVICE +auto +make_block_2d_prefetch(const Shape& shape, Tensor const& gmem) +{ + using ValType = typename Engine::value_type; + return make_block_2d_prefetch(shape, gmem.stride()).with(gmem); +} + +template +CUTE_HOST_DEVICE +auto +make_block_2d_prefetch(const Shape& shape, Stride const& stride) +{ + return make_block_2d_prefetch(shape, stride, find_x_mode(stride), find_y_mode(stride)); +} + +template +CUTE_HOST_DEVICE +auto +make_block_2d_prefetch(const Shape&, Stride const& stride, const XMode& x_mode, const YMode& y_mode) +{ + constexpr auto shape_x = get(Shape{}); + constexpr auto shape_y = get(Shape{}); + + // Try to retrieve whole cache lines (contiguous dimension = x) + constexpr auto width = cute::min(shape_x, 512 / sizeof_bits_v); + + // Do a preliminary tiling to choose appropriate height. + constexpr int n_sg_x = cute::gcd(SGCount, ceil_div(shape_x, width)); + constexpr int n_sg_y = SGCount / n_sg_x; + + constexpr auto max_height = 32; + constexpr auto height = cute::min(max_height, ceil_div(shape_y, n_sg_y)); + + // Select op. + using CopyType = int_byte_t; + using CopyOp = XE_PREFETCH_2D, + height, + ceil_div(width * sizeof_bits_v, sizeof_bits_v)>; + + return make_block_2d_prefetch(CopyOp{}, Shape{}, stride, x_mode, y_mode); +} + +// Low-level prefetch creation utility. +template +CUTE_HOST_DEVICE +auto +make_block_2d_prefetch(PrefetchOp const& op, + Shape const& shape, + Stride const& stride, + XMode const& x_mode, + YMode const& y_mode) +{ + constexpr auto all_1s = tuple_repeat(_1{}); + constexpr auto width = PrefetchOp::AtomWidth * PrefetchOp::CopyBits / sizeof_bits_v; + constexpr auto height = PrefetchOp::AtomHeight; + + auto op_tile = replace(replace(all_1s, Int{}), Int{}); + + // Reduce shape to grid of atoms. + auto atom_shape = shape_div(shape, op_tile); + + // Replicate op tile across subgroups, traversing the innermost dimension first. + // Ensure the resulting collective tile goes evenly into the given shape (may not be a power of 2) + constexpr int n_sg_x = cute::gcd(SGCount, get(atom_shape)); + constexpr int n_sg_y = SGCount / n_sg_x; + + auto collective_op_tile = replace(replace(all_1s, + Int{}), + Int{}); + + // Tile atom grid across collective op tile. + auto sv_layout = zipped_divide(make_layout(collective_op_tile), atom_shape); + + // Create the TiledCopy object. + return make_block_2d_copy(op, stride, x_mode, y_mode, atom_shape, sv_layout); +} + + + +// +// Display utilities +// +template +CUTE_HOST_DEVICE +void +print_block_2d_traits(Xe2DTraitsBase const& traits) +{ + print(" Width: "); print(Op::AtomWidth); print("\n"); + print(" Height: "); print(Op::AtomHeight); print("\n"); + print(" CopyType: "); print(Op::CopyBits); print("b\n"); + print(" ValueType: "); print(sizeof_bits_v); print("b\n"); + print(" XMode: "); print(XMode{}); print("\n"); + print(" YMode: "); print(YMode{}); print("\n"); + print(" TiledStrides: "); print(traits.tiled_strides); print("\n"); +} + +template +CUTE_HOST_DEVICE +void +print_block_2d_atom(Copy_Atom const& atom) +{ + using Atom = remove_cvref_t; + print(" ThrID: "); print(typename Atom::ThrID{}); print("\n"); + print(" ValLayoutSrc: "); print(typename Atom::ValLayoutSrc{}); print("\n"); + print(" ValLayoutDst: "); print(typename Atom::ValLayoutDst{}); print("\n"); + print(" ValLayoutRef: "); print(typename Atom::ValLayoutRef{}); print("\n"); + if constexpr (sizeof_bits_v != sizeof_bits_v) { + print(" AtomValType: "); print(sizeof_bits_v); print("b\n"); + } +} + +template +CUTE_HOST_DEVICE +void +print(Copy_Atom, + XMode, YMode, ValType, TiledStrides>, AtomValType> const& atom) +{ + print("Copy_Atom (XE_LOAD_2D)\n"); + print(" BlockWidth: "); print(BlockWidth); print("\n"); + print_block_2d_traits(atom); + print("\n"); + print_block_2d_atom(atom); +} + +template +CUTE_HOST_DEVICE +void +print(Copy_Atom, + XMode, YMode, ValType, TiledStrides>, AtomValType> const& atom) +{ + print("Copy_Atom (XE_LOAD_2D_VNNI)\n"); + print(" BlockWidth: "); print(BlockWidth); print("\n"); + print_block_2d_traits(atom); + print("\n"); + print_block_2d_atom(atom); +} + +template +CUTE_HOST_DEVICE +void +print(Copy_Atom, + XMode, YMode, ValType, TiledStrides>, AtomValType> const& atom) +{ + print("Copy_Atom (XE_LOAD_2D_TRANSPOSE)\n"); + print_block_2d_traits(atom); + print("\n"); + print_block_2d_atom(atom); +} + +template +CUTE_HOST_DEVICE +void +print(Copy_Atom, + XMode, YMode, ValType, TiledStrides>, AtomValType> const& atom) +{ + print("Copy_Atom (XE_STORE_2D)\n"); + print_block_2d_traits(atom); + print("\n"); + print_block_2d_atom(atom); +} + +template +CUTE_HOST_DEVICE +void +print(Copy_Atom, + XMode, YMode, ValType, TiledStrides>, AtomValType> const& atom) +{ + print("Copy_Atom (XE_PREFETCH_2D)\n"); + print_block_2d_traits(atom); + print("\n"); + print_block_2d_atom(atom); +} + +} // end namespace cute From a28bb6b2eca11f6d9db53863f9a3deaef5ab9904 Mon Sep 17 00:00:00 2001 From: Peter Caday Date: Mon, 15 Sep 2025 15:57:41 -0700 Subject: [PATCH 10/16] [CuTe core] Introduce SubgroupTensor class --- include/cute/atom/copy_atom.hpp | 16 +++++ include/cute/atom/mma_atom.hpp | 27 +++++++ include/cute/tensor.hpp | 2 + include/cute/tensor_sg.hpp | 122 ++++++++++++++++++++++++++++++++ 4 files changed, 167 insertions(+) create mode 100644 include/cute/tensor_sg.hpp diff --git a/include/cute/atom/copy_atom.hpp b/include/cute/atom/copy_atom.hpp index 4c653c9433..8017f7128a 100644 --- a/include/cute/atom/copy_atom.hpp +++ b/include/cute/atom/copy_atom.hpp @@ -426,6 +426,22 @@ struct ThrCopy partition_fragment_D(DTensor&& dtensor) const { return make_fragment_like(partition_D(dtensor)); } + + template + CUTE_HOST_DEVICE + auto + partition_sg_fragment_S(STensor&& stensor) const { + return make_subgroup_tensor(partition_fragment_S(stensor), + layout(atom_partition_S(stensor))); + } + + template + CUTE_HOST_DEVICE + auto + partition_sg_fragment_D(DTensor&& dtensor) const { + return make_subgroup_tensor(partition_fragment_D(dtensor), + layout(atom_partition_D(dtensor))); + } }; diff --git a/include/cute/atom/mma_atom.hpp b/include/cute/atom/mma_atom.hpp index 1e01fb1106..719ec8e156 100644 --- a/include/cute/atom/mma_atom.hpp +++ b/include/cute/atom/mma_atom.hpp @@ -518,6 +518,33 @@ struct ThrMMA : TiledMMA { return TiledMMA::make_fragment_B(partition_B(btensor)); } + + template + CUTE_HOST_DEVICE constexpr + auto + partition_sg_fragment_C(CTensor&& ctensor) const + { + return make_subgroup_tensor(partition_fragment_C(ctensor), + layout(atom_partition_C(ctensor))); + } + + template + CUTE_HOST_DEVICE constexpr + auto + partition_sg_fragment_A(ATensor&& atensor) const + { + return make_subgroup_tensor(partition_fragment_A(atensor), + layout(atom_partition_A(atensor))); + } + + template + CUTE_HOST_DEVICE constexpr + auto + partition_sg_fragment_B(BTensor&& btensor) const + { + return make_subgroup_tensor(partition_fragment_B(btensor), + layout(atom_partition_B(btensor))); + } }; // diff --git a/include/cute/tensor.hpp b/include/cute/tensor.hpp index 659c903f48..3171d214ee 100644 --- a/include/cute/tensor.hpp +++ b/include/cute/tensor.hpp @@ -1,5 +1,6 @@ /*************************************************************************************************** * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (C) 2025 Intel Corporation, All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without @@ -31,6 +32,7 @@ #pragma once #include +#include // // Extended Engines diff --git a/include/cute/tensor_sg.hpp b/include/cute/tensor_sg.hpp new file mode 100644 index 0000000000..b128bf4e13 --- /dev/null +++ b/include/cute/tensor_sg.hpp @@ -0,0 +1,122 @@ +/*************************************************************************************************** +* Copyright (C) 2025 Intel Corporation, All rights reserved. +* SPDX-License-Identifier: BSD-3-Clause +* +* Redistribution and use in source and binary forms, with or without +* modification, are permitted provided that the following conditions are met: +* +* 1. Redistributions of source code must retain the above copyright notice, this +* list of conditions and the following disclaimer. +* +* 2. Redistributions in binary form must reproduce the above copyright notice, +* this list of conditions and the following disclaimer in the documentation +* and/or other materials provided with the distribution. +* +* 3. Neither the name of the copyright holder nor the names of its +* contributors may be used to endorse or promote products derived from +* this software without specific prior written permission. +* +* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +* +**************************************************************************************************/ + +#pragma once + +#include // cute::Tensor + +namespace cute +{ + +// +// SubgroupTensor +// +// A SubgroupTensor represents a subgroup-scope tensor, +// e.g. the result of a block 2D load, or an input/output matrix to DPAS. +// +// SubgroupTensor wraps a standard CuTe rmem tensor ("fragment") holding the current +// work-item's slice of the tensor. It implicitly decays to this fragment, so it can +// be used as a regular rmem Tensor. +// +// In addition, a SubgroupTensor holds a thread-value layout identifying logical coordinates +// for each element of the tensor. The interpretation of the logical coordinates is user-defined, +// Reorder operations use these logical coordinates to identify corresponding values in +// the source and destination tensors. +// + +template V + class SubgroupTVLayout> // (T,V) -> coord in subgroup +struct SubgroupTensor : Tensor +{ + using Base = Tensor; + + using typename Base::iterator; + using typename Base::value_type; + using typename Base::element_type; + using typename Base::reference; + using typename Base::engine_type; + using typename Base::layout_type; + + CUTE_HOST_DEVICE constexpr + SubgroupTensor() {} + + CUTE_HOST_DEVICE constexpr explicit + SubgroupTensor(Base const& base) { + *this = static_cast(base); + } + + static constexpr int rank = Layout::rank; + + CUTE_HOST_DEVICE constexpr + decltype(auto) + tensor() const { + return *static_cast(this); + } + + CUTE_HOST_DEVICE constexpr + auto + tv_layout() const { + return SubgroupTVLayout{}; + } +}; + +template +struct is_tensor> : true_type {}; + +template::value)> +CUTE_HOST_DEVICE +constexpr auto +make_subgroup_tensor(Tensor const& tensor, SubgroupTVLayout const&) +{ + static_assert(is_static_v, "Subgroup TV layout must be static"); + static_assert(is_rmem_v, "Expected an rmem tensor"); + return static_cast const&>(tensor); +} + + +// +// Display utilities +// + +template +CUTE_HOST_DEVICE void print(SubgroupTensor const& tensor) +{ + print("SubgroupTensor\n"); + print(" Tensor: "); print(static_cast const&>(tensor)); print("\n"); + print(" SubgroupTVLayout: "); print(SubgroupTVLayout{}); print("\n"); +} + +} // end namespace cute + From 9e00f6df5d33c96345d93f41c2e94336dd0022ad Mon Sep 17 00:00:00 2001 From: Peter Caday Date: Tue, 26 Aug 2025 10:59:09 -0700 Subject: [PATCH 11/16] [CuTe core] Add atom_partition_* methods to MMA/copy atoms --- include/cute/atom/copy_atom.hpp | 26 +++++++++++++++++++++ include/cute/atom/mma_atom.hpp | 40 +++++++++++++++++++++++++++++++++ 2 files changed, 66 insertions(+) diff --git a/include/cute/atom/copy_atom.hpp b/include/cute/atom/copy_atom.hpp index 8017f7128a..69f492807d 100644 --- a/include/cute/atom/copy_atom.hpp +++ b/include/cute/atom/copy_atom.hpp @@ -395,6 +395,32 @@ struct ThrCopy return thr_tensor(thr_idx_, _, repeat>(_)); } + template + CUTE_HOST_DEVICE + auto + atom_partition_S(STensor&& stensor) const { + // Get fragment layout, and group atom thread modes (ThrV) since that is not done by tidfrg_D. + static constexpr auto RThrV = rank<0>(typename TiledCopy::AtomLayoutSrc{}); + auto tf_layout0 = TiledCopy::tidfrg_S(stensor.layout()); + auto tf_layout = replace<0>(tf_layout0, group<0,RThrV>(get<0>(tf_layout0))); + auto thr_tensor = make_tensor(static_cast(stensor).data(), tf_layout); + // Index, selecting full ThrV slice. + auto thr = idx2crd(thr_idx_, shape<0>(thr_tensor)); + return thr_tensor(replace<0>(thr, _), _, _); + } + + template + CUTE_HOST_DEVICE + auto + atom_partition_D(DTensor&& dtensor) const { + static constexpr auto RThrV = rank<0>(typename TiledCopy::AtomLayoutDst{}); + auto tf_layout0 = TiledCopy::tidfrg_D(dtensor.layout()); + auto tf_layout = replace<0>(tf_layout0, group<0,RThrV>(get<0>(tf_layout0))); + auto thr_tensor = make_tensor(static_cast(dtensor).data(), tf_layout); + auto thr = idx2crd(thr_idx_, shape<0>(thr_tensor)); + return thr_tensor(replace<0>(thr, _), _, _); + } + template CUTE_HOST_DEVICE static auto diff --git a/include/cute/atom/mma_atom.hpp b/include/cute/atom/mma_atom.hpp index 719ec8e156..b950a8d64f 100644 --- a/include/cute/atom/mma_atom.hpp +++ b/include/cute/atom/mma_atom.hpp @@ -372,6 +372,12 @@ struct TiledMMA : MMA_Atom return get_slice(thr_idx); } + CUTE_HOST_DEVICE constexpr + auto + tile_mnk() const { + return make_tile(tile_size_mnk<0>(), tile_size_mnk<1>(), tile_size_mnk<2>()); + } + // // Utility for printing and visualization // @@ -495,6 +501,40 @@ struct ThrMMA : TiledMMA return thr_tensor(thr_vnk, make_coord(_, repeat(thr_tensor)>(_))); } + // Atom-level partitioning + template + CUTE_HOST_DEVICE constexpr + auto + atom_partition_C(CTensor&& ctensor) const + { + auto thr_tensor = make_tensor(static_cast(ctensor).data(), this->thrfrg_C(ctensor.layout())); + + auto atom_vmn = make_coord(_, make_coord(get<1>(thr_vmnk_), get<2>(thr_vmnk_))); + return thr_tensor(atom_vmn, _); // (atom-local thr, val) -> coord + } + + template + CUTE_HOST_DEVICE constexpr + auto + atom_partition_A(ATensor&& atensor) const + { + auto thr_tensor = make_tensor(static_cast(atensor).data(), this->thrfrg_A(atensor.layout())); + + auto atom_vmk = make_coord(_, make_coord(get<1>(thr_vmnk_), get<3>(thr_vmnk_))); + return thr_tensor(atom_vmk, _); + } + + template + CUTE_HOST_DEVICE constexpr + auto + atom_partition_B(BTensor&& btensor) const + { + auto thr_tensor = make_tensor(static_cast(btensor).data(), this->thrfrg_B(btensor.layout())); + + auto atom_vnk = make_coord(_, make_coord(get<2>(thr_vmnk_), get<3>(thr_vmnk_))); + return thr_tensor(atom_vnk, _); + } + template CUTE_HOST_DEVICE constexpr auto From 2a6739efba22ac9de5658b1b602a72173a20890b Mon Sep 17 00:00:00 2001 From: Peter Caday Date: Mon, 18 Aug 2025 22:23:29 -0700 Subject: [PATCH 12/16] [CuTe] Introduce reorders --- include/cute/algorithm/reorder.hpp | 235 +++++ include/cute/arch/reorder.hpp | 48 + include/cute/arch/reorder_xe.hpp | 1230 +++++++++++++++++++++++++ include/cute/arch/util.hpp | 14 + include/cute/atom/reorder_atom.hpp | 51 + include/cute/atom/reorder_atom_xe.hpp | 241 +++++ include/cute/tensor.hpp | 1 + 7 files changed, 1820 insertions(+) create mode 100644 include/cute/algorithm/reorder.hpp create mode 100644 include/cute/arch/reorder.hpp create mode 100644 include/cute/arch/reorder_xe.hpp create mode 100644 include/cute/atom/reorder_atom.hpp create mode 100644 include/cute/atom/reorder_atom_xe.hpp diff --git a/include/cute/algorithm/reorder.hpp b/include/cute/algorithm/reorder.hpp new file mode 100644 index 0000000000..ea2d6c6018 --- /dev/null +++ b/include/cute/algorithm/reorder.hpp @@ -0,0 +1,235 @@ +/*************************************************************************************************** +* Copyright (C) 2025 Intel Corporation, All rights reserved. +* SPDX-License-Identifier: BSD-3-Clause +* +* Redistribution and use in source and binary forms, with or without +* modification, are permitted provided that the following conditions are met: +* +* 1. Redistributions of source code must retain the above copyright notice, this +* list of conditions and the following disclaimer. +* +* 2. Redistributions in binary form must reproduce the above copyright notice, +* this list of conditions and the following disclaimer in the documentation +* and/or other materials provided with the distribution. +* +* 3. Neither the name of the copyright holder nor the names of its +* contributors may be used to endorse or promote products derived from +* this software without specific prior written permission. +* +* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +* +**************************************************************************************************/ + +#pragma once + +#include // CUTE_HOST_DEVICE +#include // cute::Tensor +#include + +// Subgroup-level ("warp" in CUDA terminology) register-to-register reorder operations. +// Currently implemented for Xe only. + + +namespace cute +{ + +namespace detail { + +// Modify subgroup TV layout for subbyte types. +// +// In general on Xe successive elements in registers are assigned to work-items in +// round-robin order (interleaved at element granularity). However, subbyte types are +// only interleaved at byte granularity. +// +// This routine modifies the incoming layout to appear as though work-item ownership for subbyte +// types is also at element granularity, to uniformize later logic. +template +CUTE_HOST_DEVICE +constexpr decltype(auto) +subbyte_sg_tv_swizzle(const InLayout &layout) +{ +#ifdef SYCL_INTEL_TARGET + using namespace intel; + if constexpr (sizeof_bits_v >= 8) + return layout; + else { + static_assert(is_static_v, "Layout must be static"); + constexpr auto values = size(InLayout{}) / sg_size; + constexpr auto per_byte = 8 / sizeof_bits_v; + static_assert(values % per_byte == 0, "Partially-occupied bytes in layout"); + return composition(layout, Layout, C>, Shape, C>>, + Stride, Stride, C>>>{}); + } +#else + return layout; +#endif +} + +} /* namespace detail */ + +// Subgroup-cooperative reorder. +// src, dst: WI-owned fragments +// slayout, dlayout: subgroup TV-layouts for these fragments. +// +// The layout of src/dst can be arbitrary. The TV layouts +// are used to map values in src to values in dst. +template +CUTE_HOST_DEVICE +void +reorder(Tensor const& src, // WI fragment + Tensor & dst, // WI fragment + SLayout const& slayout, // (src thr, src val) -> coord + DLayout const& dlayout) // (dst thr, dst val) -> coord +{ + using SType = typename SEngine::element_type; + using DType = typename DEngine::element_type; + + static_assert(is_static_v, "Reorder source layout must be static"); + static_assert(is_static_v, "Reorder destination layout must be static"); + + auto sl0 = detail::subbyte_sg_tv_swizzle(project_strides(slayout)); + auto dl0 = detail::subbyte_sg_tv_swizzle(project_strides(dlayout)); + +#ifdef SYCL_INTEL_TARGET + auto impl = choose_xe_reorder_impl(sl0, dl0); // -> atom or dispatch tag +#else + static_assert("Reorder only implemented on Xe"); +#endif + + reorder_impl(impl, src, dst, sl0, dl0); +} + +template +CUTE_HOST_DEVICE +void +reorder(SubgroupTensor const& src, + SubgroupTensor & dst) +{ + reorder(src, dst, src.tv_layout(), dst.tv_layout()); +} + +// Accept mutable temporaries +template +CUTE_HOST_DEVICE +void +reorder(Tensor const& src, // WI fragment + Tensor && dst, // WI fragment + SLayout const& slayout, // (src thr, src val) -> coord + DLayout const& dlayout) // (dst thr, dst val) -> coord +{ + reorder(src, dst, slayout, dlayout); +} + +template +CUTE_HOST_DEVICE +void +reorder(SubgroupTensor const& src, + SubgroupTensor && dst) +{ + reorder(src, dst); +} + +// Base case for reorders: loop over reorder atoms +template +CUTE_HOST_DEVICE +void +reorder_impl(ReorderAtom const& atom, + Tensor const& src, // WI fragment + Tensor & dst, // WI fragment + SLayout const& slayout, // (src thr, src val) -> coord + DLayout const& dlayout) // (dst thr, dst val) -> coord +{ + using _SG = intel::_SGSize; + using SType = typename SEngine::element_type; + using RegistersSrc = typename ReorderAtom::SRegisters; + using RegistersDst = typename ReorderAtom::DRegisters; + using RegTypeSrc = typename remove_extent::type; + using RegTypeDst = typename remove_extent::type; + constexpr int RegNumSrc = extent::value; + constexpr int RegNumDst = extent::value; + constexpr int values = size(SLayout{}) / size<0>(SLayout{}); + constexpr int vchunk = sizeof_bits_v / sizeof_bits_v; + + // Calculate mapping from src val -> dst val on a chunk-by-chunk basis. Unlike a plain copy, there is no intrinsic + // correspondence of src/dst values for subgroup reorders. + auto rlayout = coalesce(composition(right_inverse(dlayout), slayout)); // src index -> dst index + auto vrlayout = composition(composition(Layout>, Stride<_0, _1>>{}, + rlayout), + Layout>, Stride<_0, _SG>>{}); // src val -> dst val + + CUTE_UNROLL + for (int sv = 0; sv < values; sv += vchunk) { + auto pS = recast_ptr(src.data() + sv); + auto pD = recast_ptr(dst.data() + vrlayout(sv)); + + detail::explode(detail::CallReorder{}, + pS, make_int_sequence{}, + pD, make_int_sequence{}); + } +} + +template +using upcast_subbyte_t = conditional_t, + conditional_t::is_integer, + conditional_t::is_signed, + int8_t, uint8_t>, + half_t>, + T>; + +// Reorder strategy: type conversion, then layout change. +template +CUTE_HOST_DEVICE +void +reorder_impl(ReorderDispatchConvertRelayout const&, + Tensor const& src, // WI fragment + Tensor & dst, // WI fragment + SLayout const& slayout, // (src thr, src val) -> coord + DLayout const& dlayout) // (dst thr, dst val) -> coord +{ + using SrcType = typename SEngine::element_type; + using DstType = typename DEngine::element_type; + using NewSrcType = conditional_t, upcast_subbyte_t, DstType>; + auto src_c = make_fragment_like(src); + + reorder(src, src_c, slayout, slayout); + reorder(src_c, dst, slayout, dlayout); +} + +// Reorder strategy: layout change, then type conversion +template +CUTE_HOST_DEVICE +void +reorder_impl(ReorderDispatchRelayoutConvert const&, + Tensor const& src, // WI fragment + Tensor & dst, // WI fragment + SLayout const& slayout, // (src thr, src val) -> coord + DLayout const& dlayout) // (dst thr, dst val) -> coord +{ + using SrcType = typename SEngine::element_type; + using DstType = typename DEngine::element_type; + using NewDstType = conditional_t, upcast_subbyte_t, SrcType>; + auto dst_c = make_fragment_like(dst); + + reorder(src, dst_c, slayout, dlayout); + reorder(dst_c, dst, dlayout, dlayout); +} + + +} // end namespace cute diff --git a/include/cute/arch/reorder.hpp b/include/cute/arch/reorder.hpp new file mode 100644 index 0000000000..a2caa033f0 --- /dev/null +++ b/include/cute/arch/reorder.hpp @@ -0,0 +1,48 @@ +/*************************************************************************************************** +* Copyright (C) 2025 Intel Corporation, All rights reserved. +* SPDX-License-Identifier: BSD-3-Clause +* +* Redistribution and use in source and binary forms, with or without +* modification, are permitted provided that the following conditions are met: +* +* 1. Redistributions of source code must retain the above copyright notice, this +* list of conditions and the following disclaimer. +* +* 2. Redistributions in binary form must reproduce the above copyright notice, +* this list of conditions and the following disclaimer in the documentation +* and/or other materials provided with the distribution. +* +* 3. Neither the name of the copyright holder nor the names of its +* contributors may be used to endorse or promote products derived from +* this software without specific prior written permission. +* +* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +* +**************************************************************************************************/ + +#pragma once + +namespace cute { + +// Universal reorder with no change of layout. +template +struct Universal_Reorder_UU { + using SRegisters = SrcType[1]; + using DRegisters = DstType[1]; + + CUTE_HOST_DEVICE static void + reorder(SrcType const& src0, DstType& dst0) { + dst0 = src0; + } +}; + +} // end namespace cute diff --git a/include/cute/arch/reorder_xe.hpp b/include/cute/arch/reorder_xe.hpp new file mode 100644 index 0000000000..496a0180a1 --- /dev/null +++ b/include/cute/arch/reorder_xe.hpp @@ -0,0 +1,1230 @@ +/*************************************************************************************************** +* Copyright (C) 2025 Intel Corporation, All rights reserved. +* SPDX-License-Identifier: BSD-3-Clause +* +* Redistribution and use in source and binary forms, with or without +* modification, are permitted provided that the following conditions are met: +* +* 1. Redistributions of source code must retain the above copyright notice, this +* list of conditions and the following disclaimer. +* +* 2. Redistributions in binary form must reproduce the above copyright notice, +* this list of conditions and the following disclaimer in the documentation +* and/or other materials provided with the distribution. +* +* 3. Neither the name of the copyright holder nor the names of its +* contributors may be used to endorse or promote products derived from +* this software without specific prior written permission. +* +* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +* +**************************************************************************************************/ + +#pragma once + +#include // native vector types +#include // Universal_Reorder_UU + +#if defined(__SYCL_DEVICE_ONLY__) && defined(SYCL_INTEL_TARGET) +#define CUTE_ARCH_REORDER_XE_ENABLED +#endif + +namespace cute { + +template +struct Xe_Reorder : Universal_Reorder_UU {}; + +template <> +struct Xe_Reorder +{ + using SRegisters = intel::uchar4[1]; + using DRegisters = intel::ushort4[1]; + + CUTE_HOST_DEVICE static void + reorder(intel::uchar4 const& src0, intel::ushort4& dst0) + { +#if defined(CUTE_ARCH_COPY_XE_ENABLED) + asm ( /* 2 cycles/output register */ + "{\n" + ".decl IN_UB v_type=G type=UB num_elts=64 alias=<%1,0>\n" + ".decl OUT_UW v_type=G type=UW num_elts=64 alias=<%0,0>\n" + ".decl OUT_HF v_type=G type=HF num_elts=64 alias=<%0,0>\n" + "or (M1_NM, 32) OUT_UW(0,0)<1> IN_UB(0,0)<1;1,0> 0x6400:uw\n" + "or (M1_NM, 32) OUT_UW(1,0)<1> IN_UB(0,32)<1;1,0> 0x6400:uw\n" + "add (M1_NM, 32) OUT_HF(0,0)<1> OUT_HF(0,0)<1;1,0> 0xE400:hf\n" + "add (M1_NM, 32) OUT_HF(1,0)<1> OUT_HF(1,0)<1;1,0> 0xE400:hf\n" + "}\n" + : "=rw"(dst0) + : "rw"(src0) + ); +#else + CUTE_INVALID_CONTROL_PATH("Not Xe"); +#endif + } +}; + +template <> +struct Xe_Reorder +{ + using SRegisters = intel::uchar4[1]; + using DRegisters = intel::ushort4[1]; + + CUTE_HOST_DEVICE static void + reorder(intel::uchar4 const& src0, intel::ushort4& dst0) + { +#if defined(CUTE_ARCH_COPY_XE_ENABLED) + asm ( /* 2 cycles/output register */ + "{\n" + ".decl IN_UB v_type=G type=UB num_elts=64 alias=<%1,0>\n" + ".decl OUT_UW v_type=G type=UW num_elts=64 alias=<%0,0>\n" + ".decl OUT_HF v_type=G type=HF num_elts=64 alias=<%0,0>\n" + "or (M1_NM, 32) OUT_UW(0,0)<1> IN_UB(0,0)<4;2,1> 0x6400:uw\n" + "or (M1_NM, 32) OUT_UW(1,0)<1> IN_UB(0,2)<4;2,1> 0x6400:uw\n" + "add (M1_NM, 32) OUT_HF(0,0)<1> OUT_HF(0,0)<1;1,0> 0xE400:hf\n" + "add (M1_NM, 32) OUT_HF(1,0)<1> OUT_HF(1,0)<1;1,0> 0xE400:hf\n" + "}\n" + : "=rw"(dst0) + : "rw"(src0) + ); +#else + CUTE_INVALID_CONTROL_PATH("Not Xe"); +#endif + } +}; + +template <> +struct Xe_Reorder +{ + using SRegisters = intel::uchar4[1]; + using DRegisters = intel::ushort4[1]; + + CUTE_HOST_DEVICE static void + reorder(intel::uchar4 const& src0, intel::ushort4& dst0) + { +#if defined(CUTE_ARCH_COPY_XE_ENABLED) + asm ( /* 2 cycles/output register */ + "{\n" + ".decl IN_UB v_type=G type=UB num_elts=64 alias=<%1,0>\n" + ".decl OUT_UW v_type=G type=UW num_elts=64 alias=<%0,0>\n" + ".decl OUT_HF v_type=G type=HF num_elts=64 alias=<%0,0>\n" + "xor (M1_NM, 32) OUT_UW(0,0)<1> IN_UB(0,0)<1;1,0> 0x6480:uw\n" + "xor (M1_NM, 32) OUT_UW(1,0)<1> IN_UB(0,32)<1;1,0> 0x6480:uw\n" + "add (M1_NM, 32) OUT_HF(0,0)<1> OUT_HF(0,0)<1;1,0> 0xE480:hf\n" + "add (M1_NM, 32) OUT_HF(1,0)<1> OUT_HF(1,0)<1;1,0> 0xE480:hf\n" + "}\n" + : "=rw"(dst0) + : "rw"(src0) + ); +#else + CUTE_INVALID_CONTROL_PATH("Not Xe"); +#endif + } +}; + +template <> +struct Xe_Reorder +{ + using SRegisters = intel::uchar4[1]; + using DRegisters = intel::ushort4[1]; + + CUTE_HOST_DEVICE static void + reorder(intel::uchar4 const& src0, intel::ushort4& dst0) + { +#if defined(CUTE_ARCH_COPY_XE_ENABLED) + asm ( /* 2 cycles/output register */ + "{\n" + ".decl IN_UB v_type=G type=UB num_elts=64 alias=<%1,0>\n" + ".decl OUT_UW v_type=G type=UW num_elts=64 alias=<%0,0>\n" + ".decl OUT_HF v_type=G type=HF num_elts=64 alias=<%0,0>\n" + "xor (M1_NM, 32) OUT_UW(0,0)<1> IN_UB(0,0)<4;2,1> 0x6480:uw\n" + "xor (M1_NM, 32) OUT_UW(1,0)<1> IN_UB(0,2)<4;2,1> 0x6480:uw\n" + "add (M1_NM, 32) OUT_HF(0,0)<1> OUT_HF(0,0)<1;1,0> 0xE480:hf\n" + "add (M1_NM, 32) OUT_HF(1,0)<1> OUT_HF(1,0)<1;1,0> 0xE480:hf\n" + "}\n" + : "=rw"(dst0) + : "rw"(src0) + ); +#else + CUTE_INVALID_CONTROL_PATH("Not Xe"); +#endif + } +}; + +// Common uint8 -> bfloat16 conversion sequence, after unpacking bytes to words. +// This is defined as a macro as the compiler produces more efficient code +// when inline asm blocks are merged. +#define CUTE_XE_REORDER_U8_BF16_SEQ \ + ".decl OUT_HF v_type=G type=HF num_elts=64 alias=<%0,0>\n" \ + ".decl OUT_BF v_type=G type=BF num_elts=64 alias=<%0,0>\n" \ + "mul (M1_NM, 32) OUT_BF(0,0)<1> OUT_BF(0,0)<1;1,0> 0x7E000000:f\n" \ + "mul (M1_NM, 32) OUT_BF(1,0)<1> OUT_BF(1,0)<1;1,0> 0x7E000000:f\n" \ + "mul (M1_NM, 32) OUT_HF(0,0)<1> OUT_HF(0,0)<1;1,0> 0x4000:hf\n" \ + "mul (M1_NM, 32) OUT_HF(1,0)<1> OUT_HF(1,0)<1;1,0> 0x4000:hf\n" + + +template <> +struct Xe_Reorder +{ + using SRegisters = intel::uchar4[1]; + using DRegisters = intel::ushort4[1]; + + CUTE_HOST_DEVICE static void + reorder(intel::uchar4 const& src0, intel::ushort4& dst0) + { +#if defined(CUTE_ARCH_COPY_XE_ENABLED) + asm ( /* 4 cycles/output register */ + "{\n" + ".decl IN_UB v_type=G type=UB num_elts=64 alias=<%1,0>\n" + ".decl OUT_UW v_type=G type=UW num_elts=64 alias=<%0,0>\n" + "mov (M1_NM, 32) OUT_UW(0,0)<1> IN_UB(0,0)<1;1,0>\n" + "mov (M1_NM, 32) OUT_UW(1,0)<1> IN_UB(0,32)<1;1,0>\n" + CUTE_XE_REORDER_U8_BF16_SEQ + "}\n" + : "=rw"(dst0) + : "rw"(src0) + ); +#else + CUTE_INVALID_CONTROL_PATH("Not Xe"); +#endif + } +}; + +template <> +struct Xe_Reorder +{ + using SRegisters = intel::uchar4[1]; + using DRegisters = intel::ushort4[1]; + + CUTE_HOST_DEVICE static void + reorder(intel::uchar4 const& src0, intel::ushort4& dst0) + { +#if defined(CUTE_ARCH_COPY_XE_ENABLED) + asm ( /* 4 cycles/output register */ + "{\n" + ".decl IN_UB v_type=G type=UB num_elts=64 alias=<%1,0>\n" + ".decl OUT_UW v_type=G type=UW num_elts=64 alias=<%0,0>\n" + "mov (M1_NM, 32) OUT_UW(0,0)<1> IN_UB(0,0)<4;2,1>\n" + "mov (M1_NM, 32) OUT_UW(1,0)<1> IN_UB(0,2)<4;2,1>\n" + CUTE_XE_REORDER_U8_BF16_SEQ + "}\n" + : "=rw"(dst0) + : "rw"(src0) + ); +#else + CUTE_INVALID_CONTROL_PATH("Not Xe"); +#endif + } +}; + + +// Common int8 -> bfloat16 conversion sequence, after unpacking bytes to words. +#define CUTE_XE_REORDER_S8_BF16_SEQ \ + ".decl OUT_HF v_type=G type=HF num_elts=64 alias=<%0,0>\n" \ + ".decl OUT_BF v_type=G type=BF num_elts=64 alias=<%0,0>\n" \ + ".decl F_7E000000 v_type=G type=F num_elts=1 alias=<%2,0>\n" \ + ".decl F_BF000000 v_type=G type=F num_elts=1 alias=<%3,0>\n" \ + "mad (M1_NM, 32) OUT_BF(0,0)<1> F_7E000000(0,0)<0;1,0> OUT_BF(0,0)<1;1,0> F_BF000000(0,0)<0;1,0>\n" \ + "mad (M1_NM, 32) OUT_BF(1,0)<1> F_7E000000(0,0)<0;1,0> OUT_BF(1,0)<1;1,0> F_BF000000(0,0)<0;1,0>\n" \ + "mul (M1_NM, 32) OUT_HF(0,0)<1> OUT_HF(0,0)<1;1,0> 0x4000:hf\n" \ + "mul (M1_NM, 32) OUT_HF(1,0)<1> OUT_HF(1,0)<1;1,0> 0x4000:hf\n" + + +template <> +struct Xe_Reorder +{ + using SRegisters = intel::uchar4[1]; + using DRegisters = intel::ushort4[1]; + + CUTE_HOST_DEVICE static void + reorder(intel::uchar4 const& src0, intel::ushort4& dst0) + { +#if defined(CUTE_ARCH_COPY_XE_ENABLED) + const uint32_t scale = 0x7E000000; + const uint32_t shift = 0xBF000000; + asm ( /* 4 cycles/output register */ + "{\n" + ".decl IN_UB v_type=G type=UB num_elts=64 alias=<%1,0>\n" + ".decl OUT_UW v_type=G type=UW num_elts=64 alias=<%0,0>\n" + "xor (M1_NM, 32) OUT_UW(0,0)<1> IN_UB(0,0)<1;1,0> 0x80:uw\n" + "xor (M1_NM, 32) OUT_UW(1,0)<1> IN_UB(0,32)<1;1,0> 0x80:uw\n" + CUTE_XE_REORDER_S8_BF16_SEQ + "}\n" + : "=rw"(dst0) + : "rw"(src0), "rw.u"(scale), "rw.u"(shift) + ); +#else + CUTE_INVALID_CONTROL_PATH("Not Xe"); +#endif + } +}; + +template <> +struct Xe_Reorder +{ + using SRegisters = intel::uchar4[1]; + using DRegisters = intel::ushort4[1]; + + CUTE_HOST_DEVICE static void + reorder(intel::uchar4 const& src0, intel::ushort4& dst0) + { +#if defined(CUTE_ARCH_COPY_XE_ENABLED) + const uint32_t scale = 0x7E000000; + const uint32_t shift = 0xBF000000; + asm ( /* 4 cycles/output register */ + "{\n" + ".decl IN_UB v_type=G type=UB num_elts=64 alias=<%1,0>\n" + ".decl OUT_UW v_type=G type=UW num_elts=64 alias=<%0,0>\n" + "xor (M1_NM, 32) OUT_UW(0,0)<1> IN_UB(0,0)<4;2,1> 0x80:uw\n" + "xor (M1_NM, 32) OUT_UW(1,0)<1> IN_UB(0,2)<4;2,1> 0x80:uw\n" + CUTE_XE_REORDER_S8_BF16_SEQ + "}\n" + : "=rw"(dst0) + : "rw"(src0), "rw.u"(scale), "rw.u"(shift) + ); +#else + CUTE_INVALID_CONTROL_PATH("Not Xe"); +#endif + } +}; + +template <> +struct Xe_Reorder +{ + using SRegisters = intel::uchar4[1]; + using DRegisters = intel::ushort4[1]; + + CUTE_HOST_DEVICE static void + reorder(intel::uchar4 const& src0, intel::ushort4& dst0) + { +#if defined(CUTE_ARCH_COPY_XE_ENABLED) + asm ( /* 1 cycle/output register */ + "{\n" + ".decl IN_UB v_type=G type=UB num_elts=64 alias=<%1,0>\n" + ".decl OUT_UW v_type=G type=UW num_elts=64 alias=<%0,0>\n" + "shl (M1_NM, 32) OUT_UW(0,0)<1> IN_UB(0,0)<1;1,0> 8:uw\n" + "shl (M1_NM, 32) OUT_UW(1,0)<1> IN_UB(0,32)<1;1,0> 8:uw\n" + "}\n" + : "=rw"(dst0) + : "rw"(src0) + ); +#else + CUTE_INVALID_CONTROL_PATH("Not Xe"); +#endif + } +}; + +template <> +struct Xe_Reorder +{ + using SRegisters = intel::uchar4[1]; + using DRegisters = intel::ushort4[1]; + + CUTE_HOST_DEVICE static void + reorder(intel::uchar4 const& src0, intel::ushort4& dst0) + { +#if defined(CUTE_ARCH_COPY_XE_ENABLED) + asm ( /* 1 cycle/output register */ + "{\n" + ".decl IN_UB v_type=G type=UB num_elts=64 alias=<%1,0>\n" + ".decl OUT_UW v_type=G type=UW num_elts=64 alias=<%0,0>\n" + "shl (M1_NM, 32) OUT_UW(0,0)<1> IN_UB(0,0)<4;2,1> 8:uw\n" + "shl (M1_NM, 32) OUT_UW(1,0)<1> IN_UB(0,2)<4;2,1> 8:uw\n" + "}\n" + : "=rw"(dst0) + : "rw"(src0) + ); +#else + CUTE_INVALID_CONTROL_PATH("Not Xe"); +#endif + } +}; + +// Common e5m2 -> bfloat16 conversion sequence, after shl by 8. +#define CUTE_XE_REORDER_E5M2_BF16_SEQ \ + ".decl OUT_W v_type=G type=W num_elts=64 alias=<%0,0>\n" \ + ".decl OUT_UD v_type=G type=UD num_elts=32 alias=<%0,0>\n" \ + ".decl OUT_BF v_type=G type=BF num_elts=64 alias=<%0,0>\n" \ + "asr (M1_NM, 32) OUT_W(0,0)<1> OUT_W(0,0)<1;1,0> 3:uw\n" \ + "asr (M1_NM, 32) OUT_W(1,0)<1> OUT_W(1,0)<1;1,0> 3:uw\n" \ + "and (M1_NM, 32) OUT_UD(0,0)<1> OUT_UD(0,0)<1;1,0> 0x8FFF8FFF:ud\n" \ + "mul (M1_NM, 32) OUT_BF(0,0)<1> OUT_BF(0,0)<1;1,0> 0x77800000:f\n" \ + "mul (M1_NM, 32) OUT_BF(1,0)<1> OUT_BF(1,0)<1;1,0> 0x77800000:f\n" + +template <> +struct Xe_Reorder +{ + using SRegisters = intel::uchar4[1]; + using DRegisters = intel::ushort4[1]; + + CUTE_HOST_DEVICE static void + reorder(intel::uchar4 const& src0, intel::ushort4& dst0) + { +#if defined(CUTE_ARCH_COPY_XE_ENABLED) + asm ( /* 5 cycles/output register */ + "{\n" + ".decl IN_UB v_type=G type=UB num_elts=64 alias=<%1,0>\n" + ".decl OUT_UW v_type=G type=UW num_elts=64 alias=<%0,0>\n" + "shl (M1_NM, 32) OUT_UW(0,0)<1> IN_UB(0,0)<1;1,0> 8:uw\n" + "shl (M1_NM, 32) OUT_UW(1,0)<1> IN_UB(0,32)<1;1,0> 8:uw\n" + CUTE_XE_REORDER_E5M2_BF16_SEQ + "}\n" + : "=rw"(dst0) + : "rw"(src0) + ); +#else + CUTE_INVALID_CONTROL_PATH("Not Xe"); +#endif + } +}; + +template <> +struct Xe_Reorder +{ + using SRegisters = intel::uchar4[1]; + using DRegisters = intel::ushort4[1]; + + CUTE_HOST_DEVICE static void + reorder(intel::uchar4 const& src0, intel::ushort4& dst0) + { +#if defined(CUTE_ARCH_COPY_XE_ENABLED) + asm ( /* 5 cycles/output register */ + "{\n" + ".decl IN_UB v_type=G type=UB num_elts=64 alias=<%1,0>\n" + ".decl OUT_UW v_type=G type=UW num_elts=64 alias=<%0,0>\n" + "shl (M1_NM, 32) OUT_UW(0,0)<1> IN_UB(0,0)<4;2,1> 8:uw\n" + "shl (M1_NM, 32) OUT_UW(1,0)<1> IN_UB(0,2)<4;2,1> 8:uw\n" + CUTE_XE_REORDER_E5M2_BF16_SEQ + "}\n" + : "=rw"(dst0) + : "rw"(src0) + ); +#else + CUTE_INVALID_CONTROL_PATH("Not Xe"); +#endif + } +}; + + +// Common e4m3 -> half conversion sequence, after shl by 8. +#define CUTE_XE_REORDER_E4M3_HALF_SEQ \ + ".decl OUT_W v_type=G type=W num_elts=64 alias=<%0,0>\n" \ + ".decl OUT_UD v_type=G type=UD num_elts=32 alias=<%0,0>\n" \ + ".decl OUT_HF v_type=G type=HF num_elts=64 alias=<%0,0>\n" \ + "asr (M1_NM, 32) OUT_W(0,0)<1> OUT_W(0,0)<1;1,0> 1:uw\n" \ + "asr (M1_NM, 32) OUT_W(1,0)<1> OUT_W(1,0)<1;1,0> 1:uw\n" \ + "and (M1_NM, 32) OUT_UD(0,0)<1> OUT_UD(0,0)<1;1,0> 0xBFFFBFFF:ud\n" \ + /* If no NaN inputs, the rest of the sequence can be replaced with: */ \ + /* "mul (M1_NM, 32) OUT_HF(0,0)<1> OUT_HF(0,0)<1;1,0> 0x5C00:hf\n" */ \ + /* "mul (M1_NM, 32) OUT_HF(1,0)<1> OUT_HF(1,0)<1;1,0> 0x5C00:hf\n" */ \ + "mul (M1_NM, 32) OUT_HF(0,0)<1> OUT_HF(0,0)<1;1,0> 0x7880:hf\n" \ + "mul (M1_NM, 32) OUT_HF(1,0)<1> OUT_HF(1,0)<1;1,0> 0x7880:hf\n" \ + "mul (M1_NM, 32) OUT_HF(0,0)<1> OUT_HF(0,0)<1;1,0> 0x1F1C:hf\n" \ + "mul (M1_NM, 32) OUT_HF(1,0)<1> OUT_HF(1,0)<1;1,0> 0x1F1C:hf\n" \ + "mad (M1_NM, 32) OUT_HF(0,0)<1> 0x0:hf OUT_HF(0,0)<1;1,0> OUT_HF(0,0)<1;1,0>\n" \ + "mad (M1_NM, 32) OUT_HF(1,0)<1> 0x0:hf OUT_HF(1,0)<1;1,0> OUT_HF(1,0)<1;1,0>\n" + +template <> +struct Xe_Reorder +{ + using SRegisters = intel::uchar4[1]; + using DRegisters = intel::ushort4[1]; + + CUTE_HOST_DEVICE static void + reorder(intel::uchar4 const& src0, intel::ushort4& dst0) + { +#if defined(CUTE_ARCH_COPY_XE_ENABLED) + asm ( /* 6 cycles/output register */ + "{\n" + ".decl IN_UB v_type=G type=UB num_elts=64 alias=<%1,0>\n" + ".decl OUT_UW v_type=G type=UW num_elts=64 alias=<%0,0>\n" + "shl (M1_NM, 32) OUT_UW(0,0)<1> IN_UB(0,0)<1;1,0> 8:uw\n" + "shl (M1_NM, 32) OUT_UW(1,0)<1> IN_UB(0,32)<1;1,0> 8:uw\n" + CUTE_XE_REORDER_E4M3_HALF_SEQ + "}\n" + : "=rw"(dst0) + : "rw"(src0) + ); +#else + CUTE_INVALID_CONTROL_PATH("Not Xe"); +#endif + } +}; + +template <> +struct Xe_Reorder +{ + using SRegisters = intel::uchar4[1]; + using DRegisters = intel::ushort4[1]; + + CUTE_HOST_DEVICE static void + reorder(intel::uchar4 const& src0, intel::ushort4& dst0) + { +#if defined(CUTE_ARCH_COPY_XE_ENABLED) + asm ( /* 6 cycles/output register */ + "{\n" + ".decl IN_UB v_type=G type=UB num_elts=64 alias=<%1,0>\n" + ".decl OUT_UW v_type=G type=UW num_elts=64 alias=<%0,0>\n" + "shl (M1_NM, 32) OUT_UW(0,0)<1> IN_UB(0,0)<4;2,1> 8:uw\n" + "shl (M1_NM, 32) OUT_UW(1,0)<1> IN_UB(0,2)<4;2,1> 8:uw\n" + CUTE_XE_REORDER_E4M3_HALF_SEQ + "}\n" + : "=rw"(dst0) + : "rw"(src0) + ); +#else + CUTE_INVALID_CONTROL_PATH("Not Xe"); +#endif + } +}; + +// Common e4m3 -> bfloat16 conversion sequence, after shl by 8. +#define CUTE_XE_REORDER_E4M3_BF16_SEQ \ + ".decl OUT_W v_type=G type=W num_elts=64 alias=<%0,0>\n" \ + ".decl OUT_UD v_type=G type=UD num_elts=32 alias=<%0,0>\n" \ + ".decl OUT_HF v_type=G type=HF num_elts=64 alias=<%0,0>\n" \ + ".decl OUT_BF v_type=G type=BF num_elts=64 alias=<%0,0>\n" \ + ".decl NZ_PRED0 v_type=P num_elts=32\n" \ + ".decl NZ_PRED1 v_type=P num_elts=32\n" \ + "asr (M1_NM, 32) OUT_W(0,0)<1> OUT_W(0,0)<1;1,0> 4:uw\n" \ + "asr (M1_NM, 32) OUT_W(1,0)<1> OUT_W(1,0)<1;1,0> 4:uw\n" \ + "and (M1_NM, 32) OUT_UD(0,0)<1> OUT_UD(0,0)<1;1,0> 0x87FF87FF:ud\n" \ + "cmp.ge (M1_NM, 32) NZ_PRED0 (abs)OUT_HF(0,0)<1;1,0> 0x07F0:hf\n" \ + "cmp.ge (M1_NM, 32) NZ_PRED1 (abs)OUT_HF(1,0)<1;1,0> 0x07F0:hf\n" \ + "mul (M1_NM, 32) OUT_BF(0,0)<1> OUT_BF(0,0)<1;1,0> 0x7B800000:f\n" \ + "mul (M1_NM, 32) OUT_BF(1,0)<1> OUT_BF(1,0)<1;1,0> 0x7B800000:f\n" \ + "(NZ_PRED0) mov (M1_NM, 32) OUT_UW(0,0)<1> 0x7FC0:uw\n" \ + "(NZ_PRED1) mov (M1_NM, 32) OUT_UW(1,0)<1> 0x7FC0:uw\n" + +template <> +struct Xe_Reorder +{ + using SRegisters = intel::uchar4[1]; + using DRegisters = intel::ushort4[1]; + + CUTE_HOST_DEVICE static void + reorder(intel::uchar4 const& src0, intel::ushort4& dst0) + { +#if defined(CUTE_ARCH_COPY_XE_ENABLED) + asm ( /* 7 cycles/output register */ + "{\n" + ".decl IN_UB v_type=G type=UB num_elts=64 alias=<%1,0>\n" + ".decl OUT_UW v_type=G type=UW num_elts=64 alias=<%0,0>\n" + "shl (M1_NM, 32) OUT_UW(0,0)<1> IN_UB(0,0)<1;1,0> 8:uw\n" + "shl (M1_NM, 32) OUT_UW(1,0)<1> IN_UB(0,32)<1;1,0> 8:uw\n" + CUTE_XE_REORDER_E4M3_BF16_SEQ + "}\n" + : "=rw"(dst0) + : "rw"(src0) + ); +#else + CUTE_INVALID_CONTROL_PATH("Not Xe"); +#endif + } +}; + +template <> +struct Xe_Reorder +{ + using SRegisters = intel::uchar4[1]; + using DRegisters = intel::ushort4[1]; + + CUTE_HOST_DEVICE static void + reorder(intel::uchar4 const& src0, intel::ushort4& dst0) + { +#if defined(CUTE_ARCH_COPY_XE_ENABLED) + asm ( /* 7 cycles/output register */ + "{\n" + ".decl IN_UB v_type=G type=UB num_elts=64 alias=<%1,0>\n" + ".decl OUT_UW v_type=G type=UW num_elts=64 alias=<%0,0>\n" + "shl (M1_NM, 32) OUT_UW(0,0)<1> IN_UB(0,0)<4;2,1> 8:uw\n" + "shl (M1_NM, 32) OUT_UW(1,0)<1> IN_UB(0,2)<4;2,1> 8:uw\n" + CUTE_XE_REORDER_E4M3_BF16_SEQ + "}\n" + : "=rw"(dst0) + : "rw"(src0) + ); +#else + CUTE_INVALID_CONTROL_PATH("Not Xe"); +#endif + } +}; + +// Common uint4 -> half conversion sequence, after expanding nybbles to words. +#define CUTE_XE_REORDER_U4_HALF_SEQ \ + ".decl OUT_HF v_type=G type=HF num_elts=128 alias=<%0,0>\n" \ + "bfn.xCA (M1_NM, 32) OUT_UW(0,0)<1> 0x6400:uw OUT_UW(0,0)<1;1,0> 0xF:uw\n" \ + "bfn.xCA (M1_NM, 32) OUT_UW(1,0)<1> 0x6400:uw OUT_UW(1,0)<1;1,0> 0xF:uw\n" \ + "bfn.xCA (M1_NM, 32) OUT_UW(2,0)<1> 0x6400:uw OUT_UW(2,0)<1;1,0> 0xF:uw\n" \ + "bfn.xCA (M1_NM, 32) OUT_UW(3,0)<1> 0x6400:uw OUT_UW(3,0)<1;1,0> 0xF:uw\n" \ + "add (M1_NM, 32) OUT_HF(0,0)<1> OUT_HF(0,0)<1;1,0> 0xE400:hf\n" \ + "add (M1_NM, 32) OUT_HF(1,0)<1> OUT_HF(1,0)<1;1,0> 0xE400:hf\n" \ + "add (M1_NM, 32) OUT_HF(2,0)<1> OUT_HF(2,0)<1;1,0> 0xE400:hf\n" \ + "add (M1_NM, 32) OUT_HF(3,0)<1> OUT_HF(3,0)<1;1,0> 0xE400:hf\n" + + +template <> +struct Xe_Reorder +{ + using SRegisters = intel::uchar4[1]; + using DRegisters = intel::ushort8[1]; + + CUTE_HOST_DEVICE static void + reorder(intel::uchar4 const& src0, intel::ushort8& dst0) + { +#if defined(CUTE_ARCH_COPY_XE_ENABLED) + const uint32_t shifts = 0x00040000; + asm ( /* 3 cycles/output register */ + "{\n" + ".decl IN_UB v_type=G type=UB num_elts=64 alias=<%1,0>\n" + ".decl OUT_UW v_type=G type=UW num_elts=128 alias=<%0,0>\n" + ".decl SHIFTS v_type=G type=UW num_elts=2 alias=<%2,0>\n" + "shr (M1_NM, 32) OUT_UW(0,0)<1> IN_UB(0,0)<1;2,0> SHIFTS(0,0)<0;2,1>\n" + "shr (M1_NM, 32) OUT_UW(1,0)<1> IN_UB(0,16)<1;2,0> SHIFTS(0,0)<0;2,1>\n" + "shr (M1_NM, 32) OUT_UW(2,0)<1> IN_UB(0,32)<1;2,0> SHIFTS(0,0)<0;2,1>\n" + "shr (M1_NM, 32) OUT_UW(3,0)<1> IN_UB(0,48)<1;2,0> SHIFTS(0,0)<0;2,1>\n" + CUTE_XE_REORDER_U4_HALF_SEQ + "}\n" + : "=rw"(dst0) + : "rw"(src0), "rw.u"(shifts) + ); +#else + CUTE_INVALID_CONTROL_PATH("Not Xe"); +#endif + } +}; + +template <> +struct Xe_Reorder +{ + using SRegisters = intel::uchar4[1]; + using DRegisters = intel::ushort8[1]; + + CUTE_HOST_DEVICE static void + reorder(intel::uchar4 const& src0, intel::ushort8& dst0) + { +#if defined(CUTE_ARCH_COPY_XE_ENABLED) + const uint32_t shifts = 0x00040000; + asm ( /* 4 cycles/output register */ + "{\n" + ".decl IN_UB v_type=G type=UB num_elts=64 alias=<%1,0>\n" + ".decl OUT_UW v_type=G type=UW num_elts=128 alias=<%0,0>\n" + ".decl SHIFTS v_type=G type=UW num_elts=2 alias=<%2,0>\n" + "shr (M1_NM, 16) OUT_UW(0,0)<2> IN_UB(0,0)<1;2,0> SHIFTS(0,0)<0;2,1>\n" + "shr (M1_NM, 16) OUT_UW(0,1)<2> IN_UB(0,8)<1;2,0> SHIFTS(0,0)<0;2,1>\n" + "shr (M1_NM, 16) OUT_UW(1,0)<2> IN_UB(0,16)<1;2,0> SHIFTS(0,0)<0;2,1>\n" + "shr (M1_NM, 16) OUT_UW(1,1)<2> IN_UB(0,24)<1;2,0> SHIFTS(0,0)<0;2,1>\n" + "shr (M1_NM, 16) OUT_UW(2,0)<2> IN_UB(0,32)<1;2,0> SHIFTS(0,0)<0;2,1>\n" + "shr (M1_NM, 16) OUT_UW(2,1)<2> IN_UB(0,40)<1;2,0> SHIFTS(0,0)<0;2,1>\n" + "shr (M1_NM, 16) OUT_UW(3,0)<2> IN_UB(0,48)<1;2,0> SHIFTS(0,0)<0;2,1>\n" + "shr (M1_NM, 16) OUT_UW(3,1)<2> IN_UB(0,56)<1;2,0> SHIFTS(0,0)<0;2,1>\n" + CUTE_XE_REORDER_U4_HALF_SEQ + "}\n" + : "=rw"(dst0) + : "rw"(src0), "rw.u"(shifts) + ); +#else + CUTE_INVALID_CONTROL_PATH("Not Xe"); +#endif + } +}; + +template <> +struct Xe_Reorder +{ + using SRegisters = intel::uchar4[1]; + using DRegisters = intel::ushort8[1]; + + CUTE_HOST_DEVICE static void + reorder(intel::uchar4 const& src0, intel::ushort8& dst0) + { +#if defined(CUTE_ARCH_COPY_XE_ENABLED) + const uint32_t shifts = 0x00040000; + asm ( /* 3 cycles/output register */ + "{\n" + ".decl IN_UB v_type=G type=UB num_elts=64 alias=<%1,0>\n" + ".decl OUT_UW v_type=G type=UW num_elts=128 alias=<%0,0>\n" + ".decl SHIFTS v_type=G type=UW num_elts=2 alias=<%2,0>\n" + "shr (M1_NM, 32) OUT_UW(0,0)<1> IN_UB(0,0)<4;2,0> SHIFTS(0,0)<0;2,1>\n" + "shr (M1_NM, 32) OUT_UW(1,0)<1> IN_UB(0,1)<4;2,0> SHIFTS(0,0)<0;2,1>\n" + "shr (M1_NM, 32) OUT_UW(2,0)<1> IN_UB(0,2)<4;2,0> SHIFTS(0,0)<0;2,1>\n" + "shr (M1_NM, 32) OUT_UW(3,0)<1> IN_UB(0,3)<4;2,0> SHIFTS(0,0)<0;2,1>\n" + CUTE_XE_REORDER_U4_HALF_SEQ + "}\n" + : "=rw"(dst0) + : "rw"(src0), "rw.u"(shifts) + ); +#else + CUTE_INVALID_CONTROL_PATH("Not Xe"); +#endif + } +}; + +// Common int4 -> half conversion sequence, after expanding nybbles to words. +#define CUTE_XE_REORDER_S4_HALF_SEQ \ + ".decl OUT_HF v_type=G type=HF num_elts=128 alias=<%0,0>\n" \ + "bfn.x6A (M1_NM, 32) OUT_UW(0,0)<1> 0x6408:uw OUT_UW(0,0)<1;1,0> 0xF:uw\n" \ + "bfn.x6A (M1_NM, 32) OUT_UW(1,0)<1> 0x6408:uw OUT_UW(1,0)<1;1,0> 0xF:uw\n" \ + "bfn.x6A (M1_NM, 32) OUT_UW(2,0)<1> 0x6408:uw OUT_UW(2,0)<1;1,0> 0xF:uw\n" \ + "bfn.x6A (M1_NM, 32) OUT_UW(3,0)<1> 0x6408:uw OUT_UW(3,0)<1;1,0> 0xF:uw\n" \ + "add (M1_NM, 32) OUT_HF(0,0)<1> OUT_HF(0,0)<1;1,0> 0xE408:hf\n" \ + "add (M1_NM, 32) OUT_HF(1,0)<1> OUT_HF(1,0)<1;1,0> 0xE408:hf\n" \ + "add (M1_NM, 32) OUT_HF(2,0)<1> OUT_HF(2,0)<1;1,0> 0xE408:hf\n" \ + "add (M1_NM, 32) OUT_HF(3,0)<1> OUT_HF(3,0)<1;1,0> 0xE408:hf\n" + + +template <> +struct Xe_Reorder +{ + using SRegisters = intel::uchar4[1]; + using DRegisters = intel::ushort8[1]; + + CUTE_HOST_DEVICE static void + reorder(intel::uchar4 const& src0, intel::ushort8& dst0) + { +#if defined(CUTE_ARCH_COPY_XE_ENABLED) + const uint32_t shifts = 0x00040000; + asm ( /* 3 cycles/output register */ + "{\n" + ".decl IN_UB v_type=G type=UB num_elts=64 alias=<%1,0>\n" + ".decl OUT_UW v_type=G type=UW num_elts=128 alias=<%0,0>\n" + ".decl SHIFTS v_type=G type=UW num_elts=2 alias=<%2,0>\n" + "shr (M1_NM, 32) OUT_UW(0,0)<1> IN_UB(0,0)<1;2,0> SHIFTS(0,0)<0;2,1>\n" + "shr (M1_NM, 32) OUT_UW(1,0)<1> IN_UB(0,16)<1;2,0> SHIFTS(0,0)<0;2,1>\n" + "shr (M1_NM, 32) OUT_UW(2,0)<1> IN_UB(0,32)<1;2,0> SHIFTS(0,0)<0;2,1>\n" + "shr (M1_NM, 32) OUT_UW(3,0)<1> IN_UB(0,48)<1;2,0> SHIFTS(0,0)<0;2,1>\n" + CUTE_XE_REORDER_S4_HALF_SEQ + "}\n" + : "=rw"(dst0) + : "rw"(src0), "rw.u"(shifts) + ); +#else + CUTE_INVALID_CONTROL_PATH("Not Xe"); +#endif + } +}; + +template <> +struct Xe_Reorder +{ + using SRegisters = intel::uchar4[1]; + using DRegisters = intel::ushort8[1]; + + CUTE_HOST_DEVICE static void + reorder(intel::uchar4 const& src0, intel::ushort8& dst0) + { +#if defined(CUTE_ARCH_COPY_XE_ENABLED) + const uint32_t shifts = 0x00040000; + asm ( /* 4 cycles/output register */ + "{\n" + ".decl IN_UB v_type=G type=UB num_elts=64 alias=<%1,0>\n" + ".decl OUT_UW v_type=G type=UW num_elts=128 alias=<%0,0>\n" + ".decl SHIFTS v_type=G type=UW num_elts=2 alias=<%2,0>\n" + "shr (M1_NM, 16) OUT_UW(0,0)<2> IN_UB(0,0)<1;2,0> SHIFTS(0,0)<0;2,1>\n" + "shr (M1_NM, 16) OUT_UW(0,1)<2> IN_UB(0,8)<1;2,0> SHIFTS(0,0)<0;2,1>\n" + "shr (M1_NM, 16) OUT_UW(1,0)<2> IN_UB(0,16)<1;2,0> SHIFTS(0,0)<0;2,1>\n" + "shr (M1_NM, 16) OUT_UW(1,1)<2> IN_UB(0,24)<1;2,0> SHIFTS(0,0)<0;2,1>\n" + "shr (M1_NM, 16) OUT_UW(2,0)<2> IN_UB(0,32)<1;2,0> SHIFTS(0,0)<0;2,1>\n" + "shr (M1_NM, 16) OUT_UW(2,1)<2> IN_UB(0,40)<1;2,0> SHIFTS(0,0)<0;2,1>\n" + "shr (M1_NM, 16) OUT_UW(3,0)<2> IN_UB(0,48)<1;2,0> SHIFTS(0,0)<0;2,1>\n" + "shr (M1_NM, 16) OUT_UW(3,1)<2> IN_UB(0,56)<1;2,0> SHIFTS(0,0)<0;2,1>\n" + CUTE_XE_REORDER_S4_HALF_SEQ + "}\n" + : "=rw"(dst0) + : "rw"(src0), "rw.u"(shifts) + ); +#else + CUTE_INVALID_CONTROL_PATH("Not Xe"); +#endif + } +}; + +template <> +struct Xe_Reorder +{ + using SRegisters = intel::uchar4[1]; + using DRegisters = intel::ushort8[1]; + + CUTE_HOST_DEVICE static void + reorder(intel::uchar4 const& src0, intel::ushort8& dst0) + { +#if defined(CUTE_ARCH_COPY_XE_ENABLED) + const uint32_t shifts = 0x00040000; + asm ( /* 3 cycles/output register */ + "{\n" + ".decl IN_UB v_type=G type=UB num_elts=64 alias=<%1,0>\n" + ".decl OUT_UW v_type=G type=UW num_elts=128 alias=<%0,0>\n" + ".decl SHIFTS v_type=G type=UW num_elts=2 alias=<%2,0>\n" + "shr (M1_NM, 32) OUT_UW(0,0)<1> IN_UB(0,0)<4;2,0> SHIFTS(0,0)<0;2,1>\n" + "shr (M1_NM, 32) OUT_UW(1,0)<1> IN_UB(0,1)<4;2,0> SHIFTS(0,0)<0;2,1>\n" + "shr (M1_NM, 32) OUT_UW(2,0)<1> IN_UB(0,2)<4;2,0> SHIFTS(0,0)<0;2,1>\n" + "shr (M1_NM, 32) OUT_UW(3,0)<1> IN_UB(0,3)<4;2,0> SHIFTS(0,0)<0;2,1>\n" + CUTE_XE_REORDER_S4_HALF_SEQ + "}\n" + : "=rw"(dst0) + : "rw"(src0), "rw.u"(shifts) + ); +#else + CUTE_INVALID_CONTROL_PATH("Not Xe"); +#endif + } +}; + +// Common uint4 -> bfloat16 conversion sequence, after expanding nybbles to words. +#define CUTE_XE_REORDER_U4_BF16_SEQ \ + ".decl OUT_BF v_type=G type=BF num_elts=128 alias=<%0,0>\n" \ + "bfn.xCA (M1_NM, 32) OUT_UW(0,0)<1> 0x4300:uw OUT_UW(0,0)<1;1,0> 0xF:uw\n" \ + "bfn.xCA (M1_NM, 32) OUT_UW(1,0)<1> 0x4300:uw OUT_UW(1,0)<1;1,0> 0xF:uw\n" \ + "bfn.xCA (M1_NM, 32) OUT_UW(2,0)<1> 0x4300:uw OUT_UW(2,0)<1;1,0> 0xF:uw\n" \ + "bfn.xCA (M1_NM, 32) OUT_UW(3,0)<1> 0x4300:uw OUT_UW(3,0)<1;1,0> 0xF:uw\n" \ + "add (M1_NM, 32) OUT_BF(0,0)<1> OUT_BF(0,0)<1;1,0> 0xC3000000:f\n" \ + "add (M1_NM, 32) OUT_BF(1,0)<1> OUT_BF(1,0)<1;1,0> 0xC3000000:f\n" \ + "add (M1_NM, 32) OUT_BF(2,0)<1> OUT_BF(2,0)<1;1,0> 0xC3000000:f\n" \ + "add (M1_NM, 32) OUT_BF(3,0)<1> OUT_BF(3,0)<1;1,0> 0xC3000000:f\n" + + +template <> +struct Xe_Reorder +{ + using SRegisters = intel::uchar4[1]; + using DRegisters = intel::ushort8[1]; + + CUTE_HOST_DEVICE static void + reorder(intel::uchar4 const& src0, intel::ushort8& dst0) + { +#if defined(CUTE_ARCH_COPY_XE_ENABLED) + const uint32_t shifts = 0x00040000; + asm ( /* 4 cycles/output register */ + "{\n" + ".decl IN_UB v_type=G type=UB num_elts=64 alias=<%1,0>\n" + ".decl OUT_UW v_type=G type=UW num_elts=128 alias=<%0,0>\n" + ".decl SHIFTS v_type=G type=UW num_elts=2 alias=<%2,0>\n" + "shr (M1_NM, 32) OUT_UW(0,0)<1> IN_UB(0,0)<1;2,0> SHIFTS(0,0)<0;2,1>\n" + "shr (M1_NM, 32) OUT_UW(1,0)<1> IN_UB(0,16)<1;2,0> SHIFTS(0,0)<0;2,1>\n" + "shr (M1_NM, 32) OUT_UW(2,0)<1> IN_UB(0,32)<1;2,0> SHIFTS(0,0)<0;2,1>\n" + "shr (M1_NM, 32) OUT_UW(3,0)<1> IN_UB(0,48)<1;2,0> SHIFTS(0,0)<0;2,1>\n" + CUTE_XE_REORDER_U4_BF16_SEQ + "}\n" + : "=rw"(dst0) + : "rw"(src0), "rw.u"(shifts) + ); +#else + CUTE_INVALID_CONTROL_PATH("Not Xe"); +#endif + } +}; + +template <> +struct Xe_Reorder +{ + using SRegisters = intel::uchar4[1]; + using DRegisters = intel::ushort8[1]; + + CUTE_HOST_DEVICE static void + reorder(intel::uchar4 const& src0, intel::ushort8& dst0) + { +#if defined(CUTE_ARCH_COPY_XE_ENABLED) + const uint32_t shifts = 0x00040000; + asm ( /* 5 cycles/output register */ + "{\n" + ".decl IN_UB v_type=G type=UB num_elts=64 alias=<%1,0>\n" + ".decl OUT_UW v_type=G type=UW num_elts=128 alias=<%0,0>\n" + ".decl SHIFTS v_type=G type=UW num_elts=2 alias=<%2,0>\n" + "shr (M1_NM, 16) OUT_UW(0,0)<2> IN_UB(0,0)<1;2,0> SHIFTS(0,0)<0;2,1>\n" + "shr (M1_NM, 16) OUT_UW(0,1)<2> IN_UB(0,8)<1;2,0> SHIFTS(0,0)<0;2,1>\n" + "shr (M1_NM, 16) OUT_UW(1,0)<2> IN_UB(0,16)<1;2,0> SHIFTS(0,0)<0;2,1>\n" + "shr (M1_NM, 16) OUT_UW(1,1)<2> IN_UB(0,24)<1;2,0> SHIFTS(0,0)<0;2,1>\n" + "shr (M1_NM, 16) OUT_UW(2,0)<2> IN_UB(0,32)<1;2,0> SHIFTS(0,0)<0;2,1>\n" + "shr (M1_NM, 16) OUT_UW(2,1)<2> IN_UB(0,40)<1;2,0> SHIFTS(0,0)<0;2,1>\n" + "shr (M1_NM, 16) OUT_UW(3,0)<2> IN_UB(0,48)<1;2,0> SHIFTS(0,0)<0;2,1>\n" + "shr (M1_NM, 16) OUT_UW(3,1)<2> IN_UB(0,56)<1;2,0> SHIFTS(0,0)<0;2,1>\n" + CUTE_XE_REORDER_U4_BF16_SEQ + "}\n" + : "=rw"(dst0) + : "rw"(src0), "rw.u"(shifts) + ); +#else + CUTE_INVALID_CONTROL_PATH("Not Xe"); +#endif + } +}; + +template <> +struct Xe_Reorder +{ + using SRegisters = intel::uchar4[1]; + using DRegisters = intel::ushort8[1]; + + CUTE_HOST_DEVICE static void + reorder(intel::uchar4 const& src0, intel::ushort8& dst0) + { +#if defined(CUTE_ARCH_COPY_XE_ENABLED) + const uint32_t shifts = 0x00040000; + asm ( /* 4 cycles/output register */ + "{\n" + ".decl IN_UB v_type=G type=UB num_elts=64 alias=<%1,0>\n" + ".decl OUT_UW v_type=G type=UW num_elts=128 alias=<%0,0>\n" + ".decl SHIFTS v_type=G type=UW num_elts=2 alias=<%2,0>\n" + "shr (M1_NM, 32) OUT_UW(0,0)<1> IN_UB(0,0)<4;2,0> SHIFTS(0,0)<0;2,1>\n" + "shr (M1_NM, 32) OUT_UW(1,0)<1> IN_UB(0,1)<4;2,0> SHIFTS(0,0)<0;2,1>\n" + "shr (M1_NM, 32) OUT_UW(2,0)<1> IN_UB(0,2)<4;2,0> SHIFTS(0,0)<0;2,1>\n" + "shr (M1_NM, 32) OUT_UW(3,0)<1> IN_UB(0,3)<4;2,0> SHIFTS(0,0)<0;2,1>\n" + CUTE_XE_REORDER_U4_BF16_SEQ + "}\n" + : "=rw"(dst0) + : "rw"(src0), "rw.u"(shifts) + ); +#else + CUTE_INVALID_CONTROL_PATH("Not Xe"); +#endif + } +}; + +// Common int4 -> bfloat16 conversion sequence, after expanding nybbles to words. +#define CUTE_XE_REORDER_S4_BF16_SEQ \ + ".decl OUT_BF v_type=G type=BF num_elts=128 alias=<%0,0>\n" \ + "bfn.x6A (M1_NM, 32) OUT_UW(0,0)<1> 0x4308:uw OUT_UW(0,0)<1;1,0> 0xF:uw\n" \ + "bfn.x6A (M1_NM, 32) OUT_UW(1,0)<1> 0x4308:uw OUT_UW(1,0)<1;1,0> 0xF:uw\n" \ + "bfn.x6A (M1_NM, 32) OUT_UW(2,0)<1> 0x4308:uw OUT_UW(2,0)<1;1,0> 0xF:uw\n" \ + "bfn.x6A (M1_NM, 32) OUT_UW(3,0)<1> 0x4308:uw OUT_UW(3,0)<1;1,0> 0xF:uw\n" \ + "add (M1_NM, 32) OUT_BF(0,0)<1> OUT_BF(0,0)<1;1,0> 0xC3080000:f\n" \ + "add (M1_NM, 32) OUT_BF(1,0)<1> OUT_BF(1,0)<1;1,0> 0xC3080000:f\n" \ + "add (M1_NM, 32) OUT_BF(2,0)<1> OUT_BF(2,0)<1;1,0> 0xC3080000:f\n" \ + "add (M1_NM, 32) OUT_BF(3,0)<1> OUT_BF(3,0)<1;1,0> 0xC3080000:f\n" + + +template <> +struct Xe_Reorder +{ + using SRegisters = intel::uchar4[1]; + using DRegisters = intel::ushort8[1]; + + CUTE_HOST_DEVICE static void + reorder(intel::uchar4 const& src0, intel::ushort8& dst0) + { +#if defined(CUTE_ARCH_COPY_XE_ENABLED) + const uint32_t shifts = 0x00040000; + asm ( /* 4 cycles/output register */ + "{\n" + ".decl IN_UB v_type=G type=UB num_elts=64 alias=<%1,0>\n" + ".decl OUT_UW v_type=G type=UW num_elts=128 alias=<%0,0>\n" + ".decl SHIFTS v_type=G type=UW num_elts=2 alias=<%2,0>\n" + "shr (M1_NM, 32) OUT_UW(0,0)<1> IN_UB(0,0)<1;2,0> SHIFTS(0,0)<0;2,1>\n" + "shr (M1_NM, 32) OUT_UW(1,0)<1> IN_UB(0,16)<1;2,0> SHIFTS(0,0)<0;2,1>\n" + "shr (M1_NM, 32) OUT_UW(2,0)<1> IN_UB(0,32)<1;2,0> SHIFTS(0,0)<0;2,1>\n" + "shr (M1_NM, 32) OUT_UW(3,0)<1> IN_UB(0,48)<1;2,0> SHIFTS(0,0)<0;2,1>\n" + CUTE_XE_REORDER_S4_BF16_SEQ + "}\n" + : "=rw"(dst0) + : "rw"(src0), "rw.u"(shifts) + ); +#else + CUTE_INVALID_CONTROL_PATH("Not Xe"); +#endif + } +}; + +template <> +struct Xe_Reorder +{ + using SRegisters = intel::uchar4[1]; + using DRegisters = intel::ushort8[1]; + + CUTE_HOST_DEVICE static void + reorder(intel::uchar4 const& src0, intel::ushort8& dst0) + { +#if defined(CUTE_ARCH_COPY_XE_ENABLED) + const uint32_t shifts = 0x00040000; + asm ( /* 5 cycles/output register */ + "{\n" + ".decl IN_UB v_type=G type=UB num_elts=64 alias=<%1,0>\n" + ".decl OUT_UW v_type=G type=UW num_elts=128 alias=<%0,0>\n" + ".decl SHIFTS v_type=G type=UW num_elts=2 alias=<%2,0>\n" + "shr (M1_NM, 16) OUT_UW(0,0)<2> IN_UB(0,0)<1;2,0> SHIFTS(0,0)<0;2,1>\n" + "shr (M1_NM, 16) OUT_UW(0,1)<2> IN_UB(0,8)<1;2,0> SHIFTS(0,0)<0;2,1>\n" + "shr (M1_NM, 16) OUT_UW(1,0)<2> IN_UB(0,16)<1;2,0> SHIFTS(0,0)<0;2,1>\n" + "shr (M1_NM, 16) OUT_UW(1,1)<2> IN_UB(0,24)<1;2,0> SHIFTS(0,0)<0;2,1>\n" + "shr (M1_NM, 16) OUT_UW(2,0)<2> IN_UB(0,32)<1;2,0> SHIFTS(0,0)<0;2,1>\n" + "shr (M1_NM, 16) OUT_UW(2,1)<2> IN_UB(0,40)<1;2,0> SHIFTS(0,0)<0;2,1>\n" + "shr (M1_NM, 16) OUT_UW(3,0)<2> IN_UB(0,48)<1;2,0> SHIFTS(0,0)<0;2,1>\n" + "shr (M1_NM, 16) OUT_UW(3,1)<2> IN_UB(0,56)<1;2,0> SHIFTS(0,0)<0;2,1>\n" + CUTE_XE_REORDER_S4_BF16_SEQ + "}\n" + : "=rw"(dst0) + : "rw"(src0), "rw.u"(shifts) + ); +#else + CUTE_INVALID_CONTROL_PATH("Not Xe"); +#endif + } +}; + +template <> +struct Xe_Reorder +{ + using SRegisters = intel::uchar4[1]; + using DRegisters = intel::ushort8[1]; + + CUTE_HOST_DEVICE static void + reorder(intel::uchar4 const& src0, intel::ushort8& dst0) + { +#if defined(CUTE_ARCH_COPY_XE_ENABLED) + const uint32_t shifts = 0x00040000; + asm ( /* 4 cycles/output register */ + "{\n" + ".decl IN_UB v_type=G type=UB num_elts=64 alias=<%1,0>\n" + ".decl OUT_UW v_type=G type=UW num_elts=128 alias=<%0,0>\n" + ".decl SHIFTS v_type=G type=UW num_elts=2 alias=<%2,0>\n" + "shr (M1_NM, 32) OUT_UW(0,0)<1> IN_UB(0,0)<4;2,0> SHIFTS(0,0)<0;2,1>\n" + "shr (M1_NM, 32) OUT_UW(1,0)<1> IN_UB(0,1)<4;2,0> SHIFTS(0,0)<0;2,1>\n" + "shr (M1_NM, 32) OUT_UW(2,0)<1> IN_UB(0,2)<4;2,0> SHIFTS(0,0)<0;2,1>\n" + "shr (M1_NM, 32) OUT_UW(3,0)<1> IN_UB(0,3)<4;2,0> SHIFTS(0,0)<0;2,1>\n" + CUTE_XE_REORDER_S4_BF16_SEQ + "}\n" + : "=rw"(dst0) + : "rw"(src0), "rw.u"(shifts) + ); +#else + CUTE_INVALID_CONTROL_PATH("Not Xe"); +#endif + } +}; + +// Common e2m1 -> half conversion sequence, after moving nybbles to highest 4 bits of each word. +#define CUTE_XE_REORDER_E2M1_HALF_SEQ \ + ".decl OUT_W v_type=G type=W num_elts=128 alias=<%0,0>\n" \ + ".decl OUT_UD v_type=G type=UD num_elts=64 alias=<%0,0>\n" \ + ".decl OUT_HF v_type=G type=HF num_elts=128 alias=<%0,0>\n" \ + "asr (M1_NM, 32) OUT_W(0,0)<1> OUT_W(0,0)<1;1,0> 3:uw\n" \ + "asr (M1_NM, 32) OUT_W(1,0)<1> OUT_W(1,0)<1;1,0> 3:uw\n" \ + "asr (M1_NM, 32) OUT_W(2,0)<1> OUT_W(2,0)<1;1,0> 3:uw\n" \ + "asr (M1_NM, 32) OUT_W(3,0)<1> OUT_W(3,0)<1;1,0> 3:uw\n" \ + "and (M1_NM, 32) OUT_UD(0,0)<1> OUT_UD(0,0)<1;1,0> 0x8E008E00:ud\n" \ + "and (M1_NM, 32) OUT_UD(2,0)<1> OUT_UD(2,0)<1;1,0> 0x8E008E00:ud\n" \ + "mul (M1_NM, 32) OUT_HF(0,0)<1> OUT_HF(0,0)<1;1,0> 0x7400:hf\n" \ + "mul (M1_NM, 32) OUT_HF(1,0)<1> OUT_HF(1,0)<1;1,0> 0x7400:hf\n" \ + "mul (M1_NM, 32) OUT_HF(2,0)<1> OUT_HF(2,0)<1;1,0> 0x7400:hf\n" \ + "mul (M1_NM, 32) OUT_HF(3,0)<1> OUT_HF(3,0)<1;1,0> 0x7400:hf\n" + +template <> +struct Xe_Reorder +{ + using SRegisters = intel::uchar4[1]; + using DRegisters = intel::ushort8[1]; + + CUTE_HOST_DEVICE static void + reorder(intel::uchar4 const& src0, intel::ushort8& dst0) + { +#if defined(CUTE_ARCH_COPY_XE_ENABLED) + const uint32_t shifts = 0x0008000C; + asm ( /* 4 cycles/output register */ + "{\n" + ".decl IN_UB v_type=G type=UB num_elts=64 alias=<%1,0>\n" + ".decl OUT_UW v_type=G type=UW num_elts=128 alias=<%0,0>\n" + ".decl SHIFTS v_type=G type=UW num_elts=2 alias=<%2,0>\n" + "shl (M1_NM, 32) OUT_UW(0,0)<1> IN_UB(0,0)<1;2,0> SHIFTS(0,0)<0;2,1>\n" + "shl (M1_NM, 32) OUT_UW(1,0)<1> IN_UB(0,16)<1;2,0> SHIFTS(0,0)<0;2,1>\n" + "shl (M1_NM, 32) OUT_UW(2,0)<1> IN_UB(0,32)<1;2,0> SHIFTS(0,0)<0;2,1>\n" + "shl (M1_NM, 32) OUT_UW(3,0)<1> IN_UB(0,48)<1;2,0> SHIFTS(0,0)<0;2,1>\n" + CUTE_XE_REORDER_E2M1_HALF_SEQ + "}\n" + : "=rw"(dst0) + : "rw"(src0), "rw.u"(shifts) + ); +#else + CUTE_INVALID_CONTROL_PATH("Not Xe"); +#endif + } +}; + +template <> +struct Xe_Reorder +{ + using SRegisters = intel::uchar4[1]; + using DRegisters = intel::ushort8[1]; + + CUTE_HOST_DEVICE static void + reorder(intel::uchar4 const& src0, intel::ushort8& dst0) + { +#if defined(CUTE_ARCH_COPY_XE_ENABLED) + const uint32_t shifts = 0x0008000C; + asm ( /* 5 cycles/output register */ + "{\n" + ".decl IN_UB v_type=G type=UB num_elts=64 alias=<%1,0>\n" + ".decl OUT_UW v_type=G type=UW num_elts=128 alias=<%0,0>\n" + ".decl SHIFTS v_type=G type=UW num_elts=2 alias=<%2,0>\n" + "shl (M1_NM, 16) OUT_UW(0,0)<2> IN_UB(0,0)<1;2,0> SHIFTS(0,0)<0;2,1>\n" + "shl (M1_NM, 16) OUT_UW(0,1)<2> IN_UB(0,8)<1;2,0> SHIFTS(0,0)<0;2,1>\n" + "shl (M1_NM, 16) OUT_UW(1,0)<2> IN_UB(0,16)<1;2,0> SHIFTS(0,0)<0;2,1>\n" + "shl (M1_NM, 16) OUT_UW(1,1)<2> IN_UB(0,24)<1;2,0> SHIFTS(0,0)<0;2,1>\n" + "shl (M1_NM, 16) OUT_UW(2,0)<2> IN_UB(0,32)<1;2,0> SHIFTS(0,0)<0;2,1>\n" + "shl (M1_NM, 16) OUT_UW(2,1)<2> IN_UB(0,40)<1;2,0> SHIFTS(0,0)<0;2,1>\n" + "shl (M1_NM, 16) OUT_UW(3,0)<2> IN_UB(0,48)<1;2,0> SHIFTS(0,0)<0;2,1>\n" + "shl (M1_NM, 16) OUT_UW(3,1)<2> IN_UB(0,56)<1;2,0> SHIFTS(0,0)<0;2,1>\n" + CUTE_XE_REORDER_E2M1_HALF_SEQ + "}\n" + : "=rw"(dst0) + : "rw"(src0), "rw.u"(shifts) + ); +#else + CUTE_INVALID_CONTROL_PATH("Not Xe"); +#endif + } +}; + +template <> +struct Xe_Reorder +{ + using SRegisters = intel::uchar4[1]; + using DRegisters = intel::ushort8[1]; + + CUTE_HOST_DEVICE static void + reorder(intel::uchar4 const& src0, intel::ushort8& dst0) + { +#if defined(CUTE_ARCH_COPY_XE_ENABLED) + const uint32_t shifts = 0x0008000C; + asm ( /* 4 cycles/output register */ + "{\n" + ".decl IN_UB v_type=G type=UB num_elts=64 alias=<%1,0>\n" + ".decl OUT_UW v_type=G type=UW num_elts=128 alias=<%0,0>\n" + ".decl SHIFTS v_type=G type=UW num_elts=2 alias=<%2,0>\n" + "shl (M1_NM, 32) OUT_UW(0,0)<1> IN_UB(0,0)<4;2,0> SHIFTS(0,0)<0;2,1>\n" + "shl (M1_NM, 32) OUT_UW(1,0)<1> IN_UB(0,1)<4;2,0> SHIFTS(0,0)<0;2,1>\n" + "shl (M1_NM, 32) OUT_UW(2,0)<1> IN_UB(0,2)<4;2,0> SHIFTS(0,0)<0;2,1>\n" + "shl (M1_NM, 32) OUT_UW(3,0)<1> IN_UB(0,3)<4;2,0> SHIFTS(0,0)<0;2,1>\n" + CUTE_XE_REORDER_E2M1_HALF_SEQ + "}\n" + : "=rw"(dst0) + : "rw"(src0), "rw.u"(shifts) + ); +#else + CUTE_INVALID_CONTROL_PATH("Not Xe"); +#endif + } +}; + +// Common e2m1 -> bfloat16 conversion sequence, after moving nybbles to highest 4 bits of each word. +#define CUTE_XE_REORDER_E2M1_BF16_SEQ \ + ".decl OUT_W v_type=G type=W num_elts=128 alias=<%0,0>\n" \ + ".decl OUT_UD v_type=G type=UD num_elts=64 alias=<%0,0>\n" \ + ".decl OUT_BF v_type=G type=BF num_elts=128 alias=<%0,0>\n" \ + "asr (M1_NM, 32) OUT_W(0,0)<1> OUT_W(0,0)<1;1,0> 6:uw\n" \ + "asr (M1_NM, 32) OUT_W(1,0)<1> OUT_W(1,0)<1;1,0> 6:uw\n" \ + "asr (M1_NM, 32) OUT_W(2,0)<1> OUT_W(2,0)<1;1,0> 6:uw\n" \ + "asr (M1_NM, 32) OUT_W(3,0)<1> OUT_W(3,0)<1;1,0> 6:uw\n" \ + "and (M1_NM, 32) OUT_UD(0,0)<1> OUT_UD(0,0)<1;1,0> 0x81C081C0:ud\n" \ + "and (M1_NM, 32) OUT_UD(2,0)<1> OUT_UD(2,0)<1;1,0> 0x81C081C0:ud\n" \ + "mul (M1_NM, 32) OUT_BF(0,0)<1> OUT_BF(0,0)<1;1,0> 0x7E800000:f\n" \ + "mul (M1_NM, 32) OUT_BF(1,0)<1> OUT_BF(1,0)<1;1,0> 0x7E800000:f\n" \ + "mul (M1_NM, 32) OUT_BF(2,0)<1> OUT_BF(2,0)<1;1,0> 0x7E800000:f\n" \ + "mul (M1_NM, 32) OUT_BF(3,0)<1> OUT_BF(3,0)<1;1,0> 0x7E800000:f\n" + +template <> +struct Xe_Reorder +{ + using SRegisters = intel::uchar4[1]; + using DRegisters = intel::ushort8[1]; + + CUTE_HOST_DEVICE static void + reorder(intel::uchar4 const& src0, intel::ushort8& dst0) + { +#if defined(CUTE_ARCH_COPY_XE_ENABLED) + const uint32_t shifts = 0x0008000C; + asm ( /* 5 cycles/output register */ + "{\n" + ".decl IN_UB v_type=G type=UB num_elts=64 alias=<%1,0>\n" + ".decl OUT_UW v_type=G type=UW num_elts=128 alias=<%0,0>\n" + ".decl SHIFTS v_type=G type=UW num_elts=2 alias=<%2,0>\n" + "shl (M1_NM, 32) OUT_UW(0,0)<1> IN_UB(0,0)<1;2,0> SHIFTS(0,0)<0;2,1>\n" + "shl (M1_NM, 32) OUT_UW(1,0)<1> IN_UB(0,16)<1;2,0> SHIFTS(0,0)<0;2,1>\n" + "shl (M1_NM, 32) OUT_UW(2,0)<1> IN_UB(0,32)<1;2,0> SHIFTS(0,0)<0;2,1>\n" + "shl (M1_NM, 32) OUT_UW(3,0)<1> IN_UB(0,48)<1;2,0> SHIFTS(0,0)<0;2,1>\n" + CUTE_XE_REORDER_E2M1_BF16_SEQ + "}\n" + : "=rw"(dst0) + : "rw"(src0), "rw.u"(shifts) + ); +#else + CUTE_INVALID_CONTROL_PATH("Not Xe"); +#endif + } +}; + +template <> +struct Xe_Reorder +{ + using SRegisters = intel::uchar4[1]; + using DRegisters = intel::ushort8[1]; + + CUTE_HOST_DEVICE static void + reorder(intel::uchar4 const& src0, intel::ushort8& dst0) + { +#if defined(CUTE_ARCH_COPY_XE_ENABLED) + const uint32_t shifts = 0x0008000C; + asm ( /* 6 cycles/output register */ + "{\n" + ".decl IN_UB v_type=G type=UB num_elts=64 alias=<%1,0>\n" + ".decl OUT_UW v_type=G type=UW num_elts=128 alias=<%0,0>\n" + ".decl SHIFTS v_type=G type=UW num_elts=2 alias=<%2,0>\n" + "shl (M1_NM, 16) OUT_UW(0,0)<2> IN_UB(0,0)<1;2,0> SHIFTS(0,0)<0;2,1>\n" + "shl (M1_NM, 16) OUT_UW(0,1)<2> IN_UB(0,8)<1;2,0> SHIFTS(0,0)<0;2,1>\n" + "shl (M1_NM, 16) OUT_UW(1,0)<2> IN_UB(0,16)<1;2,0> SHIFTS(0,0)<0;2,1>\n" + "shl (M1_NM, 16) OUT_UW(1,1)<2> IN_UB(0,24)<1;2,0> SHIFTS(0,0)<0;2,1>\n" + "shl (M1_NM, 16) OUT_UW(2,0)<2> IN_UB(0,32)<1;2,0> SHIFTS(0,0)<0;2,1>\n" + "shl (M1_NM, 16) OUT_UW(2,1)<2> IN_UB(0,40)<1;2,0> SHIFTS(0,0)<0;2,1>\n" + "shl (M1_NM, 16) OUT_UW(3,0)<2> IN_UB(0,48)<1;2,0> SHIFTS(0,0)<0;2,1>\n" + "shl (M1_NM, 16) OUT_UW(3,1)<2> IN_UB(0,56)<1;2,0> SHIFTS(0,0)<0;2,1>\n" + CUTE_XE_REORDER_E2M1_BF16_SEQ + "}\n" + : "=rw"(dst0) + : "rw"(src0), "rw.u"(shifts) + ); +#else + CUTE_INVALID_CONTROL_PATH("Not Xe"); +#endif + } +}; + +template <> +struct Xe_Reorder +{ + using SRegisters = intel::uchar4[1]; + using DRegisters = intel::ushort8[1]; + + CUTE_HOST_DEVICE static void + reorder(intel::uchar4 const& src0, intel::ushort8& dst0) + { +#if defined(CUTE_ARCH_COPY_XE_ENABLED) + const uint32_t shifts = 0x0008000C; + asm ( /* 5 cycles/output register */ + "{\n" + ".decl IN_UB v_type=G type=UB num_elts=64 alias=<%1,0>\n" + ".decl OUT_UW v_type=G type=UW num_elts=128 alias=<%0,0>\n" + ".decl SHIFTS v_type=G type=UW num_elts=2 alias=<%2,0>\n" + "shl (M1_NM, 32) OUT_UW(0,0)<1> IN_UB(0,0)<4;2,0> SHIFTS(0,0)<0;2,1>\n" + "shl (M1_NM, 32) OUT_UW(1,0)<1> IN_UB(0,1)<4;2,0> SHIFTS(0,0)<0;2,1>\n" + "shl (M1_NM, 32) OUT_UW(2,0)<1> IN_UB(0,2)<4;2,0> SHIFTS(0,0)<0;2,1>\n" + "shl (M1_NM, 32) OUT_UW(3,0)<1> IN_UB(0,3)<4;2,0> SHIFTS(0,0)<0;2,1>\n" + CUTE_XE_REORDER_E2M1_BF16_SEQ + "}\n" + : "=rw"(dst0) + : "rw"(src0), "rw.u"(shifts) + ); +#else + CUTE_INVALID_CONTROL_PATH("Not Xe"); +#endif + } +}; + + + + +} // end namespace cute diff --git a/include/cute/arch/util.hpp b/include/cute/arch/util.hpp index c9ff9ef878..85d18c88a7 100644 --- a/include/cute/arch/util.hpp +++ b/include/cute/arch/util.hpp @@ -1,5 +1,6 @@ /*************************************************************************************************** * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (C) 2025 Intel Corporation, All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without @@ -163,6 +164,19 @@ struct CallCOPY { } }; +// +// Wrapper for ReorderOp::reorder +// + +template +struct CallReorder { + template + CUTE_HOST_DEVICE constexpr void + operator()(Args&&... args) const { + return ReorderOp::reorder(static_cast(args)...); + } +}; + // // Utility for exploding pointers/arrays/tensors into functions // diff --git a/include/cute/atom/reorder_atom.hpp b/include/cute/atom/reorder_atom.hpp new file mode 100644 index 0000000000..e89f8e589e --- /dev/null +++ b/include/cute/atom/reorder_atom.hpp @@ -0,0 +1,51 @@ +/*************************************************************************************************** +* Copyright (C) 2025 Intel Corporation, All rights reserved. +* SPDX-License-Identifier: BSD-3-Clause +* +* Redistribution and use in source and binary forms, with or without +* modification, are permitted provided that the following conditions are met: +* +* 1. Redistributions of source code must retain the above copyright notice, this +* list of conditions and the following disclaimer. +* +* 2. Redistributions in binary form must reproduce the above copyright notice, +* this list of conditions and the following disclaimer in the documentation +* and/or other materials provided with the distribution. +* +* 3. Neither the name of the copyright holder nor the names of its +* contributors may be used to endorse or promote products derived from +* this software without specific prior written permission. +* +* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +* +**************************************************************************************************/ + +#pragma once + +namespace cute +{ + +// Reorder dispatch tags +struct ReorderDispatchRelayoutConvert {}; // Change layout, then convert +struct ReorderDispatchConvertRelayout {}; // Convert, then change layout +#ifdef SYCL_INTEL_TARGET +struct ReorderDispatchXeGeneric {}; // Generic Xe subgroup reorder operation +#endif + +} // end namespace cute + + +#include + +#ifdef SYCL_INTEL_TARGET +#include +#endif diff --git a/include/cute/atom/reorder_atom_xe.hpp b/include/cute/atom/reorder_atom_xe.hpp new file mode 100644 index 0000000000..c7bc546fce --- /dev/null +++ b/include/cute/atom/reorder_atom_xe.hpp @@ -0,0 +1,241 @@ +/*************************************************************************************************** +* Copyright (C) 2025 Intel Corporation, All rights reserved. +* SPDX-License-Identifier: BSD-3-Clause +* +* Redistribution and use in source and binary forms, with or without +* modification, are permitted provided that the following conditions are met: +* +* 1. Redistributions of source code must retain the above copyright notice, this +* list of conditions and the following disclaimer. +* +* 2. Redistributions in binary form must reproduce the above copyright notice, +* this list of conditions and the following disclaimer in the documentation +* and/or other materials provided with the distribution. +* +* 3. Neither the name of the copyright holder nor the names of its +* contributors may be used to endorse or promote products derived from +* this software without specific prior written permission. +* +* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +* +**************************************************************************************************/ + +#pragma once + +#include + +namespace cute +{ + +// Reorder kinds. +// "U" -> unit stride in registers +// "V" -> VNNI format in registers +// e.g. UV means "unit to VNNI" +enum class ReorderKind : int { UU_Universal, UU, UV, VU, VV, Generic }; + +template +struct Xe_Reorder { + using Unimplemented = void; +}; + +// Check for the existence of an optimized reorder sequence. +template +constexpr bool has_xe_optimized_reorder_impl(char) { return true; } +template ::Unimplemented> +constexpr bool has_xe_optimized_reorder_impl(int) { return false; } + +template +constexpr bool has_xe_optimized_reorder() { + return has_xe_optimized_reorder_impl(0); +} + +// Classify a subgroup-scope reorder. +template +constexpr ReorderKind classify_xe_reorder() +{ + constexpr int R = rank(ReorderLayout{}); + using Size0 = decltype(size<0>(ReorderLayout{})); + using Stride0 = decltype(stride<0>(ReorderLayout{})); + using _SV = Int; + using _DV = Int; + + constexpr int VL = 16 * cute::max(SV, DV); + if (is_constant_v<1, Stride0>) { + // Unit stride -> unit stride. Require whole GRFs for both src and dst. + if constexpr (is_constant_v<0, decltype(Size0{} % Int{})>) { + return ReorderKind::UU; + } + + // Fallback unit->unit when we do not have full GRFs. + if constexpr (is_constant_v<0, decltype(Size0{} % _16{})>) { + return ReorderKind::UU_Universal; + } + } + + // Check for VNNI reorders. + // Fundamental assumption: values associated with a single VNNI block are contiguous in val space + // in both src and dst (even if only one of those is in VNNI format). All others take the generic path. + if constexpr (R >= 2) { + constexpr auto Modes01 = take<0,2>(ReorderLayout{}); + + // Check for unit <-> VNNI reorders. + // unit->VNNI: (_16, _DV, ...):(_DV, _1, ...) + // VNNI->unit: (_SV, _16, ...):(_16, _1, ...) + if constexpr (Modes01 == Layout, Stride<_DV, _1>>{}) { + return ReorderKind::UV; + } + if constexpr (Modes01 == Layout, Stride<_16, _1>>{}) { + return ReorderKind::VU; + } + + // Check for VNNI -> VNNI reorders. + // SV > DV: (_DV, _SV/DV, _16, ...):(_1, _DV*16, _DV, ...) + // DV > SV: (_SV, _16, _DV/SV, ...):(_1, _SV*16, _SV, ...) + if constexpr (R >= 3 && SV != DV) { + constexpr auto Modes012 = take<0,3>(ReorderLayout()); + if constexpr (SV > DV && Modes012 == Layout, _16>, Stride<_1, Int, _DV>>{}) { + return ReorderKind::VV; + } else if constexpr (DV > SV && Modes012 == Layout>, Stride<_1, Int, _SV>>{}) { + return ReorderKind::VV; + } + } + } + + return ReorderKind::Generic; +} + +template +constexpr auto choose_xe_reorder_impl(SLayout const& slayout, // (src thr, src val) -> coord + DLayout const& dlayout) { // (dst thr, dst val) -> coord + // Calculate data transformation, interleaving WI-owned values: + // (thr0,val0) ... (thr15,val0), (thr0,val1), ..., (thr15,val1), ... + auto rlayout = coalesce(composition(right_inverse(dlayout), slayout)); // src index -> dst index + + // Classify reorder type. + constexpr auto SV = 32 / sizeof_bits_v; // src elements per 32-bit channel + constexpr auto DV = 32 / sizeof_bits_v; // dst elements per 32-bit channel + constexpr auto rclass = classify_xe_reorder(); + + if constexpr (rclass == ReorderKind::UU_Universal) + return Universal_Reorder_UU{}; + else if constexpr (has_xe_optimized_reorder()) + return Xe_Reorder{}; + else if constexpr (is_subbyte_v) + return ReorderDispatchConvertRelayout{}; + else if constexpr (is_subbyte_v) + return ReorderDispatchRelayoutConvert{}; + else if constexpr (!is_same_v, remove_cv_t>) + return ReorderDispatchRelayoutConvert{}; + else + return ReorderDispatchXeGeneric{}; +} + + +// Copy a strided vector to a strided vector in GRF. +// src and dst must each fit within a single register. +template +CUTE_HOST_DEVICE +void +reorder_span(Tensor const& src, + Tensor & dst) +{ + using namespace intel; + using ValType = typename SEngine::element_type; + using StorageType = storage_vector_t; + constexpr int grf_elems = 64 / sizeof(ValType); + const auto& sv = *recast_ptr(src.data() + ((sidx / grf_elems) * (grf_elems / sg_size))); + auto& dv = *recast_ptr(dst.data() + ((didx / grf_elems) * (grf_elems / sg_size))); + constexpr auto soff = sidx % grf_elems; + constexpr auto doff = didx % grf_elems; +#ifdef __SYCL_DEVICE_ONLY__ + asm ( + "mov (M1_NM, %2) %0(0,%5)<%3> %1(0,%6)<%4;1,0>" + : "+rw"(dv) + : "rw"(sv), "P"(simd), "P"(dstride), "P"(sstride), "P"(doff), "P"(soff) + ); +#endif +} + +// Generic Xe reorders, supporting arbitrary layout changes, but not type conversions. +template +CUTE_HOST_DEVICE +void +reorder_impl(ReorderDispatchXeGeneric const&, + Tensor const& src, // WI fragment + Tensor & dst, // WI fragment + SLayout const& slayout, // (src thr, src val) -> coord + DLayout const& dlayout) // (dst thr, dst val) -> coord +{ + using SrcType = typename SEngine::element_type; + using DstType = typename DEngine::element_type; + static_assert(is_same_v, "No type conversions allowed on this path"); + + auto rlayout = coalesce(composition(right_inverse(dlayout), slayout)); // src index -> dst index + auto ilayout = coalesce(composition(right_inverse(slayout), dlayout)); // dst index -> src index + + // Decide whether to stride on src or dst, depending on which allows a longer vector length. + static constexpr int elems_per_grf = 64 / sizeof(SrcType); + static constexpr int ds_vl = cute::min(32, cute::min(shape<0>(rlayout), elems_per_grf / stride<0>(rlayout))); + static constexpr int ss_vl = cute::min(32, cute::min(shape<0>(ilayout), elems_per_grf / stride<0>(ilayout))); + + // Make dst live, to prevent compiler from inserting its own initialization. +#ifdef __SYCL_DEVICE_ONLY__ + using StorageType = intel::storage_vector_t; + + CUTE_UNROLL + for (int i = 0; i < dst.size(); i += 4 / sizeof(DstType)) { + auto &dv = *recast_ptr(dst.data() + i); + asm("" : "=rw"(dv)); + } +#endif + + if constexpr (ss_vl >= ds_vl) { + // Stride on src. For simplicity, take 1 GRF at a time. + for_each(make_seq{}, [&](auto i) { + constexpr auto didx = i * ss_vl; + constexpr auto sidx = ilayout(didx); + reorder_span(decltype(ilayout){}), 1, sidx, didx>(src, dst); + }); + } else { + // Stride on dst. + for_each(make_seq{}, [&](auto i) { + constexpr auto sidx = i * ds_vl; + constexpr auto didx = rlayout(sidx); + reorder_span(decltype(rlayout){}), sidx, didx>(src, dst); + }); + } +} + +// +// Display utilities +// +CUTE_HOST_DEVICE +void +print(ReorderKind kind) { +#define CASE(x) if (kind == ReorderKind::x) print(#x); + CASE(UU_Universal) + CASE(UU) + CASE(UV) + CASE(VU) + CASE(VV) + CASE(Generic) +#undef CASE +} + +} // end namespace cute + + +#include diff --git a/include/cute/tensor.hpp b/include/cute/tensor.hpp index 3171d214ee..a20f984b8b 100644 --- a/include/cute/tensor.hpp +++ b/include/cute/tensor.hpp @@ -54,6 +54,7 @@ #include #include #include +#include #include #include From 3ef685153817b5c819467871442e15595db73033 Mon Sep 17 00:00:00 2001 From: Peter Caday Date: Thu, 4 Sep 2025 12:03:20 -0700 Subject: [PATCH 13/16] [CuTe] Add native int4 compute --- include/cute/arch/copy_xe_2d.hpp | 12 ++++-- include/cute/arch/mma_xe.hpp | 18 +++++++++ include/cute/arch/reorder_xe.hpp | 12 ++++++ include/cute/atom/copy_traits_xe_2d.hpp | 51 +++++++++++++++++-------- include/cute/atom/mma_traits_xe.hpp | 32 +++++++++++++--- include/cute/util/sycl_vec.hpp | 10 ++++- 6 files changed, 108 insertions(+), 27 deletions(-) diff --git a/include/cute/arch/copy_xe_2d.hpp b/include/cute/arch/copy_xe_2d.hpp index 02b4b019e1..6f24ef0515 100644 --- a/include/cute/arch/copy_xe_2d.hpp +++ b/include/cute/arch/copy_xe_2d.hpp @@ -77,7 +77,8 @@ struct XE_LOAD_2D : XE_Copy_Op_2D_Base template CUTE_HOST_DEVICE static void copy(const int *payload, T *dst) { #ifdef CUTE_ARCH_COPY_XE_ENABLED - auto &dv = *reinterpret_cast*>(dst); + using namespace intel; + auto &dv = *reinterpret_cast*>(dst); asm ( "lsc_load_block2d.ugm (M1, 1) %0:d%2.%3x%4x%5nn flat[%1+(0,0)]" : "=rw"(dv) @@ -99,7 +100,8 @@ struct XE_LOAD_2D_VNNI : XE_Copy_Op_2D_Base CUTE_HOST_DEVICE static void copy(const int *payload, T *dst) { #ifdef CUTE_ARCH_COPY_XE_ENABLED - auto &dv = *reinterpret_cast*>(dst); + using namespace intel; + auto &dv = *reinterpret_cast*>(dst); asm ( "lsc_load_block2d.ugm (M1, 1) %0:d%2.%3x%4x%5nt flat[%1+(0,0)]" : "=rw"(dv) @@ -123,7 +125,8 @@ struct XE_LOAD_2D_TRANSPOSE : XE_Copy_Op_2D_Base template CUTE_HOST_DEVICE static void copy(const int *payload, T *dst) { #ifdef CUTE_ARCH_COPY_XE_ENABLED - auto &dv = *reinterpret_cast*>(dst); + using namespace intel; + auto &dv = *reinterpret_cast*>(dst); asm ( "lsc_load_block2d.ugm (M1, 1) %0:d%2.%3x%4tn flat[%1+(0,0)]" : "=rw"(dv) @@ -162,7 +165,8 @@ struct XE_STORE_2D : XE_Copy_Op_2D_Base template CUTE_HOST_DEVICE static void copy(const int *payload, const T *src) { #ifdef CUTE_ARCH_COPY_XE_ENABLED - auto &sv = *reinterpret_cast*>(src); \ + using namespace intel; + auto &sv = *reinterpret_cast*>(src); \ asm ( "lsc_store_block2d.ugm (M1, 1) flat[%1+(0,0)] %0:d%2.%3x%4nn" :: "rw"(sv), "rw.u"(payload), "P"(Bits), "P"(Width), "P"(Height) diff --git a/include/cute/arch/mma_xe.hpp b/include/cute/arch/mma_xe.hpp index 20e3ce19a7..846fbadce5 100644 --- a/include/cute/arch/mma_xe.hpp +++ b/include/cute/arch/mma_xe.hpp @@ -150,6 +150,24 @@ CUTE_DECLARE_XE_DPAS_TT(d, u8, s8, d) CUTE_DECLARE_XE_DPAS_TT(d, s8, u8, d) CUTE_DECLARE_XE_DPAS_TT(d, s8, s8, d) +CUTE_DECLARE_XE_DPAS_TT(ud, u8, u4, ud) +CUTE_DECLARE_XE_DPAS_TT(d, u8, u4, d) +CUTE_DECLARE_XE_DPAS_TT(d, u8, s4, d) +CUTE_DECLARE_XE_DPAS_TT(d, s8, u4, d) +CUTE_DECLARE_XE_DPAS_TT(d, s8, s4, d) + +CUTE_DECLARE_XE_DPAS_TT(ud, u4, u8, ud) +CUTE_DECLARE_XE_DPAS_TT(d, u4, u8, d) +CUTE_DECLARE_XE_DPAS_TT(d, u4, s8, d) +CUTE_DECLARE_XE_DPAS_TT(d, s4, u8, d) +CUTE_DECLARE_XE_DPAS_TT(d, s4, s8, d) + +CUTE_DECLARE_XE_DPAS_TT(ud, u4, u4, ud) +CUTE_DECLARE_XE_DPAS_TT(d, u4, u4, d) +CUTE_DECLARE_XE_DPAS_TT(d, u4, s4, d) +CUTE_DECLARE_XE_DPAS_TT(d, s4, u4, d) +CUTE_DECLARE_XE_DPAS_TT(d, s4, s4, d) + #undef CUTE_DECLARE_XE_DPAS_TT } //namespace cute diff --git a/include/cute/arch/reorder_xe.hpp b/include/cute/arch/reorder_xe.hpp index 496a0180a1..42e701a4ce 100644 --- a/include/cute/arch/reorder_xe.hpp +++ b/include/cute/arch/reorder_xe.hpp @@ -43,6 +43,18 @@ namespace cute { template struct Xe_Reorder : Universal_Reorder_UU {}; +template +struct Xe_Reorder { + using StorageT = conditional_t<(sizeof_bits_v >= 8), T, uint8_t>; + using SRegisters = StorageT[1]; + using DRegisters = StorageT[1]; + + CUTE_HOST_DEVICE static void + reorder(StorageT const& src0, StorageT& dst0) { + dst0 = src0; + } +}; + template <> struct Xe_Reorder { diff --git a/include/cute/atom/copy_traits_xe_2d.hpp b/include/cute/atom/copy_traits_xe_2d.hpp index 52aaa743d0..2df8ae0a38 100644 --- a/include/cute/atom/copy_traits_xe_2d.hpp +++ b/include/cute/atom/copy_traits_xe_2d.hpp @@ -76,7 +76,7 @@ template ; - using ThrID = Layout<_16>; + using ThrID = Layout; static constexpr int ValBits = is_void_v ? Op::CopyBits : int(sizeof_bits_v); @@ -263,7 +263,7 @@ struct XeInterleavedLayoutHelper { // Res: ((_2, _8), (_16, _2, _8)):((_16, _256), (_1, _2048, _32)) (T,V) -> (xbit,y) }; -template +template using XeInterleavedLayout = typename XeInterleavedLayoutHelper::Result; // Block 2D load traits. @@ -282,7 +282,7 @@ struct Copy_Traits, XMode, YMode sizeof_bits_v>; using RefLayout = DstLayout; - using SrcLayout = decltype(replace<0>(RefLayout{}, Layout, Stride<_0>>{})); + using SrcLayout = decltype(replace<0>(RefLayout{}, Layout, Stride<_0>>{})); }; // Block 2D VNNI load traits. @@ -303,7 +303,7 @@ struct Copy_Traits, XMode, sizeof_bits_v>; using RefLayout = DstLayout; - using SrcLayout = decltype(replace<0>(RefLayout{}, Layout, Stride<_0>>{})); + using SrcLayout = decltype(replace<0>(RefLayout{}, Layout, Stride<_0>>{})); }; // Block 2D transposed load traits. @@ -322,7 +322,7 @@ struct Copy_Traits, XMode, YMode, sizeof_bits_v>; using RefLayout = DstLayout; - using SrcLayout = decltype(replace<0>(RefLayout{}, Layout, Stride<_0>>{})); + using SrcLayout = decltype(replace<0>(RefLayout{}, Layout, Stride<_0>>{})); }; // Block 2D store traits. @@ -337,7 +337,7 @@ struct Copy_Traits, XMode, YMode, ValType, sizeof_bits_v>; using RefLayout = SrcLayout; - using DstLayout = decltype(replace<0>(RefLayout{}, Layout, Stride<_0>>{})); + using DstLayout = decltype(replace<0>(RefLayout{}, Layout, Stride<_0>>{})); using Op = XE_STORE_2D; using Super = Xe2DTraitsBase; @@ -386,7 +386,7 @@ struct Copy_Traits, XMode, Y sizeof_bits_v>; using RefLayout = DstLayout; - using SrcLayout = decltype(replace<0>(RefLayout{}, Layout, Stride<_0>>{})); + using SrcLayout = decltype(replace<0>(RefLayout{}, Layout, Stride<_0>>{})); using Op = XE_PREFETCH_2D; using Super = Xe2DTraitsBase; @@ -564,15 +564,34 @@ get_block_size(InLayout const&) { return get(atuple_coshape(layout)); } -// Remove VNNI modes from a layout, if present. +// Remove subbyte packing modes from a layout, if present. +template +CUTE_HOST_DEVICE +constexpr auto +strip_subbyte(InLayout const& layout) +{ + using namespace cute::intel; + if constexpr (Bits >= 8) + return layout; + else { + static_assert(is_static_v, "Layout must be static"); + constexpr auto values = size(InLayout{}) / sg_size; + constexpr auto per_byte = 8 / Bits; + static_assert(values % per_byte == 0, "Partially-occupied bytes in layout"); + return coalesce(composition(layout, Layout, _SGSize, C>, + Stride<_SGSize, _1, C>>{})); + } +} + +// Remove VNNI and subbyte packing modes from a layout, if present. // Returns a std::pair = (layout_out, has_vnni) template CUTE_HOST_DEVICE constexpr auto -strip_vnni(InLayout const&) +strip_vnni_subbyte(InLayout const&) { - constexpr InLayout layout{}; - constexpr int R = rank(InLayout{}); + constexpr auto layout = strip_subbyte(InLayout{}); + constexpr int R = rank(layout); constexpr bool vnni = (R >= 2) && (Bits < 32) && is_constant_v<32 / Bits, decltype(size<0>(layout))>; @@ -606,7 +625,7 @@ block_2d_transform_selector(DesiredCoordLayout const& layout, return Block2DTransform::N; // Check if copy's consumer wants VNNI layout. - constexpr auto result = strip_vnni(DesiredCoordLayout{}); + constexpr auto result = strip_vnni_subbyte(DesiredCoordLayout{}); constexpr auto slayout = get<0>(result); constexpr bool vnni = get<1>(result); constexpr bool transpose = !is_constant_v<1, decltype(basis_get(stride<0>(slayout), gstride))>; @@ -648,12 +667,12 @@ block_2d_selector(CoordLayout const&, GlobalStride const&) constexpr auto kind = block_2d_transform_selector(layout, gstride); // Strip off VNNI mode if present. - constexpr auto slayout = get<0>(strip_vnni(layout)); + constexpr auto slayout = get<0>(strip_vnni_subbyte(layout)); constexpr auto x_mode = find_x_mode(gstride); constexpr auto y_mode = find_y_mode(gstride); - constexpr int grf_elems = 512 / RegBits; + constexpr int min_large_block = cute::min(256 / RegBits, 16); constexpr bool resize = (MemBits != RegBits); auto shape = atuple_coshape(layout); @@ -666,7 +685,7 @@ block_2d_selector(CoordLayout const&, GlobalStride const&) // Block width = highest power of 2 divisor (up to 64b) // Width = highest power of 2 divisor of full tile's width, up to 64b and 4x block width constexpr int max_w = 64 * 8 / MemBits; - constexpr int x_stride = get_block_size(slayout); + constexpr int x_stride = get_block_size(slayout); constexpr int block_width = cute::gcd(max_w, x_stride); constexpr int load_width = cute::gcd(cute::min(max_w, 4 * block_width), get(shape)); @@ -692,7 +711,7 @@ block_2d_selector(CoordLayout const&, GlobalStride const&) // Similar process for transposing copies, but with width/height reversed. constexpr int CopyBits = cute::max(32, cute::min(64, MemBits)); - constexpr int y_stride = get_block_size(slayout); + constexpr int y_stride = get_block_size(slayout); constexpr int height = cute::gcd(32, y_stride); constexpr int max_w = 32 * 8 / MemBits; diff --git a/include/cute/atom/mma_traits_xe.hpp b/include/cute/atom/mma_traits_xe.hpp index 06f812bacf..5c1fc2b2de 100644 --- a/include/cute/atom/mma_traits_xe.hpp +++ b/include/cute/atom/mma_traits_xe.hpp @@ -39,6 +39,29 @@ namespace cute { +namespace detail +{ + +template +CUTE_HOST_DEVICE +constexpr auto +wi_interleave(LayoutIn const&) +{ + using namespace intel; + constexpr LayoutIn layout{}; + constexpr int per_byte = ceil_div(8, sizeof_bits_v); + constexpr int vals = ceil_div(size(layout), sg_size); + auto tv_interleaved = Layout, C>>, + Stride, Stride<_1, C>>>{}; + return coalesce(composition(layout, tv_interleaved), Step<_1,_1>{}); +} + +template +using wi_interleave_t = remove_cvref_t(LayoutIn{}))>; + +} // end namespace detail + + template struct MMA_Traits> { @@ -55,17 +78,16 @@ struct MMA_Traits> using _K = Int; using Shape_MNK = Shape<_M, _16, _K>; - using ThrID = Layout<_16>; + using ThrID = Layout; // A layout: (T,V) -> (M,K) // M x K row major, work-items interleaved. - using ALayout = decltype(composition(make_layout(make_shape(_K{}, _M{}), LayoutRight{}), - make_layout(make_shape(_16{}, Int{})))); + using ALayout = detail::wi_interleave_t, Stride<_M, _1>>>; // B layout: (T,V) -> (N,K) // K x 16 VNNI-transformed row major, work-items interleaved. - using BLayout = Layout, Int<16/BV>>, Shape, Int<16/BV>>>, - Stride, Stride, Int<16*BV>>>>; + using BLayout = detail::wi_interleave_t, _16, Int>, + Stride<_16, _1, Int<16*BV>>>>; // C layout: (T,V) -> (M,N) // M x 16 row major, work-items interleaved. diff --git a/include/cute/util/sycl_vec.hpp b/include/cute/util/sycl_vec.hpp index bc0937a0ec..eeecb98a39 100644 --- a/include/cute/util/sycl_vec.hpp +++ b/include/cute/util/sycl_vec.hpp @@ -42,14 +42,20 @@ constexpr int sg_size = 16; using _SGSize = Int; #ifdef __SYCL_DEVICE_ONLY__ -template struct vector_element_helper { using type = T; }; +template struct vector_element_helper { + using type = conditional_t<(sizeof_bits_v < 8), uint8_t, T>; +}; template <> struct vector_element_helper { using type = uint32_t; }; template <> struct vector_element_helper { using type = uint16_t; }; template <> struct vector_element_helper { using type = uint16_t; }; template <> struct vector_element_helper { using type = uint8_t; }; template <> struct vector_element_helper { using type = uint8_t; }; -template using vector_t = typename vector_element_helper::type __attribute__((ext_vector_type(N))); +template struct vector_helper { + using U = typename vector_element_helper::type; + using type = U __attribute__((ext_vector_type(ceil_div(N * sizeof_bits_v, sizeof_bits_v)))); +}; +template using vector_t = typename vector_helper::type; #else template using vector_t = sycl::marray; #endif From cf1cbdd20698c7cfb4c13a46875ec01da2583002 Mon Sep 17 00:00:00 2001 From: Peter Caday Date: Mon, 18 Aug 2025 22:59:02 -0700 Subject: [PATCH 14/16] [Documentation] Document Xe architectural changes --- media/docs/cpp/xe_rearchitecture.md | 506 ++++++++++++++++++++++++++++ 1 file changed, 506 insertions(+) create mode 100644 media/docs/cpp/xe_rearchitecture.md diff --git a/media/docs/cpp/xe_rearchitecture.md b/media/docs/cpp/xe_rearchitecture.md new file mode 100644 index 0000000000..1b8d159b3f --- /dev/null +++ b/media/docs/cpp/xe_rearchitecture.md @@ -0,0 +1,506 @@ +# Xe CuTe Architecture Redesign + +## Limitations of Current Intel CuTe Architecture + +* VNNI layout used by DPAS and block 2D VNNI loads is hidden from CuTe/CUTLASS. + - Compiler inserts extra interleave/deinterleave operations if there is any computation between VNNI load and DPAS. + - Additionally, any such computation using the native B data type (instead of int) can lead to private memory traffic. +* MMA and copy fragments must be carefully set up to match layouts + - Fragile code, easy to break functionality + - Sometimes it's not possible or desirable to match layouts +* Block 2D atoms are somewhat disconnected from the actual operations in HW + - Thread-value layouts (of both atoms and TiledCopys) don't reflect the actual ownership of data by threads. + - Necessary to introduce "fake" copy atoms (like the U4 atoms) that do not reflect actual HW operations. + + +## Goals + +The large-scale goal of this re-architecture is to improve CUTLASS performance and flexibility on pre-Xe4 GPUs. In more detail, we want to: + +* Rewrite MMA/copy atoms to accurately expose HW capabilities + - Expose VNNI layouts for MMA B matrix and VNNI block 2D copies + - Expose the actual block 2D operations provided by HW + +* Improve performance + - Introduce high-performance subgroup-level reorder operations for copying data that can change data types and layouts at the same time + - Remove hidden overheads from MMA/copy atoms + +* Improve ease-of-use + - Introduce helper functions to construct accurate block 2D TiledCopy objects + - Allow user to partition tensors to copy and MMA fragments following the usual CuTe paradigm, without worrying about whether the fragments' layouts match perfectly + +* Improve maintainability + - Reduce boilerplate/duplication in atom definitions + - Allow more flexible sizes in block 2D atoms + - Make it easy to add new operations for CRI + +In the following sections, we will walk through the proposed changes. + + +## DPAS Atoms + +All DPAS MMA atoms are now instantiations of a single `XE_DPAS_TT` template. + +Following CuTe conventions, the `TT` part of the name reflects the interpretation of A/B as row-major matrices, although more accurately B is VNNI-transformed. + +```c++ +template +struct XE_DPAS_TT; +``` + +Hardware allows any `M` with 1 <= M <= 8, while `N` is fixed at 16 for all CUTLASS-supported GPUs (PVC and later). The `K` dimension depends on the data types, and can be queried from the atom. + +Default template parameters are available for B/C since they often match A/D respectively, but HW can support different combinations (e.g. f16/f16/f16/f32 or s32/u8/s4/s32). + +Initially, `XE_DPAS_TT` has been implemented using inline vISA; this is required in order to correctly expose matrix B's VNNI layout. + + +## Block 2D Copy Atoms + +### Background and Atom Definition + +PVC introduces 2D block messages that move fixed-size 2D tiles between global memory and registers. Their main features: + +* Automatic bounds-checking -- removing out-of-bounds global memory accesses and filling out-of-bounds loaded elements with zeros +* Layout changes that in many cases allow DPAS to work directly on loaded data, without additional data reordering in registers/SLM. + +These messages also come with important restrictions to know about: + +* Alignment restrictions: + - Buffer base address must be 64-byte-aligned + - Stride must be 16-byte-aligned (in some cases on PVC, 4 or 8-byte alignment is OK) + - Width (defined later) must be 4-byte-aligned +* Size restrictions: + - Width/pitch are restricted to 2^24 bytes + - Height is restricted to 2^24 elements + +For debugging, the user can compile with `-DCUTE_ENABLE_XE_BLOCK_2D_ASSERT=1` to enable (expensive) runtime checks on these restrictions. + +In hardware, block 2D copies come in 4 types: + +* 2D load (8/16/32/64 bits per element) +* 2D load with VNNI transform (8/16 bits per element) +* 2D load with transpose (32/64 bits per element) +* 2D store (8/16/32/64 bits per element) + +In addition, any load can be turned into a prefetch. However, the VNNI transform/transpose options are not relevant for prefetching, so can be safely disregarded. + +All said, there are 5 operations to consider, and each of these has a corresponding templated atom class: + +* Loads: `XE_LOAD_2D`, `XE_LOAD_2D_VNNI`, `XE_LOAD_2D_TRANSPOSE` +* Store: `XE_STORE_2D` +* Prefetch: `XE_PREFETCH_2D` + +Block 2D messages use image terminology: a 2D buffer has a _width_ (x dimension), _height_ (y dimension), and _pitch_ (y dimension stride). Width/x always refers to the dimension that is contiguous (stride-1) **in memory**; height/y is the strided dimension. A transposing load will store the y dimension contiguously in registers, while regular load/store messages have the x dimension contiguous in registers. + +In hardware, block 2D messages have 4 main configuration points: + +* Data size (8/16/32/64 bit) +* Block width +* Block height +* Block count + +_Block width/height_ define the size of each block. Again, this is the size as seen in global memory. + +The _block count_ parameter (available for regular and VNNI loads only) can load multiple blocks at once. In memory, blocks are adjacent in the width dimension, so the total tile size loaded from memory has size `(width*count) x height`. The block count has the effect of rearranging loaded data in registers, and can be used to ensure that the loaded data can be fed directly to DPAS, without needing further rearrangement in registers. + +Finally, _data size_ controls how HW interprets other parameters (e.g. block width, transposition) but does not need to match the actual data type stored in memory. For instance, we can use a 32-bit transpose load with 8-bit or 16-bit data. This results in VNNI-format data in registers, perfect for the DPAS B matrix. + +These four parameters translate to template parameters for the CuTe copy atoms: + +```c++ +template +struct XE_LOAD_2D; + +template +struct XE_LOAD_2D_VNNI; + +template +struct XE_LOAD_2D_TRANSPOSE; + +template +struct XE_STORE_2D; + +template +struct XE_PREFETCH_2D; +``` + +Above, `Width` is the total width loaded (block width * block count). The `Width` and `Height` parameters are ordered to reflect a row-major view of 2D data in memory: height = #rows, width = #columns. + +### Block 2D Traits + +The traits classes come with further configuration parameters so we can apply these abstract copy operations onto specific global memory tensors. Most users don't need to worry about setting up these parameters, since they are automatically chosen by the `make_block_2d_copy` helper APIs (described in the following section). + + +```c++ +template +struct Copy_Traits; +``` + +* `XMode`, `YMode`: indices of the modes (dimensions) corresponding to x and y. + - For instance, a 3D k-major tensor with indices ordered (M,K,L) would have `XMode = 1`, `YMode = 0`. +* `ValType`: actual data type (possibly having a different size than the underlying block 2D operation). + - In CuTe, this template parameter is specified at the `Copy_Atom` level, but for Xe it's also required in the traits due to the compiler's SIMD-like register data mapping. +* `TiledStrides`: strides for modes other than x/y, for > 2D tensors. + +### Creating Block 2D Atoms + +Since it can be a tricky to correctly choose block 2D parameters and set up an appropriate tiling, we introduce several helpers for creating TiledCopy objects. + +The high-level APIs `make_block_2d_copy_{A,B,C}` automatically create TiledCopy objects for use with an existing `TiledMMA`. They choose the copy operation and trait template parameters heuristically. + +```c++ +template +CUTE_DEVICE +TiledCopy<...> +make_block_2d_copy_A(const TiledMMA<...>&, + const Tensor& gmem); // (M,K,...) + +template +CUTE_DEVICE +TiledCopy<...> +make_block_2d_copy_B(const TiledMMA<...>&, + const Tensor& gmem); // (N,K,...) + +template +CUTE_DEVICE +TiledCopy<...> +make_block_2d_copy_C(const TiledMMA<...>&, + const Tensor& gmem); // (M,N,...) +``` + +The user may also override the choice of copy operation: + +```c++ +template +CUTE_HOST_DEVICE +auto +make_block_2d_copy_A(CopyOp const& op, // Copy operation + TiledMMA const& mma, // TiledMMA instance + Tensor const& gmem); // Global tensor + +/* Similarly for B/C */ +``` + +The `make_block_2d_copy_*` family of functions create TiledCopy objects that match the scope of the TiledMMA. That is, the set of threads participating in the TiledMMA will also participate in the TiledCopy. + +By contrast, the `make_block_2d_copy` API creates a TiledCopy object in which a single subgroup participates: + +```c++ +// Create tiled block 2D copy (scope = single subgroup) for a global memory tensor. +template +CUTE_DEVICE +TiledCopy +make_block_2d_copy(const CopyOp& op, const Tensor& gmem); +``` + +For advanced usage, there are additional overloads of `make_block_2d_copy` that allow more general work distributions for copies (see `include/cute/atom/copy_traits_xe_2d.hpp`). + +As the `CUTE_DEVICE` decorators imply, all the APIs above should be called from device code only, as they set up internal state that cannot be transferred from host to device. + +Alternately, these APIs have variants that take strides and an optional data type override. The resulting TiledCopy objects are uninitialized. + +```c++ +template +TiledCopy +make_block_2d_copy_A(TiledMMA<...> const& mma, Stride const& strides); +``` + +After this _placeholder_ TiledCopy is created, it can be initialized on the device in the standard CuTe fashion, using `with`: + +```c++ +/* host code */ +TiledCopy copy_a_placeholder = make_block_2d_copy_A(...); +... + +/* device code */ +Tensor gA = make_tensor(make_gmem_ptr(...), ...); +TiledCopy copy_a = copy_a_placeholder.with(gA); /* copy_a is now ready to use */ +``` + +### Using Block 2D Atoms + +Block 2D copy atoms follow a "proxy copy" pattern, somewhat akin to the NVIDIA TMA copy atoms. + +At creation time, a block 2D copy atom is constructed with a global tensor, which becomes part of the atom's state. When copying, instead of partitioning the global tensor, the user partitions a corresponding coordinate tensor instead, providing the coordinates (within that global tensor) that should be copied. + +```c++ +Tensor gA = make_tensor(make_gmem_ptr(...), ...); + +/* Construct TiledCopy with global tensor */ +TiledCopy copy_a = make_block_2d_copy_A(mma, gA); +ThrCopy thr_copy_a = copy_a.get_slice(...); + +/* Create a proxy coordinate tensor and use it for actual copy operations */ +Tensor cA = make_identity_tensor(gA.shape()); +Tensor tAcA = thr_copy_a.partition_S(cA); + +Tensor tArA = thr_copy_a.partition_fragment_D(cA); + +/* Copy from global (via coordinate tensor) to registers */ +copy(copy_a, tAcA, tArA); +``` + + +## Subgroup Scope and Thread-Local Data + +DPAS and block 2D copy atoms are _subgroup_ operations, meaning that all 16 threads of the subgroup collectively execute these operations, and collectively own all input/output data. This "subgroupness" reflects the underlying SIMD nature of Intel GPUs: the 16 "threads" are in reality a single thread of execution in HW. + +In order to perform thread-level operations on subgroup-shared data, it's important to understand how the compiler splits ownership of subgroup-scoped private arrays (including register-resident data) among threads. Fortunately, there is a very simple rule for this: +```math + \text{thread\ } i\ \ \ \text{owns elements} \ \ \ i, i+16, i+32, \ldots +``` +That is to say, elements are assigned to threads in a round-robin fashion. Conversely, if we declare a vector variable (say `cute::intel::float8`) in SYCL, the compiler will interleave the vectors from each thread in the subgroup to form a length-128 (8 * 16) float array in registers. + +> [!IMPORTANT] +> Note that the thread mapping _depends on the element data size._ If an array of 32-bit data, say, is reinterpreted _on a register level_ as an array of 16-bit data, data ownership will change -- i.e., in SIMT terms, it is a shuffle operation. Contrast this operation with a SIMT bitcast/reinterpret_cast, which does not change data ownership, but _does_ shuffle data in registers. + +Now that we have the basic thread mapping rule, let's apply it to a simple block 2D load, with height = 8 rows and width = 4 columns. Recalling that the width dimension is contiguous in both memory and registers, we deduce the following mapping: +```math + \begin{array}{c} + \text{Subgroup view}\\ + \begin{array}{cccc} + 0 & 1 & 2 & 3\\ + 4 & 5 & 6 & 7\\ + 8 & 9 & 10 & 11\\ + 12 & 13 & 14 & 15\\ + 16 & 17 & 18 & 19\\ + 20 & 21 & 22 & 23\\ + 24 & 25 & 26 & 27\\ + 28 & 29 & 30 & 31 + \end{array} + \end{array} + \rightarrow + \begin{array}{c} + \text{Thread view}\\ + \begin{array}{cccc} + \text{T0V0} & \text{T1V0} & \text{T2V0} & \text{T3V0}\\ + \text{T4V0} & \text{T5V0} & \text{T6V0} & \text{T7V0}\\ + \text{T8V0} & \text{T9V0} & \text{T10V0} & \text{T11V0}\\ + \text{T12V0} & \text{T13V0} & \text{T14V0} & \text{T15V0}\\ + \text{T0V1} & \text{T1V1} & \text{T2V1} & \text{T3V1}\\ + \text{T4V1} & \text{T5V1} & \text{T6V1} & \text{T7V1}\\ + \text{T8V1} & \text{T9V1} & \text{T10V1} & \text{T11V1}\\ + \text{T12V1} & \text{T13V1} & \text{T14V1} & \text{T15V1} + \end{array} + \end{array} +``` +(Following CuTe convention, `TxVy` means thread `x`, value `y`.) + +An individual DPAS atom's A matrix follows the same pattern, with height ranging from 1 to 8, and width equal to 8 (tf32), 16 (f16/bf16), or 32 (s8/u8). The DPAS C matrix is also organized this way, except that its width is always 16. + +As a more complicated example, let's consider a 16-bit VNNI load, with height = 4, width = 16: +```math + \begin{array}{c} + \text{Subgroup view}\\ + \begin{array}{cccccc} + 0 & 2 & 4 & 6 & \cdots & 30\\ + 1 & 3 & 5 & 7 & \cdots & 31\\ + 32 & 34 & 36 & 38 & \cdots & 62\\ + 33 & 35 & 37 & 39 & \cdots & 63 + \end{array} + \end{array} + \rightarrow + \begin{array}{c} + \text{Thread view}\\ + \begin{array}{cccc} + \text{T0V0} & \text{T2V0} & \text{T4V0} & \cdots & \text{T14V0} & \text{T0V1} & \cdots & \text{T14V1}\\ + \text{T1V0} & \text{T3V0} & \text{T5V0} & \cdots & \text{T15V0} & \text{T1V1} & \cdots & \text{T15V1}\\ + \text{T0V2} & \text{T2V2} & \text{T4V2} & \cdots & \text{T14V2} & \text{T0V3} & \cdots & \text{T14V3}\\ + \text{T1V2} & \text{T3V2} & \text{T5V2} & \cdots & \text{T15V2} & \text{T1V3} & \cdots & \text{T15V3} + \end{array} + \end{array} +``` + +The DPAS B matrix follows the same pattern. + + +### The SubgroupTensor Class + +A new `SubgroupTensor` class represents a subgroup-scope tensor (fragment). SubgroupTensor wraps a standard CuTe rmem Tensor holding the current work-item's slice of the tensor. It implicitly decays to this fragment, so it can be used as a regular rmem Tensor. + +In addition to tensor data, a SubgroupTensor holds a thread-value layout identifying logical coordinates for each element of the tensor. The interpretation of the logical coordinates is user-defined. + +```c++ +template V + class SubgroupTVLayout> // (T,V) -> coord in subgroup +struct SubgroupTensor; + +// Create a SubgroupTensor from an existing rmem Tensor +template <...> +auto make_subgroup_tensor(Tensor const&, SubgroupTVLayout const&); +``` + +To create a `SubgroupTensor`, use the new `partition_sg_fragment_*` methods of the `ThrCopy` and `ThrMMA` classes: +```c++ + template + CUTE_HOST_DEVICE constexpr auto + ThrMMA::partition_sg_fragment_A(ATensor&& atensor) const; // Similarly for B/C + + template + CUTE_HOST_DEVICE auto + ThrCopy::partition_sg_fragment_S(STensor&& stensor) const; // Similarly for D +``` + + +## Subgroup Reorders + +Each subgroup-scope operation (MMA atom, copy atom) requires specific logical layouts. For, instance the DPAS B matrix needs to be in VNNI format, in blocks of Kx16 (for some specific value of K). In many cases, we can use block 2D copy atoms to load data in exactly the right layout. In other cases, we can't, because: + +* The copy/MMA atoms are tiled and these tilings don't match. +* No block 2D copy produces data in the right layout, e.g.: + - Transposed A, if the data type is less than 32 bits wide. + - Dequantizing B, if the storage data type is smaller than the compute data type. +* The fastest load operation is not block 2D (e.g. some memory-bound kernels). +* Block 2D operations cannot be used (insufficient memory alignment, matrix size too large, etc.) +* Data needs to be broadcast/duplicated in special patterns (e.g. scales/zero points). + + +A flexible and powerful way to handle all these situations is via a subgroup-scope _reorder_ operation that allows arbitrary layout changes (shuffles) in registers. In addition to layout changes, reorders may include data type conversions, because these conversions can often be fused into the layout conversion for more efficient code. + +The general API for a reorder looks like: + +```c++ +reorder(SubgroupTensor<...> const& src_fragment, + SubgroupTensor<...> & dst_fragment); +``` + +`reorder` copies subgroup-scoped data between subgroup-owned fragments `src_fragment` and `dst_fragment`. `reorder` uses the subgroup TV layout (part of the SubgroupTensor) to determine which source values map to which destination values. These are computed automatically by the `partition_sg_fragment_*` family of methods. + +`reorder` acts as a "pipe" connecting copy and MMA operations (or any other subgroup-scope operations). With reorders, the kernel writer does not need to worry about perfectly matching layouts between copy and MMA atoms. In case the layouts do match perfectly (as `make_block_2d_copy_{A,B,C}` try to do), the compiler is able to remove the reorder entirely, making it a no-op. + +A longer version of the reorder API takes the source and destination fragments as ordinary `Tensor` objects, in which case the subgroup TV-layouts must be passed in as auxiliary parameters. + +```c++ +reorder(Tensor<...> const& src_fragment, + Tensor<...> & dst_fragment, + Layout<...> & src_sg_layout, /* (T,V) -> coord */ + Layout<...> & dst_sg_layout) /* (T,V) -> coord */ +``` + + +## Example CuTe GEMM + +Let's combine the pieces here to make a complete CuTe-level Xe GEMM kernel. + +```c++ +template +void +gemm_device(ATensor const& A, // (M,K) + BTensor const& B, // (N,K) + CTensor & C, // (M,N) + TiledMMA const & mma) +{ + using namespace cute; + + // ----- + // Setup + // ----- + + /* Create proxy coordinate tensors for each global tensor */ + Tensor cA = make_identity_tensor(A.shape()); // (M,K) + Tensor cB = make_identity_tensor(B.shape()); // (N,K) + Tensor cC = make_identity_tensor(C.shape()); // (M,N) + + /* Split GEMM into workgroup tiles, and identify our workgroup's tile (wg_coord) */ + auto wg_tile = mma.tile_mnk(); + auto wg_coord = make_coord(BlockIdxX(), BlockIdxY(), 0); + + Tensor gA = local_tile(cA, select<0,2>(wg_tile), make_coord(BlockIdxX(),_)); // (BLK_M,BLK_K,k) + Tensor gB = local_tile(cB, select<1,2>(wg_tile), make_coord(BlockIdxY(),_)); // (BLK_N,BLK_K,k) + Tensor gC = local_tile(cC, wg_tile, wg_coord, Step<_1,_1, X>{}); // (BLK_M,BLK_N) + + /* Create block 2D TiledCopies */ + auto copy_a = make_block_2d_copy_A(mma, A); + auto copy_b = make_block_2d_copy_B(mma, B); + auto copy_c = make_block_2d_copy_C(mma, C); + + /* Slice TiledCopy/TiledMMA operations to thread (work-item) level */ + int thread_idx = int(ThreadIdxX()); + + auto thr_mma = mma.get_slice(thread_idx); + auto thr_copy_a = copy_a.get_slice(thread_idx); + auto thr_copy_b = copy_b.get_slice(thread_idx); + + /* Register fragments for MMA */ + auto tCrA = thr_mma.partition_sg_fragment_A(gA(_,_,0)); + auto tCrB = thr_mma.partition_sg_fragment_B(gB(_,_,0)); + + /* Register fragments for copies */ + auto tArA = thr_copy_a.partition_sg_fragment_D(gA(_,_,0)); + auto tBrB = thr_copy_b.partition_sg_fragment_D(gB(_,_,0)); + + /* Partition global tensor (proxies) for copies */ + Tensor tAgA = thr_copy_a.partition_S(gA); + Tensor tBgB = thr_copy_b.partition_S(gB); + + /* Partition C */ + Tensor tCrC = partition_fragment_C(mma, select<0,1>(wg_tile)); + Tensor tCgC = thr_mma.partition_C(gC); /* also matches copy_c's source layout */ + + /* Create prefetch TiledCopy instances */ + auto prefetch_a = make_block_2d_prefetch(copy_a); + auto prefetch_b = make_block_2d_prefetch(copy_b); + + auto thr_prefetch_A = prefetch_a.get_slice(thread_idx); + auto thr_prefetch_B = prefetch_b.get_slice(thread_idx); + + /* Partition global tensor (proxies) for prefetch */ + auto pAgA = thr_prefetch_A.partition_S(gA); + auto pBgB = thr_prefetch_B.partition_S(gB); + + // ------ + // Kernel + // ------ + + constexpr int barrier_scope = 2; + + int k_tile_count = ceil_div(get<2>(shape_MNK), get<2>(cta_tiler)); + int k_tile_prefetch = 0; + + /* Clear the accumulators */ + clear(tCrC); + + /* Warm up loops with prefetch to L1 */ + CUTLASS_PRAGMA_UNROLL + for (; k_tile_prefetch < stages; k_tile_prefetch++) { + prefetch(prefetch_a, pAgA(_,_,k_tile_prefetch)); + prefetch(prefetch_b, pBgB(_,_,k_tile_prefetch)); + } + + /* Main loop */ + CUTLASS_PRAGMA_UNROLL + for (int k_tile = 0; k_tile < k_tile_count; k_tile++, k_tile_prefetch++) { + /* Split barrier keeping threads loosely together */ + barrier_arrive(barrier_scope); + + /* Copy A/B from global memory (ideally L1 cache) to registers */ + copy(copy_a, tAgA(_,_,k_tile), tArA); + copy(copy_b, tBgB(_,_,k_tile), tBrB); + + /* Prefetch A/B tiles to L1 */ + prefetch(prefetch_a, pAgA(_,_,k_tile_prefetch)); + prefetch(prefetch_b, pBgB(_,_,k_tile_prefetch)); + + /* Shuffle data from copy fragments to MMA fragments */ + reorder(tArA, tCrA); + reorder(tBrB, tCrB); + + /* Accumulate C += A * B */ + gemm(mma, tCrA, tCrB, tCrC); + + /* Other half of split barrier */ + barrier_wait(barrier_scope); + } + + /* Write C to global memory */ + copy(copy_c, tCrC, tCgC); +} +``` + + +## New Collective MMAs + +... coming later! \ No newline at end of file From cc5d3bba2ed5b41d931fe57939154cee143a7ade Mon Sep 17 00:00:00 2001 From: Peter Caday Date: Wed, 3 Sep 2025 10:05:39 -0700 Subject: [PATCH 15/16] [CUTLASS] Add missing platform::numeric_limits for tf32 --- include/cutlass/tfloat32.h | 60 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 60 insertions(+) diff --git a/include/cutlass/tfloat32.h b/include/cutlass/tfloat32.h index 17380e9041..aa73e0dcf0 100644 --- a/include/cutlass/tfloat32.h +++ b/include/cutlass/tfloat32.h @@ -1,5 +1,6 @@ /*************************************************************************************************** * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (C) 2025 Intel Corporation, All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without @@ -332,6 +333,65 @@ struct numeric_limits { } // namespace std +namespace cutlass { +namespace platform { + +/// Forward Declaration +template +struct numeric_limits; + +/// Numeric limits +template <> +struct numeric_limits { + static bool const is_specialized = true; + static bool const is_signed = true; + static bool const is_integer = false; + static bool const is_exact = false; + static bool const has_infinity = true; + static bool const has_quiet_NaN = true; + static bool const has_signaling_NaN = false; +#if !defined(__CUDACC_RTC__) + static std::float_denorm_style const has_denorm = std::denorm_present; +#endif + static bool const has_denorm_loss = true; +#if !defined(__CUDACC_RTC__) + static std::float_round_style const round_style = std::round_to_nearest; +#endif + static bool const is_iec559 = false; + static bool const is_bounded = true; + static bool const is_modulo = false; + static int const digits = 19; + + /// Least positive value + static cutlass::tfloat32_t min() { return cutlass::tfloat32_t::bitcast(0x01); } + + /// Minimum finite value + static cutlass::tfloat32_t lowest() { return cutlass::tfloat32_t::bitcast(0xff7fffff); } + + /// Maximum finite value + static cutlass::tfloat32_t max() { return cutlass::tfloat32_t::bitcast(0x7f7fffff); } + + /// Returns smallest finite value + static cutlass::tfloat32_t epsilon() { return cutlass::tfloat32_t::bitcast(0x1000); } + + /// Returns smallest finite value + static cutlass::tfloat32_t round_error() { return cutlass::tfloat32_t(0.5f); } + + /// Returns smallest finite value + static cutlass::tfloat32_t infinity() { return cutlass::tfloat32_t::bitcast(0x7f800000); } + + /// Returns smallest finite value + static cutlass::tfloat32_t quiet_NaN() { return cutlass::tfloat32_t::bitcast(0x7fffffff); } + + /// Returns smallest finite value + static cutlass::tfloat32_t signaling_NaN() { return cutlass::tfloat32_t::bitcast(0x7fffffff); } + + /// Returns smallest finite value + static cutlass::tfloat32_t denorm_min() { return cutlass::tfloat32_t::bitcast(0x1); } +}; +} // namespace platform +} // namespace cutlass + /////////////////////////////////////////////////////////////////////////////////////////////////// // // Arithmetic operators From 3c0bc236e5298083bd08c6e1ddba9a7eaec8e6c8 Mon Sep 17 00:00:00 2001 From: Peter Caday Date: Fri, 29 Aug 2025 15:37:38 -0700 Subject: [PATCH 16/16] [Examples] Add Xe GEMM CuTe example --- examples/common/sycl_cute_common.hpp | 166 +++++++ examples/cute/tutorial/CMakeLists.txt | 15 +- .../{sgemm_bmg.cpp => bgemm_bmg_legacy.cpp} | 1 + examples/cute/tutorial/xe_gemm.cpp | 464 ++++++++++++++++++ media/docs/cpp/xe_rearchitecture.md | 52 +- 5 files changed, 669 insertions(+), 29 deletions(-) create mode 100644 examples/common/sycl_cute_common.hpp rename examples/cute/tutorial/{sgemm_bmg.cpp => bgemm_bmg_legacy.cpp} (99%) create mode 100644 examples/cute/tutorial/xe_gemm.cpp diff --git a/examples/common/sycl_cute_common.hpp b/examples/common/sycl_cute_common.hpp new file mode 100644 index 0000000000..4875a1a57b --- /dev/null +++ b/examples/common/sycl_cute_common.hpp @@ -0,0 +1,166 @@ +/*************************************************************************************************** +* Copyright (c) 2025 ---- + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +#include "cute/tensor.hpp" + +// +// Common routines for SYCL CuTe examples. +// + +// Variant of make_signed_t that works for both integer and floating point types. +template +auto ensure_signed_helper_t() { + if constexpr (cute::is_unsigned_v) + return cute::make_signed_t{}; + else + return T{}; +} + +template +using ensure_signed_t = decltype(ensure_signed_helper_t()); + +template +T random_value() +{ + using Limits = cutlass::platform::numeric_limits; + + static std::vector saved; + static constexpr size_t nsave = 65537; + static size_t idx = 0; + + if (saved.empty()) { + float range = Limits::is_integer ? 10.f : 1.f; + float v_min = cute::max(-range, float(Limits::lowest())); + float v_max = cute::min(+range, float(Limits::max())); + + saved.resize(nsave); + for (auto &x: saved) + x = T(v_min + (v_max - v_min) * (float(rand()) / float(RAND_MAX))); + } + + auto v = saved[idx++]; + if (idx >= nsave) idx -= nsave; + + return v; +} + +template +void +random_fill(InTensor &X) +{ + using T = typename InTensor::element_type; + + for (int i = 0; i < size(X); i++) + X(i) = random_value(); +} + +template +void +zero_fill(InTensor &X) +{ + using T = typename InTensor::element_type; + + for (int i = 0; i < size(X); i++) + X(i) = T(0); +} + +// Pack sub-byte types in a gmem tensor. +// On input, the backing array holds one sub-byte value per byte. +// On exit, the backing array contains packed values. +template +void +subbyte_pack(InTensor &X) +{ + using namespace cute; + using T = typename InTensor::element_type; + + if constexpr (sizeof_bits_v % 8 != 0) { + static_assert(sizeof_bits_v == 4, "Unsupported sub-byte data size"); + + auto ptr = recast_ptr(&*X.data()); + auto bytes = X.size(); + + for (size_t i = 0; i < bytes/2; i++) + ptr[i] = ptr[2*i] | (ptr[2*i + 1] << 4); + if (bytes & 1) + ptr[bytes >> 1] = ptr[bytes - 1]; + } +} + +// Retrieve a user-friendly string representation of an element type. +template +const char *type_str() +{ + using namespace cute; + using T_ = remove_cvref_t; +#define CASE(x, y) if (is_same_v) return #y; +#define ICASE(x) CASE(x, x) + ICASE(double) + ICASE(float) + CASE(tfloat32_t, tf32) + CASE(half_t, half) + CASE(bfloat16_t, bf16) + CASE(float_e5m2_t, e5m2) + CASE(float_e4m3_t, e4m3) + CASE(float_e2m1_t, e2m1) + CASE(int32_t, int32) + CASE(uint32_t, uint32) + CASE(int8_t, int8) + CASE(uint8_t, uint8) + CASE(int4_t, int4) + CASE(uint4_t, uint4) +#undef CASE + return ""; +} + + +template +auto +make_shared_usm_tensor(sycl::queue &Q, int r, int c) +{ + using namespace cute; + auto ptr = make_gmem_ptr(sycl::malloc_shared(r*c, Q)); + auto shape = make_shape(r, c); + if constexpr (LayoutKind == 'C') + return make_tensor(ptr, make_layout(shape, make_stride(_1{}, r))); + else + return make_tensor(ptr, make_layout(shape, make_stride(c, _1{}))); +} + +template +void +free_usm_tensor(InTensor &X, sycl::queue &Q) +{ + // RAII? What's that? + sycl::free(&*X.data(), Q); +} diff --git a/examples/cute/tutorial/CMakeLists.txt b/examples/cute/tutorial/CMakeLists.txt index a38af5deaa..06d1dade29 100644 --- a/examples/cute/tutorial/CMakeLists.txt +++ b/examples/cute/tutorial/CMakeLists.txt @@ -47,16 +47,21 @@ if (CUTLASS_ENABLE_SYCL) ) if (SYCL_INTEL_TARGET) - message(STATUS "Building CuTE BMG examples for Intel GPU targets") + message(STATUS "Building CuTE examples for Intel GPU targets") cutlass_example_add_executable( cute_tutorial_bmg - sgemm_bmg.cpp + bgemm_bmg_legacy.cpp + ) + + cutlass_example_add_executable( + cute_tutorial_xe_gemm + xe_gemm.cpp ) endif() - + if (SYCL_NVIDIA_TARGET) message(STATUS "Building CuTE examples for NVIDIA GPU targets with SYCL") - + cutlass_example_add_executable( cute_tutorial_sgemm_sm70 sgemm_sm70_sycl.cpp @@ -71,4 +76,4 @@ if (CUTLASS_ENABLE_SYCL) else() message(STATUS "CUTLASS_ENABLE_SYCL is OFF - CuTE tutorial examples will not be built") message(STATUS "Enable SYCL support to build CuTE tutorial examples") -endif() \ No newline at end of file +endif() diff --git a/examples/cute/tutorial/sgemm_bmg.cpp b/examples/cute/tutorial/bgemm_bmg_legacy.cpp similarity index 99% rename from examples/cute/tutorial/sgemm_bmg.cpp rename to examples/cute/tutorial/bgemm_bmg_legacy.cpp index 209510b8d0..9303360148 100644 --- a/examples/cute/tutorial/sgemm_bmg.cpp +++ b/examples/cute/tutorial/bgemm_bmg_legacy.cpp @@ -1,5 +1,6 @@ /*************************************************************************************************** * Copyright (c) 2025 - 2025 Codeplay Software Ltd. All rights reserved. + * Copyright (C) 2025 Intel Corporation, All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/examples/cute/tutorial/xe_gemm.cpp b/examples/cute/tutorial/xe_gemm.cpp new file mode 100644 index 0000000000..f23ff199e7 --- /dev/null +++ b/examples/cute/tutorial/xe_gemm.cpp @@ -0,0 +1,464 @@ +/*************************************************************************************************** +* Copyright (C) 2025 Intel Corporation, All rights reserved. +* SPDX-License-Identifier: BSD-3-Clause +* +* Redistribution and use in source and binary forms, with or without +* modification, are permitted provided that the following conditions are met: +* +* 1. Redistributions of source code must retain the above copyright notice, this +* list of conditions and the following disclaimer. +* +* 2. Redistributions in binary form must reproduce the above copyright notice, +* this list of conditions and the following disclaimer in the documentation +* and/or other materials provided with the distribution. +* +* 3. Neither the name of the copyright holder nor the names of its +* contributors may be used to endorse or promote products derived from +* this software without specific prior written permission. +* +* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +* +**************************************************************************************************/ + +#include +#include +#include + +#include + +#include "cutlass/kernel_hardware_info.h" +#include "cutlass/platform/platform.h" +#include "cutlass/tensor_ref.h" +#include "cutlass/util/sycl_event_manager.hpp" +#include "cutlass/util/GPU_Clock.hpp" +#include "cutlass/util/reference/device/gemm_complex.h" +#include "cutlass/util/reference/device/tensor_compare.h" +#include "cutlass/util/reference/host/tensor_fill.h" + +#include "../../common/sycl_cute_common.hpp" + +#pragma clang diagnostic ignored "-Wpass-failed" +#pragma clang diagnostic ignored "-Wdeprecated-declarations" + +using namespace cute; + + +template +void +gemm_device(ATensor const& A, // (M,K) + BTensor const& B, // (N,K) + CTensor & C, // (M,N) + TiledMMA const & mma) +{ + // ----- + // Setup + // ----- + + /* Get workgroup and local IDs */ + auto item = sycl::ext::oneapi::this_work_item::get_nd_item<2>(); + auto wg_m = int(item.get_group(1)); + auto wg_n = int(item.get_group(0)); + auto local_id = int(item.get_local_id(0)); + + /* Create proxy coordinate tensors for each global tensor */ + Tensor cA = make_identity_tensor(A.shape()); // (M,K) + Tensor cB = make_identity_tensor(B.shape()); // (N,K) + Tensor cC = make_identity_tensor(C.shape()); // (M,N) + + /* Split GEMM into workgroup tiles, and identify our workgroup's tile (wg_coord) */ + auto wg_tile = mma.tile_mnk(); + auto wg_coord = make_coord(wg_m, wg_n, 0); + + Tensor gA = local_tile(cA, select<0,2>(wg_tile), make_coord(wg_m,_)); // (BLK_M,BLK_K,k) + Tensor gB = local_tile(cB, select<1,2>(wg_tile), make_coord(wg_n,_)); // (BLK_N,BLK_K,k) + Tensor gC = local_tile(cC, wg_tile, wg_coord, Step<_1,_1, X>{}); // (BLK_M,BLK_N) + + /* Create block 2D TiledCopies */ + auto copy_a = make_block_2d_copy_A(mma, A); + auto copy_b = make_block_2d_copy_B(mma, B); + auto copy_c = make_block_2d_copy_C(mma, C); + + /* Slice TiledCopy/TiledMMA operations to thread (work-item) level */ + auto thr_mma = mma.get_slice(local_id); + auto thr_copy_a = copy_a.get_slice(local_id); + auto thr_copy_b = copy_b.get_slice(local_id); + + /* Register fragments for MMA */ + auto tCrA = thr_mma.partition_sg_fragment_A(gA(_,_,0)); + auto tCrB = thr_mma.partition_sg_fragment_B(gB(_,_,0)); + + /* Register fragments for copies */ + auto tArA = thr_copy_a.partition_sg_fragment_D(gA(_,_,0)); + auto tBrB = thr_copy_b.partition_sg_fragment_D(gB(_,_,0)); + + /* Partition global tensor (proxies) for copies */ + Tensor tAgA = thr_copy_a.partition_S(gA); + Tensor tBgB = thr_copy_b.partition_S(gB); + + /* Partition C */ + Tensor tCrC = partition_fragment_C(mma, select<0,1>(wg_tile)); + Tensor tCgC = thr_mma.partition_C(gC); /* also matches copy_c's source layout */ + + /* Create prefetch TiledCopy instances */ + auto prefetch_a = make_block_2d_prefetch(copy_a); + auto prefetch_b = make_block_2d_prefetch(copy_b); + + auto thr_prefetch_A = prefetch_a.get_slice(local_id); + auto thr_prefetch_B = prefetch_b.get_slice(local_id); + + /* Partition global tensor (proxies) for prefetch */ + auto pAgA = thr_prefetch_A.partition_S(gA); + auto pBgB = thr_prefetch_B.partition_S(gB); + + /* Prefetch distance, in units of k tiles */ + const int prefetch_dist = 3; + + // ------ + // Kernel + // ------ + + constexpr int barrier_scope = 2; + + int k_tile_count = ceil_div(shape<1>(A), get<2>(wg_tile)); + int k_tile_prefetch = 0; + + /* Clear the accumulators */ + clear(tCrC); + + /* Warm up loops with prefetch to L1 */ + CUTE_UNROLL + for (; k_tile_prefetch < prefetch_dist; k_tile_prefetch++) { + prefetch(prefetch_a, pAgA(_,_,_,k_tile_prefetch)); + prefetch(prefetch_b, pBgB(_,_,_,k_tile_prefetch)); + } + + /* Main loop */ + for (int k_tile = 0; k_tile < k_tile_count; k_tile++, k_tile_prefetch++) { + /* Split barrier keeping threads loosely together */ + barrier_arrive(barrier_scope); + + /* Copy A/B from global memory (ideally L1 cache) to registers */ + copy(copy_a, tAgA(_,_,_,k_tile), tArA); + copy(copy_b, tBgB(_,_,_,k_tile), tBrB); + + /* Prefetch A/B tiles to L1 */ + prefetch(prefetch_a, pAgA(_,_,_,k_tile_prefetch)); + prefetch(prefetch_b, pBgB(_,_,_,k_tile_prefetch)); + + /* Shuffle data from copy fragments to MMA fragments */ + reorder(tArA, tCrA); + reorder(tBrB, tCrB); + + /* Accumulate C += A * B */ + gemm(mma, tCrA, tCrB, tCrC); + + /* Other half of split barrier */ + barrier_wait(barrier_scope); + } + + /* Write C to global memory */ + copy(copy_c, tCrC, tCgC); +} + + + +template +struct is_complete : std::false_type {}; + +template +struct is_complete : std::true_type {}; + +template +static constexpr bool is_complete_v = is_complete::value; + + +template +auto +choose_mma_op() +{ + if constexpr (is_complete_v>) + return XE_DPAS_TT<8, TC, TA, TB>{}; + else if constexpr (is_same_v) + return XE_DPAS_TT<8, float, cute::bfloat16_t>{}; + else /* Use f16 by default as upconversion sequences are typically faster */ + return XE_DPAS_TT<8, float, cute::half_t>{}; +} + +template +auto +choose_tiled_mma(ATensor const& A, BTensor const& B, CTensor const&) +{ + using TA = typename ATensor::element_type; + using TB = typename BTensor::element_type; + using TC = typename CTensor::element_type; + + auto op = choose_mma_op(); + + constexpr bool byte = (cute::max(sizeof_bits_v, sizeof_bits_v) <= 8); + constexpr bool use_1x_dpas_per_k = is_constant_v<1, decltype(stride<0>(A))> // Use one DPAS in k dimension for A^T case + || (byte && is_constant_v<1, decltype(stride<0>(B))>); // pending compiler improvements (also int8 B^N) + + using _K = conditional_t, C>; + + using WGTile = Shape<_256, _256, _K>; // 256x256 WG tile size + using SGLayout = Layout, Stride<_4, _1, _0>>; // 8x4 SG tiling, n-major + + using MMA = typename TiledMMAHelper, Layout, SGLayout>::TiledMMA; + + return MMA{}; +} + +template +void +gemm_cute(sycl::queue &Q, + ATensor const& A, // (M,K) + BTensor const& B, // (N,K) + CTensor & C) // (M,N) +{ + auto mma = choose_tiled_mma(A, B, C); + + sycl::range<2> local = {size(mma), 1}; + sycl::range<2> global = {local[0] * ceil_div(shape<0>(B), get<1>(mma.tile_mnk())), + local[1] * ceil_div(shape<0>(A), get<0>(mma.tile_mnk()))}; + + namespace syclex = sycl::ext::oneapi::experimental; + namespace intelex = sycl::ext::intel::experimental; + + syclex::properties kernel_props { + syclex::sub_group_size<16>, + intelex::grf_size<256> + }; + + auto event = Q.parallel_for(sycl::nd_range<2>(global, local), kernel_props, + [=](auto) { + gemm_device(A, B, C, mma); + } + ); + + EventManager::getInstance().addEvent(event); +} + +template +bool +gemm_verify(sycl::queue &Q, + ATensor const& A, // (M,K) + BTensor const& B, // (N,K) + CTensor const& C) // (M,N) +{ + int m = size<0>(A); + int n = size<0>(B); + int k = size<1>(A); + + auto ok = sycl::malloc_shared(1, Q); + *ok = true; + + Q.parallel_for(sycl::range<2>(m, n), [=](sycl::item<2> id) { + int i = id[0], j = id[1]; + + using AccType = typename CTensor::element_type; + using SignedAccType = ensure_signed_t; + + auto c = AccType(0); + for (int h = 0; h < k; h++) + c += AccType(A(i,h)) * AccType(B(j,h)); + + auto tol = AccType(1e-5f * k); + if (std::abs(SignedAccType(c - AccType(C(i,j)))) > tol) { +#ifdef SHOW_DIFF + printf("Error at (%d,%d): got %f, expected %f\n", i, j, double(C(i,j)), double(c)); +#endif + *ok = false; + } + }).wait(); + + bool read_ok = *ok; + + sycl::free(ok, Q); + + return read_ok; +} + +template +void +test_case(sycl::queue &Q, int m, int n, int k) +{ + std::cout << type_str() << " (" << layoutA << ") x " + << type_str() << " (" << layoutB << ") -> " + << type_str() << ": \t"; + + // Transpose B to match CuTe conventions + constexpr char tlayoutB = layoutB ^ ('R' ^ 'C'); + + // Prepare data: + auto A = make_shared_usm_tensor(Q, m, k); + auto B = make_shared_usm_tensor(Q, n, k); + auto C = make_shared_usm_tensor(Q, m, n); + + random_fill(A); + random_fill(B); + zero_fill(C); + +#ifndef SKIP_VERIFY + auto A_ref = make_shared_usm_tensor(Q, m, k); + auto B_ref = make_shared_usm_tensor(Q, n, k); + + copy(A, A_ref); + copy(B, B_ref); +#endif + + subbyte_pack(A); + subbyte_pack(B); + + // Test accuracy: + gemm_cute(Q, A, B, C); + Q.wait_and_throw(); + +#ifdef SKIP_VERIFY + const bool ok = true; + std::cout << "verification skipped"; +#else + bool ok = gemm_verify(Q, A_ref, B_ref, C); + std::cout << (ok ? "passed" : "failed"); +#endif + + if (ok) { + // Test performance: + const int timing_iterations = 100; + GPU_Clock timer; + + timer.start(); + for (int i = 0; i < timing_iterations; ++i) + gemm_cute(Q, A, B, C); + Q.wait_and_throw(); + + double avg = timer.seconds() / timing_iterations; + double tops = (2.0*m*n*k) * 1e-12; + + printf(", %4.3f TF/s", tops / avg, avg*1000); + } + + free_usm_tensor(A, Q); + free_usm_tensor(B, Q); + free_usm_tensor(C, Q); + +#ifndef SKIP_VERIFY + free_usm_tensor(A_ref, Q); + free_usm_tensor(B_ref, Q); +#endif + + std::cout << '\n'; + + // Pause for a short period of time to allow the GPU to cool. + static bool first = true; + if (first) + first = false; + else + sleep(1); +} + + +int main(int argc, char** argv) +{ + auto shift = [&] { + return (argc-- > 0) ? *argv++ : nullptr; + }; + + auto parse_size = [&] { + static constexpr int default_size = 4096; + if (auto e = shift()) + return atoi(e); + else + return default_size; + }; + + (void) shift(); + + auto m = parse_size(); + auto n = parse_size(); + auto k = parse_size(); + + sycl::queue Q; + + // Native compute + test_case(Q, m, n, k); + test_case(Q, m, n, k); + test_case(Q, m, n, k); + + test_case(Q, m, n, k); + test_case(Q, m, n, k); + test_case(Q, m, n, k); + + test_case(Q, m, n, k); + test_case(Q, m, n, k); + test_case(Q, m, n, k); + + test_case(Q, m, n, k); + test_case(Q, m, n, k); + test_case(Q, m, n, k); + + test_case(Q, m, n, k); + test_case(Q, m, n, k); + + test_case(Q, m, n, k); + + // Upconversion cases + test_case(Q, m, n, k); + test_case(Q, m, n, k); + + test_case(Q, m, n, k); + test_case(Q, m, n, k); + + test_case(Q, m, n, k); + test_case(Q, m, n, k); + + test_case(Q, m, n, k); + test_case(Q, m, n, k); + + test_case(Q, m, n, k); + test_case(Q, m, n, k); + + test_case(Q, m, n, k); + test_case(Q, m, n, k); + + test_case(Q, m, n, k); + test_case(Q, m, n, k); + + test_case(Q, m, n, k); + test_case(Q, m, n, k); + + test_case(Q, m, n, k); + test_case(Q, m, n, k); + + test_case(Q, m, n, k); + test_case(Q, m, n, k); + + test_case(Q, m, n, k); + test_case(Q, m, n, k); + + test_case(Q, m, n, k); + test_case(Q, m, n, k); + + test_case(Q, m, n, k); + test_case(Q, m, n, k); + + test_case(Q, m, n, k); + test_case(Q, m, n, k); + + test_case(Q, m, n, k); + test_case(Q, m, n, k); + + test_case(Q, m, n, k); + test_case(Q, m, n, k); +} diff --git a/media/docs/cpp/xe_rearchitecture.md b/media/docs/cpp/xe_rearchitecture.md index 1b8d159b3f..7e5c49bfb4 100644 --- a/media/docs/cpp/xe_rearchitecture.md +++ b/media/docs/cpp/xe_rearchitecture.md @@ -382,7 +382,7 @@ reorder(Tensor<...> const& src_fragment, ## Example CuTe GEMM -Let's combine the pieces here to make a complete CuTe-level Xe GEMM kernel. +Let's combine the pieces here to make a complete CuTe-level Xe GEMM kernel. Fully runnable code can be found at [examples/cute/tutorial/xe_gemm.cpp](../../../examples/cute/tutorial/xe_gemm.cpp). ```c++ template (); + auto wg_m = int(item.get_group(1)); + auto wg_n = int(item.get_group(0)); + auto local_id = int(item.get_local_id(0)); + /* Create proxy coordinate tensors for each global tensor */ Tensor cA = make_identity_tensor(A.shape()); // (M,K) Tensor cB = make_identity_tensor(B.shape()); // (N,K) @@ -406,11 +410,11 @@ gemm_device(ATensor const& A, // (M,K) /* Split GEMM into workgroup tiles, and identify our workgroup's tile (wg_coord) */ auto wg_tile = mma.tile_mnk(); - auto wg_coord = make_coord(BlockIdxX(), BlockIdxY(), 0); + auto wg_coord = make_coord(wg_m, wg_n, 0); - Tensor gA = local_tile(cA, select<0,2>(wg_tile), make_coord(BlockIdxX(),_)); // (BLK_M,BLK_K,k) - Tensor gB = local_tile(cB, select<1,2>(wg_tile), make_coord(BlockIdxY(),_)); // (BLK_N,BLK_K,k) - Tensor gC = local_tile(cC, wg_tile, wg_coord, Step<_1,_1, X>{}); // (BLK_M,BLK_N) + Tensor gA = local_tile(cA, select<0,2>(wg_tile), make_coord(wg_m,_)); // (BLK_M,BLK_K,k) + Tensor gB = local_tile(cB, select<1,2>(wg_tile), make_coord(wg_n,_)); // (BLK_N,BLK_K,k) + Tensor gC = local_tile(cC, wg_tile, wg_coord, Step<_1,_1, X>{}); // (BLK_M,BLK_N) /* Create block 2D TiledCopies */ auto copy_a = make_block_2d_copy_A(mma, A); @@ -418,11 +422,9 @@ gemm_device(ATensor const& A, // (M,K) auto copy_c = make_block_2d_copy_C(mma, C); /* Slice TiledCopy/TiledMMA operations to thread (work-item) level */ - int thread_idx = int(ThreadIdxX()); - - auto thr_mma = mma.get_slice(thread_idx); - auto thr_copy_a = copy_a.get_slice(thread_idx); - auto thr_copy_b = copy_b.get_slice(thread_idx); + auto thr_mma = mma.get_slice(local_id); + auto thr_copy_a = copy_a.get_slice(local_id); + auto thr_copy_b = copy_b.get_slice(local_id); /* Register fragments for MMA */ auto tCrA = thr_mma.partition_sg_fragment_A(gA(_,_,0)); @@ -444,45 +446,47 @@ gemm_device(ATensor const& A, // (M,K) auto prefetch_a = make_block_2d_prefetch(copy_a); auto prefetch_b = make_block_2d_prefetch(copy_b); - auto thr_prefetch_A = prefetch_a.get_slice(thread_idx); - auto thr_prefetch_B = prefetch_b.get_slice(thread_idx); + auto thr_prefetch_A = prefetch_a.get_slice(local_id); + auto thr_prefetch_B = prefetch_b.get_slice(local_id); /* Partition global tensor (proxies) for prefetch */ auto pAgA = thr_prefetch_A.partition_S(gA); auto pBgB = thr_prefetch_B.partition_S(gB); + /* Prefetch distance, in units of k tiles */ + const int prefetch_dist = 2; + // ------ // Kernel // ------ constexpr int barrier_scope = 2; - int k_tile_count = ceil_div(get<2>(shape_MNK), get<2>(cta_tiler)); + int k_tile_count = ceil_div(shape<1>(A), get<2>(wg_tile)); int k_tile_prefetch = 0; /* Clear the accumulators */ clear(tCrC); /* Warm up loops with prefetch to L1 */ - CUTLASS_PRAGMA_UNROLL - for (; k_tile_prefetch < stages; k_tile_prefetch++) { - prefetch(prefetch_a, pAgA(_,_,k_tile_prefetch)); - prefetch(prefetch_b, pBgB(_,_,k_tile_prefetch)); + CUTE_UNROLL + for (; k_tile_prefetch < prefetch_dist; k_tile_prefetch++) { + prefetch(prefetch_a, pAgA(_,_,_,k_tile_prefetch)); + prefetch(prefetch_b, pBgB(_,_,_,k_tile_prefetch)); } /* Main loop */ - CUTLASS_PRAGMA_UNROLL for (int k_tile = 0; k_tile < k_tile_count; k_tile++, k_tile_prefetch++) { /* Split barrier keeping threads loosely together */ barrier_arrive(barrier_scope); /* Copy A/B from global memory (ideally L1 cache) to registers */ - copy(copy_a, tAgA(_,_,k_tile), tArA); - copy(copy_b, tBgB(_,_,k_tile), tBrB); + copy(copy_a, tAgA(_,_,_,k_tile), tArA); + copy(copy_b, tBgB(_,_,_,k_tile), tBrB); /* Prefetch A/B tiles to L1 */ - prefetch(prefetch_a, pAgA(_,_,k_tile_prefetch)); - prefetch(prefetch_b, pBgB(_,_,k_tile_prefetch)); + prefetch(prefetch_a, pAgA(_,_,_,k_tile_prefetch)); + prefetch(prefetch_b, pBgB(_,_,_,k_tile_prefetch)); /* Shuffle data from copy fragments to MMA fragments */ reorder(tArA, tCrA);