-
Notifications
You must be signed in to change notification settings - Fork 89
Description
Context
Before explaining my issue, please understand (and forgive) the following caveats:
- I'm not sure if my issue is expected behaviour or not. If it is expected behaviour please feel free to close the issue; and
- I'm not interacting with
onnxscriptdirectly, but I believe I've traced my issue to a call toonnxscript, so I believe the issue belongs here. Please feel free to correct me if I'm wrong on this.
For further context, please see this issue.
The Issue
With that out-of-the-way, let me explain the issue. It appears when onnxscript does not preserve any metadata when it applies rewrite rules on mergeable operations (in my example, with consecutive transpose operations).
Consider the following code:
import torch
class DoubleTranspose(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, k):
k = k.transpose(1,2)
k = k.transpose(-2, -1) # Two transposes get merged and have no metadata
return 2*k # Some other op which does have metadata
k = torch.rand(1,128,12,32)
model = DoubleTranspose()
output = model(k) # Make sure the forward path runs
with torch.no_grad():
torch.onnx.export(
model,
(k),
"minimal.onnx",
input_names=['k'],
opset_version=18,
dynamo=True,
optimize=True, # set `optimize=False` to prevent the transpose nodes from being merged
)When optimize=True, the 2 transposes are merged into a single transpose and no longer contain any metadata, as shown:
With optimize=False, you get the two transposes with metadata intact:
In tracing the PyTorch export function, I'm led to the following function in onnxscript when optimize=True during PyTorch export. Is there any way to apply these optimization passes while automagically adding some metadata to the generated node? Either as a argument to torch.onnx.export or as an alternative post-export step would be of interest here.
Environment
- OS: Ubuntu 24.04
- PyTorch version: 2.7.1
- ONNXScript versions: 0.4.0, 0.5.3