@@ -1008,7 +1008,162 @@ make_block_2d_copy_CD(CopyOp const& op, // Copy operation
10081008 make_tile (sg_to_vmn, _)); // (SG,V) -> (M,N)
10091009
10101010 // Derive copy tile layout and create TiledCopy
1011- return make_block_2d_copy_X<ValType>(op, mma, gstride, x_mode, y_mode, tile_mn, svC);
1011+ return make_block_2d_copy_X<ValType>(op, gstride, x_mode, y_mode, tile_mn, svC);
1012+ }
1013+
1014+ // Variants of make_block_2d_copy_C/D where the C/D tile is further subdivided by the user.
1015+ // (e.g. split-k parallelization).
1016+
1017+ template <class TiledMMA ,
1018+ class SubtileTVCoordLayout , class SubtileSGLayout ,
1019+ class GEngine , class GLayout ,
1020+ __CUTE_REQUIRES (is_layout_v<SubtileSGLayout>)>
1021+ CUTE_HOST_DEVICE
1022+ auto
1023+ make_block_2d_copy_C_subtiled(TiledMMA const & mma, // TiledMMA instance
1024+ SubtileTVCoordLayout const & stv_layout, // Subtile TV-layout: (T,V) -> coord
1025+ SubtileSGLayout const & ssg_layout, // Subtile subgroup layout: SG_K -> (m_subtile,n_subtile)
1026+ Tensor<GEngine, GLayout> const & gmem) // Global tensor
1027+ {
1028+ using ValType = typename GEngine::value_type;
1029+ return make_block_2d_copy_C_subtiled<ValType>(mma, stv_layout, ssg_layout, gmem.stride ()).with (gmem);
1030+ }
1031+
1032+ template <class TiledMMA ,
1033+ class SubtileTVCoordLayout , class SubtileSGLayout ,
1034+ class GEngine , class GLayout ,
1035+ __CUTE_REQUIRES (is_layout_v<SubtileSGLayout>)>
1036+ CUTE_HOST_DEVICE
1037+ auto
1038+ make_block_2d_copy_D_subtiled(TiledMMA const & mma, // TiledMMA instance
1039+ SubtileTVCoordLayout const & stv_layout, // Subtile TV-layout: (T,V) -> coord
1040+ SubtileSGLayout const & ssg_layout, // Subtile subgroup layout: SG_K -> (m_subtile,n_subtile)
1041+ Tensor<GEngine, GLayout> const & gmem) // Global tensor
1042+ {
1043+ using ValType = typename GEngine::value_type;
1044+ return make_block_2d_copy_D_subtiled<ValType>(mma, stv_layout, ssg_layout, gmem.stride ()).with (gmem);
1045+ }
1046+
1047+ template <class TiledMMA ,
1048+ class SubtileShape , class SubtileSGLayout ,
1049+ class CopyOp , class GEngine , class GLayout ,
1050+ __CUTE_REQUIRES (is_layout_v<SubtileSGLayout>)>
1051+ CUTE_HOST_DEVICE
1052+ auto
1053+ make_block_2d_copy_CD_subtiled(CopyOp const & op, // Copy operation
1054+ TiledMMA const & mma, // TiledMMA instance
1055+ SubtileShape const & sshape, // Subtile shape: (m,n)
1056+ SubtileSGLayout const & ssg_layout, // Subtile subgroup layout: SG_K -> (m_subtile,n_subtile)
1057+ Tensor<GEngine, GLayout> const & gmem) // Global tensor
1058+ {
1059+ using ValType = typename GEngine::value_type;
1060+ return make_block_2d_copy_CD_subtiled<ValType>(op, sshape, ssg_layout, mma, gmem.stride ()).with (gmem);
1061+ }
1062+
1063+ template <class ValType , class TiledMMA ,
1064+ class SubtileTVCoordLayout , class SubtileSGLayout ,
1065+ class ... Strides,
1066+ __CUTE_REQUIRES (is_layout_v<SubtileSGLayout>)>
1067+ CUTE_HOST_DEVICE
1068+ auto
1069+ make_block_2d_copy_C_subtiled(TiledMMA const & mma, // TiledMMA instance
1070+ SubtileTVCoordLayout const & stv_layout, // Subtile TV-layout: (T,V) -> coord
1071+ SubtileSGLayout const & ssg_layout, // Subtile subgroup layout: SG_K -> (m_subtile,n_subtile)
1072+ Stride<Strides...> const & gstride) // Global memory strides
1073+ {
1074+ using MMAType = typename TiledMMA::ValTypeA;
1075+ auto op = block_2d_selector<ValType, MMAType>(stv_layout, gstride);
1076+ return make_block_2d_copy_CD_subtiled<ValType>(op, mma, atuple_coshape (stv_layout), ssg_layout, gstride);
1077+ }
1078+
1079+ template <class ValType , class TiledMMA ,
1080+ class SubtileTVCoordLayout , class SubtileSGLayout ,
1081+ class ... Strides,
1082+ __CUTE_REQUIRES (is_layout_v<SubtileSGLayout>)>
1083+ CUTE_HOST_DEVICE
1084+ auto
1085+ make_block_2d_copy_D_subtiled(TiledMMA const & mma, // TiledMMA instance
1086+ SubtileTVCoordLayout const & stv_layout, // Subtile TV-layout: (T,V) -> coord
1087+ SubtileSGLayout const & ssg_layout, // Subtile subgroup layout: SG_K -> (m_subtile,n_subtile)
1088+ Stride<Strides...> const & gstride) // Global memory strides
1089+ {
1090+ using MMAType = typename TiledMMA::ValTypeA;
1091+ auto op = block_2d_selector<ValType, MMAType, true >(stv_layout, gstride);
1092+ return make_block_2d_copy_CD_subtiled<ValType>(op, mma, atuple_coshape (stv_layout), ssg_layout, gstride);
1093+ }
1094+
1095+ template <class ValType , class TiledMMA , class CopyOp ,
1096+ class SubtileShape , class SubtileSGLayout ,
1097+ class ... Strides,
1098+ __CUTE_REQUIRES (is_layout_v<SubtileSGLayout>)>
1099+ CUTE_HOST_DEVICE
1100+ auto
1101+ make_block_2d_copy_CD_subtiled(CopyOp const & op, // Copy operation
1102+ TiledMMA const & mma, // TiledMMA instance
1103+ SubtileShape const & sshape, // Subtile shape: (m,n)
1104+ SubtileSGLayout const & ssg_layout, // Subtile subgroup layout: SG_K -> (m_subtile,n_subtile)
1105+ Stride<Strides...> const & gstride) // Global memory strides
1106+ {
1107+ return make_block_2d_copy_CD_subtiled<ValType>(op, mma, sshape, ssg_layout, gstride,
1108+ find_x_mode (gstride), find_y_mode (gstride));
1109+ }
1110+
1111+ template <class ValType , class TiledMMA , class CopyOp ,
1112+ class SubtileShape , class SubtileSGLayout ,
1113+ class ... Strides, class XMode , class YMode ,
1114+ __CUTE_REQUIRES (is_layout_v<SubtileSGLayout>)>
1115+ CUTE_HOST_DEVICE
1116+ auto
1117+ make_block_2d_copy_CD_subtiled(CopyOp const & op, // Copy operation
1118+ TiledMMA const & mma, // TiledMMA instance
1119+ SubtileShape const & sshape, // Subtile shape: (m,n)
1120+ SubtileSGLayout const & ssg_layout, // Subtile subgroup layout: SG_K -> (m_subtile,n_subtile)
1121+ Stride<Strides...> const & gstride, // Global memory strides
1122+ XMode const & x_mode, // x, y modes
1123+ YMode const & y_mode)
1124+ {
1125+ // Expand subtile layout.
1126+ auto xssg_layout = make_layout (shape (ssg_layout),
1127+ elem_scale (stride (ssg_layout), sshape)); // SG_K -> (M,N)
1128+
1129+ // Retrieve MMA atom's (subgroup, value) -> (M,N) layout.
1130+ // Allow cross-MMA tiling.
1131+ auto tile_mn = round_up (select<0 ,1 >(mma.tile_mnk ()),
1132+ atuple_coshape (xssg_layout));
1133+
1134+ auto thr_vmnk = mma.get_thr_layout_vmnk (); // (ThrV,ThrM,ThrN,ThrK) -> thr
1135+ auto shape_vmnk = shape (thr_vmnk); // (ThrV,ThrM,ThrN,ThrK)
1136+ auto drop_k = replace<3 >(make_layout (shape_vmnk),
1137+ make_layout (get<3 >(shape_vmnk), _0{})); // (ThrV,ThrM,ThrN,ThrK) -> (ThrV,ThrM,ThrN)
1138+
1139+ auto thr_to_vmn = composition (drop_k, right_inverse (thr_vmnk)); // thr -> (ThrV,ThrM,ThrN)
1140+ auto sg_to_vmn = composition (thr_to_vmn,
1141+ make_layout (product (take<1 ,4 >(shape_vmnk)), get<0 >(shape_vmnk))); // SG -> (0,ThrM,ThrN)
1142+
1143+ auto svC = composition (mma.thrfrg_C (make_layout (tile_mn)),
1144+ make_tile (sg_to_vmn, _)); // (SG,V) -> (M,N)
1145+
1146+ // Add subtile modes. Limitations:
1147+ // - ThrK must be covered by a single mode in svC.
1148+ // - SubtileSGLayout must have a subtile for each ThrK, OR ThrK must be the last mode.
1149+ decltype (coalesce (get<0 >(svC))) sC {};
1150+ constexpr auto mode_thr_k = find_if (stride (sC ), [](auto const &x) { return C<is_constant_v<0 , decltype (x)>>{}; });
1151+ static_assert (shape<mode_thr_k>(sC ) == shape<3 >(thr_vmnk), " ThrK split into multiple modes; unsupported" );
1152+
1153+ auto k_to_mn = composition (make_layout (tile_mn), xssg_layout); // ThrK -> (M,N)
1154+
1155+ static_assert (size (SubtileSGLayout{}) == shape<3 >(thr_vmnk) || mode_thr_k + 1 >= rank (sC ),
1156+ " Unsupported partially occupied ThrK scenario" );
1157+
1158+ // Remove subtile value modes.
1159+ auto drop_subtiles = make_layout (zip (sshape, shape_div (tile_mn, sshape)),
1160+ zip (stride (make_layout (tile_mn)), Stride<_0,_0>{}));
1161+
1162+ auto svC_tiled = make_layout (replace<mode_thr_k>(sC , k_to_mn),
1163+ coalesce (composition (drop_subtiles, get<1 >(svC))));
1164+
1165+ // Derive copy tile layout and create TiledCopy
1166+ return make_block_2d_copy_X<ValType>(op, gstride, x_mode, y_mode, tile_mn, svC_tiled);
10121167}
10131168
10141169// Prefetch selection and creation.
0 commit comments