# PyTorch 2 Export

```{contents}
```

## torch.export 101

The main idea of `torch.export` is that it translates an Eager Mode PyTorch model into a graph-based intermediate representation called *Export IR*. This allows compiler backends to take this IR and further transform and optimize it for a target device. A general overview of the process is shown in the figure [below](torchexport).

:::{figure-md} torchexport
<img src="compilation.png" alt="torch.export" width=70%>

PyTorch 2 Export
:::

This IR needs to fulfill a couple of properties for it to be useful to compilers. For example:
1. Operators have to be general enough for backends to notice patterns and optimize them: Many runtimes have specialized kernels  for common operators like convolutions or even more complex ones like a `conv2 + batchnorm` (operator fusion). If the IR reduces all operators to sums, products and views, noticing these patterns becomes too hard.
2. The number of operators has to be small enough for the backend to implement all of them. 
3. Operators have to be functional, that is, without side effects. For example: If two functions read and modify the same parameters, the order of execution matters and the compiler has to be careful when parallelizing them.

Notice that properties 1 and 2 are in conflict with each other. The more operators we have, the more expressive the IR is, but the harder it is to implement all of them. This is a trade-off that the PyTorch team has to balance. 

TODO:
- [ ] Introduce ATEN (dialects), fx.Graph and link to Export IR

For now, let's get some practical intuition with an example.

## Hands on with torch.export

Let's use a simple network to see how `torch.export` works.

In [1]:
import torch.nn as nn
import torch
import torch.nn.functional as F

In [2]:
class SimpleNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv = nn.Conv2d(3, 6, 5)
        self.fc1 = nn.Linear(6 * 14 * 14, 10)
        self.fc2 = nn.Linear(6 * 14 * 14, 10)
        self.register_buffer("mask", torch.randn(6, 14, 14) > 0.5)

    def forward(self, x: torch.Tensor):
        x = self.conv(x)
        x = F.relu(x)
        x = F.max_pool2d(x, (2, 2))
        x = x * self.mask
        x = torch.flatten(x, 1)
        y = self.fc1(x)
        z = self.fc2(x)
        y = F.relu(y) 
        z = F.relu(z)
        return z, y


To export a model we must first define a sample input. This is used to trace the model and generate the Export IR. The way this works efficiently is by using `torch._subclasses.fake_tensor.FakeTensor`. FakeTensors are a special type of tensor that only store metadata such as `dtype`, `shape` and `device` and overload all operators to simulate the computation without actually looking at the values. For example, doing matrix multiplications of FakeTensors of shapes `(N, M)` and `(M, K)` will return a FakeTensor of shape `(N, K)` in constant time instead of the normal quadratic complexity.



In [5]:
torch._subclasses.fake_tensor.FakeTensor??


[0;31mInit signature:[0m
[0mtorch[0m[0;34m.[0m[0m_subclasses[0m[0;34m.[0m[0mfake_tensor[0m[0;34m.[0m[0mFakeTensor[0m[0;34m([0m[0;34m[0m
[0;34m[0m    [0mfake_mode[0m[0;34m:[0m [0;34m'FakeTensorMode'[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0melem[0m[0;34m:[0m [0;34m'Tensor'[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mdevice[0m[0;34m:[0m [0;34m'torch.device'[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mconstant[0m[0;34m:[0m [0;34m'Optional[Tensor]'[0m [0;34m=[0m [0;32mNone[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mreal_tensor[0m[0;34m:[0m [0;34m'Optional[Tensor]'[0m [0;34m=[0m [0;32mNone[0m[0;34m,[0m[0;34m[0m
[0;34m[0m[0;34m)[0m [0;34m->[0m [0;34m'Self'[0m[0;34m[0m[0;34m[0m[0m
[0;31mSource:[0m        
[0;32mclass[0m [0mFakeTensor[0m[0;34m([0m[0mTensor[0m[0;34m)[0m[0;34m:[0m[0;34m[0m
[0;34m[0m    [0;34m"""[0m
[0;34m    Meta tensors give you the ability to run PyTorch code without having to

In [3]:
x = torch.randn(1, 3, 32, 32) 
ep: torch.export.ExportedProgram = torch.export.export(SimpleNet(), (x,))

In [None]:
ep.p

In [21]:
torch.export.save(ep, "simple_net.pt2")