Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optimise as_windowed and apply :reflect padding to the whole input #17

Merged
merged 3 commits into from
Aug 30, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
137 changes: 30 additions & 107 deletions lib/nx_signal.ex
Original file line number Diff line number Diff line change
Expand Up @@ -172,10 +172,10 @@ defmodule NxSignal do

* `:window_length` - the number of samples in a window
* `:stride` - The number of samples to skip between windows. Defaults to `1`.
* `:padding` - A can be `:reflect` or a valid padding as per `Nx.pad/3` over the
input tensor's shape. Defaults to `:valid`. If `:reflect` or `:zeros`, the first window will be centered
at the start of the signal. For `:reflect`, each incomplete window will be reflected as if it was
periodic (see examples for `as_windowed/2`). For `:zeros`, each incomplete window will be zero-padded.
* `:padding` - Padding mode, can be `:reflect` or a valid padding as per `Nx.pad/3` over the
input tensor's shape. Defaults to `:valid`. If `:reflect` or `:same`, the first window will be centered
at the start of the signal. The padding is applied for the whole input, rather than individual
windows. For `:zeros`, effectively each incomplete window will be zero-padded.

## Examples

Expand Down Expand Up @@ -219,27 +219,29 @@ defmodule NxSignal do
iex> t = Nx.iota({7});
iex> NxSignal.as_windowed(t, window_length: 6, padding: :reflect, stride: 1)
#Nx.Tensor<
s64[7][6]
s64[8][6]
[
[1, 2, 1, 0, 1, 2],
[3, 2, 1, 0, 1, 2],
[2, 1, 0, 1, 2, 3],
[1, 0, 1, 2, 3, 4],
[0, 1, 2, 3, 4, 5],
[1, 2, 3, 4, 5, 6],
[2, 3, 4, 5, 6, 5],
[3, 4, 5, 6, 5, 4]
[3, 4, 5, 6, 5, 4],
[4, 5, 6, 5, 4, 3]
]
>

iex> NxSignal.as_windowed(Nx.iota({10}), window_length: 6, padding: :reflect, stride: 2)
#Nx.Tensor<
s64[5][6]
s64[6][6]
[
[1, 2, 1, 0, 1, 2],
[3, 2, 1, 0, 1, 2],
[1, 0, 1, 2, 3, 4],
[1, 2, 3, 4, 5, 6],
[3, 4, 5, 6, 7, 8],
[5, 6, 7, 8, 9, 8]
[5, 6, 7, 8, 9, 8],
[7, 8, 9, 8, 7, 6]
]
>
"""
Expand All @@ -257,7 +259,7 @@ defmodule NxSignal do

as_windowed_parse_non_reflect_opts(
shape,
Keyword.put(opts, :padding, [{div(window_length, 2), div(window_length, 2) - 1}])
Keyword.put(opts, :padding, [{div(window_length, 2), div(window_length, 2)}])
Comment on lines -260 to +262
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@polvalente this changes some output shapes (the currently failing tests), but that's what torch.stft does. wdyt?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does it just decrease the number of windows?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The opposite, it adds an extra window, where half of the window is padded.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's ok then

)
end

Expand Down Expand Up @@ -333,114 +335,34 @@ defmodule NxSignal do
{window_length, stride, padding, output_shape} =
as_windowed_parse_non_reflect_opts(Nx.shape(tensor), opts)

output = Nx.broadcast(Nx.tensor(0, type: tensor.type), output_shape)
{num_windows, _} = Nx.shape(output)

index_template =
Nx.concatenate([Nx.broadcast(0, {window_length, 1}), Nx.iota({window_length, 1})], axis: 1)

{output, _, _, _, _} =
while {output, i = 0, current_window = 0, t = Nx.pad(tensor, 0, padding), index_template},
current_window < num_windows do
indices = index_template + Nx.stack([current_window, 0])
updates = t |> Nx.slice([i], [window_length]) |> Nx.flatten()

updated = Nx.indexed_add(output, indices, updates)
tensor = Nx.pad(tensor, 0, padding)

{updated, i + stride, current_window + 1, t, index_template}
end

output
as_windowed_apply(tensor, stride, output_shape, window_length)
end

defnp as_windowed_reflect_padding(tensor, opts \\ []) do
# current implementation only supports windowing 1D tensors
{window_length, stride, _padding, output_shape} =
as_windowed_parse_reflect_opts(Nx.shape(tensor), opts)

output = Nx.broadcast(Nx.tensor(0, type: tensor.type), output_shape)
{num_windows, _} = Nx.shape(output)

index_template =
Nx.concatenate([Nx.broadcast(0, {window_length, 1}), Nx.iota({window_length, 1})], axis: 1)

leading_window_indices = generate_leading_window_indices(window_length, stride)

trailing_window_indices =
generate_trailing_window_indices(Nx.size(tensor), window_length, stride)

half_window = div(window_length - 1, 2) + 1

{output, _, _, _, _} =
while {output, i = 0, current_window = 0, t = tensor, index_template},
current_window < num_windows do
# Here windows are centered at the current index

cond do
i < half_window ->
# We're indexing before we have a full window on the left

window = Nx.take(t, leading_window_indices[i])

indices = index_template + Nx.stack([current_window, 0])
updated = Nx.indexed_add(output, indices, window)

{updated, i + stride, current_window + 1, t, index_template}

i > Nx.size(t) - half_window ->
# We're indexing after the last full window on the right
window = Nx.take(t, trailing_window_indices[i - (Nx.size(t) - half_window + 1)])

indices = index_template + Nx.stack([current_window, 0])
updated = Nx.indexed_add(output, indices, window)

{updated, i + stride, current_window + 1, t, index_template}

true ->
# Case where we can index a full window
indices = index_template + Nx.stack([current_window, 0])
updates = t |> Nx.slice([i - half_window], [window_length]) |> Nx.flatten()

updated = Nx.indexed_add(output, indices, updates)

{updated, i + stride, current_window + 1, t, index_template}
end
end

# Now we need to handle the tail-end of the windows,
# since they are currently all the same value. We want to apply the tapering-off
# like we did with the initial windows.

output
end

deftransformp generate_leading_window_indices(window_length, stride) do
half_window = div(window_length, 2)
tensor = Nx.reflect(tensor, padding_config: [{half_window, half_window}])

for offset <- 0..half_window//stride do
partial_length = offset + half_window
padding_length = window_length - partial_length

{partial_length}
|> Nx.iota()
|> Nx.reflect(padding_config: [{padding_length, 0}])
end
|> Nx.stack()
as_windowed_apply(tensor, stride, output_shape, window_length)
end

deftransformp generate_trailing_window_indices(tensor_size, window_length, stride) do
min_index = tensor_size - window_length + 1
defnp as_windowed_apply(tensor, stride, output_shape, window_length) do
output = Nx.broadcast(Nx.tensor(0, type: tensor.type), output_shape)
{num_windows, _} = Nx.shape(output)

for {offset, add} <- Enum.with_index(min_index..(tensor_size - 1)//stride) do
partial_length = tensor_size - offset
padding_length = window_length - partial_length
{output, _, _, _} =
while {output, i = 0, current_window = 0, t = tensor}, current_window < num_windows do
window = t |> Nx.slice([i], [window_length])
updated = Nx.put_slice(output, [current_window, 0], Nx.new_axis(window, 0))
{updated, i + stride, current_window + 1, t}
end

{partial_length}
|> Nx.iota()
|> Nx.add(min_index + add - rem(window_length, 2))
|> Nx.reflect(padding_config: [{0, padding_length}])
end
|> Nx.stack()
output
end

@doc """
Expand Down Expand Up @@ -548,15 +470,16 @@ defmodule NxSignal do
iex> Nx.axis_size(z, :frequencies)
16
iex> Nx.axis_size(z, :frames)
5
6
iex> NxSignal.stft_to_mel(z, sampling_rate, fft_length: fft_length, mel_bins: 4)
#Nx.Tensor<
f32[frames: 5][mel: 4]
f32[frames: 6][mel: 4]
[
[0.2900530695915222, 0.17422175407409668, 0.18422472476959229, 0.09807997941970825],
[0.6093881130218506, 0.5647397041320801, 0.4353824257850647, 0.08635270595550537],
[0.7584103345870972, 0.7085014581680298, 0.5636920928955078, 0.179118812084198],
[0.8461772203445435, 0.7952491044998169, 0.6470762491226196, 0.2520409822463989],
[0.908548891544342, 0.8572604656219482, 0.7078656554222107, 0.3086767792701721],
[0.908548891544342, 0.8572604656219482, 0.7078656554222107, 0.3086767792701721]
]
>
Expand Down
Loading