Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix np_uniform_impl3 for handling size=() #8997

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
3 changes: 2 additions & 1 deletion numba/cpython/randomimpl.py
Original file line number Diff line number Diff line change
Expand Up @@ -676,9 +676,10 @@ def impl(context, builder, sig, args):

@overload(np.random.uniform)
def np_uniform_impl3(low, high, size):
is_empty_tuple = lambda x :isinstance(x, types.Tuple) and len(x.types) == 0
if (isinstance(low, (types.Float, types.Integer)) and isinstance(
high, (types.Float, types.Integer)) and
is_nonelike(size)):
(is_nonelike(size) or is_empty_tuple(size))):
return lambda low, high, size: np.random.uniform(low, high)
if (isinstance(low, (types.Float, types.Integer)) and isinstance(
high, (types.Float, types.Integer)) and
Expand Down
7 changes: 7 additions & 0 deletions numba/tests/test_random.py
Original file line number Diff line number Diff line change
Expand Up @@ -531,6 +531,13 @@ def test_numpy_uniform_kwargs(self):
paramlist=[{'low': 1.5, 'high': 1e6},
{'low': -2.5, 'high': 1e3},
{'low': 1.5, 'high': -2.5}])

def test_numpy_uniform_impl3(self):
self._check_any_distrib_kwargs(
jit_with_kwargs("np.random.uniform", ['low', 'high', 'size']),
get_np_state_ptr(),
'uniform',
paramlist=[{'low' : 0.0, 'high' : 1.0, 'size': None}])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The size needs to check for () and not None.

Suggested change
paramlist=[{'low' : 0.0, 'high' : 1.0, 'size': None}])
paramlist=[{'low' : 0.0, 'high' : 1.0, 'size': ()}])

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @kc611! I applied the suggestion and ran the test locally. The test fails because of dtype mismatch (see attached screenshot). I think I need to fiddle around with this method. But I am not sure. Could you help me with this?

image


def _check_triangular(self, func2, func3, ptr):
"""
Expand Down