diff --git a/docs/plots/strided_sum/Strided_sum_fp16.png b/docs/plots/strided_sum/Strided_sum_fp16.png deleted file mode 100644 index e8fd201f6..000000000 Binary files a/docs/plots/strided_sum/Strided_sum_fp16.png and /dev/null differ diff --git a/docs/plots/strided_sum/Strided_sum_fp32.png b/docs/plots/strided_sum/Strided_sum_fp32.png deleted file mode 100644 index 5e1c1f22a..000000000 Binary files a/docs/plots/strided_sum/Strided_sum_fp32.png and /dev/null differ diff --git a/tests/test_triton_basics.py b/tests/test_triton_basics.py index 1376aff45..e41ab7afa 100644 --- a/tests/test_triton_basics.py +++ b/tests/test_triton_basics.py @@ -4,32 +4,14 @@ # LICENSE file in the root directory of this source tree. -import pytest import torch -SHAPES = [ - (384, 128), - (8 * 384, 128), - (34, 128), - (16, 128), - (16, 512), - (8, 384), - (8, 1024), - (8, 2048), - (8, 4096), - (8, 4096), - (4, 12288), -] - - _triton_available = torch.cuda.is_available() if _triton_available: try: import triton import triton.language as tl - from xformers.triton.sum_strided import sum_2d_dim_0 - except (ImportError, ModuleNotFoundError): _triton_available = False @@ -57,7 +39,7 @@ def k_mean(X, Mean, Var, stride, N, **META): # Compute variance x_mean = tl.sum(x, axis=0) / N x_zm = x - x_mean - x_zm = tl.where(cols < N, x_zm, 0.0) + x_zm = tl.where(cols < N, x_zm, 0.0) # THIS SHOULD NOT BE NEEDED x_var = tl.sum(x_zm * x_zm, axis=0) / N tl.store(Mean + row, x_mean) tl.store(Var + row, x_var) @@ -106,33 +88,3 @@ def test_mean(): assert torch.allclose(mean, t_mean, rtol=1e-1) assert torch.allclose(var, t_var, rtol=1e-1) - - @pytest.mark.parametrize("shape", SHAPES) - @pytest.mark.parametrize("dtype", [torch.float16, torch.float32]) - def test_sum_strided(shape, dtype): - torch.random.manual_seed(0) - a = torch.rand(shape, device=torch.device("cuda"), dtype=dtype) - - torch_sum = torch.sum(a, dim=0) - triton_sum = sum_2d_dim_0(a) - assert torch.allclose( - torch_sum, triton_sum, rtol=0.01 - ), f"{torch_sum}\n{triton_sum}" - - def test_sum_strided_asserts(): - torch.random.manual_seed(0) - a = torch.rand((128, 256), device=torch.device("cuda"), dtype=torch.float16) - - with pytest.raises(AssertionError): - # This kernel is not useful in that case, assert to prevent misuse - sum_2d_dim_0(a.transpose(1, 0)) - - a = torch.rand((3, 128, 256), device=torch.device("cuda"), dtype=torch.float16) - with pytest.raises(AssertionError): - # This kernel expects 2D tensors, assert to prevent misuse - sum_2d_dim_0(a) - - a = torch.rand((2, 128), device=torch.device("cuda"), dtype=torch.float16) - with pytest.raises(AssertionError): - # This kernel cannot sum over dimensions < 4 - sum_2d_dim_0(a) diff --git a/xformers/benchmarks/benchmark_triton_stride_sum.py b/xformers/benchmarks/benchmark_triton_stride_sum.py deleted file mode 100644 index 6fb887e78..000000000 --- a/xformers/benchmarks/benchmark_triton_stride_sum.py +++ /dev/null @@ -1,71 +0,0 @@ -# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. -# -# This source code is licensed under the BSD license found in the -# LICENSE file in the root directory of this source tree. - -from typing import Any, Dict, List - -import torch -import triton - -from xformers.benchmarks.utils import TestCase, pretty_plot, pretty_print -from xformers.triton.sum_strided import sum_2d_dim_0 - -SHAPES = [ - (128, 128), - (384, 128), - (784, 512), - (1024, 768), - (2048, 1024), - (4096, 4096), -] - - -def to_gbs(a, ms): - # Read the full array, write the non-reduced dimension - return ((a.numel() + a.shape[1]) * a.element_size() * 1e-9) / (ms * 1e-3) - - -def bench_functions( - test_cases: List[TestCase], shapes, metric_transform, unit, title="" -): - device = torch.device("cuda") - - for dtype in [torch.float16, torch.float32]: - results: Dict[str, Any] = {} - - for M, N in shapes: - a = torch.rand(M, N, device=device, dtype=dtype, requires_grad=True) - - for testcase in test_cases: - time = triton.testing.do_bench(lambda: testcase.function(a))[0] - - metric = metric_transform(a, time) - - key = f"M={M}, N={N}" - if key not in results: - results[key] = {} - - results[key][testcase.name] = f"{metric:.1f}" - - _type = " fp16" if dtype == torch.float16 else " fp32" - - pretty_print( - results, - title=" ------------- Type: {} ------------- ".format(_type), - units=unit, - ) - - pretty_plot(results, title + _type, unit, dash_key="pytorch") - - -bench_functions( - [ - TestCase(lambda x: torch.sum(x, dim=0), "pytorch"), - TestCase(sum_2d_dim_0, "triton"), - ], - SHAPES, - to_gbs, - "GB/s", - "Strided_sum", -) diff --git a/xformers/triton/dropout.py b/xformers/triton/dropout.py index 54799be51..b7178b950 100644 --- a/xformers/triton/dropout.py +++ b/xformers/triton/dropout.py @@ -44,6 +44,9 @@ def grid(meta): triton.cdiv(N, meta["BLOCK_N"]), ) + GROUP_M = 128 + BLOCK_M = GROUP_M // 4 + # fmt: off k_dropout_fw[grid]( y, x_, @@ -53,7 +56,8 @@ def grid(meta): M, N, p, USE_BIAS=bias is not None, - ACTIVATION=activation + ACTIVATION=activation, + BLOCK_M=BLOCK_M ) # fmt: on @@ -87,12 +91,21 @@ def backward(ctx, grad_out): elif inputs.ndim > 2: inputs = inputs.reshape(-1, N) + GROUP_M = 128 + BLOCK_M = GROUP_M // 4 + N_BLOCKS_M = triton.cdiv(M, GROUP_M) + if ctx.trainable_bias: - grad_bias = torch.empty((N,), device=grad_in.device, dtype=grad_in.dtype) - locks = torch.zeros(N // 2, dtype=torch.int32, device=grad_in.device) + grad_bias = torch.empty( + ( + N_BLOCKS_M, + N, + ), + device=grad_in.device, + dtype=grad_in.dtype, + ) else: grad_bias = grad_in # will not be used - locks = grad_in def grid(meta): return ( @@ -104,20 +117,21 @@ def grid(meta): k_dropout_bw[grid]( grad_in, grad_bias, grad_out_, inputs, bias if bias is not None else inputs, - seeds, locks, + seeds, grad_out_.stride(0), inputs.stride(0), M, N, ctx.p, USE_BIAS=bias is not None, ACTIVATION_GRAD=ctx.activation_grad, - TRAINABLE_BIAS=ctx.trainable_bias + TRAINABLE_BIAS=ctx.trainable_bias, + BLOCK_M=BLOCK_M ) # fmt: on return ( grad_in.reshape_as(grad_out), None, - grad_bias if ctx.trainable_bias else None, + grad_bias.sum(dim=0) if ctx.trainable_bias else None, None, None, None, diff --git a/xformers/triton/k_dropout.py b/xformers/triton/k_dropout.py index 58589f0a1..e7656204e 100644 --- a/xformers/triton/k_dropout.py +++ b/xformers/triton/k_dropout.py @@ -10,17 +10,11 @@ import triton import triton.language as tl -# WARNING: For now, the number of threads must be the same as the N buffer, and warps have to be 4 (will be fixed) k_configs = [ - triton.Config({"BLOCK_M": 64, "BLOCK_N": 32}, num_warps=1), - triton.Config({"BLOCK_M": 128, "BLOCK_N": 32}, num_warps=1), - triton.Config({"BLOCK_M": 256, "BLOCK_N": 32}, num_warps=1), - triton.Config({"BLOCK_M": 64, "BLOCK_N": 64}, num_warps=2), - triton.Config({"BLOCK_M": 128, "BLOCK_N": 64}, num_warps=2), - triton.Config({"BLOCK_M": 256, "BLOCK_N": 64}, num_warps=2), - triton.Config({"BLOCK_M": 64, "BLOCK_N": 128}, num_warps=4), - triton.Config({"BLOCK_M": 128, "BLOCK_N": 128}, num_warps=4), - triton.Config({"BLOCK_M": 256, "BLOCK_N": 128}, num_warps=4), + triton.Config({"BLOCK_N": 32}, num_warps=1), + triton.Config({"BLOCK_N": 64}, num_warps=2), + triton.Config({"BLOCK_N": 128}, num_warps=4), + triton.Config({"BLOCK_N": 256}, num_warps=8), ] @@ -131,7 +125,7 @@ def k_dropout_fw( @triton.jit def k_dropout_bw( GRAD_IN, GRAD_BIAS, GRAD_OUT, - INPUTS, BIAS, SEEDS, LOCKS, + INPUTS, BIAS, SEEDS, stride_grad, stride_inputs, M, N, p, @@ -250,23 +244,5 @@ def k_dropout_bw( rand_mask = rand_mask1 if TRAINABLE_BIAS: - lock_ptr = LOCKS + 2 * col_id - count_ptr = LOCKS + 2 * col_id + 1 - grad_bias_ptr = GRAD_BIAS + cols - - # Uniquely taking a lock over the col results - while tl.atomic_cas(lock_ptr, 0, 1) == 1: - pass - - count = tl.load(count_ptr) - if count == 0: - # first store doesn't accumulate - tl.atomic_xchg(count_ptr, 1) - else: - # read and add back - grad_bias += tl.load(grad_bias_ptr, mask=cols < N) - + grad_bias_ptr = GRAD_BIAS + row_id * N + cols tl.store(grad_bias_ptr, grad_bias, mask=cols < N) - - # release lock - tl.atomic_xchg(lock_ptr, 0) diff --git a/xformers/triton/k_fused_matmul_bw.py b/xformers/triton/k_fused_matmul_bw.py index bfe322924..49d25f3c3 100644 --- a/xformers/triton/k_fused_matmul_bw.py +++ b/xformers/triton/k_fused_matmul_bw.py @@ -10,8 +10,6 @@ import triton import triton.language as tl -from xformers.triton.sum_strided import sum_2d_dim_0 - # fmt: off @triton.heuristics({ @@ -146,6 +144,6 @@ def grid(META): # The following ops can also be handled by triton grad_in = grad_out_ @ weight grad_weight = grad_out_.transpose(1, 0) @ inputs_ if trainable_weight else None - grad_bias = sum_2d_dim_0(grad_out_) if trainable_bias else None + grad_bias = torch.sum(grad_out_, 0) if trainable_bias else None return grad_in.reshape_as(inputs), grad_weight, grad_bias diff --git a/xformers/triton/k_sum.py b/xformers/triton/k_sum.py deleted file mode 100644 index 998d6f04d..000000000 --- a/xformers/triton/k_sum.py +++ /dev/null @@ -1,66 +0,0 @@ -# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. -# -# This source code is licensed under the BSD license found in the -# LICENSE file in the root directory of this source tree. - -import triton -import triton.language as tl - - -# fmt: off -@triton.autotune( - configs=[ - triton.Config({"BLOCK_M": 32, "BLOCK_N": 16}, num_stages=5, num_warps=1), - triton.Config({"BLOCK_M": 64, "BLOCK_N": 16}, num_stages=5, num_warps=2), - triton.Config({"BLOCK_M": 128, "BLOCK_N": 16}, num_stages=5, num_warps=2), - triton.Config({"BLOCK_M": 256, "BLOCK_N": 16}, num_stages=5, num_warps=2), - triton.Config({"BLOCK_M": 512, "BLOCK_N": 16}, num_stages=5, num_warps=2), - triton.Config({"BLOCK_M": 1024, "BLOCK_N": 8}, num_stages=5, num_warps=2), - triton.Config({"BLOCK_M": 2048, "BLOCK_N": 8}, num_stages=5, num_warps=2), - triton.Config({"BLOCK_M": 4096, "BLOCK_N": 8}, num_stages=4, num_warps=2), - ], - key=["M", "N", "is_fp16"], -) -@triton.jit -def k_sum_0( - Y, X, - stride_xm, - M, N, - is_fp16, - **meta, -): - # fmt: om - - """ - Sum a 2d tensor over the first (strided) dimension. - This extracts some speed through a parallel sum across the second dimension - """ - BLOCK_M = meta["BLOCK_M"] - BLOCK_N = meta["BLOCK_N"] - - # partial row indices. We'll reduce over this dimension - m = tl.arange(0, BLOCK_M) - - # To get some extra parallelization, we handle several columns in the same thread block - rn = tl.program_id(axis=0) * BLOCK_N + tl.arange(0, BLOCK_N) - - # the memory address of all the elements that we want to load can be computed as follows - x_ptrs = X + m[:, None] * stride_xm + rn[None, :] - x_sum = tl.zeros((BLOCK_N,), dtype=tl.float32) - - tiles = M // BLOCK_M - if M % BLOCK_M > 0: - tiles += 1 - - for _ in range(tiles): - # load input data; pad out-of-bounds elements with 0 - # NOTE: make sure to accumulate in fp32 to prevent a trivial overflow - mask = (m[:, None] < M) & (rn[None, :] < N) - x = tl.load(x_ptrs, mask=mask, other=0.0) - x_sum += tl.sum(x, 0) - - # move the load pointer - x_ptrs += BLOCK_M * stride_xm - m += BLOCK_M # update the mask check - - tl.store(Y + rn, x_sum, mask=rn < N) diff --git a/xformers/triton/sum_strided.py b/xformers/triton/sum_strided.py deleted file mode 100644 index 8d15bd5e7..000000000 --- a/xformers/triton/sum_strided.py +++ /dev/null @@ -1,48 +0,0 @@ -# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. -# -# This source code is licensed under the BSD license found in the -# LICENSE file in the root directory of this source tree. - - -import torch -import triton - -from xformers.triton.k_sum import k_sum_0 - - -def sum_2d_dim_0(x: torch.Tensor): - """ - Sum a 2D tensor across the first dimension - """ - - out = torch.empty(x.shape[1], device=x.device, dtype=x.dtype) - - assert ( - x.ndim == 2 - ), "This is a very specific kernel, only for 2-dim tensors and summing along dim 0" - - assert ( - x.shape[0] >= 4 - ), "This is a very specific kernel, requires the reduction dimension to be bigger than 4" - - assert x.stride(1) == 1, ( - "We're expecting x to be contiguous along dim 1, and non contiguous along dim 0.\n" - " You would probably be better served with torch.sum()" - ) - - # Manually handle the scheduling - M, N = x.shape - - def grid(meta): - return (triton.cdiv(N, meta["BLOCK_N"]),) - - # fmt: off - k_sum_0[grid]( - out, x, - x.stride(0), - M, N, - x.dtype == torch.float16, - ) - # fmt: on - - return out