From 085401d7939eb380cb37bace58bc0758ff61d336 Mon Sep 17 00:00:00 2001 From: Tomoaki Kobayashi Date: Sat, 23 Aug 2025 20:16:05 +0900 Subject: [PATCH 1/5] Revert "[torchlib] Unregister stft, var, var_mean, std, std_mean (#1867)" This reverts commit 1eef63304555f4ce7686d9ed20657367b64ae323. --- .../function_libs/torch_lib/ops/core.py | 133 ++++++++++++++++++ .../function_libs/torch_lib/ops_test_data.py | 8 ++ 2 files changed, 141 insertions(+) diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index 36f2a70f8..01dcc2ee3 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -8118,6 +8118,139 @@ def aten_std_mean_correction( return op.Sqrt(var), mean +@torch_op("aten::stft", private=True) +def _add_batch_dimension(self: TFloatOrBFloat16) -> Tuple[TFloatOrBFloat16, INT64]: + signal_rank = Rank(self) + if signal_rank == 1: + # Add a batch dimension + self = op.Unsqueeze(self, op.Constant(value_ints=[0])) + return op.Identity(self), signal_rank + + +@torch_op("aten::stft", private=True) +def _center_window_around_zeros_if_needed( + window: TFloatOrBFloat16, n_fft: int +) -> TFloatOrBFloat16: + # first dimension + n_win = op.Shape(window, start=0, end=1) + # Center window around zeros if needed (required by ONNX's STFT) + if n_win < n_fft: + left = (n_fft - n_win) / 2 + + right = n_fft - left - n_win + left = op.Reshape(left, op.Constant(value_ints=[1])) + right = op.Reshape(right, op.Constant(value_ints=[1])) + + left_win = op.Expand(op.Constant(value_ints=[0]), left) + right_win = op.Expand(op.Constant(value_ints=[0]), right) + right_win = op.CastLike(right_win, window) + left_win = op.CastLike(left_win, window) + window = op.Concat(left_win, window, right_win, axis=0) + return window + + +@torch_op("aten::stft", private=True) +def _create_window_from_win_length(win_length: int, n_fft: int) -> TFloatOrBFloat16: + left = (n_fft - win_length) / 2 + + right = n_fft - left - win_length + left = op.Reshape(left, op.Constant(value_ints=[1])) + right = op.Reshape(right, op.Constant(value_ints=[1])) + win_length = op.Reshape(win_length, op.Constant(value_ints=[1])) + + left_win = op.Expand(op.Constant(value_ints=[0]), left) + right_win = op.Expand(op.Constant(value_ints=[0]), right) + window_list = op.Expand(op.Constant(value_ints=[1]), win_length) + return op.Concat(left_win, window_list, right_win, axis=0) + + +@torch_op("aten::stft", private=True) +def _create_window_from_n_fft(n_fft: int) -> TFloatOrBFloat16: + n_fft_tensor = op.Reshape(n_fft, op.Constant(value_ints=[1])) + window = op.Expand(op.Constant(value_ints=[1]), n_fft_tensor) + return window + + +@torch_op("aten::stft", private=True) +def _normalize_fft_result( + signal: TFloatOrBFloat16, result: TFloatOrBFloat16, n_fft: int +) -> TFloatOrBFloat16: + n_fft_tensor = op.Reshape(n_fft, op.Constant(value_ints=[1])) + sqrt_nfft = op.Sqrt(op.CastLike(n_fft_tensor, signal)) + result = result / sqrt_nfft + return result + + +@torch_op("aten::stft", private=True) +def _aten_stft_onnx( + signal: TFloatOrBFloat16, + frame_step_const: INT64, + window: Union[TFloatOrBFloat16, INT64], + frame_length_const: INT64, + signal_rank: INT64, + onesided: int, +) -> TFloatOrBFloat16: + window = op.CastLike(window, signal) + result = op.STFT(signal, frame_step_const, window, frame_length_const, onesided=onesided) + result = op.Transpose(result, perm=[0, 2, 1, 3]) + # Remove batch dimension, if needed + if signal_rank == 1: + result = op.Squeeze(result, op.Constant(value_ints=[0])) + return result + + +@torch_op("aten::stft", trace_only=True) +def aten_stft( + self: TFloatOrBFloat16, + n_fft: int, + hop_length: Optional[int] = None, + win_length: Optional[int] = None, + window: Optional[TFloatOrBFloat16] = None, + normalized: bool = False, + onesided: Optional[bool] = None, + return_complex: Optional[bool] = None, +) -> TFloatOrBFloat16: + """stft(Tensor self, int n_fft, int? hop_length=None, int? win_length=None, Tensor? window=None, bool normalized=False, bool? onesided=None, bool? return_complex=None) -> Tensor""" + + # NOTE: regarless of the value of return_complex, we always return a real representation. + del return_complex + + # Get STFT sizes + if hop_length is None: + # core dump + # hop_leagth = op.Div(op.Constant(value_ints=n_fft), op.Constant(value_ints=[4])) + hop_length = n_fft // 4 + frame_step_const = op.Reshape(hop_length, op.Constant(value_ints=[1])) + frame_length_const = op.Reshape(n_fft, op.Constant(value_ints=[1])) + + # Pre-process input if needed + self, signal_rank = _add_batch_dimension(self) + + # Get window and make sure it's the same size as `win_length` or `n_fft` + if window is not None and window.shape[0] is not None: + window = _center_window_around_zeros_if_needed(window, n_fft) + elif window is None: + if win_length is not None: + window = _create_window_from_win_length(win_length, n_fft) + else: + window = _create_window_from_n_fft(n_fft) + + if onesided is None or onesided: + onesided = 1 + else: + onesided = 0 + # remove batch dimension included + result = _aten_stft_onnx( + self, frame_step_const, window, frame_length_const, signal_rank, onesided + ) + + # Normalize, if needed + if normalized: + result = _normalize_fft_result(self, result, n_fft) + + return result + + @torch_op( ( "aten::sub.Tensor", diff --git a/tests/function_libs/torch_lib/ops_test_data.py b/tests/function_libs/torch_lib/ops_test_data.py index b60fd8cf3..4ef7550b6 100644 --- a/tests/function_libs/torch_lib/ops_test_data.py +++ b/tests/function_libs/torch_lib/ops_test_data.py @@ -1760,6 +1760,14 @@ def _where_input_wrangler( TorchLibOpInfo("ops.aten.scatter.value", core_ops.aten_scatter_value), TorchLibOpInfo("slice", core_ops.aten_slice), TorchLibOpInfo("slice", core_ops.aten_slice_complex, complex=True), + TorchLibOpInfo( + "ops.aten.stft", # Custom from extra_opinfo + core_ops.aten_stft, + tolerance={torch.float32: (3.7e-5, 1.8e-4)}, + ).xfail( + dtypes=(torch.float16,), + reason="RuntimeError: MKL FFT doesn't support tensors of type: Half", + ), TorchLibOpInfo( "sum", core_ops.aten_sum_dim_IntList, From 449e1fe2f93eb7d95f94dbf9f97fd6f80f358773 Mon Sep 17 00:00:00 2001 From: Tomoaki Kobayashi Date: Sat, 18 Oct 2025 03:57:05 +0900 Subject: [PATCH 2/5] Fix aten_stft --- .../function_libs/torch_lib/ops/core.py | 28 +++++++++---------- 1 file changed, 14 insertions(+), 14 deletions(-) diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index 01dcc2ee3..0810c4f74 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -8119,8 +8119,8 @@ def aten_std_mean_correction( @torch_op("aten::stft", private=True) -def _add_batch_dimension(self: TFloatOrBFloat16) -> Tuple[TFloatOrBFloat16, INT64]: - signal_rank = Rank(self) +def _add_batch_dimension(self: TFloat) -> Tuple[TFloat, INT64]: + signal_rank = op.Size(op.Shape(self)) if signal_rank == 1: # Add a batch dimension self = op.Unsqueeze(self, op.Constant(value_ints=[0])) @@ -8129,8 +8129,8 @@ def _add_batch_dimension(self: TFloatOrBFloat16) -> Tuple[TFloatOrBFloat16, INT6 @torch_op("aten::stft", private=True) def _center_window_around_zeros_if_needed( - window: TFloatOrBFloat16, n_fft: int -) -> TFloatOrBFloat16: + window: TFloat, n_fft: int +) -> TFloat: # first dimension n_win = op.Shape(window, start=0, end=1) # Center window around zeros if needed (required by ONNX's STFT) @@ -8150,7 +8150,7 @@ def _center_window_around_zeros_if_needed( @torch_op("aten::stft", private=True) -def _create_window_from_win_length(win_length: int, n_fft: int) -> TFloatOrBFloat16: +def _create_window_from_win_length(win_length: int, n_fft: int) -> TFloat: left = (n_fft - win_length) / 2 right = n_fft - left - win_length @@ -8165,7 +8165,7 @@ def _create_window_from_win_length(win_length: int, n_fft: int) -> TFloatOrBFloa @torch_op("aten::stft", private=True) -def _create_window_from_n_fft(n_fft: int) -> TFloatOrBFloat16: +def _create_window_from_n_fft(n_fft: int) -> TFloat: n_fft_tensor = op.Reshape(n_fft, op.Constant(value_ints=[1])) window = op.Expand(op.Constant(value_ints=[1]), n_fft_tensor) return window @@ -8173,8 +8173,8 @@ def _create_window_from_n_fft(n_fft: int) -> TFloatOrBFloat16: @torch_op("aten::stft", private=True) def _normalize_fft_result( - signal: TFloatOrBFloat16, result: TFloatOrBFloat16, n_fft: int -) -> TFloatOrBFloat16: + signal: TFloat, result: TFloat, n_fft: int +) -> TFloat: n_fft_tensor = op.Reshape(n_fft, op.Constant(value_ints=[1])) sqrt_nfft = op.Sqrt(op.CastLike(n_fft_tensor, signal)) result = result / sqrt_nfft @@ -8183,13 +8183,13 @@ def _normalize_fft_result( @torch_op("aten::stft", private=True) def _aten_stft_onnx( - signal: TFloatOrBFloat16, + signal: TFloat, frame_step_const: INT64, - window: Union[TFloatOrBFloat16, INT64], + window: Union[TFloat, INT64], frame_length_const: INT64, signal_rank: INT64, onesided: int, -) -> TFloatOrBFloat16: +) -> TFloat: window = op.CastLike(window, signal) result = op.STFT(signal, frame_step_const, window, frame_length_const, onesided=onesided) result = op.Transpose(result, perm=[0, 2, 1, 3]) @@ -8201,15 +8201,15 @@ def _aten_stft_onnx( @torch_op("aten::stft", trace_only=True) def aten_stft( - self: TFloatOrBFloat16, + self: TFloat, n_fft: int, hop_length: Optional[int] = None, win_length: Optional[int] = None, - window: Optional[TFloatOrBFloat16] = None, + window: Optional[TFloat] = None, normalized: bool = False, onesided: Optional[bool] = None, return_complex: Optional[bool] = None, -) -> TFloatOrBFloat16: +) -> TFloat: """stft(Tensor self, int n_fft, int? hop_length=None, int? win_length=None, Tensor? window=None, bool normalized=False, bool? onesided=None, bool? return_complex=None) -> Tensor""" # NOTE: regarless of the value of return_complex, we always return a real representation. From ae878c8c53d5da1b035c641d5b8c5ed0f4accaa9 Mon Sep 17 00:00:00 2001 From: Tomoaki Kobayashi Date: Tue, 21 Oct 2025 22:08:09 +0900 Subject: [PATCH 3/5] Fix typos --- onnxscript/function_libs/torch_lib/ops/core.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index 0810c4f74..12189731b 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -8212,13 +8212,13 @@ def aten_stft( ) -> TFloat: """stft(Tensor self, int n_fft, int? hop_length=None, int? win_length=None, Tensor? window=None, bool normalized=False, bool? onesided=None, bool? return_complex=None) -> Tensor""" - # NOTE: regarless of the value of return_complex, we always return a real representation. + # NOTE: regardless of the value of return_complex, we always return a real representation. del return_complex # Get STFT sizes if hop_length is None: # core dump - # hop_leagth = op.Div(op.Constant(value_ints=n_fft), op.Constant(value_ints=[4])) + # hop_length = op.Div(op.Constant(value_ints=n_fft), op.Constant(value_ints=[4])) hop_length = n_fft // 4 frame_step_const = op.Reshape(hop_length, op.Constant(value_ints=[1])) frame_length_const = op.Reshape(n_fft, op.Constant(value_ints=[1])) From 839a5bb858e2b168df4e2e6b82130331b3ab63eb Mon Sep 17 00:00:00 2001 From: Tomoaki Kobayashi Date: Sat, 25 Oct 2025 15:00:23 +0900 Subject: [PATCH 4/5] Remove deprecated annotations --- onnxscript/function_libs/torch_lib/ops/core.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index 12189731b..750184d12 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -8118,7 +8118,6 @@ def aten_std_mean_correction( return op.Sqrt(var), mean -@torch_op("aten::stft", private=True) def _add_batch_dimension(self: TFloat) -> Tuple[TFloat, INT64]: signal_rank = op.Size(op.Shape(self)) if signal_rank == 1: @@ -8127,7 +8126,6 @@ def _add_batch_dimension(self: TFloat) -> Tuple[TFloat, INT64]: return op.Identity(self), signal_rank -@torch_op("aten::stft", private=True) def _center_window_around_zeros_if_needed( window: TFloat, n_fft: int ) -> TFloat: @@ -8149,7 +8147,6 @@ def _center_window_around_zeros_if_needed( return window -@torch_op("aten::stft", private=True) def _create_window_from_win_length(win_length: int, n_fft: int) -> TFloat: left = (n_fft - win_length) / 2 @@ -8164,14 +8161,12 @@ def _create_window_from_win_length(win_length: int, n_fft: int) -> TFloat: return op.Concat(left_win, window_list, right_win, axis=0) -@torch_op("aten::stft", private=True) def _create_window_from_n_fft(n_fft: int) -> TFloat: n_fft_tensor = op.Reshape(n_fft, op.Constant(value_ints=[1])) window = op.Expand(op.Constant(value_ints=[1]), n_fft_tensor) return window -@torch_op("aten::stft", private=True) def _normalize_fft_result( signal: TFloat, result: TFloat, n_fft: int ) -> TFloat: @@ -8181,7 +8176,6 @@ def _normalize_fft_result( return result -@torch_op("aten::stft", private=True) def _aten_stft_onnx( signal: TFloat, frame_step_const: INT64, From 3c12aae51096d1722e34132683493508e61f8047 Mon Sep 17 00:00:00 2001 From: Tomoaki Kobayashi Date: Sat, 25 Oct 2025 17:44:52 +0900 Subject: [PATCH 5/5] Use op --- .../function_libs/torch_lib/ops/core.py | 55 +++++++++---------- 1 file changed, 25 insertions(+), 30 deletions(-) diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index 750184d12..759d766b4 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -8118,39 +8118,33 @@ def aten_std_mean_correction( return op.Sqrt(var), mean -def _add_batch_dimension(self: TFloat) -> Tuple[TFloat, INT64]: - signal_rank = op.Size(op.Shape(self)) - if signal_rank == 1: - # Add a batch dimension - self = op.Unsqueeze(self, op.Constant(value_ints=[0])) - return op.Identity(self), signal_rank - - def _center_window_around_zeros_if_needed( window: TFloat, n_fft: int ) -> TFloat: # first dimension n_win = op.Shape(window, start=0, end=1) + + left = op.Div(op.Sub(n_fft, n_win), op.Constant(value_ints=[2])) + + right = op.Sub(op.Sub(n_fft, left), n_win) + left = op.Reshape(left, op.Constant(value_ints=[1])) + right = op.Reshape(right, op.Constant(value_ints=[1])) + + left_win = op.Expand(op.Constant(value_ints=[0]), left) + right_win = op.Expand(op.Constant(value_ints=[0]), right) + right_win = op.CastLike(right_win, window) + left_win = op.CastLike(left_win, window) + window_padded = op.Concat(left_win, window, right_win, axis=0) + # Center window around zeros if needed (required by ONNX's STFT) - if n_win < n_fft: - left = (n_fft - n_win) / 2 - - right = n_fft - left - n_win - left = op.Reshape(left, op.Constant(value_ints=[1])) - right = op.Reshape(right, op.Constant(value_ints=[1])) - - left_win = op.Expand(op.Constant(value_ints=[0]), left) - right_win = op.Expand(op.Constant(value_ints=[0]), right) - right_win = op.CastLike(right_win, window) - left_win = op.CastLike(left_win, window) - window = op.Concat(left_win, window, right_win, axis=0) + window = op.Where(op.Less(op.Squeeze(n_win), n_fft), window_padded, window) return window def _create_window_from_win_length(win_length: int, n_fft: int) -> TFloat: - left = (n_fft - win_length) / 2 + left = op.Div(op.Sub(n_fft, win_length), op.Constant(value_ints=[2])) - right = n_fft - left - win_length + right = op.Sub(op.Sub(n_fft, left), win_length) left = op.Reshape(left, op.Constant(value_ints=[1])) right = op.Reshape(right, op.Constant(value_ints=[1])) win_length = op.Reshape(win_length, op.Constant(value_ints=[1])) @@ -8172,7 +8166,7 @@ def _normalize_fft_result( ) -> TFloat: n_fft_tensor = op.Reshape(n_fft, op.Constant(value_ints=[1])) sqrt_nfft = op.Sqrt(op.CastLike(n_fft_tensor, signal)) - result = result / sqrt_nfft + result = op.Div(result, sqrt_nfft) return result @@ -8181,15 +8175,11 @@ def _aten_stft_onnx( frame_step_const: INT64, window: Union[TFloat, INT64], frame_length_const: INT64, - signal_rank: INT64, onesided: int, ) -> TFloat: window = op.CastLike(window, signal) result = op.STFT(signal, frame_step_const, window, frame_length_const, onesided=onesided) result = op.Transpose(result, perm=[0, 2, 1, 3]) - # Remove batch dimension, if needed - if signal_rank == 1: - result = op.Squeeze(result, op.Constant(value_ints=[0])) return result @@ -8218,7 +8208,10 @@ def aten_stft( frame_length_const = op.Reshape(n_fft, op.Constant(value_ints=[1])) # Pre-process input if needed - self, signal_rank = _add_batch_dimension(self) + is_signal_rank1 = self.shape is not None and len(self.shape) == 1 + if is_signal_rank1: + # Add a batch dimension + self = op.Identity(op.Unsqueeze(self, op.Constant(value_ints=[0]))) # Get window and make sure it's the same size as `win_length` or `n_fft` if window is not None and window.shape[0] is not None: @@ -8233,10 +8226,12 @@ def aten_stft( onesided = 1 else: onesided = 0 - # remove batch dimension included result = _aten_stft_onnx( - self, frame_step_const, window, frame_length_const, signal_rank, onesided + self, frame_step_const, window, frame_length_const, onesided ) + # Remove batch dimension, if needed + if is_signal_rank1: + result = op.Squeeze(result, op.Constant(value_ints=[0])) # Normalize, if needed if normalized: