Skip to content

Commit

Permalink
Add workgroup-level tile
Browse files Browse the repository at this point in the history
  • Loading branch information
aacostadiaz committed Jun 13, 2024
1 parent d161fa7 commit f5e0a17
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 12 deletions.
5 changes: 3 additions & 2 deletions examples/sycl/pvc/pvc_bfloat_dpas_gemm_cute.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -353,11 +353,12 @@ int main(int argc, const char** argv)
using GmemTiledCopyA = XE_2D_U16x8x16x4x2_LD_N;
using GmemTiledCopyB = XE_2D_U16x16x16x2x1_LD_N;

using TileShape = Shape<_1, _1, _1>;
// Workgroup-level tile
using TileShape = Shape<_32, _256, _32>;

using TiledMma = TiledMMA<MMA_Atom<XE_8x16x16_F32BF16BF16F32_TN>,
Layout<Shape<_1,_1,_1>>,
Tile<_32,_64,_32>>;
Tile<_32,_64,_32>>; // Subgroup level-tile

using DispatchPolicy = cutlass::gemm::MainloopIntelPVCUnpredicated;

Expand Down
2 changes: 1 addition & 1 deletion include/cutlass/gemm/collective/intel_pvc_mma.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ struct CollectiveMma<
using DpasShape = typename TiledMma::Shape_MNK;
using TileDpasShape = decltype(tile_shape(TiledMma()));

static constexpr uint32_t MaxThreadsPerBlock = get<0>(DpasShape()) * get<1>(DpasShape());
static constexpr uint32_t MaxThreadsPerBlock = cute::size(TileShape{}) / cute::size(TileDpasShape{}) * SubgroupSize;

static constexpr int FragsM = get<0>(TileDpasShape{}) / get<0>(DpasShape()); // A frags per sub_group
static constexpr int FragsN = get<1>(TileDpasShape{}) / get<1>(DpasShape()); // B frags per sub_group
Expand Down
15 changes: 6 additions & 9 deletions include/cutlass/gemm/kernel/intel_pvc_gemm.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -101,14 +101,10 @@ class GemmUniversal<
static constexpr int SubgroupSize = CollectiveMainloop::SubgroupSize; // sub_group size
static constexpr uint32_t MaxThreadsPerBlock = CollectiveMainloop::MaxThreadsPerBlock;
static constexpr uint32_t MinBlocksPerMultiprocessor = CollectiveMainloop::MinBlocksPerMultiprocessor;
static constexpr int num_sg = MaxThreadsPerBlock / SubgroupSize; // number of sub_groups per work group
using DpasShape = typename CollectiveMainloop::DpasShape;
using TileDpasShape = typename CollectiveMainloop::TileDpasShape;
static constexpr int FragsM = CollectiveMainloop::FragsM;
static constexpr int FragsN = CollectiveMainloop::FragsN;
Expand Down Expand Up @@ -182,9 +178,9 @@ class GemmUniversal<
const int sg_n = (N - 1) / get<1>(TileDpasShape{}) + 1; // sub_groups required to process B fragments
return dim3(
sg_m,
cute::ceil_div(sg_n, num_sg),
batch_count
cute::size(cute::ceil_div(cute::shape<0>(params.problem_shape), cute::shape<0>(TileShape{}))),
cute::size(cute::ceil_div(cute::shape<1>(params.problem_shape), cute::shape<1>(TileShape{}))),
batch_count
);
}
Expand Down Expand Up @@ -218,9 +214,10 @@ class GemmUniversal<
// Get the appropriate blocks for this sub_group -- potential for sub_group locality
int thread_idx = int(ThreadIdxX());
auto subgroup_shape = TileDpasShape{}; // (SUB_M,SUB_N,SUB_K)
constexpr auto workgroup_shape = TileShape{}; // (SUB_M,SUB_N,SUB_K)
constexpr auto subgroup_shape = TileDpasShape{}; // (SUB_M,SUB_N,SUB_K)
const int m_coord = BlockIdxX() * get<0>(subgroup_shape);
const int n_coord = (BlockIdxY() * num_sg + thread_idx / SubgroupSize) * get<1>(subgroup_shape);
const int n_coord = BlockIdxY() * get<1>(workgroup_shape) + thread_idx / SubgroupSize * get<1>(subgroup_shape);
const int l_coord = BlockIdxZ();
Tensor tAi = params.mainloop.gmem_tiled_copy_a.get_pvc_tensor(make_coord(m_coord, 0, 0),
Expand Down

0 comments on commit f5e0a17

Please sign in to comment.