Skip to content

Commit

Permalink
flipping the seeds so that it drops down from the top
Browse files Browse the repository at this point in the history
using less seeds

tiling + vertical seeds
  • Loading branch information
blefaudeux committed Dec 10, 2021
1 parent adfb645 commit 147e2c9
Show file tree
Hide file tree
Showing 15 changed files with 503 additions and 144 deletions.
2 changes: 1 addition & 1 deletion .gitignore
Expand Up @@ -51,4 +51,4 @@ examples/data

# Hydra default output dir
multirun
outputs
outputs
Binary file added docs/plots/strided_sum/Strided_sum_fp16.png
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added docs/plots/strided_sum/Strided_sum_fp32.png
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
50 changes: 49 additions & 1 deletion tests/test_triton_basics.py
Expand Up @@ -4,14 +4,32 @@
# LICENSE file in the root directory of this source tree.


import pytest
import torch

SHAPES = [
(384, 128),
(8 * 384, 128),
(34, 128),
(16, 128),
(16, 512),
(8, 384),
(8, 1024),
(8, 2048),
(8, 4096),
(8, 4096),
(4, 12288),
]


_triton_available = torch.cuda.is_available()
if _triton_available:
try:
import triton
import triton.language as tl

from xformers.triton.sum_strided import sum_2d_dim_0

except (ImportError, ModuleNotFoundError):
_triton_available = False

Expand Down Expand Up @@ -39,7 +57,7 @@ def k_mean(X, Mean, Var, stride, N, **META):
# Compute variance
x_mean = tl.sum(x, axis=0) / N
x_zm = x - x_mean
x_zm = tl.where(cols < N, x_zm, 0.0) # THIS SHOULD NOT BE NEEDED
x_zm = tl.where(cols < N, x_zm, 0.0)
x_var = tl.sum(x_zm * x_zm, axis=0) / N
tl.store(Mean + row, x_mean)
tl.store(Var + row, x_var)
Expand Down Expand Up @@ -88,3 +106,33 @@ def test_mean():

assert torch.allclose(mean, t_mean, rtol=1e-1)
assert torch.allclose(var, t_var, rtol=1e-1)

@pytest.mark.parametrize("shape", SHAPES)
@pytest.mark.parametrize("dtype", [torch.float16, torch.float32])
def test_sum_strided(shape, dtype):
torch.random.manual_seed(0)
a = torch.rand(shape, device=torch.device("cuda"), dtype=dtype)

torch_sum = torch.sum(a, dim=0)
triton_sum = sum_2d_dim_0(a)
assert torch.allclose(
torch_sum, triton_sum, rtol=0.01
), f"{torch_sum}\n{triton_sum}"

def test_sum_strided_asserts():
torch.random.manual_seed(0)
a = torch.rand((128, 256), device=torch.device("cuda"), dtype=torch.float16)

with pytest.raises(AssertionError):
# This kernel is not useful in that case, assert to prevent misuse
sum_2d_dim_0(a.transpose(1, 0))

a = torch.rand((3, 128, 256), device=torch.device("cuda"), dtype=torch.float16)
with pytest.raises(AssertionError):
# This kernel expects 2D tensors, assert to prevent misuse
sum_2d_dim_0(a)

a = torch.rand((2, 128), device=torch.device("cuda"), dtype=torch.float16)
with pytest.raises(AssertionError):
# This kernel cannot sum over dimensions < 4
sum_2d_dim_0(a)
8 changes: 6 additions & 2 deletions tests/test_triton_dropout.py
Expand Up @@ -25,7 +25,7 @@
)
_triton_available = False

# Testing odd shapes on purpose
# Testing odd (non-power-of-two for instance) shapes on purpose
SHAPES = [
(384, 128),
(8, 384, 128),
Expand Down Expand Up @@ -90,6 +90,10 @@ def test_dropout(shape, amp, bias):
== y.shape[1]
)

# Check that the drop probability is about right
drop_p = (y_a.numel() - y_a.count_nonzero()) / y_a.numel()
assert drop_p < 0.55 and drop_p > 0.45


@pytest.mark.skipif(not _triton_available, reason="Triton is not available")
@pytest.mark.skipif(
Expand Down Expand Up @@ -151,4 +155,4 @@ def test_dropout_parity(shape, amp, bias, activation, p):
if bias:
assert torch.allclose(
torch.norm(b.grad), torch.norm(b_.grad), rtol=0.01
), f"{b.grad}\n{b_.grad}"
), f"{b.grad.norm()}\n{b_.grad.norm()}"
4 changes: 2 additions & 2 deletions xformers/benchmarks/benchmark_triton_dropout.py
Expand Up @@ -18,8 +18,8 @@
(8, 512, 1024),
(4, 1024, 1024),
(2, 2048, 2048),
(2, 4096, 4096),
(1, 2048, 12288),
(2, 4096, 4096),
]

P = 0.1
Expand Down Expand Up @@ -105,7 +105,7 @@ def triton_step(x):
)


for activation in [Activation.GeLU, None]:
for activation in [Activation.SquaredReLU, Activation.GeLU, None]:
for bw in [True, False]:
for bias in [True, False]:
bench_dropout(bias, bw, activation)
71 changes: 71 additions & 0 deletions xformers/benchmarks/benchmark_triton_stride_sum.py
@@ -0,0 +1,71 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.

from typing import Any, Dict, List

import torch
import triton

from xformers.benchmarks.utils import TestCase, pretty_plot, pretty_print
from xformers.triton.sum_strided import sum_2d_dim_0

SHAPES = [
(128, 128),
(384, 128),
(784, 512),
(1024, 768),
(2048, 1024),
(4096, 4096),
]


def to_gbs(a, ms):
# Read the full array, write the non-reduced dimension
return ((a.numel() + a.shape[1]) * a.element_size() * 1e-9) / (ms * 1e-3)


def bench_functions(
test_cases: List[TestCase], shapes, metric_transform, unit, title=""
):
device = torch.device("cuda")

for dtype in [torch.float16, torch.float32]:
results: Dict[str, Any] = {}

for M, N in shapes:
a = torch.rand(M, N, device=device, dtype=dtype, requires_grad=True)

for testcase in test_cases:
time = triton.testing.do_bench(lambda: testcase.function(a))[0]

metric = metric_transform(a, time)

key = f"M={M}, N={N}"
if key not in results:
results[key] = {}

results[key][testcase.name] = f"{metric:.1f}"

_type = " fp16" if dtype == torch.float16 else " fp32"

pretty_print(
results,
title=" ------------- Type: {} ------------- ".format(_type),
units=unit,
)

pretty_plot(results, title + _type, unit, dash_key="pytorch")


bench_functions(
[
TestCase(lambda x: torch.sum(x, dim=0), "pytorch"),
TestCase(sum_2d_dim_0, "triton"),
],
SHAPES,
to_gbs,
"GB/s",
"Strided_sum",
)
10 changes: 5 additions & 5 deletions xformers/benchmarks/utils.py
Expand Up @@ -28,12 +28,12 @@
def pretty_print(results, title, units):
""" Printout the contents of a dict as a human-readable and Markdown compatible array"""
print(title)
header = " Units: {:<40}".format(units)
print("|" + header + "|" + "".join("{0:<20}|".format(k) for k in results.keys()))
header = " Units: {:<45}".format(units)
print("| " + header + "|" + "".join("{0:<20}|".format(k) for k in results.keys()))

offset = len(header)
print(
"|{}|".format("-" * offset)
"|-{}|".format("-" * offset)
+ "".join("{}|".format("-" * 20) for _ in results.keys())
)

Expand All @@ -44,7 +44,7 @@ def pretty_print(results, title, units):

for k, w in workloads.items():
print(
"|{0:<{offset}}|".format(k, offset=offset)
"| {0:<{offset}}|".format(k, offset=offset)
+ "".join("{:<20}|".format(v) for v in w)
)

Expand Down Expand Up @@ -85,7 +85,7 @@ def pretty_plot(results, title, units: str, filename=None, dash_key=""):
plt.xticks(rotation=45)

plt.savefig(filename, bbox_inches="tight")
plt.clf()
plt.close(f)


if _triton_is_available:
Expand Down
3 changes: 2 additions & 1 deletion xformers/components/__init__.py
Expand Up @@ -13,7 +13,8 @@
from .activations import Activation, build_activation # noqa
from .attention import Attention, build_attention # noqa
from .in_proj_container import InProjContainer, InProjParams # noqa
from .multi_head_dispatch import MultiHeadDispatch, MultiHeadDispatchConfig # noqa
from .multi_head_dispatch import MultiHeadDispatch # noqa
from .multi_head_dispatch import MultiHeadDispatchConfig
from .residual import LayerNormStyle, PostNorm, PreNorm, Residual # noqa

# automatically import any Python files in the directory
Expand Down
81 changes: 48 additions & 33 deletions xformers/triton/dropout.py
Expand Up @@ -25,42 +25,40 @@
class _dropout(torch.autograd.Function):
@staticmethod
@custom_fwd(cast_inputs=torch.float16)
def forward(ctx, x, p, bias, activation, activation_grad):
def forward(ctx, x, p, bias, activation, activation_grad, trainable_bias):
# Soft-flatten an hypothetical 3rd dimension
x_ = x.reshape(-1, x.shape[-1]).contiguous()
y = torch.empty_like(x_)
_, N = x_.shape
M, N = x_.shape

assert bias is None or bias.dtype == x.dtype, bias
assert bias is None or (bias.dtype == x.dtype and bias.shape[0] == N)

# Generate one seed per sample
# seed max is int32 max for positive numbers: 2**16
seeds = torch.randint(65536, (x_.shape[0],), device=x.device).to(torch.int32)
seeds = torch.randint(65536, (N,), device=x.device).to(torch.int32)

# SPMD launch grid
def grid(meta):
return (
x_.shape[0],
triton.cdiv(x_.shape[1], meta["BLOCK_SIZE"]),
)
return (triton.cdiv(N, meta["BLOCK_N"]),)

# fmt: off
k_dropout_fw[grid](
y, x_, bias if bias is not None else x_,
y, x_,
bias if bias is not None else x_,
seeds,
y.stride(0),
N,
M, N,
p,
USE_BIAS=bias is not None,
ACTIVATION=activation
ACTIVATION=activation,
)
# fmt: on

if activation is not None:
ctx.save_for_backward(seeds, bias, x)
else:
ctx.save_for_backward(seeds, bias, None)
ctx.trainable_bias = bias is not None

ctx.trainable_bias = bias is not None and trainable_bias
ctx.activation_grad = activation_grad
ctx.p = p

Expand All @@ -75,40 +73,46 @@ def backward(ctx, grad_out):
grad_out_ = grad_out.reshape(-1, grad_out.shape[-1]).contiguous()
grad_in = torch.empty_like(grad_out_)

_, N = grad_out_.shape
M, N = grad_out_.shape

# Optional inputs to compute the activation contribution to the gradient
assert inputs is not None or ctx.activation_grad is None

if inputs is None:
inputs = grad_out_
elif inputs.ndim > 2:
inputs = inputs.reshape(-1, grad_out.shape[-1])
inputs = inputs.reshape(-1, N)

if ctx.trainable_bias:
grad_bias = torch.empty((N,), device=grad_in.device, dtype=grad_in.dtype)
else:
grad_bias = grad_in # will not be used

# SPMD launch grid
def grid(meta):
return (
grad_out_.shape[0],
triton.cdiv(grad_out_.shape[1], meta["BLOCK_SIZE"]),
)
return (triton.cdiv(N, meta["BLOCK_N"]),)

# fmt: off
k_dropout_bw[grid](
grad_in, grad_out_, inputs, bias if bias is not None else inputs,
grad_in, grad_bias, grad_out_,
inputs, bias if bias is not None else inputs,
seeds,
grad_out_.stride(0), inputs.stride(0),
N,
M, N,
ctx.p,
USE_BIAS=bias is not None,
ACTIVATION_GRAD=ctx.activation_grad)
ACTIVATION_GRAD=ctx.activation_grad,
TRAINABLE_BIAS=ctx.trainable_bias
)
# fmt: on

if ctx.trainable_bias:
grad_bias: Optional[torch.Tensor] = torch.sum(grad_in, dim=0)
else:
grad_bias = None

return grad_in.reshape_as(grad_out), None, grad_bias, None, None
return (
grad_in.reshape_as(grad_out),
None,
grad_bias if ctx.trainable_bias else None,
None,
None,
None,
)


def dropout(
Expand All @@ -128,7 +132,14 @@ def dropout(

act_kernel = get_triton_activation_kernel(activation)
act_grad_kernel = get_triton_activation_bwd_kernel(activation)
return _dropout.apply(x, p, bias, act_kernel, act_grad_kernel)
return _dropout.apply(
x,
p,
bias,
act_kernel,
act_grad_kernel,
bias is not None and bias.requires_grad,
)


class FusedDropoutBias(torch.nn.Module):
Expand All @@ -141,8 +152,10 @@ def __init__(
super().__init__()
self.p = p
self.activation = activation
self.register_buffer(
"bias", torch.zeros(bias_shape) if bias_shape is not None else None
self.bias = (
torch.zeros(bias_shape, requires_grad=True)
if bias_shape is not None
else None
)
self.activation = get_triton_activation_kernel(activation)
self.activation_grad = get_triton_activation_bwd_kernel(activation)
Expand All @@ -153,4 +166,6 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
self.bias = self.bias.to(dtype=x.dtype, device=x.device) # type: ignore

p = self.p if self.training else 0.0
return _dropout.apply(x, p, self.bias, self.activation, self.activation_grad)
return _dropout.apply(
x, p, self.bias, self.activation, self.activation_grad, True
)

0 comments on commit 147e2c9

Please sign in to comment.