Fix edge cases: linspace n=1, scalar slice, gather error, window_scatter f64, floating point tests (IEEE 754 deferred to Complex library PR)#1707
Conversation
BinaryBackend crashes with ArithmeticError in several cases where IEEE 754 requires returning Inf or NaN. This affects users whose data contains large values (common when logits/activations explode during training) or involves division by zero (normalizing by variance which can be zero). Fixes: 1. Unary math overflow: exp(1000.0), sinh(1000.0), cosh(1000.0) etc. now return Inf instead of crashing. sigmoid returns 1.0/0.0 for extreme inputs. 2. Domain errors: asin(2.0), acos(2.0), acosh(0.5), atanh(2.0) now return NaN instead of crashing. atanh(1.0) returns Inf. 3. Division by zero: 1.0/0.0 returns Inf, -1.0/0.0 returns -Inf, 0.0/0.0 returns NaN. Respects -0.0 sign per IEEE 754. 4. window_scatter_max/min on f64: binary size calculation produced wrong-sized output. Fixed by casting scatter result to output type. 5. Nx.slice on scalar tensor: bin_slice crashed calling hd([]) on empty strides list. Added scalar guard clause. 6. Nx.linspace with n=1: divided by zero computing step size. Special-cased to return start value directly. 7. Nx.gather with scalar indices: gave unhelpful Erlang error instead of the intended "expected indices rank to be at least 1" message. Moved shape check before indexed_axes call. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
polvalente
left a comment
There was a problem hiding this comment.
The floating point standard changes should be applied to Complex instead.
Also, keep in mind that -0.0 and 0.0 do not match in recent OTP, so we can leverage that instead of relying on binary.
|
Oh, and the PR seems to be doing a lot more than just FP corrections. My suggestion is to keep those here and rename the PR. |
…ter f64 Four independent fixes: 1. window_scatter_max/min on f64: binary size mismatch. Fixed by casting scatter result to output type before to_binary. 2. Nx.slice on scalar tensor: bin_slice crashed calling hd([]) on empty strides list. Added scalar guard clause. 3. Nx.linspace with n=1: divided by zero computing step size. Special-cased to return start value directly. 4. Nx.gather with scalar indices: gave unhelpful Erlang error. Moved shape check before indexed_axes call. IEEE 754 overflow/domain/divzero tests are skipped pending upstream fix in the Complex library (elixir-nx/complex#29). Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
13482d9 to
6697362
Compare
|
Right! removed floating point fixes -> these live in Complex now (PR elixir-nx/complex#29). We should make a chore to unskip the floating point tests that are currently skipped in this PR when that is merged upstream. Kept: the other edge case fixes:
I'd say if you want to see any of these, try copying over the edge case tests (toward the end of the tests file) and running them on main. test file renamed to All of these fixes were found via fuzz testing (https://github.com/blasphemetheus/nx/tree/fork/fuzz-testing) (there's a lot of new tests over there, most of them passing, not probably worth adding most of those as they usually duplicate existing tests) |
polvalente
left a comment
There was a problem hiding this comment.
This PR is doing 3 unrelated things.
- FP tests - these should be removed altogether and either added when updating complex, or just kept in complex itself
- Fix for Nx.linspace
- Fix for select and scatter
We should make an effort to work on reduced scope PRs
Erlang's :math module doesn't abide by IEEE 754 and raises an ArithmenticError for some cases.
BinaryBackend crashes with ArithmeticError in several cases where IEEE 754 requires returning Inf or NaN. This affects users whose data contains large values (common when logits/activations explode during training) or involves division by zero (normalizing by variance which can be zero).
Fixes:
Unary math overflow: exp(1000.0), sinh(1000.0), cosh(1000.0) etc. now return Inf instead of crashing. sigmoid returns 1.0/0.0 for extreme inputs.
Domain errors: asin(2.0), acos(2.0), acosh(0.5), atanh(2.0) now return NaN instead of crashing. atanh(1.0) returns Inf.
Division by zero: 1.0/0.0 returns Inf, -1.0/0.0 returns -Inf, 0.0/0.0 returns NaN. Respects -0.0 sign per IEEE 754.
window_scatter_max/min on f64: binary size calculation produced wrong-sized output. Fixed by casting scatter result to output type.
Nx.slice on scalar tensor: bin_slice crashed calling hd([]) on empty strides list. Added scalar guard clause.
Nx.linspace with n=1: divided by zero computing step size. Special-cased to return start value directly.
Nx.gather with scalar indices: gave unhelpful Erlang error instead of the intended "expected indices rank to be at least 1" message. Moved shape check before indexed_axes call.