diff --git a/README.md b/README.md index 073376e..167bba1 100644 --- a/README.md +++ b/README.md @@ -11,7 +11,7 @@ requests are greatly appreciated. The ultimate goal is to create a fast, portable, and user-friendly implementation of the llama2 model architecture. The code prioritizes simplicity and readability without sacrificing performance. Certain core functions have SIMD implementations using the Zig -`@Vector` feature, which provides a ~4x speed increase. For more details, +`@Vector` feature, which provides a ~5x speed increase. For more details, please refer to the [performance](#performance) section. The `stories15.bin` file is a model checkpoint for a 15M parameter model that @@ -36,7 +36,7 @@ Processor. | ------------------------------------------------- | -------- | | llama2.c `make run` | 116 | | llama2.c `make runfast` | 375 | -| llama2.zig `zig build run -Doptimize=ReleaseFast` | 482 | +| llama2.zig `zig build run -Doptimize=ReleaseFast` | 525 | ## Todo diff --git a/src/main.zig b/src/main.zig index df080e0..9416e0d 100644 --- a/src/main.zig +++ b/src/main.zig @@ -218,9 +218,15 @@ fn transformer(token: usize, pos: usize, config: *const Config, s: *RunState, w: rmsnorm(s.xb, x, w.rms_att_weight[l * dim ..][0..dim]); // qkv - matmul(s.q, s.xb, w.wq[l * dim * dim ..][0 .. dim * dim]); - matmul(s.k, s.xb, w.wk[l * dim * dim ..][0 .. dim * dim]); - matmul(s.v, s.xb, w.wv[l * dim * dim ..][0 .. dim * dim]); + // matmul(s.q, s.xb, w.wq[l * dim * dim ..][0 .. dim * dim]); + // matmul(s.k, s.xb, w.wk[l * dim * dim ..][0 .. dim * dim]); + // matmul(s.v, s.xb, w.wv[l * dim * dim ..][0 .. dim * dim]); + // fused version of the above + matmul_fused(3, [_][]f32{ s.q, s.k, s.v }, s.xb, [_][]f32{ + w.wq[l * dim * dim ..][0 .. dim * dim], + w.wk[l * dim * dim ..][0 .. dim * dim], + w.wv[l * dim * dim ..][0 .. dim * dim], + }); // apply RoPE rotation to the q and k vectors for each head for (0..config.n_heads) |h| { @@ -295,8 +301,13 @@ fn transformer(token: usize, pos: usize, config: *const Config, s: *RunState, w: // Now for FFN in PyTorch we have: self.w2(F.silu(self.w1(x)) * self.w3(x)) // first calculate self.w1(x) and self.w3(x) - matmul(s.hb, s.xb, w.w1[l * dim * hidden_dim ..][0 .. dim * hidden_dim]); - matmul(s.hb2, s.xb, w.w3[l * dim * hidden_dim ..][0 .. dim * hidden_dim]); + // matmul(s.hb, s.xb, w.w1[l * dim * hidden_dim ..][0 .. dim * hidden_dim]); + // matmul(s.hb2, s.xb, w.w3[l * dim * hidden_dim ..][0 .. dim * hidden_dim]); + // fused version of the above + matmul_fused(2, [_][]f32{ s.hb, s.hb2 }, s.xb, [_][]f32{ + w.w1[l * dim * hidden_dim ..][0 .. dim * hidden_dim], + w.w3[l * dim * hidden_dim ..][0 .. dim * hidden_dim], + }); // F.silu; silu(x)=x*σ(x),where σ(x) is the logistic sigmoid for (0..hidden_dim) |i| { @@ -356,7 +367,6 @@ fn rmsnorm(o: []f32, x: []f32, w: []f32) void { /// fn matmul(xout: []f32, x: []const f32, w: []const f32) void { // This one function accounts for ~90% of the total runtime. - @setFloatMode(std.builtin.FloatMode.Optimized); const d = xout.len; const n = x.len; assert(w.len == n * d); @@ -398,6 +408,68 @@ fn vector_dot_product(x: []const f32, y: []const f32) f32 { return @reduce(.Add, sum) + sum_rem; } +/// Does matrix vector multiplication using comptime to dynamically generate the fused steps. +fn matmul_fused(comptime N: usize, outs: [N][]f32, x: []const f32, ws: [N][]const f32) void { + if (N == 0) @compileError("N must be greater than 0"); + // go through and check that all the dimensions are correct + inline for (0..N) |i| { + assert(outs[i].len > 0); + assert(ws[i].len > 0); + assert(ws[i].len == x.len * outs[i].len); + if (i > 0) { + assert(outs[i].len == outs[i - 1].len); + assert(ws[i].len == ws[i - 1].len); + } + } + + const vector_width = DEFAULT_VECTOR_WIDTH; + const vec_len = x.len / vector_width; + const vec_rem = x.len % vector_width; + + const d = outs[0].len; + const n = x.len; + + for (0..d) |i| { + // pick out rows of W + var wrows: [N][]const f32 = undefined; + inline for (0..N) |j| { + wrows[j] = ws[j][i * n ..][0..n]; + } + + // Initialize sums + var sums: [N]@Vector(vector_width, f32) = undefined; + inline for (0..N) |j| { + sums[j] = @splat(0.0); + } + + var offset: usize = 0; + for (0..vec_len) |_| { + const xvec: @Vector(vector_width, f32) = x[offset..][0..vector_width].*; + inline for (0..N) |j| { + const wvec: @Vector(vector_width, f32) = wrows[j][offset..][0..vector_width].*; + sums[j] += xvec * wvec; + } + offset += vector_width; + } + + // process remaining elements with scalar ops + var sums_rem: [N]f32 = undefined; + inline for (0..N) |j| { + sums_rem[j] = 0.0; + } + for (0..vec_rem) |a| { + inline for (0..N) |j| { + sums_rem[j] += x[offset + a] * wrows[j][offset + a]; + } + } + + // reduce SIMD vector to scalar + inline for (0..N) |j| { + outs[j][i] = @reduce(.Add, sums[j]) + sums_rem[j]; + } + } +} + /// Computes vector vector multiplication elementwise and stores the result in the first vector. fn vector_mul(x: []f32, y: []const f32) void { assert(x.len == y.len);