Skip to content

Commit

Permalink
- Adding a dedicated sum kernel when on a strided dimension
Browse files Browse the repository at this point in the history
- adding a dedicated benchmark, better unit test
- moving to a tile based approach to better handle big buffers
- trying to find better scheduling defaults
  • Loading branch information
blefaudeux committed Dec 15, 2021
1 parent e75ec39 commit 00aebbc
Show file tree
Hide file tree
Showing 12 changed files with 241 additions and 8 deletions.
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -51,4 +51,4 @@ examples/data

# Hydra default output dir
multirun
outputs
outputs
Binary file added docs/plots/strided_sum/Strided_sum_fp16.png
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added docs/plots/strided_sum/Strided_sum_fp32.png
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
45 changes: 44 additions & 1 deletion tests/test_triton_basics.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,32 @@
# LICENSE file in the root directory of this source tree.


import pytest
import torch

SHAPES = [
(384, 128),
(8 * 384, 128),
(34, 128),
(16, 128),
(16, 512),
(8, 384),
(8, 1024),
(8, 2048),
(8, 4096),
(8, 4096),
(4, 12288),
]


_triton_available = torch.cuda.is_available()
if _triton_available:
try:
import triton
import triton.language as tl

from xformers.triton.sum_strided import sum_2d_dim_0

except (ImportError, ModuleNotFoundError):
_triton_available = False

Expand Down Expand Up @@ -39,7 +57,7 @@ def k_mean(X, Mean, Var, stride, N, **META):
# Compute variance
x_mean = tl.sum(x, axis=0) / N
x_zm = x - x_mean
x_zm = tl.where(cols < N, x_zm, 0.0) # THIS SHOULD NOT BE NEEDED
x_zm = tl.where(cols < N, x_zm, 0.0)
x_var = tl.sum(x_zm * x_zm, axis=0) / N
tl.store(Mean + row, x_mean)
tl.store(Var + row, x_var)
Expand Down Expand Up @@ -88,3 +106,28 @@ def test_mean():

assert torch.allclose(mean, t_mean, rtol=1e-1)
assert torch.allclose(var, t_var, rtol=1e-1)

@pytest.mark.parametrize("shape", SHAPES)
@pytest.mark.parametrize("dtype", [torch.float16, torch.float32])
def test_sum_strided(shape, dtype):
torch.random.manual_seed(0)
a = torch.rand(shape, device=torch.device("cuda"), dtype=dtype)

torch_sum = torch.sum(a, dim=0)
triton_sum = sum_2d_dim_0(a)
assert torch.allclose(
torch_sum, triton_sum, rtol=0.01
), f"{torch_sum}\n{triton_sum}"

def test_sum_strided_asserts():
torch.random.manual_seed(0)
a = torch.rand((128, 256), device=torch.device("cuda"), dtype=torch.float16)

with pytest.raises(AssertionError):
# This kernel is not useful in that case, assert to prevent misuse
sum_2d_dim_0(a.transpose(1, 0))

a = torch.rand((3, 128, 256), device=torch.device("cuda"), dtype=torch.float16)
with pytest.raises(AssertionError):
# This kernel expects 2D tensors, assert to prevent misuse
sum_2d_dim_0(a)
4 changes: 2 additions & 2 deletions tests/test_triton_dropout.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
)
_triton_available = False

# Testing odd shapes on purpose
# Testing odd (non-power-of-two for instance) shapes on purpose
SHAPES = [
(384, 128),
(8, 384, 128),
Expand Down Expand Up @@ -158,4 +158,4 @@ def test_dropout_parity(shape, amp, bias, activation, p):
if bias:
assert torch.allclose(
torch.norm(b.grad), torch.norm(b_.grad), rtol=0.01
), f"{b.grad}\n{b_.grad}"
), f"{b.grad.norm()}\n{b_.grad.norm()}"
2 changes: 1 addition & 1 deletion xformers/benchmarks/benchmark_triton_dropout.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@
(8, 512, 1024),
(4, 1024, 1024),
(2, 2048, 2048),
(2, 4096, 4096),
(1, 2048, 12288),
(2, 4096, 4096),
]

P = 0.1
Expand Down
71 changes: 71 additions & 0 deletions xformers/benchmarks/benchmark_triton_stride_sum.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.

from typing import Any, Dict, List

import torch
import triton

from xformers.benchmarks.utils import TestCase, pretty_plot, pretty_print
from xformers.triton.sum_strided import sum_2d_dim_0

SHAPES = [
(128, 128),
(384, 128),
(784, 512),
(1024, 768),
(2048, 1024),
(4096, 4096),
]


def to_gbs(a, ms):
# Read the full array, write the non-reduced dimension
return ((a.numel() + a.shape[1]) * a.element_size() * 1e-9) / (ms * 1e-3)


def bench_functions(
test_cases: List[TestCase], shapes, metric_transform, unit, title=""
):
device = torch.device("cuda")

for dtype in [torch.float16, torch.float32]:
results: Dict[str, Any] = {}

for M, N in shapes:
a = torch.rand(M, N, device=device, dtype=dtype, requires_grad=True)

for testcase in test_cases:
time = triton.testing.do_bench(lambda: testcase.function(a))[0]

metric = metric_transform(a, time)

key = f"M={M}, N={N}"
if key not in results:
results[key] = {}

results[key][testcase.name] = f"{metric:.1f}"

_type = " fp16" if dtype == torch.float16 else " fp32"

pretty_print(
results,
title=" ------------- Type: {} ------------- ".format(_type),
units=unit,
)

pretty_plot(results, title + _type, unit, dash_key="pytorch")


bench_functions(
[
TestCase(lambda x: torch.sum(x, dim=0), "pytorch"),
TestCase(sum_2d_dim_0, "triton"),
],
SHAPES,
to_gbs,
"GB/s",
"Strided_sum",
)
3 changes: 2 additions & 1 deletion xformers/components/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@
from .activations import Activation, build_activation # noqa
from .attention import Attention, build_attention # noqa
from .in_proj_container import InProjContainer, InProjParams # noqa
from .multi_head_dispatch import MultiHeadDispatch, MultiHeadDispatchConfig # noqa
from .multi_head_dispatch import MultiHeadDispatch # noqa
from .multi_head_dispatch import MultiHeadDispatchConfig
from .residual import LayerNormStyle, PostNorm, PreNorm, Residual # noqa

# automatically import any Python files in the directory
Expand Down
3 changes: 2 additions & 1 deletion xformers/triton/dropout.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
get_triton_activation_kernel,
)
from xformers.triton.k_dropout import k_dropout_bw, k_dropout_fw
from xformers.triton.sum_strided import sum_2d_dim_0


# Helper to handle the SPMD launch grid and error cases
Expand Down Expand Up @@ -104,7 +105,7 @@ def grid(meta):
# fmt: on

if ctx.trainable_bias:
grad_bias: Optional[torch.Tensor] = torch.sum(grad_in, dim=0)
grad_bias: Optional[torch.Tensor] = sum_2d_dim_0(grad_in)
else:
grad_bias = None

Expand Down
4 changes: 3 additions & 1 deletion xformers/triton/k_fused_matmul_bw.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
import triton
import triton.language as tl

from xformers.triton.sum_strided import sum_2d_dim_0


# fmt: off
@triton.heuristics({
Expand Down Expand Up @@ -144,6 +146,6 @@ def grid(META):
# The following ops can also be handled by triton
grad_in = grad_out_ @ weight
grad_weight = grad_out_.transpose(1, 0) @ inputs_ if trainable_weight else None
grad_bias = torch.sum(grad_out_, 0) if trainable_bias else None
grad_bias = sum_2d_dim_0(grad_out_) if trainable_bias else None

return grad_in.reshape_as(inputs), grad_weight, grad_bias
55 changes: 55 additions & 0 deletions xformers/triton/k_sum.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.

import triton
import triton.language as tl


# fmt: off
@triton.jit
def k_sum_0(
Y, X,
stride_xm,
M, N,
is_fp16,
**meta,
):
# fmt: om

"""
Sum a 2d tensor over the first (strided) dimension.
This extracts some speed through a parallel sum across the second dimension
"""
BLOCK_M = meta["BLOCK_M"]
BLOCK_N = meta["BLOCK_N"]

# partial row indices. We'll reduce over this dimension
m = tl.arange(0, BLOCK_M)

# To get some extra parallelization, we handle several columns in the same thread block
rn = tl.program_id(axis=0) * BLOCK_N + tl.arange(0, BLOCK_N)

# the memory address of all the elements that we want to load can be computed as follows
x_ptrs = X + m[:, None] * stride_xm + rn[None, :]
x_sum = tl.zeros((BLOCK_N,), dtype=tl.float32)

tiles = M // BLOCK_M
if M % BLOCK_M > 0:
tiles += 1

col_mask = (rn[None, :] < N)

for _ in range(tiles):
# load input data; pad out-of-bounds elements with 0
# NOTE: make sure to accumulate in fp32 to prevent a trivial overflow
mask = (m[:, None] < M) & col_mask
x = tl.load(x_ptrs, mask=mask, other=0.0)
x_sum += tl.sum(x, 0)

# move the load pointer
x_ptrs += BLOCK_M * stride_xm
m += BLOCK_M # update the mask check

tl.store(Y + rn, x_sum, mask=rn < N)
60 changes: 60 additions & 0 deletions xformers/triton/sum_strided.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.


import torch
import triton

from xformers.triton.k_sum import k_sum_0


def sum_2d_dim_0(x: torch.Tensor):
"""
Sum a 2D tensor across the first dimension
"""

out = torch.empty(x.shape[1], device=x.device, dtype=x.dtype)

assert (
x.ndim == 2
), "This is a very specific kernel, only for 2-dim tensors and summing along dim 0"
M, N = x.shape

# This kernel is not competitive for these sizes
if M > 2048 or M < 8:
return x.sum(dim=0)

assert (
M >= 4
), "This is a very specific kernel, requires the reduction dimension to be bigger than 4"

assert x.stride(1) == 1, (
"We're expecting x to be contiguous along dim 1, and non contiguous along dim 0.\n"
" You would probably be better served with torch.sum()"
)

BLOCK_M = min(triton.next_power_of_2(M), 2048)
BLOCK_N = 32
if BLOCK_M > 256:
BLOCK_N = 16
if BLOCK_M > 1024:
BLOCK_N = 8

def grid(meta):
return (triton.cdiv(N, meta["BLOCK_N"]),)

# fmt: off
k_sum_0[grid](
out, x,
x.stride(0),
M, N,
x.dtype == torch.float16,
BLOCK_M=BLOCK_M,
BLOCK_N=BLOCK_N,
num_stages=4,
)
# fmt: on

return out

0 comments on commit 00aebbc

Please sign in to comment.