# Edit Draft

**TorchDynamo** is a Python-level Just-In-Time (JIT) compiler designed to make unmodified PyTorch programs faster. TorchDynamo hooks into the frame evaluation API in CPython ([PEP 523](https://peps.python.org/pep-0523/)) to dynamically modify Python bytecode right before it is executed. It rewrites Python bytecode to extract sequences of PyTorch operations into an [FX Graph](https://pytorch.org/docs/stable/fx.html) which is then compiled with a customizable backend. It creates this FX Graph through bytecode analysis and is designed to mix Python execution with compiled backends to get the best of both worlds â€” usability and performance.

*From [TorchDynamo Deep Dive](https://pytorch.org/docs/stable/torch.compiler_deepdive.html)*

## Setup (Ignore)

In [1]:
from typing import List

import torch
import torch.nn as nn

from nnsight import LanguageModel
from nnsight.util import WrapperModule

  from .autonotebook import tqdm as notebook_tqdm


## 1 - Simple Example

Let's create a simple torch model to demonstrate how operations are translated into a Torch FX graph. The `WrappedLayer` module will demonstrate how user defined modules are compiled by Dynamo.

In [2]:
class WrappedLayer(nn.Module):
    def __init__(self):
        super().__init__()
        self.layer1 = nn.Linear(1, 1)

    def forward(self, x):
        x = self.layer1(x)
        x = x * 100
        return x

class M(nn.Module):
    def __init__(self):
        super().__init__()
        self.layer1 = nn.Linear(1, 1)
        self.wrapped = WrappedLayer()
        self.dropout = nn.Dropout(0.1)

    def forward(self, x):
        x = self.layer1(x)
        x = self.wrapped(x)
        x = self.dropout(x)
        x = x.split(1, dim=-1)
        return x

mod = M()

input_tensor = torch.tensor([[1.0]])
output = mod(input_tensor)
print(output)

(tensor([[-18.2395]], grad_fn=<SplitBackward0>),)


Torch Compile is another method to speed up PyTorch code. It uses Dynamo under the hood to JIT compile arbitrary Python code. We'll use it as an easy interface for accessing FX GraphModules compiled by Dynamo.

In [3]:
def custom_backend(gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor]):
    gm.graph.print_tabular()
    gm.recompile()

    return gm.forward

torch._dynamo.reset()

opt_model = torch.compile(mod, backend=custom_backend, dynamic=True)
gm = opt_model(torch.tensor([1.0]))

opcode         name     target                       args           kwargs
-------------  -------  ---------------------------  -------------  -----------
placeholder    l_x_     L_x_                         ()             {}
call_module    x        L__self___layer1             (l_x_,)        {}
call_module    x_1      L__self___wrapped_layer1     (x,)           {}
call_function  x_3      <built-in function mul>      (x_1, 100)     {}
call_module    x_4      L__self___dropout            (x_3,)         {}
call_method    split    split                        (x_4, 1)       {'dim': -1}
call_function  getitem  <built-in function getitem>  (split, 0)     {}
output         output   output                       ((getitem,),)  {}


Notice how the functions and components we declared in the module are translated into nodes and their respective operations in an FX graph. Dynamo will trace through user defined modules such as WrappedLayer, breaking apart the operations on its forward pass into separate nodes on the FX graph.

Now, let's see what happens if we load this module into NNsight and trace it. Note how we call `torch._dynamo.reset()` to signify that we wish to **SOMETHING ABOUT NEW BACKEND**

In [5]:
from nnsight import NNsight

nn_model = NNsight(mod)

torch._dynamo.reset()

opt_model = torch.compile(nn_model._model, backend=custom_backend, dynamic=True)
gm = opt_model(torch.tensor([1.0]))

opcode         name    target                   args       kwargs
-------------  ------  -----------------------  ---------  --------
placeholder    x       L_stack0_                ()         {}
call_function  x_1     <built-in function mul>  (x, 100)   {}
output         output  output                   ((x_1,),)  {}
opcode         name     target                       args           kwargs
-------------  -------  ---------------------------  -------------  -----------
placeholder    x        L_stack0_                    ()             {}
call_method    split    split                        (x, 1)         {'dim': -1}
call_function  getitem  <built-in function getitem>  (split, 0)     {}
output         output   output                       ((getitem,),)  {}


After loading our model with NNsight, we find that our Dynamo has produced two separate graphs. When TorchDynamo encounters unsupported Python features, such as data-dependent control flow, it breaks the computation graph, lets the default Python interpreter handle the unsupported code, then resumes capturing the graph. (*From [TorchDynamo Deep Dive](https://pytorch.org/docs/stable/torch.compiler_deepdive.html)*)

We can see where TorchDynamo breaks the graph by using `torch._dynamo.explain`:

In [5]:
torch._dynamo.reset()
explain_output = torch._dynamo.explain(nn_model._model)(torch.tensor([1.0]))
print(explain_output)

Graph Count: 2
Graph Break Count: 1
Op Count: 2
Break Reasons:
Ops per Graph:
  Ops 1:
    <built-in function mul>
  Ops 2:
    <built-in function getitem>
Out Guards:
  Guard 1:
    Name: "L['x']"
    Source: local
    Create Function: TENSOR_MATCH
    Guard Types: ['TENSOR_MATCH']
    Code List: ["hasattr(L['x'], '_dynamo_dynamic_indices') == False"]
    Object Weakref: <weakref at 0x7f633b6984f0; dead>
    Guarded Class Weakref: <weakref at 0x7f631ff00ae0; to 'torch._C._TensorMeta' at 0x63781b0 (Tensor)>
  Guard 2:
    Name: ''
    Source: global
    Create Function: TORCH_FUNCTION_STATE
    Guard Types: None
    Code List: None
    Object Weakref: None
    Guarded Class Weakref: None
  Guard 3:
    Name: ''
    Source: global
    Create Function: DEFAULT_DEVICE
    Guard Types: ['DEFAULT_DEVICE']
    Code List: ['utils_device.CURRENT_DEVICE == None']
    Object Weakref: None
    Guarded Class Weakref: None
  Guard 4:
    Name: ''
    Source: shape_env
    Create Function: SHAPE_ENV

It looks like the graph broke once on the multiply `x = x * 100`. We already knew this by looking at the broken graph. 

We can force TorchDynamo to raise an error upon the first graph break encountered by using `fullgraph=True`. The stack trace will provide more details on exactly what is breaking our graph.

In [6]:
import traceback as tb

opt_bar = torch.compile(nn_model._model, fullgraph=True)
try:
    opt_bar(torch.tensor([1.0]))
except:
    tb.print_exc()

Traceback (most recent call last):
  File "/tmp/ipykernel_1725725/3700699761.py", line 5, in <module>
    opt_bar(torch.tensor([1.0]))
  File "/share/u/caden/.local/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/share/u/caden/.local/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/share/u/caden/.local/lib/python3.11/site-packages/torch/_dynamo/eval_frame.py", line 489, in _fn
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/share/u/caden/.local/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/share/u/caden/.local/lib/python3.11/site-packages/torch/nn/modules/module.p

Expanding the error trace reveals this line toward the end. 

```
torch._dynamo.exc.Unsupported: call_function UserDefinedObjectVariable(_hook) [UnspecializedNNModuleVariable(Linear), TupleVariable(), ConstDictVariable(), TensorVariable()] {}
```

This message indicates Dynamo ran into an unsupported Python feature - some forward_hook - and broke the graph. 

We can remove NNsight hooks by accessing the underlying `._envoy` and clearing the hooks with `.clear_hooks(propagate=True)`. Propagate tells NNsight to remove the hooks of an envoy's sub_envoys too.

In [8]:
nn_model._envoy.clear_hooks(propagate=True)

torch._dynamo.reset()

opt_model = torch.compile(nn_model._model, backend=custom_backend, dynamic=True)
gm = opt_model(torch.tensor([1.0]))

opcode         name                             target                                args                                                                 kwargs
-------------  -------------------------------  ------------------------------------  -------------------------------------------------------------------  -----------
placeholder    l_x_                             L_x_                                  ()                                                                   {}
get_attr       l__self___layer1_weight          L__self___layer1_weight               ()                                                                   {}
get_attr       l__self___layer1_bias            L__self___layer1_bias                 ()                                                                   {}
call_function  x                                <built-in function linear>            (l_x_, l__self___layer1_weight, l__self___layer1_bias)               {}
get_attr       l__self___wrapped_layer1

## 2 - Intervening on the FX Graph

TorchDynamo is a really powerful tool for compiling torch modules to improve performance and efficiency at scale. 

https://depyf.readthedocs.io/en/latest/walk_through.html

What if we used torch compile to attach arbitrary modules at any point in an existing module's computation? There are a couple obvious benefits: 

1. Edit models to access arbitrary attributes that aren't normally availible.
2. Add modules such as dictionaries or lora weights and access the hidden states of those modules - on a forward or backward pass - with hooks. 
3. We can just host one module on NDIF and use Torch compile to recompile existing modules. Compile simply returns an optimized module wrapper over the existing module, so we don't have to host multiple models.

Let's declare a simple model to see how we can wrap one of its attributes below. 

In [6]:
class M(nn.Module):
    def __init__(self):
        super().__init__()
        self.layer1 = nn.Linear(1, 1)
        self.dropout = nn.Dropout(0.1)

    def forward(self, x):
        x = self.layer1(x)

        value = x.item()
        x = x * value

        x = self.dropout(x)
        x = x.split(1, dim=-1)
        return x

mod = M()

input_tensor = torch.tensor([[1.0]])
output = mod(input_tensor)
print(output)

(tensor([[0.0057]], grad_fn=<SplitBackward0>),)


Suppose we'd like to access the `value` attribute. We wouldn't normally be able to do this with hooks because its not declared as a class variable.

In [8]:
class WrapperModule(torch.nn.Module):
    """Simple torch module which passes it's input through. Useful for hooking.
    If there is only one argument, returns the first element.
    """

    def forward(self, *args, **kwargs):
        if len(args) == 1:
            args = args[0]

        return args
    
wrapper_module = WrapperModule()
wrapper_name = 'value_wrapper'

setattr(mod, wrapper_name, wrapper_module)
print(mod)

M(
  (layer1): Linear(in_features=1, out_features=1, bias=True)
  (dropout): Dropout(p=0.1, inplace=False)
  (value_wrapper): WrapperModule()
)


In [None]:
def custom_backend(gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor]):

    if wrapper_name not in gm._modules:
        gm.add_submodule(wrapper_name, wrapper_module)

    ### TODO: BELOW ###
    for node in gm.graph.nodes:    

        if node.op == 'call_method' and node.name == "tensor":
            if node.args[0].name == "query":
                print('found')
                with gm.graph.inserting_after(node):
                    wrapper_args = (node.args[0], )
                    wrapper_kwargs = node.kwargs
                    wrapper_node = gm.graph.call_module(wrapper_name, args=wrapper_args, kwargs=wrapper_kwargs)
                    node = wrapper_node

    gm.recompile()

    return gm.forward

torch._dynamo.reset()

opt_model = torch.compile(mod, backend=custom_backend, dynamic=True)
gm = opt_model(torch.tensor([1.0]))