# PyTorch 2 Export

```{contents}
```

## torch.export 101

The main idea of `torch.export` is that it translates an Eager Mode PyTorch model into a graph-based intermediate representation called *Export IR*. This allows compiler backends to take this IR and further transform and optimize it for a target device. A general overview of the process is shown in the figure [below](torchexport).

:::{figure-md} torchexport
<img src="compilation.png" alt="torch.export" width=70%>

PyTorch 2 Export
:::

This IR needs to fulfill a couple of properties for it to be useful to compilers. For example:
1. Operators have to be general enough for backends to notice patterns and optimize them: Many runtimes have specialized kernels  for common operators like convolutions or even more complex ones like a `conv2 + relu` (operator fusion, see examples [here](https://docs.nvidia.com/deeplearning/tensorrt/developer-guide/index.html#fusion-types)). If the IR reduces all operators to sums, products and views, noticing these patterns becomes too hard.
2. The number of operators has to be small enough for the backend to implement all of them. 
3. Operators have to be functional, that is, without side effects. For example: If two functions read and modify the same parameters, the order of execution matters and the compiler has to be careful when parallelizing them.

Notice that properties 1 and 2 are in conflict with each other. The more operators we have, the more expressive the IR is, but the harder it is to implement all of them. This is a trade-off that the PyTorch team has to balance. 

TODO:
- [ ] Introduce ATEN (dialects), fx.Graph and link to Export IR, functionalization

For now, let's get some practical intuition with an example.

## Hands on with torch.export

Let's use a simple network to see how `torch.export` works.

In [1]:
import torch
import pprint
from part3_artifacts.simple_net import SimpleNet
import torch.fx.graph_module
from myst_nb import glue

In [2]:
SimpleNet??

[0;31mInit signature:[0m [0mSimpleNet[0m[0;34m([0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0;31mSource:[0m        
[0;32mclass[0m [0mSimpleNet[0m[0;34m([0m[0mnn[0m[0;34m.[0m[0mModule[0m[0;34m)[0m[0;34m:[0m[0;34m[0m
[0;34m[0m    [0;34m"""[0m
[0;34m    Just a simple network[0m
[0;34m    """[0m[0;34m[0m
[0;34m[0m    [0;32mdef[0m [0m__init__[0m[0;34m([0m[0mself[0m[0;34m)[0m[0;34m:[0m[0;34m[0m
[0;34m[0m        [0msuper[0m[0;34m([0m[0;34m)[0m[0;34m.[0m[0m__init__[0m[0;34m([0m[0;34m)[0m[0;34m[0m
[0;34m[0m        [0mself[0m[0;34m.[0m[0mconv1[0m [0;34m=[0m [0mnn[0m[0;34m.[0m[0mConv2d[0m[0;34m([0m[0;36m3[0m[0;34m,[0m [0;36m6[0m[0;34m,[0m [0;36m5[0m[0;34m)[0m[0;34m[0m
[0;34m[0m        [0mself[0m[0;34m.[0m[0mconv2[0m [0;34m=[0m [0mnn[0m[0;34m.[0m[0mConv2d[0m[0;34m([0m[0;36m3[0m[0;34m,[0m [0;36m6[0m[0;34m,[0m [0;36m5[0m[0;34m)[0m[0;34m[0m
[0;34m[0m        [0mself[

To export a model we must first define a sample input. This is used to `trace` the model and generate the Export IR. 

```{note}
`Tracing` refers to the process of recording the operations executed by a model when given a specific input along with their metadata. 

The way tracing works efficiently is by using `torch._subclasses.fake_tensor.FakeTensor`. FakeTensors are a special type of tensor that only store metadata such as `dtype`, `shape` and `device` and overload all operators to simulate the computation without actually looking at the values. 

For example, doing matrix multiplications of FakeTensors of shapes `(N, M)` and `(M, K)` will return a FakeTensor of shape `(N, K)` in constant time instead of the normal cubic complexity of multiplication.
```

For our case, the model will be deployed on a camera with a fixed resolution, so we can just define a statically shaped tensor of `batch_size` 1. If you want to support dynamically shaped inputs, refer to the [documentation](https://pytorch.org/docs/main/export.html#expressing-dynamism).

Once we have the input, we can call the `torch.export.export` function.


In [3]:
x = torch.randn(1, 3, 32, 32) 
ep: torch.export.ExportedProgram = torch.export.export(SimpleNet().eval(), (x,))

And that's it, we have exported our model. The new object is a `torch.export.ExportedProgram` which contains the model and parameters in the Export IR. Let's inspect it one by one.

The first and most important attribute is the `graph_module` which stores the computational graph of the model. We can print it using the `print_readable` method:

In [4]:
graph_module: torch.fx.GraphModule = ep.graph_module
print(graph_module.print_readable(print_output=False, colored=True, include_device=True))

class GraphModule(torch.nn.Module):
    def forward(self, p_conv1_weight: "[31mf32[0m[34m[6, 3, 5, 5][0m[2m[34m[0m[2m[32mcpu[0m", p_conv1_bias: "[31mf32[0m[34m[6][0m[2m[34m[0m[2m[32mcpu[0m", p_conv2_weight: "[31mf32[0m[34m[6, 3, 5, 5][0m[2m[34m[0m[2m[32mcpu[0m", p_conv2_bias: "[31mf32[0m[34m[6][0m[2m[34m[0m[2m[32mcpu[0m", p_fc_weight: "[31mf32[0m[34m[10, 4704][0m[2m[34m[0m[2m[32mcpu[0m", p_fc_bias: "[31mf32[0m[34m[10][0m[2m[34m[0m[2m[32mcpu[0m", x: "[31mf32[0m[34m[1, 3, 32, 32][0m[2m[34m[0m[2m[32mcpu[0m"):
         [2m# File: /home/dgcnz/development/amsterdam/edge/docs/src/part3/part3_artifacts/simple_net.py:16 in forward, code: z = self.conv1(x)[0m
        conv2d: "[31mf32[0m[34m[1, 6, 28, 28][0m[2m[34m[0m[2m[32mcpu[0m" = torch.ops.aten.conv2d.default(x, p_conv1_weight, p_conv1_bias);  [2mp_conv1_weight = p_conv1_bias = None[0m
        
         [2m# File: /home/dgcnz/development/amsterdam/edge/docs/s

Here we can see all *nodes* (`conv2d`, `relu`, `conv2d_1`, etc.), their shapes, dtypes, devices and the aten operators that are being used (`torch.ops.aten.conv2d.default`), with their accompanying file, line and code. We can also see that the graph inputs expects not only the model inputs but also its parameters (buffers and constants too).

In [5]:
def graph_formatter(graph, pp, cycle):
    pp.text(str(graph))

# def graph_nodes_formatter(nodes, pp, cycle):
#     pp.
#     for node in nodes:
#         pp.text(str(node))

from IPython import get_ipython
import torch.fx.graph as fx_graph
plain = get_ipython().display_formatter.formatters['text/plain']
plain.for_type(torch.fx.Graph, graph_formatter)
# plain.for_type(fx_graph._node_list, graph_nodes_formatter)
glue("graphmodule_graph", graph_module.graph)
glue("graphmodule_graph_nodes", list(graph_module.graph.nodes))

class StackTrace(object):
    def __init__(self, stack_trace):
        self.stack_trace = stack_trace

def stack_trace_formatter(stack_trace, pp, cycle):
    pp.text(stack_trace.stack_trace)

plain.for_type(StackTrace, stack_trace_formatter)

relu_1 = next(filter(lambda n: n.name == "relu_1", graph_module.graph.nodes))
glue("relu_1_op", relu_1.op, display=False)
glue("relu_1_target", relu_1.target, display=False)
glue("relu_1_args", relu_1.args, display=False)
glue("relu_1_stack_trace_2", StackTrace(relu_1.stack_trace), display=False)
glue("relu_1_name", relu_1.name, display=False)
glue("relu_1_meta", relu_1.meta, display=False)
glue("relu_1_users", list(relu_1.users), display=False)


graph():
    %p_conv1_weight : [num_users=1] = placeholder[target=p_conv1_weight]
    %p_conv1_bias : [num_users=1] = placeholder[target=p_conv1_bias]
    %p_conv2_weight : [num_users=1] = placeholder[target=p_conv2_weight]
    %p_conv2_bias : [num_users=1] = placeholder[target=p_conv2_bias]
    %p_fc_weight : [num_users=1] = placeholder[target=p_fc_weight]
    %p_fc_bias : [num_users=1] = placeholder[target=p_fc_bias]
    %x : [num_users=2] = placeholder[target=x]
    %conv2d : [num_users=1] = call_function[target=torch.ops.aten.conv2d.default](args = (%x, %p_conv1_weight, %p_conv1_bias), kwargs = {})
    %relu : [num_users=1] = call_function[target=torch.ops.aten.relu.default](args = (%conv2d,), kwargs = {})
    %conv2d_1 : [num_users=1] = call_function[target=torch.ops.aten.conv2d.default](args = (%x, %p_conv2_weight, %p_conv2_bias), kwargs = {})
    %relu_1 : [num_users=1] = call_function[target=torch.ops.aten.relu.default](args = (%conv2d_1,), kwargs = {})
    %add : [num_users=1]

[p_conv1_weight,
 p_conv1_bias,
 p_conv2_weight,
 p_conv2_bias,
 p_fc_weight,
 p_fc_bias,
 x,
 conv2d,
 relu,
 conv2d_1,
 relu_1,
 add,
 view,
 linear,
 output]

::::{note}

A `torch.fx.GraphModule` is just a wrapper around its `fx.Graph`, and you can access it through `graph_module.graph`. This is useful for two reasons:
- Most of the compiler steps will work with `fx.Graph` directly, so it's good to get acquainted with its attributes in case you need to debug an error.
- You *might* need to manipulate the graph directly to ensure compatibility ([example](https://leimao.github.io/blog/PyTorch-Eager-Mode-Quantization-TensorRT-Acceleration/)).


To start, if we want to print the underlying graph, we can do it like this:

```python
print(str(graph_module.graph))
```

```{glue} graphmodule_graph
```

This is similar enough to the `graph_module`'s output, so let's move on. Each "variable" in the graph is a `Node` object, and we can access them like this:

```python
print(list(graph_module.graph.nodes))
```

```{glue} graphmodule_graph_nodes
```

Specifically, if we're interested in a particular node, like the `relu_1` node, we can filter it by name:

```python
relu_1 = next(filter(lambda n: n.name == "relu_1", graph_module.graph.nodes))
```

Some of its most important attributes are the `name`, `op`, `args`, `stack_trace`, `target` and `users`.  Let's print them and see what they store.

The `name` is just the unique name of the node:

```python
print(relu_1.name)
```

```{glue} relu_1_name    
```

The `op` is the operator that the node represents. It refers to the high-level function that specifies the type of node. It is accompanied by a `target` and together they define the behavior of the node.
For example `Node(op=placeholder, target=p_p_conv1_weight)` means that the node is a placeholder for the weight of the first convolutional layer. Inputs, weights, etc are tagged as `placeholder` nodes.

On the other hand, `call_function` nodes represent a function call to their `target`. For example, `Node(op=call_function, target=torch.ops.aten.relu.default)` means that the node is a call to the `relu` function, as we can see next:

```python
print(relu_1.op)
```

```{glue} relu_1_op    
```

```python
print(relu_1.target)
```

```{glue} relu_1_target    
```

As we can see, *operator* is almost used interchangeably with *function* in this context.

The `args` are the arguments of the node's function. In our case, since `relu_1` takes as input the output of `conv2d_1`, we should see a reference to that node.

```python
print(relu_1.args)
```

```{glue} relu_1_args    
```

Similarly, the `users` are the nodes that take the output of `relu_1` as input. Both of these attributes are useful to traverse the graph and understand the dependencies between nodes.

```python
print(relu_1.users)
```

```{glue} relu_1_users    
```

Finally, the `stack_trace` is the piece of code that generated the node. This is also useful for debugging and it helps with localizing the source code that should be rewritten in case of an error.
```python
print(relu_1.stack_trace)
```

```{glue} relu_1_stack_trace_2
```

For more information refer to the [documentation](https://pytorch.org/docs/main/export.ir_spec.html).

::::

Back to the `ExportedProgram`, the second most important attribute is its `graph_signature`. This object contains information about the inputs (actual inputs, parameters, constant tensors, etc) and outputs of the model. This is particularly useful if you want to check whether a tensor is being folded as a constant.

We can print it like this:

In [6]:
pprint.pp(ep._graph_signature)

ExportGraphSignature(input_specs=[InputSpec(kind=<InputKind.PARAMETER: 2>,
                                            arg=TensorArgument(name='p_conv1_weight'),
                                            target='conv1.weight',
                                            persistent=None),
                                  InputSpec(kind=<InputKind.PARAMETER: 2>,
                                            arg=TensorArgument(name='p_conv1_bias'),
                                            target='conv1.bias',
                                            persistent=None),
                                  InputSpec(kind=<InputKind.PARAMETER: 2>,
                                            arg=TensorArgument(name='p_conv2_weight'),
                                            target='conv2.weight',
                                            persistent=None),
                                  InputSpec(kind=<InputKind.PARAMETER: 2>,
                                            arg=TensorAr

If you want to access the parameters and buffers directly, you can reference the `state_dict` attribute.

In [7]:
ep._state_dict.keys()

dict_keys(['conv1.weight', 'conv1.bias', 'conv2.weight', 'conv2.bias', 'fc.weight', 'fc.bias'])

Constants are tensors that during the forward pass are found to not change (think of a tensor that contains the shape of the input). It is a bit less common to find them, but somestimes ensuring they are constant can help the compiler to parse the model correctly. Our simple network doesn't have any constants, but you can access them like this:

In [8]:
print(ep.constants)

{}


Finally, we can save our exported program using the `torch.export.save` function.

In [9]:
torch.export.save(ep, "simple_net.pt2")