Skip to content

Conversation

@aghilann
Copy link

@aghilann aghilann commented May 7, 2025

Summary:
The binomial function previously only accepted floating-point dtypes, even though it should typically return integers. This PR allows integer dtypes (e.g., int32, int64) while preserving backward compatibility by keeping floats as the default.
I verified the changes by:

  • Adding unit tests to check that the return type matches the specified dtype.
  • Ensuring existing tests that validate output values continue to pass.

Fixes #28457

Copy link
Collaborator

@jakevdp jakevdp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for taking a look at this! A few comments below

@jakevdp jakevdp self-assigned this May 7, 2025
@aghilann aghilann force-pushed the binomial-float-to-int-fix branch from 852d921 to f2799c9 Compare May 8, 2025 02:16
@aghilann
Copy link
Author

aghilann commented May 8, 2025

Thanks for taking a look at this! A few comments below

Resolved comments, thanks for the detailed review. This is my first open-source contribution, so apologies for any newbie mistakes!

@aghilann aghilann requested a review from jakevdp May 8, 2025 02:18
shape = core.canonicalize_shape(shape)
return _binomial(key, n, p, shape, dtype)

float_type = dtypes.canonicalize_dtype(np.float32)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's use float_type = dtypes.to_inexact_dtype(dtype) to respect the input dtype if it's floating point.

if dtypes.issubdtype(dtype, np.integer):
samples = jnp.round(samples)

return lax.convert_element_type(samples, dtype)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What happens here when the sampler returns NaN or inf and dtype is integer? Is it the correct semantics for those cases? (this is a hard question to answer, and probably why float dtypes were required in the original implementation).

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm honestly not entirely sure what the best course of action would be. Could we disallow a dtype of int if a NaN or inf is present?

Copy link
Collaborator

@jakevdp jakevdp May 8, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, that's not possible unfortunately. dtypes are specified statically (i.e. at compile time) and the value of p in general is not known until runtime. Static quantities cannot be dependent on dynamic quantities.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What do you think about adding logic like this to sanitize the values before the round and cast?

if dtypes.issubdtype(dtype, np.integer):
    # replace NaN→0, +∞→max, -∞→min before rounding
    info = np.iinfo(dtype)
    samples = jnp.nan_to_num(
        samples,
        nan=0.0,
        posinf=info.max,
        neginf=info.min,
    )
    samples = jnp.round(samples)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Something like that may work, but I wonder if it wouldn't be less confusing to just keep requiring float output? Why do we need to support integer output if it will lead to ambiguities?

For reference, I found the place where output type was originally discussed: #16134 (review)

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reading that discussion, I understand now why a float dtype was chosen - it wasn’t just a bug. I think my solution could work, but I’m not sure it justifies the added complexity and potential confusion. The only reason I added support for int outputs is because someone raised the issue, and I initially thought it was a bug. Should we go ahead and close this PR?

By the way, I casually browsed the issues for any good first ones, but if you ever come across something small that you’d like help with, feel free to tag me - I'm happy to contribute.

@aghilann aghilann closed this May 8, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

jax.random.binomial returns float

2 participants