Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

[perf] Fused linear : small FW cleanup and much better perfs #283

Merged
merged 1 commit into from
Apr 26, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,15 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Fix some torchscriptability [#246]
- Fix FourierMix being compatible with AMP [#258]
- Better asserts on QKV dimensions [#264]
- Better perfs for FusedMLP and FusedLinearLayer [#283]

### Added
- Simplicial Embeddings [#259]
- Mem efficient attention, FW pass [#267]
- MHA benchmark
- MLP benchmark
- Move all triton kernels to triton v2 [#272]
- Mem efficient attention, BW pass [#281]

## [0.0.10] - 2022-03-14
### Fixed
Expand Down
Binary file modified docs/plots/fused_linear/FusedLinear_fp16_FW_BW_gelu.png
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified docs/plots/fused_linear/FusedLinear_fp16_FW_BW_leaky_relu.png
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified docs/plots/fused_linear/FusedLinear_fp16_FW_BW_none.png
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified docs/plots/fused_linear/FusedLinear_fp16_FW_BW_relu.png
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified docs/plots/fused_linear/FusedLinear_fp16_FW_BW_squared_relu.png
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified docs/plots/fused_linear/FusedLinear_fp16_FW_gelu.png
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified docs/plots/fused_linear/FusedLinear_fp16_FW_leaky_relu.png
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified docs/plots/fused_linear/FusedLinear_fp16_FW_none.png
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified docs/plots/fused_linear/FusedLinear_fp16_FW_relu.png
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified docs/plots/fused_linear/FusedLinear_fp16_FW_squared_relu.png
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
8 changes: 5 additions & 3 deletions tests/test_triton_fused_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,13 +47,13 @@ def test_fused_matmul(shape, dtype):

# Test that not passing any bias is fine
res_torch = a @ b
res_triton, _ = fused_matmul(a, b.transpose(0, 1), None)
res_triton, _ = fused_matmul(a, b.transpose(0, 1).contiguous(), None)
assert torch.allclose(res_torch, res_triton), "Vanilla matmul is broken"

# Now test with a real FMA
c = -torch.rand((shape[-2],), dtype=dtype, device="cuda")
res_torch = torch.addmm(c, a, b)
res_triton, _ = fused_matmul(a, b.transpose(1, 0), c)
res_triton, _ = fused_matmul(a, b.transpose(1, 0).contiguous(), c)

assert torch.allclose(
res_torch, res_triton
Expand All @@ -65,7 +65,9 @@ def test_fused_matmul(shape, dtype):
res_torch = torch_activation(torch.addmm(c, a, b))

triton_activation = get_triton_activation_kernel(activation)
res_triton, _ = fused_matmul(a, b.transpose(1, 0), c, triton_activation)
res_triton, _ = fused_matmul(
a, b.transpose(1, 0).contiguous(), c, triton_activation
)

# NOTE: @lefaudeux
# GeLUs are not well handled for now, we use an approximation
Expand Down
26 changes: 14 additions & 12 deletions xformers/triton/k_fused_matmul_fw.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def kernel_fma(
# element in a particular dimension. E.g. stride_am is how much to increase a_ptr
# by to get the element one row down (A has M rows)
stride_om, stride_im,
stride_wn, stride_wk,
stride_wn,
# Meta-parameters
BLOCK_M: tl.constexpr, GROUP_M: tl.constexpr,
BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
Expand Down Expand Up @@ -93,8 +93,8 @@ def kernel_fma(
rk = tl.arange(0, BLOCK_K)

# the memory addresses of elements can follow numpy broadcasting
input_ptrs = INPUT + rm[:, None] * stride_im + rk[None, :]
weight_ptrs = WEIGHT + rk[:, None] * stride_wk + rn[None, :] * stride_wn
input_ptrs = INPUT + rm[:, None] * stride_im
weight_ptrs = WEIGHT + rn[None, :] * stride_wn

# initialize and iteratively update accumulator
acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
Expand All @@ -105,27 +105,28 @@ def kernel_fma(

# block level matrix multiplication.
# We fetch a block memory block from both inputs, matmul and accumulate, then repeat
for _ in range(K, 0, -BLOCK_K):
a = tl.load(input_ptrs, mask=((rk[None, :] < K) & (rm[:, None] < M)), other=0.0)
w = tl.load(weight_ptrs, mask=((rk[:, None] < K) & (rn[None, :] < N)), other=0.0)
mask_rn = rn < N
mask_rm = rm < M

acc += tl.dot(a, w).to(tl.float32)
for i in range(0, K, BLOCK_K):
rk = tl.arange(0, BLOCK_K) + i
a = tl.load(input_ptrs + rk[None, :], mask=((rk[None, :] < K) & mask_rm[:, None]), other=0.0)
w = tl.load(weight_ptrs + rk[:, None], mask=((rk[:, None] < K) & mask_rn[None, :]), other=0.0)

input_ptrs += BLOCK_K
weight_ptrs += BLOCK_K * stride_wk
acc += tl.dot(a, w)

# optional: save the activation inputs
if SAVE_ACT_INPUTS:
act_in_ptrs = ACT_INPUTS + rm[:, None] * stride_om + rn[None, :]
tl.store(act_in_ptrs, acc, mask=(rm[:, None] < M) & (rn[None, :] < N))
tl.store(act_in_ptrs, acc, mask=mask_rm[:, None] & mask_rn[None, :])

# optional: fused activation (while the data is in shared memory)
if ACTIVATION:
acc = ACTIVATION(acc)

# write back result
out_ptrs = OUT + rm[:, None] * stride_om + rn[None, :]
tl.store(out_ptrs, acc, mask=(rm[:, None] < M) & (rn[None, :] < N))
tl.store(out_ptrs, acc, mask=mask_rm[:, None] & mask_rn[None, :])


# Activation needs to be a triton kernel
Expand Down Expand Up @@ -153,6 +154,7 @@ def fused_matmul(
assert (
bias is None or bias.shape[0] == weight.shape[0]
), "Incompatible dimensions in between weight and bias"
assert weight.is_contiguous()

M, K = x_.shape
N, K = weight.shape
Expand All @@ -169,7 +171,7 @@ def fused_matmul(
bias if bias is not None else x, # auto skip bias if not present
M, N, K, # shapes
outputs.stride(0), x_.stride(0), # strides
weight.stride(0), weight.stride(1),
weight.stride(0),
ACTIVATION=activation, # optional fused activation
BIAS=bias is not None, # optional fused bias
GROUP_M=8, # speed optimization: group the programs
Expand Down