Skip to content

Exporting torch model with DFT results in transform being taken along the wrong dimension #2704

@jamied157

Description

@jamied157

I'm using onnxscript==0.5.4 and torch==2.9.0 but expect this to still be a problem with the latest onnxscript as the code here: https://github.com/microsoft/onnxscript/blob/c1bfdfc5baf2dabf09191ad2ab7603ea21cfe614/onnxscript/function_libs/torch_lib/ops/fft.py hasn't been updated for a while.

Exporting a model in pytorch like

import torch.nn as nn
import torch

in_data = torch.randn(2, 128, 512)

class DFTModel(nn.Module):
    def __init__(self):
        super().__init__()
        
    def forward(self, a):
        return torch.fft.rfft(a)


torch.onnx.export(DFTModel(), (in_data,), "dft.onnx", dynamo=True)

Results in a model that looks like:

Image

Whilst the output shapes are correct here, the attributes of the model are not. We unsqueeze the final axis and then set the dft_length to 2. This results in the DFT being calculated on the wrong dimension (the extra trailing dimension with padding). I think the correct graph should look more like:

Image

To verify the model is wrong you can try and run it, I get this result

Image

Whereas I get this with the correct model

Image

Happy to go in and patch this with some pointers!

Metadata

Metadata

Labels

No labels
No labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions