In [1]:
# https://pytorch.org/tutorials/intermediate/nvfuser_intro_tutorial.html
# Authors: 
#     Christian Sarofeen > https://github.com/csarofeen
#     Piotr Bialecki > https://github.com/ptrblck
#     Kevin Stephano > https://github.com/kevinstephano
#     Jie Jiang > https://github.com/jjsjann123
#     Masaki Kozuki https://github.com/crcrpar
#     Neal Vaidya
#     IvanYashchuk https://github.com/IvanYashchuk > https://twitter.com/IvanYashchuk
#     Svetlana Karslioglu https://github.com/svekars > 
    
import torch
import torch.nn.functional as F
import functorch
from functorch.compile import memory_efficient_fusion
from copy import deepcopy
from typing import List
import time
import functools
import random

random.seed(42)

torch_version = torch.__version__
print(f'torch_version: {torch_version}')
if torch_version < (1, 12, 0):
    raise RuntimeError(
        "PyTorch >= 1.12.0 required, but your environment uses torch=={}".format(
            torch.__version__
        )
    )

functorch_version = functorch.__version__
print(f'functorch_version: {functorch_version}')
major, minor, cuda_version = functorch_version.split(".")
if int(major) == 0 and int(minor) < 2:
    raise RuntimeError(
        "FuncTorch >= 0.2.0 required, but your environment uses functorch=={}".format(
            functorch.__version__
        )
    )

torch_version: 1.13.1+cu117
functorch_version: 1.13.1+cu117


In [2]:
def composite_definition(
    input1: torch.Tensor,
    input2: torch.Tensor,
    weight: torch.Tensor,
    bias1: torch.Tensor,
    bias2: torch.Tensor,
    normalization_axis: int,
    dropout_prob: float,
) -> torch.Tensor:
    bias1_out = input1 + bias1
    dropout_out = F.dropout(bias1_out, dropout_prob, training=True)
    norm_input = dropout_out + input2
    norm_output = F.layer_norm(
        norm_input, (input1.size(normalization_axis),), weight, bias2
    )
    return norm_output

In [3]:
# Setup initial tensors and parameters
input_size = [64, 128, 1024]
device = "cuda"
dtype = torch.float32

# Create sample inputs
input1 = torch.randn(*input_size, device=device, dtype=dtype, requires_grad=True)
input2 = torch.rand_like(input1).requires_grad_()

# Precompute a grad output tensor, for this example it's the same size
# as the inputs
grad_output = torch.rand_like(input1)

# Randomly initialize the model parameters
weight = torch.nn.Parameter(torch.randn(input_size[2], dtype=dtype, device=device))
bias1 = torch.nn.Parameter(torch.randn(input_size[2], dtype=dtype, device=device))
bias2 = torch.nn.Parameter(torch.randn(input_size[2], dtype=dtype, device=device))

parameters = [input1, input2, weight, bias1, bias2]

In [4]:
# Utility to profile the workload
def profile_workload(forward_func, grad_output, iteration_count=100, label=""):
    # Perform warm-up iterations
    for _ in range(3):
        # Run model, forward and backward
        output = forward_func()
        output.backward(grad_output)
        # delete gradiens to avoid profiling the gradient accumulation
        for p in parameters:
            p.grad = None

    # Synchronize the GPU before starting the timer
    torch.cuda.synchronize()
    start = time.perf_counter()
    for _ in range(iteration_count):
        # Run model, forward and backward
        output = forward_func()
        output.backward(grad_output)
        # delete gradiens to avoid profiling the gradient accumulation
        for p in parameters:
            p.grad = None

    # Synchronize the GPU before stopping the timer
    torch.cuda.synchronize()
    stop = time.perf_counter()
    iters_per_second = iteration_count / (stop - start)
    if label:
        print(label)
    print("Average iterations per second: {:.2f}".format(iters_per_second))

In [5]:
# Run and profile eager mode execution on the composite definition of our
# operations.
func = functools.partial(
    composite_definition,
    input1,
    input2,
    weight,
    bias1,
    bias2,
    normalization_axis=2,
    dropout_prob=0.1,
)
profile_workload(
    func, grad_output, iteration_count=100, label="Eager Mode - Composite definition"
)

Eager Mode - Composite definition
Average iterations per second: 216.45


### TorchScript & nvFuser
nvFuser is the default fusion system in TorchScript since PyTorch version 1.12, so to turn on nvFuser we need to enable TorchScript. This will allow nvFuser to automatically generate fast kernels and take over execution of these operations. TorchScript can be a challenging system to get working, but with our current definition of our operators, all we need to do is wrap our function in the torch.jit.script compile function. We can then simply run our workload as before.



In [6]:
scripted_composite_definition = torch.jit.script(composite_definition)
func = functools.partial(
    scripted_composite_definition,
    input1,
    input2,
    weight,
    bias1,
    bias2,
    normalization_axis=2,
    dropout_prob=0.1,
)
profile_workload(
    func, grad_output, iteration_count=100, label="TorchScript - Composite definition"
)

TorchScript - Composite definition
Average iterations per second: 344.64


nvFuser & Dynamic Shapes
It is challenging for Deep Learning Compilers to provide performance gains when the user changes the input sizes of the tensors. However, supporting changing shapes has always been a fundamental design criteria for nvFuser, as processing different-sized input tensors is critical to many applications like Natural Language Processing and Graph Neural Networks.

In [7]:
SHAPE_COUNT = 20
dynamic_sizes = deepcopy(input_size)

inputs1: List[torch.Tensor] = []
inputs2: List[torch.Tensor] = []
grad_outputs: List[torch.Tensor] = []


# Create some random shapes
for _ in range(SHAPE_COUNT):
    dynamic_sizes[0] = input_size[0] + random.randrange(-2, 3)
    dynamic_sizes[1] = input_size[1] + random.randrange(-2, 3)
    input = torch.randn(*dynamic_sizes, device=device, dtype=dtype, requires_grad=True)
    inputs1.append(input)
    inputs2.append(torch.rand_like(input))
    grad_outputs.append(torch.rand_like(input))

In [8]:
# Perform warm-up iterations
for _ in range(3):
    dynamic_input1 = inputs1[0]
    dynamic_input2 = inputs2[0]
    dynamic_grad_output = grad_outputs[0]
    # Run model, forward and backward
    output = scripted_composite_definition(
        dynamic_input1,
        dynamic_input2,
        weight,
        bias1,
        bias2,
        normalization_axis=2,
        dropout_prob=0.1,
    )
    output.backward(dynamic_grad_output)

In [9]:
# Profile manually as our helper function expects static inputs
iteration_count = 100
# Synchronize the GPU before starting the timer
torch.cuda.synchronize()
start = time.perf_counter()
for i in range(iteration_count):
    dynamic_input1 = inputs1[i % SHAPE_COUNT]
    dynamic_input2 = inputs2[i % SHAPE_COUNT]
    dynamic_grad_output = grad_outputs[i % SHAPE_COUNT]
    dynamic_parameters = [dynamic_input1, dynamic_input2, weight, bias1, bias2]

    # Run model, forward and backward
    output = scripted_composite_definition(
        dynamic_input1,
        dynamic_input2,
        weight,
        bias1,
        bias2,
        normalization_axis=2,
        dropout_prob=0.1,
    )
    output.backward(dynamic_grad_output)
    # Delete the gradients to avoid profiling the gradient accumulation
    for p in dynamic_parameters:
        p.grad = None

# Synchronize the GPU before stopping the timer
torch.cuda.synchronize()
stop = time.perf_counter()
iters_per_second = iteration_count / (stop - start)
print("TorchScript - Random Sizes")
print("Average iterations per second: {:.2f}".format(iters_per_second))

TorchScript - Random Sizes
Average iterations per second: 376.99


### Defining novel operations with nvFuser and FuncTorch
One of the primary benefits of nvFuser is the ability to define novel operations composed of PyTorch “primitives” which are then just-in-time compiled into efficient kernels.

PyTorch has strong performance for any individual operation, especially composite operations like LayerNorm. However, if LayerNorm wasn’t already implemented in PyTorch as a composite operation, then you’d have to define it as a series of simpler (primitive) operations. Let’s make such a definition and run it without nvFuser.

In [10]:
def primitive_definition(
    input1: torch.Tensor,
    input2: torch.Tensor,
    weight: torch.Tensor,
    bias1: torch.Tensor,
    bias2: torch.Tensor,
    normalization_axis: int,
    dropout_prob: float,
    keepdim: bool,
) -> torch.Tensor:
    bias1_out = input1 + bias1
    dropout_out = F.dropout(bias1_out, dropout_prob, training=True)
    norm_input = dropout_out + input2
    mean = norm_input.mean(normalization_axis, keepdim=keepdim)
    diff = norm_input - mean
    diff_sq = diff * diff
    var = diff_sq.mean(normalization_axis, keepdim=keepdim)
    pre_shift_scale_norm_output = (norm_input - mean) / torch.sqrt(var + 1e-12)
    norm_output = weight * pre_shift_scale_norm_output + bias2
    return norm_output


# Profile primitive definition
func = functools.partial(
    primitive_definition,
    input1,
    input2,
    weight,
    bias1,
    bias2,
    normalization_axis=2,
    dropout_prob=0.1,
    keepdim=True,
)
profile_workload(
    func, grad_output, iteration_count=100, label="Eager Mode - Primitive Definition"
)

Eager Mode - Primitive Definition
Average iterations per second: 65.23


In [11]:
# Profile scripted primitive definition
scripted_primitive_definition = torch.jit.script(primitive_definition)
func = functools.partial(
    scripted_primitive_definition,
    input1,
    input2,
    weight,
    bias1,
    bias2,
    normalization_axis=2,
    dropout_prob=0.1,
    keepdim=True,
)
profile_workload(
    func, grad_output, iteration_count=100, label="TorchScript - Primitive definition"
)

TorchScript - Primitive definition
Average iterations per second: 187.01


In [12]:
def primitive_definition_for_memory_efficient_fusion(
    input1: torch.Tensor,
    input2: torch.Tensor,
    weight: torch.Tensor,
    bias1: torch.Tensor,
    bias2: torch.Tensor,
) -> torch.Tensor:
    bias1_out = input1 + bias1
    dropout_out = F.dropout(bias1_out, 0.1, training=True)
    norm_input = dropout_out + input2
    mean = norm_input.mean(2, keepdim=True)
    diff = norm_input - mean
    diff_sq = diff * diff
    var = diff_sq.mean(2, keepdim=True)
    pre_shift_scale_norm_output = (norm_input - mean) / torch.sqrt(var + 1e-12)
    norm_output = weight * pre_shift_scale_norm_output + bias2
    return norm_output

In [13]:
# Optimize the model with FuncTorch tracing and the memory efficiency
# optimization pass
memory_efficient_primitive_definition = memory_efficient_fusion(
    primitive_definition_for_memory_efficient_fusion
)

# Profile memory efficient primitive definition
func = functools.partial(
    memory_efficient_primitive_definition, input1, input2, weight, bias1, bias2
)
profile_workload(
    func,
    grad_output,
    iteration_count=100,
    label="FuncTorch - Primitive definition",
)



FuncTorch - Primitive definition
Average iterations per second: 351.15


### Transformer Block With a Novel Normalization
The ability to quickly execute chains of simple operations is important as not every operation has a composite operation defined in PyTorch. Previously, this meant researchers either had to define an entirely new operation in PyTorch – which takes a lot of time and knowledge of the lower-level PyTorch code as well as parallel programming – or writing the operation in simpler PyTorch ops and settling for poor performance. For example, let’s replace LayerNorm in our example with RMSNorm. Even though RMSNorm is a bit simpler than LayerNorm, it doesn’t have an existing compound operation in PyTorch. See the Root Mean Square Layer Normalization paper for more information about RMSNorm. As before, we’ll define our new transformer block with primitive PyTorch operations.

In [14]:
def with_rms_norm(
    input1: torch.Tensor,
    input2: torch.Tensor,
    weight: torch.Tensor,
    bias: torch.Tensor,
    normalization_axis: int,
    dropout_prob: float,
    keepdim: bool,
) -> torch.Tensor:
    bias_out = input1 + bias
    dropout_out = F.dropout(bias_out, dropout_prob, training=True)
    norm_input = dropout_out + input2
    var = norm_input.mul(norm_input).mean(normalization_axis, keepdim)
    pre_shift_scale_norm_output = norm_input / torch.sqrt(var + 1e-12)
    norm_output = weight * pre_shift_scale_norm_output
    return norm_output

In [15]:
# Profile rms_norm
func = functools.partial(
    with_rms_norm,
    input1,
    input2,
    weight,
    bias1,
    normalization_axis=2,
    dropout_prob=0.1,
    keepdim=True,
)
profile_workload(func, grad_output, iteration_count=100, label="Eager Mode - RMS Norm")

Eager Mode - RMS Norm
Average iterations per second: 88.40


In [16]:
# Profile scripted rms_norm
scripted_with_rms_norm = torch.jit.script(with_rms_norm)
func = functools.partial(
    scripted_with_rms_norm,
    input1,
    input2,
    weight,
    bias1,
    normalization_axis=2,
    dropout_prob=0.1,
    keepdim=True,
)
profile_workload(func, grad_output, iteration_count=100, label="TorchScript - RMS Norm")

TorchScript - RMS Norm
Average iterations per second: 231.52


In [17]:
def with_rms_norm_for_memory_efficient_fusion(
    input1: torch.Tensor, input2: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor
) -> torch.Tensor:
    bias_out = input1 + bias
    dropout_out = torch.nn.functional.dropout(bias_out, 0.1)
    norm_input = dropout_out + input2
    var = norm_input.mul(norm_input).mean(2, keepdim=True)
    pre_shift_scale_norm_output = norm_input / torch.sqrt(var + 1e-12)
    norm_output = weight * pre_shift_scale_norm_output
    return norm_output

In [18]:
# Profile memory efficient rms_norm
memory_efficient_rms_norm = memory_efficient_fusion(
    with_rms_norm_for_memory_efficient_fusion
)
func = functools.partial(memory_efficient_rms_norm, input1, input2, weight, bias1)
profile_workload(func, grad_output, iteration_count=100, label="FuncTorch - RMS Norm")

FuncTorch - RMS Norm
Average iterations per second: 290.35


Since RMSNorm is simpler than LayerNorm the performance of our new transformer block is a little higher than the primitive definition without nvFuser (354 iterations per second compared with 260 iterations per second). With TorchScript, the iterations per second increases by 2.68x and 3.36x to 952 iterations per second and 1,191 iterations per second with TorchScript and FuncTorch’s memory efficient optimization pass, respectively. The performance of this new operation nearly matches the performance of the composite Layer Norm definition with TorchScript.

nvFuser is here to provide the ability to define novel operations in simple PyTorch and get performance that’s close to a highly optimized composite operation in PyTorch. We believe this will enable research into novel network topologies without paying for sometimes devastating effects on speed of training. nvFuser provides this unique ability as it’s able to analyze users’ programs to provide performance as fast as a highly hand tuned implementation, regardless of how the operations are defined. nvFuser still cannot support every operation in PyTorch, however its capabilities will continue to grow over time.

https://pytorch.org/tutorials/intermediate/nvfuser_intro_tutorial.html