In [1]:
# For tips on running notebooks in Google Colab, see
# https://docs.pytorch.org/tutorials/beginner/colab
%matplotlib inline

Introduction to `torch.compile`
===============================

**Author:** William Wen


`torch.compile` is the new way to speed up your PyTorch code!
`torch.compile` makes PyTorch code run faster by JIT-compiling PyTorch
code into optimized kernels, while requiring minimal code changes.

`torch.compile` accomplishes this by tracing through your Python code,
looking for PyTorch operations. Code that is difficult to trace will
result a **graph break**, which are lost optimization opportunities,
rather than errors or silent incorrectness.

`torch.compile` is available in PyTorch 2.0 and later.

This introduction covers basic `torch.compile` usage and demonstrates
the advantages of `torch.compile` over our previous PyTorch compiler
solution, [TorchScript](https://pytorch.org/docs/stable/jit.html).

For an end-to-end example on a real model, check out our [end-to-end
torch.compile
tutorial](https://pytorch.org/tutorials/intermediate/torch_compile_full_example.html).

To troubleshoot issues and to gain a deeper understanding of how to
apply `torch.compile` to your code, check out [the torch.compile
programming
model](https://docs.pytorch.org/docs/main/compile/programming_model.html).

**Contents**

::: {.contents local=""}
:::

**Required pip dependencies for this tutorial**

-   `torch >= 2.0`
-   `numpy`
-   `scipy`

**System requirements** - A C++ compiler, such as `g++` - Python
development package (`python-devel`/`python-dev`)


Basic Usage
===========

We turn on some logging to help us to see what `torch.compile` is doing
under the hood in this tutorial. The following code will print out the
PyTorch ops that `torch.compile` traced.


In [None]:
import torch


torch._logging.set_logs(graph_code=True)
# 不起作用
# Enable verbose logging to see compilation details
# torch._dynamo.config.verbose = True

In [6]:
def foo(x, y):
    a = torch.sin(x)
    b = torch.cos(y)
    return a + b


opt_foo1 = torch.compile(foo)
print(opt_foo1(torch.randn(3, 3), torch.randn(3, 3)))
print(opt_foo1(torch.randn(4, 4), torch.randn(4, 4)))

V0125 15:13:05.402000 197418 site-packages/torch/_dynamo/output_graph.py:2184] [3/0] [__graph_code] TRACED GRAPH
V0125 15:13:05.402000 197418 site-packages/torch/_dynamo/output_graph.py:2184] [3/0] [__graph_code]  ===== __compiled_fn_7_9931b1aa_c92a_4185_8547_d7e651d37223 =====
V0125 15:13:05.402000 197418 site-packages/torch/_dynamo/output_graph.py:2184] [3/0] [__graph_code]  /home/hliu/anaconda3/lib/python3.13/site-packages/torch/fx/_lazy_graph_module.py class GraphModule(torch.nn.Module):
V0125 15:13:05.402000 197418 site-packages/torch/_dynamo/output_graph.py:2184] [3/0] [__graph_code]     def forward(self, L_x_: "f32[3, 3][3, 1]cpu", L_y_: "f32[3, 3][3, 1]cpu"):
V0125 15:13:05.402000 197418 site-packages/torch/_dynamo/output_graph.py:2184] [3/0] [__graph_code]         l_x_ = L_x_
V0125 15:13:05.402000 197418 site-packages/torch/_dynamo/output_graph.py:2184] [3/0] [__graph_code]         l_y_ = L_y_
V0125 15:13:05.402000 197418 site-packages/torch/_dynamo/output_graph.py:2184] [3/0]

tensor([[ 1.9546,  1.8503,  0.8476],
        [-0.0234, -0.1408,  1.9060],
        [-1.5669,  0.2384,  1.4532]])
tensor([[ 0.0889,  0.9144,  1.0922,  0.3792],
        [ 0.8120,  0.0853, -0.2555,  0.7229],
        [ 0.2868,  0.0549,  0.6464,  0.5942],
        [-0.0562,  0.0907, -0.3454,  0.6626]])


In [18]:
x = torch.randn(3, 4)
print(x.shape[0].__class__)  # <class 'int'>
print(x.shape[0])

<class 'int'>
3


In [9]:
print(opt_foo1(torch.randn(3, 3), torch.randn(3, 3)))
print(opt_foo1(torch.randn(4, 4), torch.randn(4, 4)))

tensor([[0.6818, 1.7326, 0.1336],
        [0.0587, 0.4029, 0.8624],
        [0.7774, 0.9512, 0.2888]])
tensor([[ 0.7467,  0.0485,  1.4156,  0.0289],
        [ 0.9113,  0.9799,  1.4499,  0.6286],
        [ 0.0259, -0.1759, -0.9992, -1.0882],
        [ 0.1066,  0.9046, -0.4007,  1.6108]])


`torch.compile` is a decorator that takes an arbitrary Python function.


In [None]:
@torch.compile
def opt_foo2(x, y):
    a = torch.sin(x)
    b = torch.cos(y)
    return a + b

print(opt_foo2(torch.randn(3, 3), torch.randn(3, 3)))

V0125 14:42:53.693000 179889 site-packages/torch/_dynamo/output_graph.py:2184] [4/0] [__graph_code] TRACED GRAPH
V0125 14:42:53.693000 179889 site-packages/torch/_dynamo/output_graph.py:2184] [4/0] [__graph_code]  ===== __compiled_fn_9_5e62b208_6380_4342_b318_8c393cd6df6f =====
V0125 14:42:53.693000 179889 site-packages/torch/_dynamo/output_graph.py:2184] [4/0] [__graph_code]  /home/hliu/anaconda3/lib/python3.13/site-packages/torch/fx/_lazy_graph_module.py class GraphModule(torch.nn.Module):
V0125 14:42:53.693000 179889 site-packages/torch/_dynamo/output_graph.py:2184] [4/0] [__graph_code]     def forward(self, L_x_: "f32[3, 3][3, 1]cpu", L_y_: "f32[3, 3][3, 1]cpu"):
V0125 14:42:53.693000 179889 site-packages/torch/_dynamo/output_graph.py:2184] [4/0] [__graph_code]         l_x_ = L_x_
V0125 14:42:53.693000 179889 site-packages/torch/_dynamo/output_graph.py:2184] [4/0] [__graph_code]         l_y_ = L_y_
V0125 14:42:53.693000 179889 site-packages/torch/_dynamo/output_graph.py:2184] [4/0]

tensor([[ 0.0671,  0.8626,  1.3965],
        [ 1.4243,  0.3922,  0.3412],
        [-0.9585, -1.1444, -0.2240]])


`torch.compile` is applied recursively, so nested function calls within
the top-level compiled function will also be compiled.


In [8]:
def inner(x):
    return torch.sin(x)


@torch.compile
def outer(x, y):
    a = inner(x)
    b = torch.cos(y)
    return a + b


print(outer(torch.randn(3, 3), torch.randn(3, 3)))

V0125 14:45:58.618000 179889 site-packages/torch/_dynamo/output_graph.py:2184] [6/0] [__graph_code] TRACED GRAPH
V0125 14:45:58.618000 179889 site-packages/torch/_dynamo/output_graph.py:2184] [6/0] [__graph_code]  ===== __compiled_fn_13_6f0e438d_f42f_4d18_af2a_f6be1b9424d2 =====
V0125 14:45:58.618000 179889 site-packages/torch/_dynamo/output_graph.py:2184] [6/0] [__graph_code]  /home/hliu/anaconda3/lib/python3.13/site-packages/torch/fx/_lazy_graph_module.py class GraphModule(torch.nn.Module):
V0125 14:45:58.618000 179889 site-packages/torch/_dynamo/output_graph.py:2184] [6/0] [__graph_code]     def forward(self, L_x_: "f32[3, 3][3, 1]cpu", L_y_: "f32[3, 3][3, 1]cpu"):
V0125 14:45:58.618000 179889 site-packages/torch/_dynamo/output_graph.py:2184] [6/0] [__graph_code]         l_x_ = L_x_
V0125 14:45:58.618000 179889 site-packages/torch/_dynamo/output_graph.py:2184] [6/0] [__graph_code]         l_y_ = L_y_
V0125 14:45:58.618000 179889 site-packages/torch/_dynamo/output_graph.py:2184] [6/0

tensor([[ 1.7620,  0.5823, -0.2230],
        [ 0.4700,  1.6073,  0.1784],
        [ 1.5852,  0.1588,  0.6764]])


We can also optimize `torch.nn.Module` instances by either calling its
`.compile()` method or by directly `torch.compile`-ing the module. This
is equivalent to `torch.compile`-ing the module\'s `__call__` method
(which indirectly calls `forward`).


In [19]:
class MyModule(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.lin = torch.nn.Linear(3, 3)

    def forward(self, x):
        return torch.nn.functional.relu(self.lin(x))

mod2 = MyModule()
mod2 = torch.compile(mod2)
print(mod2(torch.randn(3, 3)))

V0125 14:51:49.922000 179889 site-packages/torch/_dynamo/output_graph.py:2184] [13/0] [__graph_code] TRACED GRAPH
V0125 14:51:49.922000 179889 site-packages/torch/_dynamo/output_graph.py:2184] [13/0] [__graph_code]  ===== __compiled_fn_27_9c390f3e_399a_4717_9ceb_a8f78871a303 =====
V0125 14:51:49.922000 179889 site-packages/torch/_dynamo/output_graph.py:2184] [13/0] [__graph_code]  /home/hliu/anaconda3/lib/python3.13/site-packages/torch/fx/_lazy_graph_module.py class GraphModule(torch.nn.Module):
V0125 14:51:49.922000 179889 site-packages/torch/_dynamo/output_graph.py:2184] [13/0] [__graph_code]     def forward(self, L_self_modules_lin_parameters_weight_: "f32[3, 3][3, 1]cpu", L_self_modules_lin_parameters_bias_: "f32[3][1]cpu", L_x_: "f32[3, 3][3, 1]cpu"):
V0125 14:51:49.922000 179889 site-packages/torch/_dynamo/output_graph.py:2184] [13/0] [__graph_code]         l_self_modules_lin_parameters_weight_ = L_self_modules_lin_parameters_weight_
V0125 14:51:49.922000 179889 site-packages/tor

tensor([[0.7187, 0.9459, 0.6886],
        [0.0000, 0.5362, 0.3013],
        [0.0000, 0.4415, 0.0248]], grad_fn=<CompiledFunctionBackward>)


In [None]:
# 一个模型只会被编译一次，重复调用 compile 方法不会重新编译模型
# 下面的代码与上面的效果是一样的
mod1 = MyModule()
mod1.compile()
print(mod1(torch.randn(3, 3)))

tensor([[0.0000, 0.0596, 0.1275],
        [0.0000, 0.7664, 0.0000],
        [0.1984, 0.2656, 0.0000]], grad_fn=<CompiledFunctionBackward>)


Demonstrating Speedups
======================

Now let\'s demonstrate how `torch.compile` speeds up a simple PyTorch
example. For a demonstration on a more complex model, see our
[end-to-end torch.compile
tutorial](https://pytorch.org/tutorials/intermediate/torch_compile_full_example.html).


In [None]:
def foo3(x):
    y = x + 1
    z = torch.nn.functional.relu(y)
    u = z * 2
    return u

# 启用 Inductor 生成的底层代码日志
torch._logging.set_logs(graph_code=True, output_code=True)
opt_foo3 = torch.compile(foo3)


# Returns the result of running `fn()` and the time it took for `fn()` to run,
# in seconds. We use CUDA events and synchronization for the most accurate
# measurements.
def timed(fn):
    start = torch.cuda.Event(enable_timing=True)
    end = torch.cuda.Event(enable_timing=True)
    start.record()
    result = fn()
    end.record()
    torch.cuda.synchronize()
    return result, start.elapsed_time(end) / 1000

# Warmup # 运行一次以触发编译
inp = torch.randn(4096, 4096).cuda()
print("compile:", timed(lambda: opt_foo3(inp))[1])
print("eager:", timed(lambda: foo3(inp))[1])

V0125 16:20:30.437000 197418 site-packages/torch/_dynamo/output_graph.py:2184] [10/0] [__graph_code] TRACED GRAPH
V0125 16:20:30.437000 197418 site-packages/torch/_dynamo/output_graph.py:2184] [10/0] [__graph_code]  ===== __compiled_fn_23_60b27506_864b_4e24_b498_00cccf75ee57 =====
V0125 16:20:30.437000 197418 site-packages/torch/_dynamo/output_graph.py:2184] [10/0] [__graph_code]  /home/hliu/anaconda3/lib/python3.13/site-packages/torch/fx/_lazy_graph_module.py class GraphModule(torch.nn.Module):
V0125 16:20:30.437000 197418 site-packages/torch/_dynamo/output_graph.py:2184] [10/0] [__graph_code]     def forward(self, L_x_: "f32[4096, 4096][4096, 1]cuda:0"):
V0125 16:20:30.437000 197418 site-packages/torch/_dynamo/output_graph.py:2184] [10/0] [__graph_code]         l_x_ = L_x_
V0125 16:20:30.437000 197418 site-packages/torch/_dynamo/output_graph.py:2184] [10/0] [__graph_code]         
V0125 16:20:30.437000 197418 site-packages/torch/_dynamo/output_graph.py:2184] [10/0] [__graph_code]    

compile: 0.02723142433166504
eager: 0.0008028159737586975


In [None]:
inp = torch.randn(4096, 4096).cuda()
# why Faster?
# @triton.jit
# def triton_poi_fused_add_mul_relu_0
print("compile:", timed(lambda: opt_foo3(inp))[1])
print("eager:", timed(lambda: foo3(inp))[1])

compile: 0.00048732799291610715
eager: 0.0009031040072441101


Notice that `torch.compile` appears to take a lot longer to complete
compared to eager. This is because `torch.compile` takes extra time to
compile the model on the first few executions. `torch.compile` re-uses
compiled code whever possible, so if we run our optimized model several
more times, we should see a significant improvement compared to eager.


In [31]:
# turn off logging for now to prevent spam
torch._logging.set_logs(graph_code=False, output_code=False)

eager_times = []
for i in range(10):
    _, eager_time = timed(lambda: foo3(inp))
    eager_times.append(eager_time)
    print(f"eager time {i}: {eager_time}")
print("~" * 10)

compile_times = []
for i in range(10):
    _, compile_time = timed(lambda: opt_foo3(inp))
    compile_times.append(compile_time)
    print(f"compile time {i}: {compile_time}")
print("~" * 10)

import numpy as np

eager_med = np.median(eager_times)
compile_med = np.median(compile_times)
speedup = eager_med / compile_med
assert speedup > 1
print(
    f"(eval) eager median: {eager_med}, compile median: {compile_med}, speedup: {speedup}x"
)
print("~" * 10)

eager time 0: 0.0008069120049476623
eager time 1: 0.0007383040189743042
eager time 2: 0.000735647976398468
eager time 3: 0.000735584020614624
eager time 4: 0.0007297919988632203
eager time 5: 0.0007301120162010193
eager time 6: 0.000729088008403778
eager time 7: 0.0007279679775238037
eager time 8: 0.0007321599721908569
eager time 9: 0.0007331839799880981
~~~~~~~~~~
compile time 0: 0.000374783992767334
compile time 1: 0.0002744320034980774
compile time 2: 0.0002652159929275513
compile time 3: 0.0002631680071353912
compile time 4: 0.00026214399933815005
compile time 5: 0.0002611199915409088
compile time 6: 0.0002682879865169525
compile time 7: 0.00026624000072479246
compile time 8: 0.00028569599986076356
compile time 9: 0.0002703680098056793
~~~~~~~~~~
(eval) eager median: 0.0007326719760894775, compile median: 0.0002672639936208725, speedup: 2.74137928631273x
~~~~~~~~~~


And indeed, we can see that running our model with `torch.compile`
results in a significant speedup. Speedup mainly comes from reducing
Python overhead and GPU read/writes, and so the observed speedup may
vary on factors such as model architecture and batch size. For example,
if a model\'s architecture is simple and the amount of data is large,
then the bottleneck would be GPU compute and the observed speedup may be
less significant.

To see speedups on a real model, check out our [end-to-end torch.compile
tutorial](https://pytorch.org/tutorials/intermediate/torch_compile_full_example.html).


Benefits over TorchScript
=========================

Why should we use `torch.compile` over TorchScript? Primarily, the
advantage of `torch.compile` lies in its ability to handle arbitrary
Python code with minimal changes to existing code.

Compare to TorchScript, which has a tracing mode (`torch.jit.trace`) and
a scripting mode (`torch.jit.script`). Tracing mode is susceptible to
silent incorrectness, while scripting mode requires significant code
changes and will raise errors on unsupported Python code.

For example, TorchScript tracing silently fails on data-dependent
control flow (the `if x.sum() < 0:` line below) because only the actual
control flow path is traced. In comparison, `torch.compile` is able to
correctly handle it.


In [32]:
def f1(x, y):
    if x.sum() < 0:
        return -y
    return y


# Test that `fn1` and `fn2` return the same result, given the same arguments `args`.
def test_fns(fn1, fn2, args):
    out1 = fn1(*args)
    out2 = fn2(*args)
    return torch.allclose(out1, out2)


inp1 = torch.randn(5, 5)
inp2 = torch.randn(5, 5)

traced_f1 = torch.jit.trace(f1, (inp1, inp2))
print("traced 1, 1:", test_fns(f1, traced_f1, (inp1, inp2)))
print("traced 1, 2:", test_fns(f1, traced_f1, (-inp1, inp2)))

compile_f1 = torch.compile(f1)
print("compile 1, 1:", test_fns(f1, compile_f1, (inp1, inp2)))
print("compile 1, 2:", test_fns(f1, compile_f1, (-inp1, inp2)))
print("~" * 10)

traced 1, 1: True
traced 1, 2: False
compile 1, 1: True
compile 1, 2: True
~~~~~~~~~~


  if x.sum() < 0:


TorchScript scripting can handle data-dependent control flow, but it can
require major code changes and will raise errors when unsupported Python
is used.

In the example below, we forget TorchScript type annotations and we
receive a TorchScript error because the input type for argument `y`, an
`int`, does not match with the default argument type, `torch.Tensor`. In
comparison, `torch.compile` works without requiring any type
annotations.


In [33]:
import traceback as tb

torch._logging.set_logs(graph_code=True)


def f2(x, y):
    return x + y


inp1 = torch.randn(5, 5)
inp2 = 3

script_f2 = torch.jit.script(f2)
try:
    script_f2(inp1, inp2)
except:
    tb.print_exc()

compile_f2 = torch.compile(f2)
print("compile 2:", test_fns(f2, compile_f2, (inp1, inp2)))
print("~" * 10)

Traceback (most recent call last):
  File "/tmp/ipykernel_197418/3652677659.py", line 15, in <module>
    script_f2(inp1, inp2)
    ~~~~~~~~~^^^^^^^^^^^^
RuntimeError: f2() Expected a value of type 'Tensor (inferred)' for argument 'y' but instead found type 'int'.
Inferred 'y' to be of type 'Tensor' because it was not annotated with an explicit type.
Position: 1
Value: 3
Declaration: f2(Tensor x, Tensor y) -> Tensor
Cast error details: Unable to cast 3 to Tensor
V0125 16:42:57.616000 197418 site-packages/torch/_dynamo/output_graph.py:2184] [19/0] [__graph_code] TRACED GRAPH
V0125 16:42:57.616000 197418 site-packages/torch/_dynamo/output_graph.py:2184] [19/0] [__graph_code]  ===== __compiled_fn_43_15c554a6_e2c3_4bf5_a206_0da60a135e49 =====
V0125 16:42:57.616000 197418 site-packages/torch/_dynamo/output_graph.py:2184] [19/0] [__graph_code]  /home/hliu/anaconda3/lib/python3.13/site-packages/torch/fx/_lazy_graph_module.py class GraphModule(torch.nn.Module):
V0125 16:42:57.616000 197418 sit

compile 2: True
~~~~~~~~~~


Graph Breaks
============

The graph break is one of the most fundamental concepts within
`torch.compile`. It allows `torch.compile` to handle arbitrary Python
code by interrupting compilation, running the unsupported code, then
resuming compilation. The term \"graph break\" comes from the fact that
`torch.compile` attempts to capture and optimize the PyTorch operation
graph. When unsupported Python code is encountered, then this graph must
be \"broken\". Graph breaks result in lost optimization opportunities,
which may still be undesirable, but this is better than silent
incorrectness or a hard crash.

Let\'s look at a data-dependent control flow example to better see how
graph breaks work.


In [34]:
def bar(a, b):
    x = a / (torch.abs(a) + 1)
    if b.sum() < 0:
        b = b * -1
    return x * b


opt_bar = torch.compile(bar)
inp1 = torch.ones(10)
inp2 = torch.ones(10)
opt_bar(inp1, inp2)
opt_bar(inp1, -inp2)

V0125 16:56:29.807000 197418 site-packages/torch/_dynamo/output_graph.py:2184] [20/0] [__graph_code] TRACED GRAPH
V0125 16:56:29.807000 197418 site-packages/torch/_dynamo/output_graph.py:2184] [20/0] [__graph_code]  ===== __compiled_fn_45_196dda0d_23c7_4766_b135_c5a1aebde606 =====
V0125 16:56:29.807000 197418 site-packages/torch/_dynamo/output_graph.py:2184] [20/0] [__graph_code]  /home/hliu/anaconda3/lib/python3.13/site-packages/torch/fx/_lazy_graph_module.py class GraphModule(torch.nn.Module):
V0125 16:56:29.807000 197418 site-packages/torch/_dynamo/output_graph.py:2184] [20/0] [__graph_code]     def forward(self, L_a_: "f32[10][1]cpu", L_b_: "f32[10][1]cpu"):
V0125 16:56:29.807000 197418 site-packages/torch/_dynamo/output_graph.py:2184] [20/0] [__graph_code]         l_a_ = L_a_
V0125 16:56:29.807000 197418 site-packages/torch/_dynamo/output_graph.py:2184] [20/0] [__graph_code]         l_b_ = L_b_
V0125 16:56:29.807000 197418 site-packages/torch/_dynamo/output_graph.py:2184] [20/0] [

tensor([0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000,
        0.5000])

The first time we run `bar`, we see that `torch.compile` traced 2 graphs
corresponding to the following code (noting that `b.sum() < 0` is
False):

1.  `x = a / (torch.abs(a) + 1); b.sum()`
2.  `return x * b`

The second time we run `bar`, we take the other branch of the if
statement and we get 1 traced graph corresponding to the code
`b = b * -1; return x * b`. We do not see a graph of
`x = a / (torch.abs(a) + 1)` outputted the second time since
`torch.compile` cached this graph from the first run and re-used it.

Let\'s investigate by example how TorchDynamo would step through `bar`.
If `b.sum() < 0`, then TorchDynamo would run graph 1, let Python
determine the result of the conditional, then run graph 2. On the other
hand, if `not b.sum() < 0`, then TorchDynamo would run graph 1, let
Python determine the result of the conditional, then run graph 3.

We can see all graph breaks by using
`torch._logging.set_logs(graph_breaks=True)`.


In [35]:
torch._logging.set_logs(graph_breaks=True)
# Reset to clear the torch.compile cache
torch._dynamo.reset()
opt_bar(inp1, inp2)
opt_bar(inp1, -inp2)

V0125 17:04:01.233000 197418 site-packages/torch/_dynamo/symbolic_convert.py:4275] [0/0] [__graph_breaks] Graph break in user code at /tmp/ipykernel_197418/49326488.py:3
V0125 17:04:01.233000 197418 site-packages/torch/_dynamo/symbolic_convert.py:4275] [0/0] [__graph_breaks] Graph Break Reason: Data-dependent branching
V0125 17:04:01.233000 197418 site-packages/torch/_dynamo/symbolic_convert.py:4275] [0/0] [__graph_breaks]   Explanation: Detected data-dependent branching (e.g. `if my_tensor.sum() > 0:`). Dynamo does not support tracing dynamic control flow.
V0125 17:04:01.233000 197418 site-packages/torch/_dynamo/symbolic_convert.py:4275] [0/0] [__graph_breaks]   Hint: This graph break is fundamental - it is unlikely that Dynamo will ever be able to trace through your code. Consider finding a workaround.
V0125 17:04:01.233000 197418 site-packages/torch/_dynamo/symbolic_convert.py:4275] [0/0] [__graph_breaks]   Hint: Use `torch.cond` to express dynamic control flow.
V0125 17:04:01.23300

tensor([0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000,
        0.5000])

In order to maximize speedup, graph breaks should be limited. We can
force TorchDynamo to raise an error upon the first graph break
encountered by using `fullgraph=True`:


In [36]:
# Reset to clear the torch.compile cache
torch._dynamo.reset()

opt_bar_fullgraph = torch.compile(bar, fullgraph=True)
try:
    opt_bar_fullgraph(torch.randn(10), torch.randn(10))
except:
    tb.print_exc()

Traceback (most recent call last):
  File "/tmp/ipykernel_197418/387069252.py", line 6, in <module>
    opt_bar_fullgraph(torch.randn(10), torch.randn(10))
    ~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/hliu/anaconda3/lib/python3.13/site-packages/torch/_dynamo/eval_frame.py", line 953, in compile_wrapper
    return fn(*args, **kwargs)
  File "/home/hliu/anaconda3/lib/python3.13/site-packages/torch/_dynamo/convert_frame.py", line 2202, in __call__
    result = self._torchdynamo_orig_backend(
        frame, cache_entry, self.hooks, frame_state, skip=1
    )
  File "/home/hliu/anaconda3/lib/python3.13/site-packages/torch/_dynamo/convert_frame.py", line 707, in __call__
    result = _compile(
        frame.f_code,
    ...<16 lines>...
        convert_frame_box=self._box,
    )
  File "/home/hliu/anaconda3/lib/python3.13/site-packages/torch/_dynamo/convert_frame.py", line 1752, in _compile
    guarded_code, tracer_output = compile_inner(code, one_graph, hooks)
       

In our example above, we can work around this graph break by replacing
the if statement with a `torch.cond`:


In [38]:
from functorch.experimental.control_flow import cond

torch._logging.set_logs(graph_code=True)

@torch.compile(fullgraph=True)
def bar_fixed(a, b):
    x = a / (torch.abs(a) + 1)

    def true_branch(y):
        return y * -1

    def false_branch(y):
        # NOTE: torch.cond doesn't allow aliased outputs
        return y.clone()

    x = cond(b.sum() < 0, true_branch, false_branch, (b,))
    return x * b


bar_fixed(inp1, inp2)
bar_fixed(inp1, -inp2)

V0125 17:11:05.755000 197418 site-packages/torch/_dynamo/output_graph.py:2184] [2/0] [__graph_code] TRACED GRAPH
V0125 17:11:05.755000 197418 site-packages/torch/_dynamo/output_graph.py:2184] [2/0] [__graph_code]  ===== __compiled_fn_64_8ac9cba9_53bc_47ee_b0fe_8854e6cc4561 =====
V0125 17:11:05.755000 197418 site-packages/torch/_dynamo/output_graph.py:2184] [2/0] [__graph_code]  /home/hliu/anaconda3/lib/python3.13/site-packages/torch/fx/_lazy_graph_module.py class GraphModule(torch.nn.Module):
V0125 17:11:05.755000 197418 site-packages/torch/_dynamo/output_graph.py:2184] [2/0] [__graph_code]     def forward(self, L_a_: "f32[10][1]cpu", L_b_: "f32[10][1]cpu"):
V0125 17:11:05.755000 197418 site-packages/torch/_dynamo/output_graph.py:2184] [2/0] [__graph_code]         l_a_ = L_a_
V0125 17:11:05.755000 197418 site-packages/torch/_dynamo/output_graph.py:2184] [2/0] [__graph_code]         l_b_ = L_b_
V0125 17:11:05.755000 197418 site-packages/torch/_dynamo/output_graph.py:2184] [2/0] [__graph

tensor([-1., -1., -1., -1., -1., -1., -1., -1., -1., -1.])

In order to serialize graphs or to run graphs on different (i.e.
Python-less) environments, consider using `torch.export` instead (from
PyTorch 2.1+). One important restriction is that `torch.export` does not
support graph breaks. Please check [the torch.export
tutorial](https://pytorch.org/tutorials/intermediate/torch_export_tutorial.html)
for more details on `torch.export`.

Check out our [section on graph breaks in the torch.compile programming
model](https://docs.pytorch.org/docs/main/compile/programming_model.graph_breaks_index.html)
for tips on how to work around graph breaks.


Troubleshooting
===============

Is `torch.compile` failing to speed up your model? Is compile time
unreasonably long? Is your code recompiling excessively? Are you having
difficulties dealing with graph breaks? Are you looking for tips on how
to best use `torch.compile`? Or maybe you simply want to learn more
about the inner workings of `torch.compile`?

Check out [the torch.compile programming
model](https://docs.pytorch.org/docs/main/compile/programming_model.html).


Conclusion
==========

In this tutorial, we introduced `torch.compile` by covering basic usage,
demonstrating speedups over eager mode, comparing to TorchScript, and
briefly describing graph breaks.

For an end-to-end example on a real model, check out our [end-to-end
torch.compile
tutorial](https://pytorch.org/tutorials/intermediate/torch_compile_full_example.html).

To troubleshoot issues and to gain a deeper understanding of how to
apply `torch.compile` to your code, check out [the torch.compile
programming
model](https://docs.pytorch.org/docs/main/compile/programming_model.html).

We hope that you will give `torch.compile` a try!
