diff --git a/include/cute/atom/copy_atom.hpp b/include/cute/atom/copy_atom.hpp index 5627b722af..90610ec736 100644 --- a/include/cute/atom/copy_atom.hpp +++ b/include/cute/atom/copy_atom.hpp @@ -769,7 +769,7 @@ print_latex_copy(LayoutS const& S, ThrIDS const& TS, // (m,n) -> (tid,vid) and #include #endif -#if defined(CUTLASS_ENABLE_SYCL) +#if defined(SYCL_INTEL_TARGET) #include #endif diff --git a/include/cute/atom/copy_traits_xe.hpp b/include/cute/atom/copy_traits_xe.hpp index af33cb68b6..b60ef8811c 100644 --- a/include/cute/atom/copy_traits_xe.hpp +++ b/include/cute/atom/copy_traits_xe.hpp @@ -53,8 +53,17 @@ struct XE_2D_LD_Unpack static_assert(is_rmem::value); int H = size<0>(traits.tensor); int W = size<1>(traits.tensor) * sizeof(typename Copy_Traits::CopyInternalType); - auto [y, x] = src.data().coord_; - CopyOp::copy(traits.tensor.data().get(), W, H, W, intel::coord_t{static_cast(x), static_cast(y)}, &*dst.data()); + auto [y, x, z] = src.data().coord_; + CopyOp::copy(traits.tensor.data() + z, W, H, W, intel::coord_t{x, y}, &*dst.data()); + } + + template + CUTE_HOST_DEVICE constexpr auto + get_pvc_tensor(GCoord const& coord, GShape const& shape, GStride const& stride_mul) const + { + return make_tensor(make_inttuple_iter(coord), + make_layout(make_shape(_1{}, get<0>(shape), get<1>(shape), get<2>(shape)), + make_stride(_1{}, E<0>{} * get<0>(stride_mul), E<1>{} * get<1>(stride_mul), E<2>{} * get<2>(stride(tensor))))); } }; @@ -274,8 +283,17 @@ struct XE_2D_ST_Unpack static_assert(is_rmem::value); int H = size<0>(traits.tensor); int W = size<1>(traits.tensor) * sizeof(typename Copy_Traits::CopyInternalType); - auto [y, x] = dst.data().coord_; - CopyOp::copy(traits.tensor.data().get(), W, H, W, intel::coord_t{static_cast(x), static_cast(y)}, &*src.data()); + auto [y, x, z] = dst.data().coord_; + CopyOp::copy(traits.tensor.data() + z, W, H, W, intel::coord_t{x, y}, &*src.data()); + } + + template + CUTE_HOST_DEVICE constexpr auto + get_pvc_tensor(GCoord const& coord, GShape const& shape, GStride const& stride_mul) const + { + return make_tensor(make_inttuple_iter(coord), + make_layout(make_shape(_1{}, get<0>(shape), get<1>(shape), get<2>(shape)), + make_stride(_1{}, E<0>{} * get<0>(stride_mul), E<1>{} * get<1>(stride_mul), E<2>{} * get<2>(stride(tensor))))); } }; diff --git a/include/cute/atom/mma_atom.hpp b/include/cute/atom/mma_atom.hpp index ffb6a08b0c..d7427ef802 100644 --- a/include/cute/atom/mma_atom.hpp +++ b/include/cute/atom/mma_atom.hpp @@ -38,6 +38,10 @@ #include #include +#if defined(CUTLASS_ENABLE_SYCL) +#include +#endif + namespace cute { template diff --git a/include/cute/atom/mma_traits_xe.hpp b/include/cute/atom/mma_traits_xe.hpp index a24cc3386f..a5ef6dbec2 100644 --- a/include/cute/atom/mma_traits_xe.hpp +++ b/include/cute/atom/mma_traits_xe.hpp @@ -41,8 +41,8 @@ template <> struct MMA_Traits { using ValTypeD = float; - using ValTypeA = sycl::ext::oneapi::bfloat16; - using ValTypeB = sycl::ext::oneapi::bfloat16; + using ValTypeA = bfloat16_t; + using ValTypeB = bfloat16_t; using ValTypeC = float; using Shape_MNK = Shape<_8,_16,_16>; diff --git a/include/cute/config.hpp b/include/cute/config.hpp index 4cf38929e5..b669c6ba80 100644 --- a/include/cute/config.hpp +++ b/include/cute/config.hpp @@ -34,9 +34,9 @@ # define CUTE_HOST_DEVICE __forceinline__ __host__ __device__ # define CUTE_DEVICE __forceinline__ __device__ # define CUTE_HOST __forceinline__ __host__ -#elif defined(__SYCL_CUDA_ARCH__) -# define CUTE_HOST_DEVICE __attribute__((always_inline)) -# define CUTE_DEVICE __attribute__((always_inline)) +#elif defined(__SYCL_DEVICE_ONLY__) +# define CUTE_HOST_DEVICE __attribute__((always_inline)) inline +# define CUTE_DEVICE __attribute__((always_inline)) inline # define CUTE_HOST inline #else # define CUTE_HOST_DEVICE inline diff --git a/include/cute/tensor.hpp b/include/cute/tensor.hpp index 28d3ee67a9..85af8589eb 100644 --- a/include/cute/tensor.hpp +++ b/include/cute/tensor.hpp @@ -42,6 +42,7 @@ #include #include +#include namespace cute { diff --git a/include/cutlass/arch/arch.h b/include/cutlass/arch/arch.h index 6c7941735d..b87d899a32 100644 --- a/include/cutlass/arch/arch.h +++ b/include/cutlass/arch/arch.h @@ -97,6 +97,13 @@ struct Sm90 { static int const kMinComputeCapability = 90; }; +#if defined(CUTLASS_ENABLE_SYCL) +struct IntelPVC { + static int const kMinComputeCapability = 0; +}; + +#endif + /// Triggers a breakpoint on the device CUTLASS_DEVICE void device_breakpoint() { diff --git a/include/cutlass/gemm/collective/collective_mma.hpp b/include/cutlass/gemm/collective/collective_mma.hpp index 91c801762a..c80b2f1a34 100644 --- a/include/cutlass/gemm/collective/collective_mma.hpp +++ b/include/cutlass/gemm/collective/collective_mma.hpp @@ -75,4 +75,8 @@ struct CollectiveMma { #include "cutlass/gemm/collective/sm90_mma_tma_gmma_ss_warpspecialized.hpp" #include "cutlass/gemm/collective/sm90_mma_array_tma_gmma_ss_warpspecialized.hpp" #include "cutlass/gemm/collective/sm90_mma_tma_gmma_ss_warpspecialized_fp8.hpp" + +#if defined(SYCL_INTEL_TARGET) +#include "cutlass/gemm/collective/intel_pvc_mma.hpp" +#endif ///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/gemm/collective/intel_pvc_mma.hpp b/include/cutlass/gemm/collective/intel_pvc_mma.hpp new file mode 100644 index 0000000000..83d46afa69 --- /dev/null +++ b/include/cutlass/gemm/collective/intel_pvc_mma.hpp @@ -0,0 +1,220 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2024 Codeplay Software Ltd. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/gemm/dispatch_policy.hpp" + +#include "cute/algorithm/functional.hpp" +#include "cute/atom/mma_atom.hpp" +#include "cute/algorithm/gemm.hpp" +#include "cute/atom/mma_atom.hpp" +#include "cute/tensor_predicate.hpp" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::gemm::collective { +using namespace cute; +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + class TileShape_, + class ElementA_, + class StrideA_, + class ElementB_, + class StrideB_, + class TiledMma_, + class GmemTiledCopyA_, + class SmemLayoutAtomA_, + class SmemCopyAtomA_, + class TransformA_, + class GmemTiledCopyB_, + class SmemLayoutAtomB_, + class SmemCopyAtomB_, + class TransformB_> +struct CollectiveMma< + MainloopIntelPVCUnpredicated, + TileShape_, + ElementA_, + StrideA_, + ElementB_, + StrideB_, + TiledMma_, + GmemTiledCopyA_, + SmemLayoutAtomA_, + SmemCopyAtomA_, + TransformA_, + GmemTiledCopyB_, + SmemLayoutAtomB_, + SmemCopyAtomB_, + TransformB_> +{ + // + // Type Aliases + // + using DispatchPolicy = MainloopIntelPVCUnpredicated; + using TileShape = TileShape_; + using ElementA = ElementA_; + using StrideA = StrideA_; + using ElementB = ElementB_; + using StrideB = StrideB_; + using TiledMma = TiledMma_; + using ElementAccumulator = typename TiledMma::ValTypeC; + using GmemTiledCopyA = GmemTiledCopyA_; + using GmemTiledCopyB = GmemTiledCopyB_; + using SmemLayoutAtomA = SmemLayoutAtomA_; + using SmemLayoutAtomB = SmemLayoutAtomB_; + using SmemCopyAtomA = SmemCopyAtomA_; + using SmemCopyAtomB = SmemCopyAtomB_; + using TransformA = TransformA_; + using TransformB = TransformB_; + using ArchTag = typename DispatchPolicy::ArchTag; + + static constexpr int SubgroupSize = DispatchPolicy::SubgroupSize; + + static constexpr int DpasM = get<0>(shape(typename TiledMma::LayoutA_TV{})); // rows per dpas operation per sub_group for Matrix A + static constexpr int DpasN = get<1>(shape(typename TiledMma::LayoutB_TV{})); // cols per dpas operation per sub_group for Matrix B + static constexpr int DpasK = get<1>(shape(typename TiledMma::LayoutA_TV{})); // cols per dpas operation per sub_group for Matrix A + + static constexpr uint32_t MaxThreadsPerBlock = DpasM * DpasN; + static constexpr uint32_t MinBlocksPerMultiprocessor = 1; + + static constexpr int FragsM = get<0>(TileShape{}) / DpasM; // A frags per sub_group + static constexpr int FragsN = get<1>(TileShape{}) / DpasN; // B frags per sub_group + static constexpr int FragsK = get<2>(TileShape{}) / DpasK; + + // Calculate the vector width based on the amount of registers + // required per work item by dividing the total fragment size by + // the sub_group size. + static constexpr int VecC = (DpasN * DpasM) / SubgroupSize; + static constexpr int VecA = (DpasM * DpasK) / SubgroupSize; + static constexpr int VecB = (DpasN * DpasK) / SubgroupSize; + + // Host side kernel arguments + struct Arguments { + ElementA const* ptr_A; + StrideA dA; + ElementB const* ptr_B; + StrideB dB; + }; + + struct Params { + using XE_Copy_A = decltype(make_xe_2d_copy(make_tensor(static_cast(nullptr), + repeat_like(StrideA{}, int32_t(0)), StrideA{}))); + using XE_Copy_B = decltype(make_xe_2d_copy(make_tensor(static_cast(nullptr), + repeat_like(StrideB{}, int32_t(0)), StrideB{}))); + XE_Copy_A gmem_tiled_copy_a; + XE_Copy_B gmem_tiled_copy_b; + }; + + // + // Methods + // + + CollectiveMma() = default; + + template + static constexpr Params + to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) { + (void) workspace; + + auto problem_shape_MNKL = append<4>(problem_shape, 1); + auto [M,N,K,L] = problem_shape_MNKL; + + Tensor tensorA = make_tensor(args.ptr_A, make_layout(make_shape(M,K,L), args.dA)); + Tensor tensorB = make_tensor(args.ptr_B, make_layout(make_shape(K,N,L), args.dB)); + + typename Params::XE_Copy_A copyA = make_xe_2d_copy(tensorA); + typename Params::XE_Copy_B copyB = make_xe_2d_copy(tensorB); + return Params{copyA, copyB}; + } + + /// Perform a subgroup-scoped matrix multiply-accumulate + template < + class FrgTensorD, + class TensorA, + class TensorB, + class FrgTensorC, + class KTileIterator, + class ResidueMNK + > + CUTLASS_DEVICE void + operator() ( + FrgTensorD &accum, + TensorA gA, + TensorB gB, + FrgTensorC const &src_accum, + KTileIterator k_tile_iter, int k_tile_count, + ResidueMNK residue_mnk, + int thread_idx, + char *smem_buf, + Params const& mainloop) + { + (void)residue_mnk; + (void)thread_idx; + (void)smem_buf; + + static_assert(is_rmem::value, "D tensor must be rmem resident."); + static_assert(is_tuple::value, "A tensor must be a tuple iterator."); + static_assert(is_tuple::value, "B tensor must be a tuple iterator."); + static_assert(is_rmem::value, "C tensor must be rmem resident."); + + // Tensor to hold input data + Tensor tAr = make_tensor(Shape(TileShape{}) * FragsK>, Int<1>>{}); + Tensor tBr = make_tensor(Shape(TileShape{}) / FragsN>, Int>{}); + + Tensor tAr_view = make_tensor(static_cast(tAr).data(), + Shape, Int, Int>{}); + Tensor tBr_view = make_tensor(static_cast(tBr).data(), + Shape, Int, Int>{}); + + // Instantiate the M MA object + TiledMma tiled_mma; + + // + // Mainloop + // + for (int k_tile = 0, k = 0; k_tile < k_tile_count; ++k_tile, k += DpasK * FragsK) + { + // Copy gmem to rmem for the first k_tile + copy(mainloop.gmem_tiled_copy_a, gA(_,_,k), tAr); + copy(mainloop.gmem_tiled_copy_b, gB(_,k/2,_), tBr); + + for (int kl = 0; kl < FragsK; kl++) { + cute::gemm(tiled_mma, accum, tAr_view(_, _, kl), tBr_view(_, kl, _), src_accum); + } + } + } +}; + +} // namespace cutlass::gemm::collective + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/gemm/device/gemm_universal_adapter.h b/include/cutlass/gemm/device/gemm_universal_adapter.h index 283c85e3cb..680dd5d4ce 100644 --- a/include/cutlass/gemm/device/gemm_universal_adapter.h +++ b/include/cutlass/gemm/device/gemm_universal_adapter.h @@ -402,7 +402,11 @@ class GemmUniversalAdapter< const auto sycl_block = syclcompat::dim3(block.x, block.y, block.z); const auto sycl_grid = syclcompat::dim3(grid.x, grid.y, grid.z); +#if defined (SYCL_INTEL_TARGET) + syclcompat::experimental::launch, DispatchPolicy::SubgroupSize>(sycl_grid, sycl_block, smem_size, params); +#else syclcompat::launch>(sycl_grid, sycl_block, smem_size, params); +#endif #else device_kernel<<>>(params); #endif diff --git a/include/cutlass/gemm/dispatch_policy.hpp b/include/cutlass/gemm/dispatch_policy.hpp index 3694d0a87b..a772a08d25 100644 --- a/include/cutlass/gemm/dispatch_policy.hpp +++ b/include/cutlass/gemm/dispatch_policy.hpp @@ -71,6 +71,7 @@ enum class KernelInputTransformType { // // Kernel schedule policies (the base class tags, one for each kernel layer file) // +struct KernelSinglestage { }; struct KernelMultistage { }; struct KernelCpAsyncWarpSpecialized { }; struct KernelCpAsyncWarpSpecializedPingpong { }; @@ -269,6 +270,19 @@ struct MainloopSm90ArrayTmaGmmaWarpSpecialized { "KernelSchedule must be one of the Ptr-Array or Grouped Gemm TMA Warp Specialized Cooperative policies"); }; + +#if defined(SYCL_INTEL_TARGET) +struct MainloopIntelPVCBase { + constexpr static int Stages = 1; + using ArchTag = arch::IntelPVC; + using Schedule = KernelSinglestage; + using ClusterShape = Shape<_1,_1,_1>; + static constexpr int SubgroupSize = 16; +}; + +struct MainloopIntelPVCUnpredicated : MainloopIntelPVCBase{}; +#endif + ////////////////////////////////////////////////////////////////////////////// } // namespace cutlass::gemm diff --git a/include/cutlass/gemm/kernel/gemm_universal.hpp b/include/cutlass/gemm/kernel/gemm_universal.hpp index d1ad5288f2..083ffcd233 100644 --- a/include/cutlass/gemm/kernel/gemm_universal.hpp +++ b/include/cutlass/gemm/kernel/gemm_universal.hpp @@ -89,4 +89,8 @@ struct IsCutlass3ArrayKernel +class GemmUniversal< + ProblemShape_, + CollectiveMainloop_, + CollectiveEpilogue_, + TileScheduler_, + cute::enable_if_t>> +{ +public: + // + // Type Aliases + // + using ProblemShape = ProblemShape_; + + static_assert(rank(ProblemShape{}) == 3 or rank(ProblemShape{}) == 4, + "ProblemShape{} should be or "); + + // Mainloop derived types + using CollectiveMainloop = CollectiveMainloop_; + using TileShape = typename CollectiveMainloop::TileShape; + using TiledMma = typename CollectiveMainloop::TiledMma; + using ArchTag = typename CollectiveMainloop::ArchTag; + using ElementA = typename CollectiveMainloop::ElementA; + using StrideA = typename CollectiveMainloop::StrideA; + using ElementB = typename CollectiveMainloop::ElementB; + using StrideB = typename CollectiveMainloop::StrideB; + using DispatchPolicy = typename CollectiveMainloop::DispatchPolicy; + using ElementAccumulator = typename CollectiveMainloop::ElementAccumulator; + using MainloopArguments = typename CollectiveMainloop::Arguments; + using MainloopParams = typename CollectiveMainloop::Params; + + static_assert(cute::is_void_v or cute::is_same_v, + "Intel PVC does not support specializing the tile scheduler."); + using TileSchedulerTag = TileScheduler_; + using TileScheduler = typename detail::TileSchedulerSelector< + TileScheduler_, ArchTag, TileShape, + cute::Shape, cute::Int<1>, cute::Int<1>>>::Scheduler; + using TileSchedulerArguments = typename TileScheduler::Arguments; + + // Epilogue derived types + using CollectiveEpilogue = CollectiveEpilogue_; + using ElementC = typename CollectiveEpilogue::ElementC; + using StrideC = typename CollectiveEpilogue::StrideC; + using ElementD = typename CollectiveEpilogue::ElementD; + using StrideD = typename CollectiveEpilogue::StrideD; + using EpilogueArguments = typename CollectiveEpilogue::Arguments; + using EpilogueParams = typename CollectiveEpilogue::Params; + static_assert(cute::is_same_v, + "Mainloop and epilogue do not agree on accumulator value type."); + + // MSVC requires the cast to fix a warning-as-error. + static constexpr int SharedStorageSize = 0; + + static constexpr int SubgroupSize = CollectiveMainloop::SubgroupSize; // sub_group size + static constexpr uint32_t MaxThreadsPerBlock = CollectiveMainloop::MaxThreadsPerBlock; + static constexpr uint32_t MinBlocksPerMultiprocessor = CollectiveMainloop::MinBlocksPerMultiprocessor; + + static constexpr int num_sg = MaxThreadsPerBlock / SubgroupSize; // number of sub_groups per work group + + static constexpr int DpasM = CollectiveMainloop::DpasM; + static constexpr int DpasN = CollectiveMainloop::DpasN; + + static constexpr int FragsM = CollectiveMainloop::FragsM; + static constexpr int FragsN = CollectiveMainloop::FragsN; + + static constexpr int VecC = CollectiveMainloop::VecC; + + // Device side arguments + struct Arguments { + GemmUniversalMode mode{}; + ProblemShape problem_shape{}; + MainloopArguments mainloop{}; + EpilogueArguments epilogue{}; + KernelHardwareInfo hw_info{}; + TileSchedulerArguments scheduler{}; + }; + + // Kernel entry point API + struct Params { + GemmUniversalMode mode; + ProblemShape problem_shape; + MainloopParams mainloop; + EpilogueParams epilogue; + }; + + // + // Methods + // + + // Convert to underlying arguments. In this case, a simple copy for the aliased type. + static + Params + to_underlying_arguments(Arguments const& args, void* workspace) { + (void) workspace; + return { + args.mode, + args.problem_shape, + CollectiveMainloop::to_underlying_arguments(args.problem_shape, args.mainloop, workspace), + CollectiveEpilogue::to_underlying_arguments(args.problem_shape, args.epilogue, workspace) + }; + } + + static bool + can_implement(Arguments const& args) { + bool mode_implementable = args.mode == GemmUniversalMode::kGemm or + (args.mode == GemmUniversalMode::kBatched && rank(ProblemShape{}) == 4); + return mode_implementable && TileScheduler::can_implement(args.scheduler); + } + + static int + get_workspace_size(Arguments const& args) { + return 0; + } + + static + cutlass::Status + initialize_workspace(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr, + CudaHostAdapter* cuda_adapter = nullptr) { + return Status::kSuccess; + } + + static dim3 + get_grid_shape(Params const& params) { + int batch_count = 1; + if constexpr (cute::rank(ProblemShape{}) == 4) { + batch_count = cute::size<3>(params.problem_shape); + } + + auto M = get<0>(params.problem_shape); + auto N = get<1>(params.problem_shape); + + const int sg_m = (M - 1) / get<0>(TileShape{}) + 1; // sub_groups required to process A fragments + const int sg_n = (N - 1) / get<1>(TileShape{}) + 1; // sub_groups required to process B fragments + + return dim3( + sg_m, + cute::ceil_div(sg_n, num_sg), + batch_count + ); + } + + static dim3 + get_block_shape() { + return dim3(MaxThreadsPerBlock, 1, 1); + } + + CUTLASS_DEVICE + void + operator()(Params const& params, char* smem_buf) { + + (void)smem_buf; + + // Preconditions + CUTE_STATIC_ASSERT(is_static::value); + + // Separate out problem shape for convenience + // Optionally append 1s until problem shape is rank-4 in case its is only rank-3 (MNK) + auto problem_shape_MNKL = append<4>(params.problem_shape, Int<1>{}); + auto M = get<0>(problem_shape_MNKL); + auto N = get<1>(problem_shape_MNKL); + auto K = get<2>(problem_shape_MNKL); + auto L = get<3>(problem_shape_MNKL); + + // Preconditions + static_assert(cute::rank(StrideA{}) == 3, "StrideA must be rank-3: [M, K, L]. If batch mode is not needed, set L stride to Int<0>."); + static_assert(cute::rank(StrideB{}) == 3, "StrideB must be rank-3: [N, K, L]. If batch mode is not needed, set L stride to Int<0>."); + static_assert(cute::rank(StrideC{}) == 3, "StrideC must be rank-3: [M, N, L]. If batch mode is not needed, set L stride to Int<0>."); + static_assert(cute::rank(StrideD{}) == 3, "StrideD must be rank-3: [M, N, L]. If batch mode is not needed, set L stride to Int<0>."); + + // Get the appropriate blocks for this sub_group -- potential for sub_group locality + int thread_idx = int(ThreadIdxX()); + auto subgroup_shape = TileShape{}; // (SUB_M,SUB_N,SUB_K) + const int m_coord = BlockIdxX() * get<0>(subgroup_shape); + const int n_coord = (BlockIdxY() * num_sg + thread_idx / SubgroupSize) * get<1>(subgroup_shape); + const int l_coord = BlockIdxZ(); + + Tensor tAi = params.mainloop.gmem_tiled_copy_a.get_pvc_tensor(make_coord(m_coord, 0, l_coord), + make_shape(_1{}, K, L), + make_stride(Int{}, _1{})); + + Tensor tBi = params.mainloop.gmem_tiled_copy_b.get_pvc_tensor(make_coord(0, n_coord, l_coord), + make_shape(K, Int{}, L), + make_stride(_1{}, Int{})); + + // Compute tile residues for predication + auto m_max_coord = M - get<0>(subgroup_shape) * m_coord; // M - SUB_M * m_coord + auto n_max_coord = N - get<1>(subgroup_shape) * n_coord; // N - SUB_N * n_coord + auto k_residue = K - get<2>(subgroup_shape) * (K / get<2>(subgroup_shape)); // K - SUB_K * k_coord_max + auto residue_mnk = make_tuple(m_max_coord, n_max_coord, k_residue); + + // Allocate the tiled_mma and the accumulators for the (M,N) subgroup_shape + TiledMma tiled_mma; + + Tensor accumulators = make_tensor(Shape, Int, Int>{}); + clear(accumulators); + + auto k_tile_iter = cute::make_coord_iterator(make_shape(K / get<2>(subgroup_shape))); + int k_tile_count = K / get<2>(subgroup_shape); + + // Perform the collective scoped MMA + CollectiveMainloop collective_mma; + collective_mma( + accumulators, + tAi(_,_,_,l_coord), + tBi(_,_,_,l_coord), + accumulators, + k_tile_iter, k_tile_count, + residue_mnk, + thread_idx, + smem_buf, + params.mainloop + ); + auto gmem_tiled_copy_c = make_xe_2d_copy(make_tensor(params.epilogue.ptr_D, make_shape(M, N, L), params.epilogue.dD)); + + Tensor tCi = gmem_tiled_copy_c.get_pvc_tensor(make_coord(m_coord, n_coord, l_coord), + make_shape(Int{}, Int{}, L), + make_stride(Int{}, Int{})); + + copy(gmem_tiled_copy_c, accumulators, tCi(_,_,_,l_coord)); + } +}; + +/////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::gemm::kernel