|
| 1 | +/*************************************************************************************************** |
| 2 | + * Copyright (C) 2025 Intel Corporation, All rights reserved. |
| 3 | + * SPDX-License-Identifier: BSD-3-Clause |
| 4 | + * |
| 5 | + * Redistribution and use in source and binary forms, with or without |
| 6 | + * modification, are permitted provided that the following conditions are met: |
| 7 | + * |
| 8 | + * 1. Redistributions of source code must retain the above copyright notice, this |
| 9 | + * list of conditions and the following disclaimer. |
| 10 | + * |
| 11 | + * 2. Redistributions in binary form must reproduce the above copyright notice, |
| 12 | + * this list of conditions and the following disclaimer in the documentation |
| 13 | + * and/or other materials provided with the distribution. |
| 14 | + * |
| 15 | + * 3. Neither the name of the copyright holder nor the names of its |
| 16 | + * contributors may be used to endorse or promote products derived from |
| 17 | + * this software without specific prior written permission. |
| 18 | + * |
| 19 | + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" |
| 20 | + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE |
| 21 | + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE |
| 22 | + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE |
| 23 | + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL |
| 24 | + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR |
| 25 | + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER |
| 26 | + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, |
| 27 | + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE |
| 28 | + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. |
| 29 | + * |
| 30 | + **************************************************************************************************/ |
| 31 | + |
| 32 | +#pragma once |
| 33 | + |
| 34 | +#include "cute/tensor.hpp" |
| 35 | +#include "cute/util/sycl_vec.hpp" |
| 36 | + |
| 37 | +namespace cute { |
| 38 | + |
| 39 | +// Uniformize a value, in case the compiler cannot prove it is subgroup-uniform. |
| 40 | +template <typename T> |
| 41 | +CUTE_HOST_DEVICE |
| 42 | +T |
| 43 | +assert_uniform(T x) { |
| 44 | + auto sg = sycl::ext::oneapi::this_work_item::get_sub_group(); |
| 45 | + return group_broadcast(sg, x, 0); |
| 46 | +} |
| 47 | + |
| 48 | +// Set a value in a single work-item -- x[i] = val. |
| 49 | +// WARNING: i _must_ be a compile-time constant. |
| 50 | +// No diagnostics/error will be issued by the compiler if it is not. |
| 51 | +template <typename T> |
| 52 | +CUTE_HOST_DEVICE void |
| 53 | +set_wi_value(T &x, int i, T val) |
| 54 | +{ |
| 55 | +#if defined(__SYCL_DEVICE_ONLY__) && defined(SYCL_INTEL_TARGET) |
| 56 | + asm ( |
| 57 | + "mov (M1_NM, 1) %0(0,%2)<1> %1(0,0)<1;1,0>" |
| 58 | + : "+rw"(x) |
| 59 | + : "rw.u"(val), "P"(i) |
| 60 | + ); |
| 61 | +#else |
| 62 | + int lane = sycl::ext::oneapi::this_work_item::get_sub_group().get_local_id()[0]; |
| 63 | + if (lane == i) |
| 64 | + x = val; |
| 65 | +#endif |
| 66 | +} |
| 67 | + |
| 68 | +// Set an element of a 1D SG-shared fragment x. |
| 69 | +// WARNING: i _must_ be a compile-time constant. |
| 70 | +// No diagnostics/error will be issued by the compiler if it is not. |
| 71 | +template <typename FragX> |
| 72 | +CUTE_HOST_DEVICE void |
| 73 | +set_single_value(FragX& x, int i, typename FragX::element_type val) { |
| 74 | + set_wi_value(x(i / intel::sg_size), i % intel::sg_size, val); |
| 75 | +} |
| 76 | + |
| 77 | +// Broadcast the element from a 1D SG-shared fragment x |
| 78 | +// corresponding to the Mode'th dimension of the logical coordinates of src(val). |
| 79 | +template <int Mode, typename FragX, typename SGTensorSrc, |
| 80 | + __CUTE_REQUIRES(is_sg_tensor<SGTensorSrc>::value)> |
| 81 | +CUTE_HOST_DEVICE |
| 82 | +constexpr auto |
| 83 | +broadcast(FragX const& x, SGTensorSrc const& src, int val) |
| 84 | +{ |
| 85 | + auto coord = src.tv_layout()(0, val); |
| 86 | + auto coord_i = get<Mode>(coord); |
| 87 | + |
| 88 | + constexpr auto TMode = rank(as_arithmetic_tuple(stride<0>(SGTensorSrc{}.tv_layout()))) - 1; |
| 89 | + if constexpr (TMode == Mode) { |
| 90 | + return x(coord_i / intel::sg_size); |
| 91 | + } else { |
| 92 | + auto sg = sycl::ext::oneapi::this_work_item::get_sub_group(); |
| 93 | + return group_broadcast(sg, x(coord_i / intel::sg_size), coord_i % intel::sg_size); |
| 94 | + } |
| 95 | +} |
| 96 | + |
| 97 | +#if defined(__SYCL_DEVICE_ONLY__) && defined(SYCL_INTEL_TARGET) |
| 98 | +#define DEFINE_HREDUCE16_FLOAT(op) \ |
| 99 | + CUTE_DEVICE \ |
| 100 | + float \ |
| 101 | + hreduce16_float_ ## op(float x[16]) \ |
| 102 | + { \ |
| 103 | + float y; \ |
| 104 | + asm ( \ |
| 105 | + "{\n" \ |
| 106 | + ".decl INTERLEAVE_2 v_type=P num_elts=16\n" \ |
| 107 | + ".decl INTERLEAVE_4 v_type=P num_elts=16\n" \ |
| 108 | + ".decl INTERLEAVE_8 v_type=P num_elts=16\n" \ |
| 109 | + ".decl IN0 v_type=G type=UD num_elts=16 alias=<%1,0>\n" \ |
| 110 | + ".decl IN1 v_type=G type=UD num_elts=16 alias=<%2,0>\n" \ |
| 111 | + ".decl IN2 v_type=G type=UD num_elts=16 alias=<%3,0>\n" \ |
| 112 | + ".decl IN3 v_type=G type=UD num_elts=16 alias=<%4,0>\n" \ |
| 113 | + ".decl IN4 v_type=G type=UD num_elts=16 alias=<%5,0>\n" \ |
| 114 | + ".decl IN5 v_type=G type=UD num_elts=16 alias=<%6,0>\n" \ |
| 115 | + ".decl IN6 v_type=G type=UD num_elts=16 alias=<%7,0>\n" \ |
| 116 | + ".decl IN7 v_type=G type=UD num_elts=16 alias=<%8,0>\n" \ |
| 117 | + ".decl IN8 v_type=G type=UD num_elts=16 alias=<%9,0>\n" \ |
| 118 | + ".decl IN9 v_type=G type=UD num_elts=16 alias=<%10,0>\n" \ |
| 119 | + ".decl IN10 v_type=G type=UD num_elts=16 alias=<%11,0>\n" \ |
| 120 | + ".decl IN11 v_type=G type=UD num_elts=16 alias=<%12,0>\n" \ |
| 121 | + ".decl IN12 v_type=G type=UD num_elts=16 alias=<%13,0>\n" \ |
| 122 | + ".decl IN13 v_type=G type=UD num_elts=16 alias=<%14,0>\n" \ |
| 123 | + ".decl IN14 v_type=G type=UD num_elts=16 alias=<%15,0>\n" \ |
| 124 | + ".decl IN15 v_type=G type=UD num_elts=16 alias=<%16,0>\n" \ |
| 125 | + ".decl RA0 v_type=G type=UD num_elts=32 align=64\n" \ |
| 126 | + ".decl RA2 v_type=G type=UD num_elts=32 align=64\n" \ |
| 127 | + ".decl RA4 v_type=G type=UD num_elts=32 align=64\n" \ |
| 128 | + ".decl RA6 v_type=G type=UD num_elts=32 align=64\n" \ |
| 129 | + ".decl RA8 v_type=G type=UD num_elts=32 align=64\n" \ |
| 130 | + ".decl RA10 v_type=G type=UD num_elts=32 align=64\n" \ |
| 131 | + ".decl RA12 v_type=G type=UD num_elts=32 align=64\n" \ |
| 132 | + ".decl RA14 v_type=G type=UD num_elts=32 align=64\n" \ |
| 133 | + ".decl RF0 v_type=G type=F num_elts=16 alias=<RA0,0>\n" \ |
| 134 | + ".decl RF1 v_type=G type=F num_elts=16 alias=<RA0,64>\n" \ |
| 135 | + ".decl RF2 v_type=G type=F num_elts=16 alias=<RA2,0>\n" \ |
| 136 | + ".decl RF3 v_type=G type=F num_elts=16 alias=<RA2,64>\n" \ |
| 137 | + ".decl RF4 v_type=G type=F num_elts=16 alias=<RA4,0>\n" \ |
| 138 | + ".decl RF5 v_type=G type=F num_elts=16 alias=<RA4,64>\n" \ |
| 139 | + ".decl RF6 v_type=G type=F num_elts=16 alias=<RA6,0>\n" \ |
| 140 | + ".decl RF7 v_type=G type=F num_elts=16 alias=<RA6,64>\n" \ |
| 141 | + ".decl RF8 v_type=G type=F num_elts=16 alias=<RA8,0>\n" \ |
| 142 | + ".decl RF9 v_type=G type=F num_elts=16 alias=<RA8,64>\n" \ |
| 143 | + ".decl RF10 v_type=G type=F num_elts=16 alias=<RA10,0>\n" \ |
| 144 | + ".decl RF11 v_type=G type=F num_elts=16 alias=<RA10,64>\n" \ |
| 145 | + ".decl RF12 v_type=G type=F num_elts=16 alias=<RA12,0>\n" \ |
| 146 | + ".decl RF13 v_type=G type=F num_elts=16 alias=<RA12,64>\n" \ |
| 147 | + ".decl RF14 v_type=G type=F num_elts=16 alias=<RA14,0>\n" \ |
| 148 | + ".decl RF15 v_type=G type=F num_elts=16 alias=<RA14,64>\n" \ |
| 149 | + "setp (M1_NM,16) INTERLEAVE_2 0x5555:uw\n" \ |
| 150 | + "setp (M1_NM,16) INTERLEAVE_4 0x3333:uw\n" \ |
| 151 | + "setp (M1_NM,16) INTERLEAVE_8 0x0F0F:uw\n" \ |
| 152 | + /* Round 1: interleave 2n with 2n+1 */ \ |
| 153 | + "(!INTERLEAVE_2) sel (M1_NM,16) RA0(0,0)<1> IN1(0,0)<2;2,0> IN0(0,0)<1;1,0>\n" \ |
| 154 | + " (INTERLEAVE_2) sel (M1_NM,16) RA0(1,0)<1> IN0(0,1)<2;2,0> IN1(0,0)<1;1,0>\n" \ |
| 155 | + "(!INTERLEAVE_2) sel (M1_NM,16) RA2(0,0)<1> IN3(0,0)<2;2,0> IN2(0,0)<1;1,0>\n" \ |
| 156 | + " (INTERLEAVE_2) sel (M1_NM,16) RA2(1,0)<1> IN2(0,1)<2;2,0> IN3(0,0)<1;1,0>\n" \ |
| 157 | + "(!INTERLEAVE_2) sel (M1_NM,16) RA4(0,0)<1> IN5(0,0)<2;2,0> IN4(0,0)<1;1,0>\n" \ |
| 158 | + " (INTERLEAVE_2) sel (M1_NM,16) RA4(1,0)<1> IN4(0,1)<2;2,0> IN5(0,0)<1;1,0>\n" \ |
| 159 | + "(!INTERLEAVE_2) sel (M1_NM,16) RA6(0,0)<1> IN7(0,0)<2;2,0> IN6(0,0)<1;1,0>\n" \ |
| 160 | + " (INTERLEAVE_2) sel (M1_NM,16) RA6(1,0)<1> IN6(0,1)<2;2,0> IN7(0,0)<1;1,0>\n" \ |
| 161 | + "(!INTERLEAVE_2) sel (M1_NM,16) RA8(0,0)<1> IN9(0,0)<2;2,0> IN8(0,0)<1;1,0>\n" \ |
| 162 | + " (INTERLEAVE_2) sel (M1_NM,16) RA8(1,0)<1> IN8(0,1)<2;2,0> IN9(0,0)<1;1,0>\n" \ |
| 163 | + "(!INTERLEAVE_2) sel (M1_NM,16) RA10(0,0)<1> IN11(0,0)<2;2,0> IN10(0,0)<1;1,0>\n" \ |
| 164 | + " (INTERLEAVE_2) sel (M1_NM,16) RA10(1,0)<1> IN10(0,1)<2;2,0> IN11(0,0)<1;1,0>\n" \ |
| 165 | + "(!INTERLEAVE_2) sel (M1_NM,16) RA12(0,0)<1> IN13(0,0)<2;2,0> IN12(0,0)<1;1,0>\n" \ |
| 166 | + " (INTERLEAVE_2) sel (M1_NM,16) RA12(1,0)<1> IN12(0,1)<2;2,0> IN13(0,0)<1;1,0>\n" \ |
| 167 | + "(!INTERLEAVE_2) sel (M1_NM,16) RA14(0,0)<1> IN15(0,0)<2;2,0> IN14(0,0)<1;1,0>\n" \ |
| 168 | + " (INTERLEAVE_2) sel (M1_NM,16) RA14(1,0)<1> IN14(0,1)<2;2,0> IN15(0,0)<1;1,0>\n" \ |
| 169 | + /* Reduce */ \ |
| 170 | + #op " (M1_NM,16) RF0(0,0)<1> RF0(0,0)<1;1,0> RF1(0,0)<1;1,0>\n" \ |
| 171 | + #op " (M1_NM,16) RF3(0,0)<1> RF2(0,0)<1;1,0> RF3(0,0)<1;1,0>\n" \ |
| 172 | + #op " (M1_NM,16) RF4(0,0)<1> RF4(0,0)<1;1,0> RF5(0,0)<1;1,0>\n" \ |
| 173 | + #op " (M1_NM,16) RF7(0,0)<1> RF6(0,0)<1;1,0> RF7(0,0)<1;1,0>\n" \ |
| 174 | + #op " (M1_NM,16) RF8(0,0)<1> RF8(0,0)<1;1,0> RF9(0,0)<1;1,0>\n" \ |
| 175 | + #op " (M1_NM,16) RF11(0,0)<1> RF10(0,0)<1;1,0> RF11(0,0)<1;1,0>\n" \ |
| 176 | + #op " (M1_NM,16) RF12(0,0)<1> RF12(0,0)<1;1,0> RF13(0,0)<1;1,0>\n" \ |
| 177 | + #op " (M1_NM,16) RF15(0,0)<1> RF14(0,0)<1;1,0> RF15(0,0)<1;1,0>\n" \ |
| 178 | + /* Round 2: interleave 4n+{0,1} with 4n+{2,3} */ \ |
| 179 | + "(!INTERLEAVE_4) sel (M1_NM,16) RA0(1,0)<1> RA2(0,14)<1;1,0> RA0(0,0)<1;1,0>\n" \ |
| 180 | + " (INTERLEAVE_4) sel (M1_NM,16) RA0(0,0)<1> RA0(0,2)<1;1,0> RA2(1,0)<1;1,0>\n" \ |
| 181 | + "(!INTERLEAVE_4) sel (M1_NM,16) RA4(1,0)<1> RA6(0,14)<1;1,0> RA4(0,0)<1;1,0>\n" \ |
| 182 | + " (INTERLEAVE_4) sel (M1_NM,16) RA4(0,0)<1> RA4(0,2)<1;1,0> RA6(1,0)<1;1,0>\n" \ |
| 183 | + "(!INTERLEAVE_4) sel (M1_NM,16) RA8(1,0)<1> RA10(0,14)<1;1,0> RA8(0,0)<1;1,0>\n" \ |
| 184 | + " (INTERLEAVE_4) sel (M1_NM,16) RA8(0,0)<1> RA8(0,2)<1;1,0> RA10(1,0)<1;1,0>\n" \ |
| 185 | + "(!INTERLEAVE_4) sel (M1_NM,16) RA12(1,0)<1> RA14(0,14)<1;1,0> RA12(0,0)<1;1,0>\n" \ |
| 186 | + " (INTERLEAVE_4) sel (M1_NM,16) RA12(0,0)<1> RA12(0,2)<1;1,0> RA14(1,0)<1;1,0>\n" \ |
| 187 | + /* Reduce */ \ |
| 188 | + #op " (M1_NM,16) RF0(0,0)<1> RF0(0,0)<1;1,0> RF1(0,0)<1;1,0>\n" \ |
| 189 | + #op " (M1_NM,16) RF5(0,0)<1> RF4(0,0)<1;1,0> RF5(0,0)<1;1,0>\n" \ |
| 190 | + #op " (M1_NM,16) RF8(0,0)<1> RF8(0,0)<1;1,0> RF9(0,0)<1;1,0>\n" \ |
| 191 | + #op " (M1_NM,16) RF13(0,0)<1> RF12(0,0)<1;1,0> RF13(0,0)<1;1,0>\n" \ |
| 192 | + /* Round 3: interleave 8n+{0,1,2,3} with 8n+{4,5,6,7} */ \ |
| 193 | + "(!INTERLEAVE_8) sel (M1_NM,16) RA0(1,0)<1> RA4(0,12)<1;1,0> RA0(0,0)<1;1,0>\n" \ |
| 194 | + " (INTERLEAVE_8) sel (M1_NM,16) RA0(0,0)<1> RA0(0,4)<1;1,0> RA4(1,0)<1;1,0>\n" \ |
| 195 | + "(!INTERLEAVE_8) sel (M1_NM,16) RA8(1,0)<1> RA12(0,12)<1;1,0> RA8(0,0)<1;1,0>\n" \ |
| 196 | + " (INTERLEAVE_8) sel (M1_NM,16) RA8(0,0)<1> RA8(0,4)<1;1,0> RA12(1,0)<1;1,0>\n" \ |
| 197 | + /* Reduce */ \ |
| 198 | + #op " (M1_NM,16) RF0(0,0)<1> RF0(0,0)<1;1,0> RF1(0,0)<1;1,0>\n" \ |
| 199 | + #op " (M1_NM,16) RF8(0,0)<1> RF8(0,0)<1;1,0> RF9(0,0)<1;1,0>\n" \ |
| 200 | + /* Round 4: final interleave */ \ |
| 201 | + "mov (M1_NM, 8) RA0(1,0)<1> RA0(0,8)<1;1,0>\n" \ |
| 202 | + "mov (M1_NM, 8) RA8(1,8)<1> RA8(0,0)<1;1,0>\n" \ |
| 203 | + /* Reduce */ \ |
| 204 | + #op " (M1_NM,8) %0(0,0)<1> RF0(0,0)<1;1,0> RF1(0,0)<1;1,0>\n" \ |
| 205 | + #op " (M1_NM,8) %0(0,8)<1> RF8(0,8)<1;1,0> RF9(0,8)<1;1,0>\n" \ |
| 206 | + "}\n" \ |
| 207 | + : "=rw"(y) \ |
| 208 | + : "rw"(x[0]), "rw"(x[1]), "rw"(x[2]), "rw"(x[3]), "rw"(x[4]), "rw"(x[5]), "rw"(x[6]), "rw"(x[7]), \ |
| 209 | + "rw"(x[8]), "rw"(x[9]), "rw"(x[10]), "rw"(x[11]), "rw"(x[12]), "rw"(x[13]), "rw"(x[14]), "rw"(x[15]) \ |
| 210 | + ); \ |
| 211 | + return y; \ |
| 212 | + } |
| 213 | + |
| 214 | +#define DEFINE_HREDUCE8_FLOAT(op) \ |
| 215 | + CUTE_DEVICE \ |
| 216 | + float \ |
| 217 | + hreduce8_float_ ## op(float x[8]) \ |
| 218 | + { \ |
| 219 | + float y; \ |
| 220 | + asm ( \ |
| 221 | + "{\n" \ |
| 222 | + ".decl INTERLEAVE_2 v_type=P num_elts=16\n" \ |
| 223 | + ".decl INTERLEAVE_4 v_type=P num_elts=16\n" \ |
| 224 | + ".decl INTERLEAVE_8 v_type=P num_elts=16\n" \ |
| 225 | + ".decl IN0 v_type=G type=UD num_elts=16 alias=<%1,0>\n" \ |
| 226 | + ".decl IN1 v_type=G type=UD num_elts=16 alias=<%2,0>\n" \ |
| 227 | + ".decl IN2 v_type=G type=UD num_elts=16 alias=<%3,0>\n" \ |
| 228 | + ".decl IN3 v_type=G type=UD num_elts=16 alias=<%4,0>\n" \ |
| 229 | + ".decl IN4 v_type=G type=UD num_elts=16 alias=<%5,0>\n" \ |
| 230 | + ".decl IN5 v_type=G type=UD num_elts=16 alias=<%6,0>\n" \ |
| 231 | + ".decl IN6 v_type=G type=UD num_elts=16 alias=<%7,0>\n" \ |
| 232 | + ".decl IN7 v_type=G type=UD num_elts=16 alias=<%8,0>\n" \ |
| 233 | + ".decl RA0 v_type=G type=UD num_elts=32 align=64\n" \ |
| 234 | + ".decl RA2 v_type=G type=UD num_elts=32 align=64\n" \ |
| 235 | + ".decl RA4 v_type=G type=UD num_elts=32 align=64\n" \ |
| 236 | + ".decl RA6 v_type=G type=UD num_elts=32 align=64\n" \ |
| 237 | + ".decl RF0 v_type=G type=F num_elts=16 alias=<RA0,0>\n" \ |
| 238 | + ".decl RF1 v_type=G type=F num_elts=16 alias=<RA0,64>\n" \ |
| 239 | + ".decl RF2 v_type=G type=F num_elts=16 alias=<RA2,0>\n" \ |
| 240 | + ".decl RF3 v_type=G type=F num_elts=16 alias=<RA2,64>\n" \ |
| 241 | + ".decl RF4 v_type=G type=F num_elts=16 alias=<RA4,0>\n" \ |
| 242 | + ".decl RF5 v_type=G type=F num_elts=16 alias=<RA4,64>\n" \ |
| 243 | + ".decl RF6 v_type=G type=F num_elts=16 alias=<RA6,0>\n" \ |
| 244 | + ".decl RF7 v_type=G type=F num_elts=16 alias=<RA6,64>\n" \ |
| 245 | + "setp (M1_NM,16) INTERLEAVE_2 0x5555:uw\n" \ |
| 246 | + "setp (M1_NM,16) INTERLEAVE_4 0x3333:uw\n" \ |
| 247 | + "setp (M1_NM,16) INTERLEAVE_8 0x0F0F:uw\n" \ |
| 248 | + /* Round 1: interleave 2n with 2n+1 */ \ |
| 249 | + "(!INTERLEAVE_2) sel (M1_NM,16) RA0(0,0)<1> IN1(0,0)<2;2,0> IN0(0,0)<1;1,0>\n" \ |
| 250 | + " (INTERLEAVE_2) sel (M1_NM,16) RA0(1,0)<1> IN0(0,1)<2;2,0> IN1(0,0)<1;1,0>\n" \ |
| 251 | + "(!INTERLEAVE_2) sel (M1_NM,16) RA2(0,0)<1> IN3(0,0)<2;2,0> IN2(0,0)<1;1,0>\n" \ |
| 252 | + " (INTERLEAVE_2) sel (M1_NM,16) RA2(1,0)<1> IN2(0,1)<2;2,0> IN3(0,0)<1;1,0>\n" \ |
| 253 | + "(!INTERLEAVE_2) sel (M1_NM,16) RA4(0,0)<1> IN5(0,0)<2;2,0> IN4(0,0)<1;1,0>\n" \ |
| 254 | + " (INTERLEAVE_2) sel (M1_NM,16) RA4(1,0)<1> IN4(0,1)<2;2,0> IN5(0,0)<1;1,0>\n" \ |
| 255 | + "(!INTERLEAVE_2) sel (M1_NM,16) RA6(0,0)<1> IN7(0,0)<2;2,0> IN6(0,0)<1;1,0>\n" \ |
| 256 | + " (INTERLEAVE_2) sel (M1_NM,16) RA6(1,0)<1> IN6(0,1)<2;2,0> IN7(0,0)<1;1,0>\n" \ |
| 257 | + /* Reduce */ \ |
| 258 | + #op " (M1_NM,16) RF0(0,0)<1> RF0(0,0)<1;1,0> RF1(0,0)<1;1,0>\n" \ |
| 259 | + #op " (M1_NM,16) RF3(0,0)<1> RF2(0,0)<1;1,0> RF3(0,0)<1;1,0>\n" \ |
| 260 | + #op " (M1_NM,16) RF4(0,0)<1> RF4(0,0)<1;1,0> RF5(0,0)<1;1,0>\n" \ |
| 261 | + #op " (M1_NM,16) RF7(0,0)<1> RF6(0,0)<1;1,0> RF7(0,0)<1;1,0>\n" \ |
| 262 | + /* Round 2: interleave 4n+{0,1} with 4n+{2,3} */ \ |
| 263 | + "(!INTERLEAVE_4) sel (M1_NM,16) RA0(1,0)<1> RA2(0,14)<1;1,0> RA0(0,0)<1;1,0>\n" \ |
| 264 | + " (INTERLEAVE_4) sel (M1_NM,16) RA0(0,0)<1> RA0(0,2)<1;1,0> RA2(1,0)<1;1,0>\n" \ |
| 265 | + "(!INTERLEAVE_4) sel (M1_NM,16) RA4(1,0)<1> RA6(0,14)<1;1,0> RA4(0,0)<1;1,0>\n" \ |
| 266 | + " (INTERLEAVE_4) sel (M1_NM,16) RA4(0,0)<1> RA4(0,2)<1;1,0> RA6(1,0)<1;1,0>\n" \ |
| 267 | + /* Reduce */ \ |
| 268 | + #op " (M1_NM,16) RF0(0,0)<1> RF0(0,0)<1;1,0> RF1(0,0)<1;1,0>\n" \ |
| 269 | + #op " (M1_NM,16) RF5(0,0)<1> RF4(0,0)<1;1,0> RF5(0,0)<1;1,0>\n" \ |
| 270 | + /* Round 3: interleave 8n+{0,1,2,3} with 8n+{4,5,6,7} */ \ |
| 271 | + "(!INTERLEAVE_8) sel (M1_NM,16) RA0(1,0)<1> RA4(0,12)<1;1,0> RA0(0,0)<1;1,0>\n" \ |
| 272 | + " (INTERLEAVE_8) sel (M1_NM,16) RA0(0,0)<1> RA0(0,4)<1;1,0> RA4(1,0)<1;1,0>\n" \ |
| 273 | + /* Reduce */ \ |
| 274 | + #op " (M1_NM,16) RF0(0,0)<1> RF0(0,0)<1;1,0> RF1(0,0)<1;1,0>\n" \ |
| 275 | + /* Round 4: reduce top and bottom halves */ \ |
| 276 | + #op " (M1_NM,8) %0(0,0)<1> RF0(0,0)<1;1,0> RF0(0,8)<1;1,0>\n" \ |
| 277 | + "}\n" \ |
| 278 | + : "=rw"(y) \ |
| 279 | + : "rw"(x[0]), "rw"(x[1]), "rw"(x[2]), "rw"(x[3]), "rw"(x[4]), "rw"(x[5]), "rw"(x[6]), "rw"(x[7]), \ |
| 280 | + "rw"(x[8]), "rw"(x[9]), "rw"(x[10]), "rw"(x[11]), "rw"(x[12]), "rw"(x[13]), "rw"(x[14]), "rw"(x[15]) \ |
| 281 | + ); \ |
| 282 | + return y; \ |
| 283 | + } |
| 284 | +#else |
| 285 | +#define DEFINE_HREDUCE16_FLOAT(op) \ |
| 286 | + CUTE_DEVICE float hreduce16_float_ ## op(float x[16]) { return 0.f; } |
| 287 | +#define DEFINE_HREDUCE8_FLOAT(op) \ |
| 288 | + CUTE_DEVICE float hreduce8_float_ ## op(float x[8]) { return 0.f; } |
| 289 | +#endif |
| 290 | + |
| 291 | +DEFINE_HREDUCE8_FLOAT(add) |
| 292 | +DEFINE_HREDUCE8_FLOAT(max) |
| 293 | +DEFINE_HREDUCE16_FLOAT(add) |
| 294 | +DEFINE_HREDUCE16_FLOAT(max) |
| 295 | + |
| 296 | +// Subgroup-cooperative reduction of a SubgroupTensor. |
| 297 | +template <int Mode, class BinaryOp, |
| 298 | + class Engine, class FragLayout, class SubgroupTVLayout> |
| 299 | +CUTE_HOST_DEVICE |
| 300 | +auto |
| 301 | +reduce(SubgroupTensor<Engine,FragLayout,SubgroupTVLayout> const& src, BinaryOp op) |
| 302 | +{ |
| 303 | + auto sg = sycl::ext::oneapi::this_work_item::get_sub_group(); |
| 304 | + using T = typename Engine::value_type; |
| 305 | + using TVToV = Layout<Shape<intel::_SGSize,int>, Stride<_0,_1>>; |
| 306 | + |
| 307 | + /* Retrieve logical coordinate -> (T,V) mapping */ |
| 308 | + constexpr auto shape = atuple_coshape(SubgroupTVLayout{}); |
| 309 | + constexpr auto coord_to_tv = right_inverse(project_strides(SubgroupTVLayout{})).with_shape(shape); |
| 310 | + |
| 311 | + /* Move reduction coordinate to mode-0 and group the rest in mode-1. Then, remove work-item modes. */ |
| 312 | + constexpr auto rcoord_to_tv = make_layout(select<Mode>(coord_to_tv), remove<Mode>(coord_to_tv)); |
| 313 | + constexpr auto rcoord_to_v = filter(composition(TVToV{}, rcoord_to_tv), Step<_1,_1>{}); |
| 314 | + |
| 315 | + /* Regroup input tensor */ |
| 316 | + Tensor src_r = make_tensor(src.data(), rcoord_to_v); |
| 317 | + |
| 318 | + /* Create output tensor */ |
| 319 | + auto rshape = replace<Mode>(shape, _1{}); |
| 320 | + Tensor out = make_subgroup_tensor(make_tensor<T>(ceil_div(size(rshape), intel::_SGSize{})), |
| 321 | + make_identity_layout(rshape)); |
| 322 | + |
| 323 | + /* Check for reduction type */ |
| 324 | + constexpr bool horizontal = (size<0>(rcoord_to_tv) == intel::_SGSize{} * size<0>(rcoord_to_v)); |
| 325 | + constexpr bool vertical = (size<1>(rcoord_to_tv) == intel::_SGSize{} * size<1>(rcoord_to_v)); |
| 326 | + |
| 327 | + /* Check for optimized reductions */ |
| 328 | + constexpr bool align16 = is_constant_v<0, decltype(size<1>(rcoord_to_v) % _16{})>; |
| 329 | + constexpr bool align8 = is_constant_v<8, decltype(size<1>(rcoord_to_v))>; |
| 330 | + |
| 331 | + constexpr bool hadd = (horizontal && is_same_v<T, float> && is_same_v<BinaryOp, sycl::plus<void>>); |
| 332 | + constexpr bool hmax = (horizontal && is_same_v<T, float> && is_same_v<BinaryOp, sycl::maximum<void>>); |
| 333 | + |
| 334 | + constexpr bool hadd16 = hadd && align16; |
| 335 | + constexpr bool hmax16 = hmax && align16; |
| 336 | + |
| 337 | + constexpr bool hadd8 = hadd && align8; |
| 338 | + constexpr bool hmax8 = hmax && align8; |
| 339 | + |
| 340 | + [[maybe_unused]] T temp[size<1>(rcoord_to_v)]; /* array of partial reductions */ |
| 341 | + |
| 342 | + CUTE_UNROLL |
| 343 | + for (int j = 0; j < size<1>(rcoord_to_v); j++) { |
| 344 | + T acc = src_r(0, j); |
| 345 | + CUTE_UNROLL |
| 346 | + for (int i = 1; i < size<0>(rcoord_to_v); i++) { |
| 347 | + acc = op(acc, src_r(i, j)); |
| 348 | + } |
| 349 | + |
| 350 | + if constexpr (hadd16 || hmax16 || hadd8 || hmax8) |
| 351 | + temp[j] = acc; |
| 352 | + else if constexpr (horizontal) |
| 353 | + set_single_value(out, j, reduce_over_group(sg, acc, op)); |
| 354 | + else if constexpr (vertical) |
| 355 | + out(j) = acc; |
| 356 | + else |
| 357 | + static_assert("Unimplemented reduction type"); |
| 358 | + } |
| 359 | + |
| 360 | + if constexpr (hadd16) { |
| 361 | + CUTE_UNROLL |
| 362 | + for (int j = 0; j < size<1>(rcoord_to_v); j += 16) { |
| 363 | + out(j/16) = hreduce16_float_add(&temp[j]); |
| 364 | + } |
| 365 | + } else if constexpr (hmax16) { |
| 366 | + CUTE_UNROLL |
| 367 | + for (int j = 0; j < size<1>(rcoord_to_v); j += 16) { |
| 368 | + out(j/16) = hreduce16_float_max(&temp[j]); |
| 369 | + } |
| 370 | + } else if constexpr (hadd8) { |
| 371 | + out(0) = hreduce8_float_add(&temp[0]); |
| 372 | + } else if constexpr (hmax8) { |
| 373 | + out(0) = hreduce8_float_max(&temp[0]); |
| 374 | + } |
| 375 | + |
| 376 | + return out; |
| 377 | +} |
| 378 | + |
| 379 | +} // namespace cute |
0 commit comments