# PyTorch 2 Export

```{contents}
```

## torch.export 101

The main idea of `torch.export` is that it translates an Eager Mode PyTorch model into a graph-based intermediate representation called *Export IR*. This allows compiler backends to take this IR and further transform and optimize it for a target device. A general overview of the process is shown in the figure [below](torchexport).

:::{figure-md} torchexport
<img src="compilation.png" alt="torch.export" width=70%>

PyTorch 2 Export
:::

This IR needs to fulfill a couple of properties for it to be useful to compilers. For example:
1. Operators have to be general enough for backends to notice patterns and optimize them: Many runtimes have specialized kernels  for common operators like convolutions or even more complex ones like a `conv2 + batchnorm` (operator fusion). If the IR reduces all operators to sums, products and views, noticing these patterns becomes too hard.
2. The number of operators has to be small enough for the backend to implement all of them. 
3. Operators have to be functional, that is, without side effects. For example: If two functions read and modify the same parameters, the order of execution matters and the compiler has to be careful when parallelizing them.

Notice that properties 1 and 2 are in conflict with each other. The more operators we have, the more expressive the IR is, but the harder it is to implement all of them. This is a trade-off that the PyTorch team has to balance. 

TODO:
- [ ] Introduce ATEN (dialects), fx.Graph and link to Export IR

For now, let's get some practical intuition with an example.

## Hands on with torch.export

Let's use a simple network to see how `torch.export` works.

In [1]:
import torch
import pprint
from part3_artifacts.simple_net import SimpleNet
import torch.fx.graph_module

In [2]:
SimpleNet??

[0;31mInit signature:[0m [0mSimpleNet[0m[0;34m([0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0;31mSource:[0m        
[0;32mclass[0m [0mSimpleNet[0m[0;34m([0m[0mnn[0m[0;34m.[0m[0mModule[0m[0;34m)[0m[0;34m:[0m[0;34m[0m
[0;34m[0m    [0;34m"""[0m
[0;34m    Just a simple network[0m
[0;34m    """[0m[0;34m[0m
[0;34m[0m    [0;32mdef[0m [0m__init__[0m[0;34m([0m[0mself[0m[0;34m)[0m[0;34m:[0m[0;34m[0m
[0;34m[0m        [0msuper[0m[0;34m([0m[0;34m)[0m[0;34m.[0m[0m__init__[0m[0;34m([0m[0;34m)[0m[0;34m[0m
[0;34m[0m        [0mself[0m[0;34m.[0m[0mconv1[0m [0;34m=[0m [0mnn[0m[0;34m.[0m[0mConv2d[0m[0;34m([0m[0;36m3[0m[0;34m,[0m [0;36m6[0m[0;34m,[0m [0;36m5[0m[0;34m)[0m[0;34m[0m
[0;34m[0m        [0mself[0m[0;34m.[0m[0mconv2[0m [0;34m=[0m [0mnn[0m[0;34m.[0m[0mConv2d[0m[0;34m([0m[0;36m6[0m[0;34m,[0m [0;36m9[0m[0;34m,[0m [0;36m5[0m[0;34m)[0m[0;34m[0m
[0;34m[0m        [0mself[

To export a model we must first define a sample input. This is used to trace the model and generate the Export IR. The way this works efficiently is by using `torch._subclasses.fake_tensor.FakeTensor`. FakeTensors are a special type of tensor that only store metadata such as `dtype`, `shape` and `device` and overload all operators to simulate the computation without actually looking at the values. For example, doing matrix multiplications of FakeTensors of shapes `(N, M)` and `(M, K)` will return a FakeTensor of shape `(N, K)` in constant time instead of the normal quadratic complexity.



In [3]:
x = torch.randn(1, 3, 32, 32) 
ep: torch.export.ExportedProgram = torch.export.export(SimpleNet().eval(), (x,))

And that's it, we have exported our model. The new object is a `torch.export.ExportedProgram` which contains the model and parameters in the Export IR. Let's inspect it one by one.

The first and most important attribute is the `graph_module` which stores the computational graph of the model. We can print it using the `print_readable` method:

In [4]:
graph_module: torch.fx.GraphModule = ep.graph_module
print(graph_module.print_readable(print_output=False, colored=True, include_device=True))

class GraphModule(torch.nn.Module):
    def forward(self, p_conv1_weight: "[31mf32[0m[34m[6, 3, 5, 5][0m[2m[34m[0m[2m[32mcpu[0m", p_conv1_bias: "[31mf32[0m[34m[6][0m[2m[34m[0m[2m[32mcpu[0m", p_conv2_weight: "[31mf32[0m[34m[9, 6, 5, 5][0m[2m[34m[0m[2m[32mcpu[0m", p_conv2_bias: "[31mf32[0m[34m[9][0m[2m[34m[0m[2m[32mcpu[0m", p_fc_weight: "[31mf32[0m[34m[10, 5184][0m[2m[34m[0m[2m[32mcpu[0m", p_fc_bias: "[31mf32[0m[34m[10][0m[2m[34m[0m[2m[32mcpu[0m", x: "[31mf32[0m[34m[1, 3, 32, 32][0m[2m[34m[0m[2m[32mcpu[0m"):
         [2m# File: /home/dgcnz/development/amsterdam/edge/docs/src/part3/part3_artifacts/simple_net.py:16 in forward, code: x = self.conv1(x)[0m
        conv2d: "[31mf32[0m[34m[1, 6, 28, 28][0m[2m[34m[0m[2m[32mcpu[0m" = torch.ops.aten.conv2d.default(x, p_conv1_weight, p_conv1_bias);  [2mx = p_conv1_weight = p_conv1_bias = None[0m
        
         [2m# File: /home/dgcnz/development/amsterdam/edge/do

Here we can see all nodes (`conv2d`, `relu`, `conv2d_1`, etc.), their shapes, dtypes, devices and the aten operators that are being used (`torch.ops.aten.conv2d.default`) with their accompanying file, line and code. We can also see that the graph inputs expects not only the model inputs but also its parameters (buffers and constants too).

A `torch.fx.GraphModule` is just a wrapper around the `fx.Graph`, and you can access it through `graph_module.graph`. This is useful because `fx.Graph` has a lot of methods to manipulate the graph, like `graph_module.graph.nodes` to access all nodes, `graph_module.graph.nodes[0].args` to access the arguments of the first node.

In [12]:
print(graph_module.graph)

graph():
    %p_conv1_weight : [num_users=1] = placeholder[target=p_conv1_weight]
    %p_conv1_bias : [num_users=1] = placeholder[target=p_conv1_bias]
    %p_conv2_weight : [num_users=1] = placeholder[target=p_conv2_weight]
    %p_conv2_bias : [num_users=1] = placeholder[target=p_conv2_bias]
    %p_fc_weight : [num_users=1] = placeholder[target=p_fc_weight]
    %p_fc_bias : [num_users=1] = placeholder[target=p_fc_bias]
    %x : [num_users=1] = placeholder[target=x]
    %conv2d : [num_users=1] = call_function[target=torch.ops.aten.conv2d.default](args = (%x, %p_conv1_weight, %p_conv1_bias), kwargs = {})
    %relu : [num_users=1] = call_function[target=torch.ops.aten.relu.default](args = (%conv2d,), kwargs = {})
    %conv2d_1 : [num_users=1] = call_function[target=torch.ops.aten.conv2d.default](args = (%relu, %p_conv2_weight, %p_conv2_bias), kwargs = {})
    %relu_1 : [num_users=1] = call_function[target=torch.ops.aten.relu.default](args = (%conv2d_1,), kwargs = {})
    %view : [num_user

In [15]:
pprint.pp(list(graph_module.graph.nodes))

[p_conv1_weight,
 p_conv1_bias,
 p_conv2_weight,
 p_conv2_bias,
 p_fc_weight,
 p_fc_bias,
 x,
 conv2d,
 relu,
 conv2d_1,
 relu_1,
 view,
 linear,
 output]


In [31]:
print(graph_module.graph.python_code(graph_module.graph._root).src)




def forward(self, p_conv1_weight, p_conv1_bias, p_conv2_weight, p_conv2_bias, p_fc_weight, p_fc_bias, x):
    conv2d = torch.ops.aten.conv2d.default(x, p_conv1_weight, p_conv1_bias);  x = p_conv1_weight = p_conv1_bias = None
    relu = torch.ops.aten.relu.default(conv2d);  conv2d = None
    conv2d_1 = torch.ops.aten.conv2d.default(relu, p_conv2_weight, p_conv2_bias);  relu = p_conv2_weight = p_conv2_bias = None
    relu_1 = torch.ops.aten.relu.default(conv2d_1);  conv2d_1 = None
    view = torch.ops.aten.view.default(relu_1, [1, 5184]);  relu_1 = None
    linear = torch.ops.aten.linear.default(view, p_fc_weight, p_fc_bias);  view = p_fc_weight = p_fc_bias = None
    return (linear,)
    


In [5]:
pprint.pp(ep._graph_signature)

ExportGraphSignature(input_specs=[InputSpec(kind=<InputKind.PARAMETER: 2>,
                                            arg=TensorArgument(name='p_conv1_weight'),
                                            target='conv1.weight',
                                            persistent=None),
                                  InputSpec(kind=<InputKind.PARAMETER: 2>,
                                            arg=TensorArgument(name='p_conv1_bias'),
                                            target='conv1.bias',
                                            persistent=None),
                                  InputSpec(kind=<InputKind.PARAMETER: 2>,
                                            arg=TensorArgument(name='p_conv2_weight'),
                                            target='conv2.weight',
                                            persistent=None),
                                  InputSpec(kind=<InputKind.PARAMETER: 2>,
                                            arg=TensorAr

In [6]:
ep._state_dict.keys()

dict_keys(['conv1.weight', 'conv1.bias', 'conv2.weight', 'conv2.bias', 'fc.weight', 'fc.bias'])

In [7]:
torch.export.save(ep, "simple_net.pt2")

In [8]:
x = [
    torch.rand(1, 3, 150, 100),
    torch.rand(1, 3, 75, 50),
    torch.rand(1, 3, 37, 25),
    torch.rand(1, 3, 19, 13),
]

In [9]:
ep.constants

{}

In [22]:
ep.constants

{}