Skip to content

Commit 8cde449

Browse files
committed
wip : 8 rows per simd group
1 parent b973258 commit 8cde449

File tree

2 files changed

+139
-44
lines changed

2 files changed

+139
-44
lines changed

ggml-metal.m

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2252,14 +2252,14 @@ static bool ggml_metal_graph_compute(
22522252
[encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:26];
22532253
[encoder setBytes:&scale length:sizeof( float) atIndex:27];
22542254

2255-
const int nwarps = 1;
2255+
const int64_t nwarps = 2;
22562256

2257-
const size_t shalf = sizeof(float)/2;
2257+
const size_t smem = nwarps*(2*8*nwarps*ne00 + 128)*(sizeof(float)/2);
22582258

2259-
GGML_ASSERT(2*32*nwarps*ne00*shalf <= ctx->device.maxThreadgroupMemoryLength);
2260-
[encoder setThreadgroupMemoryLength:2*32*nwarps*ne00*shalf atIndex:0];
2259+
GGML_ASSERT(smem <= ctx->device.maxThreadgroupMemoryLength);
2260+
[encoder setThreadgroupMemoryLength:smem atIndex:0];
22612261

2262-
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 31)/32, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(32, 1, 1)];
2262+
[encoder dispatchThreadgroups:MTLSizeMake(ne01, (ne02 + 8*nwarps - 1)/(8*nwarps), ne03) threadsPerThreadgroup:MTLSizeMake(32*nwarps, 1, 1)];
22632263
} break;
22642264
case GGML_OP_DUP:
22652265
case GGML_OP_CPY:

ggml-metal.metal

Lines changed: 134 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -2031,33 +2031,20 @@ kernel void kernel_flash_attn_ext_f16(
20312031
uint3 ntg[[threads_per_threadgroup]],
20322032
uint tiisg[[thread_index_in_simdgroup]],
20332033
uint sgitg[[simdgroup_index_in_threadgroup]]) {
2034-
const int64_t iq3 = tgpig[2];
2035-
const int64_t iq2 = tgpig[1];
2036-
const int64_t iq1 = tgpig[0]*N_SIMDWIDTH + tiisg;
2037-
2038-
if (iq1 >= ne01) {
2039-
return;
2040-
}
2034+
//const int64_t iq3 = tgpig[2];
2035+
//const int64_t iq2 = tgpig[1];
2036+
//const int64_t iq1 = tgpig[0]*N_SIMDWIDTH + tiisg;
20412037

2042-
const int64_t D4 = D/4;
2038+
const uint nsg = ntg.x/N_SIMDWIDTH; // number of simdgroups
20432039

2044-
// TODO: can we move this to the stack?
2045-
threadgroup half4 * V16 = (threadgroup half4 *) (shared + (2*sgitg*N_SIMDWIDTH + tiisg)*D);
2040+
const int64_t iq3 = tgpig[2];
2041+
const int64_t iq2 = tgpig[1]*(8*nsg) + 8*sgitg + tiisg/4;
2042+
const int64_t iq1 = tgpig[0];
20462043

2047-
// initialize with zeros
2048-
for (int64_t d = 0; d < D4; ++d) {
2049-
V16[d] = 0.0h;
2044+
if (iq2 >= ne02) {
2045+
return;
20502046
}
20512047

2052-
threadgroup half4 * pq4 = (threadgroup half4 *) (shared + (2*sgitg*N_SIMDWIDTH + N_SIMDWIDTH)*D + tiisg*D);
2053-
2054-
half S = 0.0h;
2055-
half M = -INFINITY;
2056-
2057-
const int64_t ir = iq3*ne02*ne01 + iq2*ne01 + iq1;
2058-
2059-
device const float * mp = mask ? (device const float *) (mask + (ir%ne31)*nb31) : nullptr;
2060-
20612048
// assume K and V are same shape
20622049
const int64_t ne22 = ne12;
20632050
const int64_t ne23 = ne13;
@@ -2081,11 +2068,97 @@ kernel void kernel_flash_attn_ext_f16(
20812068
const int64_t iv2 = iq2 / rv2;
20822069
const int64_t iv3 = iq3 / rv3;
20832070

2084-
// load Q to shared memory
2085-
for (int64_t d = 0; d < D4; ++d) {
2086-
pq4[d] = ((device const half4 *) ((device const char *) q + (iq1*nb01 + iq2*nb02 + iq3*nb03)))[d];
2071+
const int64_t ir = iq3*ne02*ne01 + iq2*ne01 + iq1;
2072+
2073+
device const float * mp = mask ? (device const float *) (mask + (ir%ne31)*nb31) : nullptr;
2074+
2075+
// const int64_t D4 = D/4;
2076+
//
2077+
// // TODO: can we move this to the stack?
2078+
// threadgroup half4x4 * V16 = (threadgroup half4x4 *) (shared);
2079+
//
2080+
// // initialize with zeros
2081+
// for (int64_t d = 0; d < D4; ++d) {
2082+
//
2083+
// }
2084+
//
2085+
// threadgroup half4 * pq4 = (threadgroup half4 *) (shared + 4*D);
2086+
//
2087+
// // load Q to shared memory
2088+
// for (int64_t d = 0; d < D4; ++d) {
2089+
// pq4[d] = ((device const half4 *) ((device const char *) q + (iq1*nb01 + iq2*nb02 + iq3*nb03)))[d];
2090+
// }
2091+
//
2092+
// half S = 0.0h;
2093+
// half M = -INFINITY;
2094+
//
2095+
// for (int64_t ic = 0; ic < ne11; ++ic) {
2096+
// const half mv = mp ? mp[ic] : 0.0h;
2097+
// if (mv == -INFINITY) {
2098+
// continue;
2099+
// }
2100+
//
2101+
// device const half4 * pk4 = (device const half4 *) ((device char *) k + (ic*nb11 + ik2*nb12 + ik3*nb13));
2102+
// device const half4 * pv4 = (device const half4 *) ((device char *) v + (ic*nb21 + iv2*nb22 + iv3*nb23));
2103+
//
2104+
// half4 s4 = 0.0h;
2105+
//
2106+
// for (int64_t d = 0; d < D4; ++d) {
2107+
// s4 += pk4[d] * pq4[d];
2108+
// }
2109+
//
2110+
// half s = (s4.x + s4.y + s4.z + s4.w)*scale + mv;
2111+
//
2112+
// const half Mold = M;
2113+
//
2114+
// M = max(M, s);
2115+
//
2116+
// const half ms = exp(Mold - M);
2117+
// const half vs = exp(s - M);
2118+
//
2119+
// for (int64_t d = 0; d < D4; ++d) {
2120+
// V16[d] = V16[d]*ms + pv4[d]*vs;
2121+
// }
2122+
//
2123+
// S = S*ms + vs;
2124+
// }
2125+
//
2126+
// for (int64_t d = 0; d < D4; ++d) {
2127+
// V16[d] /= S;
2128+
// }
2129+
//
2130+
// // dst indices
2131+
// const int64_t i1 = iq1;
2132+
// const int64_t i2 = iq2;
2133+
// const int64_t i3 = iq3;
2134+
//
2135+
// device float4 * dst4 = (device float4 *) dst;
2136+
//
2137+
// for (int64_t d = 0; d < D4; ++d) {
2138+
// dst4[(i3*ne2*ne1 + i2 + i1*ne1)*D4 + d] = (float4) V16[d];
2139+
// }
2140+
2141+
const int64_t D4 = D/4;
2142+
2143+
threadgroup half4 * pq4 = (threadgroup half4 *) (shared + sgitg*(16*D + 128) );
2144+
threadgroup half4 * ps4 = (threadgroup half4 *) (shared + sgitg*(16*D + 128) + 8*D);
2145+
threadgroup half4 * ss4 = (threadgroup half4 *) (shared + sgitg*(16*D + 128) + 16*D);
2146+
threadgroup half * ss = (threadgroup half *) (shared + sgitg*(16*D + 128) + 16*D);
2147+
2148+
const uint tiih = tiisg%4; // thread index in head
2149+
const uint hiisg = tiisg/4; // head index in simdgroup
2150+
2151+
// load 8 heads from Q to shared memory
2152+
for (int64_t i = 0; i < D4/4; ++i) {
2153+
pq4[hiisg*D4 + 4*i + tiih] = ((device const half4 *) ((device const char *) q + (iq1*nb01 + iq2*nb02 + iq3*nb03)))[4*i + tiih];
2154+
ps4[hiisg*D4 + 4*i + tiih] = 0.0h;
20872155
}
20882156

2157+
simdgroup_barrier(mem_flags::mem_threadgroup);
2158+
2159+
half S = 0.0h;
2160+
half M = -INFINITY;
2161+
20892162
for (int64_t ic = 0; ic < ne11; ++ic) {
20902163
const half mv = mp ? mp[ic] : 0.0h;
20912164
if (mv == -INFINITY) {
@@ -2097,39 +2170,61 @@ kernel void kernel_flash_attn_ext_f16(
20972170

20982171
half4 s4 = 0.0h;
20992172

2100-
for (int64_t d = 0; d < D4; ++d) {
2101-
s4 += pk4[d] * pq4[d];
2173+
for (int64_t i = 0; i < D4/4; ++i) {
2174+
s4 += pk4[4*i + tiih] * pq4[hiisg*D4 + 4*i + tiih];
21022175
}
21032176

2104-
half s = (s4.x + s4.y + s4.z + s4.w)*scale + mv;
2177+
ss4[hiisg*4 + tiih] = s4;
2178+
2179+
simdgroup_barrier(mem_flags::mem_threadgroup);
2180+
2181+
if (tiih == 0) {
2182+
s4 = ss4[4*hiisg + 0] + ss4[4*hiisg + 1] + ss4[4*hiisg + 2] + ss4[4*hiisg + 3];
2183+
2184+
half s = (s4.x + s4.y + s4.z + s4.w)*scale + mv;
21052185

2106-
const half Mold = M;
2186+
const half Mold = M;
21072187

2108-
M = max(M, s);
2188+
M = max(M, s);
21092189

2110-
const half ms = exp(Mold - M);
2111-
const half vs = exp(s - M);
2190+
const half ms = exp(Mold - M);
2191+
const half vs = exp(s - M);
21122192

2113-
for (int64_t d = 0; d < D4; ++d) {
2114-
V16[d] = V16[d]*ms + pv4[d]*vs;
2193+
S = S*ms + vs;
2194+
2195+
ss[2*hiisg + 0] = ms;
2196+
ss[2*hiisg + 1] = vs;
21152197
}
21162198

2117-
S = S*ms + vs;
2199+
simdgroup_barrier(mem_flags::mem_threadgroup);
2200+
2201+
const half ms = ss[2*hiisg + 0];
2202+
const half vs = ss[2*hiisg + 1];
2203+
2204+
for (int64_t i = 0; i < D4/4; ++i) {
2205+
ps4[hiisg*D4 + 4*i + tiih] = ps4[hiisg*D4 + 4*i + tiih]*ms + pv4[4*i + tiih]*vs;
2206+
}
21182207
}
21192208

2120-
for (int64_t d = 0; d < D4; ++d) {
2121-
V16[d] /= S;
2209+
simdgroup_barrier(mem_flags::mem_threadgroup);
2210+
2211+
if (tiih == 0) {
2212+
for (int64_t i = 0; i < D4; ++i) {
2213+
ps4[hiisg*D4 + i] /= S;
2214+
}
21222215
}
21232216

2217+
simdgroup_barrier(mem_flags::mem_threadgroup);
2218+
21242219
// dst indices
21252220
const int64_t i1 = iq1;
21262221
const int64_t i2 = iq2;
21272222
const int64_t i3 = iq3;
21282223

21292224
device float4 * dst4 = (device float4 *) dst;
21302225

2131-
for (int64_t d = 0; d < D4; ++d) {
2132-
dst4[(i3*ne2*ne1 + i2 + i1*ne1)*D4 + d] = (float4) V16[d];
2226+
for (int64_t i = 0; i < D4/4; ++i) {
2227+
dst4[(i3*ne2*ne1 + i2 + i1*ne1)*D4 + 4*i + tiih] = (float4) ps4[hiisg*D4 + 4*i + tiih];
21332228
}
21342229
}
21352230

0 commit comments

Comments
 (0)