Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Random numbers and dtypes #7

Closed
RikVoorhaar opened this issue Feb 22, 2021 · 1 comment · Fixed by #8
Closed

Random numbers and dtypes #7

RikVoorhaar opened this issue Feb 22, 2021 · 1 comment · Fixed by #8

Comments

@RikVoorhaar
Copy link
Contributor

Something I ran into is that different backends prefer either single or double precision. I personally need double precision, or at least prefer to consistently use one precision. This is also much more fair for benchmarking. The main problem is when forming arrays, for example to generate (2,2) random normal array with double precision we should do:

import jax
jax.config.update('jax_enable_x64', True)
for backend in ['numpy', 'tensorflow', 'torch', 'jax', 'dask', 'mars', 'sparse']:
    if backend in ('tensorflow', 'torch'):
        A = ar.do("random.normal", size=(2,2), like=backend, dtype=ar.to_backend_dtype('float64', backend))
    else:
        A = ar.do("random.normal", size=(2,2), like=backend)

We can't just always supply the dtype argument, since numpy, dask and sparse throw an error when fed dtype. We could also generate whatever dtype array and then convert the result to double precision, but this doesn't really address the problem. This doesn't just hold for random.normal, but for essentially any kind of array-creating functions, like zeros or eye, although there supplying dtype does work. (for jax we still need to set jax_enable_x64.)
I can also see from your gen_rand method in test_autoray.py that you encountered similar problems.

Suggested solutions

  • Make a wrapper for numpy, dask, sparse (and cupy?) that ignores the dtype keyword, and then converts result to the correct dtype after the fact (if dtype is 'float32'). For jax we should maybe throw a warning if trying to generate double precision random numbers without setting 'jax_enable_x64' to True. In fact, for example for 'zeros', jax already throws a warning in this situation.
  • Make a autoray.random.Generator object, like the numpy.random.Generator, but then backend aware. This may perform slightly better, and I think numpy is urging people to start using this over calling e.g. numpy.random.normal directly (although it doesn't seem to be catching on).

It might also be worthwhile to add translations for some more standard distributions like binomial or poisson, although I mostly use normal and uniform myself.

@jcmgray
Copy link
Owner

jcmgray commented Feb 24, 2021

Yeah this is a slightly fiddly one, my approach so far is just to focus on what happens once you have created the arrays, with the random functions included so far mostly just for testing purposes. Jax already handles random numbers pretty differently.

Having said that having a guaranteed interface for being able to call (as a first draft of things I too would find useful):

  • random.normal / random.uniform
  • eye / identity
  • zeros
  • ones

with dtype specified would be useful in and of itself, as well as for writing lots of algorithms which aren't 'creating arrays' as their overall purpose - e.g. randomized SVD etc.

Happy to accept any PRs along this direction. Some general thoughts:

  • I think supporting the more concise do('random.normal', ...) style calls would be simplest to begin with, even if numpy.random.Generator is used in the background.
  • Unless they are exceptionally easy to add I'm tempted to leave out stuff like binomial until someones needs them.
  • I think it would make sense to default to dtype=None with the backend deciding the dtype since jax e.g. and other backends have goodish reasons for by default forcing everything to the same precision (hardware-acceleration) - in other words, dtype control probably should be opt-in with errors / loud warnings? when it can't be provided

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants