diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index 36f2a70f8..759d766b4 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -8118,6 +8118,128 @@ def aten_std_mean_correction( return op.Sqrt(var), mean +def _center_window_around_zeros_if_needed( + window: TFloat, n_fft: int +) -> TFloat: + # first dimension + n_win = op.Shape(window, start=0, end=1) + + left = op.Div(op.Sub(n_fft, n_win), op.Constant(value_ints=[2])) + + right = op.Sub(op.Sub(n_fft, left), n_win) + left = op.Reshape(left, op.Constant(value_ints=[1])) + right = op.Reshape(right, op.Constant(value_ints=[1])) + + left_win = op.Expand(op.Constant(value_ints=[0]), left) + right_win = op.Expand(op.Constant(value_ints=[0]), right) + right_win = op.CastLike(right_win, window) + left_win = op.CastLike(left_win, window) + window_padded = op.Concat(left_win, window, right_win, axis=0) + + # Center window around zeros if needed (required by ONNX's STFT) + window = op.Where(op.Less(op.Squeeze(n_win), n_fft), window_padded, window) + return window + + +def _create_window_from_win_length(win_length: int, n_fft: int) -> TFloat: + left = op.Div(op.Sub(n_fft, win_length), op.Constant(value_ints=[2])) + + right = op.Sub(op.Sub(n_fft, left), win_length) + left = op.Reshape(left, op.Constant(value_ints=[1])) + right = op.Reshape(right, op.Constant(value_ints=[1])) + win_length = op.Reshape(win_length, op.Constant(value_ints=[1])) + + left_win = op.Expand(op.Constant(value_ints=[0]), left) + right_win = op.Expand(op.Constant(value_ints=[0]), right) + window_list = op.Expand(op.Constant(value_ints=[1]), win_length) + return op.Concat(left_win, window_list, right_win, axis=0) + + +def _create_window_from_n_fft(n_fft: int) -> TFloat: + n_fft_tensor = op.Reshape(n_fft, op.Constant(value_ints=[1])) + window = op.Expand(op.Constant(value_ints=[1]), n_fft_tensor) + return window + + +def _normalize_fft_result( + signal: TFloat, result: TFloat, n_fft: int +) -> TFloat: + n_fft_tensor = op.Reshape(n_fft, op.Constant(value_ints=[1])) + sqrt_nfft = op.Sqrt(op.CastLike(n_fft_tensor, signal)) + result = op.Div(result, sqrt_nfft) + return result + + +def _aten_stft_onnx( + signal: TFloat, + frame_step_const: INT64, + window: Union[TFloat, INT64], + frame_length_const: INT64, + onesided: int, +) -> TFloat: + window = op.CastLike(window, signal) + result = op.STFT(signal, frame_step_const, window, frame_length_const, onesided=onesided) + result = op.Transpose(result, perm=[0, 2, 1, 3]) + return result + + +@torch_op("aten::stft", trace_only=True) +def aten_stft( + self: TFloat, + n_fft: int, + hop_length: Optional[int] = None, + win_length: Optional[int] = None, + window: Optional[TFloat] = None, + normalized: bool = False, + onesided: Optional[bool] = None, + return_complex: Optional[bool] = None, +) -> TFloat: + """stft(Tensor self, int n_fft, int? hop_length=None, int? win_length=None, Tensor? window=None, bool normalized=False, bool? onesided=None, bool? return_complex=None) -> Tensor""" + + # NOTE: regardless of the value of return_complex, we always return a real representation. + del return_complex + + # Get STFT sizes + if hop_length is None: + # core dump + # hop_length = op.Div(op.Constant(value_ints=n_fft), op.Constant(value_ints=[4])) + hop_length = n_fft // 4 + frame_step_const = op.Reshape(hop_length, op.Constant(value_ints=[1])) + frame_length_const = op.Reshape(n_fft, op.Constant(value_ints=[1])) + + # Pre-process input if needed + is_signal_rank1 = self.shape is not None and len(self.shape) == 1 + if is_signal_rank1: + # Add a batch dimension + self = op.Identity(op.Unsqueeze(self, op.Constant(value_ints=[0]))) + + # Get window and make sure it's the same size as `win_length` or `n_fft` + if window is not None and window.shape[0] is not None: + window = _center_window_around_zeros_if_needed(window, n_fft) + elif window is None: + if win_length is not None: + window = _create_window_from_win_length(win_length, n_fft) + else: + window = _create_window_from_n_fft(n_fft) + + if onesided is None or onesided: + onesided = 1 + else: + onesided = 0 + result = _aten_stft_onnx( + self, frame_step_const, window, frame_length_const, onesided + ) + # Remove batch dimension, if needed + if is_signal_rank1: + result = op.Squeeze(result, op.Constant(value_ints=[0])) + + # Normalize, if needed + if normalized: + result = _normalize_fft_result(self, result, n_fft) + + return result + + @torch_op( ( "aten::sub.Tensor", diff --git a/tests/function_libs/torch_lib/ops_test_data.py b/tests/function_libs/torch_lib/ops_test_data.py index b60fd8cf3..4ef7550b6 100644 --- a/tests/function_libs/torch_lib/ops_test_data.py +++ b/tests/function_libs/torch_lib/ops_test_data.py @@ -1760,6 +1760,14 @@ def _where_input_wrangler( TorchLibOpInfo("ops.aten.scatter.value", core_ops.aten_scatter_value), TorchLibOpInfo("slice", core_ops.aten_slice), TorchLibOpInfo("slice", core_ops.aten_slice_complex, complex=True), + TorchLibOpInfo( + "ops.aten.stft", # Custom from extra_opinfo + core_ops.aten_stft, + tolerance={torch.float32: (3.7e-5, 1.8e-4)}, + ).xfail( + dtypes=(torch.float16,), + reason="RuntimeError: MKL FFT doesn't support tensors of type: Half", + ), TorchLibOpInfo( "sum", core_ops.aten_sum_dim_IntList,