Skip to content

Commit 48d82e8

Browse files
committed
[CuTe] [Xe] make_block_2d_copy_{C,D} variants with subtiling
1 parent dec36a9 commit 48d82e8

File tree

1 file changed

+156
-1
lines changed

1 file changed

+156
-1
lines changed

include/cute/atom/copy_traits_xe_2d.hpp

Lines changed: 156 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1008,7 +1008,162 @@ make_block_2d_copy_CD(CopyOp const& op, // Copy operation
10081008
make_tile(sg_to_vmn, _)); // (SG,V) -> (M,N)
10091009

10101010
// Derive copy tile layout and create TiledCopy
1011-
return make_block_2d_copy_X<ValType>(op, mma, gstride, x_mode, y_mode, tile_mn, svC);
1011+
return make_block_2d_copy_X<ValType>(op, gstride, x_mode, y_mode, tile_mn, svC);
1012+
}
1013+
1014+
// Variants of make_block_2d_copy_C/D where the C/D tile is further subdivided by the user.
1015+
// (e.g. split-k parallelization).
1016+
1017+
template <class TiledMMA,
1018+
class SubtileTVCoordLayout, class SubtileSGLayout,
1019+
class GEngine, class GLayout,
1020+
__CUTE_REQUIRES(is_layout_v<SubtileSGLayout>)>
1021+
CUTE_HOST_DEVICE
1022+
auto
1023+
make_block_2d_copy_C_subtiled(TiledMMA const& mma, // TiledMMA instance
1024+
SubtileTVCoordLayout const& stv_layout, // Subtile TV-layout: (T,V) -> coord
1025+
SubtileSGLayout const& ssg_layout, // Subtile subgroup layout: SG_K -> (m_subtile,n_subtile)
1026+
Tensor<GEngine, GLayout> const& gmem) // Global tensor
1027+
{
1028+
using ValType = typename GEngine::value_type;
1029+
return make_block_2d_copy_C_subtiled<ValType>(mma, stv_layout, ssg_layout, gmem.stride()).with(gmem);
1030+
}
1031+
1032+
template <class TiledMMA,
1033+
class SubtileTVCoordLayout, class SubtileSGLayout,
1034+
class GEngine, class GLayout,
1035+
__CUTE_REQUIRES(is_layout_v<SubtileSGLayout>)>
1036+
CUTE_HOST_DEVICE
1037+
auto
1038+
make_block_2d_copy_D_subtiled(TiledMMA const& mma, // TiledMMA instance
1039+
SubtileTVCoordLayout const& stv_layout, // Subtile TV-layout: (T,V) -> coord
1040+
SubtileSGLayout const& ssg_layout, // Subtile subgroup layout: SG_K -> (m_subtile,n_subtile)
1041+
Tensor<GEngine, GLayout> const& gmem) // Global tensor
1042+
{
1043+
using ValType = typename GEngine::value_type;
1044+
return make_block_2d_copy_D_subtiled<ValType>(mma, stv_layout, ssg_layout, gmem.stride()).with(gmem);
1045+
}
1046+
1047+
template <class TiledMMA,
1048+
class SubtileShape, class SubtileSGLayout,
1049+
class CopyOp, class GEngine, class GLayout,
1050+
__CUTE_REQUIRES(is_layout_v<SubtileSGLayout>)>
1051+
CUTE_HOST_DEVICE
1052+
auto
1053+
make_block_2d_copy_CD_subtiled(CopyOp const& op, // Copy operation
1054+
TiledMMA const& mma, // TiledMMA instance
1055+
SubtileShape const& sshape, // Subtile shape: (m,n)
1056+
SubtileSGLayout const& ssg_layout, // Subtile subgroup layout: SG_K -> (m_subtile,n_subtile)
1057+
Tensor<GEngine, GLayout> const& gmem) // Global tensor
1058+
{
1059+
using ValType = typename GEngine::value_type;
1060+
return make_block_2d_copy_CD_subtiled<ValType>(op, sshape, ssg_layout, mma, gmem.stride()).with(gmem);
1061+
}
1062+
1063+
template <class ValType, class TiledMMA,
1064+
class SubtileTVCoordLayout, class SubtileSGLayout,
1065+
class... Strides,
1066+
__CUTE_REQUIRES(is_layout_v<SubtileSGLayout>)>
1067+
CUTE_HOST_DEVICE
1068+
auto
1069+
make_block_2d_copy_C_subtiled(TiledMMA const& mma, // TiledMMA instance
1070+
SubtileTVCoordLayout const& stv_layout, // Subtile TV-layout: (T,V) -> coord
1071+
SubtileSGLayout const& ssg_layout, // Subtile subgroup layout: SG_K -> (m_subtile,n_subtile)
1072+
Stride<Strides...> const& gstride) // Global memory strides
1073+
{
1074+
using MMAType = typename TiledMMA::ValTypeA;
1075+
auto op = block_2d_selector<ValType, MMAType>(stv_layout, gstride);
1076+
return make_block_2d_copy_CD_subtiled<ValType>(op, mma, atuple_coshape(stv_layout), ssg_layout, gstride);
1077+
}
1078+
1079+
template <class ValType, class TiledMMA,
1080+
class SubtileTVCoordLayout, class SubtileSGLayout,
1081+
class... Strides,
1082+
__CUTE_REQUIRES(is_layout_v<SubtileSGLayout>)>
1083+
CUTE_HOST_DEVICE
1084+
auto
1085+
make_block_2d_copy_D_subtiled(TiledMMA const& mma, // TiledMMA instance
1086+
SubtileTVCoordLayout const& stv_layout, // Subtile TV-layout: (T,V) -> coord
1087+
SubtileSGLayout const& ssg_layout, // Subtile subgroup layout: SG_K -> (m_subtile,n_subtile)
1088+
Stride<Strides...> const& gstride) // Global memory strides
1089+
{
1090+
using MMAType = typename TiledMMA::ValTypeA;
1091+
auto op = block_2d_selector<ValType, MMAType, true>(stv_layout, gstride);
1092+
return make_block_2d_copy_CD_subtiled<ValType>(op, mma, atuple_coshape(stv_layout), ssg_layout, gstride);
1093+
}
1094+
1095+
template <class ValType, class TiledMMA, class CopyOp,
1096+
class SubtileShape, class SubtileSGLayout,
1097+
class... Strides,
1098+
__CUTE_REQUIRES(is_layout_v<SubtileSGLayout>)>
1099+
CUTE_HOST_DEVICE
1100+
auto
1101+
make_block_2d_copy_CD_subtiled(CopyOp const& op, // Copy operation
1102+
TiledMMA const& mma, // TiledMMA instance
1103+
SubtileShape const& sshape, // Subtile shape: (m,n)
1104+
SubtileSGLayout const& ssg_layout, // Subtile subgroup layout: SG_K -> (m_subtile,n_subtile)
1105+
Stride<Strides...> const& gstride) // Global memory strides
1106+
{
1107+
return make_block_2d_copy_CD_subtiled<ValType>(op, mma, sshape, ssg_layout, gstride,
1108+
find_x_mode(gstride), find_y_mode(gstride));
1109+
}
1110+
1111+
template <class ValType, class TiledMMA, class CopyOp,
1112+
class SubtileShape, class SubtileSGLayout,
1113+
class... Strides, class XMode, class YMode,
1114+
__CUTE_REQUIRES(is_layout_v<SubtileSGLayout>)>
1115+
CUTE_HOST_DEVICE
1116+
auto
1117+
make_block_2d_copy_CD_subtiled(CopyOp const& op, // Copy operation
1118+
TiledMMA const& mma, // TiledMMA instance
1119+
SubtileShape const& sshape, // Subtile shape: (m,n)
1120+
SubtileSGLayout const& ssg_layout, // Subtile subgroup layout: SG_K -> (m_subtile,n_subtile)
1121+
Stride<Strides...> const& gstride, // Global memory strides
1122+
XMode const& x_mode, // x, y modes
1123+
YMode const& y_mode)
1124+
{
1125+
// Expand subtile layout.
1126+
auto xssg_layout = make_layout(shape(ssg_layout),
1127+
elem_scale(stride(ssg_layout), sshape)); // SG_K -> (M,N)
1128+
1129+
// Retrieve MMA atom's (subgroup, value) -> (M,N) layout.
1130+
// Allow cross-MMA tiling.
1131+
auto tile_mn = round_up(select<0,1>(mma.tile_mnk()),
1132+
atuple_coshape(xssg_layout));
1133+
1134+
auto thr_vmnk = mma.get_thr_layout_vmnk(); // (ThrV,ThrM,ThrN,ThrK) -> thr
1135+
auto shape_vmnk = shape(thr_vmnk); // (ThrV,ThrM,ThrN,ThrK)
1136+
auto drop_k = replace<3>(make_layout(shape_vmnk),
1137+
make_layout(get<3>(shape_vmnk), _0{})); // (ThrV,ThrM,ThrN,ThrK) -> (ThrV,ThrM,ThrN)
1138+
1139+
auto thr_to_vmn = composition(drop_k, right_inverse(thr_vmnk)); // thr -> (ThrV,ThrM,ThrN)
1140+
auto sg_to_vmn = composition(thr_to_vmn,
1141+
make_layout(product(take<1,4>(shape_vmnk)), get<0>(shape_vmnk))); // SG -> (0,ThrM,ThrN)
1142+
1143+
auto svC = composition(mma.thrfrg_C(make_layout(tile_mn)),
1144+
make_tile(sg_to_vmn, _)); // (SG,V) -> (M,N)
1145+
1146+
// Add subtile modes. Limitations:
1147+
// - ThrK must be covered by a single mode in svC.
1148+
// - SubtileSGLayout must have a subtile for each ThrK, OR ThrK must be the last mode.
1149+
decltype(coalesce(get<0>(svC))) sC{};
1150+
constexpr auto mode_thr_k = find_if(stride(sC), [](auto const &x) { return C<is_constant_v<0, decltype(x)>>{}; });
1151+
static_assert(shape<mode_thr_k>(sC) == shape<3>(thr_vmnk), "ThrK split into multiple modes; unsupported");
1152+
1153+
auto k_to_mn = composition(make_layout(tile_mn), xssg_layout); // ThrK -> (M,N)
1154+
1155+
static_assert(size(SubtileSGLayout{}) == shape<3>(thr_vmnk) || mode_thr_k + 1 >= rank(sC),
1156+
"Unsupported partially occupied ThrK scenario");
1157+
1158+
// Remove subtile value modes.
1159+
auto drop_subtiles = make_layout(zip(sshape, shape_div(tile_mn, sshape)),
1160+
zip(stride(make_layout(tile_mn)), Stride<_0,_0>{}));
1161+
1162+
auto svC_tiled = make_layout(replace<mode_thr_k>(sC, k_to_mn),
1163+
coalesce(composition(drop_subtiles, get<1>(svC))));
1164+
1165+
// Derive copy tile layout and create TiledCopy
1166+
return make_block_2d_copy_X<ValType>(op, gstride, x_mode, y_mode, tile_mn, svC_tiled);
10121167
}
10131168

10141169
// Prefetch selection and creation.

0 commit comments

Comments
 (0)