Skip to content

Commit 21fb89a

Browse files
committed
[CuTe] [Xe] New make_subgroup_tensor helpers
1 parent 8819b01 commit 21fb89a

File tree

1 file changed

+52
-5
lines changed

1 file changed

+52
-5
lines changed

include/cute/tensor_sg.hpp

Lines changed: 52 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
#pragma once
3333

3434
#include <cute/tensor_impl.hpp> // cute::Tensor
35+
#include <cute/util/sycl_vec.hpp> // intel::_SGSize
3536

3637
namespace cute
3738
{
@@ -74,7 +75,7 @@ struct SubgroupTensor : Tensor<Engine, Layout>
7475
*this = static_cast<SubgroupTensor const&>(base);
7576
}
7677

77-
static constexpr int rank = Layout::rank;
78+
static constexpr int rank = Layout::rank;
7879

7980
CUTE_HOST_DEVICE constexpr
8081
decltype(auto)
@@ -89,13 +90,18 @@ struct SubgroupTensor : Tensor<Engine, Layout>
8990
}
9091
};
9192

93+
template <class T>
94+
struct is_sg_tensor : false_type {};
95+
template <class Engine, class Layout, class SubgroupTVLayout>
96+
struct is_sg_tensor<SubgroupTensor<Engine,Layout,SubgroupTVLayout>> : true_type {};
97+
9298
template <class Engine, class Layout, class SubgroupTVLayout>
9399
struct is_tensor<SubgroupTensor<Engine,Layout,SubgroupTVLayout>> : true_type {};
94100

95-
template<class Engine,
96-
class Layout,
97-
class SubgroupTVLayout,
98-
__CUTE_REQUIRES(is_layout<SubgroupTVLayout>::value)>
101+
template <class Engine,
102+
class Layout,
103+
class SubgroupTVLayout,
104+
__CUTE_REQUIRES(is_layout<SubgroupTVLayout>::value)>
99105
CUTE_HOST_DEVICE
100106
constexpr auto
101107
make_subgroup_tensor(Tensor<Engine, Layout> const& tensor, SubgroupTVLayout const&)
@@ -105,6 +111,47 @@ make_subgroup_tensor(Tensor<Engine, Layout> const& tensor, SubgroupTVLayout cons
105111
return static_cast<SubgroupTensor<Engine,Layout,SubgroupTVLayout> const&>(tensor);
106112
}
107113

114+
template <typename T, class Shape, class Stride>
115+
CUTE_HOST_DEVICE
116+
constexpr auto
117+
make_subgroup_tensor(Layout<Shape,Stride> const& sg_layout)
118+
{
119+
using _SG = intel::_SGSize;
120+
auto ilayout = make_layout(make_shape(_SG{}, size(sg_layout) / _SG{}),
121+
make_stride(_1{}, _16{}));
122+
auto sv_layout = sg_layout.compose(ilayout);
123+
return make_subgroup_tensor(make_fragment_like<T>(sv_layout(0,_)), sv_layout);
124+
}
125+
126+
template <typename T, class... Args>
127+
CUTE_HOST_DEVICE
128+
constexpr auto
129+
make_subgroup_tensor(Args const&... args)
130+
{
131+
return make_subgroup_tensor<T>(make_layout(args...));
132+
}
133+
134+
135+
// Replicate a subgroup fragment in a given mode.
136+
template <int Mode, int Expand, typename EngineIn, typename LayoutIn, typename TVLayoutIn>
137+
CUTE_HOST_DEVICE
138+
constexpr auto
139+
expand_sg_fragment_helper(SubgroupTensor<EngineIn,LayoutIn,TVLayoutIn> const&)
140+
{
141+
constexpr SubgroupTensor<EngineIn,LayoutIn,TVLayoutIn> frag;
142+
constexpr int ModeSize = get<Mode>(atuple_coshape(frag.tv_layout()));
143+
144+
auto xlayout = append(frag.layout(),
145+
Layout<C<Expand>, C<cosize_v<LayoutIn>>>{});
146+
auto xv_layout = append(get<1>(frag.tv_layout()),
147+
make_layout(C<Expand>{}, C<ModeSize>{} * E<Mode>{}));
148+
auto xtv_layout = make_layout(get<0>(frag.tv_layout()), xv_layout);
149+
150+
return make_subgroup_tensor(make_tensor<typename EngineIn::element_type>(xlayout), xtv_layout);
151+
}
152+
153+
template <typename SGTensor, int Mode, int Expand>
154+
using expand_sg_fragment_t = decltype(expand_sg_fragment_helper<Mode, Expand>(SGTensor{}));
108155

109156
//
110157
// Display utilities

0 commit comments

Comments
 (0)