From c7ab5b07787837c66f92b5c3b11bc58657f57238 Mon Sep 17 00:00:00 2001 From: Benjamin Lefaudeux Date: Thu, 9 Dec 2021 21:56:14 -0800 Subject: [PATCH] yet another take, partial sum --- xformers/triton/dropout.py | 26 ++++++++++++++++++++------ xformers/triton/k_dropout.py | 34 +++++----------------------------- 2 files changed, 25 insertions(+), 35 deletions(-) diff --git a/xformers/triton/dropout.py b/xformers/triton/dropout.py index 54799be51..d72901aab 100644 --- a/xformers/triton/dropout.py +++ b/xformers/triton/dropout.py @@ -44,6 +44,9 @@ def grid(meta): triton.cdiv(N, meta["BLOCK_N"]), ) + GROUP_M = 128 + BLOCK_M = GROUP_M // 4 + # fmt: off k_dropout_fw[grid]( y, x_, @@ -53,7 +56,8 @@ def grid(meta): M, N, p, USE_BIAS=bias is not None, - ACTIVATION=activation + ACTIVATION=activation, + BLOCK_M=BLOCK_M ) # fmt: on @@ -87,12 +91,21 @@ def backward(ctx, grad_out): elif inputs.ndim > 2: inputs = inputs.reshape(-1, N) + GROUP_M = 256 + BLOCK_M = GROUP_M // 4 + N_BLOCKS_M = triton.cdiv(M, GROUP_M) + if ctx.trainable_bias: - grad_bias = torch.empty((N,), device=grad_in.device, dtype=grad_in.dtype) - locks = torch.zeros(N // 2, dtype=torch.int32, device=grad_in.device) + grad_bias = torch.empty( + ( + N_BLOCKS_M, + N, + ), + device=grad_in.device, + dtype=grad_in.dtype, + ) else: grad_bias = grad_in # will not be used - locks = grad_in def grid(meta): return ( @@ -104,13 +117,14 @@ def grid(meta): k_dropout_bw[grid]( grad_in, grad_bias, grad_out_, inputs, bias if bias is not None else inputs, - seeds, locks, + seeds, grad_out_.stride(0), inputs.stride(0), M, N, ctx.p, USE_BIAS=bias is not None, ACTIVATION_GRAD=ctx.activation_grad, - TRAINABLE_BIAS=ctx.trainable_bias + TRAINABLE_BIAS=ctx.trainable_bias, + BLOCK_M=BLOCK_M ) # fmt: on diff --git a/xformers/triton/k_dropout.py b/xformers/triton/k_dropout.py index 58589f0a1..13cff1911 100644 --- a/xformers/triton/k_dropout.py +++ b/xformers/triton/k_dropout.py @@ -12,15 +12,9 @@ # WARNING: For now, the number of threads must be the same as the N buffer, and warps have to be 4 (will be fixed) k_configs = [ - triton.Config({"BLOCK_M": 64, "BLOCK_N": 32}, num_warps=1), - triton.Config({"BLOCK_M": 128, "BLOCK_N": 32}, num_warps=1), - triton.Config({"BLOCK_M": 256, "BLOCK_N": 32}, num_warps=1), - triton.Config({"BLOCK_M": 64, "BLOCK_N": 64}, num_warps=2), - triton.Config({"BLOCK_M": 128, "BLOCK_N": 64}, num_warps=2), - triton.Config({"BLOCK_M": 256, "BLOCK_N": 64}, num_warps=2), - triton.Config({"BLOCK_M": 64, "BLOCK_N": 128}, num_warps=4), - triton.Config({"BLOCK_M": 128, "BLOCK_N": 128}, num_warps=4), - triton.Config({"BLOCK_M": 256, "BLOCK_N": 128}, num_warps=4), + triton.Config({"BLOCK_N": 32}, num_warps=1), + triton.Config({"BLOCK_N": 64}, num_warps=2), + triton.Config({"BLOCK_N": 128}, num_warps=4), ] @@ -131,7 +125,7 @@ def k_dropout_fw( @triton.jit def k_dropout_bw( GRAD_IN, GRAD_BIAS, GRAD_OUT, - INPUTS, BIAS, SEEDS, LOCKS, + INPUTS, BIAS, SEEDS, stride_grad, stride_inputs, M, N, p, @@ -250,23 +244,5 @@ def k_dropout_bw( rand_mask = rand_mask1 if TRAINABLE_BIAS: - lock_ptr = LOCKS + 2 * col_id - count_ptr = LOCKS + 2 * col_id + 1 - grad_bias_ptr = GRAD_BIAS + cols - - # Uniquely taking a lock over the col results - while tl.atomic_cas(lock_ptr, 0, 1) == 1: - pass - - count = tl.load(count_ptr) - if count == 0: - # first store doesn't accumulate - tl.atomic_xchg(count_ptr, 1) - else: - # read and add back - grad_bias += tl.load(grad_bias_ptr, mask=cols < N) - + grad_bias_ptr = GRAD_BIAS + row_id * N + cols tl.store(grad_bias_ptr, grad_bias, mask=cols < N) - - # release lock - tl.atomic_xchg(lock_ptr, 0)