https://docs.pytorch.org/tutorials/intermediate/torch_compile_tutorial.html

In [5]:
# For tips on running notebooks in Google Colab, see
# https://pytorch.org/tutorials/beginner/colab
# %matplotlib inline

Introduction to `torch.compile`
===============================

**Author:** William Wen


`torch.compile` is the latest method to speed up your PyTorch code!
`torch.compile` makes PyTorch code run faster by JIT-compiling PyTorch
code into optimized kernels, all while requiring minimal code changes.

In this tutorial, we cover basic `torch.compile` usage, and demonstrate
the advantages of `torch.compile` over previous PyTorch compiler
solutions, such as
[TorchScript](https://pytorch.org/docs/stable/jit.html) and [FX
Tracing](https://pytorch.org/docs/stable/fx.html#torch.fx.symbolic_trace).

**Contents**

::: {.contents local=""}
:::

**Required pip Dependencies**

-   `torch >= 2.0`
-   `torchvision`
-   `numpy`
-   `scipy`
-   `tabulate`

**System Requirements** - A C++ compiler, such as `g++` - Python
development package (`python-devel`/`python-dev`)


NOTE: a modern NVIDIA GPU (H100, A100, or V100) is recommended for this
tutorial in order to reproduce the speedup numbers shown below and
documented elsewhere.


In [6]:
import torch
import warnings

gpu_ok = False
if torch.cuda.is_available():
    device_cap = torch.cuda.get_device_capability()
    if device_cap in ((7, 0), (8, 0), (9, 0)):
        gpu_ok = True

if not gpu_ok:
    warnings.warn(
        "GPU is not NVIDIA V100, A100, or H100. Speedup numbers may be lower "
        "than expected."
    )



In [7]:
torch.cuda.get_device_capability()

(8, 9)

Basic Usage
===========

`torch.compile` is included in the latest PyTorch. Running TorchInductor
on GPU requires Triton, which is included with the PyTorch 2.0 nightly
binary. If Triton is still missing, try installing `torchtriton` via pip
(`pip install torchtriton --extra-index-url "https://download.pytorch.org/whl/nightly/cu117"`
for CUDA 11.7).

Arbitrary Python functions can be optimized by passing the callable to
`torch.compile`. We can then call the returned optimized function in
place of the original function.


In [8]:
def foo(x, y):
    a = torch.sin(x)
    b = torch.cos(y)
    return a + b


opt_foo1 = torch.compile(foo)
print(opt_foo1(torch.randn(10, 10), torch.randn(10, 10)))

tensor([[ 1.6441e+00, -2.3744e-01,  3.3466e-01,  1.3373e+00,  1.7755e+00,
          7.6003e-02,  1.0719e+00,  1.0100e-01,  7.3513e-01,  5.4297e-01],
        [-2.4148e-01,  2.8763e-02,  7.2524e-01,  1.0647e+00,  1.1237e+00,
          1.6473e+00,  3.6311e-01,  1.2705e+00,  3.1930e-01,  6.5589e-04],
        [ 1.0697e+00,  1.5630e+00,  1.3214e+00,  8.3187e-02,  8.5060e-01,
          1.3529e+00,  7.2161e-01,  1.3939e+00,  8.3523e-01, -1.0453e+00],
        [-7.9625e-01,  2.4606e-01,  1.9218e+00,  3.3712e-01,  1.1852e+00,
          1.8330e+00,  1.6982e+00,  3.2020e-01, -2.0006e-01, -2.6483e-01],
        [ 1.3713e+00,  2.9989e-01, -3.0108e-01,  6.7520e-01,  1.9782e+00,
          5.7397e-01,  1.2379e-02,  4.4264e-01,  7.7161e-01,  6.1352e-01],
        [ 2.8111e-01,  1.8953e+00,  1.4625e+00,  1.8992e+00,  3.5646e-01,
          6.5643e-01, -1.8666e-01,  7.1100e-01, -2.4304e-01,  4.8156e-01],
        [ 4.5385e-01,  2.4414e-01, -3.9878e-01,  1.1043e+00,  1.5947e+00,
         -7.7061e-01,  5.8219e-0

Alternatively, we can decorate the function.


In [9]:
t1 = torch.randn(10, 10)
t2 = torch.randn(10, 10)


@torch.compile
def opt_foo2(x, y):
    a = torch.sin(x)
    b = torch.cos(y)
    return a + b


print(opt_foo2(t1, t2))

tensor([[ 0.9878,  0.6664,  1.6616,  1.2878,  0.1221,  1.3020,  1.8696, -0.0464,
          0.9167,  0.5943],
        [-0.0081, -0.0586,  0.5976,  1.7362,  1.9753,  0.0370,  0.6115,  1.8201,
          0.8020,  0.4709],
        [ 1.0193, -0.2708,  0.7117,  1.0429,  1.3406,  1.9255, -1.0757, -0.3766,
          0.5571,  1.3056],
        [ 1.9479,  0.7625,  1.7719,  1.7156, -0.0299,  0.1075,  0.5167,  0.6288,
          0.9370,  0.0086],
        [ 1.3085,  1.4662,  0.9243,  1.4638, -0.6317,  0.3579,  1.8838,  1.7756,
         -0.9232,  0.4282],
        [ 0.1316,  0.8578,  0.1652,  1.0583,  0.7161,  1.1577, -0.3801,  0.9910,
          1.4017,  0.1863],
        [ 1.3836,  0.1763,  0.5551,  0.0372,  0.7654,  0.2410,  1.6998,  1.7749,
          0.9866,  1.9513],
        [-0.0477,  0.2768,  1.0746,  1.7253,  1.4814,  1.8925, -0.1557,  1.8058,
          0.2987, -0.6080],
        [ 1.2114,  0.0295, -0.1344,  0.0892,  0.1154,  0.3053,  0.4710,  1.8639,
          0.1864, -0.0249],
        [ 1.1651,  

We can also optimize `torch.nn.Module` instances.


In [10]:
t = torch.randn(10, 100)


class MyModule(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.lin = torch.nn.Linear(100, 10)

    def forward(self, x):
        return torch.nn.functional.relu(self.lin(x))


mod = MyModule()
mod.compile()
print(mod(t))
## or:
# opt_mod = torch.compile(mod)
# print(opt_mod(t))

tensor([[0.8690, 0.6797, 0.0000, 0.0000, 0.0000, 0.0626, 0.0000, 0.5468, 0.0000,
         1.0521],
        [0.0107, 0.0000, 0.5158, 0.0000, 0.0000, 0.1908, 0.2138, 0.0929, 0.0000,
         0.7257],
        [0.0000, 0.0000, 0.3076, 0.2883, 0.6478, 0.1089, 0.0000, 0.0000, 0.0000,
         0.0000],
        [0.0000, 0.0000, 0.2876, 0.0000, 0.0000, 0.0000, 0.5118, 0.0000, 0.6576,
         0.0000],
        [0.0000, 0.4161, 0.0000, 0.4233, 0.8046, 0.0000, 0.0000, 0.0000, 0.2612,
         0.0000],
        [0.0000, 0.0468, 0.4084, 0.3678, 0.2205, 0.3162, 0.2730, 0.0000, 0.0000,
         0.0000],
        [0.0000, 0.0000, 0.0265, 0.0000, 0.2623, 0.1892, 0.0000, 0.0000, 0.0000,
         0.1240],
        [0.0000, 0.3190, 0.0000, 0.9979, 0.1024, 0.7692, 0.0000, 0.0000, 0.6181,
         0.0000],
        [0.0000, 0.0000, 0.4269, 0.0000, 0.3567, 0.0000, 0.2295, 0.0000, 0.0000,
         1.2362],
        [0.0000, 0.0000, 0.0000, 1.1740, 0.2125, 0.0792, 1.5206, 0.0000, 0.2972,
         0.2830]], grad_fn=<

torch.compile and Nested Calls
==============================

Nested function calls within the decorated function will also be
compiled.


In [11]:
def nested_function(x):
    return torch.sin(x)


@torch.compile
def outer_function(x, y):
    a = nested_function(x)
    b = torch.cos(y)
    return a + b


print(outer_function(t1, t2))

tensor([[ 0.9878,  0.6664,  1.6616,  1.2878,  0.1221,  1.3020,  1.8696, -0.0464,
          0.9167,  0.5943],
        [-0.0081, -0.0586,  0.5976,  1.7362,  1.9753,  0.0370,  0.6115,  1.8201,
          0.8020,  0.4709],
        [ 1.0193, -0.2708,  0.7117,  1.0429,  1.3406,  1.9255, -1.0757, -0.3766,
          0.5571,  1.3056],
        [ 1.9479,  0.7625,  1.7719,  1.7156, -0.0299,  0.1075,  0.5167,  0.6288,
          0.9370,  0.0086],
        [ 1.3085,  1.4662,  0.9243,  1.4638, -0.6317,  0.3579,  1.8838,  1.7756,
         -0.9232,  0.4282],
        [ 0.1316,  0.8578,  0.1652,  1.0583,  0.7161,  1.1577, -0.3801,  0.9910,
          1.4017,  0.1863],
        [ 1.3836,  0.1763,  0.5551,  0.0372,  0.7654,  0.2410,  1.6998,  1.7749,
          0.9866,  1.9513],
        [-0.0477,  0.2768,  1.0746,  1.7253,  1.4814,  1.8925, -0.1557,  1.8058,
          0.2987, -0.6080],
        [ 1.2114,  0.0295, -0.1344,  0.0892,  0.1154,  0.3053,  0.4710,  1.8639,
          0.1864, -0.0249],
        [ 1.1651,  

In the same fashion, when compiling a module all sub-modules and methods
within it, that are not in a skip list, are also compiled.


In [12]:
class OuterModule(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.inner_module = MyModule()
        self.outer_lin = torch.nn.Linear(10, 2)

    def forward(self, x):
        x = self.inner_module(x)
        return torch.nn.functional.relu(self.outer_lin(x))


outer_mod = OuterModule()
outer_mod.compile()
print(outer_mod(t))

tensor([[0.0308, 0.0000],
        [0.2024, 0.1565],
        [0.2461, 0.1796],
        [0.2937, 0.0000],
        [0.1010, 0.0486],
        [0.4328, 0.1606],
        [0.0480, 0.0558],
        [0.1595, 0.0593],
        [0.4238, 0.3796],
        [0.2593, 0.0000]], grad_fn=<CompiledFunctionBackward>)


We can also disable some functions from being compiled by using
`torch.compiler.disable`. Suppose you want to disable the tracing on
just the `complex_function` function, but want to continue the tracing
back in `complex_conjugate`. In this case, you can use
`torch.compiler.disable(recursive=False)` option. Otherwise, the default
is `recursive=True`.


In [13]:
def complex_conjugate(z):
    return torch.conj(z)


@torch.compiler.disable(recursive=False)
def complex_function(real, imag):
    # Assuming this function cause problems in the compilation
    z = torch.complex(real, imag)
    return complex_conjugate(z)


def outer_function():
    real = torch.tensor([2, 3], dtype=torch.float32)
    imag = torch.tensor([4, 5], dtype=torch.float32)
    z = complex_function(real, imag)
    return torch.abs(z)


# Try to compile the outer_function
try:
    opt_outer_function = torch.compile(outer_function)
    print(opt_outer_function())
except Exception as e:
    print("Compilation of outer_function failed:", e)

tensor([4.4721, 5.8310])


Best Practices and Recommendations
==================================

Behavior of `torch.compile` with Nested Modules and Function Calls

When you use `torch.compile`, the compiler will try to recursively
compile every function call inside the target function or module inside
the target function or module that is not in a skip list (such as
built-ins, some functions in the torch.\* namespace).

**Best Practices:**

1\. **Top-Level Compilation:** One approach is to compile at the highest
level possible (i.e., when the top-level module is initialized/called)
and selectively disable compilation when encountering excessive graph
breaks or errors. If there are still many compile issues, compile
individual subcomponents instead.

2\. **Modular Testing:** Test individual functions and modules with
`torch.compile` before integrating them into larger models to isolate
potential issues.

3\. **Disable Compilation Selectively:** If certain functions or
sub-modules cannot be handled by [torch.compile]{.title-ref}, use the
[torch.compiler.disable]{.title-ref} context managers to recursively
exclude them from compilation.

4\. **Compile Leaf Functions First:** In complex models with multiple
nested functions and modules, start by compiling the leaf functions or
modules first. For more information see [TorchDynamo APIs for
fine-grained
tracing](https://pytorch.org/docs/stable/torch.compiler_fine_grain_apis.html).

5.  **Prefer \`\`mod.compile()\`\` over \`\`torch.compile(mod)\`\`:**
    Avoids `_orig_` prefix issues in `state_dict`.

6\. **Use \`\`fullgraph=True\`\` to catch graph breaks:** Helps ensure
end-to-end compilation, maximizing speedup and compatibility with
`torch.export`.


Demonstrating Speedups
======================

Let\'s now demonstrate that using `torch.compile` can speed up real
models. We will compare standard eager mode and `torch.compile` by
evaluating and training a `torchvision` model on random data.

Before we start, we need to define some utility functions.


In [14]:
# Returns the result of running `fn()` and the time it took for `fn()` to run,
# in seconds. We use CUDA events and synchronization for the most accurate
# measurements.
def timed(fn):
    start = torch.cuda.Event(enable_timing=True)
    end = torch.cuda.Event(enable_timing=True)
    start.record()
    result = fn()
    end.record()
    torch.cuda.synchronize()
    return result, start.elapsed_time(end) / 1000


# Generates random input and targets data for the model, where `b` is
# batch size.
def generate_data(b):
    return (
        torch.randn(b, 3, 128, 128).to(torch.float32).cuda(),
        torch.randint(1000, (b,)).cuda(),
    )


N_ITERS = 10

from torchvision.models import densenet121


def init_model():
    return densenet121().to(torch.float32).cuda()

First, let\'s compare inference.

Note that in the call to `torch.compile`, we have the additional `mode`
argument, which we will discuss below.


In [15]:
model = init_model()

# Reset since we are using a different mode.
import torch._dynamo

torch._dynamo.reset()

model_opt = torch.compile(model, mode="reduce-overhead")

inp = generate_data(16)[0]
with torch.no_grad():
    print("eager:", timed(lambda: model(inp))[1])
    print("compile:", timed(lambda: model_opt(inp))[1])

eager: 0.17742454528808593
compile: 3.003368896484375


Notice that `torch.compile` takes a lot longer to complete compared to
eager. This is because `torch.compile` compiles the model into optimized
kernels as it executes. In our example, the structure of the model
doesn\'t change, and so recompilation is not needed. So if we run our
optimized model several more times, we should see a significant
improvement compared to eager.


In [16]:
eager_times = []
for i in range(N_ITERS):
    inp = generate_data(16)[0]
    with torch.no_grad():
        _, eager_time = timed(lambda: model(inp))
    eager_times.append(eager_time)
    print(f"eager eval time {i}: {eager_time}")

print("~" * 10)

compile_times = []
for i in range(N_ITERS):
    inp = generate_data(16)[0]
    with torch.no_grad():
        _, compile_time = timed(lambda: model_opt(inp))
    compile_times.append(compile_time)
    print(f"compile eval time {i}: {compile_time}")
print("~" * 10)

import numpy as np

eager_med = np.median(eager_times)
compile_med = np.median(compile_times)
speedup = eager_med / compile_med
assert speedup > 1
print(
    f"(eval) eager median: {eager_med}, compile median: {compile_med}, speedup: {speedup}x"
)
print("~" * 10)

eager eval time 0: 0.005596159934997558
eager eval time 1: 0.004498335838317871
eager eval time 2: 0.004241504192352295
eager eval time 3: 0.00440064001083374
eager eval time 4: 0.004104447841644287
eager eval time 5: 0.004055744171142578
eager eval time 6: 0.004042687892913818
eager eval time 7: 0.004048992156982422
eager eval time 8: 0.004025343894958496
eager eval time 9: 0.00404095983505249
~~~~~~~~~~
compile eval time 0: 0.13317814636230468
compile eval time 1: 0.004147424221038819
compile eval time 2: 0.004278207778930664
compile eval time 3: 0.003794368028640747
compile eval time 4: 0.0037889280319213865
compile eval time 5: 0.0037751040458679198
compile eval time 6: 0.0037930240631103515
compile eval time 7: 0.0037874879837036133
compile eval time 8: 0.003780832052230835
compile eval time 9: 0.003825632095336914
~~~~~~~~~~
(eval) eager median: 0.004080096006393433, compile median: 0.003793696045875549, speedup: 1.0754936497427763x
~~~~~~~~~~


And indeed, we can see that running our model with `torch.compile`
results in a significant speedup. Speedup mainly comes from reducing
Python overhead and GPU read/writes, and so the observed speedup may
vary on factors such as model architecture and batch size. For example,
if a model\'s architecture is simple and the amount of data is large,
then the bottleneck would be GPU compute and the observed speedup may be
less significant.

You may also see different speedup results depending on the chosen
`mode` argument. The `"reduce-overhead"` mode uses CUDA graphs to
further reduce the overhead of Python. For your own models, you may need
to experiment with different modes to maximize speedup. You can read
more about modes
[here](https://pytorch.org/get-started/pytorch-2.0/#user-experience).

You may might also notice that the second time we run our model with
`torch.compile` is significantly slower than the other runs, although it
is much faster than the first run. This is because the
`"reduce-overhead"` mode runs a few warm-up iterations for CUDA graphs.

For general PyTorch benchmarking, you can try using
`torch.utils.benchmark` instead of the `timed` function we defined
above. We wrote our own timing function in this tutorial to show
`torch.compile`\'s compilation latency.

Now, let\'s consider comparing training.


In [17]:
model = init_model()
opt = torch.optim.Adam(model.parameters())


def train(mod, data):
    opt.zero_grad(True)
    pred = mod(data[0])
    loss = torch.nn.CrossEntropyLoss()(pred, data[1])
    loss.backward()
    opt.step()


eager_times = []
for i in range(N_ITERS):
    inp = generate_data(16)
    _, eager_time = timed(lambda: train(model, inp))
    eager_times.append(eager_time)
    print(f"eager train time {i}: {eager_time}")
print("~" * 10)

model = init_model()
opt = torch.optim.Adam(model.parameters())
train_opt = torch.compile(train, mode="reduce-overhead")

compile_times = []
for i in range(N_ITERS):
    inp = generate_data(16)
    _, compile_time = timed(lambda: train_opt(model, inp))
    compile_times.append(compile_time)
    print(f"compile train time {i}: {compile_time}")
print("~" * 10)

eager_med = np.median(eager_times)
compile_med = np.median(compile_times)
speedup = eager_med / compile_med
assert speedup > 1
print(
    f"(train) eager median: {eager_med}, compile median: {compile_med}, speedup: {speedup}x"
)
print("~" * 10)

eager train time 0: 0.12498336029052734
eager train time 1: 0.015583711624145508
eager train time 2: 0.01411023998260498
eager train time 3: 0.013883199691772461
eager train time 4: 0.014187904357910156
eager train time 5: 0.01563488006591797
eager train time 6: 0.01409500789642334
eager train time 7: 0.013906368255615234
eager train time 8: 0.013944448471069336
eager train time 9: 0.014017951965332032
~~~~~~~~~~


W0614 21:17:03.766000 396503 site-packages/torch/_logging/_internal.py:1130] [4/0] Profiler function <class 'torch.autograd.profiler.record_function'> will be ignored


compile train time 0: 8.061078125
compile train time 1: 1.3932918701171875
compile train time 2: 0.010026816368103028
compile train time 3: 0.009671392440795899
compile train time 4: 0.009518303871154786
compile train time 5: 0.009332511901855469
compile train time 6: 0.009558431625366211
compile train time 7: 0.009229599952697754
compile train time 8: 0.009159647941589355
compile train time 9: 0.0091015043258667
~~~~~~~~~~
(train) eager median: 0.014102623939514159, compile median: 0.009538367748260498, speedup: 1.4785154348956655x
~~~~~~~~~~


Again, we can see that `torch.compile` takes longer in the first
iteration, as it must compile the model, but in subsequent iterations,
we see significant speedups compared to eager.

We remark that the speedup numbers presented in this tutorial are for
demonstration purposes only. Official speedup values can be seen at the
[TorchInductor performance
dashboard](https://hud.pytorch.org/benchmark/compilers).


Comparison to TorchScript and FX Tracing
========================================

We have seen that `torch.compile` can speed up PyTorch code. Why else
should we use `torch.compile` over existing PyTorch compiler solutions,
such as TorchScript or FX Tracing? Primarily, the advantage of
`torch.compile` lies in its ability to handle arbitrary Python code with
minimal changes to existing code.

One case that `torch.compile` can handle that other compiler solutions
struggle with is data-dependent control flow (the `if x.sum() < 0:` line
below).


In [18]:
def f1(x, y):
    if x.sum() < 0:
        return -y
    return y


# Test that `fn1` and `fn2` return the same result, given
# the same arguments `args`. Typically, `fn1` will be an eager function
# while `fn2` will be a compiled function (torch.compile, TorchScript, or FX graph).
def test_fns(fn1, fn2, args):
    out1 = fn1(*args)
    out2 = fn2(*args)
    return torch.allclose(out1, out2)


inp1 = torch.randn(5, 5)
inp2 = torch.randn(5, 5)

TorchScript tracing `f1` results in silently incorrect results, since
only the actual control flow path is traced.


In [19]:
traced_f1 = torch.jit.trace(f1, (inp1, inp2))
print("traced 1, 1:", test_fns(f1, traced_f1, (inp1, inp2)))
print("traced 1, 2:", test_fns(f1, traced_f1, (-inp1, inp2)))

traced 1, 1: True
traced 1, 2: False


  if x.sum() < 0:


FX tracing `f1` results in an error due to the presence of
data-dependent control flow.


In [20]:
import traceback as tb

try:
    torch.fx.symbolic_trace(f1)
except:
    tb.print_exc()

Traceback (most recent call last):
  File "/tmp/ipykernel_396503/3126115963.py", line 4, in <module>
    torch.fx.symbolic_trace(f1)
  File "/opt/conda/lib/python3.11/site-packages/torch/fx/_symbolic_trace.py", line 1314, in symbolic_trace
    graph = tracer.trace(root, concrete_args)
            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/eval_frame.py", line 838, in _fn
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/torch/fx/_symbolic_trace.py", line 838, in trace
    (self.create_arg(fn(*args)),),
                     ^^^^^^^^^
  File "/tmp/ipykernel_396503/116533254.py", line 2, in f1
    if x.sum() < 0:
       ^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/torch/fx/proxy.py", line 555, in __bool__
    return self.tracer.to_bool(self)
           ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/torch/fx/proxy.py", line 366, in to_bo

If we provide a value for `x` as we try to FX trace `f1`, then we run
into the same problem as TorchScript tracing, as the data-dependent
control flow is removed in the traced function.


In [21]:
fx_f1 = torch.fx.symbolic_trace(f1, concrete_args={"x": inp1})
print("fx 1, 1:", test_fns(f1, fx_f1, (inp1, inp2)))
print("fx 1, 2:", test_fns(f1, fx_f1, (-inp1, inp2)))

fx 1, 1: True
fx 1, 2: False




Now we can see that `torch.compile` correctly handles data-dependent
control flow.


In [22]:
# Reset since we are using a different mode.
torch._dynamo.reset()

compile_f1 = torch.compile(f1)
print("compile 1, 1:", test_fns(f1, compile_f1, (inp1, inp2)))
print("compile 1, 2:", test_fns(f1, compile_f1, (-inp1, inp2)))
print("~" * 10)

compile 1, 1: True
compile 1, 2: True
~~~~~~~~~~


TorchScript scripting can handle data-dependent control flow, but this
solution comes with its own set of problems. Namely, TorchScript
scripting can require major code changes and will raise errors when
unsupported Python is used.

In the example below, we forget TorchScript type annotations and we
receive a TorchScript error because the input type for argument `y`, an
`int`, does not match with the default argument type, `torch.Tensor`.


In [23]:
def f2(x, y):
    return x + y


inp1 = torch.randn(5, 5)
inp2 = 3

script_f2 = torch.jit.script(f2)
try:
    script_f2(inp1, inp2)
except:
    tb.print_exc()

Traceback (most recent call last):
  File "/tmp/ipykernel_396503/2913434199.py", line 10, in <module>
    script_f2(inp1, inp2)
RuntimeError: f2() Expected a value of type 'Tensor (inferred)' for argument 'y' but instead found type 'int'.
Inferred 'y' to be of type 'Tensor' because it was not annotated with an explicit type.
Position: 1
Value: 3
Declaration: f2(Tensor x, Tensor y) -> Tensor
Cast error details: Unable to cast 3 to Tensor


However, `torch.compile` is easily able to handle `f2`.


In [24]:
compile_f2 = torch.compile(f2)
print("compile 2:", test_fns(f2, compile_f2, (inp1, inp2)))
print("~" * 10)

compile 2: True
~~~~~~~~~~


Another case that `torch.compile` handles well compared to previous
compilers solutions is the usage of non-PyTorch functions.


In [25]:
!pip install scipy -q

In [26]:
import scipy


def f3(x):
    x = x * 2
    x = scipy.fft.dct(x.numpy())
    x = torch.from_numpy(x)
    x = x * 2
    return x

TorchScript tracing treats results from non-PyTorch function calls as
constants, and so our results can be silently wrong.


In [27]:
inp1 = torch.randn(5, 5)
inp2 = torch.randn(5, 5)
traced_f3 = torch.jit.trace(f3, (inp1,))
print("traced 3:", test_fns(f3, traced_f3, (inp2,)))

traced 3: False


  x = scipy.fft.dct(x.numpy())
  x = torch.from_numpy(x)


TorchScript scripting and FX tracing disallow non-PyTorch function
calls.


In [28]:
try:
    torch.jit.script(f3)
except:
    tb.print_exc()

try:
    torch.fx.symbolic_trace(f3)
except:
    tb.print_exc()

Traceback (most recent call last):
  File "/tmp/ipykernel_396503/1991772511.py", line 2, in <module>
    torch.jit.script(f3)
  File "/opt/conda/lib/python3.11/site-packages/torch/jit/_script.py", line 1443, in script
    ret = _script_impl(
          ^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/torch/jit/_script.py", line 1214, in _script_impl
    fn = torch._C._jit_script_compile(
         ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/torch/_jit_internal.py", line 1233, in _try_get_dispatched_fn
    return boolean_dispatched.get(fn)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/weakref.py", line 452, in get
    return self.data.get(ref(key),default)
                         ^^^^^^^^
TypeError: cannot create weak reference to 'uarray._Function' object
Traceback (most recent call last):
  File "/tmp/ipykernel_396503/1991772511.py", line 7, in <module>
    torch.fx.symbolic_trace(f3)
  File "/opt/conda/lib/python

In comparison, `torch.compile` is easily able to handle the non-PyTorch
function call.


In [29]:
inp2

tensor([[ 0.3630,  0.5209, -0.7981, -0.8188, -1.1406],
        [ 1.4841, -1.4682,  0.1976, -0.8664,  0.3564],
        [ 3.0464, -0.0483, -0.0347, -0.7905,  1.6688],
        [-0.9529, -0.5755,  0.3267, -1.3545,  0.4998],
        [-0.3499, -0.4624,  0.1807,  0.7081, -0.6798]])

In [None]:
compile_f3 = torch.compile(f3)
print("compile 3:", test_fns(f3, compile_f3, (inp2,)))

If it is a Python builtin, please file an issue on GitHub so the PyTorch team can add support for it and see the next case for a workaround.
If it is a third-party C/C++ Python extension, please either wrap it into a PyTorch-understood custom operator (see https://pytorch.org/tutorials/advanced/custom_ops_landing_page.html for more details) or, if it is traceable, use `torch.compiler.allow_in_graph`.
  torch._dynamo.utils.warn_once(explanation + "\n" + "\n".join(hints))


InternalTorchDynamoError: AttributeError: 'str' object has no attribute 'IF_NEEDED'

from user code:
   File "/home/user-name-goes-here/.local/lib/python3.11/site-packages/scipy/fft/_realtransforms_backend.py", line 15, in torch_dynamo_resume_in__execute_at_12
    return xp.asarray(y)
  File "/home/user-name-goes-here/.local/lib/python3.11/site-packages/scipy/_lib/array_api_compat/numpy/_aliases.py", line 116, in asarray
    return np.array(obj, copy=copy, dtype=dtype, **kwargs)

Set TORCHDYNAMO_VERBOSE=1 for the internal stack trace (please do this especially if you're reporting a bug to PyTorch). For even more developer context, set TORCH_LOGS="+dynamo"


In [None]:
inp2.numpy()

array([[ 0.37013376, -0.25889605, -1.0555654 ,  1.2370975 ,  1.6417791 ],
       [-1.0730505 , -0.4954986 , -0.6134838 ,  0.04148984, -1.2115946 ],
       [ 0.26323107, -0.19715777,  0.55473197, -0.38094166, -1.2469018 ],
       [-1.7523555 ,  0.2557459 ,  0.2023534 , -0.17961223,  2.4302366 ],
       [ 1.2394569 , -1.192113  ,  0.6754096 ,  0.3486911 ,  0.5941279 ]],
      dtype=float32)

### !!!Не работает как у них в доке!!!

TorchDynamo and FX Graphs
=========================

One important component of `torch.compile` is TorchDynamo. TorchDynamo
is responsible for JIT compiling arbitrary Python code into [FX
graphs](https://pytorch.org/docs/stable/fx.html#torch.fx.Graph), which
can then be further optimized. TorchDynamo extracts FX graphs by
analyzing Python bytecode during runtime and detecting calls to PyTorch
operations.

Normally, TorchInductor, another component of `torch.compile`, further
compiles the FX graphs into optimized kernels, but TorchDynamo allows
for different backends to be used. In order to inspect the FX graphs
that TorchDynamo outputs, let us create a custom backend that outputs
the FX graph and simply returns the graph\'s unoptimized forward method.


In [31]:
from typing import List


def custom_backend(gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor]):
    print("custom backend called with FX graph:")
    gm.graph.print_tabular()
    return gm.forward


# Reset since we are using a different backend.
torch._dynamo.reset()

opt_model = torch.compile(init_model(), backend=custom_backend)
opt_model(generate_data(16)[0])

custom backend called with FX graph:
opcode         name                                                                                                         target                                                                                                       args                                                                                                                                                                                                                                                                                                                                                                                                                                                     kwargs
-------------  -----------------------------------------------------------------------------------------------------------  -----------------------------------------------------------------------------------------------------------  -------------------------------------------------

tensor([[-0.2560, -0.1072, -0.0482,  ...,  0.1529, -0.0687,  0.1681],
        [-0.0559,  0.1181, -0.0911,  ...,  0.1538, -0.0735,  0.0955],
        [-0.0462,  0.1132, -0.0403,  ...,  0.1321, -0.1376,  0.1901],
        ...,
        [-0.1068, -0.0489,  0.0927,  ...,  0.0408,  0.0424,  0.2202],
        [-0.0502, -0.0566, -0.0664,  ...,  0.0414, -0.0391,  0.1008],
        [-0.1472, -0.0846, -0.0037,  ...,  0.2491,  0.0128,  0.1111]],
       device='cuda:0', grad_fn=<AddmmBackward0>)

Using our custom backend, we can now see how TorchDynamo is able to
handle data-dependent control flow. Consider the function below, where
the line `if b.sum() < 0` is the source of data-dependent control flow.


In [36]:
import torch


def bar(a, b):
    x = a / (torch.abs(a) + 1)
    if b.sum() < 0:
        b = b * -1
    # print('test')
    return x * b


opt_bar = torch.compile(bar, backend=custom_backend)
inp1 = torch.randn(10)
inp2 = torch.randn(10)
opt_bar(inp1, inp2)
opt_bar(inp1, -inp2)

custom backend called with FX graph:
opcode         name    target                                                  args         kwargs
-------------  ------  ------------------------------------------------------  -----------  --------
placeholder    l_a_    L_a_                                                    ()           {}
placeholder    l_b_    L_b_                                                    ()           {}
call_function  abs_1   <built-in method abs of type object at 0x7614fe8f6f80>  (l_a_,)      {}
call_function  add     <built-in function add>                                 (abs_1, 1)   {}
call_function  x       <built-in function truediv>                             (l_a_, add)  {}
call_method    sum_1   sum                                                     (l_b_,)      {}
call_function  lt      <built-in function lt>                                  (sum_1, 0)   {}
output         output  output                                                  ((x, lt),)   {}
cus

tensor([ 0.2181,  0.8413,  0.2798, -0.0838, -0.5152, -0.2496,  0.2436, -0.0396,
         0.1016,  0.1305])

The output reveals that TorchDynamo extracted 3 different FX graphs
corresponding the following code (order may differ from the output
above):

1.  `x = a / (torch.abs(a) + 1)`
2.  `b = b * -1; return x * b`
3.  `return x * b`

When TorchDynamo encounters unsupported Python features, such as
data-dependent control flow, it breaks the computation graph, lets the
default Python interpreter handle the unsupported code, then resumes
capturing the graph.

Let\'s investigate by example how TorchDynamo would step through `bar`.
If `b.sum() < 0`, then TorchDynamo would run graph 1, let Python
determine the result of the conditional, then run graph 2. On the other
hand, if `not b.sum() < 0`, then TorchDynamo would run graph 1, let
Python determine the result of the conditional, then run graph 3.

This highlights a major difference between TorchDynamo and previous
PyTorch compiler solutions. When encountering unsupported Python
features, previous solutions either raise an error or silently fail.
TorchDynamo, on the other hand, will break the computation graph.

We can see where TorchDynamo breaks the graph by using
`torch._dynamo.explain`:


In [37]:
# Reset since we are using a different backend.
torch._dynamo.reset()
explain_output = torch._dynamo.explain(bar)(torch.randn(10), torch.randn(10))
print(explain_output)

Graph Count: 2
Graph Break Count: 1
Op Count: 6
Break Reasons:
  Break Reason 1:
    Reason: generic_jump TensorVariable()
    User Stack:
      <FrameSummary file /tmp/ipykernel_396503/2060524620.py, line 6 in bar>
Ops per Graph:
  Ops 1:
    <built-in method abs of type object at 0x7614fe8f6f80>
    <built-in function add>
    <built-in function truediv>
    <built-in function lt>
  Ops 2:
    <built-in function mul>
    <built-in function mul>
Out Guards:
  Guard 1:
    Name: "L['a']"
    Source: local
    Create Function: TENSOR_MATCH
    Guard Types: ['TENSOR_MATCH']
    Code List: ["hasattr(L['a'], '_dynamo_dynamic_indices') == False"]
    Object Weakref: <weakref at 0x76133af1a610; dead>
    Guarded Class Weakref: <weakref at 0x7614ff5158a0; to 'torch._C._TensorMeta' at 0x61fd64183b10 (Tensor)>
  Guard 2:
    Name: "L['b'].sum"
    Source: local
    Create Function: HASATTR
    Guard Types: ['HASATTR']
    Code List: ["hasattr(L['b'], 'sum')"]
    Object Weakref: <weakref at 0x7

In order to maximize speedup, graph breaks should be limited. We can
force TorchDynamo to raise an error upon the first graph break
encountered by using `fullgraph=True`:


In [38]:
opt_bar = torch.compile(bar, fullgraph=True)
try:
    opt_bar(torch.randn(10), torch.randn(10))
except:
    tb.print_exc()


class GraphModule(torch.nn.Module):
    def forward(self, L_a_: "f32[10][1]cpu", L_b_: "f32[10][1]cpu"):
        l_a_ = L_a_
        l_b_ = L_b_
        
         # File: /tmp/ipykernel_396503/2060524620.py:5 in bar, code: x = a / (torch.abs(a) + 1)
        abs_1: "f32[10][1]cpu" = torch.abs(l_a_)
        add: "f32[10][1]cpu" = abs_1 + 1;  abs_1 = None
        x: "f32[10][1]cpu" = l_a_ / add;  l_a_ = add = x = None
        
         # File: /tmp/ipykernel_396503/2060524620.py:6 in bar, code: if b.sum() < 0:
        sum_1: "f32[][]cpu" = l_b_.sum();  l_b_ = None
        lt: "b8[][]cpu" = sum_1 < 0;  sum_1 = lt = None
        
Traceback (most recent call last):
  File "/tmp/ipykernel_396503/3610564610.py", line 3, in <module>
    opt_bar(torch.randn(10), torch.randn(10))
  File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/eval_frame.py", line 659, in _fn
    raise e.with_traceback(None) from None
torch._dynamo.exc.Unsupported: Data-dependent branching
  Explanation: Detected d

And below, we demonstrate that TorchDynamo does not break the graph on
the model we used above for demonstrating speedups.


In [None]:
opt_model = torch.compile(init_model(), fullgraph=True)
print(opt_model(generate_data(16)[0]))

We can use `torch.export` (from PyTorch 2.1+) to extract a single,
exportable FX graph from the input PyTorch program. The exported graph
is intended to be run on different (i.e. Python-less) environments. One
important restriction is that the `torch.export` does not support graph
breaks. Please check [this
tutorial](https://pytorch.org/tutorials/intermediate/torch_export_tutorial.html)
for more details on `torch.export`.


Conclusion
==========

In this tutorial, we introduced `torch.compile` by covering basic usage,
demonstrating speedups over eager mode, comparing to previous PyTorch
compiler solutions, and briefly investigating TorchDynamo and its
interactions with FX graphs. We hope that you will give `torch.compile`
a try!
