Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[chore][DRAFT] Updating triton to a recent release #418

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
8 changes: 4 additions & 4 deletions BENCHMARKS.md
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ Some examples, generated with `python3 xformers/benchmarks/benchmark_encoder.py

### Fused softmax

You can reproduce these numbers locally by running `python3 xformers/benchmarks/benchmark_triton_softmax.py`. The units are GB/s. These results are for a laptop nVidia 3080, Triton 2.0 and PyTorch 1.12.
You can reproduce these numbers locally by running `python3 xformers/benchmarks/benchmark_triton_softmax.py`. The units are GB/s. These results are for a nvidia A6000, Triton 2.0 and PyTorch 1.12.

![Softmax throughput in fp16 - inference](docs/plots/fused_softmax/Softmax_Bandwidth_FW_fp16.png)

Expand All @@ -51,7 +51,7 @@ You can reproduce these numbers locally by running `python3 xformers/benchmarks/

### Fused linear layer

You can reproduce these numbers locally by running `python3 xformers/benchmarks/benchmark_triton_fused_linear_layer.py`. The units are TFlops/s. These results are for a laptop nVidia 3080, Triton 2.0 and PyTorch 1.12.
You can reproduce these numbers locally by running `python3 xformers/benchmarks/benchmark_triton_fused_linear_layer.py`. The units are TFlops/s. These results are for a nvidia A6000, Triton 2.0 and PyTorch 1.12.

![Fused linear layers throughput in fp16 - inference](docs/plots/fused_linear/FusedLinear_fp16_FW_gelu.png)

Expand All @@ -75,7 +75,7 @@ You can reproduce these numbers locally by running `python3 xformers/benchmarks/

### Fused layer norm

You can reproduce these numbers locally by running `python3 xformers/benchmarks/benchmark_triton_layernorm.py`. The units are GB/s. These results are for a laptop nVidia 3080, Triton 2.0 and PyTorch 1.12.
You can reproduce these numbers locally by running `python3 xformers/benchmarks/benchmark_triton_layernorm.py`. The units are GB/s. These results are for a nvidia A6000, Triton 2.0 and PyTorch 1.12.

![Fused layer norm throughput in fp16 - inference](docs/plots/layer_norm/LayerNorm_FW_torch.float16.png)

Expand All @@ -87,7 +87,7 @@ You can reproduce these numbers locally by running `python3 xformers/benchmarks/

### Fused dropout + bias + activation

You can reproduce these numbers locally by running `python3 xformers/benchmarks/benchmark_triton_dropout.py`. The units are GB/s. These results are for a laptop nVidia 3080, Triton 2.0 and PyTorch 1.12.
You can reproduce these numbers locally by running `python3 xformers/benchmarks/benchmark_triton_dropout.py`. The units are GB/s. These results are for a nvidia A6000, Triton 2.0 and PyTorch 1.12.

![Fused dropout+ bias throughput in fp16 - inference](docs/plots/fused_dropout/Dropout_Bias_True_FW_torch.float16_Act_gelu.png)

Expand Down
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

## TBD
### Fixed
- Updated triton dependency [#418]

### Added

Expand Down
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified docs/plots/fused_linear/FusedLinear_fp16_FW_BW_gelu.png
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified docs/plots/fused_linear/FusedLinear_fp16_FW_BW_leaky_relu.png
Binary file modified docs/plots/fused_linear/FusedLinear_fp16_FW_BW_none.png
Binary file modified docs/plots/fused_linear/FusedLinear_fp16_FW_BW_relu.png
Binary file modified docs/plots/fused_linear/FusedLinear_fp16_FW_BW_smelu.png
Binary file modified docs/plots/fused_linear/FusedLinear_fp16_FW_BW_squared_relu.png
Binary file modified docs/plots/fused_linear/FusedLinear_fp16_FW_gelu.png
Binary file modified docs/plots/fused_linear/FusedLinear_fp16_FW_leaky_relu.png
Binary file modified docs/plots/fused_linear/FusedLinear_fp16_FW_none.png
Binary file modified docs/plots/fused_linear/FusedLinear_fp16_FW_relu.png
Binary file modified docs/plots/fused_linear/FusedLinear_fp16_FW_smelu.png
Binary file modified docs/plots/fused_linear/FusedLinear_fp16_FW_squared_relu.png
Binary file modified docs/plots/fused_linear/FusedLinear_fp32_FW_BW_gelu.png
Binary file modified docs/plots/fused_linear/FusedLinear_fp32_FW_BW_leaky_relu.png
Binary file modified docs/plots/fused_linear/FusedLinear_fp32_FW_BW_none.png
Binary file modified docs/plots/fused_linear/FusedLinear_fp32_FW_BW_relu.png
Binary file modified docs/plots/fused_linear/FusedLinear_fp32_FW_BW_squared_relu.png
Binary file modified docs/plots/fused_linear/FusedLinear_fp32_FW_gelu.png
Binary file modified docs/plots/fused_linear/FusedLinear_fp32_FW_leaky_relu.png
Binary file modified docs/plots/fused_linear/FusedLinear_fp32_FW_none.png
Binary file modified docs/plots/fused_linear/FusedLinear_fp32_FW_relu.png
Binary file modified docs/plots/fused_linear/FusedLinear_fp32_FW_squared_relu.png
Binary file modified docs/plots/fused_softmax/Softmax_Bandwidth_FW_BW_fp16.png
Binary file modified docs/plots/fused_softmax/Softmax_Bandwidth_FW_BW_fp32.png
Binary file modified docs/plots/fused_softmax/Softmax_Bandwidth_FW_fp16.png
Binary file modified docs/plots/fused_softmax/Softmax_Bandwidth_FW_fp32.png
Binary file modified docs/plots/layer_norm/LayerNorm_FW+BW_torch.float16.png
Binary file modified docs/plots/layer_norm/LayerNorm_FW+BW_torch.float32.png
Binary file modified docs/plots/layer_norm/LayerNorm_FW_torch.float16.png
Binary file modified docs/plots/layer_norm/LayerNorm_FW_torch.float32.png
6 changes: 3 additions & 3 deletions docs/source/tutorials/triton.rst
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ This is a drop-in replacement to two PyTorch operands: a `torch.nn.Linear`, and
It is possible to skip either the bias or the activation (just use `None` in that case). As of September 2021, this layer is **faster than PyTorch for non-sigmoid activations and fp16**.
In all other usecases, you will be better served using PyTorch.

The following is an example of the measured performance on a laptop nVidia 3080, using Triton 1.1 and PyTorch 1.10.
The following is an example of the measured performance on a nvidia A6000, using Triton 1.1 and PyTorch 1.10.

.. image:: ../../plots/fused_linear/FusedLinear_fp16_FW_gelu.png
:width: 600
Expand Down Expand Up @@ -107,7 +107,7 @@ The following is an example of the measured performance on a laptop nVidia 3080,
Fused layer norm
-----------------

You can reproduce these numbers locally by running `python3 xformers/benchmarks/benchmark_triton_layernorm.py`. The units are GB/s. These results are for a laptop nVidia 3080, Triton 1.1 and PyTorch 1.10.
You can reproduce these numbers locally by running `python3 xformers/benchmarks/benchmark_triton_layernorm.py`. The units are GB/s. These results are for a nvidia A6000, Triton 1.1 and PyTorch 1.10.

.. image:: ../../plots/layer_norm/LayerNorm_FW_torch.float16.png
:width: 600
Expand All @@ -130,7 +130,7 @@ Fused dropout + bias + activation
---------------------------------

You can reproduce these numbers locally by running `python3 xformers/benchmarks/benchmark_triton_dropout.py`. The units are GB/s.
These results are for a laptop nVidia 3080, Triton 1.1 and PyTorch 1.10.
These results are for a nvidia A6000, Triton 1.1 and PyTorch 1.10.


.. image:: ../../plots/fused_dropout/Dropout_Bias_True_FW_torch.float16_Act_gelu.png
Expand Down
2 changes: 1 addition & 1 deletion requirements-test.txt
Original file line number Diff line number Diff line change
Expand Up @@ -27,4 +27,4 @@ hydra-core >= 1.1
fairscale >= 0.4.5

# Dependency for fused layers, optional
triton == 2.0.0.dev20220701
triton == 2.0.0.dev20220830
47 changes: 47 additions & 0 deletions tests/test_triton_basics.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,3 +131,50 @@ def test_sum_strided_asserts():
with pytest.raises(AssertionError):
# This kernel expects 2D tensors, assert to prevent misuse
sum_2d_dim_0(a)

@triton.jit
def k_rand(X, Y, SEED_X, SEED_Y, stride_x, stride_y, N: tl.constexpr):
# fmt: on
"""
Check the random number generation
"""

row = tl.program_id(0)

# Generate random numbers with seed A
rand_offsets = tl.arange(0, N)
seed_x = tl.load(SEED_X + row)
randx, _, _, _ = tl.randint4x(seed_x, rand_offsets)

rand_offsets = tl.arange(0, N)
seed_y = tl.load(SEED_Y + row)
randy, _, _, _ = tl.randint4x(seed_y, rand_offsets)

# Move to this row
tl.store(X + row * stride_x + tl.arange(0, N), randx)
tl.store(Y + row * stride_y + tl.arange(0, N), randy)

def test_rand():
# Check that the random generator used in triton works fine
torch.random.manual_seed(0)
x = torch.zeros((512, 32), device=torch.device("cuda"), dtype=torch.int32)
y = torch.zeros((512, 32), device=torch.device("cuda"), dtype=torch.int32)

M, N = x.shape

seeds_x = torch.randint(65536, (M,), device=x.device)
seeds_y = torch.randint(65536, (M,), device=x.device)

assert not torch.allclose(seeds_x, seeds_y)

# enqueue kernels, one per line
# fmt: off
k_rand[(M,)](
x, y,
seeds_x, seeds_y,
x.stride(0), y.stride(0),
N,
)
# fmt: on

assert not torch.allclose(x, y)
14 changes: 6 additions & 8 deletions tests/test_triton_dropout.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,13 +30,11 @@
# Testing odd (non-power-of-two for instance) shapes on purpose
SHAPES = [
(384, 512),
(8, 384, 128),
(8, 784, 512),
(4, 16, 384),
(4, 16, 1024),
(2, 16, 2048),
(2, 16, 4096),
(1, 16, 12288),
]


Expand Down Expand Up @@ -98,6 +96,11 @@ def test_dropout(shape, amp, bias, p):
x_ref = (x + b if bias else x).to(y.dtype)
assert not torch.allclose(x_ref, y, rtol=tol)

# Check that the drop probability is about right
y = triton_dropout(x, p=p)
drop_p = (y.numel() - y.count_nonzero()) / y.numel()
assert abs(drop_p - p) < 0.01

# Check that the drops are different for every row (could catch broken seeds per row)
y = triton_dropout(x, p=0.5)

Expand All @@ -114,12 +117,7 @@ def test_dropout(shape, amp, bias, p):
assert (
not torch.sum(torch.eq(y_a[0, :] == 0.0, y_b[0, :] == 0.0)).item()
== y.shape[1]
)

# Check that the drop probability is about right
y = triton_dropout(x, p=p)
drop_p = (y.numel() - y.count_nonzero()) / y.numel()
assert abs(drop_p - p) < 0.01
), f"{y_a}\n{y_b}"

# Check that the same seeds lead to the same dropout
torch.manual_seed(0)
Expand Down
55 changes: 17 additions & 38 deletions xformers/triton/dropout.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,32 +4,28 @@
# LICENSE file in the root directory of this source tree.


# CREDITS: This comes almost as-is from the Triton dropout tutorial
# CREDITS: This is heavily inspired by the Triton dropout tutorial
# https://raw.githubusercontent.com/openai/triton/master/python/tutorials/04-low-memory-dropout.py

from typing import Any, Optional
from typing import Optional

import torch
import triton
from torch.cuda.amp import custom_bwd, custom_fwd

from xformers.components.activations import Activation, build_activation
from xformers.triton.k_activations import (
get_triton_activation_bwd_kernel,
get_triton_activation_kernel,
)
from xformers.triton.k_activations import get_triton_activation_index
from xformers.triton.k_dropout import k_dropout_bw, k_dropout_fw

GROUP_M = 32
BLOCK_M = GROUP_M // 4
BLOCK_M = 32
BLOCK_N = 128


# Helper to handle the SPMD launch grid and error cases
class _dropout(torch.autograd.Function):
@staticmethod
@custom_fwd(cast_inputs=torch.float16)
def forward(ctx, x, p, bias, activation, activation_grad, trainable_bias):
def forward(ctx, x, p, bias, activation, trainable_bias):
# Soft-flatten an hypothetical 3rd dimension
x_ = x.reshape(-1, x.shape[-1]).contiguous()
y = torch.empty_like(x_)
Expand All @@ -39,19 +35,16 @@ def forward(ctx, x, p, bias, activation, activation_grad, trainable_bias):
assert p > 0.0

def grid(meta):
# NOTE: We use Triton Philox random number generator, which optimally generates 4 blocks for
# a given seed and offsets. "BLOCK_M" here describes the size of one of these blocks
# but we need to take this factor of 4 into account when scheduling all the kernels
return (
triton.cdiv(M, meta["BLOCK_M"] * 4),
triton.cdiv(M, meta["BLOCK_M"]),
triton.cdiv(N, meta["BLOCK_N"]),
)

N_BLOCK_N = triton.cdiv(N, BLOCK_N)

# Generate one seed per sample
# seed max is int32 max for positive numbers: 2**16
seeds = torch.randint(65536, (N_BLOCK_N,), device=x.device).to(torch.int32)
seeds = torch.randint(65536, (N_BLOCK_N,), device=x.device, dtype=torch.int32)

# fmt: off
bias_ptr = bias if bias is not None else x_ # Possibly not being used
Expand All @@ -77,7 +70,7 @@ def grid(meta):
ctx.save_for_backward(seeds, bias, None)

ctx.trainable_bias = bias is not None and trainable_bias
ctx.activation_grad = activation_grad
ctx.activation = activation
ctx.p = p

return y.reshape_as(x)
Expand All @@ -96,7 +89,7 @@ def backward(
M, N = grad_out_.shape

# Optional inputs to compute the activation contribution to the gradient
assert inputs is not None or ctx.activation_grad is None
assert inputs is not None or ctx.activation is None

if inputs is None:
inputs = grad_out_
Expand All @@ -105,11 +98,10 @@ def backward(

# We split the problem in tiles:
# - over M there will be a follow up reduction
# - over M, we go by 4 tiles at at time (consequence of the random number generation)
# - over N we compromise in between trying to use as much memory paralellism as possible,
# (fill in the warps, there are 32 threads per warps, and 4 warps default), and not being too
# big because of register spilling
N_BLOCKS_M = triton.cdiv(M, GROUP_M)
N_BLOCKS_M = triton.cdiv(M, BLOCK_M)

if ctx.trainable_bias:
grad_bias = torch.empty(
Expand All @@ -129,7 +121,7 @@ def grid(meta):
# a given seed and offsets. "BLOCK_M" here describes the size of one of these blocks
# but we need to take this factor of 4 into account when scheduling all the kernels
return (
triton.cdiv(M, meta["BLOCK_M"] * 4),
triton.cdiv(M, meta["BLOCK_M"]),
triton.cdiv(N, meta["BLOCK_N"]),
)

Expand All @@ -143,7 +135,7 @@ def grid(meta):
ctx.p,
grad_in.dtype == torch.float16,
USE_BIAS=bias is not None,
ACTIVATION_GRAD=ctx.activation_grad,
ACTIVATION=ctx.activation,
TRAINABLE_BIAS=ctx.trainable_bias,
BLOCK_M=BLOCK_M,
BLOCK_N=BLOCK_N,
Expand Down Expand Up @@ -185,14 +177,12 @@ def dropout(
return x

# The normal triton enabled codepath
act_kernel = get_triton_activation_kernel(activation)
act_grad_kernel = get_triton_activation_bwd_kernel(activation)
activation_index = get_triton_activation_index(activation)
return _dropout.apply(
x,
float(p),
bias,
act_kernel,
act_grad_kernel,
activation_index,
bias is not None and bias.requires_grad,
)

Expand Down Expand Up @@ -224,9 +214,8 @@ def __init__(
else None
)

self.activation: Optional[Any] = None
self.activation_grad: Optional[Any] = None
self.activation_pytorch: Optional[Any] = None
self.activation = get_triton_activation_index(self.activation_type)
self.activation_pytorch = build_activation(self.activation_type)

def init_weights(self, *args, **kwargs):
with torch.no_grad():
Expand All @@ -238,14 +227,6 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
if self.bias is not None:
self.bias = self.bias.to(dtype=x.dtype, device=x.device) # type: ignore

# Lazy init (helps with pickling)
if self.activation is None or self.activation_pytorch is None:
self.activation = get_triton_activation_kernel(self.activation_type)
self.activation_pytorch = build_activation(self.activation_type)
self.activation_grad = get_triton_activation_bwd_kernel(
self.activation_type
)

# Train/inference
p = self.p if self.training else 0.0

Expand All @@ -259,6 +240,4 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
return torch.nn.functional.dropout(x, p) if p > 0.0 else x

# The normal, Triton-backed path
return _dropout.apply(
x, p, self.bias, self.activation, self.activation_grad, True
)
return _dropout.apply(x, p, self.bias, self.activation, True)
18 changes: 16 additions & 2 deletions xformers/triton/k_activations.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,20 @@ def get_triton_activation_bwd_kernel(activation: Optional[Activation]):
)


def get_triton_activation_index(activation: Optional[Activation]) -> Optional[int]:
return (
{
Activation.ReLU: 1,
Activation.LeakyReLU: 2,
Activation.GeLU: 3,
Activation.SquaredReLU: 4,
Activation.SmeLU: 5,
}[activation]
if activation
else 0
)


@triton.jit
def tanh(x):
# Tanh is just a scaled sigmoid
Expand Down Expand Up @@ -86,8 +100,8 @@ def squared_relu(x):

.. _Primer: https://arxiv.org/abs/2109.08668
"""
x_ = relu(x)
return (x_ * x_).to(x.dtype)
x_ = tl.where(x >= 0, x, 0.0)
return x_ * x_


@triton.jit
Expand Down