Skip to content

Commit

Permalink
Add Median Filter 2D (#20)
Browse files Browse the repository at this point in the history
* bump nx to 0.7

* implement median filter 2D

* refactor median filter, move sinc to waveforms

* use deftransform, minor tweaks

* use vectorization, add tests

* run mix format
  • Loading branch information
santiago-imelio committed May 18, 2024
1 parent f6c83e8 commit d0b7df4
Show file tree
Hide file tree
Showing 5 changed files with 133 additions and 16 deletions.
55 changes: 41 additions & 14 deletions lib/nx_signal/filters.ex
Original file line number Diff line number Diff line change
Expand Up @@ -4,25 +4,52 @@ defmodule NxSignal.Filters do
"""
import Nx.Defn

@pi :math.pi()

@doc ~S"""
Calculates the normalized sinc function $sinc(t) = \frac{sin(\pi t)}{\pi t}$
Performs a median filter on a rank 1 or rank 2 tensor.
## Examples
## Options
iex> NxSignal.Filters.sinc(Nx.tensor([0, 0.25, 1]))
#Nx.Tensor<
f32[3]
[1.0, 0.9003162980079651, -2.7827534054836178e-8]
>
* `:kernel_shape` - the shape of the sliding window.
It must be compatible with the shape of the tensor.
"""
@doc type: :filters
defn sinc(t) do
t = t * @pi
zero_idx = Nx.equal(t, 0)
deftransform median(t = %Nx.Tensor{shape: {length}}, opts) do
validate_median_opts!(t, opts)
{kernel_length} = opts[:kernel_shape]

median(Nx.reshape(t, {1, length}), kernel_shape: {1, kernel_length})
|> Nx.squeeze()
end

deftransform median(t = %Nx.Tensor{shape: {_h, _w}}, opts) do
validate_median_opts!(t, opts)
median_n(t, opts)
end

deftransform median(_t, _opts),
do: raise(ArgumentError, message: "tensor must be of rank 1 or 2")

defn median_n(t, opts) do
{k0, k1} = opts[:kernel_shape]

idx =
Nx.stack([Nx.iota(t.shape, axis: 0), Nx.iota(t.shape, axis: 1)], axis: -1)
|> Nx.reshape({:auto, 2})
|> Nx.vectorize(:elements)

t
|> Nx.slice([idx[0], idx[1]], [k0, k1])
|> Nx.median()
|> Nx.devectorize(keep_names: false)
|> Nx.reshape(t.shape)
|> Nx.as_type({:f, 32})
end

deftransformp validate_median_opts!(t, opts) do
Keyword.validate!(opts, [:kernel_shape])

# Define sinc(0) = 1
Nx.select(zero_idx, 1, Nx.sin(t) / t)
if Nx.rank(t) != Nx.rank(opts[:kernel_shape]) do
raise ArgumentError, message: "kernel shape must be of the same rank as the tensor"
end
end
end
20 changes: 20 additions & 0 deletions lib/nx_signal/waveforms.ex
Original file line number Diff line number Diff line change
Expand Up @@ -435,4 +435,24 @@ defmodule NxSignal.Waveforms do
Nx.reshape(index, {n})
end
end

@doc ~S"""
Calculates the normalized sinc function $sinc(t) = \frac{sin(\pi t)}{\pi t}$
## Examples
iex> NxSignal.Waveforms.sinc(Nx.tensor([0, 0.25, 1]))
#Nx.Tensor<
f32[3]
[1.0, 0.9003162980079651, -2.7827534054836178e-8]
>
"""
@doc type: :waveforms
defn sinc(t) do
t = t * pi()
zero_idx = Nx.equal(t, 0)

# Define sinc(0) = 1
Nx.select(zero_idx, 1, Nx.sin(t) / t)
end
end
2 changes: 1 addition & 1 deletion mix.exs
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ defmodule NxSignal.MixProject do
# Run "mix help deps" to learn about dependencies.
defp deps do
[
{:nx, "~> 0.6"},
{:nx, "~> 0.7"},
{:ex_doc, "~> 0.29", only: :docs}
]
end
Expand Down
2 changes: 1 addition & 1 deletion mix.lock
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
"makeup_elixir": {:hex, :makeup_elixir, "0.16.0", "f8c570a0d33f8039513fbccaf7108c5d750f47d8defd44088371191b76492b0b", [:mix], [{:makeup, "~> 1.0", [hex: :makeup, repo: "hexpm", optional: false]}, {:nimble_parsec, "~> 1.2.3", [hex: :nimble_parsec, repo: "hexpm", optional: false]}], "hexpm", "28b2cbdc13960a46ae9a8858c4bebdec3c9a6d7b4b9e7f4ed1502f8159f338e7"},
"makeup_erlang": {:hex, :makeup_erlang, "0.1.1", "3fcb7f09eb9d98dc4d208f49cc955a34218fc41ff6b84df7c75b3e6e533cc65f", [:mix], [{:makeup, "~> 1.0", [hex: :makeup, repo: "hexpm", optional: false]}], "hexpm", "174d0809e98a4ef0b3309256cbf97101c6ec01c4ab0b23e926a9e17df2077cbb"},
"nimble_parsec": {:hex, :nimble_parsec, "1.2.3", "244836e6e3f1200c7f30cb56733fd808744eca61fd182f731eac4af635cc6d0b", [:mix], [], "hexpm", "c8d789e39b9131acf7b99291e93dae60ab48ef14a7ee9d58c6964f59efb570b0"},
"nx": {:hex, :nx, "0.6.0", "37c86eae824125a7e298dd1ee896953d9d671ce3630dcff74c77db17d734a85f", [:mix], [{:complex, "~> 0.5", [hex: :complex, repo: "hexpm", optional: false]}, {:telemetry, "~> 0.4.0 or ~> 1.0", [hex: :telemetry, repo: "hexpm", optional: false]}], "hexpm", "e1ad3cc70a5828a1aedb156b71e90863d9623a2dc9b35a5588f8627a07ee6cb4"},
"nx": {:hex, :nx, "0.7.1", "5f6376e3d18408116e8a84b8f4ac851fb07dfe61764a5410ebf0b5dcb69c1b7e", [:mix], [{:complex, "~> 0.5", [hex: :complex, repo: "hexpm", optional: false]}, {:telemetry, "~> 0.4.0 or ~> 1.0", [hex: :telemetry, repo: "hexpm", optional: false]}], "hexpm", "e3ddd6a3f2a9bac79c67b3933368c25bb5ec814a883fc68aba8fd8a236751777"},
"telemetry": {:hex, :telemetry, "1.2.1", "68fdfe8d8f05a8428483a97d7aab2f268aaff24b49e0f599faa091f1d4e7f61c", [:rebar3], [], "hexpm", "dad9ce9d8effc621708f99eac538ef1cbe05d6a874dd741de2e689c47feafed5"},
"xla": {:hex, :xla, "0.4.3", "cf6201aaa44d990298996156a83a16b9a87c5fbb257758dbf4c3e83c5e1c4b96", [:make, :mix], [{:elixir_make, "~> 0.4", [hex: :elixir_make, repo: "hexpm", optional: false]}], "hexpm", "caae164b56dcaec6fbcabcd7dea14303afde07623b0cfa4a3cd2576b923105f5"},
}
70 changes: 70 additions & 0 deletions test/nx_signal/filters_test.exs
Original file line number Diff line number Diff line change
@@ -1,4 +1,74 @@
defmodule NxSignal.FiltersTest do
use NxSignal.Case
doctest NxSignal.Filters

describe "median/2" do
test "performs 1D median filter" do
t = Nx.tensor([10, 9, 8, 7, 1, 4, 5, 3, 2, 6])
opts = [kernel_shape: {3}]
expected = Nx.tensor([9.0, 8.0, 7.0, 4.0, 4.0, 4.0, 3.0, 3.0, 3.0, 3.0])

assert NxSignal.Filters.median(t, opts) == expected
end

test "performs 2D median filter" do
t =
Nx.tensor([
[31, 11, 17, 13, 1],
[1, 3, 19, 23, 29],
[19, 5, 7, 37, 2]
])

opts = [kernel_shape: {3, 3}]

expected =
Nx.tensor([
[11.0, 13.0, 17.0, 17.0, 17.0],
[11.0, 13.0, 17.0, 17.0, 17.0],
[11.0, 13.0, 17.0, 17.0, 17.0]
])

assert NxSignal.Filters.median(t, opts) == expected
end

test "raises if kernel_shape is not compatible" do
t1 = Nx.iota({10})
opts1 = [kernel_shape: {5, 5}]

assert_raise(
ArgumentError,
"kernel shape must be of the same rank as the tensor",
fn -> NxSignal.Filters.median(t1, opts1) end
)

t2 = Nx.iota({5, 5})
opts2 = [kernel_shape: {5, 5, 5}]

assert_raise(
ArgumentError,
"kernel shape must be of the same rank as the tensor",
fn -> NxSignal.Filters.median(t2, opts2) end
)
end

test "raises if tensor rank is not 1 or 2" do
t1 = Nx.tensor(1)
opts1 = [kernel_shape: {1}]

assert_raise(
ArgumentError,
"tensor must be of rank 1 or 2",
fn -> NxSignal.Filters.median(t1, opts1) end
)

t2 = Nx.iota({5, 5, 5})
opts2 = [kernel_shape: {3, 3, 3}]

assert_raise(
ArgumentError,
"tensor must be of rank 1 or 2",
fn -> NxSignal.Filters.median(t2, opts2) end
)
end
end
end

0 comments on commit d0b7df4

Please sign in to comment.