diff --git a/tests/test_triton_dropout.py b/tests/test_triton_dropout.py index 74e0bc73b..9667be017 100644 --- a/tests/test_triton_dropout.py +++ b/tests/test_triton_dropout.py @@ -97,6 +97,10 @@ def test_dropout(shape, amp, bias): == y.shape[1] ) + # Check that the drop probability is about right + drop_p = (y_a.numel() - y_a.count_nonzero()) / y_a.numel() + assert drop_p < 0.55 and drop_p > 0.45 + @pytest.mark.skipif(not _triton_available, reason="Triton is not available") @pytest.mark.skipif( @@ -107,7 +111,7 @@ def test_dropout(shape, amp, bias): @pytest.mark.parametrize("amp", [False, True]) @pytest.mark.parametrize("bias", [True, False]) @pytest.mark.parametrize("activation", [a.value for a in Activation]) -@pytest.mark.parametrize("p", [0, 0.001, 0.5]) +@pytest.mark.parametrize("p", [0, 0.01, 0.5]) def test_dropout_parity(shape, amp, bias, activation, p): """ Check some basic dropout properties @@ -158,4 +162,4 @@ def test_dropout_parity(shape, amp, bias, activation, p): if bias: assert torch.allclose( torch.norm(b.grad), torch.norm(b_.grad), rtol=0.01 - ), f"{b.grad.norm()}\n{b_.grad.norm()}" + ), f"{b.grad.norm()} - {b_.grad.norm()}" diff --git a/xformers/benchmarks/benchmark_triton_dropout.py b/xformers/benchmarks/benchmark_triton_dropout.py index 8aa05a33c..d13e3fe81 100644 --- a/xformers/benchmarks/benchmark_triton_dropout.py +++ b/xformers/benchmarks/benchmark_triton_dropout.py @@ -62,12 +62,14 @@ def torch_step(x): y = torch_act(y) if backward: + y.grad = None torch.norm(y).backward() return y def triton_step(x): y = triton_dropout(x) if backward: + y.grad = None torch.norm(y).backward() return y @@ -105,7 +107,7 @@ def triton_step(x): ) -for activation in [Activation.GeLU, None]: +for activation in [Activation.GeLU, None, Activation.SquaredReLU]: for bw in [True, False]: for bias in [True, False]: bench_dropout(bias, bw, activation) diff --git a/xformers/benchmarks/utils.py b/xformers/benchmarks/utils.py index d8ab2e386..365d010fa 100644 --- a/xformers/benchmarks/utils.py +++ b/xformers/benchmarks/utils.py @@ -28,12 +28,12 @@ def pretty_print(results, title, units): """ Printout the contents of a dict as a human-readable and Markdown compatible array""" print(title) - header = " Units: {:<40}".format(units) - print("|" + header + "|" + "".join("{0:<20}|".format(k) for k in results.keys())) + header = " Units: {:<45}".format(units) + print("| " + header + "|" + "".join("{0:<20}|".format(k) for k in results.keys())) offset = len(header) print( - "|{}|".format("-" * offset) + "|-{}|".format("-" * offset) + "".join("{}|".format("-" * 20) for _ in results.keys()) ) @@ -44,7 +44,7 @@ def pretty_print(results, title, units): for k, w in workloads.items(): print( - "|{0:<{offset}}|".format(k, offset=offset) + "| {0:<{offset}}|".format(k, offset=offset) + "".join("{:<20}|".format(v) for v in w) ) @@ -85,7 +85,7 @@ def pretty_plot(results, title, units: str, filename=None, dash_key=""): plt.xticks(rotation=45) plt.savefig(filename, bbox_inches="tight") - plt.clf() + plt.close(f) if _triton_is_available: diff --git a/xformers/triton/dropout.py b/xformers/triton/dropout.py index ec4cb09bd..ecda101f7 100644 --- a/xformers/triton/dropout.py +++ b/xformers/triton/dropout.py @@ -26,34 +26,39 @@ class _dropout(torch.autograd.Function): @staticmethod @custom_fwd(cast_inputs=torch.float16) - def forward(ctx, x, p, bias, activation, activation_grad): + def forward(ctx, x, p, bias, activation, activation_grad, trainable_bias): # Soft-flatten an hypothetical 3rd dimension x_ = x.reshape(-1, x.shape[-1]).contiguous() y = torch.empty_like(x_) - _, N = x_.shape + M, N = x_.shape - assert bias is None or bias.dtype == x.dtype, bias + assert bias is None or (bias.dtype == x.dtype and bias.shape[0] == N) # Generate one seed per sample # seed max is int32 max for positive numbers: 2**16 - seeds = torch.randint(65536, (x_.shape[0],), device=x.device).to(torch.int32) + # FIXME: adjust the number of seeds needed + seeds = torch.randint(65536, (N,), device=x.device).to(torch.int32) - # SPMD launch grid def grid(meta): return ( - x_.shape[0], - triton.cdiv(x_.shape[1], meta["BLOCK_SIZE"]), + triton.cdiv(M, meta["BLOCK_M"] * 4), + triton.cdiv(N, meta["BLOCK_N"]), ) + GROUP_M = 16 + BLOCK_M = GROUP_M // 4 + # fmt: off k_dropout_fw[grid]( - y, x_, bias if bias is not None else x_, + y, x_, + bias if bias is not None else x_, seeds, y.stride(0), - N, + M, N, p, USE_BIAS=bias is not None, - ACTIVATION=activation + ACTIVATION=activation, + BLOCK_M=BLOCK_M ) # fmt: on @@ -61,7 +66,8 @@ def grid(meta): ctx.save_for_backward(seeds, bias, x) else: ctx.save_for_backward(seeds, bias, None) - ctx.trainable_bias = bias is not None + + ctx.trainable_bias = bias is not None and trainable_bias ctx.activation_grad = activation_grad ctx.p = p @@ -76,7 +82,7 @@ def backward(ctx, grad_out): grad_out_ = grad_out.reshape(-1, grad_out.shape[-1]).contiguous() grad_in = torch.empty_like(grad_out_) - _, N = grad_out_.shape + M, N = grad_out_.shape # Optional inputs to compute the activation contribution to the gradient assert inputs is not None or ctx.activation_grad is None @@ -84,32 +90,53 @@ def backward(ctx, grad_out): if inputs is None: inputs = grad_out_ elif inputs.ndim > 2: - inputs = inputs.reshape(-1, grad_out.shape[-1]) + inputs = inputs.reshape(-1, N) + + GROUP_M = 16 if M > 512 else M // 4 + BLOCK_M = GROUP_M // 4 + N_BLOCKS_M = triton.cdiv(M, GROUP_M) + + if ctx.trainable_bias: + grad_bias = torch.empty( + ( + N_BLOCKS_M, + N, + ), + device=grad_in.device, + dtype=grad_in.dtype, + ) + else: + grad_bias = grad_in # will not be used - # SPMD launch grid def grid(meta): return ( - grad_out_.shape[0], - triton.cdiv(grad_out_.shape[1], meta["BLOCK_SIZE"]), + triton.cdiv(M, meta["BLOCK_M"] * 4), + triton.cdiv(N, meta["BLOCK_N"]), ) # fmt: off k_dropout_bw[grid]( - grad_in, grad_out_, inputs, bias if bias is not None else inputs, + grad_in, grad_bias, grad_out_, + inputs, bias if bias is not None else inputs, seeds, grad_out_.stride(0), inputs.stride(0), - N, + M, N, ctx.p, USE_BIAS=bias is not None, - ACTIVATION_GRAD=ctx.activation_grad) + ACTIVATION_GRAD=ctx.activation_grad, + TRAINABLE_BIAS=ctx.trainable_bias, + BLOCK_M=BLOCK_M + ) # fmt: on - if ctx.trainable_bias: - grad_bias: Optional[torch.Tensor] = sum_2d_dim_0(grad_in) - else: - grad_bias = None - - return grad_in.reshape_as(grad_out), None, grad_bias, None, None + return ( + grad_in.reshape_as(grad_out), + None, + sum_2d_dim_0(grad_bias) if ctx.trainable_bias else None, + None, + None, + None, + ) def dropout( @@ -129,7 +156,14 @@ def dropout( act_kernel = get_triton_activation_kernel(activation) act_grad_kernel = get_triton_activation_bwd_kernel(activation) - return _dropout.apply(x, p, bias, act_kernel, act_grad_kernel) + return _dropout.apply( + x, + p, + bias, + act_kernel, + act_grad_kernel, + bias is not None and bias.requires_grad, + ) class FusedDropoutBias(torch.nn.Module): @@ -142,8 +176,10 @@ def __init__( super().__init__() self.p = p self.activation_type = activation - self.register_buffer( - "bias", torch.zeros(bias_shape) if bias_shape is not None else None + self.bias = ( + torch.zeros(bias_shape, requires_grad=True) + if bias_shape is not None + else None ) self.activation = get_triton_activation_kernel(activation) self.activation_grad = get_triton_activation_bwd_kernel(activation) @@ -160,5 +196,8 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: x = activation(x) return torch.nn.functional.dropout(x, self.p) + # The normal, Triton-backed path p = self.p if self.training else 0.0 - return _dropout.apply(x, p, self.bias, self.activation, self.activation_grad) + return _dropout.apply( + x, p, self.bias, self.activation, self.activation_grad, True + ) diff --git a/xformers/triton/k_activations.py b/xformers/triton/k_activations.py index 0964096d6..31049101c 100644 --- a/xformers/triton/k_activations.py +++ b/xformers/triton/k_activations.py @@ -64,8 +64,7 @@ def relu(x): .. _ReLU: https://pytorch.org/docs/stable/generated/torch.nn.ReLU.html """ zero = 0.0 - zero = zero.to(x.dtype) - return tl.where(x >= 0, x, zero) + return tl.where(x >= 0, x, zero.to(x.dtype)) @triton.jit @@ -74,10 +73,8 @@ def relu_grad(x): # in that it does not require the input to retrospectively compute its gradient # here the input is the downstream gradient, and we return the upstream gradient directly zero = 0.0 - zero = zero.to(x.dtype) one = 1.0 - one = one.to(x.dtype) - return tl.where(x >= 0, one, zero) + return tl.where(x >= 0, one.to(x.dtype), zero.to(x.dtype)) @triton.jit @@ -88,7 +85,7 @@ def squared_relu(x): .. _Primer: https://arxiv.org/abs/2109.08668 """ x_ = relu(x) - return x_ * x_ + return (x_ * x_).to(x.dtype) @triton.jit diff --git a/xformers/triton/k_dropout.py b/xformers/triton/k_dropout.py index 61878840f..a0a737050 100644 --- a/xformers/triton/k_dropout.py +++ b/xformers/triton/k_dropout.py @@ -10,138 +10,244 @@ import triton import triton.language as tl -_k_configs = [ - triton.Config({"BLOCK_SIZE": 128}, num_warps=1), - triton.Config({"BLOCK_SIZE": 512}, num_warps=2), - triton.Config({"BLOCK_SIZE": 1024}, num_warps=4), - triton.Config({"BLOCK_SIZE": 2048}, num_warps=8), - triton.Config({"BLOCK_SIZE": 4096}, num_warps=16), -] - - -@triton.jit -def _drop_and_scale(SEEDS, row, p, offsets, x): - # randomly prune the weights - seed = SEEDS + row - random = tl.rand(seed.to(tl.int32), offsets) - x_keep = random > p - - zero = 0.0 - zero = zero.to(x.dtype) - - # prune and normalize in one go - return tl.where(x_keep, (x / (1 - p)).to(x.dtype), zero) - # fmt: off +@triton.heuristics({"SIZE_RAND_BLOCK": lambda *_, **meta: meta["BLOCK_N"] * meta["BLOCK_M"]}) @triton.autotune( - configs=_k_configs, - key=["N"], + configs=[ + triton.Config({"BLOCK_N": 32}, num_warps=1), + triton.Config({"BLOCK_N": 64}, num_warps=2), + triton.Config({"BLOCK_N": 128}, num_warps=4), + triton.Config({"BLOCK_N": 256}, num_warps=8), + ], + key=["M", "N"], ) @triton.jit def k_dropout_fw( Y, X, BIAS, SEEDS, stride, - N, + M, N, p, - **META, + **meta, ): """ Apply dropout on an input tensor - Y : Output (M, N) - X : Input (M, N) - S : Seeds (M,) + Y : Output (M, N) + X : Input (M, N) + BIAS (N,) + SEEDS (M,) p : dropout probability + + This kernel goes through the tensor columns (N dimension), per block (to keep memory parallelism). + This allows the backward pass to follow the same path, with the same seeds, + and start reducing on the gradient bias. """ # fmt: on - BLOCK_SIZE = META["BLOCK_SIZE"] - row = tl.program_id(axis=0) - col = tl.program_id(axis=1) + BLOCK_M = meta["BLOCK_M"] + BLOCK_N = meta["BLOCK_N"] + SIZE_RAND_BLOCK = meta["SIZE_RAND_BLOCK"] + + row_id = tl.program_id(axis=0) + rows = row_id * BLOCK_M * 4 + tl.arange(0, BLOCK_M) + + col_id = tl.program_id(axis=1) + cols = col_id * BLOCK_N + tl.arange(0, BLOCK_N) + seed = SEEDS + col_id # FIXME index the seed properly + + # pointers starting point + x_ptrs = X + rows[:, None] * stride + cols[None, :] + y_ptrs = Y + rows[:, None] * stride + cols[None, :] + + # go over all the tiles, one by one + rand_offsets = tl.arange(0, SIZE_RAND_BLOCK) + row_id * BLOCK_M * 4 + rand1, rand2, rand3, rand4 = tl.randint4x(seed.to(tl.int32), rand_offsets) + threshold = ((p - 0.5) * 2147483648.).to(tl.int32) - # compute memory offsets of elements handled by this instance - offsets = row * stride + col * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) - mask = col * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) < N + # binarize masks, save registers + rand_mask1 = rand1 > threshold + rand_mask2 = rand2 > threshold + rand_mask3 = rand3 > threshold + rand_mask4 = rand4 > threshold - # load data from x - x_ptrs = X + offsets - x = tl.load(x_ptrs, mask=mask) + col_mask = cols[None, :] < N + p_scale = 1 / (1 - p) if p < 1. else 1. + zero = 0.0 + + if meta["USE_BIAS"]: + b_ptrs = BIAS + cols[None, :] + bias = tl.load(b_ptrs, mask=cols[None, :] < N, other=0.) + + for i in range(4): + # cycle through the binary masks (workaround / no indexing) + if i == 0: + rand_mask = rand_mask1 + elif i == 1: + rand_mask = rand_mask2 + elif i == 2: + rand_mask = rand_mask3 + else: + rand_mask = rand_mask4 + + block_mask = (rows[:, None] < M) & col_mask + x = tl.load(x_ptrs, mask=block_mask, other=0.) + + # optionally apply a fused bias + if meta["USE_BIAS"]: + x += bias + + # optional: fused activation (while the data is in shared memory) + if meta["ACTIVATION"]: + x = meta["ACTIVATION"](x) - # optionally apply a fused bias - if META["USE_BIAS"]: - b_ptrs = BIAS + col * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) - b = tl.load(b_ptrs, mask=mask) - x += b + # randomly prune and scale + if p > 0.: + # generate all the random numbers for the block at once, then reshape + keep = tl.reshape(rand_mask, x.shape) - # optional: fused activation (while the data is in shared memory) - if META["ACTIVATION"]: - x = META["ACTIVATION"](x) + # prune and normalize in one go + output = tl.where(keep, (x * p_scale).to(x.dtype), zero.to(x.dtype)) + else: + output = x - # randomly prune it - if p > 0.: - output = _drop_and_scale(SEEDS, row, p, offsets, x) - else: - output = x + tl.store(y_ptrs, output, mask=block_mask) - y_ptrs = Y + offsets - tl.store(y_ptrs, output, mask=mask) + # Update the pointers + rows += BLOCK_M # needs to be updated for the mask to be correct + x_ptrs += BLOCK_M * stride + y_ptrs += BLOCK_M * stride # fmt: off +@triton.heuristics({"SIZE_RAND_BLOCK": lambda *_, **meta: meta["BLOCK_N"] * meta["BLOCK_M"]}) @triton.autotune( - configs=_k_configs, - key=["N"], + configs=[ + triton.Config({"BLOCK_N": 32}, num_warps=1), + triton.Config({"BLOCK_N": 64}, num_warps=2), + ], + key=["M", "N"], ) @triton.jit def k_dropout_bw( - GRAD_IN, GRAD_OUT, INPUTS, BIAS, SEEDS, + GRAD_IN, GRAD_BIAS, GRAD_OUT, + INPUTS, BIAS, SEEDS, stride_grad, stride_inputs, - N, + M, N, p, - **META, + **meta, ): """ Apply dropout on an input tensor GRAD_OUT (M, N) + GRAD_BIAS (N,) GRAD_IN (M, N) BIAS (N,) - SEEDS (M,) + SEEDS (N,) p : dropout probability """ # fmt: on - BLOCK_SIZE = META["BLOCK_SIZE"] - row = tl.program_id(axis=0) - col = tl.program_id(axis=1) - - # compute memory offsets of elements handled by this instance - grad_offsets = row * stride_grad + col * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) - mask = col * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) < N - - # load data from x - grad_out_ptrs = GRAD_OUT + grad_offsets - grad_out = tl.load(grad_out_ptrs, mask=mask) - - # optional: fused activation (while the data is in shared memory) - if META["ACTIVATION_GRAD"]: - input_ptrs = INPUTS + row * stride_inputs + col * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) - inputs = tl.load(input_ptrs, mask=mask) - - # optionally apply a fused bias - if META["USE_BIAS"]: - b_ptrs = BIAS + col * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) - b = tl.load(b_ptrs, mask=mask) - inputs += b - - act_grad = META["ACTIVATION_GRAD"](inputs) - grad_out *= act_grad - - # randomly prune it - if p > 0.: - output = _drop_and_scale(SEEDS, row, p, grad_offsets, grad_out) - else: - output = grad_out - - # write-back - y_ptrs = GRAD_IN + grad_offsets - tl.store(y_ptrs, output, mask=mask) + BLOCK_M = meta["BLOCK_M"] + BLOCK_N = meta["BLOCK_N"] + SIZE_RAND_BLOCK = meta["SIZE_RAND_BLOCK"] + TRAINABLE_BIAS = meta["TRAINABLE_BIAS"] + + rows = tl.arange(0, BLOCK_M) + row_id = tl.program_id(axis=0) + rows = row_id * BLOCK_M * 4 + tl.arange(0, BLOCK_M) + + col_id = tl.program_id(axis=1) + cols = col_id * BLOCK_N + tl.arange(0, BLOCK_N) + seed = SEEDS + col_id # FIXME index the seed properly + + # pointers starting point + grad_out_ptrs = GRAD_OUT + rows[:, None] * stride_grad + cols[None, :] + grad_in_ptrs = GRAD_IN + rows[:, None] * stride_grad + cols[None, :] + input_ptrs = INPUTS + rows[:, None] * stride_inputs + cols[None, :] + + # random binary masks, save registers + rand_offsets = tl.arange(0, SIZE_RAND_BLOCK) + row_id * BLOCK_M * 4 + rand1, rand2, rand3, rand4 = tl.randint4x(seed.to(tl.int32), rand_offsets) + threshold = ((p - 0.5) * 2147483648.).to(tl.int32) + + rand_mask1 = rand1 > threshold + rand_mask2 = rand2 > threshold + rand_mask3 = rand3 > threshold + rand_mask4 = rand4 > threshold + rand_mask = rand_mask1 + + # now go over the tiles + grad_bias = tl.zeros((BLOCK_N,), dtype=tl.float32) + col_mask = cols[None, :] < N + zero = 0.0 + p_scale = 1 / (1 - p) if p < 1. else 1. + + if meta["USE_BIAS"]: + b_ptrs = BIAS + cols[None, :] + bias = tl.load(b_ptrs, mask=col_mask, other=0.) + + for i in range(4): + # cycle through the binary masks (workaround / no indexing) + if i == 0: + rand_mask = rand_mask1 + elif i == 1: + rand_mask = rand_mask2 + elif i == 2: + rand_mask = rand_mask3 + else: + rand_mask = rand_mask4 + + block_mask = (rows[:, None] < M) & col_mask + grad_out = tl.load(grad_out_ptrs, mask=block_mask, other=0.) + + # optional: fused activation (while the data is in shared memory) + if meta["ACTIVATION_GRAD"]: + inputs = tl.load(input_ptrs, mask=block_mask, other=0.) + + # optionally apply a fused bias + if meta["USE_BIAS"]: + inputs += bias + + act_grad = meta["ACTIVATION_GRAD"](inputs).to(grad_out.dtype) + grad_out *= act_grad + + # randomly prune and scale + if p > 0.: + # generate all the random numbers for the block at once, then reshape + keep = tl.reshape(rand_mask, grad_out.shape) + + # prune and normalize in one go + output = tl.where( + keep, + (grad_out * p_scale).to(grad_out.dtype), + zero.to(grad_out.dtype) + ) + else: + output = grad_out + + # write-back + tl.store(grad_in_ptrs, output, mask=block_mask) + + # optionally accumulate the bias gradient + if TRAINABLE_BIAS: + grad_bias += tl.sum(output, axis=0) + + # Update the pointers + rows += BLOCK_M # needs to be updated for the mask to be correct + grad_out_ptrs += BLOCK_M * stride_grad + input_ptrs += BLOCK_M * stride_inputs + grad_in_ptrs += BLOCK_M * stride_grad + + # cycle through the binary masks + if i == 0: + rand_mask = rand_mask2 + elif i == 1: + rand_mask = rand_mask3 + elif i == 2: + rand_mask = rand_mask4 + else: + rand_mask = rand_mask1 + + if TRAINABLE_BIAS: + grad_bias_ptr = GRAD_BIAS + row_id * N + cols + tl.store(grad_bias_ptr, grad_bias, mask=cols < N)