You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Something I ran into is that different backends prefer either single or double precision. I personally need double precision, or at least prefer to consistently use one precision. This is also much more fair for benchmarking. The main problem is when forming arrays, for example to generate (2,2) random normal array with double precision we should do:
We can't just always supply the dtype argument, since numpy, dask and sparse throw an error when fed dtype. We could also generate whatever dtype array and then convert the result to double precision, but this doesn't really address the problem. This doesn't just hold for random.normal, but for essentially any kind of array-creating functions, like zeros or eye, although there supplying dtype does work. (for jax we still need to set jax_enable_x64.)
I can also see from your gen_rand method in test_autoray.py that you encountered similar problems.
Suggested solutions
Make a wrapper for numpy, dask, sparse (and cupy?) that ignores the dtype keyword, and then converts result to the correct dtype after the fact (if dtype is 'float32'). For jax we should maybe throw a warning if trying to generate double precision random numbers without setting 'jax_enable_x64' to True. In fact, for example for 'zeros', jax already throws a warning in this situation.
Make a autoray.random.Generator object, like the numpy.random.Generator, but then backend aware. This may perform slightly better, and I think numpy is urging people to start using this over calling e.g. numpy.random.normal directly (although it doesn't seem to be catching on).
It might also be worthwhile to add translations for some more standard distributions like binomial or poisson, although I mostly use normal and uniform myself.
The text was updated successfully, but these errors were encountered:
Yeah this is a slightly fiddly one, my approach so far is just to focus on what happens once you have created the arrays, with the random functions included so far mostly just for testing purposes. Jax already handles random numbers pretty differently.
Having said that having a guaranteed interface for being able to call (as a first draft of things I too would find useful):
random.normal / random.uniform
eye / identity
zeros
ones
with dtype specified would be useful in and of itself, as well as for writing lots of algorithms which aren't 'creating arrays' as their overall purpose - e.g. randomized SVD etc.
Happy to accept any PRs along this direction. Some general thoughts:
I think supporting the more concise do('random.normal', ...) style calls would be simplest to begin with, even if numpy.random.Generator is used in the background.
Unless they are exceptionally easy to add I'm tempted to leave out stuff like binomial until someones needs them.
I think it would make sense to default to dtype=None with the backend deciding the dtype since jax e.g. and other backends have goodish reasons for by default forcing everything to the same precision (hardware-acceleration) - in other words, dtype control probably should be opt-in with errors / loud warnings? when it can't be provided
Something I ran into is that different backends prefer either single or double precision. I personally need double precision, or at least prefer to consistently use one precision. This is also much more fair for benchmarking. The main problem is when forming arrays, for example to generate (2,2) random normal array with double precision we should do:
We can't just always supply the
dtype
argument, since numpy, dask and sparse throw an error when feddtype
. We could also generate whatever dtype array and then convert the result to double precision, but this doesn't really address the problem. This doesn't just hold forrandom.normal
, but for essentially any kind of array-creating functions, likezeros
oreye
, although there supplyingdtype
does work. (for jax we still need to setjax_enable_x64
.)I can also see from your
gen_rand
method intest_autoray.py
that you encountered similar problems.Suggested solutions
dtype
keyword, and then converts result to the correct dtype after the fact (if dtype is 'float32'). For jax we should maybe throw a warning if trying to generate double precision random numbers without setting'jax_enable_x64'
toTrue
. In fact, for example for 'zeros', jax already throws a warning in this situation.It might also be worthwhile to add translations for some more standard distributions like
binomial
orpoisson
, although I mostly use normal and uniform myself.The text was updated successfully, but these errors were encountered: