# Torch Export Tips

## 1. Use dynamic_axes to keep the "Shape" op and prevent shape inline

When not using dynamic_axes, the "Shape" op will be removed, and the shape of the input tensor will be hard-coded in the exported model. For a model which needs to support dynamic shapes in the inference time, this is not what we want.

In [9]:
import torch
from torch import nn

class Model(nn.Module):
    def __init__(self) -> None:
        super().__init__()

    def forward(self, x):
        return x.shape[-1]

The following export function does not use the dynamic_axes argument, so the "Shape" op is removed. 
And the `x.shape[-1]` is exported as the evaluated constant value given the input shape.

In [7]:
torch.onnx.export(
    Model(),
    torch.ones(12,8),
    "static_shape.onnx",
    verbose=True,
)

Exported graph: graph():
  %1 : Long(requires_grad=0, device=cpu) = onnx::Constant[value={8}, onnx_name="/Constant"](), scope: __main__.Model:: # /tmp/ipykernel_673838/427168263.py:11:0
  return (%1)



The following export function uses the dynamic_axes argument, so the "Shape" op is kept.
The onnx graph looks like this:

![onnx graph of shape](./media/torch_dynamic_shape.png)

In [10]:

torch.onnx.export(
    Model(),
    torch.ones(12,8),
    "dynamic_shape.onnx",
    input_names=["input"],
    dynamic_axes={
        "input": {0: "batch_size", 1: "seq_len"},
    },
    verbose=True,
)

Exported graph: graph(%input : Float(*, *, strides=[8, 1], requires_grad=0, device=cpu)):
  %/Shape_output_0 : Long(2, strides=[1], device=cpu) = onnx::Shape[onnx_name="/Shape"](%input), scope: __main__.Model:: # /tmp/ipykernel_673838/1291048605.py:12:0
  %/Constant_output_0 : Long(device=cpu) = onnx::Constant[value={1}, onnx_name="/Constant"](), scope: __main__.Model:: # /tmp/ipykernel_673838/1291048605.py:12:0
  %3 : Long(requires_grad=0, device=cpu) = onnx::Gather[axis=0, onnx_name="/Gather"](%/Shape_output_0, %/Constant_output_0), scope: __main__.Model:: # /tmp/ipykernel_673838/1291048605.py:12:0
  return (%3)



Use the torch.Tensor.size() function to get some extent in certain dims will result in different ONNX ops.

torch.Tensor.size() will result in a "Shape -> Slice -> Squeeze" pattern, like the following graph:

![onnx graph of size](./media/torch_dynamic_shape_size.png)

In [14]:
class ModelUseSize(nn.Module):
    def __init__(self) -> None:
        super().__init__()

    def forward(self, x):
        return x.size(-1)

torch.onnx.export(
    ModelUseSize(),
    torch.ones(12,8),
    "dynamic_shape_size.onnx",
    input_names=["input"],
    dynamic_axes={
        "input": {0: "batch_size", 1: "seq_len"},
    },
    verbose=True,
)

Exported graph: graph(%input : Float(*, *, strides=[8, 1], requires_grad=0, device=cpu)):
  %/Shape_output_0 : Long(2, strides=[1], device=cpu) = onnx::Shape[onnx_name="/Shape"](%input), scope: __main__.ModelUseSize:: # /tmp/ipykernel_673838/1713131070.py:6:0
  %/Constant_output_0 : Long(1, strides=[1], device=cpu) = onnx::Constant[value={0}, onnx_name="/Constant"](), scope: __main__.ModelUseSize:: # /tmp/ipykernel_673838/1713131070.py:6:0
  %/Constant_1_output_0 : Long(1, strides=[1], device=cpu) = onnx::Constant[value={-1}, onnx_name="/Constant_1"](), scope: __main__.ModelUseSize:: # /tmp/ipykernel_673838/1713131070.py:6:0
  %/Constant_2_output_0 : Long(1, strides=[1], device=cpu) = onnx::Constant[value={9223372036854775807}, onnx_name="/Constant_2"](), scope: __main__.ModelUseSize:: # /tmp/ipykernel_673838/1713131070.py:6:0
  %/Slice_output_0 : Long(1, strides=[1], device=cpu) = onnx::Slice[onnx_name="/Slice"](%/Shape_output_0, %/Constant_1_output_0, %/Constant_2_output_0, %/Constan

Sometimes, the torch.Tensor.size() function will be exported as a "Shape -> Gather" pattern, like the following graph:

![onnx graph dynamic shape gather](./media/torch_dynamic_shape_gather.png)

In [19]:
class ModelUseSize2(nn.Module):
    def __init__(self) -> None:
        super().__init__()

    def forward(self, x):
        return x.size(0)

torch.onnx.export(
    ModelUseSize2(),
    torch.ones(12,8),
    "dynamic_shape_gather.onnx",
    input_names=["input"],
    dynamic_axes={
        "input": {0: "batch_size", 1: "seq_len"},
    },
    verbose=True,
)

Exported graph: graph(%input : Float(*, *, strides=[8, 1], requires_grad=0, device=cpu)):
  %/Shape_output_0 : Long(2, strides=[1], device=cpu) = onnx::Shape[onnx_name="/Shape"](%input), scope: __main__.ModelUseSize2:: # /tmp/ipykernel_673838/2600646057.py:6:0
  %/Constant_output_0 : Long(device=cpu) = onnx::Constant[value={0}, onnx_name="/Constant"](), scope: __main__.ModelUseSize2:: # /tmp/ipykernel_673838/2600646057.py:6:0
  %3 : Long(requires_grad=0, device=cpu) = onnx::Gather[axis=0, onnx_name="/Gather"](%/Shape_output_0, %/Constant_output_0), scope: __main__.ModelUseSize2:: # /tmp/ipykernel_673838/2600646057.py:6:0
  return (%3)



# 2 Some common patterns

In [22]:
## Gather
x : torch.Tensor = torch.ones(12,8,8)

y = x[0]  # exported as: Gather op, axis=0, index=0

In [23]:
## Unsqueeze
y = x.view(1, 12, 8, 8) # exported as: Unsqueeze op, axes=[0]