Skip to content

Conversation

@jakevdp
Copy link
Collaborator

@jakevdp jakevdp commented Jun 16, 2020

This PR adds an experimental precision doubling transform, following the basic approach outlined in Dekker 1971 (pdf). When this transform is applied, the number of significant bits is approximately doubled compared to the base operation.

Simple demo:

In [1]: import jax.numpy as jnp                                                                                        

In [2]: from jax.experimental.doubledouble import doubledouble                                                         

In [3]: def f(a, b): 
   ...:     return a + b - a 
   ...:                                                                                                                

In [4]: f(1E20, 1.0)  # float64 loses precision
Out[4]: 0.0

In [5]: g = doubledouble(f)(1E20, 1.0)
Out[5]: DeviceArray(1., dtype=float64)

This initial experiment supports basic arithmetic operators and inequalities.

@jakevdp jakevdp requested a review from mattjj June 16, 2020 21:28
Copy link
Collaborator

@shoyer shoyer left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very cool!

@shoyer
Copy link
Collaborator

shoyer commented Jun 16, 2020

You can double-double, but can you double-double-double? Might be a nice consistency test, and conceivably if somebody cares about double-double they might be interested in even higher precision, too.

@mattjj
Copy link
Collaborator

mattjj commented Jun 16, 2020

You can double-double, but can you double-double-double? Might be a nice consistency test, and conceivably if somebody cares about double-double they might be interested in even higher precision, too.

Yes! That was a key goal (and win) from doing this in JAX.

I'm sure it can work, though I'm not sure if there are any fiddly things left to get right in the current implementation (as of yesterday afternoon it was a WIP).

@jakevdp jakevdp changed the title Experimental: add precision doubling transform WIP: add experimental precision doubling transform Jun 16, 2020
@jakevdp
Copy link
Collaborator Author

jakevdp commented Jun 16, 2020

One main question I have: I believe I'll need to flatten inputs to make this as general as possible – any pointers on that?

@mattjj
Copy link
Collaborator

mattjj commented Jun 17, 2020

@jakevdp the general pattern is in most api.py transformations, and might look something like this (untested):

def doubledouble(f):
  def wrapped(*args):
    args_flat, in_tree = tree_flatten(args)
    f_flat, out_tree = flatten_fun_nokwargs(lu.wrap_init(f), in_tree)
    arg_pairs = [(x, jnp.zeros_like(x)) for x in args_flat]
    out_pairs_flat = doubling_transform(f_flat).call_wrapped(*arg_pairs)
    out_flat = [hi + lo for hi, lo in out_pairs_flat]
    out = tree_unflatten(out_tree(), out_flat)
    return out
  return wrapped

@mattjj
Copy link
Collaborator

mattjj commented Jun 17, 2020

The logic there is: we basically want to generate a function f_flat that does a bit of processing to its input before calling f and processing to its output before returning, something like

def f_flat(*flat_args):
  args = unflatten(flat_args)
  out = f(*args)
  out_flat = flatten(out)
  return out_flat

What should we flatten to and from? Well, the caller needs to flatten the args it receives, so it can pass them to f_flat (which then turns around and unflattens them to call f). That's why we pass in in_tree from the caller: f_flat needs to know how to unflatten its flat inputs. The output side is similar but trickier: we need to plumb out the out_tree somehow, smuggling it out of the function so the caller can unflatten the flat outputs we give it.

That's what flatten_fun_nokwargs in api_util.py does:

@lu.transformation_with_aux
def flatten_fun_nokwargs(in_tree, *args_flat):
  py_args = tree_unflatten(in_tree, args_flat)
  ans = yield py_args, {}
  out_flat, out_tree = tree_flatten(ans)
  yield out_flat, out_tree

It's written in an unfamiliar way, but the logic is just the above: we pass in in_tree and args_flat, unflatten to get args, call the underlying function f, then flatten its outputs and return them while smuggling out out_tree. The first yield statement is about calling the underlying function, while the second yield statement is returning to the caller.

Copy link
Collaborator

@mattjj mattjj left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is so cool! Nice work, I can't wait to learn more. I also can't wait for doubledouble(doubledouble(f)), but no need to block on it :)

@jakevdp jakevdp changed the title WIP: add experimental precision doubling transform Add experimental precision doubling transform Jun 17, 2020
@jakevdp
Copy link
Collaborator Author

jakevdp commented Jun 17, 2020

Tests pass internally; I'm going to merge and then keep iterating.

@jakevdp jakevdp merged commit 5a74ebf into jax-ml:master Jun 17, 2020
@jakevdp jakevdp deleted the doubledouble branch June 17, 2020 21:59
@jakevdp jakevdp mentioned this pull request Jun 18, 2020
@gugar20
Copy link

gugar20 commented Apr 17, 2021

Is it possible to use this transformation on linear algebra functions? I would like to apply it to matrix inversion.

@shoyer
Copy link
Collaborator

shoyer commented Apr 17, 2021 via email

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants