<a href="https://colab.research.google.com/github/gauravjain14/mlcompilers_and_kernels/blob/main/pytorch_etc/torch_export_101.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

Confused at any point, refer to https://pytorch.org/docs/stable/export.html

In [1]:
import torch
from torch.export import export

In [2]:
class Mod(torch.nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
        a = torch.sin(x)
        b = torch.cos(y)
        return a + b

Under the hood `torch.export` leverages TorchDynamo (`torch._dynamo`),
AOT Autograd (to decompose to the ATen operator set), and Torch FX (`torch.fx`) for the underlying representation of graph for a flexible Python-based transformations.

In [3]:
example_args = (torch.randn(10, 10), torch.randn(10, 10))

exported_program: torch.export.ExportedProgram = export(Mod(), example_args)
print(exported_program)

ExportedProgram:
    class GraphModule(torch.nn.Module):
        def forward(self, x: "f32[10, 10]", y: "f32[10, 10]"):
             # File: <ipython-input-2-93a00b9c2195>:6 in forward, code: a = torch.sin(x)
            sin: "f32[10, 10]" = torch.ops.aten.sin.default(x);  x = None
            
             # File: <ipython-input-2-93a00b9c2195>:7 in forward, code: b = torch.cos(y)
            cos: "f32[10, 10]" = torch.ops.aten.cos.default(y);  y = None
            
             # File: <ipython-input-2-93a00b9c2195>:8 in forward, code: return a + b
            add: "f32[10, 10]" = torch.ops.aten.add.Tensor(sin, cos);  sin = cos = None
            return (add,)
            
Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='x'), target=None, persistent=None), InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='y'), target=None, persistent=None)], output_specs=[OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>,

How `torch.export()` compares with `torch.compile()`, `torch.fx.symbolic_trace`, etc. -

1. When `torch.compile()` runs into an untraceable part of a model, it will "graph break" and fall back to running the program in eager Python runtime.
`torch.export()` will error out when something untraceable is reached.

2. `torch.export()` creates a full graph from Python features or runtime, which can be saved, loaded, and run in different environments and languages.

Compared to torch.fx.symbolic_trace(), torch.export traces using TorchDynamo which operates at the Python bytecode level, giving it the ability to trace arbitrary Python constructs not limited by what Python operator overloading supports. Additionally, torch.export keeps fine-grained track of tensor metadata, so that conditionals on things like tensor shapes do not fail tracing. In general, torch.export is expected to work on more user programs, and produce lower-level graphs (at the torch.ops.aten operator level). Note that users can still use torch.fx.symbolic_trace() as a preprocessing step before torch.export.

## Exporting a PyTorch Model

In [4]:
# Simple module for demonstration
class M(torch.nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.conv = torch.nn.Conv2d(
            in_channels=3, out_channels=16, kernel_size=3, padding=1
        )
        self.relu = torch.nn.ReLU()
        self.maxpool = torch.nn.MaxPool2d(kernel_size=3)

    def forward(self, x: torch.Tensor, *, constant=None) -> torch.Tensor:
        a = self.conv(x)
        a.add_(constant)
        return self.maxpool(self.relu(a))

In [5]:
example_args = (torch.randn(1, 3, 224, 224),)
example_kwargs = {"constant": torch.ones(1, 16, 224, 224)}

exported_program: torch.export.ExportedProgram = export(M(), example_args, example_kwargs)
print(exported_program)

ExportedProgram:
    class GraphModule(torch.nn.Module):
        def forward(self, p_conv_weight: "f32[16, 3, 3, 3]", p_conv_bias: "f32[16]", x: "f32[1, 3, 224, 224]", constant: "f32[1, 16, 224, 224]"):
             # File: <ipython-input-4-99131fc297ec>:12 in forward, code: a = self.conv(x)
            conv2d: "f32[1, 16, 224, 224]" = torch.ops.aten.conv2d.default(x, p_conv_weight, p_conv_bias, [1, 1], [1, 1]);  x = p_conv_weight = p_conv_bias = None
            
             # File: <ipython-input-4-99131fc297ec>:13 in forward, code: a.add_(constant)
            add: "f32[1, 16, 224, 224]" = torch.ops.aten.add.Tensor(conv2d, constant);  conv2d = constant = None
            
             # File: <ipython-input-4-99131fc297ec>:14 in forward, code: return self.maxpool(self.relu(a))
            relu: "f32[1, 16, 224, 224]" = torch.ops.aten.relu.default(add);  add = None
            max_pool2d: "f32[1, 16, 74, 74]" = torch.ops.aten.max_pool2d.default(relu, [3, 3], [3, 3]);  relu = None
  

Few things learned in this -

If `InputKind.PARAMETER`, Persistent=None defaults to True. What that means is that the model weights will be embedded within the graph. Thus, during deploying, separate weight files don't need to managed.

The `torch.fx.Graph` contains the computation graph of the original program, along with records of the original code for easy debugging.

The graph contains only `torch.ops.aten` operators.

The resulting shape and dtype of tensors produced by each node in the graph is noted. For example, the convolution node will result in a tensor of dtype torch.float32 and shape (1, 16, 256, 256).


### Next, we want to add define and add custom operator and see how it is traced.

In [6]:
torch.fx.Graph

# Handling and Expressing Dynamism

By default, `torch.export` will trace the program assuming all shapes are static, and specialize the exported program to those dimensions.

However, some dimensions like the Batch Dimension can be dynamic and vary from run to run.

Such dimensions are specified using the `torch.export.Dim()` API and by passing the same to `torch.export.export()` using `dynamic_shapes` argument.

In [10]:
from torch.export import Dim

class M(torch.nn.Module):
    def __init__(self):
        super().__init__()

        self.branch1 = torch.nn.Sequential(
            torch.nn.Linear(64, 32), torch.nn.ReLU()
        )
        self.branch2 = torch.nn.Sequential(
            torch.nn.Linear(128, 64), torch.nn.ReLU()
        )
        self.buffer = torch.ones(32)

    def forward(self, x1, x2):
        out1 = self.branch1(x1)
        out2 = self.branch2(x2)
        return (out1 + self.buffer, out2)

In [11]:
example_args = (torch.randn(32, 64), torch.randn(32, 128))

# Create a dynamic batch size
batch = Dim("batch")
# Specify that the first dimension of each input is that batch size
dynamic_shapes = {"x1": {0: batch}, "x2": {0: batch}}

exported_program = export(M(), example_args, dynamic_shapes=dynamic_shapes)
print(exported_program)

ExportedProgram:
    class GraphModule(torch.nn.Module):
        def forward(self, p_branch1_0_weight: "f32[32, 64]", p_branch1_0_bias: "f32[32]", p_branch2_0_weight: "f32[64, 128]", p_branch2_0_bias: "f32[64]", c_buffer: "f32[32]", x1: "f32[s0, 64]", x2: "f32[s0, 128]"):
             # File: <ipython-input-10-34a0f04649d0>:16 in forward, code: out1 = self.branch1(x1)
            linear: "f32[s0, 32]" = torch.ops.aten.linear.default(x1, p_branch1_0_weight, p_branch1_0_bias);  x1 = p_branch1_0_weight = p_branch1_0_bias = None
            relu: "f32[s0, 32]" = torch.ops.aten.relu.default(linear);  linear = None
            
             # File: <ipython-input-10-34a0f04649d0>:17 in forward, code: out2 = self.branch2(x2)
            linear_1: "f32[s0, 64]" = torch.ops.aten.linear.default(x2, p_branch2_0_weight, p_branch2_0_bias);  x2 = p_branch2_0_weight = p_branch2_0_bias = None
            relu_1: "f32[s0, 64]" = torch.ops.aten.relu.default(linear_1);  linear_1 = None
            
     

Looking at the inputs x1 and x2, they have a symbolic shape of (s0, 64) and (s0, 128)

Also look at the `exported_program.range_constraints` to see the ranges of each symbol appearing in the graph.

**Need to understand range_constraints more, in depth**

https://docs.google.com/document/d/16VPOa3d-Liikf48teAOmxLc92rgvJdfosIy-yoT38Io/edit?fbclid=IwAR3HNwmmexcitV0pbZm_x1a4ykdXZ9th_eJWK-3hBtVgKnrkmemz6Pm5jRQ#heading=h.ez923tomjvyk


## Specializations
A key concept in understanding the behavior of torch.export is the difference between static and dynamic values.

A dynamic value is one that can change from run to run. These behave like normal arguments to a Python function—you can pass different values for an argument and expect your function to do the right thing. Tensor data is treated as dynamic.

A static value is a value that is fixed at export time and cannot change between executions of the exported program. When the value is encountered during tracing, the exporter will treat it as a constant and hard-code it into the graph.

When an operation is performed (e.g. x + y) and all inputs are static, then the output of the operation will be directly hard-coded into the graph, and the operation won’t show up (i.e. it will get constant-folded).

When a value has been hard-coded into the graph, we say that the graph has been specialized to that value.


In [12]:
# Implications of Specializations is that when shape-dependent control flow is
# encountered. `torch.export` will specialize on the branh that is being taken
# with the given sample inputs

In [13]:
import torch
from torch.export import export

class Mod(torch.nn.Module):
    def forward(self, x):
        if x.shape[0] > 5:
            return x + 1
        else:
            return x - 1

# The shape is assumed to be static here since Dim() is not used to hint
example_inputs = (torch.rand(10, 2),)
exported_program = export(Mod(), example_inputs)
print(exported_program)

ExportedProgram:
    class GraphModule(torch.nn.Module):
        def forward(self, x: "f32[10, 2]"):
             # File: <ipython-input-13-952f5532b81e>:7 in forward, code: return x + 1
            add: "f32[10, 2]" = torch.ops.aten.add.Tensor(x, 1);  x = None
            return (add,)
            
Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='x'), target=None, persistent=None)], output_specs=[OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='add'), target=None)])
Range constraints: {}



As visible, the static shape `(10, 2)` is not visible in the graph since `torch.export` specializes on the inputs' static shapes.

To capture the same, dimension 0 need to be marked as Dim()

## Next - `torch.export` for Training and Inference

This section is a ToDo because these are features are introduced in PyTorch 2.5 and beyond.



In [9]:
# decomp_table = torch.export.default_decompositions()

In [8]:
torch.__version__

'2.5.1+cu124'