# PyTorch 2 Export

The general idea of `torch.export` is that it translates an Eager Mode PyTorch model into a graph-based intermediate representation called Export IR. This allows compiler backends to take this IR and further transform and optimize it for a target device. A general overview of the process is shown in the figure [below](torchexport).

:::{figure-md} torchexport
<img src="compilation.png" alt="torch.export">

PyTorch 2 Export
:::



Thus, 

 for compilers in a few ways:
1. Operators have to be functional, so that compilers don't have to deal with side effects.
2. Operators have to be general enough for backends to notice patterns and optimize them.
3. The number of operators has to be small enough for the backend to implement all of them.

With this IR, compilers can take the graph representing the model and perform multiple optimizations on it, like fusing operators (merging multiple operators into one that has a specialized kernel, useful to avoid multiple CUDA kernel calls and just do it ones), partitioning the graph (splitting the graph into multiple subgraphs that can be executed in parallel), etc. We'll talk about these "passes", because they will inform the user on how to debug the model if something goes wrong.

For now, let's get some practical intuition with an example.

Let's use the following simple network:

In [3]:
import torch.nn as nn
import torch
import torch.nn.functional as F

In [19]:
class SimpleNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv = nn.Conv2d(3, 6, 5)
        self.fc1 = nn.Linear(6 * 14 * 14, 10)
        self.fc2 = nn.Linear(6 * 14 * 14, 10)
        self.register_buffer("mask", torch.randn(6, 14, 14) > 0.5)

    def forward(self, x: torch.Tensor):
        x = self.conv(x)
        x = F.relu(x)
        x = F.max_pool2d(x, (2, 2))
        x = x * self.mask
        x = torch.flatten(x, 1)
        y = self.fc1(x)
        z = self.fc2(x)
        y = F.relu(y) 
        z = F.relu(z)
        return z, y


In [20]:
x = torch.randn(1, 3, 32, 32) 
ep: torch.export.ExportedProgram = torch.export.export(SimpleNet(), (x,))

In [21]:
torch.export.save(ep, "simple_net.pt2")