Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
170 changes: 17 additions & 153 deletions onnxscript/function_libs/torch_lib/ops/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -3974,8 +3974,6 @@ def aten_hspmm(mat1: TensorType, mat2: TensorType) -> TensorType:


# Do not register hstack - decomposed by PyTorch: https://github.com/pytorch/pytorch/blob/bedf96d7ffe74b34bcfe52c7ae1ae05f40d6c8ee/torch/_refs/__init__.py#L3918


def aten_hstack(tensors: Sequence[TTensor]) -> TTensor:
"""hstack(Tensor[] tensors) -> Tensor"""

Expand Down Expand Up @@ -7887,14 +7885,14 @@ def aten_stack(tensors: Sequence[TTensorOrString], dim: int = 0) -> TTensorOrStr
return op.ConcatFromSequence(tensors, axis=dim, new_axis=1)


@torch_op("aten::std", trace_only=True)
# std is decomposed by PyTroch
def aten_std(self: TReal, unbiased: bool = True) -> TReal:
"""std(Tensor self, bool unbiased=True) -> Tensor"""
var = _aten_var_onnx(self, correction=float(unbiased), keepdim=False)
return op.Sqrt(var)


@torch_op("aten::std.dim", trace_only=True)
# std_dim is decomposed by PyTroch
def aten_std_dim(
self: TReal,
dim: Sequence[int],
Expand All @@ -7907,7 +7905,7 @@ def aten_std_dim(
return op.Sqrt(var)


@torch_op("aten::var.correction", trace_only=True)
# std is decomposed by PyTroch
def aten_std_correction(
self: TReal,
# FIXME(justinchuby): Make dim Optional[Sequence[int]]
Expand All @@ -7927,7 +7925,7 @@ def aten_std_correction(
return op.Sqrt(var)


@torch_op("aten::std_mean", trace_only=True)
# std_mean is decomposed by PyTroch
def aten_std_mean(self: TReal, unbiased: bool = True) -> Tuple[TReal, TReal]:
"""std_mean(Tensor self, bool unbiased=True) -> (Tensor, Tensor)"""

Expand All @@ -7937,7 +7935,7 @@ def aten_std_mean(self: TReal, unbiased: bool = True) -> Tuple[TReal, TReal]:
return op.Sqrt(var), mean


@torch_op("aten::std_mean.dim", trace_only=True)
# std_mean is decomposed by PyTroch
def aten_std_mean_dim(
self: TReal, dim: Sequence[int], unbiased: bool = True, keepdim: bool = False
) -> Tuple[TReal, TReal]:
Expand All @@ -7951,7 +7949,7 @@ def aten_std_mean_dim(
return op.Sqrt(var), mean


@torch_op("aten::std_mean.correction", trace_only=True)
# std_mean is decomposed by PyTroch
def aten_std_mean_correction(
self: TReal,
# FIXME(justinchuby): Make dim Optional[Sequence[int]]
Expand All @@ -7973,139 +7971,6 @@ def aten_std_mean_correction(
return op.Sqrt(var), mean


@torch_op("aten::stft", private=True)
def _add_batch_dimension(self: TFloatOrBFloat16) -> Tuple[TFloatOrBFloat16, INT64]:
signal_rank = Rank(self)
if signal_rank == 1:
# Add a batch dimension
self = op.Unsqueeze(self, op.Constant(value_ints=[0]))
return op.Identity(self), signal_rank


@torch_op("aten::stft", private=True)
def _center_window_around_zeros_if_needed(
window: TFloatOrBFloat16, n_fft: int
) -> TFloatOrBFloat16:
# first dimension
n_win = op.Shape(window, start=0, end=1)
# Center window around zeros if needed (required by ONNX's STFT)
if n_win < n_fft:
left = (n_fft - n_win) / 2

right = n_fft - left - n_win
left = op.Reshape(left, op.Constant(value_ints=[1]))
right = op.Reshape(right, op.Constant(value_ints=[1]))

left_win = op.Expand(op.Constant(value_ints=[0]), left)
right_win = op.Expand(op.Constant(value_ints=[0]), right)
right_win = op.CastLike(right_win, window)
left_win = op.CastLike(left_win, window)
window = op.Concat(left_win, window, right_win, axis=0)
return window


@torch_op("aten::stft", private=True)
def _create_window_from_win_length(win_length: int, n_fft: int) -> TFloatOrBFloat16:
left = (n_fft - win_length) / 2

right = n_fft - left - win_length
left = op.Reshape(left, op.Constant(value_ints=[1]))
right = op.Reshape(right, op.Constant(value_ints=[1]))
win_length = op.Reshape(win_length, op.Constant(value_ints=[1]))

left_win = op.Expand(op.Constant(value_ints=[0]), left)
right_win = op.Expand(op.Constant(value_ints=[0]), right)
window_list = op.Expand(op.Constant(value_ints=[1]), win_length)
return op.Concat(left_win, window_list, right_win, axis=0)


@torch_op("aten::stft", private=True)
def _create_window_from_n_fft(n_fft: int) -> TFloatOrBFloat16:
n_fft_tensor = op.Reshape(n_fft, op.Constant(value_ints=[1]))
window = op.Expand(op.Constant(value_ints=[1]), n_fft_tensor)
return window


@torch_op("aten::stft", private=True)
def _normalize_fft_result(
signal: TFloatOrBFloat16, result: TFloatOrBFloat16, n_fft: int
) -> TFloatOrBFloat16:
n_fft_tensor = op.Reshape(n_fft, op.Constant(value_ints=[1]))
sqrt_nfft = op.Sqrt(op.CastLike(n_fft_tensor, signal))
result = result / sqrt_nfft
return result


@torch_op("aten::stft", private=True)
def _aten_stft_onnx(
signal: TFloatOrBFloat16,
frame_step_const: INT64,
window: Union[TFloatOrBFloat16, INT64],
frame_length_const: INT64,
signal_rank: INT64,
onesided: int,
) -> TFloatOrBFloat16:
window = op.CastLike(window, signal)
result = op.STFT(signal, frame_step_const, window, frame_length_const, onesided=onesided)
result = op.Transpose(result, perm=[0, 2, 1, 3])
# Remove batch dimension, if needed
if signal_rank == 1:
result = op.Squeeze(result, op.Constant(value_ints=[0]))
return result


@torch_op("aten::stft", trace_only=True)
def aten_stft(
self: TFloatOrBFloat16,
n_fft: int,
hop_length: Optional[int] = None,
win_length: Optional[int] = None,
window: Optional[TFloatOrBFloat16] = None,
normalized: bool = False,
onesided: Optional[bool] = None,
return_complex: Optional[bool] = None,
) -> TFloatOrBFloat16:
"""stft(Tensor self, int n_fft, int? hop_length=None, int? win_length=None, Tensor? window=None, bool normalized=False, bool? onesided=None, bool? return_complex=None) -> Tensor"""

# NOTE: regarless of the value of return_complex, we always return a real representation.
del return_complex

# Get STFT sizes
if hop_length is None:
# core dump
# hop_leagth = op.Div(op.Constant(value_ints=n_fft), op.Constant(value_ints=[4]))
hop_length = n_fft // 4
frame_step_const = op.Reshape(hop_length, op.Constant(value_ints=[1]))
frame_length_const = op.Reshape(n_fft, op.Constant(value_ints=[1]))

# Pre-process input if needed
self, signal_rank = _add_batch_dimension(self)

# Get window and make sure it's the same size as `win_length` or `n_fft`
if window is not None and window.shape[0] is not None:
window = _center_window_around_zeros_if_needed(window, n_fft)
elif window is None:
if win_length is not None:
window = _create_window_from_win_length(win_length, n_fft)
else:
window = _create_window_from_n_fft(n_fft)

if onesided is None or onesided:
onesided = 1
else:
onesided = 0
# remove batch dimension included
result = _aten_stft_onnx(
self, frame_step_const, window, frame_length_const, signal_rank, onesided
)

# Normalize, if needed
if normalized:
result = _normalize_fft_result(self, result, n_fft)

return result


@torch_op(
(
"aten::sub.Tensor",
Expand Down Expand Up @@ -8738,7 +8603,7 @@ def aten_vander(
raise NotImplementedError()


@torch_op("aten::var", trace_only=True)
# var is decomposed by PyTroch
def aten_var(self: TReal, unbiased: Optional[bool] = True) -> TReal:
"""var(Tensor self, bool unbiased=True) -> Tensor"""

Expand All @@ -8747,7 +8612,7 @@ def aten_var(self: TReal, unbiased: Optional[bool] = True) -> TReal:
return _aten_var_onnx(self, correction=float(unbiased), keepdim=False)


@torch_op("aten::var.dim", trace_only=True)
# var is decomposed by PyTroch
def aten_var_dim(
self: TReal,
dim: Sequence[int],
Expand All @@ -8759,7 +8624,7 @@ def aten_var_dim(
return _aten_var_dim_onnx(self, dims=dim, correction=float(unbiased), keepdim=keepdim)


@torch_op("aten::var.correction", trace_only=True)
# var is decomposed by PyTroch
def aten_var_correction(
self: TReal,
# FIXME(justinchuby): Make dim Optional[Sequence[int]]
Expand All @@ -8779,7 +8644,7 @@ def aten_var_correction(
return var


@torch_op("aten::var", private=True, traceable=True)
# var is decomposed by PyTroch
def _aten_var_onnx(self: TReal, correction: float, keepdim: bool = False) -> TReal:
mean = op.ReduceMean(self, keepdims=keepdim)
sub_mean = op.Sub(self, mean)
Expand All @@ -8796,7 +8661,7 @@ def _aten_var_onnx(self: TReal, correction: float, keepdim: bool = False) -> TRe
return var


@torch_op("aten::var.dim", private=True, traceable=True)
# var is decomposed by PyTroch
def _aten_var_dim_onnx(
self: TReal, dims: Sequence[int], correction: float, keepdim: bool = False
) -> TReal:
Expand All @@ -8817,7 +8682,7 @@ def _aten_var_dim_onnx(
return var


@torch_op("aten::var_mean", trace_only=True)
# var_mean is decomposed by PyTroch
def aten_var_mean(self: TReal, unbiased: bool = True) -> Tuple[TReal, TReal]:
"""var_mean(Tensor self, bool unbiased=True) -> (Tensor, Tensor)"""

Expand All @@ -8826,7 +8691,7 @@ def aten_var_mean(self: TReal, unbiased: bool = True) -> Tuple[TReal, TReal]:
return _aten_var_mean_onnx(self, correction=float(unbiased), keepdim=False)


@torch_op("aten::var_mean.dim", trace_only=True)
# var_mean is decomposed by PyTroch
def aten_var_mean_dim(
self: TReal, dim: Sequence[int], unbiased: bool = True, keepdim: bool = False
) -> Tuple[TReal, TReal]:
Expand All @@ -8837,7 +8702,7 @@ def aten_var_mean_dim(
return _aten_var_mean_dim_onnx(self, dims=dim, correction=float(unbiased), keepdim=keepdim)


@torch_op("aten::var_mean.correction", trace_only=True)
# var_mean is decomposed by PyTroch
def aten_var_mean_correction(
self: TReal,
# FIXME(justinchuby): Make dim Optional[Sequence[int]]
Expand All @@ -8859,7 +8724,7 @@ def aten_var_mean_correction(
return var, mean


@torch_op("aten::var_mean", private=True)
# var_mean is decomposed by PyTroch
def _aten_var_mean_onnx(
self: TReal, correction: float = 1.0, keepdim: bool = False
) -> Tuple[TReal, TReal]:
Expand All @@ -8879,7 +8744,7 @@ def _aten_var_mean_onnx(
return var, mean


@torch_op("aten::var_mean.dim", private=True)
# var_mean is decomposed by PyTroch
def _aten_var_mean_dim_onnx(
self: TReal, dims: Sequence[int], correction: float, keepdim: bool = False
) -> Tuple[TReal, TReal]:
Expand Down Expand Up @@ -8977,8 +8842,6 @@ def aten_view_copy(self: TTensor, size: IntType) -> TTensor:


# Do not register vstack - decomposed by PyTorch: https://github.com/pytorch/pytorch/blob/bedf96d7ffe74b34bcfe52c7ae1ae05f40d6c8ee/torch/_refs/__init__.py#L3918


def aten_vstack(tensors: Sequence[TTensor]) -> TTensor:
"""vstack(Tensor[] tensors) -> Tensor"""

Expand All @@ -8998,6 +8861,7 @@ def reshape_to_2d(tensor):

@torch_op(
(
"aten::where",
"aten::where.Scalar",
"aten::where.ScalarSelf",
"aten::where.ScalarOther",
Expand Down
Loading
Loading