

---

# **torch.export() and AOTInductor: Overview and Demo**

This notebook provides an overview of `torch.export()` in PyTorch, covering its usage, limitations, and how it compares to alternatives like `torch.compile()`. We’ll also explore methods for handling non-traceable components and dynamic input shapes.

---


### 1. Overview of `torch.export()`

`torch.export()` is an experimental feature in PyTorch that enables the creation of a fully captured computation graph suitable for Ahead-of-Time (AOT) deployment. This approach allows models to be saved as reusable, serialized graphs that can be used in different runtime environments without needing the original Python code.

`torch.export()` differs from `torch.compile()` by requiring a fully traceable graph. While `torch.compile()` can fall back to Python runtime when encountering an untraceable operation, `torch.export()` enforces strict tracing requirements and raises errors if an operation cannot be traced.

---


In [None]:
import torch
from torch import nn

---


# **2. Limitations of `torch.export()`**

## Graph Breaks

`torch.export()` requires a fully traceable graph, and certain operations in Python or PyTorch are difficult to trace (e.g., dynamic control flow or custom functions). When `torch.export()` encounters an untraceable operation, it raises an error. This contrasts with `torch.compile()`, which allows "graph breaks" and continues to run the operation in Python.

The example below demonstrates a graph break using a custom function that `torch.export()` cannot trace.

---


In [None]:
import torch
from torch import nn

# Wrap the untraceable function in a Module subclass
class UntraceableModule(nn.Module):
    def forward(self, x):
        if x.item() > 0:
            return x * 2
        else:
            return x

# Define the model with the wrapped untraceable function
model = nn.Sequential(
    nn.Linear(10, 10),
    nn.ReLU(),
    UntraceableModule()  # Now it’s a Module subclass
)

# Attempt to export the model with args as a tuple
try:
    exported_model = torch.export.export(model, (torch.randn(10, 10),))  # Wrap input in a tuple
except Exception as e:
    print(f"Error during export: {e}")


E1108 19:01:57.258000 452 torch/export/_trace.py:1003] always_classified is unsupported.


Error during export: Failed running call_method item(*(FakeTensor(..., size=(10, 10), grad_fn=<ReluBackward0>),), **{}):
a Tensor with 100 elements cannot be converted to Scalar

from user code:
   File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/external_utils.py", line 40, in inner
    return fn(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
  File "<ipython-input-15-ad941b7299ad>", line 7, in forward
    if x.item() > 0:

Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information



---

## Using `torch.cond` for Data-Dependent Control Flow

In PyTorch, control flow operations that depend on tensor data or shapes are challenging for a tracing compiler because it would require generating code paths for all possible conditions. `torch.export()` supports data-dependent control flow by using `torch.cond`, a specialized operator for handling `if-else` logic in a traceable way.

The example below illustrates how `torch.cond` can be used to handle conditional operations based on tensor shapes.

---



In [None]:
class ConditionalModel(nn.Module):
    def forward(self, x):
        if x.shape[0] > 5:
            return x + 1
        else:
            return x - 1

example_input = (torch.rand(10, 2),)
exported_model = torch.export.export(ConditionalModel(), example_input)
print(exported_model)  # The condition does not appear due to shape specialization


ExportedProgram:
    class GraphModule(torch.nn.Module):
        def forward(self, x: "f32[10, 2]"):
             # File: <ipython-input-16-34662d95aa80>:4 in forward, code: return x + 1
            add: "f32[10, 2]" = torch.ops.aten.add.Tensor(x, 1);  x = None
            return (add,)
            
Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='x'), target=None, persistent=None)], output_specs=[OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='add'), target=None)])
Range constraints: {}



---

# **3. Comparison with `torch.compile()`**

Both `torch.export()` and `torch.compile()` rely on PyTorch's underlying graph-capturing technology (specifically TorchDynamo and FX), but they serve different purposes and offer different levels of flexibility:

- **Purpose**:
  - `torch.compile()` is primarily designed for Just-In-Time (JIT) compilation, making it suitable for experimentation, training, and general-purpose optimizations.
  - `torch.export()`, on the other hand, is meant for Ahead-of-Time (AOT) deployment, producing a fully serialized graph that can be reused across environments and devices without needing Python runtime.

- **Flexibility with Untraceable Parts**:
  - `torch.compile()` is more flexible; if it encounters operations that cannot be traced (i.e., *graph breaks*), it will allow those parts to run in Python using the default runtime, effectively skipping full tracing for those segments.
  - `torch.export()` requires a fully captured, end-to-end graph without untraceable segments, so it will throw an error if it encounters any untraceable parts. This ensures that the exported graph is self-contained and not dependent on Python runtime.

- **Use Case Differences**:
  - **torch.compile()** is suitable for training and inferencing with high flexibility, as it can handle models with complex, dynamic control flow.
  - **torch.export()** is more suited for deploying models in production where a standalone, traceable graph is required, and it’s acceptable to rewrite code for strict traceability.

The following example demonstrates how `torch.compile()` can handle certain graph breaks that would otherwise cause issues with `torch.export()`.

---

In [None]:
# Define a simple model with dynamic control flow
class DynamicControlModel(nn.Module):
    def forward(self, x):
        if x.shape[0] > 5:  # Condition on shape, which could vary in different inputs
            return x + 1
        else:
            return x - 1

# Compile the model with torch.compile()
compiled_model = torch.compile(DynamicControlModel())

# Run the model to observe flexibility in handling control flow
output = compiled_model(torch.randn(10, 10))  # torch.compile() can handle this flexible runtime
print(output)

tensor([[ 0.2291, -0.0506,  0.8906,  1.7823,  1.0859,  1.3073,  1.0086,  1.6430,
          1.7016,  1.5336],
        [ 1.1889,  0.0981,  1.3893,  2.7723,  2.0871,  0.7327,  2.1079,  0.9839,
          0.5456,  1.2943],
        [-0.4916, -0.2345,  1.2255,  2.0321,  1.6659,  2.1579,  1.8077,  0.8497,
          0.4578,  0.0847],
        [ 3.1650,  0.7750,  0.8805,  1.0825, -0.0762,  1.2252,  0.4562,  0.7978,
          2.6023,  3.1258],
        [ 0.6714, -0.5205,  2.0268,  0.6293,  0.5531,  0.3261,  1.3184,  0.7365,
          1.5657,  2.8883],
        [ 0.5795,  2.3762,  1.1510,  0.2447, -0.1941,  3.5267,  0.9365,  2.6206,
          0.7494, -1.7936],
        [-1.4705, -1.1605,  1.6978,  1.1631,  0.7249, -0.0164,  1.0046,  0.9659,
          1.6990,  0.0542],
        [ 0.7296,  1.1931,  0.1662,  1.2432,  1.6827, -0.3016,  1.1667,  2.8555,
          2.3856,  1.0344],
        [ 1.2813,  1.3869,  0.2535,  2.5452,  2.7733,  0.9811,  3.2994,  0.0365,
          0.4662,  0.6400],
        [-0.3389,  

---

# **4. Using Non-Strict Mode for Workarounds**

`torch.export()` offers a "non-strict" mode for cases where certain operations cannot be traced, but don’t affect the core computations. Non-strict mode (`strict=False`) allows `torch.export()` to bypass these operations by using ProxyTensors, capturing only the essential operations for the model’s computation.

In this example, we create a custom context manager, `ContextManager`, which `torch.export()` cannot trace in strict mode. However, setting `strict=False` enables `torch.export()` to bypass this untraceable part, allowing the export to succeed.

---



In [None]:
import torch
from torch import nn

# Define a model with an in-place operation
class InPlaceOperationModel(nn.Module):
    def forward(self, x):
        x += 1  # In-place operation, usually untraceable in strict mode
        return x

# Attempt to export the model in strict mode (this should fail)
print("Attempting export with strict=True")
try:
    exported_model_strict = torch.export.export(InPlaceOperationModel(), (torch.ones(3, 3),), strict=True)
    print("Export successful with strict=True (unexpected)")
except Exception as e:
    print(f"Error with strict=True: {e}")

# Attempt to export the model in non-strict mode (this should succeed)
print("Attempting export with strict=False")
try:
    exported_model_non_strict = torch.export.export(InPlaceOperationModel(), (torch.ones(3, 3),), strict=False)
    print("Export successful with strict=False")
except Exception as e:
    print(f"Error with strict=False: {e}")


Attempting export with strict=True
Export successful with strict=True (unexpected)
Attempting export with strict=False
Export successful with strict=False



---

# **5. Expressing Dynamic Shapes with `torch.export.Dim()`**

By default, `torch.export()` specializes on input tensor shapes, assuming they remain static. However, dimensions like batch size may vary between runs. The `torch.export.Dim()` API allows these dimensions to be marked as dynamic, enabling the exported model to adapt to varying input shapes.

In the example below, we specify that the first dimension (typically the batch size) is dynamic.

---

In [None]:
from torch.export import Dim, export

class DynamicShapeModel(nn.Module):
    def forward(self, x):
        return x + torch.ones_like(x)

batch_dim = Dim("batch")
dynamic_shapes = {"x": {0: batch_dim}}
exported_model = export(DynamicShapeModel(), (torch.randn(32, 64),), dynamic_shapes=dynamic_shapes)
print(exported_model)


ExportedProgram:
    class GraphModule(torch.nn.Module):
        def forward(self, x: "f32[s0, 64]"):
             # File: <ipython-input-19-a84b8747708b>:5 in forward, code: return x + torch.ones_like(x)
            ones_like: "f32[s0, 64]" = torch.ops.aten.ones_like.default(x, pin_memory = False)
            add_3: "f32[s0, 64]" = torch.ops.aten.add.Tensor(x, ones_like);  x = ones_like = None
            return (add_3,)
            
Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='x'), target=None, persistent=None)], output_specs=[OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='add_3'), target=None)])
Range constraints: {s0: VR[0, int_oo]}



---

# **6. Serialization with `torch.export.save()` and `torch.export.load()`**

`torch.export()` includes built-in support for serialization, allowing models to be saved as `.pt2` files and reloaded later. This makes it easy to share or deploy models without needing access to the original Python code.

The following example demonstrates how to save and load an exported program.

---

In [None]:
class SimpleModel(nn.Module):
    def forward(self, x):
        return x + 10

exported_model = export(SimpleModel(), (torch.randn(5),))
torch.export.save(exported_model, 'model.pt2')
loaded_model = torch.export.load('model.pt2')
print(loaded_model)

ExportedProgram:
    class GraphModule(torch.nn.Module):
        def forward(self, x: "f32[5]"):
             # File: <ipython-input-20-b6dc52742b59>:3 in forward, code: return x + 10
            add: "f32[5]" = torch.ops.aten.add.Tensor(x, 10);  x = None
            return (add,)
            
Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='x'), target=None, persistent=None)], output_specs=[OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='add'), target=None)])
Range constraints: {}



---

# **7. Summary and Key Takeaways**

- `torch.export()` is designed for Ahead-of-Time (AOT) deployment and requires fully traceable graphs.
- Compared to `torch.compile()`, `torch.export()` is stricter, enforcing full graph tracing, making it suitable for production and deployment.
- Dynamic dimensions, such as batch size, can be specified using `torch.export.Dim()`, making the model adaptable to inputs with varying shapes.
- Non-strict mode provides flexibility to bypass unsupported features, allowing for more tracing options without modifying core tensor computations.
- `.pt2` serialization enables easy sharing and deployment of models without needing the original code, making `torch.export()` powerful for PyTorch model deployment.

This notebook covers the usage, limitations, and advantages of `torch.export()`, offering a foundation for efficient model deployment with PyTorch.

---