Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Median Filter 2D #20

Merged
merged 6 commits into from
May 18, 2024
Merged

Conversation

santiago-imelio
Copy link
Contributor

I was writing this notebook to learn Nx.Defn and thought that with some improvements it could potentially belong here :)

Would love to get feedback on how to improve my implementation. Once I have some feedback I can continue with documentation and tests.

Here is the result from applying to a noisy image:

Original

Screenshot 2024-05-08 at 23 29 22

Median filter applied (kernel_size = 5)

Screenshot 2024-05-08 at 23 29 51

lib/nx_signal/filters.ex Show resolved Hide resolved
lib/nx_signal/filters.ex Outdated Show resolved Hide resolved
lib/nx_signal/filters.ex Outdated Show resolved Hide resolved
lib/nx_signal/filters.ex Outdated Show resolved Hide resolved
Comment on lines 34 to 35
{_, _, _, _, _, result} =
while {i = 0, t, kernel_tensor, height, width, output}, i < height do
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
{_, _, _, _, _, result} =
while {i = 0, t, kernel_tensor, height, width, output}, i < height do
{result, _} =
while {output, {i = 0, t, kernel_tensor, height, width}}, i < height do

Refactor the loops to follow this kind of pattern as it's easier to match on and brings the output to the front

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, here's an idea for a refactor using vectorization.

Let's use hardcoded values for simplicity:

window_shape = {3, 3}
tensor = Nx.iota({5, 5})

# first, let's build out a base indices tensor. these can be used with gather to access each element in the tensor
idx = Nx.stack([Nx.iota(tensor.shape, axis: 0), Nx.iota(tensor.shape, axis: 1)], axis: -1) |> Nx.reshape({:auto, 2}) |> Nx.vectorize(:elements)

# then, we can build an offset mask for each window:

padding_x = div(elem(window_shape, 0), 2)
padding_y = div(elem(window_shape, 1), 2)
window_offsets = Nx.stack([
  Nx.iota(window_shape, axis: 0) |> Nx.subtract(padding_x),
  Nx.iota(window_shape, axis: 1)  |> Nx.subtract(padding_y)
], axis: -1) |> Nx.reshape({:auto, 2}) |> Nx.vectorize(:offsets)

# because we're vectorizing with different names, when we operate on the tensors, we'll get
# a cross-operation between each element of one with all of the others. This can be used
# to build each median window.

idx = Nx.add(idx, window_offsets) |> Nx.add(Nx.tensor([padding_x, padding_y]))

tensor
|> Nx.pad(0, [{padding_x, padding_x, 0}, {padding_y, padding_y, 0}])
|> Nx.gather(idx) # the tensor here is now vectorized {elements: 25, offsets: 9} with scalar elements
|> Nx.devectorize(keep_names: false)
|> Nx.median(axis: 1)
|> Nx.reshape(tensor.shape)

I haven't cross-checked results, this might be wrong

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@santiago-imelio did you get around to trying the vectorized version out?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of Nx.pad the boundaries can be handled using different functions (e.g.Nx.reflect or custom implementations) to handle mode. I wouldn't worry about this in this PR, though.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I haven't yet, first I would like to understand the concept of vectorization and play with it, need some time for that. Also I'm somewhat of a noob to scientific programming so bare with me haha.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No worries. I think there is a guide on vectorization in the main nx repo

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@polvalente I was able to use vectorization, although it's somwhat different from your suggestion. Seems to be working fine with the same examples (with minor discrepancies). Let me know if the approach is correct.

Copy link
Contributor Author

@santiago-imelio santiago-imelio left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the approach I'm using for getting the window using Nx.slice will raise some problems if we want to implement different strategies for handling boundaries like scipy does

https://docs.scipy.org/doc/scipy/reference/generated/scipy.ndimage.median_filter.html#scipy.ndimage.median_filter

|> Nx.vectorize(:elements)

t
|> Nx.slice([idx[0], idx[1]], [k0, k1])
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice! This is even better than what I originally had in mind!

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this approach should be easily extendable to n dimensions, actually!

@polvalente polvalente merged commit d0b7df4 into elixir-nx:main May 18, 2024
2 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants