Skip to content

Commit

Permalink
yet another take, partial sum
Browse files Browse the repository at this point in the history
  • Loading branch information
blefaudeux committed Dec 10, 2021
1 parent ce1ab1c commit 73b7663
Show file tree
Hide file tree
Showing 9 changed files with 29 additions and 274 deletions.
Binary file removed docs/plots/strided_sum/Strided_sum_fp16.png
Binary file not shown.
Binary file removed docs/plots/strided_sum/Strided_sum_fp32.png
Binary file not shown.
50 changes: 1 addition & 49 deletions tests/test_triton_basics.py
Expand Up @@ -4,32 +4,14 @@
# LICENSE file in the root directory of this source tree.


import pytest
import torch

SHAPES = [
(384, 128),
(8 * 384, 128),
(34, 128),
(16, 128),
(16, 512),
(8, 384),
(8, 1024),
(8, 2048),
(8, 4096),
(8, 4096),
(4, 12288),
]


_triton_available = torch.cuda.is_available()
if _triton_available:
try:
import triton
import triton.language as tl

from xformers.triton.sum_strided import sum_2d_dim_0

except (ImportError, ModuleNotFoundError):
_triton_available = False

Expand Down Expand Up @@ -57,7 +39,7 @@ def k_mean(X, Mean, Var, stride, N, **META):
# Compute variance
x_mean = tl.sum(x, axis=0) / N
x_zm = x - x_mean
x_zm = tl.where(cols < N, x_zm, 0.0)
x_zm = tl.where(cols < N, x_zm, 0.0) # THIS SHOULD NOT BE NEEDED
x_var = tl.sum(x_zm * x_zm, axis=0) / N
tl.store(Mean + row, x_mean)
tl.store(Var + row, x_var)
Expand Down Expand Up @@ -106,33 +88,3 @@ def test_mean():

assert torch.allclose(mean, t_mean, rtol=1e-1)
assert torch.allclose(var, t_var, rtol=1e-1)

@pytest.mark.parametrize("shape", SHAPES)
@pytest.mark.parametrize("dtype", [torch.float16, torch.float32])
def test_sum_strided(shape, dtype):
torch.random.manual_seed(0)
a = torch.rand(shape, device=torch.device("cuda"), dtype=dtype)

torch_sum = torch.sum(a, dim=0)
triton_sum = sum_2d_dim_0(a)
assert torch.allclose(
torch_sum, triton_sum, rtol=0.01
), f"{torch_sum}\n{triton_sum}"

def test_sum_strided_asserts():
torch.random.manual_seed(0)
a = torch.rand((128, 256), device=torch.device("cuda"), dtype=torch.float16)

with pytest.raises(AssertionError):
# This kernel is not useful in that case, assert to prevent misuse
sum_2d_dim_0(a.transpose(1, 0))

a = torch.rand((3, 128, 256), device=torch.device("cuda"), dtype=torch.float16)
with pytest.raises(AssertionError):
# This kernel expects 2D tensors, assert to prevent misuse
sum_2d_dim_0(a)

a = torch.rand((2, 128), device=torch.device("cuda"), dtype=torch.float16)
with pytest.raises(AssertionError):
# This kernel cannot sum over dimensions < 4
sum_2d_dim_0(a)
71 changes: 0 additions & 71 deletions xformers/benchmarks/benchmark_triton_stride_sum.py

This file was deleted.

28 changes: 21 additions & 7 deletions xformers/triton/dropout.py
Expand Up @@ -44,6 +44,9 @@ def grid(meta):
triton.cdiv(N, meta["BLOCK_N"]),
)

GROUP_M = 128
BLOCK_M = GROUP_M // 4

# fmt: off
k_dropout_fw[grid](
y, x_,
Expand All @@ -53,7 +56,8 @@ def grid(meta):
M, N,
p,
USE_BIAS=bias is not None,
ACTIVATION=activation
ACTIVATION=activation,
BLOCK_M=BLOCK_M
)
# fmt: on

Expand Down Expand Up @@ -87,12 +91,21 @@ def backward(ctx, grad_out):
elif inputs.ndim > 2:
inputs = inputs.reshape(-1, N)

GROUP_M = 128
BLOCK_M = GROUP_M // 4
N_BLOCKS_M = triton.cdiv(M, GROUP_M)

if ctx.trainable_bias:
grad_bias = torch.empty((N,), device=grad_in.device, dtype=grad_in.dtype)
locks = torch.zeros(N // 2, dtype=torch.int32, device=grad_in.device)
grad_bias = torch.empty(
(
N_BLOCKS_M,
N,
),
device=grad_in.device,
dtype=grad_in.dtype,
)
else:
grad_bias = grad_in # will not be used
locks = grad_in

def grid(meta):
return (
Expand All @@ -104,20 +117,21 @@ def grid(meta):
k_dropout_bw[grid](
grad_in, grad_bias, grad_out_,
inputs, bias if bias is not None else inputs,
seeds, locks,
seeds,
grad_out_.stride(0), inputs.stride(0),
M, N,
ctx.p,
USE_BIAS=bias is not None,
ACTIVATION_GRAD=ctx.activation_grad,
TRAINABLE_BIAS=ctx.trainable_bias
TRAINABLE_BIAS=ctx.trainable_bias,
BLOCK_M=BLOCK_M
)
# fmt: on

return (
grad_in.reshape_as(grad_out),
None,
grad_bias if ctx.trainable_bias else None,
grad_bias.sum(dim=0) if ctx.trainable_bias else None,
None,
None,
None,
Expand Down
36 changes: 6 additions & 30 deletions xformers/triton/k_dropout.py
Expand Up @@ -10,17 +10,11 @@
import triton
import triton.language as tl

# WARNING: For now, the number of threads must be the same as the N buffer, and warps have to be 4 (will be fixed)
k_configs = [
triton.Config({"BLOCK_M": 64, "BLOCK_N": 32}, num_warps=1),
triton.Config({"BLOCK_M": 128, "BLOCK_N": 32}, num_warps=1),
triton.Config({"BLOCK_M": 256, "BLOCK_N": 32}, num_warps=1),
triton.Config({"BLOCK_M": 64, "BLOCK_N": 64}, num_warps=2),
triton.Config({"BLOCK_M": 128, "BLOCK_N": 64}, num_warps=2),
triton.Config({"BLOCK_M": 256, "BLOCK_N": 64}, num_warps=2),
triton.Config({"BLOCK_M": 64, "BLOCK_N": 128}, num_warps=4),
triton.Config({"BLOCK_M": 128, "BLOCK_N": 128}, num_warps=4),
triton.Config({"BLOCK_M": 256, "BLOCK_N": 128}, num_warps=4),
triton.Config({"BLOCK_N": 32}, num_warps=1),
triton.Config({"BLOCK_N": 64}, num_warps=2),
triton.Config({"BLOCK_N": 128}, num_warps=4),
triton.Config({"BLOCK_N": 256}, num_warps=8),
]


Expand Down Expand Up @@ -131,7 +125,7 @@ def k_dropout_fw(
@triton.jit
def k_dropout_bw(
GRAD_IN, GRAD_BIAS, GRAD_OUT,
INPUTS, BIAS, SEEDS, LOCKS,
INPUTS, BIAS, SEEDS,
stride_grad, stride_inputs,
M, N,
p,
Expand Down Expand Up @@ -250,23 +244,5 @@ def k_dropout_bw(
rand_mask = rand_mask1

if TRAINABLE_BIAS:
lock_ptr = LOCKS + 2 * col_id
count_ptr = LOCKS + 2 * col_id + 1
grad_bias_ptr = GRAD_BIAS + cols

# Uniquely taking a lock over the col results
while tl.atomic_cas(lock_ptr, 0, 1) == 1:
pass

count = tl.load(count_ptr)
if count == 0:
# first store doesn't accumulate
tl.atomic_xchg(count_ptr, 1)
else:
# read and add back
grad_bias += tl.load(grad_bias_ptr, mask=cols < N)

grad_bias_ptr = GRAD_BIAS + row_id * N + cols
tl.store(grad_bias_ptr, grad_bias, mask=cols < N)

# release lock
tl.atomic_xchg(lock_ptr, 0)
4 changes: 1 addition & 3 deletions xformers/triton/k_fused_matmul_bw.py
Expand Up @@ -10,8 +10,6 @@
import triton
import triton.language as tl

from xformers.triton.sum_strided import sum_2d_dim_0


# fmt: off
@triton.heuristics({
Expand Down Expand Up @@ -146,6 +144,6 @@ def grid(META):
# The following ops can also be handled by triton
grad_in = grad_out_ @ weight
grad_weight = grad_out_.transpose(1, 0) @ inputs_ if trainable_weight else None
grad_bias = sum_2d_dim_0(grad_out_) if trainable_bias else None
grad_bias = torch.sum(grad_out_, 0) if trainable_bias else None

return grad_in.reshape_as(inputs), grad_weight, grad_bias
66 changes: 0 additions & 66 deletions xformers/triton/k_sum.py

This file was deleted.

0 comments on commit 73b7663

Please sign in to comment.