Skip to content

Commit

Permalink
flipping the seeds so that it drops down from the top
Browse files Browse the repository at this point in the history
using less seeds

tiling + vertical seeds

Computing the FW and BW per tile over M

workaround atomics, reintroduce locks, should not be too bad

yet another take, partial sum

better scheduling defaults, improves across the board
  • Loading branch information
blefaudeux committed Dec 15, 2021
1 parent 00aebbc commit 157789b
Show file tree
Hide file tree
Showing 6 changed files with 285 additions and 137 deletions.
8 changes: 6 additions & 2 deletions tests/test_triton_dropout.py
Expand Up @@ -97,6 +97,10 @@ def test_dropout(shape, amp, bias):
== y.shape[1]
)

# Check that the drop probability is about right
drop_p = (y_a.numel() - y_a.count_nonzero()) / y_a.numel()
assert drop_p < 0.55 and drop_p > 0.45


@pytest.mark.skipif(not _triton_available, reason="Triton is not available")
@pytest.mark.skipif(
Expand All @@ -107,7 +111,7 @@ def test_dropout(shape, amp, bias):
@pytest.mark.parametrize("amp", [False, True])
@pytest.mark.parametrize("bias", [True, False])
@pytest.mark.parametrize("activation", [a.value for a in Activation])
@pytest.mark.parametrize("p", [0, 0.001, 0.5])
@pytest.mark.parametrize("p", [0, 0.01, 0.5])
def test_dropout_parity(shape, amp, bias, activation, p):
"""
Check some basic dropout properties
Expand Down Expand Up @@ -158,4 +162,4 @@ def test_dropout_parity(shape, amp, bias, activation, p):
if bias:
assert torch.allclose(
torch.norm(b.grad), torch.norm(b_.grad), rtol=0.01
), f"{b.grad.norm()}\n{b_.grad.norm()}"
), f"{b.grad.norm()} - {b_.grad.norm()}"
4 changes: 3 additions & 1 deletion xformers/benchmarks/benchmark_triton_dropout.py
Expand Up @@ -62,12 +62,14 @@ def torch_step(x):
y = torch_act(y)

if backward:
y.grad = None
torch.norm(y).backward()
return y

def triton_step(x):
y = triton_dropout(x)
if backward:
y.grad = None
torch.norm(y).backward()
return y

Expand Down Expand Up @@ -105,7 +107,7 @@ def triton_step(x):
)


for activation in [Activation.GeLU, None]:
for activation in [Activation.GeLU, None, Activation.SquaredReLU]:
for bw in [True, False]:
for bias in [True, False]:
bench_dropout(bias, bw, activation)
10 changes: 5 additions & 5 deletions xformers/benchmarks/utils.py
Expand Up @@ -28,12 +28,12 @@
def pretty_print(results, title, units):
""" Printout the contents of a dict as a human-readable and Markdown compatible array"""
print(title)
header = " Units: {:<40}".format(units)
print("|" + header + "|" + "".join("{0:<20}|".format(k) for k in results.keys()))
header = " Units: {:<45}".format(units)
print("| " + header + "|" + "".join("{0:<20}|".format(k) for k in results.keys()))

offset = len(header)
print(
"|{}|".format("-" * offset)
"|-{}|".format("-" * offset)
+ "".join("{}|".format("-" * 20) for _ in results.keys())
)

Expand All @@ -44,7 +44,7 @@ def pretty_print(results, title, units):

for k, w in workloads.items():
print(
"|{0:<{offset}}|".format(k, offset=offset)
"| {0:<{offset}}|".format(k, offset=offset)
+ "".join("{:<20}|".format(v) for v in w)
)

Expand Down Expand Up @@ -85,7 +85,7 @@ def pretty_plot(results, title, units: str, filename=None, dash_key=""):
plt.xticks(rotation=45)

plt.savefig(filename, bbox_inches="tight")
plt.clf()
plt.close(f)


if _triton_is_available:
Expand Down
97 changes: 68 additions & 29 deletions xformers/triton/dropout.py
Expand Up @@ -26,42 +26,48 @@
class _dropout(torch.autograd.Function):
@staticmethod
@custom_fwd(cast_inputs=torch.float16)
def forward(ctx, x, p, bias, activation, activation_grad):
def forward(ctx, x, p, bias, activation, activation_grad, trainable_bias):
# Soft-flatten an hypothetical 3rd dimension
x_ = x.reshape(-1, x.shape[-1]).contiguous()
y = torch.empty_like(x_)
_, N = x_.shape
M, N = x_.shape

assert bias is None or bias.dtype == x.dtype, bias
assert bias is None or (bias.dtype == x.dtype and bias.shape[0] == N)

# Generate one seed per sample
# seed max is int32 max for positive numbers: 2**16
seeds = torch.randint(65536, (x_.shape[0],), device=x.device).to(torch.int32)
# FIXME: adjust the number of seeds needed
seeds = torch.randint(65536, (N,), device=x.device).to(torch.int32)

# SPMD launch grid
def grid(meta):
return (
x_.shape[0],
triton.cdiv(x_.shape[1], meta["BLOCK_SIZE"]),
triton.cdiv(M, meta["BLOCK_M"] * 4),
triton.cdiv(N, meta["BLOCK_N"]),
)

GROUP_M = 16
BLOCK_M = GROUP_M // 4

# fmt: off
k_dropout_fw[grid](
y, x_, bias if bias is not None else x_,
y, x_,
bias if bias is not None else x_,
seeds,
y.stride(0),
N,
M, N,
p,
USE_BIAS=bias is not None,
ACTIVATION=activation
ACTIVATION=activation,
BLOCK_M=BLOCK_M
)
# fmt: on

if activation is not None:
ctx.save_for_backward(seeds, bias, x)
else:
ctx.save_for_backward(seeds, bias, None)
ctx.trainable_bias = bias is not None

ctx.trainable_bias = bias is not None and trainable_bias
ctx.activation_grad = activation_grad
ctx.p = p

Expand All @@ -76,40 +82,61 @@ def backward(ctx, grad_out):
grad_out_ = grad_out.reshape(-1, grad_out.shape[-1]).contiguous()
grad_in = torch.empty_like(grad_out_)

_, N = grad_out_.shape
M, N = grad_out_.shape

# Optional inputs to compute the activation contribution to the gradient
assert inputs is not None or ctx.activation_grad is None

if inputs is None:
inputs = grad_out_
elif inputs.ndim > 2:
inputs = inputs.reshape(-1, grad_out.shape[-1])
inputs = inputs.reshape(-1, N)

GROUP_M = 16 if M > 512 else M // 4
BLOCK_M = GROUP_M // 4
N_BLOCKS_M = triton.cdiv(M, GROUP_M)

if ctx.trainable_bias:
grad_bias = torch.empty(
(
N_BLOCKS_M,
N,
),
device=grad_in.device,
dtype=grad_in.dtype,
)
else:
grad_bias = grad_in # will not be used

# SPMD launch grid
def grid(meta):
return (
grad_out_.shape[0],
triton.cdiv(grad_out_.shape[1], meta["BLOCK_SIZE"]),
triton.cdiv(M, meta["BLOCK_M"] * 4),
triton.cdiv(N, meta["BLOCK_N"]),
)

# fmt: off
k_dropout_bw[grid](
grad_in, grad_out_, inputs, bias if bias is not None else inputs,
grad_in, grad_bias, grad_out_,
inputs, bias if bias is not None else inputs,
seeds,
grad_out_.stride(0), inputs.stride(0),
N,
M, N,
ctx.p,
USE_BIAS=bias is not None,
ACTIVATION_GRAD=ctx.activation_grad)
ACTIVATION_GRAD=ctx.activation_grad,
TRAINABLE_BIAS=ctx.trainable_bias,
BLOCK_M=BLOCK_M
)
# fmt: on

if ctx.trainable_bias:
grad_bias: Optional[torch.Tensor] = sum_2d_dim_0(grad_in)
else:
grad_bias = None

return grad_in.reshape_as(grad_out), None, grad_bias, None, None
return (
grad_in.reshape_as(grad_out),
None,
sum_2d_dim_0(grad_bias) if ctx.trainable_bias else None,
None,
None,
None,
)


def dropout(
Expand All @@ -129,7 +156,14 @@ def dropout(

act_kernel = get_triton_activation_kernel(activation)
act_grad_kernel = get_triton_activation_bwd_kernel(activation)
return _dropout.apply(x, p, bias, act_kernel, act_grad_kernel)
return _dropout.apply(
x,
p,
bias,
act_kernel,
act_grad_kernel,
bias is not None and bias.requires_grad,
)


class FusedDropoutBias(torch.nn.Module):
Expand All @@ -142,8 +176,10 @@ def __init__(
super().__init__()
self.p = p
self.activation_type = activation
self.register_buffer(
"bias", torch.zeros(bias_shape) if bias_shape is not None else None
self.bias = (
torch.zeros(bias_shape, requires_grad=True)
if bias_shape is not None
else None
)
self.activation = get_triton_activation_kernel(activation)
self.activation_grad = get_triton_activation_bwd_kernel(activation)
Expand All @@ -160,5 +196,8 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
x = activation(x)
return torch.nn.functional.dropout(x, self.p)

# The normal, Triton-backed path
p = self.p if self.training else 0.0
return _dropout.apply(x, p, self.bias, self.activation, self.activation_grad)
return _dropout.apply(
x, p, self.bias, self.activation, self.activation_grad, True
)
9 changes: 3 additions & 6 deletions xformers/triton/k_activations.py
Expand Up @@ -64,8 +64,7 @@ def relu(x):
.. _ReLU: https://pytorch.org/docs/stable/generated/torch.nn.ReLU.html
"""
zero = 0.0
zero = zero.to(x.dtype)
return tl.where(x >= 0, x, zero)
return tl.where(x >= 0, x, zero.to(x.dtype))


@triton.jit
Expand All @@ -74,10 +73,8 @@ def relu_grad(x):
# in that it does not require the input to retrospectively compute its gradient
# here the input is the downstream gradient, and we return the upstream gradient directly
zero = 0.0
zero = zero.to(x.dtype)
one = 1.0
one = one.to(x.dtype)
return tl.where(x >= 0, one, zero)
return tl.where(x >= 0, one.to(x.dtype), zero.to(x.dtype))


@triton.jit
Expand All @@ -88,7 +85,7 @@ def squared_relu(x):
.. _Primer: https://arxiv.org/abs/2109.08668
"""
x_ = relu(x)
return x_ * x_
return (x_ * x_).to(x.dtype)


@triton.jit
Expand Down

0 comments on commit 157789b

Please sign in to comment.