-
Couldn't load subscription status.
- Fork 87
Implement aten.stft #2645
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Implement aten.stft #2645
Conversation
|
@microsoft-github-policy-service agree |
|
pytorch/pytorch#147052 (comment)
I think this PR’s implementation already uses result = op.STFT(signal, frame_step_const, window, frame_length_const, onesided=onesided)So I don’t think an decay to |
|
Hi! pytorch/pytorch#147052 (comment)
Could you elaborate a bit on this comment at your convenience? (If you think this simplification is necessary.) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
This PR implements the aten::stft (Short-Time Fourier Transform) operator to resolve issue #147052. The implementation includes handling for various optional parameters like hop_length, win_length, window, normalized, onesided, and return_complex.
Key changes:
- Added STFT operator implementation with helper functions for batch dimension handling, window centering, and FFT normalization
- Registered the operator in test data with appropriate tolerance settings and xfail for float16 dtype
Reviewed Changes
Copilot reviewed 2 out of 2 changed files in this pull request and generated 2 comments.
| File | Description |
|---|---|
| onnxscript/function_libs/torch_lib/ops/core.py | Implements aten_stft and five helper functions for STFT processing |
| tests/function_libs/torch_lib/ops_test_data.py | Registers the new operator in test suite with tolerance and xfail configuration |
Codecov Report❌ Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #2645 +/- ##
==========================================
+ Coverage 70.46% 70.47% +0.01%
==========================================
Files 224 224
Lines 26572 26634 +62
Branches 2637 2645 +8
==========================================
+ Hits 18723 18770 +47
- Misses 6928 6940 +12
- Partials 921 924 +3 ☔ View full report in Codecov by Sentry. |
| if signal_rank == 1: | ||
| # Add a batch dimension | ||
| self = op.Unsqueeze(self, op.Constant(value_ints=[0])) | ||
| return op.Identity(self), signal_rank |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@justinchuby Is identity op necessary to return self?
| ) -> TFloat: | ||
| n_fft_tensor = op.Reshape(n_fft, op.Constant(value_ints=[1])) | ||
| sqrt_nfft = op.Sqrt(op.CastLike(n_fft_tensor, signal)) | ||
| result = result / sqrt_nfft |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Use op for this kind of calculation when you delete private flag.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
def _add_batch_dimension(self: TFloat) -> Tuple[TFloat, INT64]: signal_rank = op.Size(op.Shape(self)) if signal_rank == 1:
I am surveying, but I don’t know how to handle conditionals in op right now. (op.If and op.While don’t work well...)
For example, the following expression doesn't make op.Equal true:
self = op.Where(
op.Equal(signal_rank, op.Constant(value_int=1)),
op.Unsqueeze(self, op.Constant(value_ints=[0])),
self
)There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am surveying, but I don’t know how to handle conditionals in op right now. (op.If and op.While don’t work well...)
I resolved this by moving the conditional parts to aten_stft except _center_window_around_zeros_if_needed.
Perhaps we should add a test to check window = op.Where(op.Less(op.Squeeze(n_win), n_fft), window_padded, window).
| # first dimension | ||
| n_win = op.Shape(window, start=0, end=1) | ||
| # Center window around zeros if needed (required by ONNX's STFT) | ||
| if n_win < n_fft: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@justinchuby Is there a good way we trace this?
Fixed pytorch/pytorch#147052