Skip to content

Fix BinaryBackend.select_and_scatter crash with :same padding (#1675)Fix/select and scatter negative padding#1676

Merged
polvalente merged 3 commits into
elixir-nx:mainfrom
blasphemetheus:fix/select-and-scatter-negative-padding
Mar 8, 2026
Merged

Fix BinaryBackend.select_and_scatter crash with :same padding (#1675)Fix/select and scatter negative padding#1676
polvalente merged 3 commits into
elixir-nx:mainfrom
blasphemetheus:fix/select-and-scatter-negative-padding

Conversation

@blasphemetheus
Copy link
Copy Markdown
Contributor

Summary

select_and_scatter/8 computes absolute_index in the padded coordinate space but uses it against the unpadded output_shape, producing negative offsets that crash List.duplicate/2.

Fix: subtract low-padding from indices to map back to original coordinates, and filter out indices that land in the padding region.

Changes

nx/lib/nx/binary_backend.ex

  • Subtract per-dimension low-padding from absolute_index after computing it in padded space
  • Filter output_windows to discard indices outside original tensor bounds

nx/test/nx/defn/grad_test.exs

  • Add regression test: grad(sum(window_max(x, {1,3,3,1}, padding: :same, strides: [1,2,2,1]))) on a {1,4,4,1} input — verifies output shape, finiteness, and gradient sum equals number of output elements

Test Results

1170 tests, 0 failures, 1 skipped
1351 doctests, 0 failures

Closes #1675

Comment thread nx/lib/nx/binary_backend.ex Outdated
Comment on lines +1639 to +1642
absolute_index =
padded_absolute_index
|> Enum.zip(low_pads)
|> Enum.map(fn {idx, lo} -> idx - lo end)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
absolute_index =
padded_absolute_index
|> Enum.zip(low_pads)
|> Enum.map(fn {idx, lo} -> idx - lo end)
absolute_index = Enum.zip_with(padded_absolute_index, low_pads, &-/2)

Comment thread nx/lib/nx/binary_backend.ex Outdated
Comment on lines +1656 to +1657
Enum.zip(index, original_dims)
|> Enum.all?(fn {idx, dim} -> idx >= 0 and idx < dim end)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
Enum.zip(index, original_dims)
|> Enum.all?(fn {idx, dim} -> idx >= 0 and idx < dim end)
Enum.zip_reduce(index, original_dims, true, fn idx, dim, acc -> acc and idx >= 0 and idx < dim end)

Comment thread nx/test/nx/defn/grad_test.exs Outdated
Comment on lines +1423 to +1434

# With kernel=3, stride=2, same padding on 4x4:
# output is 2x2, max positions are at bottom-right of each window
# Gradient should be 1.0 at the max positions, 0.0 elsewhere
assert Nx.shape(lhs) == {1, 4, 4, 1}
assert Nx.type(lhs) == {:f, 32}

# All gradient values should be finite (no NaN/Inf)
assert Nx.all(Nx.is_nan(lhs) |> Nx.logical_not()) == Nx.tensor(1, type: {:u, 8})

# Gradient should sum to the number of output elements (each max gets grad 1.0)
assert_all_close(Nx.sum(lhs), Nx.tensor(4.0))
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the tensor is small enough that there should be an asser_equal or assert_all_close assertion on the lhs values directly.

Also, there should be new tests for scatter_max directly besides this one.

Copy link
Copy Markdown
Contributor

@polvalente polvalente Mar 4, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, for this assertion, we should double check that jax returns the same values for both the defined function as well as a compound function (say we put the input through Nx.cos and the output through Nx.sin to ensure non-linearity of the grad propagation -- this should be a new test)

Copy link
Copy Markdown
Contributor

@polvalente polvalente left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot for taking this on. Code generally looks good, I just think we should have more tests :)

@blasphemetheus
Copy link
Copy Markdown
Contributor Author

no problem! ran into it building another repo, also filed #1679 cause that popped up while trying to add more tests

blasphemetheus and others added 2 commits March 8, 2026 03:19
)

Adjust absolute_index from padded to original coordinates and filter
out-of-bounds indices in BinaryBackend.select_and_scatter/8.
Adds regression test for window_max gradient with :same padding.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
…#1675)

Apply reviewer code style suggestions (Enum.zip_with, Enum.zip_reduce)
and expand test coverage: exact gradient assertions, window_min/max/sum
with same/explicit/asymmetric padding, 1D/2D/4D shapes, kernel>input,
kernel==input, stride>kernel, and non-linear composition (cos→max→sin).

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
@blasphemetheus blasphemetheus force-pushed the fix/select-and-scatter-negative-padding branch from 122cf40 to 43d377f Compare March 8, 2026 08:19
@polvalente
Copy link
Copy Markdown
Contributor

Thanks a lot! I'm looking through all of your issues/PRs :)

@polvalente polvalente merged commit 094b72f into elixir-nx:main Mar 8, 2026
8 of 9 checks passed
@blasphemetheus blasphemetheus deleted the fix/select-and-scatter-negative-padding branch March 8, 2026 20:53
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

BinaryBackend.select_and_scatter/8 crashes with negative padding during max_pool gradient

2 participants