diff --git a/.gitignore b/.gitignore index a8541986e..3364ebe5d 100644 --- a/.gitignore +++ b/.gitignore @@ -51,4 +51,4 @@ examples/data # Hydra default output dir multirun -outputs \ No newline at end of file +outputs diff --git a/docs/plots/strided_sum/Strided_sum_fp16.png b/docs/plots/strided_sum/Strided_sum_fp16.png new file mode 100644 index 000000000..e8fd201f6 Binary files /dev/null and b/docs/plots/strided_sum/Strided_sum_fp16.png differ diff --git a/docs/plots/strided_sum/Strided_sum_fp32.png b/docs/plots/strided_sum/Strided_sum_fp32.png new file mode 100644 index 000000000..5e1c1f22a Binary files /dev/null and b/docs/plots/strided_sum/Strided_sum_fp32.png differ diff --git a/tests/test_triton_basics.py b/tests/test_triton_basics.py index e41ab7afa..1376aff45 100644 --- a/tests/test_triton_basics.py +++ b/tests/test_triton_basics.py @@ -4,14 +4,32 @@ # LICENSE file in the root directory of this source tree. +import pytest import torch +SHAPES = [ + (384, 128), + (8 * 384, 128), + (34, 128), + (16, 128), + (16, 512), + (8, 384), + (8, 1024), + (8, 2048), + (8, 4096), + (8, 4096), + (4, 12288), +] + + _triton_available = torch.cuda.is_available() if _triton_available: try: import triton import triton.language as tl + from xformers.triton.sum_strided import sum_2d_dim_0 + except (ImportError, ModuleNotFoundError): _triton_available = False @@ -39,7 +57,7 @@ def k_mean(X, Mean, Var, stride, N, **META): # Compute variance x_mean = tl.sum(x, axis=0) / N x_zm = x - x_mean - x_zm = tl.where(cols < N, x_zm, 0.0) # THIS SHOULD NOT BE NEEDED + x_zm = tl.where(cols < N, x_zm, 0.0) x_var = tl.sum(x_zm * x_zm, axis=0) / N tl.store(Mean + row, x_mean) tl.store(Var + row, x_var) @@ -88,3 +106,33 @@ def test_mean(): assert torch.allclose(mean, t_mean, rtol=1e-1) assert torch.allclose(var, t_var, rtol=1e-1) + + @pytest.mark.parametrize("shape", SHAPES) + @pytest.mark.parametrize("dtype", [torch.float16, torch.float32]) + def test_sum_strided(shape, dtype): + torch.random.manual_seed(0) + a = torch.rand(shape, device=torch.device("cuda"), dtype=dtype) + + torch_sum = torch.sum(a, dim=0) + triton_sum = sum_2d_dim_0(a) + assert torch.allclose( + torch_sum, triton_sum, rtol=0.01 + ), f"{torch_sum}\n{triton_sum}" + + def test_sum_strided_asserts(): + torch.random.manual_seed(0) + a = torch.rand((128, 256), device=torch.device("cuda"), dtype=torch.float16) + + with pytest.raises(AssertionError): + # This kernel is not useful in that case, assert to prevent misuse + sum_2d_dim_0(a.transpose(1, 0)) + + a = torch.rand((3, 128, 256), device=torch.device("cuda"), dtype=torch.float16) + with pytest.raises(AssertionError): + # This kernel expects 2D tensors, assert to prevent misuse + sum_2d_dim_0(a) + + a = torch.rand((2, 128), device=torch.device("cuda"), dtype=torch.float16) + with pytest.raises(AssertionError): + # This kernel cannot sum over dimensions < 4 + sum_2d_dim_0(a) diff --git a/tests/test_triton_dropout.py b/tests/test_triton_dropout.py index 441f56de7..bf510ad61 100644 --- a/tests/test_triton_dropout.py +++ b/tests/test_triton_dropout.py @@ -25,7 +25,7 @@ ) _triton_available = False -# Testing odd shapes on purpose +# Testing odd (non-power-of-two for instance) shapes on purpose SHAPES = [ (384, 128), (8, 384, 128), @@ -90,6 +90,10 @@ def test_dropout(shape, amp, bias): == y.shape[1] ) + # Check that the drop probability is about right + drop_p = (y_a.numel() - y_a.count_nonzero()) / y_a.numel() + assert drop_p < 0.55 and drop_p > 0.45 + @pytest.mark.skipif(not _triton_available, reason="Triton is not available") @pytest.mark.skipif( @@ -151,4 +155,4 @@ def test_dropout_parity(shape, amp, bias, activation, p): if bias: assert torch.allclose( torch.norm(b.grad), torch.norm(b_.grad), rtol=0.01 - ), f"{b.grad}\n{b_.grad}" + ), f"{b.grad.norm()}\n{b_.grad.norm()}" diff --git a/xformers/benchmarks/benchmark_triton_dropout.py b/xformers/benchmarks/benchmark_triton_dropout.py index 376f2d70c..aa806e6f3 100644 --- a/xformers/benchmarks/benchmark_triton_dropout.py +++ b/xformers/benchmarks/benchmark_triton_dropout.py @@ -18,8 +18,8 @@ (8, 512, 1024), (4, 1024, 1024), (2, 2048, 2048), - (2, 4096, 4096), (1, 2048, 12288), + (2, 4096, 4096), ] P = 0.1 @@ -105,7 +105,7 @@ def triton_step(x): ) -for activation in [Activation.GeLU, None]: +for activation in [Activation.SquaredReLU, Activation.GeLU, None]: for bw in [True, False]: for bias in [True, False]: bench_dropout(bias, bw, activation) diff --git a/xformers/benchmarks/benchmark_triton_stride_sum.py b/xformers/benchmarks/benchmark_triton_stride_sum.py new file mode 100644 index 000000000..6fb887e78 --- /dev/null +++ b/xformers/benchmarks/benchmark_triton_stride_sum.py @@ -0,0 +1,71 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Any, Dict, List + +import torch +import triton + +from xformers.benchmarks.utils import TestCase, pretty_plot, pretty_print +from xformers.triton.sum_strided import sum_2d_dim_0 + +SHAPES = [ + (128, 128), + (384, 128), + (784, 512), + (1024, 768), + (2048, 1024), + (4096, 4096), +] + + +def to_gbs(a, ms): + # Read the full array, write the non-reduced dimension + return ((a.numel() + a.shape[1]) * a.element_size() * 1e-9) / (ms * 1e-3) + + +def bench_functions( + test_cases: List[TestCase], shapes, metric_transform, unit, title="" +): + device = torch.device("cuda") + + for dtype in [torch.float16, torch.float32]: + results: Dict[str, Any] = {} + + for M, N in shapes: + a = torch.rand(M, N, device=device, dtype=dtype, requires_grad=True) + + for testcase in test_cases: + time = triton.testing.do_bench(lambda: testcase.function(a))[0] + + metric = metric_transform(a, time) + + key = f"M={M}, N={N}" + if key not in results: + results[key] = {} + + results[key][testcase.name] = f"{metric:.1f}" + + _type = " fp16" if dtype == torch.float16 else " fp32" + + pretty_print( + results, + title=" ------------- Type: {} ------------- ".format(_type), + units=unit, + ) + + pretty_plot(results, title + _type, unit, dash_key="pytorch") + + +bench_functions( + [ + TestCase(lambda x: torch.sum(x, dim=0), "pytorch"), + TestCase(sum_2d_dim_0, "triton"), + ], + SHAPES, + to_gbs, + "GB/s", + "Strided_sum", +) diff --git a/xformers/benchmarks/utils.py b/xformers/benchmarks/utils.py index d8ab2e386..365d010fa 100644 --- a/xformers/benchmarks/utils.py +++ b/xformers/benchmarks/utils.py @@ -28,12 +28,12 @@ def pretty_print(results, title, units): """ Printout the contents of a dict as a human-readable and Markdown compatible array""" print(title) - header = " Units: {:<40}".format(units) - print("|" + header + "|" + "".join("{0:<20}|".format(k) for k in results.keys())) + header = " Units: {:<45}".format(units) + print("| " + header + "|" + "".join("{0:<20}|".format(k) for k in results.keys())) offset = len(header) print( - "|{}|".format("-" * offset) + "|-{}|".format("-" * offset) + "".join("{}|".format("-" * 20) for _ in results.keys()) ) @@ -44,7 +44,7 @@ def pretty_print(results, title, units): for k, w in workloads.items(): print( - "|{0:<{offset}}|".format(k, offset=offset) + "| {0:<{offset}}|".format(k, offset=offset) + "".join("{:<20}|".format(v) for v in w) ) @@ -85,7 +85,7 @@ def pretty_plot(results, title, units: str, filename=None, dash_key=""): plt.xticks(rotation=45) plt.savefig(filename, bbox_inches="tight") - plt.clf() + plt.close(f) if _triton_is_available: diff --git a/xformers/components/__init__.py b/xformers/components/__init__.py index 60536ef0d..2bf470714 100644 --- a/xformers/components/__init__.py +++ b/xformers/components/__init__.py @@ -13,7 +13,8 @@ from .activations import Activation, build_activation # noqa from .attention import Attention, build_attention # noqa from .in_proj_container import InProjContainer, InProjParams # noqa -from .multi_head_dispatch import MultiHeadDispatch, MultiHeadDispatchConfig # noqa +from .multi_head_dispatch import MultiHeadDispatch # noqa +from .multi_head_dispatch import MultiHeadDispatchConfig from .residual import LayerNormStyle, PostNorm, PreNorm, Residual # noqa # automatically import any Python files in the directory diff --git a/xformers/triton/dropout.py b/xformers/triton/dropout.py index 30e44ff2d..965629ea9 100644 --- a/xformers/triton/dropout.py +++ b/xformers/triton/dropout.py @@ -25,34 +25,31 @@ class _dropout(torch.autograd.Function): @staticmethod @custom_fwd(cast_inputs=torch.float16) - def forward(ctx, x, p, bias, activation, activation_grad): + def forward(ctx, x, p, bias, activation, activation_grad, trainable_bias): # Soft-flatten an hypothetical 3rd dimension x_ = x.reshape(-1, x.shape[-1]).contiguous() y = torch.empty_like(x_) - _, N = x_.shape + M, N = x_.shape - assert bias is None or bias.dtype == x.dtype, bias + assert bias is None or (bias.dtype == x.dtype and bias.shape[0] == N) # Generate one seed per sample # seed max is int32 max for positive numbers: 2**16 - seeds = torch.randint(65536, (x_.shape[0],), device=x.device).to(torch.int32) + seeds = torch.randint(65536, (N,), device=x.device).to(torch.int32) - # SPMD launch grid def grid(meta): - return ( - x_.shape[0], - triton.cdiv(x_.shape[1], meta["BLOCK_SIZE"]), - ) + return (triton.cdiv(N, meta["BLOCK_N"]),) # fmt: off k_dropout_fw[grid]( - y, x_, bias if bias is not None else x_, + y, x_, + bias if bias is not None else x_, seeds, y.stride(0), - N, + M, N, p, USE_BIAS=bias is not None, - ACTIVATION=activation + ACTIVATION=activation, ) # fmt: on @@ -60,7 +57,8 @@ def grid(meta): ctx.save_for_backward(seeds, bias, x) else: ctx.save_for_backward(seeds, bias, None) - ctx.trainable_bias = bias is not None + + ctx.trainable_bias = bias is not None and trainable_bias ctx.activation_grad = activation_grad ctx.p = p @@ -75,7 +73,7 @@ def backward(ctx, grad_out): grad_out_ = grad_out.reshape(-1, grad_out.shape[-1]).contiguous() grad_in = torch.empty_like(grad_out_) - _, N = grad_out_.shape + M, N = grad_out_.shape # Optional inputs to compute the activation contribution to the gradient assert inputs is not None or ctx.activation_grad is None @@ -83,32 +81,38 @@ def backward(ctx, grad_out): if inputs is None: inputs = grad_out_ elif inputs.ndim > 2: - inputs = inputs.reshape(-1, grad_out.shape[-1]) + inputs = inputs.reshape(-1, N) + + if ctx.trainable_bias: + grad_bias = torch.empty((N,), device=grad_in.device, dtype=grad_in.dtype) + else: + grad_bias = grad_in # will not be used - # SPMD launch grid def grid(meta): - return ( - grad_out_.shape[0], - triton.cdiv(grad_out_.shape[1], meta["BLOCK_SIZE"]), - ) + return (triton.cdiv(N, meta["BLOCK_N"]),) # fmt: off k_dropout_bw[grid]( - grad_in, grad_out_, inputs, bias if bias is not None else inputs, + grad_in, grad_bias, grad_out_, + inputs, bias if bias is not None else inputs, seeds, grad_out_.stride(0), inputs.stride(0), - N, + M, N, ctx.p, USE_BIAS=bias is not None, - ACTIVATION_GRAD=ctx.activation_grad) + ACTIVATION_GRAD=ctx.activation_grad, + TRAINABLE_BIAS=ctx.trainable_bias + ) # fmt: on - if ctx.trainable_bias: - grad_bias: Optional[torch.Tensor] = torch.sum(grad_in, dim=0) - else: - grad_bias = None - - return grad_in.reshape_as(grad_out), None, grad_bias, None, None + return ( + grad_in.reshape_as(grad_out), + None, + grad_bias if ctx.trainable_bias else None, + None, + None, + None, + ) def dropout( @@ -128,7 +132,14 @@ def dropout( act_kernel = get_triton_activation_kernel(activation) act_grad_kernel = get_triton_activation_bwd_kernel(activation) - return _dropout.apply(x, p, bias, act_kernel, act_grad_kernel) + return _dropout.apply( + x, + p, + bias, + act_kernel, + act_grad_kernel, + bias is not None and bias.requires_grad, + ) class FusedDropoutBias(torch.nn.Module): @@ -141,8 +152,10 @@ def __init__( super().__init__() self.p = p self.activation = activation - self.register_buffer( - "bias", torch.zeros(bias_shape) if bias_shape is not None else None + self.bias = ( + torch.zeros(bias_shape, requires_grad=True) + if bias_shape is not None + else None ) self.activation = get_triton_activation_kernel(activation) self.activation_grad = get_triton_activation_bwd_kernel(activation) @@ -153,4 +166,6 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: self.bias = self.bias.to(dtype=x.dtype, device=x.device) # type: ignore p = self.p if self.training else 0.0 - return _dropout.apply(x, p, self.bias, self.activation, self.activation_grad) + return _dropout.apply( + x, p, self.bias, self.activation, self.activation_grad, True + ) diff --git a/xformers/triton/k_activations.py b/xformers/triton/k_activations.py index 0964096d6..31049101c 100644 --- a/xformers/triton/k_activations.py +++ b/xformers/triton/k_activations.py @@ -64,8 +64,7 @@ def relu(x): .. _ReLU: https://pytorch.org/docs/stable/generated/torch.nn.ReLU.html """ zero = 0.0 - zero = zero.to(x.dtype) - return tl.where(x >= 0, x, zero) + return tl.where(x >= 0, x, zero.to(x.dtype)) @triton.jit @@ -74,10 +73,8 @@ def relu_grad(x): # in that it does not require the input to retrospectively compute its gradient # here the input is the downstream gradient, and we return the upstream gradient directly zero = 0.0 - zero = zero.to(x.dtype) one = 1.0 - one = one.to(x.dtype) - return tl.where(x >= 0, one, zero) + return tl.where(x >= 0, one.to(x.dtype), zero.to(x.dtype)) @triton.jit @@ -88,7 +85,7 @@ def squared_relu(x): .. _Primer: https://arxiv.org/abs/2109.08668 """ x_ = relu(x) - return x_ * x_ + return (x_ * x_).to(x.dtype) @triton.jit diff --git a/xformers/triton/k_dropout.py b/xformers/triton/k_dropout.py index 61878840f..1fa4a18f7 100644 --- a/xformers/triton/k_dropout.py +++ b/xformers/triton/k_dropout.py @@ -10,138 +10,246 @@ import triton import triton.language as tl -_k_configs = [ - triton.Config({"BLOCK_SIZE": 128}, num_warps=1), - triton.Config({"BLOCK_SIZE": 512}, num_warps=2), - triton.Config({"BLOCK_SIZE": 1024}, num_warps=4), - triton.Config({"BLOCK_SIZE": 2048}, num_warps=8), - triton.Config({"BLOCK_SIZE": 4096}, num_warps=16), +k_configs = [ + triton.Config({"BLOCK_M": 128, "BLOCK_N": 32}), + triton.Config({"BLOCK_M": 128, "BLOCK_N": 64}), + triton.Config({"BLOCK_M": 128, "BLOCK_N": 128}), + triton.Config({"BLOCK_M": 256, "BLOCK_N": 32}), + triton.Config({"BLOCK_M": 256, "BLOCK_N": 64}), + triton.Config({"BLOCK_M": 256, "BLOCK_N": 128}), + triton.Config({"BLOCK_M": 512, "BLOCK_N": 32}), + triton.Config({"BLOCK_M": 512, "BLOCK_N": 64}), + triton.Config({"BLOCK_M": 512, "BLOCK_N": 128}), ] -@triton.jit -def _drop_and_scale(SEEDS, row, p, offsets, x): - # randomly prune the weights - seed = SEEDS + row - random = tl.rand(seed.to(tl.int32), offsets) - x_keep = random > p - - zero = 0.0 - zero = zero.to(x.dtype) - - # prune and normalize in one go - return tl.where(x_keep, (x / (1 - p)).to(x.dtype), zero) - - # fmt: off +@triton.heuristics({"SIZE_BLOCK": lambda *_, **meta: meta["BLOCK_M"]*meta["BLOCK_N"]}) @triton.autotune( - configs=_k_configs, - key=["N"], + configs=k_configs, + key=["M", "N"], ) @triton.jit def k_dropout_fw( Y, X, BIAS, SEEDS, stride, - N, + M, N, p, - **META, + **meta, ): """ Apply dropout on an input tensor - Y : Output (M, N) - X : Input (M, N) - S : Seeds (M,) + Y : Output (M, N) + X : Input (M, N) + BIAS (N,) + SEEDS (M,) p : dropout probability """ # fmt: on - BLOCK_SIZE = META["BLOCK_SIZE"] - row = tl.program_id(axis=0) - col = tl.program_id(axis=1) + BLOCK_M = meta["BLOCK_M"] + BLOCK_N = meta["BLOCK_N"] + SIZE_BLOCK = meta["SIZE_BLOCK"] + + rows = tl.arange(0, BLOCK_M) + col_id = tl.program_id(axis=0) + cols = col_id * BLOCK_N + tl.arange(0, BLOCK_N) + seed = SEEDS + col_id + tiles = tl.cdiv(M, BLOCK_M) + + # pointers starting point + x_ptrs = X + rows[:, None] * stride + cols[None, :] + y_ptrs = Y + rows[:, None] * stride + cols[None, :] + + # go over all the tiles, one by one + rand_offsets = tl.arange(0, SIZE_BLOCK) + rand1, rand2, rand3, rand4 = tl.randint4x(seed.to(tl.int32), rand_offsets) + threshold = ((p - 0.5) * 2147483648.).to(tl.int32) + + # binarize masks, save registers + rand_mask1 = rand1 > threshold + rand_mask2 = rand2 > threshold + rand_mask3 = rand3 > threshold + rand_mask4 = rand4 > threshold + rand_mask = rand_mask1 + + col_mask = cols[None, :] < N + p_scale = 1/(1-p) if p < 1. else 1. + zero = 0.0 + + if meta["USE_BIAS"]: + b_ptrs = BIAS + cols[None, :] + bias = tl.load(b_ptrs, mask=cols[None, :] < N, other=0.) + + i = 0 - # compute memory offsets of elements handled by this instance - offsets = row * stride + col * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) - mask = col * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) < N + for _ in range(tiles): + block_mask = (rows[:, None] < M) & col_mask + x = tl.load(x_ptrs, mask=block_mask, other=0.) - # load data from x - x_ptrs = X + offsets - x = tl.load(x_ptrs, mask=mask) + # optionally apply a fused bias + if meta["USE_BIAS"]: + x += bias - # optionally apply a fused bias - if META["USE_BIAS"]: - b_ptrs = BIAS + col * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) - b = tl.load(b_ptrs, mask=mask) - x += b + # optional: fused activation (while the data is in shared memory) + if meta["ACTIVATION"]: + x = meta["ACTIVATION"](x) - # optional: fused activation (while the data is in shared memory) - if META["ACTIVATION"]: - x = META["ACTIVATION"](x) + # randomly prune and scale + if p > 0.: + # generate all the random numbers for the block at once, then reshape + keep = tl.reshape(rand_mask, x.shape) - # randomly prune it - if p > 0.: - output = _drop_and_scale(SEEDS, row, p, offsets, x) - else: - output = x + # prune and normalize in one go + output = tl.where(keep, (x * p_scale).to(x.dtype), zero.to(x.dtype)) + else: + output = x - y_ptrs = Y + offsets - tl.store(y_ptrs, output, mask=mask) + tl.store(y_ptrs, output, mask=block_mask) + + # Update the pointers + rows += BLOCK_M # needs to be updated for the mask to be correct + x_ptrs += BLOCK_M * stride + y_ptrs += BLOCK_M * stride + + # update the seed offset + rand_offsets += SIZE_BLOCK + + # cycle through the binary masks + if i == 0: + rand_mask = rand_mask2 + elif i == 1: + rand_mask = rand_mask3 + elif i == 2: + rand_mask = rand_mask4 + else: + rand_mask = rand_mask1 + + i = (i+1) % 4 # fmt: off +@triton.heuristics({"SIZE_BLOCK": lambda *_, **meta: meta["BLOCK_M"]*meta["BLOCK_N"]}) @triton.autotune( - configs=_k_configs, - key=["N"], + configs=k_configs, + key=["M", "N"], ) @triton.jit def k_dropout_bw( - GRAD_IN, GRAD_OUT, INPUTS, BIAS, SEEDS, + GRAD_IN, GRAD_BIAS, GRAD_OUT, + INPUTS, BIAS, SEEDS, stride_grad, stride_inputs, - N, + M, N, p, - **META, + **meta, ): """ Apply dropout on an input tensor GRAD_OUT (M, N) + GRAD_BIAS (N,) GRAD_IN (M, N) BIAS (N,) - SEEDS (M,) + SEEDS (N,) p : dropout probability """ # fmt: on - BLOCK_SIZE = META["BLOCK_SIZE"] - row = tl.program_id(axis=0) - col = tl.program_id(axis=1) - - # compute memory offsets of elements handled by this instance - grad_offsets = row * stride_grad + col * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) - mask = col * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) < N - - # load data from x - grad_out_ptrs = GRAD_OUT + grad_offsets - grad_out = tl.load(grad_out_ptrs, mask=mask) - - # optional: fused activation (while the data is in shared memory) - if META["ACTIVATION_GRAD"]: - input_ptrs = INPUTS + row * stride_inputs + col * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) - inputs = tl.load(input_ptrs, mask=mask) - - # optionally apply a fused bias - if META["USE_BIAS"]: - b_ptrs = BIAS + col * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) - b = tl.load(b_ptrs, mask=mask) - inputs += b - - act_grad = META["ACTIVATION_GRAD"](inputs) - grad_out *= act_grad - - # randomly prune it - if p > 0.: - output = _drop_and_scale(SEEDS, row, p, grad_offsets, grad_out) - else: - output = grad_out - - # write-back - y_ptrs = GRAD_IN + grad_offsets - tl.store(y_ptrs, output, mask=mask) + BLOCK_M = meta["BLOCK_M"] + BLOCK_N = meta["BLOCK_N"] + SIZE_BLOCK = meta["SIZE_BLOCK"] + TRAINABLE_BIAS = meta["TRAINABLE_BIAS"] + + rows = tl.arange(0, BLOCK_M) + col_id = tl.program_id(axis=0) + cols = col_id * BLOCK_N + tl.arange(0, BLOCK_N) + seed = SEEDS + col_id + tiles = tl.cdiv(M, BLOCK_M) + + # pointers starting point + grad_out_ptrs = GRAD_OUT + rows[:, None] * stride_grad + cols[None, :] + grad_in_ptrs = GRAD_IN + rows[:, None] * stride_grad + cols[None, :] + input_ptrs = INPUTS + rows[:, None] * stride_inputs + cols[None, :] + + # random binary masks, save registers + rand_offsets = tl.arange(0, SIZE_BLOCK) + rand1, rand2, rand3, rand4 = tl.randint4x(seed.to(tl.int32), rand_offsets) + threshold = ((p - 0.5) * 2147483648.).to(tl.int32) + + rand_mask1 = rand1 > threshold + rand_mask2 = rand2 > threshold + rand_mask3 = rand3 > threshold + rand_mask4 = rand4 > threshold + rand_mask = rand_mask1 + + # now go over the tiles + grad_bias = tl.zeros((BLOCK_N,), dtype=tl.float32) + col_mask = cols[None, :] < N + zero = 0.0 + p_scale = 1/(1-p) if p < 1. else 1. + + if meta["USE_BIAS"]: + b_ptrs = BIAS + cols[None, :] + bias = tl.load(b_ptrs, mask=col_mask, other=0.) + + i = 0 + + for _ in range(tiles): + block_mask = (rows[:, None] < M) & col_mask + grad_out = tl.load(grad_out_ptrs, mask=block_mask, other=0.) + + # optional: fused activation (while the data is in shared memory) + if meta["ACTIVATION_GRAD"]: + inputs = tl.load(input_ptrs, mask=block_mask, other=0.) + + # optionally apply a fused bias + if meta["USE_BIAS"]: + inputs += bias + + act_grad = meta["ACTIVATION_GRAD"](inputs).to(grad_out.dtype) + grad_out *= act_grad + + # randomly prune and scale + if p > 0.: + # generate all the random numbers for the block at once, then reshape + keep = tl.reshape(rand_mask, grad_out.shape) + + # prune and normalize in one go + output = tl.where( + keep, + (grad_out * p_scale).to(grad_out.dtype), + zero.to(grad_out.dtype) + ) + else: + output = grad_out + + # write-back + tl.store(grad_in_ptrs, output, mask=block_mask) + + # optionally accumulate the bias gradient + if TRAINABLE_BIAS: + grad_bias += tl.sum(output, axis=0) + + # Update the pointers + rows += BLOCK_M # needs to be updated for the mask to be correct + grad_out_ptrs += BLOCK_M * stride_grad + input_ptrs += BLOCK_M * stride_inputs + grad_in_ptrs += BLOCK_M * stride_grad + + # update the seed offset + rand_offsets += SIZE_BLOCK + + # cycle through the binary masks + if i == 0: + rand_mask = rand_mask2 + elif i == 1: + rand_mask = rand_mask3 + elif i == 2: + rand_mask = rand_mask4 + else: + rand_mask = rand_mask1 + + i = (i+1) % 4 + + if TRAINABLE_BIAS: + grad_bias_ptr = GRAD_BIAS + cols + tl.store(grad_bias_ptr, grad_bias, mask=cols < N) diff --git a/xformers/triton/k_fused_matmul_bw.py b/xformers/triton/k_fused_matmul_bw.py index 49d25f3c3..bfe322924 100644 --- a/xformers/triton/k_fused_matmul_bw.py +++ b/xformers/triton/k_fused_matmul_bw.py @@ -10,6 +10,8 @@ import triton import triton.language as tl +from xformers.triton.sum_strided import sum_2d_dim_0 + # fmt: off @triton.heuristics({ @@ -144,6 +146,6 @@ def grid(META): # The following ops can also be handled by triton grad_in = grad_out_ @ weight grad_weight = grad_out_.transpose(1, 0) @ inputs_ if trainable_weight else None - grad_bias = torch.sum(grad_out_, 0) if trainable_bias else None + grad_bias = sum_2d_dim_0(grad_out_) if trainable_bias else None return grad_in.reshape_as(inputs), grad_weight, grad_bias diff --git a/xformers/triton/k_sum.py b/xformers/triton/k_sum.py new file mode 100644 index 000000000..998d6f04d --- /dev/null +++ b/xformers/triton/k_sum.py @@ -0,0 +1,66 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + +import triton +import triton.language as tl + + +# fmt: off +@triton.autotune( + configs=[ + triton.Config({"BLOCK_M": 32, "BLOCK_N": 16}, num_stages=5, num_warps=1), + triton.Config({"BLOCK_M": 64, "BLOCK_N": 16}, num_stages=5, num_warps=2), + triton.Config({"BLOCK_M": 128, "BLOCK_N": 16}, num_stages=5, num_warps=2), + triton.Config({"BLOCK_M": 256, "BLOCK_N": 16}, num_stages=5, num_warps=2), + triton.Config({"BLOCK_M": 512, "BLOCK_N": 16}, num_stages=5, num_warps=2), + triton.Config({"BLOCK_M": 1024, "BLOCK_N": 8}, num_stages=5, num_warps=2), + triton.Config({"BLOCK_M": 2048, "BLOCK_N": 8}, num_stages=5, num_warps=2), + triton.Config({"BLOCK_M": 4096, "BLOCK_N": 8}, num_stages=4, num_warps=2), + ], + key=["M", "N", "is_fp16"], +) +@triton.jit +def k_sum_0( + Y, X, + stride_xm, + M, N, + is_fp16, + **meta, +): + # fmt: om + + """ + Sum a 2d tensor over the first (strided) dimension. + This extracts some speed through a parallel sum across the second dimension + """ + BLOCK_M = meta["BLOCK_M"] + BLOCK_N = meta["BLOCK_N"] + + # partial row indices. We'll reduce over this dimension + m = tl.arange(0, BLOCK_M) + + # To get some extra parallelization, we handle several columns in the same thread block + rn = tl.program_id(axis=0) * BLOCK_N + tl.arange(0, BLOCK_N) + + # the memory address of all the elements that we want to load can be computed as follows + x_ptrs = X + m[:, None] * stride_xm + rn[None, :] + x_sum = tl.zeros((BLOCK_N,), dtype=tl.float32) + + tiles = M // BLOCK_M + if M % BLOCK_M > 0: + tiles += 1 + + for _ in range(tiles): + # load input data; pad out-of-bounds elements with 0 + # NOTE: make sure to accumulate in fp32 to prevent a trivial overflow + mask = (m[:, None] < M) & (rn[None, :] < N) + x = tl.load(x_ptrs, mask=mask, other=0.0) + x_sum += tl.sum(x, 0) + + # move the load pointer + x_ptrs += BLOCK_M * stride_xm + m += BLOCK_M # update the mask check + + tl.store(Y + rn, x_sum, mask=rn < N) diff --git a/xformers/triton/sum_strided.py b/xformers/triton/sum_strided.py new file mode 100644 index 000000000..8d15bd5e7 --- /dev/null +++ b/xformers/triton/sum_strided.py @@ -0,0 +1,48 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + + +import torch +import triton + +from xformers.triton.k_sum import k_sum_0 + + +def sum_2d_dim_0(x: torch.Tensor): + """ + Sum a 2D tensor across the first dimension + """ + + out = torch.empty(x.shape[1], device=x.device, dtype=x.dtype) + + assert ( + x.ndim == 2 + ), "This is a very specific kernel, only for 2-dim tensors and summing along dim 0" + + assert ( + x.shape[0] >= 4 + ), "This is a very specific kernel, requires the reduction dimension to be bigger than 4" + + assert x.stride(1) == 1, ( + "We're expecting x to be contiguous along dim 1, and non contiguous along dim 0.\n" + " You would probably be better served with torch.sum()" + ) + + # Manually handle the scheduling + M, N = x.shape + + def grid(meta): + return (triton.cdiv(N, meta["BLOCK_N"]),) + + # fmt: off + k_sum_0[grid]( + out, x, + x.stride(0), + M, N, + x.dtype == torch.float16, + ) + # fmt: on + + return out