Skip to content

Commit

Permalink
use vectorization, add tests
Browse files Browse the repository at this point in the history
  • Loading branch information
santiago-imelio committed May 17, 2024
1 parent a8b8acb commit cf79294
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 53 deletions.
64 changes: 11 additions & 53 deletions lib/nx_signal/filters.ex
Original file line number Diff line number Diff line change
Expand Up @@ -28,62 +28,20 @@ defmodule NxSignal.Filters do

deftransform median(_t, _opts), do: raise(ArgumentError, message: "tensor must be of rank 1 or 2")

defnp median_n(t, opts) do
kernel_shape = opts[:kernel_shape]
defn median_n(t, opts) do
{k0, k1} = opts[:kernel_shape]

{height, width} = Nx.shape(t)
idx =
Nx.stack([Nx.iota(t.shape, axis: 0), Nx.iota(t.shape, axis: 1)], axis: -1)
|> Nx.reshape({:auto, 2})
|> Nx.vectorize(:elements)

kernel_tensor = Nx.broadcast(0.0, kernel_shape)
output = Nx.broadcast(0.0, Nx.shape(t))

{result, _} =
while {output, {i = 0, t, kernel_tensor, height, width}}, i < height do
row = Nx.broadcast(0.0, {elem(Nx.shape(t), 1)})

{ith_row, _} =
while {row, {j = 0, t, i, kernel_tensor, width}}, j < width do
median =
window_median(t, i, j, kernel_tensor)
|> Nx.broadcast({1})

{Nx.put_slice(row, [j], median), {j + 1, t, i, kernel_tensor, width}}
end

{Nx.put_slice(output, [i, 0], Nx.stack(ith_row, axis: 0)),
{i + 1, t, kernel_tensor, height, width}}
end

result
end

defnp window_median(t, i, j, kernel_tensor) do
{k0, k1} = Nx.shape(kernel_tensor)

padding_y = Nx.round((k0 - 1) / 2)
padding_x = Nx.round((k1 - 1) / 2)

y_axis_start_idx =
if i - padding_y <= 0 do
0
else
i - padding_y
end
|> Nx.as_type({:u, 32})

x_axis_start_idx =
if j - padding_x <= 0 do
0
else
j - padding_x
end
|> Nx.as_type({:u, 32})

Nx.slice(
t,
[y_axis_start_idx, x_axis_start_idx],
[k0, k1]
)
t
|> Nx.slice([idx[0], idx[1]], [k0, k1])
|> Nx.median()
|> Nx.devectorize(keep_names: false)
|> Nx.reshape(t.shape)
|> Nx.as_type({:f, 32})
end

deftransformp validate_median_opts!(t, opts) do
Expand Down
28 changes: 28 additions & 0 deletions test/nx_signal/filters_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,34 @@ defmodule NxSignal.FiltersTest do
doctest NxSignal.Filters

describe "median/2" do
test "performs 1D median filter" do
t = Nx.tensor([10, 9, 8, 7, 1, 4, 5, 3, 2, 6])
opts = [kernel_shape: {3}]
expected = Nx.tensor([9.0, 8.0, 7.0, 4.0, 4.0, 4.0, 3.0, 3.0, 3.0, 3.0])

assert NxSignal.Filters.median(t, opts) == expected
end

test "performs 2D median filter" do
t =
Nx.tensor([
[31, 11, 17, 13, 1],
[1, 3, 19, 23, 29],
[19, 5, 7, 37, 2]
])

opts = [kernel_shape: {3, 3}]

expected =
Nx.tensor([
[11.0, 13.0, 17.0, 17.0, 17.0],
[11.0, 13.0, 17.0, 17.0, 17.0],
[11.0, 13.0, 17.0, 17.0, 17.0]
])

assert NxSignal.Filters.median(t, opts) == expected
end

test "raises if kernel_shape is not compatible" do
t1 = Nx.iota({10})
opts1 = [kernel_shape: {5, 5}]
Expand Down

0 comments on commit cf79294

Please sign in to comment.