-
Notifications
You must be signed in to change notification settings - Fork 2.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Contrib IRFFT operator output dimensions calculation #13236
Comments
@HectorSVC - can u pls help to check on this one? thx |
Hello @ytaous and @HectorSVC I'd like to revive this issue. A minimal working example is below. This simple example transforms a tensor into fourier space then back out, which should recover the exact input tensor again. The current IRFFT function in ORT does not do this while PyTorch does. This is because in the line mentioned above, the output dimension should be In fact looking at the line just a bit down from the bug, its clear what the reverse calculation should be (the fix). Minimal Example
OutputsI am running in the 22.12 Pytorch docker container. Using ORT
ONNX output is incorrect and of the wrong size.
Thank you for the support! |
HI @justinchuby Thanks for your very helpful responses over on the PyTorch repo: pytorch/pytorch#105160 Any chance you know who to ping for some action on this bug? (Its an easy one liner fix) #15662 Many thanks. |
### Description Fixes the issue with IRFFT output dimension calculation as described in #13236 ### Motivation and Context Please refer to #13236 for detailed description. Specifically, [this code](https://github.com/microsoft/onnxruntime/blob/main/onnxruntime/contrib_ops/cuda/math/fft_ops.cc#L103) computes the output dimension as: ``` out_dim = in_dim * 2 - 1 ``` while it should be this instead: ``` out_dim = 2 * (in_dim - 1) ``` (assuming the original signal has even number of samples, of course). For example, if the original signal has 4 samples, then the round trip should look something like: ``` 4 -> (one-sided RFFT) -> 3 (complex) -> (one-sided IRFFT) -> 4 ``` with the current code the output will be a signal with 5 points. --------- Co-authored-by: Alexey Kamenev <akamenev@nvidia.com> Co-authored-by: Nick Geneva <nicholasgeneva@gmail.com>
### Description Fixes the issue with IRFFT output dimension calculation as described in #13236 ### Motivation and Context Please refer to #13236 for detailed description. Specifically, [this code](https://github.com/microsoft/onnxruntime/blob/main/onnxruntime/contrib_ops/cuda/math/fft_ops.cc#L103) computes the output dimension as: ``` out_dim = in_dim * 2 - 1 ``` while it should be this instead: ``` out_dim = 2 * (in_dim - 1) ``` (assuming the original signal has even number of samples, of course). For example, if the original signal has 4 samples, then the round trip should look something like: ``` 4 -> (one-sided RFFT) -> 3 (complex) -> (one-sided IRFFT) -> 4 ``` with the current code the output will be a signal with 5 points. --------- Co-authored-by: Alexey Kamenev <akamenev@nvidia.com> Co-authored-by: Nick Geneva <nicholasgeneva@gmail.com>
Describe the issue
It seems there is an issue with computing the dimensions of the output signal in the implementation of com.microsoft.Irfft.
Specifically, this code computes the output dimension as:
while it should be this instead:
(assuming the original signal has even number of samples, of course).
For example, if the original signal has 4 samples, then the round trip should look something like:
with the current code the output will be a signal with 5 points.
To reproduce
I don't have a small, isolated repro at this time, but I can try and create one, if needed.
Urgency
No response
Platform
Linux
OS Version
Ubuntu 20.04
ONNX Runtime Installation
Released Package
ONNX Runtime Version or Commit ID
1.12.1
ONNX Runtime API
Python
Architecture
X64
Execution Provider
CUDA
Execution Provider Library Version
CUDA 11.7
The text was updated successfully, but these errors were encountered: