Skip to content

Commit

Permalink
Add fused matmul
Browse files Browse the repository at this point in the history
  • Loading branch information
cgbur committed Jul 31, 2023
1 parent be6614b commit 877f929
Show file tree
Hide file tree
Showing 2 changed files with 80 additions and 8 deletions.
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ requests are greatly appreciated. The ultimate goal is to create a fast,
portable, and user-friendly implementation of the llama2 model architecture.
The code prioritizes simplicity and readability without sacrificing
performance. Certain core functions have SIMD implementations using the Zig
`@Vector` feature, which provides a ~4x speed increase. For more details,
`@Vector` feature, which provides a ~5x speed increase. For more details,
please refer to the [performance](#performance) section.

The `stories15.bin` file is a model checkpoint for a 15M parameter model that
Expand All @@ -36,7 +36,7 @@ Processor.
| ------------------------------------------------- | -------- |
| llama2.c `make run` | 116 |
| llama2.c `make runfast` | 375 |
| llama2.zig `zig build run -Doptimize=ReleaseFast` | 482 |
| llama2.zig `zig build run -Doptimize=ReleaseFast` | 525 |

## Todo

Expand Down
84 changes: 78 additions & 6 deletions src/main.zig
Original file line number Diff line number Diff line change
Expand Up @@ -218,9 +218,15 @@ fn transformer(token: usize, pos: usize, config: *const Config, s: *RunState, w:
rmsnorm(s.xb, x, w.rms_att_weight[l * dim ..][0..dim]);

// qkv
matmul(s.q, s.xb, w.wq[l * dim * dim ..][0 .. dim * dim]);
matmul(s.k, s.xb, w.wk[l * dim * dim ..][0 .. dim * dim]);
matmul(s.v, s.xb, w.wv[l * dim * dim ..][0 .. dim * dim]);
// matmul(s.q, s.xb, w.wq[l * dim * dim ..][0 .. dim * dim]);
// matmul(s.k, s.xb, w.wk[l * dim * dim ..][0 .. dim * dim]);
// matmul(s.v, s.xb, w.wv[l * dim * dim ..][0 .. dim * dim]);
// fused version of the above
matmul_fused(3, [_][]f32{ s.q, s.k, s.v }, s.xb, [_][]f32{
w.wq[l * dim * dim ..][0 .. dim * dim],
w.wk[l * dim * dim ..][0 .. dim * dim],
w.wv[l * dim * dim ..][0 .. dim * dim],
});

// apply RoPE rotation to the q and k vectors for each head
for (0..config.n_heads) |h| {
Expand Down Expand Up @@ -295,8 +301,13 @@ fn transformer(token: usize, pos: usize, config: *const Config, s: *RunState, w:

// Now for FFN in PyTorch we have: self.w2(F.silu(self.w1(x)) * self.w3(x))
// first calculate self.w1(x) and self.w3(x)
matmul(s.hb, s.xb, w.w1[l * dim * hidden_dim ..][0 .. dim * hidden_dim]);
matmul(s.hb2, s.xb, w.w3[l * dim * hidden_dim ..][0 .. dim * hidden_dim]);
// matmul(s.hb, s.xb, w.w1[l * dim * hidden_dim ..][0 .. dim * hidden_dim]);
// matmul(s.hb2, s.xb, w.w3[l * dim * hidden_dim ..][0 .. dim * hidden_dim]);
// fused version of the above
matmul_fused(2, [_][]f32{ s.hb, s.hb2 }, s.xb, [_][]f32{
w.w1[l * dim * hidden_dim ..][0 .. dim * hidden_dim],
w.w3[l * dim * hidden_dim ..][0 .. dim * hidden_dim],
});

// F.silu; silu(x)=x*σ(x),where σ(x) is the logistic sigmoid
for (0..hidden_dim) |i| {
Expand Down Expand Up @@ -356,7 +367,6 @@ fn rmsnorm(o: []f32, x: []f32, w: []f32) void {
///
fn matmul(xout: []f32, x: []const f32, w: []const f32) void {
// This one function accounts for ~90% of the total runtime.
@setFloatMode(std.builtin.FloatMode.Optimized);
const d = xout.len;
const n = x.len;
assert(w.len == n * d);
Expand Down Expand Up @@ -398,6 +408,68 @@ fn vector_dot_product(x: []const f32, y: []const f32) f32 {
return @reduce(.Add, sum) + sum_rem;
}

/// Does matrix vector multiplication using comptime to dynamically generate the fused steps.
fn matmul_fused(comptime N: usize, outs: [N][]f32, x: []const f32, ws: [N][]const f32) void {
if (N == 0) @compileError("N must be greater than 0");
// go through and check that all the dimensions are correct
inline for (0..N) |i| {
assert(outs[i].len > 0);
assert(ws[i].len > 0);
assert(ws[i].len == x.len * outs[i].len);
if (i > 0) {
assert(outs[i].len == outs[i - 1].len);
assert(ws[i].len == ws[i - 1].len);
}
}

const vector_width = DEFAULT_VECTOR_WIDTH;
const vec_len = x.len / vector_width;
const vec_rem = x.len % vector_width;

const d = outs[0].len;
const n = x.len;

for (0..d) |i| {
// pick out rows of W
var wrows: [N][]const f32 = undefined;
inline for (0..N) |j| {
wrows[j] = ws[j][i * n ..][0..n];
}

// Initialize sums
var sums: [N]@Vector(vector_width, f32) = undefined;
inline for (0..N) |j| {
sums[j] = @splat(0.0);
}

var offset: usize = 0;
for (0..vec_len) |_| {
const xvec: @Vector(vector_width, f32) = x[offset..][0..vector_width].*;
inline for (0..N) |j| {
const wvec: @Vector(vector_width, f32) = wrows[j][offset..][0..vector_width].*;
sums[j] += xvec * wvec;
}
offset += vector_width;
}

// process remaining elements with scalar ops
var sums_rem: [N]f32 = undefined;
inline for (0..N) |j| {
sums_rem[j] = 0.0;
}
for (0..vec_rem) |a| {
inline for (0..N) |j| {
sums_rem[j] += x[offset + a] * wrows[j][offset + a];
}
}

// reduce SIMD vector to scalar
inline for (0..N) |j| {
outs[j][i] = @reduce(.Add, sums[j]) + sums_rem[j];
}
}
}

/// Computes vector vector multiplication elementwise and stores the result in the first vector.
fn vector_mul(x: []f32, y: []const f32) void {
assert(x.len == y.len);
Expand Down

0 comments on commit 877f929

Please sign in to comment.