Skip to content

Commit

Permalink
Add subgroup tile information to tiledMma. (#82)
Browse files Browse the repository at this point in the history
  • Loading branch information
aacostadiaz committed Jun 12, 2024
1 parent 67e39cf commit d161fa7
Show file tree
Hide file tree
Showing 5 changed files with 29 additions and 34 deletions.
7 changes: 4 additions & 3 deletions examples/sycl/pvc/pvc_bfloat_dpas_gemm_cute.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -353,10 +353,11 @@ int main(int argc, const char** argv)
using GmemTiledCopyA = XE_2D_U16x8x16x4x2_LD_N;
using GmemTiledCopyB = XE_2D_U16x16x16x2x1_LD_N;

using TileShape = Shape<_32, _64, _32>;
using TileShape = Shape<_1, _1, _1>;

using TiledMma = TiledMMA<MMA_Atom<XE_8x16x16_BF16BF16F32F32_NN>,
Layout<Shape<_8,_16,_1>>>;
using TiledMma = TiledMMA<MMA_Atom<XE_8x16x16_F32BF16BF16F32_TN>,
Layout<Shape<_1,_1,_1>>,
Tile<_32,_64,_32>>;

using DispatchPolicy = cutlass::gemm::MainloopIntelPVCUnpredicated;

Expand Down
9 changes: 2 additions & 7 deletions include/cute/arch/mma_xe.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -45,11 +45,7 @@ SYCL_DEVICE_OCL(float intel_sub_group_bf16_bf16_matrix_mad_k16(short a, cute::i
#undef SYCL_DEVICE_OCL

namespace cute {
//MxNxK_A,B,C,D
//# of vector component of a x subgroup-size x function name
//float8 intel_sub_group_bf16_bf16_matrix_mad_k16(short8 a, int8 b, float8 acc);
//TODO: Is A really not transposed? Maybe better a macro than separate define for 1,2,4,8
struct XE_8x16x16_BF16BF16F32F32_NN
struct XE_8x16x16_F32BF16BF16F32_TN
{
using DRegisters = intel::float8[1];
using ARegisters = intel::short8[1];
Expand All @@ -69,8 +65,7 @@ struct XE_8x16x16_BF16BF16F32F32_NN
#endif
}
};
//float intel_sub_group_bf16_bf16_matrix_mad_k16(short a, int8 b, float acc)
struct XE_1x16x16_BF16BF16F32F32_NN
struct XE_1x16x16_F32BF16BF16F32_TN
{
using DRegisters = float[1];
using ARegisters = short[1];
Expand Down
2 changes: 1 addition & 1 deletion include/cute/atom/mma_traits_xe.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
namespace cute
{
template <>
struct MMA_Traits<XE_8x16x16_BF16BF16F32F32_NN>
struct MMA_Traits<XE_8x16x16_F32BF16BF16F32_TN>
{
using ValTypeD = float;
using ValTypeA = bfloat16_t;
Expand Down
28 changes: 13 additions & 15 deletions include/cutlass/gemm/collective/intel_pvc_mma.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -99,24 +99,22 @@ struct CollectiveMma<
using ArchTag = typename DispatchPolicy::ArchTag;

static constexpr int SubgroupSize = DispatchPolicy::SubgroupSize;

static constexpr int DpasM = get<0>(shape(typename TiledMma::LayoutA_TV{})); // rows per dpas operation per sub_group for Matrix A
static constexpr int DpasN = get<1>(shape(typename TiledMma::LayoutB_TV{})); // cols per dpas operation per sub_group for Matrix B
static constexpr int DpasK = get<1>(shape(typename TiledMma::LayoutA_TV{})); // cols per dpas operation per sub_group for Matrix A

static constexpr uint32_t MaxThreadsPerBlock = DpasM * DpasN;
static constexpr uint32_t MinBlocksPerMultiprocessor = 1;
using DpasShape = typename TiledMma::Shape_MNK;
using TileDpasShape = decltype(tile_shape(TiledMma()));

static constexpr int FragsM = get<0>(TileShape{}) / DpasM; // A frags per sub_group
static constexpr int FragsN = get<1>(TileShape{}) / DpasN; // B frags per sub_group
static constexpr int FragsK = get<2>(TileShape{}) / DpasK;
static constexpr uint32_t MaxThreadsPerBlock = get<0>(DpasShape()) * get<1>(DpasShape());

static constexpr int FragsM = get<0>(TileDpasShape{}) / get<0>(DpasShape()); // A frags per sub_group
static constexpr int FragsN = get<1>(TileDpasShape{}) / get<1>(DpasShape()); // B frags per sub_group
static constexpr int FragsK = get<2>(TileDpasShape{}) / get<2>(DpasShape());

// Calculate the vector width based on the amount of registers
// required per work item by dividing the total fragment size by
// the sub_group size.
static constexpr int VecC = (DpasN * DpasM) / SubgroupSize;
static constexpr int VecA = (DpasM * DpasK) / SubgroupSize;
static constexpr int VecB = (DpasN * DpasK) / SubgroupSize;
static constexpr int VecC = (get<1>(DpasShape()) * get<0>(DpasShape())) / SubgroupSize;
static constexpr int VecA = (get<0>(DpasShape()) * get<2>(DpasShape())) / SubgroupSize;
static constexpr int VecB = (get<1>(DpasShape()) * get<2>(DpasShape())) / SubgroupSize;

// Host side kernel arguments
struct Arguments {
Expand Down Expand Up @@ -188,8 +186,8 @@ struct CollectiveMma<
static_assert(is_rmem<FrgTensorC>::value, "C tensor must be rmem resident.");

// Tensor to hold input data
Tensor tAr = make_tensor<typename TiledMma::ValTypeA>(Shape<Int<get<0>(TileShape{}) * FragsK>, Int<1>>{});
Tensor tBr = make_tensor<typename TiledMma::ValTypeB>(Shape<Int<FragsK * get<1>(TileShape{}) / FragsN>, Int<FragsN>>{});
Tensor tAr = make_tensor<typename TiledMma::ValTypeA>(Shape<Int<get<0>(TileDpasShape{}) * FragsK>, Int<1>>{});
Tensor tBr = make_tensor<typename TiledMma::ValTypeB>(Shape<Int<FragsK * get<1>(TileDpasShape{}) / FragsN>, Int<FragsN>>{});

Tensor tAr_view = make_tensor(static_cast<decltype(tAr) &&>(tAr).data(),
Shape<Int<VecA>, Int<FragsM>, Int<FragsK>>{});
Expand All @@ -202,7 +200,7 @@ struct CollectiveMma<
//
// Mainloop
//
for (int k_tile = 0, k = 0; k_tile < k_tile_count; ++k_tile, k += DpasK * FragsK)
for (int k_tile = 0, k = 0; k_tile < k_tile_count; ++k_tile, k += get<2>(DpasShape()) * FragsK)
{
// Copy gmem to rmem for the first k_tile
copy(mainloop.gmem_tiled_copy_a, gA(_,_,k), tAr);
Expand Down
17 changes: 9 additions & 8 deletions include/cutlass/gemm/kernel/intel_pvc_gemm.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -105,8 +105,9 @@ class GemmUniversal<
static constexpr int num_sg = MaxThreadsPerBlock / SubgroupSize; // number of sub_groups per work group
static constexpr int DpasM = CollectiveMainloop::DpasM;
static constexpr int DpasN = CollectiveMainloop::DpasN;
using DpasShape = typename CollectiveMainloop::DpasShape;
using TileDpasShape = typename CollectiveMainloop::TileDpasShape;
static constexpr int FragsM = CollectiveMainloop::FragsM;
static constexpr int FragsN = CollectiveMainloop::FragsN;
Expand Down Expand Up @@ -177,8 +178,8 @@ class GemmUniversal<
auto M = get<0>(params.problem_shape);
auto N = get<1>(params.problem_shape);
const int sg_m = (M - 1) / get<0>(TileShape{}) + 1; // sub_groups required to process A fragments
const int sg_n = (N - 1) / get<1>(TileShape{}) + 1; // sub_groups required to process B fragments
const int sg_m = (M - 1) / get<0>(TileDpasShape{}) + 1; // sub_groups required to process A fragments
const int sg_n = (N - 1) / get<1>(TileDpasShape{}) + 1; // sub_groups required to process B fragments
return dim3(
sg_m,
Expand Down Expand Up @@ -217,18 +218,18 @@ class GemmUniversal<
// Get the appropriate blocks for this sub_group -- potential for sub_group locality
int thread_idx = int(ThreadIdxX());
auto subgroup_shape = TileShape{}; // (SUB_M,SUB_N,SUB_K)
auto subgroup_shape = TileDpasShape{}; // (SUB_M,SUB_N,SUB_K)
const int m_coord = BlockIdxX() * get<0>(subgroup_shape);
const int n_coord = (BlockIdxY() * num_sg + thread_idx / SubgroupSize) * get<1>(subgroup_shape);
const int l_coord = BlockIdxZ();
Tensor tAi = params.mainloop.gmem_tiled_copy_a.get_pvc_tensor(make_coord(m_coord, 0, 0),
make_shape(_1{}, K, L),
make_stride(Int<FragsM * DpasM>{}, _1{}));
make_stride(Int<FragsM>{} * get<0>(DpasShape()), _1{}));
Tensor tBi = params.mainloop.gmem_tiled_copy_b.get_pvc_tensor(make_coord(0, n_coord, 0),
make_shape(K, Int<FragsN>{}, L),
make_stride(_1{}, Int<DpasN>{}));
make_stride(_1{}, get<1>(DpasShape())));
// Compute tile residues for predication
auto m_max_coord = M - get<0>(subgroup_shape) * m_coord; // M - SUB_M * m_coord
Expand Down Expand Up @@ -262,7 +263,7 @@ class GemmUniversal<
Tensor tCi = gmem_tiled_copy_c.get_pvc_tensor(make_coord(m_coord, n_coord, 0),
make_shape(Int<FragsM>{}, Int<FragsN>{}, L),
make_stride(Int<DpasM>{}, Int<DpasN>{}));
make_stride(get<0>(DpasShape()), get<1>(DpasShape())));
copy(gmem_tiled_copy_c, accumulators, tCi(_,_,_,l_coord));
}
Expand Down

0 comments on commit d161fa7

Please sign in to comment.