Skip to content

Commit

Permalink
[backend] 3/3 Triton 2 update (#272)
Browse files Browse the repository at this point in the history
* parent be72b26
author Kashif Rasul <kashif.rasul@gmail.com> 1648069860 +0100
committer Benjamin Lefaudeux <benjamin.lefaudeux@pm.me> 1650256563 -0700

Move to Triton 2

Author:    Kashif Rasul <kashif.rasul@gmail.com>
Co-authored-by: Benjamin Lefaudeux <benjamin.lefaudeux@pm.me>

Tentatively fixing layernorm

- faster all around
- bugfix

better take on sparse tensors, put layout on the correct device
update the pip packages, minor cleanup

* catering for triton blocksparse being probably more reliable in fp16

* faster layernorm

* Minor blocksparse refactoring, update block size restrictions, relax power of two constraint (#277)

* Relax device size restrictions

* Refactor device creation and run all tests

* linting

Co-authored-by: Cole Hawkins <colehawk@amazon.com>

* code review, thanks @fmassa !

Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>
Co-authored-by: colepshawkins <31542048+colehawkins@users.noreply.github.com>
Co-authored-by: Cole Hawkins <colehawk@amazon.com>
  • Loading branch information
4 people committed Apr 21, 2022
1 parent e3b57de commit 4ecbec1
Show file tree
Hide file tree
Showing 69 changed files with 461 additions and 496 deletions.
10 changes: 4 additions & 6 deletions BENCHMARKS.md
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,7 @@ Some examples, generated with `python3 xformers/benchmarks/benchmark_encoder.py

### Fused softmax

You can reproduce these numbers locally by running `python3 xformers/benchmarks/benchmark_triton_softmax.py`. The units are GB/s. These results are for a laptop nVidia 3080, Triton 1.1 and PyTorch 1.10.

You can reproduce these numbers locally by running `python3 xformers/benchmarks/benchmark_triton_softmax.py`. The units are GB/s. These results are for a laptop nVidia 3080, Triton 2.0 and PyTorch 1.12.

![Softmax throughput in fp16 - inference](docs/plots/fused_softmax/Softmax_Bandwidth_FW_fp16.png)

Expand All @@ -52,8 +51,7 @@ You can reproduce these numbers locally by running `python3 xformers/benchmarks/

### Fused linear layer

You can reproduce these numbers locally by running `python3 xformers/benchmarks/benchmark_triton_fused_linear_layer.py`. The units are TFlops/s. These results are for a laptop nVidia 3080, Triton 1.1 and PyTorch 1.10.

You can reproduce these numbers locally by running `python3 xformers/benchmarks/benchmark_triton_fused_linear_layer.py`. The units are TFlops/s. These results are for a laptop nVidia 3080, Triton 2.0 and PyTorch 1.12.

![Fused linear layers throughput in fp16 - inference](docs/plots/fused_linear/FusedLinear_fp16_FW_gelu.png)

Expand All @@ -77,7 +75,7 @@ You can reproduce these numbers locally by running `python3 xformers/benchmarks/

### Fused layer norm

You can reproduce these numbers locally by running `python3 xformers/benchmarks/benchmark_triton_layernorm.py`. The units are GB/s. These results are for a laptop nVidia 3080, Triton 1.1 and PyTorch 1.10.
You can reproduce these numbers locally by running `python3 xformers/benchmarks/benchmark_triton_layernorm.py`. The units are GB/s. These results are for a laptop nVidia 3080, Triton 2.0 and PyTorch 1.12.

![Fused layer norm throughput in fp16 - inference](docs/plots/layer_norm/LayerNorm_FW_torch.float16.png)

Expand All @@ -89,7 +87,7 @@ You can reproduce these numbers locally by running `python3 xformers/benchmarks/

### Fused dropout + bias + activation

You can reproduce these numbers locally by running `python3 xformers/benchmarks/benchmark_triton_dropout.py`. The units are GB/s. These results are for a laptop nVidia 3080, Triton 1.1 and PyTorch 1.10.
You can reproduce these numbers locally by running `python3 xformers/benchmarks/benchmark_triton_dropout.py`. The units are GB/s. These results are for a laptop nVidia 3080, Triton 2.0 and PyTorch 1.12.

![Fused dropout+ bias throughput in fp16 - inference](docs/plots/fused_dropout/Dropout_Bias_True_FW_torch.float16_Act_gelu.png)

Expand Down
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Mem efficient attention, FW pass [#267]
- MHA benchmark
- MLP benchmark
- Move all triton kernels to triton v2 [#272]

## [0.0.10] - 2022-03-14
### Fixed
Expand Down
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified docs/plots/fused_linear/FusedLinear_fp16_FW_BW_gelu.png
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified docs/plots/fused_linear/FusedLinear_fp16_FW_BW_leaky_relu.png
Binary file modified docs/plots/fused_linear/FusedLinear_fp16_FW_BW_none.png
Binary file modified docs/plots/fused_linear/FusedLinear_fp16_FW_BW_relu.png
Binary file modified docs/plots/fused_linear/FusedLinear_fp16_FW_BW_squared_relu.png
Binary file modified docs/plots/fused_linear/FusedLinear_fp16_FW_gelu.png
Binary file modified docs/plots/fused_linear/FusedLinear_fp16_FW_leaky_relu.png
Binary file modified docs/plots/fused_linear/FusedLinear_fp16_FW_none.png
Binary file modified docs/plots/fused_linear/FusedLinear_fp16_FW_relu.png
Binary file modified docs/plots/fused_linear/FusedLinear_fp16_FW_squared_relu.png
Binary file modified docs/plots/fused_softmax/Softmax_Bandwidth_FW_BW_fp16.png
Binary file modified docs/plots/fused_softmax/Softmax_Bandwidth_FW_BW_fp32.png
Binary file modified docs/plots/fused_softmax/Softmax_Bandwidth_FW_fp16.png
Binary file modified docs/plots/fused_softmax/Softmax_Bandwidth_FW_fp32.png
Binary file modified docs/plots/layer_norm/LayerNorm_FW+BW_torch.float16.png
Binary file modified docs/plots/layer_norm/LayerNorm_FW+BW_torch.float32.png
Binary file modified docs/plots/layer_norm/LayerNorm_FW_torch.float16.png
Binary file modified docs/plots/layer_norm/LayerNorm_FW_torch.float32.png
1 change: 0 additions & 1 deletion docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@
# The full version, including alpha/beta/rc tags
release = "0.0.10"


# -- General configuration ---------------------------------------------------

# Add any Sphinx extension module names here, as strings. They can be
Expand Down
2 changes: 1 addition & 1 deletion examples/microGPT.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@


class GPT(pl.LightningModule):
""" the full GPT language model, with a context size of block_size """
"""the full GPT language model, with a context size of block_size"""

def __init__(
self,
Expand Down
7 changes: 3 additions & 4 deletions experimental/ragged_inference/garbage_pad_ragged_acts.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,9 @@ def garbage_pad_ragged_acts_kernel(
ragged_acts_offset_per_seq_ptr,
n_ctx_per_seq_ptr,
padded_acts_ptr,
**meta, # Optional meta-parameters for the kernel
BLOCK_SIZE: tl.constexpr, # How many inputs each program should process
n_ctx_max: tl.constexpr,
):
BLOCK_SIZE = meta["d_model"] # How many inputs each program should process
# There are multiple 'program's processing different data. We identify which program
# we are here

Expand All @@ -47,7 +47,6 @@ def garbage_pad_ragged_acts_kernel(
acts = tl.load(ragged_acts_ptr + ragged_acts_offsets, mask=ctx_idx_too_large_mask)

# Calculate the offsets for the padded acts
n_ctx_max = meta["n_ctx_max"]
padded_acts_offset = n_ctx_max * seq_idx * BLOCK_SIZE

# Write things back, again masking out the sections that would be garbage
Expand Down Expand Up @@ -153,7 +152,7 @@ def triton_to_garbage_padded(self) -> torch.Tensor:
torch.tensor(ragged_acts_offset_per_seq, device="cuda"),
torch.tensor(self.n_ctx_per_seq, device="cuda"),
padded_acts,
d_model=d_model,
BLOCK_SIZE=d_model,
n_ctx_max=n_ctx_max,
)
return padded_acts
Expand Down
1 change: 0 additions & 1 deletion requirements-benchmark.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,5 +7,4 @@ scikit-learn == 0.24.1
tqdm == 4.59.0
pandas == 1.2.4
seaborn == 0.11.1
triton == 1.1.2.dev20220106
pytorch-lightning >= 1.3
3 changes: 3 additions & 0 deletions requirements-test.txt
Original file line number Diff line number Diff line change
Expand Up @@ -24,3 +24,6 @@ hydra-core >= 1.1

# Dependency for Mixture of Experts
fairscale >= 0.4.5

# Dependency for fused layers, optional
triton == 2.0.0.dev20220403
66 changes: 49 additions & 17 deletions tests/test_sparse_tensors.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
def _create_blocksparse_tensor(
device, block_size=32, Z=8, C=2, H=64, W=64, dtype=torch.float32
):
layout = torch.randint(2, (C, H // block_size, W // block_size))
layout = torch.randint(2, (C, H // block_size, W // block_size), device=device)
layout[:, :, 0] = 1
layout[:, 0, :] = 1
values = torch.randn(Z, layout.sum(), block_size, block_size, device=device).to(
Expand Down Expand Up @@ -56,6 +56,29 @@ def _create_tensor(tensor_type, device, dtype, shape, sparsity):
)


def _seed():
torch.random.manual_seed(42)
torch.cuda.manual_seed_all(42)


def _get_dtype_atol(tensor_type, device: str):
_seed()

if tensor_type == BlockSparseTensor and "cuda" in device:
# Upstream GPU blocksparse (Triton op) uses TF32 by default for all internal computations
# TF32 has the precision of fp16 but the range of fp32
# See https://blogs.nvidia.com/blog/2020/05/14/tensorfloat-32-precision-format/
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
return torch.float32, 1e-1

# Force pytorch to keep its computations as float32 (will default to tf32 with recent cuda and ampere+ GPU)
torch.backends.cuda.matmul.allow_tf32 = False
torch.backends.cudnn.allow_tf32 = False

return torch.float32, 1e-5


@pytest.mark.parametrize("device", _devices)
@pytest.mark.parametrize("func", [torch.add, torch.mul])
def test_sparse_binary_ops(func, device):
Expand Down Expand Up @@ -83,6 +106,7 @@ def test_sparse_binary_ops(func, device):
def test_masked_matmul(tensor_type, device):
N, C, H, W, L = 8, 2, 64, 64, 32
sparsity = 0.7
dtype, atol = _get_dtype_atol(tensor_type, device)

shape0 = (N, C, H, W)
shape1 = (N, C, H, L)
Expand All @@ -98,8 +122,8 @@ def test_masked_matmul(tensor_type, device):
)
mask = mask_sparse.to_dense()

a = torch.randn(shape1, device=device)
b = torch.randn(shape2, device=device)
a = torch.randn(shape1, device=device, dtype=dtype)
b = torch.randn(shape2, device=device, dtype=dtype)

aa = a.clone()
bb = b.clone()
Expand All @@ -119,24 +143,23 @@ def test_masked_matmul(tensor_type, device):
res_dense = torch.where(mask, res_dense, torch.full_like(res_dense, float("-inf")))

assert res.dtype == res_gt.dtype
assert torch.allclose(res_dense, res_gt, atol=5e-6)
assert torch.allclose(res_dense, res_gt, atol=atol)

# try to workaround non-contiguous issues with triton for now
res_gt.backward(torch.ones_like(res_gt))
res.values().backward(torch.ones_like(res.values()))
# TODO: this is not passing for BlockSparse!!!
if tensor_type != BlockSparseTensor:
assert torch.allclose(a.grad, aa.grad, atol=5e-6)
assert torch.allclose(b.grad, bb.grad, atol=5e-6)

assert torch.allclose(a.grad, aa.grad, atol=atol)
assert torch.allclose(b.grad, bb.grad, atol=atol)


@pytest.mark.parametrize("tensor_type", _tensor_types)
@pytest.mark.parametrize("device", _devices)
def test_bmm(tensor_type, device):
N, C, H, W, L = 8, 2, 64, 64, 32
dtype = torch.float32
sparsity = 0.8
dtype, atol = _get_dtype_atol(tensor_type, device)

sparsity = 0.8
shape0 = (N, C, H, W)
shape1 = (N, C, W, L)

Expand All @@ -153,7 +176,7 @@ def test_bmm(tensor_type, device):
a_sparse.requires_grad_(True)
a.requires_grad_(True)

b = torch.randn(shape1, device=device)
b = torch.randn(shape1, device=device, dtype=dtype)
b2 = b.clone()

b.requires_grad_(True)
Expand All @@ -163,23 +186,28 @@ def test_bmm(tensor_type, device):
res = a_sparse @ b2

assert res.dtype == res_gt.dtype
assert torch.allclose(res, res_gt, atol=1e-5)
assert torch.allclose(
res, res_gt, atol=atol
), f"{torch.max(torch.abs(res-res_gt))} - tolerance: {atol}"

res_gt.sum().backward()
res.sum().backward()

a_grad = a.grad.clone().detach()
a_grad[~mask] = 0

assert torch.allclose(b.grad, b2.grad, atol=1e-5)
assert torch.allclose(a_grad, a_sparse.grad.to_dense(), atol=1e-5)
assert torch.allclose(b.grad, b2.grad, atol=atol)
assert torch.allclose(
a_grad, a_sparse.grad.to_dense(), atol=atol
), f"{torch.max(torch.abs(a_grad-a_sparse.grad.to_dense()))}"


@pytest.mark.parametrize("tensor_type", _tensor_types)
@pytest.mark.parametrize("device", _devices)
def test_sparse_softmax(tensor_type, device):
N, C, H, W = 8, 2, 64, 64
dtype = torch.float32
dtype, atol = _get_dtype_atol(tensor_type, device)

sparsity = 0.8

shape0 = (N, C, H, W)
Expand All @@ -203,7 +231,9 @@ def test_sparse_softmax(tensor_type, device):
res = res_sparse.to_dense()

assert res.dtype == res_gt.dtype
assert torch.allclose(res, res_gt)
assert torch.allclose(
res, res_gt, atol=atol
), f"{torch.max(torch.abs(res- res_gt))}"

# WARNING: gradients are modified in-place!
res_sparse.values().backward(torch.ones_like(res_sparse.values()))
Expand All @@ -212,7 +242,9 @@ def test_sparse_softmax(tensor_type, device):
a_grad = a.grad.clone()
a_grad[~mask] = 0

assert torch.allclose(a_grad, a_sparse.grad.to_dense(), atol=1e-6)
assert torch.allclose(
a_grad, a_sparse.grad.to_dense(), atol=atol
), f"{torch.max(torch.abs(a_grad- a_sparse.grad.to_dense()))}"


@pytest.mark.parametrize("tensor_type", _tensor_types)
Expand Down
4 changes: 2 additions & 2 deletions tests/test_triton_basics.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
if _triton_available:

@triton.jit
def k_mean(X, Mean, Var, stride, N, **META):
def k_mean(X, Mean, Var, stride, N, BLOCK_SIZE_N: tl.constexpr):
# fmt: on
"""
Fused layernorm kernel over a 3d tensor.
Expand All @@ -47,7 +47,7 @@ def k_mean(X, Mean, Var, stride, N, **META):
"""

row = tl.program_id(0)
cols = tl.arange(0, META["BLOCK_SIZE_N"])
cols = tl.arange(0, BLOCK_SIZE_N)

# Move to this row
x_ptrs = X + row * stride + cols
Expand Down

0 comments on commit 4ecbec1

Please sign in to comment.