3232#pragma once
3333
3434#include < cute/tensor_impl.hpp> // cute::Tensor
35+ #include < cute/util/sycl_vec.hpp> // intel::_SGSize
3536
3637namespace cute
3738{
@@ -74,7 +75,7 @@ struct SubgroupTensor : Tensor<Engine, Layout>
7475 *this = static_cast <SubgroupTensor const &>(base);
7576 }
7677
77- static constexpr int rank = Layout::rank;
78+ static constexpr int rank = Layout::rank;
7879
7980 CUTE_HOST_DEVICE constexpr
8081 decltype (auto )
@@ -89,13 +90,18 @@ struct SubgroupTensor : Tensor<Engine, Layout>
8990 }
9091};
9192
93+ template <class T >
94+ struct is_sg_tensor : false_type {};
95+ template <class Engine , class Layout , class SubgroupTVLayout >
96+ struct is_sg_tensor <SubgroupTensor<Engine,Layout,SubgroupTVLayout>> : true_type {};
97+
9298template <class Engine , class Layout , class SubgroupTVLayout >
9399struct is_tensor <SubgroupTensor<Engine,Layout,SubgroupTVLayout>> : true_type {};
94100
95- template <class Engine ,
96- class Layout ,
97- class SubgroupTVLayout ,
98- __CUTE_REQUIRES (is_layout<SubgroupTVLayout>::value)>
101+ template <class Engine ,
102+ class Layout ,
103+ class SubgroupTVLayout ,
104+ __CUTE_REQUIRES (is_layout<SubgroupTVLayout>::value)>
99105CUTE_HOST_DEVICE
100106constexpr auto
101107make_subgroup_tensor(Tensor<Engine, Layout> const & tensor, SubgroupTVLayout const &)
@@ -105,6 +111,47 @@ make_subgroup_tensor(Tensor<Engine, Layout> const& tensor, SubgroupTVLayout cons
105111 return static_cast <SubgroupTensor<Engine,Layout,SubgroupTVLayout> const &>(tensor);
106112}
107113
114+ template <typename T, class Shape , class Stride >
115+ CUTE_HOST_DEVICE
116+ constexpr auto
117+ make_subgroup_tensor (Layout<Shape,Stride> const & sg_layout)
118+ {
119+ using _SG = intel::_SGSize;
120+ auto ilayout = make_layout (make_shape (_SG{}, size (sg_layout) / _SG{}),
121+ make_stride (_1{}, _16{}));
122+ auto sv_layout = sg_layout.compose (ilayout);
123+ return make_subgroup_tensor (make_fragment_like<T>(sv_layout (0 ,_)), sv_layout);
124+ }
125+
126+ template <typename T, class ... Args>
127+ CUTE_HOST_DEVICE
128+ constexpr auto
129+ make_subgroup_tensor (Args const &... args)
130+ {
131+ return make_subgroup_tensor<T>(make_layout (args...));
132+ }
133+
134+
135+ // Replicate a subgroup fragment in a given mode.
136+ template <int Mode, int Expand, typename EngineIn, typename LayoutIn, typename TVLayoutIn>
137+ CUTE_HOST_DEVICE
138+ constexpr auto
139+ expand_sg_fragment_helper (SubgroupTensor<EngineIn,LayoutIn,TVLayoutIn> const &)
140+ {
141+ constexpr SubgroupTensor<EngineIn,LayoutIn,TVLayoutIn> frag;
142+ constexpr int ModeSize = get<Mode>(atuple_coshape (frag.tv_layout ()));
143+
144+ auto xlayout = append (frag.layout (),
145+ Layout<C<Expand>, C<cosize_v<LayoutIn>>>{});
146+ auto xv_layout = append (get<1 >(frag.tv_layout ()),
147+ make_layout (C<Expand>{}, C<ModeSize>{} * E<Mode>{}));
148+ auto xtv_layout = make_layout (get<0 >(frag.tv_layout ()), xv_layout);
149+
150+ return make_subgroup_tensor (make_tensor<typename EngineIn::element_type>(xlayout), xtv_layout);
151+ }
152+
153+ template <typename SGTensor, int Mode, int Expand>
154+ using expand_sg_fragment_t = decltype (expand_sg_fragment_helper<Mode, Expand>(SGTensor{}));
108155
109156//
110157// Display utilities
0 commit comments