In [1]:
%matplotlib inline

Introduction to `torch.compile`
===============================

**Author:** William Wen


`torch.compile` is the latest method to speed up your PyTorch code!
`torch.compile` makes PyTorch code run faster by JIT-compiling PyTorch
code into optimized kernels, all while requiring minimal code changes.

In this tutorial, we cover basic `torch.compile` usage, and demonstrate
the advantages of `torch.compile` over previous PyTorch compiler
solutions, such as
[TorchScript](https://pytorch.org/docs/stable/jit.html) and [FX
Tracing](https://pytorch.org/docs/stable/fx.html#torch.fx.symbolic_trace).

**Contents**


**Required pip Dependencies**

-   `torch >= 2.0`
-   `torchvision`
-   `numpy`
-   `scipy`
-   `tabulate`

**System Requirements** - A C++ compiler, such as `g++` - Python
development package (`python-devel`/`python-dev`)


NOTE: a modern NVIDIA GPU (H100, A100, or V100) is recommended for this
tutorial in order to reproduce the speedup numbers shown below and
documented elsewhere.


In [2]:
import torch
import warnings

gpu_ok = False
if torch.cuda.is_available():
    device_cap = torch.cuda.get_device_capability()
    if device_cap in ((7, 0), (8, 0), (9, 0)):
        gpu_ok = True

if not gpu_ok:
    warnings.warn(
        "GPU is not NVIDIA V100, A100, or H100. Speedup numbers may be lower "
        "than expected."
    )



Basic Usage
===========

`torch.compile` is included in the latest PyTorch. Running TorchInductor
on GPU requires Triton, which is included with the PyTorch 2.0 nightly
binary. If Triton is still missing, try installing `torchtriton` via pip
(`pip install torchtriton --extra-index-url "https://download.pytorch.org/whl/nightly/cu117"`
for CUDA 11.7).

Arbitrary Python functions can be optimized by passing the callable to
`torch.compile`. We can then call the returned optimized function in
place of the original function.


In [3]:
def foo(x, y):
    a = torch.sin(x)
    b = torch.cos(y)
    return a + b
opt_foo1 = torch.compile(foo)
print(opt_foo1(torch.randn(10, 10), torch.randn(10, 10)))

tensor([[ 0.7198, -0.1713,  1.7675,  1.8321,  1.2644, -0.1902,  0.3201,  1.0123,
          1.1733,  0.9254],
        [ 1.0384, -0.0285,  0.9342, -0.7283,  1.3724,  1.4058,  1.1847,  0.1546,
          0.4578,  0.4899],
        [-0.2054,  0.4608,  0.9911, -0.0846,  0.3779,  0.3779,  0.1887,  0.1492,
          0.0091,  0.7202],
        [ 1.1547,  0.2087,  1.3306,  1.1124,  0.8787,  0.5917,  0.8509, -0.2039,
          1.4512,  0.5375],
        [ 0.6124, -0.2881,  0.0305,  0.4831,  1.5665, -0.2783,  1.1604,  1.8423,
          0.9622,  0.6977],
        [ 0.6001,  0.8091, -0.1112,  1.0659,  1.0641,  0.8233, -0.0419,  1.3345,
          0.1695,  1.6929],
        [ 0.2633,  0.3079,  1.2722,  0.8152,  1.7059,  1.3629,  1.1429,  0.6334,
          1.3122,  0.1158],
        [-0.5457,  0.4802,  1.8297,  0.6048,  0.7095,  0.0288,  0.8733,  0.3300,
         -0.3680,  0.1668],
        [-0.0171, -0.3566,  0.9610,  0.7617, -0.3447,  0.5309,  1.8511,  0.7938,
          1.3473,  1.5914],
        [-0.2915, -

Alternatively, we can decorate the function.


In [4]:
t1 = torch.randn(10, 10)
t2 = torch.randn(10, 10)

@torch.compile
def opt_foo2(x, y):
    a = torch.sin(x)
    b = torch.cos(y)
    return a + b
print(opt_foo2(t1, t2))

tensor([[ 7.9414e-01,  1.2794e+00, -1.3200e+00,  1.2390e+00, -8.6748e-01,
         -1.0112e-01, -2.0662e-02,  6.0547e-01,  1.4508e-01,  1.0391e+00],
        [ 6.3774e-01, -1.3560e+00,  1.3120e+00,  1.1189e+00,  1.9800e+00,
         -6.2696e-01,  6.8635e-01,  3.7673e-01,  6.9580e-01,  7.1447e-01],
        [-3.1533e-01,  7.6680e-01, -5.3394e-02, -7.5036e-01, -2.2990e-02,
          6.8529e-01,  1.3530e+00,  9.6356e-02, -1.0646e-01, -7.2923e-02],
        [ 2.7073e-01,  1.5346e-01,  7.8273e-01, -1.1649e-01, -9.3213e-01,
          1.1090e+00,  3.7734e-01,  1.7230e+00,  1.4655e+00,  1.8293e+00],
        [ 1.7805e+00,  8.7070e-01,  8.1110e-01,  1.6154e+00,  9.8415e-01,
          8.8941e-01,  1.6112e+00,  1.8030e+00,  1.8037e+00,  1.4651e-01],
        [ 2.2898e-01,  1.1648e+00,  9.8982e-01,  1.2475e+00,  7.1438e-01,
         -2.3520e-01,  6.7882e-01,  3.1488e-01,  4.1282e-01,  1.9857e+00],
        [ 7.9016e-01, -6.4244e-01, -7.1780e-02,  1.7425e+00,  5.9658e-01,
         -3.0514e-01,  1.1779e+0

We can also optimize `torch.nn.Module` instances.


In [5]:
t = torch.randn(10, 100)

class MyModule(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.lin = torch.nn.Linear(100, 10)

    def forward(self, x):
        return torch.nn.functional.relu(self.lin(x))

mod = MyModule()
mod.compile()
print(mod(t))
## or:
# opt_mod = torch.compile(mod)
# print(opt_mod(t))

tensor([[0.0000, 0.2006, 0.4798, 0.6102, 0.0000, 0.0000, 0.1879, 0.0000, 0.0452,
         0.0000],
        [0.0000, 0.0000, 0.2386, 0.0209, 0.0000, 0.3346, 0.0000, 0.4132, 0.2831,
         0.0000],
        [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.1626, 0.1518, 0.0000, 0.0000,
         0.3174],
        [0.0000, 0.0000, 0.0000, 0.0000, 0.6378, 0.1308, 0.0000, 0.0000, 0.0000,
         0.3700],
        [0.4734, 0.0000, 0.2960, 0.1057, 0.2784, 0.4522, 0.0000, 0.2040, 0.2996,
         0.0000],
        [0.1060, 0.0000, 0.2568, 0.1646, 0.0000, 0.0150, 0.1353, 0.0000, 0.0000,
         0.0000],
        [0.2909, 0.3487, 0.0000, 0.2984, 0.0000, 0.7762, 0.4334, 0.4595, 0.9832,
         0.7313],
        [0.0000, 1.1428, 1.0770, 0.0000, 1.1954, 0.0000, 0.0000, 0.2860, 0.0000,
         0.6307],
        [0.6294, 0.3150, 0.3151, 0.1927, 0.0000, 0.0000, 0.1318, 0.0000, 0.0000,
         0.7981],
        [0.0000, 1.0691, 0.0000, 0.0000, 0.0000, 0.0041, 0.0000, 0.0000, 0.0784,
         0.0000]], grad_fn=<

torch.compile and Nested Calls
==============================

Nested function calls within the decorated function will also be
compiled.


In [6]:
def nested_function(x):
    return torch.sin(x)

@torch.compile
def outer_function(x, y):
    a = nested_function(x)
    b = torch.cos(y)
    return a + b

print(outer_function(t1, t2))

tensor([[ 7.9414e-01,  1.2794e+00, -1.3200e+00,  1.2390e+00, -8.6748e-01,
         -1.0112e-01, -2.0662e-02,  6.0547e-01,  1.4508e-01,  1.0391e+00],
        [ 6.3774e-01, -1.3560e+00,  1.3120e+00,  1.1189e+00,  1.9800e+00,
         -6.2696e-01,  6.8635e-01,  3.7673e-01,  6.9580e-01,  7.1447e-01],
        [-3.1533e-01,  7.6680e-01, -5.3394e-02, -7.5036e-01, -2.2990e-02,
          6.8529e-01,  1.3530e+00,  9.6356e-02, -1.0646e-01, -7.2923e-02],
        [ 2.7073e-01,  1.5346e-01,  7.8273e-01, -1.1649e-01, -9.3213e-01,
          1.1090e+00,  3.7734e-01,  1.7230e+00,  1.4655e+00,  1.8293e+00],
        [ 1.7805e+00,  8.7070e-01,  8.1110e-01,  1.6154e+00,  9.8415e-01,
          8.8941e-01,  1.6112e+00,  1.8030e+00,  1.8037e+00,  1.4651e-01],
        [ 2.2898e-01,  1.1648e+00,  9.8982e-01,  1.2475e+00,  7.1438e-01,
         -2.3520e-01,  6.7882e-01,  3.1488e-01,  4.1282e-01,  1.9857e+00],
        [ 7.9016e-01, -6.4244e-01, -7.1780e-02,  1.7425e+00,  5.9658e-01,
         -3.0514e-01,  1.1779e+0

In the same fashion, when compiling a module all sub-modules and methods
within it, that are not in a skip list, are also compiled.


In [7]:
class OuterModule(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.inner_module = MyModule()
        self.outer_lin = torch.nn.Linear(10, 2)

    def forward(self, x):
        x = self.inner_module(x)
        return torch.nn.functional.relu(self.outer_lin(x))

outer_mod = OuterModule()
outer_mod.compile()
print(outer_mod(t))

tensor([[0.0000, 0.0000],
        [0.0000, 0.0587],
        [0.0000, 0.1209],
        [0.0000, 0.2177],
        [0.0000, 0.1731],
        [0.0000, 0.0648],
        [0.0000, 0.1839],
        [0.0000, 0.1274],
        [0.0430, 0.1260],
        [0.0000, 0.0983]], grad_fn=<CompiledFunctionBackward>)


We can also disable some functions from being compiled by using
`torch.compiler.disable`. Suppose you want to disable the tracing on
just the `complex_function` function, but want to continue the tracing
back in `complex_conjugate`. In this case, you can use
`torch.compiler.disable(recursive=False)` option. Otherwise, the default
is `recursive=True`.


In [8]:
def complex_conjugate(z):
    return torch.conj(z)

@torch.compiler.disable(recursive=False)
def complex_function(real, imag):
    # Assuming this function cause problems in the compilation
    z = torch.complex(real, imag)
    return complex_conjugate(z)

def outer_function():
    real = torch.tensor([2, 3], dtype=torch.float32)
    imag = torch.tensor([4, 5], dtype=torch.float32)
    z = complex_function(real, imag)
    return torch.abs(z)

# Try to compile the outer_function
try:
    opt_outer_function = torch.compile(outer_function)
    print(opt_outer_function())
except Exception as e:
    print("Compilation of outer_function failed:", e)

tensor([4.4721, 5.8310])




Best Practices and Recommendations
==================================

Behavior of `torch.compile` with Nested Modules and Function Calls

When you use `torch.compile`, the compiler will try to recursively
compile every function call inside the target function or module inside
the target function or module that is not in a skip list (such as
built-ins, some functions in the torch.\* namespace).

**Best Practices:**

1\. **Top-Level Compilation:** One approach is to compile at the highest
level possible (i.e., when the top-level module is initialized/called)
and selectively disable compilation when encountering excessive graph
breaks or errors. If there are still many compile issues, compile
individual subcomponents instead.

2\. **Modular Testing:** Test individual functions and modules with
`torch.compile` before integrating them into larger models to isolate
potential issues.

3\. **Disable Compilation Selectively:** If certain functions or
sub-modules cannot be handled by [torch.compile]{.title-ref}, use the
[torch.compiler.disable]{.title-ref} context managers to recursively
exclude them from compilation.

4\. **Compile Leaf Functions First:** In complex models with multiple
nested functions and modules, start by compiling the leaf functions or
modules first. For more information see [TorchDynamo APIs for
fine-grained
tracing](https://pytorch.org/docs/stable/torch.compiler_fine_grain_apis.html).

5.  **Prefer \`\`mod.compile()\`\` over \`\`torch.compile(mod)\`\`:**
    Avoids `_orig_` prefix issues in `state_dict`.

6\. **Use \`\`fullgraph=True\`\` to catch graph breaks:** Helps ensure
end-to-end compilation, maximizing speedup and compatibility with
`torch.export`.


Demonstrating Speedups
======================

Let\'s now demonstrate that using `torch.compile` can speed up real
models. We will compare standard eager mode and `torch.compile` by
evaluating and training a `torchvision` model on random data.

Before we start, we need to define some utility functions.


In [9]:
# Returns the result of running `fn()` and the time it took for `fn()` to run,
# in seconds. We use CUDA events and synchronization for the most accurate
# measurements.
def timed(fn):
    start = torch.cuda.Event(enable_timing=True)
    end = torch.cuda.Event(enable_timing=True)
    start.record()
    result = fn()
    end.record()
    torch.cuda.synchronize()
    return result, start.elapsed_time(end) / 1000

# Generates random input and targets data for the model, where `b` is
# batch size.
def generate_data(b):
    return (
        torch.randn(b, 3, 128, 128).to(torch.float32).cuda(),
        torch.randint(1000, (b,)).cuda(),
    )

N_ITERS = 10

from torchvision.models import densenet121
def init_model():
    return densenet121().to(torch.float32).cuda()

First, let\'s compare inference.

Note that in the call to `torch.compile`, we have the additional `mode`
argument, which we will discuss below.


In [10]:
model = init_model()

# Reset since we are using a different mode.
import torch._dynamo
torch._dynamo.reset()

model_opt = torch.compile(model, mode="reduce-overhead")

inp = generate_data(16)[0]
with torch.no_grad():
    print("eager:", timed(lambda: model(inp))[1])
    print("compile:", timed(lambda: model_opt(inp))[1])

eager: 1.0150154418945312


W0626 14:39:18.693000 547 torch/_inductor/utils.py:1137] [0/0] Not enough SMs to use max_autotune_gemm mode


compile: 172.362078125


Notice that `torch.compile` takes a lot longer to complete compared to
eager. This is because `torch.compile` compiles the model into optimized
kernels as it executes. In our example, the structure of the model
doesn\'t change, and so recompilation is not needed. So if we run our
optimized model several more times, we should see a significant
improvement compared to eager.


In [11]:
eager_times = []
for i in range(N_ITERS):
    inp = generate_data(16)[0]
    with torch.no_grad():
        _, eager_time = timed(lambda: model(inp))
    eager_times.append(eager_time)
    print(f"eager eval time {i}: {eager_time}")

print("~" * 10)

compile_times = []
for i in range(N_ITERS):
    inp = generate_data(16)[0]
    with torch.no_grad():
        _, compile_time = timed(lambda: model_opt(inp))
    compile_times.append(compile_time)
    print(f"compile eval time {i}: {compile_time}")
print("~" * 10)

import numpy as np
eager_med = np.median(eager_times)
compile_med = np.median(compile_times)
speedup = eager_med / compile_med
assert(speedup > 1)
print(f"(eval) eager median: {eager_med}, compile median: {compile_med}, speedup: {speedup}x")
print("~" * 10)

eager eval time 0: 0.04603209686279297
eager eval time 1: 0.024002752304077148
eager eval time 2: 0.021936384201049805
eager eval time 3: 0.02175939178466797
eager eval time 4: 0.02235638427734375
eager eval time 5: 0.02446214485168457
eager eval time 6: 0.02194361686706543
eager eval time 7: 0.025506591796875
eager eval time 8: 0.022314144134521485
eager eval time 9: 0.02146374320983887
~~~~~~~~~~
compile eval time 0: 0.8381399536132812
compile eval time 1: 0.017057119369506837
compile eval time 2: 0.01831158447265625
compile eval time 3: 0.01681635284423828
compile eval time 4: 0.017563455581665038
compile eval time 5: 0.01777302360534668
compile eval time 6: 0.01753606414794922
compile eval time 7: 0.017021856307983398
compile eval time 8: 0.01688688087463379
compile eval time 9: 0.017830591201782226
~~~~~~~~~~
(eval) eager median: 0.02233526420593262, compile median: 0.017549759864807127, speedup: 1.2726820411213693x
~~~~~~~~~~


And indeed, we can see that running our model with `torch.compile`
results in a significant speedup. Speedup mainly comes from reducing
Python overhead and GPU read/writes, and so the observed speedup may
vary on factors such as model architecture and batch size. For example,
if a model\'s architecture is simple and the amount of data is large,
then the bottleneck would be GPU compute and the observed speedup may be
less significant.

You may also see different speedup results depending on the chosen
`mode` argument. The `"reduce-overhead"` mode uses CUDA graphs to
further reduce the overhead of Python. For your own models, you may need
to experiment with different modes to maximize speedup. You can read
more about modes
[here](https://pytorch.org/get-started/pytorch-2.0/#user-experience).

You may might also notice that the second time we run our model with
`torch.compile` is significantly slower than the other runs, although it
is much faster than the first run. This is because the
`"reduce-overhead"` mode runs a few warm-up iterations for CUDA graphs.

For general PyTorch benchmarking, you can try using
`torch.utils.benchmark` instead of the `timed` function we defined
above. We wrote our own timing function in this tutorial to show
`torch.compile`\'s compilation latency.

Now, let\'s consider comparing training.


In [12]:
model = init_model()
opt = torch.optim.Adam(model.parameters())

def train(mod, data):
    opt.zero_grad(True)
    pred = mod(data[0])
    loss = torch.nn.CrossEntropyLoss()(pred, data[1])
    loss.backward()
    opt.step()

eager_times = []
for i in range(N_ITERS):
    inp = generate_data(16)
    _, eager_time = timed(lambda: train(model, inp))
    eager_times.append(eager_time)
    print(f"eager train time {i}: {eager_time}")
print("~" * 10)

model = init_model()
opt = torch.optim.Adam(model.parameters())
train_opt = torch.compile(train, mode="reduce-overhead")

compile_times = []
for i in range(N_ITERS):
    inp = generate_data(16)
    _, compile_time = timed(lambda: train_opt(model, inp))
    compile_times.append(compile_time)
    print(f"compile train time {i}: {compile_time}")
print("~" * 10)

eager_med = np.median(eager_times)
compile_med = np.median(compile_times)
speedup = eager_med / compile_med
assert(speedup > 1)
print(f"(train) eager median: {eager_med}, compile median: {compile_med}, speedup: {speedup}x")
print("~" * 10)

eager train time 0: 0.6418272094726563
eager train time 1: 0.0738986587524414
eager train time 2: 0.07069900512695312
eager train time 3: 0.0754176025390625
eager train time 4: 0.07063961791992188
eager train time 5: 0.06987232208251953
eager train time 6: 0.07310902404785156
eager train time 7: 0.07381302642822266
eager train time 8: 0.08722188568115234
eager train time 9: 0.07340332794189453
~~~~~~~~~~


W0626 14:45:46.611000 547 torch/_logging/_internal.py:1089] [3/0] Profiler function <class 'torch.autograd.profiler.record_function'> will be ignored


compile train time 0: 372.4535625
compile train time 1: 9.7961962890625
compile train time 2: 0.08050096130371094
compile train time 3: 0.08053059387207032
compile train time 4: 0.06821478271484376
compile train time 5: 0.04811737442016602
compile train time 6: 0.047951744079589846
compile train time 7: 0.04827248001098633
compile train time 8: 0.04730879974365235
compile train time 9: 0.04599603271484375
~~~~~~~~~~
(train) eager median: 0.07360817718505859, compile median: 0.05824363136291504, speedup: 1.2637978687559388x
~~~~~~~~~~


Again, we can see that `torch.compile` takes longer in the first
iteration, as it must compile the model, but in subsequent iterations,
we see significant speedups compared to eager.

We remark that the speedup numbers presented in this tutorial are for
demonstration purposes only. Official speedup values can be seen at the
[TorchInductor performance
dashboard](https://hud.pytorch.org/benchmark/compilers).


Conclusion
==========

In this tutorial, we introduced `torch.compile` by covering basic usage,
demonstrating speedups over eager mode, comparing to previous PyTorch
compiler solutions, and briefly investigating TorchDynamo and its
interactions with FX graphs. We hope that you will give `torch.compile`
a try!
