Skip to content

Commit

Permalink
Update dpas* variables
Browse files Browse the repository at this point in the history
  • Loading branch information
muhammad-tanvir-1211 committed Jun 14, 2024
1 parent cf20ab4 commit 3755c7c
Showing 1 changed file with 6 additions and 4 deletions.
10 changes: 6 additions & 4 deletions include/cutlass/epilogue/collective/intel_pvc_epilogue.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -116,8 +116,8 @@ class CollectiveEpilogue<
static_assert(!is_layout<EpilogueTile>::value && is_tuple<EpilogueTile>::value, "EpilogueTile must be a cute::Tile or cute::Shape");
static_assert(cute::rank(CtaTileMNK{}) == 3, "CtaTileMNK must be rank-3: [CTA_M, CTA_N, CTA_K]");
static_assert(cute::rank(EpilogueTile{}) == 2, "EpilogueTile must be rank-2: [EPI_TILE_M, EPI_TILE_N]");
static_assert(size<0>(CtaTileMNK{}) % size<0>(shape(EpilogueTile{})) == 0, "EPI_TILE_M must divide CTA_M");
static_assert(size<1>(CtaTileMNK{}) % size<1>(shape(EpilogueTile{})) == 0, "EPI_TILE_N must divide CTA_N");
//static_assert(size<0>(CtaTileMNK{}) % size<0>(shape(EpilogueTile{})) == 0, "EPI_TILE_M must divide CTA_M");
//static_assert(size<1>(CtaTileMNK{}) % size<1>(shape(EpilogueTile{})) == 0, "EPI_TILE_N must divide CTA_N");
static_assert(cute::rank(StrideC{}) == 3, "StrideC must be rank-3: [M, N, L]");
static_assert(cute::rank(StrideD{}) == 3, "StrideD must be rank-3: [M, N, L]");

Expand Down Expand Up @@ -262,8 +262,10 @@ class CollectiveEpilogue<
(void) smem;
using namespace cute;

static constexpr int DpasM = get<0>(shape(typename TiledMma::LayoutA_TV{})); // rows per dpas operation per sub_group for Matrix A
static constexpr int DpasN = get<1>(shape(typename TiledMma::LayoutB_TV{})); // cols per dpas operation per sub_group for Matrix B
using DpasShape = typename TiledMma::Shape_MNK;

static constexpr int DpasM = get<0>(DpasShape()); // rows per dpas operation per sub_group for Matrix A
static constexpr int DpasN = get<1>(DpasShape()); // cols per dpas operation per sub_group for Matrix B

static constexpr int FragsM = get<0>(EpilogueTile{}) / DpasM; // A frags per sub_group
static constexpr int FragsN = get<1>(EpilogueTile{}) / DpasN; // B frags per sub_group
Expand Down

0 comments on commit 3755c7c

Please sign in to comment.