Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[backend] 3/3 Triton 2 update #272

Merged
merged 5 commits into from
Apr 21, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
10 changes: 4 additions & 6 deletions BENCHMARKS.md
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,7 @@ Some examples, generated with `python3 xformers/benchmarks/benchmark_encoder.py

### Fused softmax

You can reproduce these numbers locally by running `python3 xformers/benchmarks/benchmark_triton_softmax.py`. The units are GB/s. These results are for a laptop nVidia 3080, Triton 1.1 and PyTorch 1.10.

You can reproduce these numbers locally by running `python3 xformers/benchmarks/benchmark_triton_softmax.py`. The units are GB/s. These results are for a laptop nVidia 3080, Triton 2.0 and PyTorch 1.12.

![Softmax throughput in fp16 - inference](docs/plots/fused_softmax/Softmax_Bandwidth_FW_fp16.png)

Expand All @@ -52,8 +51,7 @@ You can reproduce these numbers locally by running `python3 xformers/benchmarks/

### Fused linear layer

You can reproduce these numbers locally by running `python3 xformers/benchmarks/benchmark_triton_fused_linear_layer.py`. The units are TFlops/s. These results are for a laptop nVidia 3080, Triton 1.1 and PyTorch 1.10.

You can reproduce these numbers locally by running `python3 xformers/benchmarks/benchmark_triton_fused_linear_layer.py`. The units are TFlops/s. These results are for a laptop nVidia 3080, Triton 2.0 and PyTorch 1.12.

![Fused linear layers throughput in fp16 - inference](docs/plots/fused_linear/FusedLinear_fp16_FW_gelu.png)

Expand All @@ -77,7 +75,7 @@ You can reproduce these numbers locally by running `python3 xformers/benchmarks/

### Fused layer norm

You can reproduce these numbers locally by running `python3 xformers/benchmarks/benchmark_triton_layernorm.py`. The units are GB/s. These results are for a laptop nVidia 3080, Triton 1.1 and PyTorch 1.10.
You can reproduce these numbers locally by running `python3 xformers/benchmarks/benchmark_triton_layernorm.py`. The units are GB/s. These results are for a laptop nVidia 3080, Triton 2.0 and PyTorch 1.12.

![Fused layer norm throughput in fp16 - inference](docs/plots/layer_norm/LayerNorm_FW_torch.float16.png)

Expand All @@ -89,7 +87,7 @@ You can reproduce these numbers locally by running `python3 xformers/benchmarks/

### Fused dropout + bias + activation

You can reproduce these numbers locally by running `python3 xformers/benchmarks/benchmark_triton_dropout.py`. The units are GB/s. These results are for a laptop nVidia 3080, Triton 1.1 and PyTorch 1.10.
You can reproduce these numbers locally by running `python3 xformers/benchmarks/benchmark_triton_dropout.py`. The units are GB/s. These results are for a laptop nVidia 3080, Triton 2.0 and PyTorch 1.12.

![Fused dropout+ bias throughput in fp16 - inference](docs/plots/fused_dropout/Dropout_Bias_True_FW_torch.float16_Act_gelu.png)

Expand Down
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Mem efficient attention, FW pass [#267]
- MHA benchmark
- MLP benchmark
- Move all triton kernels to triton v2 [#272]

## [0.0.10] - 2022-03-14
### Fixed
Expand Down
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified docs/plots/fused_linear/FusedLinear_fp16_FW_BW_gelu.png
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified docs/plots/fused_linear/FusedLinear_fp16_FW_BW_leaky_relu.png
Binary file modified docs/plots/fused_linear/FusedLinear_fp16_FW_BW_none.png
Binary file modified docs/plots/fused_linear/FusedLinear_fp16_FW_BW_relu.png
Binary file modified docs/plots/fused_linear/FusedLinear_fp16_FW_BW_squared_relu.png
Binary file modified docs/plots/fused_linear/FusedLinear_fp16_FW_gelu.png
Binary file modified docs/plots/fused_linear/FusedLinear_fp16_FW_leaky_relu.png
Binary file modified docs/plots/fused_linear/FusedLinear_fp16_FW_none.png
Binary file modified docs/plots/fused_linear/FusedLinear_fp16_FW_relu.png
Binary file modified docs/plots/fused_linear/FusedLinear_fp16_FW_squared_relu.png
Binary file modified docs/plots/fused_softmax/Softmax_Bandwidth_FW_BW_fp16.png
Binary file modified docs/plots/fused_softmax/Softmax_Bandwidth_FW_BW_fp32.png
Binary file modified docs/plots/fused_softmax/Softmax_Bandwidth_FW_fp16.png
Binary file modified docs/plots/fused_softmax/Softmax_Bandwidth_FW_fp32.png
Binary file modified docs/plots/layer_norm/LayerNorm_FW+BW_torch.float16.png
Binary file modified docs/plots/layer_norm/LayerNorm_FW+BW_torch.float32.png
Binary file modified docs/plots/layer_norm/LayerNorm_FW_torch.float16.png
Binary file modified docs/plots/layer_norm/LayerNorm_FW_torch.float32.png
1 change: 0 additions & 1 deletion docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@
# The full version, including alpha/beta/rc tags
release = "0.0.10"


# -- General configuration ---------------------------------------------------

# Add any Sphinx extension module names here, as strings. They can be
Expand Down
2 changes: 1 addition & 1 deletion examples/microGPT.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@


class GPT(pl.LightningModule):
""" the full GPT language model, with a context size of block_size """
"""the full GPT language model, with a context size of block_size"""

def __init__(
self,
Expand Down
7 changes: 3 additions & 4 deletions experimental/ragged_inference/garbage_pad_ragged_acts.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,9 @@ def garbage_pad_ragged_acts_kernel(
ragged_acts_offset_per_seq_ptr,
n_ctx_per_seq_ptr,
padded_acts_ptr,
**meta, # Optional meta-parameters for the kernel
BLOCK_SIZE: tl.constexpr, # How many inputs each program should process
n_ctx_max: tl.constexpr,
):
BLOCK_SIZE = meta["d_model"] # How many inputs each program should process
# There are multiple 'program's processing different data. We identify which program
# we are here

Expand All @@ -47,7 +47,6 @@ def garbage_pad_ragged_acts_kernel(
acts = tl.load(ragged_acts_ptr + ragged_acts_offsets, mask=ctx_idx_too_large_mask)

# Calculate the offsets for the padded acts
n_ctx_max = meta["n_ctx_max"]
padded_acts_offset = n_ctx_max * seq_idx * BLOCK_SIZE

# Write things back, again masking out the sections that would be garbage
Expand Down Expand Up @@ -153,7 +152,7 @@ def triton_to_garbage_padded(self) -> torch.Tensor:
torch.tensor(ragged_acts_offset_per_seq, device="cuda"),
torch.tensor(self.n_ctx_per_seq, device="cuda"),
padded_acts,
d_model=d_model,
BLOCK_SIZE=d_model,
n_ctx_max=n_ctx_max,
)
return padded_acts
Expand Down
1 change: 0 additions & 1 deletion requirements-benchmark.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,5 +7,4 @@ scikit-learn == 0.24.1
tqdm == 4.59.0
pandas == 1.2.4
seaborn == 0.11.1
triton == 1.1.2.dev20220106
pytorch-lightning >= 1.3
3 changes: 3 additions & 0 deletions requirements-test.txt
Original file line number Diff line number Diff line change
Expand Up @@ -24,3 +24,6 @@ hydra-core >= 1.1

# Dependency for Mixture of Experts
fairscale >= 0.4.5

# Dependency for fused layers, optional
triton == 2.0.0.dev20220403
66 changes: 49 additions & 17 deletions tests/test_sparse_tensors.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
def _create_blocksparse_tensor(
device, block_size=32, Z=8, C=2, H=64, W=64, dtype=torch.float32
):
layout = torch.randint(2, (C, H // block_size, W // block_size))
layout = torch.randint(2, (C, H // block_size, W // block_size), device=device)
layout[:, :, 0] = 1
layout[:, 0, :] = 1
values = torch.randn(Z, layout.sum(), block_size, block_size, device=device).to(
Expand Down Expand Up @@ -56,6 +56,29 @@ def _create_tensor(tensor_type, device, dtype, shape, sparsity):
)


def _seed():
torch.random.manual_seed(42)
torch.cuda.manual_seed_all(42)


def _get_dtype_atol(tensor_type, device: str):
_seed()
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this was to remove some reproducibility issues in between circleci and my machine..


if tensor_type == BlockSparseTensor and "cuda" in device:
# Upstream GPU blocksparse (Triton op) uses TF32 by default for all internal computations
# TF32 has the precision of fp16 but the range of fp32
# See https://blogs.nvidia.com/blog/2020/05/14/tensorfloat-32-precision-format/
torch.backends.cuda.matmul.allow_tf32 = True
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@fmassa this seems to be a better fit following the switch to triton2, which internally moved all tl.dot() operations to tf32

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cc @ptillet, just swapping triton 1.1 for 2.dev meant that this test would not pass anymore, as we discussed

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

SGTM wrt the tests!

torch.backends.cudnn.allow_tf32 = True
return torch.float32, 1e-1
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

wow, that is quite some low precision...


# Force pytorch to keep its computations as float32 (will default to tf32 with recent cuda and ampere+ GPU)
torch.backends.cuda.matmul.allow_tf32 = False
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@fmassa this fixed issues that I was seeing with these unit tests on an ampere GPU, which I presume stemmed from the fact that the sparse kernels were fp32 while pytorch defaulted to tf32

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh wow, thanks for spotting this!

One more instance where tf32 is being somewhat harmful. Maybe worth commenting on pytorch/pytorch#67384 ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's a strange format, range of fp32 but precision of fp16, it's also kind of peculiar that it's really 18bits but named tf32..

torch.backends.cudnn.allow_tf32 = False

return torch.float32, 1e-5


@pytest.mark.parametrize("device", _devices)
@pytest.mark.parametrize("func", [torch.add, torch.mul])
def test_sparse_binary_ops(func, device):
Expand Down Expand Up @@ -83,6 +106,7 @@ def test_sparse_binary_ops(func, device):
def test_masked_matmul(tensor_type, device):
N, C, H, W, L = 8, 2, 64, 64, 32
sparsity = 0.7
dtype, atol = _get_dtype_atol(tensor_type, device)

shape0 = (N, C, H, W)
shape1 = (N, C, H, L)
Expand All @@ -98,8 +122,8 @@ def test_masked_matmul(tensor_type, device):
)
mask = mask_sparse.to_dense()

a = torch.randn(shape1, device=device)
b = torch.randn(shape2, device=device)
a = torch.randn(shape1, device=device, dtype=dtype)
b = torch.randn(shape2, device=device, dtype=dtype)

aa = a.clone()
bb = b.clone()
Expand All @@ -119,24 +143,23 @@ def test_masked_matmul(tensor_type, device):
res_dense = torch.where(mask, res_dense, torch.full_like(res_dense, float("-inf")))

assert res.dtype == res_gt.dtype
assert torch.allclose(res_dense, res_gt, atol=5e-6)
assert torch.allclose(res_dense, res_gt, atol=atol)

# try to workaround non-contiguous issues with triton for now
res_gt.backward(torch.ones_like(res_gt))
res.values().backward(torch.ones_like(res.values()))
# TODO: this is not passing for BlockSparse!!!
if tensor_type != BlockSparseTensor:
assert torch.allclose(a.grad, aa.grad, atol=5e-6)
assert torch.allclose(b.grad, bb.grad, atol=5e-6)

assert torch.allclose(a.grad, aa.grad, atol=atol)
assert torch.allclose(b.grad, bb.grad, atol=atol)


@pytest.mark.parametrize("tensor_type", _tensor_types)
@pytest.mark.parametrize("device", _devices)
def test_bmm(tensor_type, device):
N, C, H, W, L = 8, 2, 64, 64, 32
dtype = torch.float32
sparsity = 0.8
dtype, atol = _get_dtype_atol(tensor_type, device)

sparsity = 0.8
shape0 = (N, C, H, W)
shape1 = (N, C, W, L)

Expand All @@ -153,7 +176,7 @@ def test_bmm(tensor_type, device):
a_sparse.requires_grad_(True)
a.requires_grad_(True)

b = torch.randn(shape1, device=device)
b = torch.randn(shape1, device=device, dtype=dtype)
b2 = b.clone()

b.requires_grad_(True)
Expand All @@ -163,23 +186,28 @@ def test_bmm(tensor_type, device):
res = a_sparse @ b2

assert res.dtype == res_gt.dtype
assert torch.allclose(res, res_gt, atol=1e-5)
assert torch.allclose(
res, res_gt, atol=atol
), f"{torch.max(torch.abs(res-res_gt))} - tolerance: {atol}"

res_gt.sum().backward()
res.sum().backward()

a_grad = a.grad.clone().detach()
a_grad[~mask] = 0

assert torch.allclose(b.grad, b2.grad, atol=1e-5)
assert torch.allclose(a_grad, a_sparse.grad.to_dense(), atol=1e-5)
assert torch.allclose(b.grad, b2.grad, atol=atol)
assert torch.allclose(
a_grad, a_sparse.grad.to_dense(), atol=atol
), f"{torch.max(torch.abs(a_grad-a_sparse.grad.to_dense()))}"


@pytest.mark.parametrize("tensor_type", _tensor_types)
@pytest.mark.parametrize("device", _devices)
def test_sparse_softmax(tensor_type, device):
N, C, H, W = 8, 2, 64, 64
dtype = torch.float32
dtype, atol = _get_dtype_atol(tensor_type, device)

sparsity = 0.8

shape0 = (N, C, H, W)
Expand All @@ -203,7 +231,9 @@ def test_sparse_softmax(tensor_type, device):
res = res_sparse.to_dense()

assert res.dtype == res_gt.dtype
assert torch.allclose(res, res_gt)
assert torch.allclose(
res, res_gt, atol=atol
), f"{torch.max(torch.abs(res- res_gt))}"

# WARNING: gradients are modified in-place!
res_sparse.values().backward(torch.ones_like(res_sparse.values()))
Expand All @@ -212,7 +242,9 @@ def test_sparse_softmax(tensor_type, device):
a_grad = a.grad.clone()
a_grad[~mask] = 0

assert torch.allclose(a_grad, a_sparse.grad.to_dense(), atol=1e-6)
assert torch.allclose(
a_grad, a_sparse.grad.to_dense(), atol=atol
), f"{torch.max(torch.abs(a_grad- a_sparse.grad.to_dense()))}"


@pytest.mark.parametrize("tensor_type", _tensor_types)
Expand Down
4 changes: 2 additions & 2 deletions tests/test_triton_basics.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
if _triton_available:

@triton.jit
def k_mean(X, Mean, Var, stride, N, **META):
def k_mean(X, Mean, Var, stride, N, BLOCK_SIZE_N: tl.constexpr):
# fmt: on
"""
Fused layernorm kernel over a 3d tensor.
Expand All @@ -47,7 +47,7 @@ def k_mean(X, Mean, Var, stride, N, **META):
"""

row = tl.program_id(0)
cols = tl.arange(0, META["BLOCK_SIZE_N"])
cols = tl.arange(0, BLOCK_SIZE_N)

# Move to this row
x_ptrs = X + row * stride + cols
Expand Down