@@ -2031,33 +2031,20 @@ kernel void kernel_flash_attn_ext_f16(
20312031 uint3 ntg[[threads_per_threadgroup]],
20322032 uint tiisg[[thread_index_in_simdgroup]],
20332033 uint sgitg[[simdgroup_index_in_threadgroup]]) {
2034- const int64_t iq3 = tgpig[2 ];
2035- const int64_t iq2 = tgpig[1 ];
2036- const int64_t iq1 = tgpig[0 ]*N_SIMDWIDTH + tiisg;
2037-
2038- if (iq1 >= ne01) {
2039- return ;
2040- }
2034+ // const int64_t iq3 = tgpig[2];
2035+ // const int64_t iq2 = tgpig[1];
2036+ // const int64_t iq1 = tgpig[0]*N_SIMDWIDTH + tiisg;
20412037
2042- const int64_t D4 = D/ 4 ;
2038+ const uint nsg = ntg. x /N_SIMDWIDTH; // number of simdgroups
20432039
2044- // TODO: can we move this to the stack?
2045- threadgroup half4 * V16 = (threadgroup half4 *) (shared + (2 *sgitg*N_SIMDWIDTH + tiisg)*D);
2040+ const int64_t iq3 = tgpig[2 ];
2041+ const int64_t iq2 = tgpig[1 ]*(8 *nsg) + 8 *sgitg + tiisg/4 ;
2042+ const int64_t iq1 = tgpig[0 ];
20462043
2047- // initialize with zeros
2048- for (int64_t d = 0 ; d < D4; ++d) {
2049- V16[d] = 0 .0h;
2044+ if (iq2 >= ne02) {
2045+ return ;
20502046 }
20512047
2052- threadgroup half4 * pq4 = (threadgroup half4 *) (shared + (2 *sgitg*N_SIMDWIDTH + N_SIMDWIDTH)*D + tiisg*D);
2053-
2054- half S = 0 .0h;
2055- half M = -INFINITY;
2056-
2057- const int64_t ir = iq3*ne02*ne01 + iq2*ne01 + iq1;
2058-
2059- device const float * mp = mask ? (device const float *) (mask + (ir%ne31)*nb31) : nullptr ;
2060-
20612048 // assume K and V are same shape
20622049 const int64_t ne22 = ne12;
20632050 const int64_t ne23 = ne13;
@@ -2081,11 +2068,97 @@ kernel void kernel_flash_attn_ext_f16(
20812068 const int64_t iv2 = iq2 / rv2;
20822069 const int64_t iv3 = iq3 / rv3;
20832070
2084- // load Q to shared memory
2085- for (int64_t d = 0 ; d < D4; ++d) {
2086- pq4[d] = ((device const half4 *) ((device const char *) q + (iq1*nb01 + iq2*nb02 + iq3*nb03)))[d];
2071+ const int64_t ir = iq3*ne02*ne01 + iq2*ne01 + iq1;
2072+
2073+ device const float * mp = mask ? (device const float *) (mask + (ir%ne31)*nb31) : nullptr ;
2074+
2075+ // const int64_t D4 = D/4;
2076+ //
2077+ // // TODO: can we move this to the stack?
2078+ // threadgroup half4x4 * V16 = (threadgroup half4x4 *) (shared);
2079+ //
2080+ // // initialize with zeros
2081+ // for (int64_t d = 0; d < D4; ++d) {
2082+ //
2083+ // }
2084+ //
2085+ // threadgroup half4 * pq4 = (threadgroup half4 *) (shared + 4*D);
2086+ //
2087+ // // load Q to shared memory
2088+ // for (int64_t d = 0; d < D4; ++d) {
2089+ // pq4[d] = ((device const half4 *) ((device const char *) q + (iq1*nb01 + iq2*nb02 + iq3*nb03)))[d];
2090+ // }
2091+ //
2092+ // half S = 0.0h;
2093+ // half M = -INFINITY;
2094+ //
2095+ // for (int64_t ic = 0; ic < ne11; ++ic) {
2096+ // const half mv = mp ? mp[ic] : 0.0h;
2097+ // if (mv == -INFINITY) {
2098+ // continue;
2099+ // }
2100+ //
2101+ // device const half4 * pk4 = (device const half4 *) ((device char *) k + (ic*nb11 + ik2*nb12 + ik3*nb13));
2102+ // device const half4 * pv4 = (device const half4 *) ((device char *) v + (ic*nb21 + iv2*nb22 + iv3*nb23));
2103+ //
2104+ // half4 s4 = 0.0h;
2105+ //
2106+ // for (int64_t d = 0; d < D4; ++d) {
2107+ // s4 += pk4[d] * pq4[d];
2108+ // }
2109+ //
2110+ // half s = (s4.x + s4.y + s4.z + s4.w)*scale + mv;
2111+ //
2112+ // const half Mold = M;
2113+ //
2114+ // M = max(M, s);
2115+ //
2116+ // const half ms = exp(Mold - M);
2117+ // const half vs = exp(s - M);
2118+ //
2119+ // for (int64_t d = 0; d < D4; ++d) {
2120+ // V16[d] = V16[d]*ms + pv4[d]*vs;
2121+ // }
2122+ //
2123+ // S = S*ms + vs;
2124+ // }
2125+ //
2126+ // for (int64_t d = 0; d < D4; ++d) {
2127+ // V16[d] /= S;
2128+ // }
2129+ //
2130+ // // dst indices
2131+ // const int64_t i1 = iq1;
2132+ // const int64_t i2 = iq2;
2133+ // const int64_t i3 = iq3;
2134+ //
2135+ // device float4 * dst4 = (device float4 *) dst;
2136+ //
2137+ // for (int64_t d = 0; d < D4; ++d) {
2138+ // dst4[(i3*ne2*ne1 + i2 + i1*ne1)*D4 + d] = (float4) V16[d];
2139+ // }
2140+
2141+ const int64_t D4 = D/4 ;
2142+
2143+ threadgroup half4 * pq4 = (threadgroup half4 *) (shared + sgitg*(16 *D + 128 ) );
2144+ threadgroup half4 * ps4 = (threadgroup half4 *) (shared + sgitg*(16 *D + 128 ) + 8 *D);
2145+ threadgroup half4 * ss4 = (threadgroup half4 *) (shared + sgitg*(16 *D + 128 ) + 16 *D);
2146+ threadgroup half * ss = (threadgroup half *) (shared + sgitg*(16 *D + 128 ) + 16 *D);
2147+
2148+ const uint tiih = tiisg%4 ; // thread index in head
2149+ const uint hiisg = tiisg/4 ; // head index in simdgroup
2150+
2151+ // load 8 heads from Q to shared memory
2152+ for (int64_t i = 0 ; i < D4/4 ; ++i) {
2153+ pq4[hiisg*D4 + 4 *i + tiih] = ((device const half4 *) ((device const char *) q + (iq1*nb01 + iq2*nb02 + iq3*nb03)))[4 *i + tiih];
2154+ ps4[hiisg*D4 + 4 *i + tiih] = 0 .0h;
20872155 }
20882156
2157+ simdgroup_barrier (mem_flags::mem_threadgroup);
2158+
2159+ half S = 0 .0h;
2160+ half M = -INFINITY;
2161+
20892162 for (int64_t ic = 0 ; ic < ne11; ++ic) {
20902163 const half mv = mp ? mp[ic] : 0 .0h;
20912164 if (mv == -INFINITY) {
@@ -2097,39 +2170,61 @@ kernel void kernel_flash_attn_ext_f16(
20972170
20982171 half4 s4 = 0 .0h;
20992172
2100- for (int64_t d = 0 ; d < D4; ++d ) {
2101- s4 += pk4[d ] * pq4[d ];
2173+ for (int64_t i = 0 ; i < D4/ 4 ; ++i ) {
2174+ s4 += pk4[4 *i + tiih ] * pq4[hiisg*D4 + 4 *i + tiih ];
21022175 }
21032176
2104- half s = (s4.x + s4.y + s4.z + s4.w )*scale + mv;
2177+ ss4[hiisg*4 + tiih] = s4;
2178+
2179+ simdgroup_barrier (mem_flags::mem_threadgroup);
2180+
2181+ if (tiih == 0 ) {
2182+ s4 = ss4[4 *hiisg + 0 ] + ss4[4 *hiisg + 1 ] + ss4[4 *hiisg + 2 ] + ss4[4 *hiisg + 3 ];
2183+
2184+ half s = (s4.x + s4.y + s4.z + s4.w )*scale + mv;
21052185
2106- const half Mold = M;
2186+ const half Mold = M;
21072187
2108- M = max (M, s);
2188+ M = max (M, s);
21092189
2110- const half ms = exp (Mold - M);
2111- const half vs = exp (s - M);
2190+ const half ms = exp (Mold - M);
2191+ const half vs = exp (s - M);
21122192
2113- for (int64_t d = 0 ; d < D4; ++d) {
2114- V16[d] = V16[d]*ms + pv4[d]*vs;
2193+ S = S*ms + vs;
2194+
2195+ ss[2 *hiisg + 0 ] = ms;
2196+ ss[2 *hiisg + 1 ] = vs;
21152197 }
21162198
2117- S = S*ms + vs;
2199+ simdgroup_barrier (mem_flags::mem_threadgroup);
2200+
2201+ const half ms = ss[2 *hiisg + 0 ];
2202+ const half vs = ss[2 *hiisg + 1 ];
2203+
2204+ for (int64_t i = 0 ; i < D4/4 ; ++i) {
2205+ ps4[hiisg*D4 + 4 *i + tiih] = ps4[hiisg*D4 + 4 *i + tiih]*ms + pv4[4 *i + tiih]*vs;
2206+ }
21182207 }
21192208
2120- for (int64_t d = 0 ; d < D4; ++d) {
2121- V16[d] /= S;
2209+ simdgroup_barrier (mem_flags::mem_threadgroup);
2210+
2211+ if (tiih == 0 ) {
2212+ for (int64_t i = 0 ; i < D4; ++i) {
2213+ ps4[hiisg*D4 + i] /= S;
2214+ }
21222215 }
21232216
2217+ simdgroup_barrier (mem_flags::mem_threadgroup);
2218+
21242219 // dst indices
21252220 const int64_t i1 = iq1;
21262221 const int64_t i2 = iq2;
21272222 const int64_t i3 = iq3;
21282223
21292224 device float4 * dst4 = (device float4 *) dst;
21302225
2131- for (int64_t d = 0 ; d < D4; ++d ) {
2132- dst4[(i3*ne2*ne1 + i2 + i1*ne1)*D4 + d ] = (float4) V16[d ];
2226+ for (int64_t i = 0 ; i < D4/ 4 ; ++i ) {
2227+ dst4[(i3*ne2*ne1 + i2 + i1*ne1)*D4 + 4 *i + tiih ] = (float4) ps4[hiisg*D4 + 4 *i + tiih ];
21332228 }
21342229}
21352230
0 commit comments