# `TorchDynamo`
## and (Maybe) PEP-523

Dboy Liao

## `TorchDynamo`

- What?
   - What is `TorchDynamo`?
   - What can it do?

- Why?

- (Maybe) How?
    - [PEP-523](https://peps.python.org/pep-0523/)

## `TorchDynamo`: JIT Compiler Framework

- The workhorse for `torch.compile` under the hook
    - A Python-to-Python JIT compiler

![torchdynamo-jit](img/torchdynamo_jit.drawio.svg)

In [1]:
import torch

In [2]:
# A toy example
def foo(a: torch.Tensor, b: torch.Tensor):
    return torch.add(a, b)

# foo_ is compiled version of foo (JIT-enabled)
foo_ = torch.compile(foo)

In [3]:
a = torch.randn(3, 4)
b = torch.randn(3, 4)

In [4]:
%time
foo(a, b)

CPU times: user 7 μs, sys: 1e+03 ns, total: 8 μs
Wall time: 14.5 μs


tensor([[ 1.1368, -2.5452,  0.3648, -0.1320],
        [ 0.2835,  0.3718, -2.8085, -1.2019],
        [-0.0500,  0.8236,  1.4256, -1.1588]])

In [5]:
%time
foo_(a, b)

CPU times: user 7 μs, sys: 1 μs, total: 8 μs
Wall time: 14.3 μs


tensor([[ 1.1368, -2.5452,  0.3648, -0.1320],
        [ 0.2835,  0.3718, -2.8085, -1.2019],
        [-0.0500,  0.8236,  1.4256, -1.1588]])

In [6]:
%time
foo_(a, b)

CPU times: user 7 μs, sys: 1 μs, total: 8 μs
Wall time: 15 μs


tensor([[ 1.1368, -2.5452,  0.3648, -0.1320],
        [ 0.2835,  0.3718, -2.8085, -1.2019],
        [-0.0500,  0.8236,  1.4256, -1.1588]])

In [7]:
from depyf import decompile
from dis import dis
from torch._dynamo.eval_frame import _debug_get_cache_entry_list, innermost_fn

  __version__ = open(f"{os.path.dirname(__file__)}/VERSION.txt").read().strip()


In [8]:
cache_entries = _debug_get_cache_entry_list(innermost_fn(foo_))
cache_entry = cache_entries[0]
guard, code = cache_entry.check_fn, cache_entry.code

In [9]:
print(decompile(foo.__code__))

def foo(a, b):
    return torch.add(a, b)



In [10]:
print(decompile(code))

def foo(a, b):
    __temp_3, = __compiled_fn_1(b, a)
    return __temp_3



In [11]:
fn = innermost_fn(__compiled_fn_1)

In [12]:
print(fn.__closure__[0].cell_contents.__closure__[0].cell_contents.__closure__[0].cell_contents.__closure__[0].cell_contents.source_code)

# AOT ID: ['0_inference']
from ctypes import c_void_p, c_long, c_int
import torch
import math
import random
import os
import tempfile
from math import inf, nan
from torch._inductor.hooks import run_intermediate_hooks
from torch._inductor.utils import maybe_profile
from torch._inductor.codegen.memory_planning import _align as align
from torch import device, empty_strided
from torch._inductor.async_compile import AsyncCompile
from torch._inductor.select_algorithm import extern_kernels
from torch._inductor.codegen.multi_kernel import MultiKernelCall

aten = torch.ops.aten
inductor_ops = torch.ops.inductor
_quantized = torch.ops._quantized
assert_size_stride = torch._C._dynamo.guards.assert_size_stride
empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu
empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda
empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu
reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor
alloc_from_pool = torch.ops.inductor

You can learn more about `ATen` with following links:
- https://pytorch.org/cppdocs/

## Speedup Your Module with `TorchDynamo`

In [13]:
from transformers import ViTImageProcessor, ViTModel
from PIL import Image
import requests

  from .autonotebook import tqdm as notebook_tqdm


In [14]:
# https://huggingface.co/google/vit-base-patch16-224-in21k
url = 'http://images.cocodataset.org/val2017/000000039769.jpg'
image = Image.open(requests.get(url, stream=True).raw)

processor = ViTImageProcessor.from_pretrained('google/vit-base-patch16-224-in21k')
model = ViTModel.from_pretrained('google/vit-base-patch16-224-in21k')

inputs = processor(images=[image for _ in range(10)], return_tensors="pt")

In [15]:
device = torch.device(0 if torch.cuda.is_available() else "cpu") 
device

device(type='cuda', index=0)

In [16]:
model = model.to(device)
inputs = {k: v.to(device) for k, v in inputs.items()}

In [17]:
%time
_ = model(**inputs)

CPU times: user 9 μs, sys: 1e+03 ns, total: 10 μs
Wall time: 18.4 μs


In [18]:
model_ = torch.compile(model)

In [19]:
# trigger JIT-compilation
_ = model_(**inputs)

In [20]:
%time
_ = model_(**inputs)

CPU times: user 7 μs, sys: 0 ns, total: 7 μs
Wall time: 14.3 μs


## Custom Backend/Compiler Support

- `torch.compile` allows you to modify/rewrite your module with custom backend function.
    - Build on top of `fx.Graph`, which is the graph IR for PyTorch.
    - **Custom optimization on the model is possible**

In [13]:
import torch.nn as nn

In [14]:
class MyModule(nn.Module):
    def __init__(self):
        super().__init__()
        self._linear = nn.Linear(100, 3)
        self._dropout = nn.Dropout(p=0.3)
        
    def forward(self, x: torch.Tensor):
        out = self._linear(x)
        out = self._dropout(out)
        return out

my_model = MyModule()

In [15]:
ori_graph = torch.fx.symbolic_trace(my_model).graph
ori_graph.print_tabular()

opcode       name      target    args         kwargs
-----------  --------  --------  -----------  --------
placeholder  x         x         ()           {}
call_module  _linear   _linear   (x,)         {}
call_module  _dropout  _dropout  (_linear,)   {}
output       output    output    (_dropout,)  {}


In [16]:
new_graph = None

def my_backend(gm: torch.fx.GraphModule, sample_inputs: list[torch.Tensor]):
    global new_graph
    for node in gm.graph.nodes:
        if node.target == torch.nn.functional.dropout:
            node.replace_all_uses_with(node.args[0])
            gm.graph.erase_node(node)
    gm.recompile()
    new_graph = gm.graph
    return gm.forward

In [17]:
from copy import deepcopy
new_graph = None

def my_backend(gm: torch.fx.GraphModule, sample_inputs: list[torch.Tensor]):
    global new_graph
    new_graph = deepcopy(gm.graph)
    for node in new_graph.nodes:
        if node.target == torch.nn.functional.dropout:
            node.replace_all_uses_with(node.args[0])
            new_graph.erase_node(node)
    new_graph.lint()
    new_gm = torch.fx.GraphModule(root=gm, graph=new_graph)
    return new_gm.forward

In [18]:
my_model_ = torch.compile(my_model, backend=my_backend)

In [19]:
x = torch.randn(10, 100, dtype=torch.float32)

In [20]:
my_model_(x)

tensor([[-0.0533, -0.6154,  0.0078],
        [ 0.5948,  0.2216, -0.3222],
        [-0.4676, -0.2946, -1.0253],
        [-0.2791, -0.5381,  0.0247],
        [ 0.4472,  0.0657, -1.3394],
        [ 0.2567, -0.7131, -1.3041],
        [ 0.1124,  0.4636, -0.1534],
        [-0.9166,  0.7008, -0.0301],
        [-0.4780, -0.2450, -0.6456],
        [ 0.8653, -0.0872,  0.5471]], grad_fn=<AddmmBackward0>)

In [21]:
new_graph.print_tabular()

opcode         name                                      target                                    args                                                                                      kwargs
-------------  ----------------------------------------  ----------------------------------------  ----------------------------------------------------------------------------------------  --------
placeholder    l_self_modules_linear_parameters_weight_  L_self_modules_linear_parameters_weight_  ()                                                                                        {}
placeholder    l_self_modules_linear_parameters_bias_    L_self_modules_linear_parameters_bias_    ()                                                                                        {}
placeholder    l_x_                                      L_x_                                      ()                                                                                        {}
call_function  out            

- Even pattern match on subgraph is possible!
    - https://github.com/pytorch/examples/blob/main/fx/subgraph_rewriter_basic_use.py
 
![utensor-node-fusion](https://github.com/uTensor/utensor_cgen/blob/eccd6859028d0b6a350dced25ea72ff02faaf9ad/doc/source/_images/conv_pool_fuse.png?raw=true)

## Comparison with `torch.fx`

- `torch.fx.symbolic_trace`
    - it parse the `nn.Module` and convert it into a `GraphModule`.
    - `GraphModule.graph`: the graph IR of the module.
- `torch.compile` + custom backend
    - it also pase the `nn.Module` into `GraphModule`, with the support of **Graph Break**
    - Do JIT-compilation on the resulting `GraphModule`

### What is `GrapBreak`?

- It's hard for a compiler to keep track of the execution of a function frame in a dynamic language like Python.
    - ex: data-dependent control flow, try-catch

- `torch.fx.symbolic_trace` takes a all-or-nothing approach when it find codes it can not keep track of

- `TorchDynamo` introduce graph breaks to resolve the problem
    - It breaks a computational graph into many subgraphs

In [21]:
import torch.nn as nn
from torch.fx.proxy import TraceError

In [18]:
# an example with dynamic control flow
class DynamicModule(nn.Module):
    def forward(self, x: torch.Tensor):
        if x.sum() > 0:
            return x*x
        return 3*x

In [22]:
try:
    _ = torch.fx.symbolic_trace(DynamicModule())
except TraceError:
    print("Tracing Error!!")

Tracing Error!!


In [24]:
graphs = []

def collect_graphs(gm, samples):
    graphs.append(gm.graph)
    return gm.forward

model = torch.compile(DynamicModule(), backend=collect_graphs)
x = torch.randn(10)
_ = model(x)
_ = model(-x)

In [32]:
for graph in graphs:
    print(graph.python_code("self").src.strip())
    print("=="*10)

def forward(self, L_x_ : torch.Tensor):
    l_x_ = L_x_
    sum_1 = l_x_.sum();  l_x_ = None
    gt = sum_1 > 0;  sum_1 = None
    return (gt,)
def forward(self, L_x_ : torch.Tensor):
    l_x_ = L_x_
    mul = 3 * l_x_;  l_x_ = None
    return (mul,)
def forward(self, L_x_ : torch.Tensor):
    l_x_ = L_x_
    mul = l_x_ * l_x_;  l_x_ = None
    return (mul,)


## How: `TorchDynamo` and PEP-523

- PEP-523: https://peps.python.org/pep-0523/

![pep-523-ceval](img/pep-523-ceval.png)

Example: [Pyjion](https://github.com/microsoft/pyjion)

![pep-523](img/pep-523-pyjion.png)

## `TorchDynamo` Workflow

![torch-workflow](img/torchdynamo_workflow.drawio.svg)

### The Dark Side: `set_eval_frame`

<img src='img/dark-side.jpg' alt='dark-side' width=300>

[credit](https://starwars.fandom.com/wiki/Dark_side_of_the_Force/Legends)

## `torch.compile` --> `torch._dynamo.optimize`

![torch-compile](img/torch-compile.png)

<div>
    <img src='img/torch-compile-optimize.png' alt='torch-compile-optimize' width=300>
    <img src='img/torch-compile-optimize-context.png' alt='torch-compile-optimize-context' width=320>
</div>

- `torch._dynamo.optimize`
    - wrapping your `backend` with `convert_frame.convert_frame`
    - Construct a `OptimizeContext` with wrapped `backend`

- Finally, the `backend` is wrapped into a `_TorchDynamoContext`
    - its `__call__` is invoked once you run inference on compiled module

![torch-context-set-eval](img/torch-context-set-eval.png)

The source code of the C implementation of `set_eval_frame` can be found [here](https://github.com/pytorch/pytorch/blob/v2.5.0/torch/csrc/dynamo/eval_frame.c#L758-L782).

## The Final Secret: `ConvertFrameAssert`

![torch-convert-frame-call](img/torch-convert-frame-call.png)

[impl](https://github.com/pytorch/pytorch/blob/v2.5.0/torch/_dynamo/convert_frame.py#L434-L542)

![torchdynamo-high-level-steps](img/torchdynamo-high-level-steps.png)

[YouTube: A Deep Dive on TorchDynamo](https://www.youtube.com/watch?v=5FNHwPIyHr8)

In [22]:
# Let's try it!
from typing import Optional
from torch._dynamo.types import (
    CacheEntry,
    DynamoFrameType,
    FrameState,
    GuardedCode,
)
from torch._guards import CompileId

hit = 0

def callback(
    frame: DynamoFrameType,
    cache: Optional[CacheEntry],
    frame_state: FrameState
):
    global hit, frame_id
    hit += 1
    function_name = frame.f_code.co_name
    print(f"calling {function_name}")
    return GuardedCode(code=frame.f_code, check_fn=lambda local: hit, compile_id=CompileId(frame_id=1, frame_compile_id=1))

In [23]:
from torch._C._dynamo.eval_frame import set_eval_frame

In [24]:
def foo(x, y):
    return x+y

In [25]:
prior = set_eval_frame(callback)
print(foo(1, 2))
set_eval_frame(prior)

calling __get__
calling get
calling cast
calling helper
calling __init__
calling extra_flags
calling __enter__
calling showtraceback
calling _get_exc_info
calling contains_exceptiongroup
calling structured_traceback
calling structured_traceback
calling structured_traceback
calling format_exception_as_a_whole
calling prepare_header
calling __getattr__
calling get_terminal_size
calling get_terminal_size
calling __getitem__
calling encode
calling get_records
calling has_colors
calling get_style_by_name
calling _find_and_load
calling __init__
calling __enter__
calling _get_module_lock
calling __init__
calling acquire
calling _find_and_load_unlocked
calling _find_spec
calling __enter__
calling find_spec
calling __exit__
calling find_spec
calling find_spec
calling find_spec
calling _call_with_frames_removed
calling find_spec
calling _get_spec
calling _path_importer_cache
calling find_spec
calling _path_stat
calling _fill_cache
calling _relax_case
calling _path_join
calling <listcomp>
calling

  await self.process_one()


## What's Next

- Symbolic Tracing for `TorchDynamo`: `inductor`
    - How to handle dynamic control flow?
    - Graph Break
        - spliting the computational graph into multiple subgraphs
        - continuation functions
    - Tracker API
- `ExecuTorch`: PyTorch platform for edge devices
    - Customizable code emit(?

## References

- TorchDynamo
    - https://www.youtube.com/watch?v=egZB5Uxki0I&ab_channel=EdwardZ.Yang%27sPyTorchandPL
    - https://www.youtube.com/watch?v=5FNHwPIyHr8&ab_channel=PyTorch
    - https://dev-discuss.pytorch.org/t/a-minimal-working-example-of-standalone-usage-for-dynamo-eval-frame/1525
- `torch.fx`
    - https://www.youtube.com/watch?v=TexdGMdQya4&ab_channel=Unify
    - https://github.com/pytorch/examples/tree/main/fx