-
Notifications
You must be signed in to change notification settings - Fork 3.3k
Fix random.binomial(...) to accept dtype=int #28568
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
jakevdp
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for taking a look at this! A few comments below
852d921 to
f2799c9
Compare
Resolved comments, thanks for the detailed review. This is my first open-source contribution, so apologies for any newbie mistakes! |
| shape = core.canonicalize_shape(shape) | ||
| return _binomial(key, n, p, shape, dtype) | ||
|
|
||
| float_type = dtypes.canonicalize_dtype(np.float32) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's use float_type = dtypes.to_inexact_dtype(dtype) to respect the input dtype if it's floating point.
| if dtypes.issubdtype(dtype, np.integer): | ||
| samples = jnp.round(samples) | ||
|
|
||
| return lax.convert_element_type(samples, dtype) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What happens here when the sampler returns NaN or inf and dtype is integer? Is it the correct semantics for those cases? (this is a hard question to answer, and probably why float dtypes were required in the original implementation).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm honestly not entirely sure what the best course of action would be. Could we disallow a dtype of int if a NaN or inf is present?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No, that's not possible unfortunately. dtypes are specified statically (i.e. at compile time) and the value of p in general is not known until runtime. Static quantities cannot be dependent on dynamic quantities.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What do you think about adding logic like this to sanitize the values before the round and cast?
if dtypes.issubdtype(dtype, np.integer):
# replace NaN→0, +∞→max, -∞→min before rounding
info = np.iinfo(dtype)
samples = jnp.nan_to_num(
samples,
nan=0.0,
posinf=info.max,
neginf=info.min,
)
samples = jnp.round(samples)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Something like that may work, but I wonder if it wouldn't be less confusing to just keep requiring float output? Why do we need to support integer output if it will lead to ambiguities?
For reference, I found the place where output type was originally discussed: #16134 (review)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Reading that discussion, I understand now why a float dtype was chosen - it wasn’t just a bug. I think my solution could work, but I’m not sure it justifies the added complexity and potential confusion. The only reason I added support for int outputs is because someone raised the issue, and I initially thought it was a bug. Should we go ahead and close this PR?
By the way, I casually browsed the issues for any good first ones, but if you ever come across something small that you’d like help with, feel free to tag me - I'm happy to contribute.
Summary:
The
binomialfunction previously only accepted floating-point dtypes, even though it should typically return integers. This PR allows integer dtypes (e.g.,int32,int64) while preserving backward compatibility by keeping floats as the default.I verified the changes by:
Fixes #28457