Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
122 changes: 122 additions & 0 deletions onnxscript/function_libs/torch_lib/ops/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -8118,6 +8118,128 @@ def aten_std_mean_correction(
return op.Sqrt(var), mean


def _center_window_around_zeros_if_needed(
window: TFloat, n_fft: int
) -> TFloat:
# first dimension
n_win = op.Shape(window, start=0, end=1)

left = op.Div(op.Sub(n_fft, n_win), op.Constant(value_ints=[2]))

right = op.Sub(op.Sub(n_fft, left), n_win)
left = op.Reshape(left, op.Constant(value_ints=[1]))
right = op.Reshape(right, op.Constant(value_ints=[1]))

left_win = op.Expand(op.Constant(value_ints=[0]), left)
right_win = op.Expand(op.Constant(value_ints=[0]), right)
right_win = op.CastLike(right_win, window)
left_win = op.CastLike(left_win, window)
window_padded = op.Concat(left_win, window, right_win, axis=0)

# Center window around zeros if needed (required by ONNX's STFT)
window = op.Where(op.Less(op.Squeeze(n_win), n_fft), window_padded, window)
return window


def _create_window_from_win_length(win_length: int, n_fft: int) -> TFloat:
left = op.Div(op.Sub(n_fft, win_length), op.Constant(value_ints=[2]))

right = op.Sub(op.Sub(n_fft, left), win_length)
left = op.Reshape(left, op.Constant(value_ints=[1]))
right = op.Reshape(right, op.Constant(value_ints=[1]))
win_length = op.Reshape(win_length, op.Constant(value_ints=[1]))

left_win = op.Expand(op.Constant(value_ints=[0]), left)
right_win = op.Expand(op.Constant(value_ints=[0]), right)
window_list = op.Expand(op.Constant(value_ints=[1]), win_length)
return op.Concat(left_win, window_list, right_win, axis=0)


def _create_window_from_n_fft(n_fft: int) -> TFloat:
n_fft_tensor = op.Reshape(n_fft, op.Constant(value_ints=[1]))
window = op.Expand(op.Constant(value_ints=[1]), n_fft_tensor)
return window


def _normalize_fft_result(
signal: TFloat, result: TFloat, n_fft: int
) -> TFloat:
n_fft_tensor = op.Reshape(n_fft, op.Constant(value_ints=[1]))
sqrt_nfft = op.Sqrt(op.CastLike(n_fft_tensor, signal))
result = op.Div(result, sqrt_nfft)
return result


def _aten_stft_onnx(
signal: TFloat,
frame_step_const: INT64,
window: Union[TFloat, INT64],
frame_length_const: INT64,
onesided: int,
) -> TFloat:
window = op.CastLike(window, signal)
result = op.STFT(signal, frame_step_const, window, frame_length_const, onesided=onesided)
result = op.Transpose(result, perm=[0, 2, 1, 3])
return result


@torch_op("aten::stft", trace_only=True)
def aten_stft(
self: TFloat,
n_fft: int,
hop_length: Optional[int] = None,
win_length: Optional[int] = None,
window: Optional[TFloat] = None,
normalized: bool = False,
onesided: Optional[bool] = None,
return_complex: Optional[bool] = None,
) -> TFloat:
"""stft(Tensor self, int n_fft, int? hop_length=None, int? win_length=None, Tensor? window=None, bool normalized=False, bool? onesided=None, bool? return_complex=None) -> Tensor"""

# NOTE: regardless of the value of return_complex, we always return a real representation.
del return_complex

# Get STFT sizes
if hop_length is None:
# core dump
# hop_length = op.Div(op.Constant(value_ints=n_fft), op.Constant(value_ints=[4]))
hop_length = n_fft // 4
frame_step_const = op.Reshape(hop_length, op.Constant(value_ints=[1]))
frame_length_const = op.Reshape(n_fft, op.Constant(value_ints=[1]))

# Pre-process input if needed
is_signal_rank1 = self.shape is not None and len(self.shape) == 1
if is_signal_rank1:
# Add a batch dimension
self = op.Identity(op.Unsqueeze(self, op.Constant(value_ints=[0])))

# Get window and make sure it's the same size as `win_length` or `n_fft`
if window is not None and window.shape[0] is not None:
window = _center_window_around_zeros_if_needed(window, n_fft)
elif window is None:
if win_length is not None:
window = _create_window_from_win_length(win_length, n_fft)
else:
window = _create_window_from_n_fft(n_fft)

if onesided is None or onesided:
onesided = 1
else:
onesided = 0
result = _aten_stft_onnx(
self, frame_step_const, window, frame_length_const, onesided
)
# Remove batch dimension, if needed
if is_signal_rank1:
result = op.Squeeze(result, op.Constant(value_ints=[0]))

# Normalize, if needed
if normalized:
result = _normalize_fft_result(self, result, n_fft)

return result


@torch_op(
(
"aten::sub.Tensor",
Expand Down
8 changes: 8 additions & 0 deletions tests/function_libs/torch_lib/ops_test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -1760,6 +1760,14 @@ def _where_input_wrangler(
TorchLibOpInfo("ops.aten.scatter.value", core_ops.aten_scatter_value),
TorchLibOpInfo("slice", core_ops.aten_slice),
TorchLibOpInfo("slice", core_ops.aten_slice_complex, complex=True),
TorchLibOpInfo(
"ops.aten.stft", # Custom from extra_opinfo
core_ops.aten_stft,
tolerance={torch.float32: (3.7e-5, 1.8e-4)},
).xfail(
dtypes=(torch.float16,),
reason="RuntimeError: MKL FFT doesn't support tensors of type: Half",
),
TorchLibOpInfo(
"sum",
core_ops.aten_sum_dim_IntList,
Expand Down
Loading