Skip to content
25 changes: 14 additions & 11 deletions examples/00_bmg_gemm/00_bmg_gemm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -345,30 +345,33 @@ int main(int argc, const char** argv)
using LayoutC = cutlass::layout::RowMajor;
using LayoutD = cutlass::layout::RowMajor;

// The 2D block copy operations used for the A and B matrices
using GmemTiledCopyA = XE_2D_U16x32x32_LD_N;
using GmemTiledCopyB = XE_2D_U16x32x32_LD_V;
// [New Copy Atom] When left unspecified (void), make_block_2d_copy_* automatically selects
// appropriate 2D block copy operations for matrices A and B. Alternatively, you can
// explicitly specify new copy atom operations such as XE_LOAD_2D, XE_LOAD_2D_VNNI
// (applicable only to matrix B), or XE_LOAD_2D_TRANSPOSE.
// Refer https://github.com/intel/sycl-tla/blob/petercad/rearchitecture/media/docs/cpp/xe_rearchitecture.md
using GmemTiledCopyA = void; //XE_LOAD_2D<16, 32, 32>;
using GmemTiledCopyB = void; //XE_LOAD_2D_VNNI<16, 32, 32>;

// Workgroup-level tile
using TileShape = Shape<_256, _256, _32>;

// A TiledMMA struct defines a tiling of an MMA atom over M, N and K, combining both additional
// hardware (sub-groups for Intel BMG) and iterations by each sub-group.
//
// The TiledMMAHelper struct defines a specific TiledMMA for a given MMA atom
// (XE_8x16x16_F32BF16BF16F32_TT), TileShape (<256, 256, 32>) and sub-group layout (8x4x1). The
// TiledMMA constructed using TiledMMAHelper has the property that each sub-group operates on a
// The TiledMMAHelper struct defines a specific TiledMMA for a given MMA atom. This example uses
// the XE_DPAS_TT<8, float, cute::bfloat16_t> atom, which represents an 8x16x16 DPAS operation with float32 accumulation and bfloat16 inputs, TileShape (<256, 256, 32>) and sub-group layout (8x4x1).
// The TiledMMA constructed using TiledMMAHelper has the property that each sub-group operates on a
// single contiguous chunk of the work-group TileShape. For this configuration, this implies that
// each sub-group operates on a contiguous 32x64x32 chunk (4x4x2 iterations). See
// 0t_mma_atom.md#TiledMMAs for more info. Sub-groups are arranged row-major (stride 4,1,0) for
// performance reasons.
using TiledMma = // M=8,N=16,K=16, D=f32,A=bf16,B=bf16,C=f32
typename TiledMMAHelper<MMA_Atom<XE_8x16x16_F32BF16BF16F32_TT>, Layout<TileShape>,
Layout<Shape<_8, _4, _1>, Stride<_4, _1, _0>>>::TiledMMA;
using TiledMma = typename TiledMMAHelper<MMA_Atom<XE_DPAS_TT<8, float, cute::bfloat16_t>>, Layout<TileShape>, Layout<Shape<_8, _4, _1>, Stride<_4, _1, _0>>>::TiledMMA;

// For Intel BMG, PipelineStages defines how many k-blocks ahead to prefetch from A and B.
constexpr int PipelineStages = 2;
using GEMMDispatchPolicy = cutlass::gemm::MainloopIntelXeXMX16<PipelineStages>;
// For older version of copy/mma atom, use cutlass::gemm::MainloopIntelXeXMX16 as dispatch policy
using GEMMDispatchPolicy = cutlass::gemm::MainloopXeL1Staged<PipelineStages>;
using EpilogueDispatchPolicy = cutlass::epilogue::IntelXeXMX16;

// This is the 'default' epilogue operation (Linear Combination) which performs everything in:
Expand Down
36 changes: 36 additions & 0 deletions include/cute/atom/copy_traits_xe_2d.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -756,6 +756,31 @@ make_block_2d_copy_X(CopyOp const& op, // Copy operation
return make_block_2d_copy<ValType>(op, gstride, x_mode, y_mode, atom_shape, sv_layout_t);
}

// Helper trait to detect new XE copy ops
template<typename T>
struct is_new_xe_atom : cute::false_type {};

// Helper trait specifically for XE_LOAD_2D_VNNI (for copy B)
template<typename T>
struct is_new_xe_atom_vnni : cute::false_type {};

// Helper trait specifically for XE_STORE_2D (for copy C)
template<typename T>
struct is_new_xe_atom_store : cute::false_type {};

// Check if T is an instantiation of XE_LOAD_2D
template<int Bits, int Height, int Width, int BlockWidth>
struct is_new_xe_atom<XE_LOAD_2D<Bits, Height, Width, BlockWidth>> : cute::true_type {};

// Check if T is an instantiation of XE_LOAD_2D_TRANSPOSE
template<int Bits, int Height, int Width>
struct is_new_xe_atom<XE_LOAD_2D_TRANSPOSE<Bits, Height, Width>> : cute::true_type {};

// Check if T is an instantiation of XE_LOAD_2D_VNNI
template<int Bits, int Height, int Width, int BlockWidth>
struct is_new_xe_atom_vnni<XE_LOAD_2D_VNNI<Bits, Height, Width, BlockWidth>> : cute::true_type {};


// MMA-focused TiledCopy creation functions.
template <class TiledMMA, class GEngine, class GLayout>
CUTE_HOST_DEVICE
Expand All @@ -774,6 +799,12 @@ make_block_2d_copy_A(CopyOp const& op, // Copy operation
TiledMMA const& mma, // TiledMMA instance
Tensor<GEngine, GLayout> const& gmem) // Global tensor
{
// This will pass for new atoms like XE_LOAD_2D<16, 32, 32>
// and fail for old atoms like XE_2D_U16x32x32_LD_N
static_assert(is_new_xe_atom<CopyOp>::value,
"Legacy XE copy atom ops not compatible with make_block_2d_copy_A. "
"Please use the new templated atoms: XE_LOAD_2D<Bits, Height, Width> or XE_LOAD_2D_TRANSPOSE<Bits, Height, Width>. "
"Examples: XE_2D_U16x32x32_LD_N -> XE_LOAD_2D<16, 32, 32>, XE_2D_U16x32x32_LD_V -> XE_LOAD_2D_TRANSPOSE<16, 32, 32>");
using ValType = typename GEngine::value_type;
return make_block_2d_copy_A<ValType>(op, mma, gmem.stride()).with(gmem);
}
Expand Down Expand Up @@ -846,6 +877,11 @@ make_block_2d_copy_B(CopyOp const& op, // Copy operation
TiledMMA const& mma, // TiledMMA instance
Tensor<GEngine, GLayout> const& gmem) // Global tensor
{
// Only accept XE_LOAD_2D_VNNI for copy B
static_assert(is_new_xe_atom_vnni<CopyOp>::value,
"Legacy XE copy atom ops not compatible with make_block_2d_copy_B. "
"Please use the new templated atom: XE_LOAD_2D_VNNI<Bits, Height, Width, BlockWidth>. "
"Examples: XE_2D_U16x32x32_LD_V -> XE_LOAD_2D_VNNI<16, 32, 32, 32>");
using ValType = typename GEngine::value_type;
return make_block_2d_copy_B<ValType>(op, mma, gmem.stride()).with(gmem);
}
Expand Down
1 change: 1 addition & 0 deletions include/cutlass/gemm/collective/collective_mma.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@

#if defined(SYCL_INTEL_TARGET)
#include "cutlass/gemm/collective/xe_mma.hpp"
#include "cutlass/gemm/collective/xe_mma_legacy.hpp"
#include "cutlass/gemm/collective/xe_array_mma.hpp"
#include "cutlass/gemm/collective/xe_array_mma_fp8.hpp"
#include "cutlass/gemm/collective/xe_mma_mixed_input.hpp"
Expand Down
125 changes: 73 additions & 52 deletions include/cutlass/gemm/collective/xe_mma.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -47,13 +47,13 @@ using namespace cute;
template <int Stages, class Schedule, class TileShape_, class ElementA_, class StrideA_, class ElementB_, class StrideB_,
class TiledMma_, class GmemTiledCopyA_, class SmemLayoutAtomA_, class SmemCopyAtomA_, class TransformA_,
class GmemTiledCopyB_, class SmemLayoutAtomB_, class SmemCopyAtomB_, class TransformB_>
struct CollectiveMma<MainloopIntelXeXMX16<Stages, Schedule>, TileShape_, ElementA_, StrideA_, ElementB_, StrideB_, TiledMma_,
struct CollectiveMma<MainloopXeL1Staged<Stages, Schedule>, TileShape_, ElementA_, StrideA_, ElementB_, StrideB_, TiledMma_,
GmemTiledCopyA_, SmemLayoutAtomA_, SmemCopyAtomA_, TransformA_, GmemTiledCopyB_, SmemLayoutAtomB_,
SmemCopyAtomB_, TransformB_> {
//
// Type Aliases
//
using DispatchPolicy = MainloopIntelXeXMX16<Stages, Schedule>;
using DispatchPolicy = MainloopXeL1Staged<Stages, Schedule>;
using WorkgroupTileShape = TileShape_;
using ElementA = ElementA_;
using StrideA = StrideA_;
Expand All @@ -71,7 +71,7 @@ struct CollectiveMma<MainloopIntelXeXMX16<Stages, Schedule>, TileShape_, Element
using TransformB = TransformB_;
using ArchTag = typename DispatchPolicy::ArchTag;

static_assert(platform::is_same<ElementA, ElementB>::value, "MainloopIntelXeXMX16 requires that A and B have same type.");
static_assert(platform::is_same<ElementA, ElementB>::value, "MainloopXeL1Staged requires that A and B have same type.");
static_assert(std::is_same_v<TransformA, cute::identity>, "Transformation for A is not currently supported on Intel PVC");
static_assert(std::is_same_v<TransformB, cute::identity>, "Transformation for B is not currently supported on Intel PVC");

Expand Down Expand Up @@ -100,9 +100,6 @@ struct CollectiveMma<MainloopIntelXeXMX16<Stages, Schedule>, TileShape_, Element
static constexpr auto Num_SGs = ATOM_N * ATOM_M * ATOM_K;
static constexpr uint32_t MaxThreadsPerBlock = size(TiledMma{});

using Copy_A = typename Copy_Traits<GmemTiledCopyA, StrideA>::template DefaultTiledCopy<ElementA>;
using Copy_B = typename Copy_Traits<GmemTiledCopyB, StrideB>::template DefaultTiledCopy<ElementB>;

// Host side kernel arguments
struct Arguments {
ElementA const* ptr_A;
Expand All @@ -112,8 +109,11 @@ struct CollectiveMma<MainloopIntelXeXMX16<Stages, Schedule>, TileShape_, Element
};

struct Params {
Copy_A tiled_copy_a;
Copy_B tiled_copy_b;
ElementA const* ptr_A;
StrideA dA;
ElementB const* ptr_B;
StrideB dB;
int M, N, K, L;
};

//
Expand All @@ -129,12 +129,11 @@ struct CollectiveMma<MainloopIntelXeXMX16<Stages, Schedule>, TileShape_, Element

auto [M,N,K,L] = problem_shape;

auto mA_mkl = make_tensor(make_gmem_ptr(args.ptr_A), make_layout(make_shape(M, K, L), args.dA));
auto mB_nkl = make_tensor(make_gmem_ptr(args.ptr_B), make_layout(make_shape(N, K, L), args.dB));
Copy_A tiled_copy_a{Copy_A{}.with(mA_mkl)};
Copy_B tiled_copy_b{Copy_B{}.with(mB_nkl)};

return Params{tiled_copy_a, tiled_copy_b};
return Params{args.ptr_A,
args.dA,
args.ptr_B,
args.dB,
M, N, K, L};
}

template<class ProblemShape>
Expand Down Expand Up @@ -177,59 +176,76 @@ struct CollectiveMma<MainloopIntelXeXMX16<Stages, Schedule>, TileShape_, Element
static_assert(is_rmem<FrgTensorD>::value, "D tensor must be rmem resident.");
static_assert(is_rmem<FrgTensorC>::value, "C tensor must be rmem resident.");

auto thr_copy_A = mainloop.tiled_copy_a.get_slice(thread_idx);
auto thr_copy_B = mainloop.tiled_copy_b.get_slice(thread_idx);
auto mA_mkl = make_tensor(make_gmem_ptr(mainloop.ptr_A),
make_layout(make_shape(mainloop.M, mainloop.K, mainloop.L), cute::take<0,2>(mainloop.dA)));
auto mB_nkl = make_tensor(make_gmem_ptr(mainloop.ptr_B),
make_layout(make_shape(mainloop.N, mainloop.K, mainloop.L), cute::take<0,2>(mainloop.dB)));
auto copy_a = [&]() {
if constexpr (!std::is_void_v<GmemTiledCopyA>) {
// User provided copy operation
return make_block_2d_copy_A(GmemTiledCopyA{}, TiledMma{}, mA_mkl);
} else {
// make_block_2d_copy_A automatically selects copy operation
return make_block_2d_copy_A(TiledMma{}, mA_mkl);
}
}();

auto copy_b = [&]() {
if constexpr (!std::is_void_v<GmemTiledCopyB>) {
// User provided copy operation
return make_block_2d_copy_B(GmemTiledCopyB{}, TiledMma{}, mB_nkl);
} else {
// make_block_2d_copy_B automatically selects copy operation
return make_block_2d_copy_B(TiledMma{}, mB_nkl);
}
}();

auto thr_copy_a = copy_a.get_slice(thread_idx);
auto thr_copy_b = copy_b.get_slice(thread_idx);

// Instantiate the MMA object and get thread slice
TiledMma tiled_mma;
// TODO(Codeplay): see if we can make this nicer
// To make all work items in a subgroup have the same global tensors pass in the index of work item 0 in each subgroup
auto sg = compat::get_nd_item<1>().get_sub_group();
auto first_thread_in_sg_idx = sg.get_group_linear_id() * DispatchPolicy::SubgroupSize;
auto thr_mma = tiled_mma.get_slice(first_thread_in_sg_idx);

// Partition global counting tensors for MMA
Tensor tCgA = thr_mma.partition_A(gA);
Tensor tCgB = thr_mma.partition_B(gB);

Tensor tCrA = make_tensor<ElementA>(make_fragment_layout(mainloop.tiled_copy_a, tCgA(_,_,_,0).shape()));
Tensor tCrB = make_tensor<ElementB>(make_fragment_layout(mainloop.tiled_copy_b, tCgB(_,_,_,0).shape()));

// Retile registers for copies
Tensor tArA = thr_copy_A.retile_D(tCrA);
Tensor tBrB = thr_copy_B.retile_D(tCrB);

// Retile global counting tensors for copies
Tensor tAgA = thr_copy_A.retile_S(tCgA);
Tensor tBgB = thr_copy_B.retile_S(tCgB);

auto tiled_prefetch_a = cute::prefetch_selector<Shape<Int<BLK_M>,Int<BLK_K>>, Num_SGs>(mainloop.tiled_copy_a);
auto tiled_prefetch_b = cute::prefetch_selector<Shape<Int<BLK_N>,Int<BLK_K>>, Num_SGs>(mainloop.tiled_copy_b);
auto thr_prefetch_A = tiled_prefetch_a.get_slice(thread_idx);
auto thr_prefetch_B = tiled_prefetch_b.get_slice(thread_idx);
auto thr_mma = tiled_mma.get_slice(thread_idx);

/* Register fragments for MMA */
auto tCrA = thr_mma.partition_sg_fragment_A(gA(_,_,0));
auto tCrB = thr_mma.partition_sg_fragment_B(gB(_,_,0));

/* Register fragments for copies */
auto tArA = thr_copy_a.partition_sg_fragment_D(gA(_,_,0));
auto tBrB = thr_copy_b.partition_sg_fragment_D(gB(_,_,0));

/* Partition global tensor (proxies) for copies */
Tensor tAgA = thr_copy_a.partition_S(gA);
Tensor tBgB = thr_copy_b.partition_S(gB);

// Partition global tile for prefetch
/* Create prefetch TiledCopy instances */
auto prefetch_a = make_block_2d_prefetch(copy_a);
auto prefetch_b = make_block_2d_prefetch(copy_b);

auto thr_prefetch_A = prefetch_a.get_slice(thread_idx);
auto thr_prefetch_B = prefetch_b.get_slice(thread_idx);

/* Partition global tensor (proxies) for prefetch */
auto pAgA = thr_prefetch_A.partition_S(gA);
auto pBgB = thr_prefetch_B.partition_S(gB);

#if CUTLASS_ENABLE_DEBUG_PRINTS
#define PRINT(x) print(#x ": "); print(x); print("\n");
if (cute::thread(LOG_THREAD, LOG_GROUP)) {
print("======================= A: \n");
PRINT(tCgA);
PRINT(tAgA);

PRINT(tCrA);
PRINT(tArA);
PRINT(mainloop.tiled_copy_a);
PRINT(copy_a);

print("======================= B: \n");
PRINT(tCgB);
PRINT(tBgB);

PRINT(tCrB);
PRINT(tBrB);
PRINT(mainloop.tiled_copy_b);
PRINT(copy_b);
}
#undef PRINT
#endif
Expand All @@ -243,21 +259,25 @@ struct CollectiveMma<MainloopIntelXeXMX16<Stages, Schedule>, TileShape_, Element

CUTLASS_PRAGMA_UNROLL
for (; prefetch_k < DispatchPolicy::Stages; prefetch_k++) {
prefetch(tiled_prefetch_a, pAgA(_, _, _, prefetch_k));
prefetch(tiled_prefetch_b, pBgB(_, _, _, prefetch_k));
prefetch(prefetch_a, pAgA(_, _, _, prefetch_k));
prefetch(prefetch_b, pBgB(_, _, _, prefetch_k));
}

for (int k_tile = k_start_idx; k_tile < k_tile_count + k_start_idx; k_tile++, prefetch_k++) {
barrier_arrive(barrier_scope);
// Copy gmem to rmem for the first k_tile
copy(mainloop.tiled_copy_a, tAgA(_,_,_,k_tile), tArA);
copy(mainloop.tiled_copy_b, tBgB(_,_,_,k_tile), tBrB);
copy(copy_a, tAgA(_,_,_,k_tile), tArA);
copy(copy_b, tBgB(_,_,_,k_tile), tBrB);

if (prefetch_k < k_tile_count) {
prefetch(tiled_prefetch_a, pAgA(_, _, _, prefetch_k));
prefetch(tiled_prefetch_b, pBgB(_, _, _, prefetch_k));
prefetch(prefetch_a, pAgA(_, _, _, prefetch_k));
prefetch(prefetch_b, pBgB(_, _, _, prefetch_k));
}

/* Shuffle data from copy fragments to MMA fragments */
reorder(tArA, tCrA);
reorder(tBrB, tCrB);

cute::gemm(tiled_mma, tCrA, tCrB, accum);
barrier_wait(barrier_scope);
}
Expand All @@ -267,3 +287,4 @@ struct CollectiveMma<MainloopIntelXeXMX16<Stages, Schedule>, TileShape_, Element
} // namespace cutlass::gemm::collective

/////////////////////////////////////////////////////////////////////////////////////////////////

Loading