Skip to content

Commit 9a6aa27

Browse files
committed
[CuTe] [Xe] Subgroup-scope broadcast/reduction
1 parent 21fb89a commit 9a6aa27

File tree

1 file changed

+379
-0
lines changed

1 file changed

+379
-0
lines changed
Lines changed: 379 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,379 @@
1+
/***************************************************************************************************
2+
* Copyright (C) 2025 Intel Corporation, All rights reserved.
3+
* SPDX-License-Identifier: BSD-3-Clause
4+
*
5+
* Redistribution and use in source and binary forms, with or without
6+
* modification, are permitted provided that the following conditions are met:
7+
*
8+
* 1. Redistributions of source code must retain the above copyright notice, this
9+
* list of conditions and the following disclaimer.
10+
*
11+
* 2. Redistributions in binary form must reproduce the above copyright notice,
12+
* this list of conditions and the following disclaimer in the documentation
13+
* and/or other materials provided with the distribution.
14+
*
15+
* 3. Neither the name of the copyright holder nor the names of its
16+
* contributors may be used to endorse or promote products derived from
17+
* this software without specific prior written permission.
18+
*
19+
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20+
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21+
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22+
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23+
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24+
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25+
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26+
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27+
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28+
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29+
*
30+
**************************************************************************************************/
31+
32+
#pragma once
33+
34+
#include "cute/tensor.hpp"
35+
#include "cute/util/sycl_vec.hpp"
36+
37+
namespace cute {
38+
39+
// Uniformize a value, in case the compiler cannot prove it is subgroup-uniform.
40+
template <typename T>
41+
CUTE_HOST_DEVICE
42+
T
43+
assert_uniform(T x) {
44+
auto sg = sycl::ext::oneapi::this_work_item::get_sub_group();
45+
return group_broadcast(sg, x, 0);
46+
}
47+
48+
// Set a value in a single work-item -- x[i] = val.
49+
// WARNING: i _must_ be a compile-time constant.
50+
// No diagnostics/error will be issued by the compiler if it is not.
51+
template <typename T>
52+
CUTE_HOST_DEVICE void
53+
set_wi_value(T &x, int i, T val)
54+
{
55+
#if defined(__SYCL_DEVICE_ONLY__) && defined(SYCL_INTEL_TARGET)
56+
asm (
57+
"mov (M1_NM, 1) %0(0,%2)<1> %1(0,0)<1;1,0>"
58+
: "+rw"(x)
59+
: "rw.u"(val), "P"(i)
60+
);
61+
#else
62+
int lane = sycl::ext::oneapi::this_work_item::get_sub_group().get_local_id()[0];
63+
if (lane == i)
64+
x = val;
65+
#endif
66+
}
67+
68+
// Set an element of a 1D SG-shared fragment x.
69+
// WARNING: i _must_ be a compile-time constant.
70+
// No diagnostics/error will be issued by the compiler if it is not.
71+
template <typename FragX>
72+
CUTE_HOST_DEVICE void
73+
set_single_value(FragX& x, int i, typename FragX::element_type val) {
74+
set_wi_value(x(i / intel::sg_size), i % intel::sg_size, val);
75+
}
76+
77+
// Broadcast the element from a 1D SG-shared fragment x
78+
// corresponding to the Mode'th dimension of the logical coordinates of src(val).
79+
template <int Mode, typename FragX, typename SGTensorSrc,
80+
__CUTE_REQUIRES(is_sg_tensor<SGTensorSrc>::value)>
81+
CUTE_HOST_DEVICE
82+
constexpr auto
83+
broadcast(FragX const& x, SGTensorSrc const& src, int val)
84+
{
85+
auto coord = src.tv_layout()(0, val);
86+
auto coord_i = get<Mode>(coord);
87+
88+
constexpr auto TMode = rank(as_arithmetic_tuple(stride<0>(SGTensorSrc{}.tv_layout()))) - 1;
89+
if constexpr (TMode == Mode) {
90+
return x(coord_i / intel::sg_size);
91+
} else {
92+
auto sg = sycl::ext::oneapi::this_work_item::get_sub_group();
93+
return group_broadcast(sg, x(coord_i / intel::sg_size), coord_i % intel::sg_size);
94+
}
95+
}
96+
97+
#if defined(__SYCL_DEVICE_ONLY__) && defined(SYCL_INTEL_TARGET)
98+
#define DEFINE_HREDUCE16_FLOAT(op) \
99+
CUTE_DEVICE \
100+
float \
101+
hreduce16_float_ ## op(float x[16]) \
102+
{ \
103+
float y; \
104+
asm ( \
105+
"{\n" \
106+
".decl INTERLEAVE_2 v_type=P num_elts=16\n" \
107+
".decl INTERLEAVE_4 v_type=P num_elts=16\n" \
108+
".decl INTERLEAVE_8 v_type=P num_elts=16\n" \
109+
".decl IN0 v_type=G type=UD num_elts=16 alias=<%1,0>\n" \
110+
".decl IN1 v_type=G type=UD num_elts=16 alias=<%2,0>\n" \
111+
".decl IN2 v_type=G type=UD num_elts=16 alias=<%3,0>\n" \
112+
".decl IN3 v_type=G type=UD num_elts=16 alias=<%4,0>\n" \
113+
".decl IN4 v_type=G type=UD num_elts=16 alias=<%5,0>\n" \
114+
".decl IN5 v_type=G type=UD num_elts=16 alias=<%6,0>\n" \
115+
".decl IN6 v_type=G type=UD num_elts=16 alias=<%7,0>\n" \
116+
".decl IN7 v_type=G type=UD num_elts=16 alias=<%8,0>\n" \
117+
".decl IN8 v_type=G type=UD num_elts=16 alias=<%9,0>\n" \
118+
".decl IN9 v_type=G type=UD num_elts=16 alias=<%10,0>\n" \
119+
".decl IN10 v_type=G type=UD num_elts=16 alias=<%11,0>\n" \
120+
".decl IN11 v_type=G type=UD num_elts=16 alias=<%12,0>\n" \
121+
".decl IN12 v_type=G type=UD num_elts=16 alias=<%13,0>\n" \
122+
".decl IN13 v_type=G type=UD num_elts=16 alias=<%14,0>\n" \
123+
".decl IN14 v_type=G type=UD num_elts=16 alias=<%15,0>\n" \
124+
".decl IN15 v_type=G type=UD num_elts=16 alias=<%16,0>\n" \
125+
".decl RA0 v_type=G type=UD num_elts=32 align=64\n" \
126+
".decl RA2 v_type=G type=UD num_elts=32 align=64\n" \
127+
".decl RA4 v_type=G type=UD num_elts=32 align=64\n" \
128+
".decl RA6 v_type=G type=UD num_elts=32 align=64\n" \
129+
".decl RA8 v_type=G type=UD num_elts=32 align=64\n" \
130+
".decl RA10 v_type=G type=UD num_elts=32 align=64\n" \
131+
".decl RA12 v_type=G type=UD num_elts=32 align=64\n" \
132+
".decl RA14 v_type=G type=UD num_elts=32 align=64\n" \
133+
".decl RF0 v_type=G type=F num_elts=16 alias=<RA0,0>\n" \
134+
".decl RF1 v_type=G type=F num_elts=16 alias=<RA0,64>\n" \
135+
".decl RF2 v_type=G type=F num_elts=16 alias=<RA2,0>\n" \
136+
".decl RF3 v_type=G type=F num_elts=16 alias=<RA2,64>\n" \
137+
".decl RF4 v_type=G type=F num_elts=16 alias=<RA4,0>\n" \
138+
".decl RF5 v_type=G type=F num_elts=16 alias=<RA4,64>\n" \
139+
".decl RF6 v_type=G type=F num_elts=16 alias=<RA6,0>\n" \
140+
".decl RF7 v_type=G type=F num_elts=16 alias=<RA6,64>\n" \
141+
".decl RF8 v_type=G type=F num_elts=16 alias=<RA8,0>\n" \
142+
".decl RF9 v_type=G type=F num_elts=16 alias=<RA8,64>\n" \
143+
".decl RF10 v_type=G type=F num_elts=16 alias=<RA10,0>\n" \
144+
".decl RF11 v_type=G type=F num_elts=16 alias=<RA10,64>\n" \
145+
".decl RF12 v_type=G type=F num_elts=16 alias=<RA12,0>\n" \
146+
".decl RF13 v_type=G type=F num_elts=16 alias=<RA12,64>\n" \
147+
".decl RF14 v_type=G type=F num_elts=16 alias=<RA14,0>\n" \
148+
".decl RF15 v_type=G type=F num_elts=16 alias=<RA14,64>\n" \
149+
"setp (M1_NM,16) INTERLEAVE_2 0x5555:uw\n" \
150+
"setp (M1_NM,16) INTERLEAVE_4 0x3333:uw\n" \
151+
"setp (M1_NM,16) INTERLEAVE_8 0x0F0F:uw\n" \
152+
/* Round 1: interleave 2n with 2n+1 */ \
153+
"(!INTERLEAVE_2) sel (M1_NM,16) RA0(0,0)<1> IN1(0,0)<2;2,0> IN0(0,0)<1;1,0>\n" \
154+
" (INTERLEAVE_2) sel (M1_NM,16) RA0(1,0)<1> IN0(0,1)<2;2,0> IN1(0,0)<1;1,0>\n" \
155+
"(!INTERLEAVE_2) sel (M1_NM,16) RA2(0,0)<1> IN3(0,0)<2;2,0> IN2(0,0)<1;1,0>\n" \
156+
" (INTERLEAVE_2) sel (M1_NM,16) RA2(1,0)<1> IN2(0,1)<2;2,0> IN3(0,0)<1;1,0>\n" \
157+
"(!INTERLEAVE_2) sel (M1_NM,16) RA4(0,0)<1> IN5(0,0)<2;2,0> IN4(0,0)<1;1,0>\n" \
158+
" (INTERLEAVE_2) sel (M1_NM,16) RA4(1,0)<1> IN4(0,1)<2;2,0> IN5(0,0)<1;1,0>\n" \
159+
"(!INTERLEAVE_2) sel (M1_NM,16) RA6(0,0)<1> IN7(0,0)<2;2,0> IN6(0,0)<1;1,0>\n" \
160+
" (INTERLEAVE_2) sel (M1_NM,16) RA6(1,0)<1> IN6(0,1)<2;2,0> IN7(0,0)<1;1,0>\n" \
161+
"(!INTERLEAVE_2) sel (M1_NM,16) RA8(0,0)<1> IN9(0,0)<2;2,0> IN8(0,0)<1;1,0>\n" \
162+
" (INTERLEAVE_2) sel (M1_NM,16) RA8(1,0)<1> IN8(0,1)<2;2,0> IN9(0,0)<1;1,0>\n" \
163+
"(!INTERLEAVE_2) sel (M1_NM,16) RA10(0,0)<1> IN11(0,0)<2;2,0> IN10(0,0)<1;1,0>\n" \
164+
" (INTERLEAVE_2) sel (M1_NM,16) RA10(1,0)<1> IN10(0,1)<2;2,0> IN11(0,0)<1;1,0>\n" \
165+
"(!INTERLEAVE_2) sel (M1_NM,16) RA12(0,0)<1> IN13(0,0)<2;2,0> IN12(0,0)<1;1,0>\n" \
166+
" (INTERLEAVE_2) sel (M1_NM,16) RA12(1,0)<1> IN12(0,1)<2;2,0> IN13(0,0)<1;1,0>\n" \
167+
"(!INTERLEAVE_2) sel (M1_NM,16) RA14(0,0)<1> IN15(0,0)<2;2,0> IN14(0,0)<1;1,0>\n" \
168+
" (INTERLEAVE_2) sel (M1_NM,16) RA14(1,0)<1> IN14(0,1)<2;2,0> IN15(0,0)<1;1,0>\n" \
169+
/* Reduce */ \
170+
#op " (M1_NM,16) RF0(0,0)<1> RF0(0,0)<1;1,0> RF1(0,0)<1;1,0>\n" \
171+
#op " (M1_NM,16) RF3(0,0)<1> RF2(0,0)<1;1,0> RF3(0,0)<1;1,0>\n" \
172+
#op " (M1_NM,16) RF4(0,0)<1> RF4(0,0)<1;1,0> RF5(0,0)<1;1,0>\n" \
173+
#op " (M1_NM,16) RF7(0,0)<1> RF6(0,0)<1;1,0> RF7(0,0)<1;1,0>\n" \
174+
#op " (M1_NM,16) RF8(0,0)<1> RF8(0,0)<1;1,0> RF9(0,0)<1;1,0>\n" \
175+
#op " (M1_NM,16) RF11(0,0)<1> RF10(0,0)<1;1,0> RF11(0,0)<1;1,0>\n" \
176+
#op " (M1_NM,16) RF12(0,0)<1> RF12(0,0)<1;1,0> RF13(0,0)<1;1,0>\n" \
177+
#op " (M1_NM,16) RF15(0,0)<1> RF14(0,0)<1;1,0> RF15(0,0)<1;1,0>\n" \
178+
/* Round 2: interleave 4n+{0,1} with 4n+{2,3} */ \
179+
"(!INTERLEAVE_4) sel (M1_NM,16) RA0(1,0)<1> RA2(0,14)<1;1,0> RA0(0,0)<1;1,0>\n" \
180+
" (INTERLEAVE_4) sel (M1_NM,16) RA0(0,0)<1> RA0(0,2)<1;1,0> RA2(1,0)<1;1,0>\n" \
181+
"(!INTERLEAVE_4) sel (M1_NM,16) RA4(1,0)<1> RA6(0,14)<1;1,0> RA4(0,0)<1;1,0>\n" \
182+
" (INTERLEAVE_4) sel (M1_NM,16) RA4(0,0)<1> RA4(0,2)<1;1,0> RA6(1,0)<1;1,0>\n" \
183+
"(!INTERLEAVE_4) sel (M1_NM,16) RA8(1,0)<1> RA10(0,14)<1;1,0> RA8(0,0)<1;1,0>\n" \
184+
" (INTERLEAVE_4) sel (M1_NM,16) RA8(0,0)<1> RA8(0,2)<1;1,0> RA10(1,0)<1;1,0>\n" \
185+
"(!INTERLEAVE_4) sel (M1_NM,16) RA12(1,0)<1> RA14(0,14)<1;1,0> RA12(0,0)<1;1,0>\n" \
186+
" (INTERLEAVE_4) sel (M1_NM,16) RA12(0,0)<1> RA12(0,2)<1;1,0> RA14(1,0)<1;1,0>\n" \
187+
/* Reduce */ \
188+
#op " (M1_NM,16) RF0(0,0)<1> RF0(0,0)<1;1,0> RF1(0,0)<1;1,0>\n" \
189+
#op " (M1_NM,16) RF5(0,0)<1> RF4(0,0)<1;1,0> RF5(0,0)<1;1,0>\n" \
190+
#op " (M1_NM,16) RF8(0,0)<1> RF8(0,0)<1;1,0> RF9(0,0)<1;1,0>\n" \
191+
#op " (M1_NM,16) RF13(0,0)<1> RF12(0,0)<1;1,0> RF13(0,0)<1;1,0>\n" \
192+
/* Round 3: interleave 8n+{0,1,2,3} with 8n+{4,5,6,7} */ \
193+
"(!INTERLEAVE_8) sel (M1_NM,16) RA0(1,0)<1> RA4(0,12)<1;1,0> RA0(0,0)<1;1,0>\n" \
194+
" (INTERLEAVE_8) sel (M1_NM,16) RA0(0,0)<1> RA0(0,4)<1;1,0> RA4(1,0)<1;1,0>\n" \
195+
"(!INTERLEAVE_8) sel (M1_NM,16) RA8(1,0)<1> RA12(0,12)<1;1,0> RA8(0,0)<1;1,0>\n" \
196+
" (INTERLEAVE_8) sel (M1_NM,16) RA8(0,0)<1> RA8(0,4)<1;1,0> RA12(1,0)<1;1,0>\n" \
197+
/* Reduce */ \
198+
#op " (M1_NM,16) RF0(0,0)<1> RF0(0,0)<1;1,0> RF1(0,0)<1;1,0>\n" \
199+
#op " (M1_NM,16) RF8(0,0)<1> RF8(0,0)<1;1,0> RF9(0,0)<1;1,0>\n" \
200+
/* Round 4: final interleave */ \
201+
"mov (M1_NM, 8) RA0(1,0)<1> RA0(0,8)<1;1,0>\n" \
202+
"mov (M1_NM, 8) RA8(1,8)<1> RA8(0,0)<1;1,0>\n" \
203+
/* Reduce */ \
204+
#op " (M1_NM,8) %0(0,0)<1> RF0(0,0)<1;1,0> RF1(0,0)<1;1,0>\n" \
205+
#op " (M1_NM,8) %0(0,8)<1> RF8(0,8)<1;1,0> RF9(0,8)<1;1,0>\n" \
206+
"}\n" \
207+
: "=rw"(y) \
208+
: "rw"(x[0]), "rw"(x[1]), "rw"(x[2]), "rw"(x[3]), "rw"(x[4]), "rw"(x[5]), "rw"(x[6]), "rw"(x[7]), \
209+
"rw"(x[8]), "rw"(x[9]), "rw"(x[10]), "rw"(x[11]), "rw"(x[12]), "rw"(x[13]), "rw"(x[14]), "rw"(x[15]) \
210+
); \
211+
return y; \
212+
}
213+
214+
#define DEFINE_HREDUCE8_FLOAT(op) \
215+
CUTE_DEVICE \
216+
float \
217+
hreduce8_float_ ## op(float x[8]) \
218+
{ \
219+
float y; \
220+
asm ( \
221+
"{\n" \
222+
".decl INTERLEAVE_2 v_type=P num_elts=16\n" \
223+
".decl INTERLEAVE_4 v_type=P num_elts=16\n" \
224+
".decl INTERLEAVE_8 v_type=P num_elts=16\n" \
225+
".decl IN0 v_type=G type=UD num_elts=16 alias=<%1,0>\n" \
226+
".decl IN1 v_type=G type=UD num_elts=16 alias=<%2,0>\n" \
227+
".decl IN2 v_type=G type=UD num_elts=16 alias=<%3,0>\n" \
228+
".decl IN3 v_type=G type=UD num_elts=16 alias=<%4,0>\n" \
229+
".decl IN4 v_type=G type=UD num_elts=16 alias=<%5,0>\n" \
230+
".decl IN5 v_type=G type=UD num_elts=16 alias=<%6,0>\n" \
231+
".decl IN6 v_type=G type=UD num_elts=16 alias=<%7,0>\n" \
232+
".decl IN7 v_type=G type=UD num_elts=16 alias=<%8,0>\n" \
233+
".decl RA0 v_type=G type=UD num_elts=32 align=64\n" \
234+
".decl RA2 v_type=G type=UD num_elts=32 align=64\n" \
235+
".decl RA4 v_type=G type=UD num_elts=32 align=64\n" \
236+
".decl RA6 v_type=G type=UD num_elts=32 align=64\n" \
237+
".decl RF0 v_type=G type=F num_elts=16 alias=<RA0,0>\n" \
238+
".decl RF1 v_type=G type=F num_elts=16 alias=<RA0,64>\n" \
239+
".decl RF2 v_type=G type=F num_elts=16 alias=<RA2,0>\n" \
240+
".decl RF3 v_type=G type=F num_elts=16 alias=<RA2,64>\n" \
241+
".decl RF4 v_type=G type=F num_elts=16 alias=<RA4,0>\n" \
242+
".decl RF5 v_type=G type=F num_elts=16 alias=<RA4,64>\n" \
243+
".decl RF6 v_type=G type=F num_elts=16 alias=<RA6,0>\n" \
244+
".decl RF7 v_type=G type=F num_elts=16 alias=<RA6,64>\n" \
245+
"setp (M1_NM,16) INTERLEAVE_2 0x5555:uw\n" \
246+
"setp (M1_NM,16) INTERLEAVE_4 0x3333:uw\n" \
247+
"setp (M1_NM,16) INTERLEAVE_8 0x0F0F:uw\n" \
248+
/* Round 1: interleave 2n with 2n+1 */ \
249+
"(!INTERLEAVE_2) sel (M1_NM,16) RA0(0,0)<1> IN1(0,0)<2;2,0> IN0(0,0)<1;1,0>\n" \
250+
" (INTERLEAVE_2) sel (M1_NM,16) RA0(1,0)<1> IN0(0,1)<2;2,0> IN1(0,0)<1;1,0>\n" \
251+
"(!INTERLEAVE_2) sel (M1_NM,16) RA2(0,0)<1> IN3(0,0)<2;2,0> IN2(0,0)<1;1,0>\n" \
252+
" (INTERLEAVE_2) sel (M1_NM,16) RA2(1,0)<1> IN2(0,1)<2;2,0> IN3(0,0)<1;1,0>\n" \
253+
"(!INTERLEAVE_2) sel (M1_NM,16) RA4(0,0)<1> IN5(0,0)<2;2,0> IN4(0,0)<1;1,0>\n" \
254+
" (INTERLEAVE_2) sel (M1_NM,16) RA4(1,0)<1> IN4(0,1)<2;2,0> IN5(0,0)<1;1,0>\n" \
255+
"(!INTERLEAVE_2) sel (M1_NM,16) RA6(0,0)<1> IN7(0,0)<2;2,0> IN6(0,0)<1;1,0>\n" \
256+
" (INTERLEAVE_2) sel (M1_NM,16) RA6(1,0)<1> IN6(0,1)<2;2,0> IN7(0,0)<1;1,0>\n" \
257+
/* Reduce */ \
258+
#op " (M1_NM,16) RF0(0,0)<1> RF0(0,0)<1;1,0> RF1(0,0)<1;1,0>\n" \
259+
#op " (M1_NM,16) RF3(0,0)<1> RF2(0,0)<1;1,0> RF3(0,0)<1;1,0>\n" \
260+
#op " (M1_NM,16) RF4(0,0)<1> RF4(0,0)<1;1,0> RF5(0,0)<1;1,0>\n" \
261+
#op " (M1_NM,16) RF7(0,0)<1> RF6(0,0)<1;1,0> RF7(0,0)<1;1,0>\n" \
262+
/* Round 2: interleave 4n+{0,1} with 4n+{2,3} */ \
263+
"(!INTERLEAVE_4) sel (M1_NM,16) RA0(1,0)<1> RA2(0,14)<1;1,0> RA0(0,0)<1;1,0>\n" \
264+
" (INTERLEAVE_4) sel (M1_NM,16) RA0(0,0)<1> RA0(0,2)<1;1,0> RA2(1,0)<1;1,0>\n" \
265+
"(!INTERLEAVE_4) sel (M1_NM,16) RA4(1,0)<1> RA6(0,14)<1;1,0> RA4(0,0)<1;1,0>\n" \
266+
" (INTERLEAVE_4) sel (M1_NM,16) RA4(0,0)<1> RA4(0,2)<1;1,0> RA6(1,0)<1;1,0>\n" \
267+
/* Reduce */ \
268+
#op " (M1_NM,16) RF0(0,0)<1> RF0(0,0)<1;1,0> RF1(0,0)<1;1,0>\n" \
269+
#op " (M1_NM,16) RF5(0,0)<1> RF4(0,0)<1;1,0> RF5(0,0)<1;1,0>\n" \
270+
/* Round 3: interleave 8n+{0,1,2,3} with 8n+{4,5,6,7} */ \
271+
"(!INTERLEAVE_8) sel (M1_NM,16) RA0(1,0)<1> RA4(0,12)<1;1,0> RA0(0,0)<1;1,0>\n" \
272+
" (INTERLEAVE_8) sel (M1_NM,16) RA0(0,0)<1> RA0(0,4)<1;1,0> RA4(1,0)<1;1,0>\n" \
273+
/* Reduce */ \
274+
#op " (M1_NM,16) RF0(0,0)<1> RF0(0,0)<1;1,0> RF1(0,0)<1;1,0>\n" \
275+
/* Round 4: reduce top and bottom halves */ \
276+
#op " (M1_NM,8) %0(0,0)<1> RF0(0,0)<1;1,0> RF0(0,8)<1;1,0>\n" \
277+
"}\n" \
278+
: "=rw"(y) \
279+
: "rw"(x[0]), "rw"(x[1]), "rw"(x[2]), "rw"(x[3]), "rw"(x[4]), "rw"(x[5]), "rw"(x[6]), "rw"(x[7]), \
280+
"rw"(x[8]), "rw"(x[9]), "rw"(x[10]), "rw"(x[11]), "rw"(x[12]), "rw"(x[13]), "rw"(x[14]), "rw"(x[15]) \
281+
); \
282+
return y; \
283+
}
284+
#else
285+
#define DEFINE_HREDUCE16_FLOAT(op) \
286+
CUTE_DEVICE float hreduce16_float_ ## op(float x[16]) { return 0.f; }
287+
#define DEFINE_HREDUCE8_FLOAT(op) \
288+
CUTE_DEVICE float hreduce8_float_ ## op(float x[8]) { return 0.f; }
289+
#endif
290+
291+
DEFINE_HREDUCE8_FLOAT(add)
292+
DEFINE_HREDUCE8_FLOAT(max)
293+
DEFINE_HREDUCE16_FLOAT(add)
294+
DEFINE_HREDUCE16_FLOAT(max)
295+
296+
// Subgroup-cooperative reduction of a SubgroupTensor.
297+
template <int Mode, class BinaryOp,
298+
class Engine, class FragLayout, class SubgroupTVLayout>
299+
CUTE_HOST_DEVICE
300+
auto
301+
reduce(SubgroupTensor<Engine,FragLayout,SubgroupTVLayout> const& src, BinaryOp op)
302+
{
303+
auto sg = sycl::ext::oneapi::this_work_item::get_sub_group();
304+
using T = typename Engine::value_type;
305+
using TVToV = Layout<Shape<intel::_SGSize,int>, Stride<_0,_1>>;
306+
307+
/* Retrieve logical coordinate -> (T,V) mapping */
308+
constexpr auto shape = atuple_coshape(SubgroupTVLayout{});
309+
constexpr auto coord_to_tv = right_inverse(project_strides(SubgroupTVLayout{})).with_shape(shape);
310+
311+
/* Move reduction coordinate to mode-0 and group the rest in mode-1. Then, remove work-item modes. */
312+
constexpr auto rcoord_to_tv = make_layout(select<Mode>(coord_to_tv), remove<Mode>(coord_to_tv));
313+
constexpr auto rcoord_to_v = filter(composition(TVToV{}, rcoord_to_tv), Step<_1,_1>{});
314+
315+
/* Regroup input tensor */
316+
Tensor src_r = make_tensor(src.data(), rcoord_to_v);
317+
318+
/* Create output tensor */
319+
auto rshape = replace<Mode>(shape, _1{});
320+
Tensor out = make_subgroup_tensor(make_tensor<T>(ceil_div(size(rshape), intel::_SGSize{})),
321+
make_identity_layout(rshape));
322+
323+
/* Check for reduction type */
324+
constexpr bool horizontal = (size<0>(rcoord_to_tv) == intel::_SGSize{} * size<0>(rcoord_to_v));
325+
constexpr bool vertical = (size<1>(rcoord_to_tv) == intel::_SGSize{} * size<1>(rcoord_to_v));
326+
327+
/* Check for optimized reductions */
328+
constexpr bool align16 = is_constant_v<0, decltype(size<1>(rcoord_to_v) % _16{})>;
329+
constexpr bool align8 = is_constant_v<8, decltype(size<1>(rcoord_to_v))>;
330+
331+
constexpr bool hadd = (horizontal && is_same_v<T, float> && is_same_v<BinaryOp, sycl::plus<void>>);
332+
constexpr bool hmax = (horizontal && is_same_v<T, float> && is_same_v<BinaryOp, sycl::maximum<void>>);
333+
334+
constexpr bool hadd16 = hadd && align16;
335+
constexpr bool hmax16 = hmax && align16;
336+
337+
constexpr bool hadd8 = hadd && align8;
338+
constexpr bool hmax8 = hmax && align8;
339+
340+
[[maybe_unused]] T temp[size<1>(rcoord_to_v)]; /* array of partial reductions */
341+
342+
CUTE_UNROLL
343+
for (int j = 0; j < size<1>(rcoord_to_v); j++) {
344+
T acc = src_r(0, j);
345+
CUTE_UNROLL
346+
for (int i = 1; i < size<0>(rcoord_to_v); i++) {
347+
acc = op(acc, src_r(i, j));
348+
}
349+
350+
if constexpr (hadd16 || hmax16 || hadd8 || hmax8)
351+
temp[j] = acc;
352+
else if constexpr (horizontal)
353+
set_single_value(out, j, reduce_over_group(sg, acc, op));
354+
else if constexpr (vertical)
355+
out(j) = acc;
356+
else
357+
static_assert("Unimplemented reduction type");
358+
}
359+
360+
if constexpr (hadd16) {
361+
CUTE_UNROLL
362+
for (int j = 0; j < size<1>(rcoord_to_v); j += 16) {
363+
out(j/16) = hreduce16_float_add(&temp[j]);
364+
}
365+
} else if constexpr (hmax16) {
366+
CUTE_UNROLL
367+
for (int j = 0; j < size<1>(rcoord_to_v); j += 16) {
368+
out(j/16) = hreduce16_float_max(&temp[j]);
369+
}
370+
} else if constexpr (hadd8) {
371+
out(0) = hreduce8_float_add(&temp[0]);
372+
} else if constexpr (hmax8) {
373+
out(0) = hreduce8_float_max(&temp[0]);
374+
}
375+
376+
return out;
377+
}
378+
379+
} // namespace cute

0 commit comments

Comments
 (0)