Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Contrib IRFFT operator output dimensions calculation #13236

Open
Alexey-Kamenev opened this issue Oct 7, 2022 · 3 comments
Open

Contrib IRFFT operator output dimensions calculation #13236

Alexey-Kamenev opened this issue Oct 7, 2022 · 3 comments
Labels
ep:CUDA issues related to the CUDA execution provider

Comments

@Alexey-Kamenev
Copy link
Contributor

Describe the issue

It seems there is an issue with computing the dimensions of the output signal in the implementation of com.microsoft.Irfft.

Specifically, this code computes the output dimension as:

out_dim = in_dim * 2 - 1

while it should be this instead:

out_dim = 2 * (in_dim - 1)

(assuming the original signal has even number of samples, of course).

For example, if the original signal has 4 samples, then the round trip should look something like:

4 -> (one-sided RFFT) -> 3 (complex) -> (one-sided IRFFT) -> 4

with the current code the output will be a signal with 5 points.

To reproduce

I don't have a small, isolated repro at this time, but I can try and create one, if needed.

Urgency

No response

Platform

Linux

OS Version

Ubuntu 20.04

ONNX Runtime Installation

Released Package

ONNX Runtime Version or Commit ID

1.12.1

ONNX Runtime API

Python

Architecture

X64

Execution Provider

CUDA

Execution Provider Library Version

CUDA 11.7

@github-actions github-actions bot added the ep:CUDA issues related to the CUDA execution provider label Oct 7, 2022
@ytaous
Copy link
Contributor

ytaous commented Oct 10, 2022

@HectorSVC - can u pls help to check on this one? thx

@NickGeneva
Copy link
Contributor

Hello @ytaous and @HectorSVC

I'd like to revive this issue. A minimal working example is below. This simple example transforms a tensor into fourier space then back out, which should recover the exact input tensor again. The current IRFFT function in ORT does not do this while PyTorch does.

This is because in the line mentioned above, the output dimension should be out_dim = 2 * (in_dim - 1) for even point IRFFTs. What is present, works for domains with odd number of points which is not supported by PyTorch or the added ONNX RFFT. For a reference, the numpy docs discuss this difference in the returns API doc.

In fact looking at the line just a bit down from the bug, its clear what the reverse calculation should be (the fix).

Minimal Example

import torch
import onnxruntime as ort
from torch.onnx.symbolic_helper import parse_args
from torch.autograd import Function

class OnnxIrfft(Function):
    """Auto-grad function to mimic irfft for ONNX exporting
    """
    @staticmethod
    def forward(ctx, input:  torch.Tensor) ->  torch.Tensor:
        return torch.fft.irfft(torch.view_as_complex(input), dim=-1, norm="backward")

    @staticmethod
    def symbolic(g: torch.Graph, input: torch.Value) -> torch.Value:
        """Symbolic representation for onnx graph"""
        return g.op(
                "com.microsoft::Irfft",
                input,
                normalized_i=0,
                onesided_i=1,
                signal_ndim_i=1,
            )

class FooModel(torch.nn.Module):
    def __init__(self, ):
        super(FooModel, self).__init__()

    def forward(self, x):
        # Calling custom op

        if not torch.onnx.is_in_onnx_export():
            out = torch.fft.irfft(torch.view_as_complex(input), dim=-1, norm="backward")
        else:
            print('here')
            out = OnnxIrfft.apply(x)
        return out


torch.manual_seed(0)
# RFFT and IRFFT only set up for even domains in PyTorch / ONNX
x = torch.randn(4,4)
input_c = torch.fft.rfft(x, dim=-1, norm="backward")
input = torch.view_as_real(input_c) # Complex tensors not supported for ONNX export
model = FooModel()

# Standard PyTorch
output = model(input)

# ONNX RT
torch.onnx.export(model, (input), 'model.onnx')
options = ort.SessionOptions()
ort_sess = ort.InferenceSession('model.onnx', providers=["CUDAExecutionProvider", 'CPUExecutionProvider'])
ort_inputs = {inp.name: v.detach().cpu().numpy()
                for inp, v in zip(ort_sess.get_inputs(), (input,))}
ort_outputs = ort_sess.run(None, ort_inputs)
output_onnx = torch.Tensor(ort_outputs[0])


print("Input:\n", x, x.shape)
print("PyTorch Output:\n", output, output.shape)
print("ONNX Output:\n", output_onnx, output_onnx.shape)

Outputs

I am running in the 22.12 Pytorch docker container. Using ORT 1.15.1. Note that the output is as follows, all outputs should be the same:

Input:
 tensor([[-1.1258, -1.1524, -0.2506, -0.4339],
        [ 0.8487,  0.6920, -0.3160, -2.1152],
        [ 0.3223, -1.2633,  0.3500,  0.3081],
        [ 0.1198,  1.2377,  1.1168, -0.2473]]) torch.Size([4, 4])
PyTorch Output:
 tensor([[-1.1258, -1.1524, -0.2506, -0.4339],
        [ 0.8487,  0.6920, -0.3160, -2.1152],
        [ 0.3223, -1.2633,  0.3500,  0.3081],
        [ 0.1198,  1.2377,  1.1168, -0.2473]]) torch.Size([4, 4])
ONNX Output:
 tensor([[-0.8587, -1.0419, -0.4523, -0.1144, -0.4953],
        [ 1.0701,  0.4009,  0.3468, -0.9733, -1.7350],
        [ 0.5833, -1.1845, -0.2159,  0.5230,  0.0112],
        [ 0.1451,  0.8074,  1.1476,  0.4493, -0.3224]]) torch.Size([4, 5])

ONNX output is incorrect and of the wrong size.
I built a new wheel with the suggested fix implemented and now the example yields the expected output:

Input:
 tensor([[-1.1258, -1.1524, -0.2506, -0.4339],
        [ 0.8487,  0.6920, -0.3160, -2.1152],
        [ 0.3223, -1.2633,  0.3500,  0.3081],
        [ 0.1198,  1.2377,  1.1168, -0.2473]]) torch.Size([4, 4])
PyTorch Output:
 tensor([[-1.1258, -1.1524, -0.2506, -0.4339],
        [ 0.8487,  0.6920, -0.3160, -2.1152],
        [ 0.3223, -1.2633,  0.3500,  0.3081],
        [ 0.1198,  1.2377,  1.1168, -0.2473]]) torch.Size([4, 4])
ONNX Output:
 tensor([[-1.1258, -1.1524, -0.2506, -0.4339],
        [ 0.8487,  0.6920, -0.3160, -2.1152],
        [ 0.3223, -1.2633,  0.3500,  0.3081],
        [ 0.1198,  1.2377,  1.1168, -0.2473]]) torch.Size([4, 4])

Thank you for the support!

@NickGeneva
Copy link
Contributor

NickGeneva commented Jul 24, 2023

HI @justinchuby

Thanks for your very helpful responses over on the PyTorch repo: pytorch/pytorch#105160

Any chance you know who to ping for some action on this bug? (Its an easy one liner fix) #15662

Many thanks.

justinchuby pushed a commit that referenced this issue Jul 28, 2023
### Description
Fixes the issue with IRFFT output dimension calculation as described in
#13236

### Motivation and Context
Please refer to #13236 for detailed description.

Specifically, [this code](https://github.com/microsoft/onnxruntime/blob/main/onnxruntime/contrib_ops/cuda/math/fft_ops.cc#L103) computes the output dimension as:
```
out_dim = in_dim * 2 - 1
```
while it should be this instead:
```
out_dim = 2 * (in_dim - 1)
```
(assuming the original signal has even number of samples, of course).

For example, if the original signal has 4 samples, then the round trip should look something like:
```
4 -> (one-sided RFFT) -> 3 (complex) -> (one-sided IRFFT) -> 4
```
with the current code the output will be a signal with 5 points.

---------

Co-authored-by: Alexey Kamenev <akamenev@nvidia.com>
Co-authored-by: Nick Geneva <nicholasgeneva@gmail.com>
jchen351 pushed a commit that referenced this issue Aug 12, 2023
### Description
Fixes the issue with IRFFT output dimension calculation as described in
#13236

### Motivation and Context
Please refer to #13236 for detailed description.

Specifically, [this code](https://github.com/microsoft/onnxruntime/blob/main/onnxruntime/contrib_ops/cuda/math/fft_ops.cc#L103) computes the output dimension as:
```
out_dim = in_dim * 2 - 1
```
while it should be this instead:
```
out_dim = 2 * (in_dim - 1)
```
(assuming the original signal has even number of samples, of course).

For example, if the original signal has 4 samples, then the round trip should look something like:
```
4 -> (one-sided RFFT) -> 3 (complex) -> (one-sided IRFFT) -> 4
```
with the current code the output will be a signal with 5 points.

---------

Co-authored-by: Alexey Kamenev <akamenev@nvidia.com>
Co-authored-by: Nick Geneva <nicholasgeneva@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ep:CUDA issues related to the CUDA execution provider
Projects
None yet
Development

No branches or pull requests

3 participants