-
Notifications
You must be signed in to change notification settings - Fork 3.4k
Add experimental precision doubling transform #3465
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
shoyer
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Very cool!
|
You can double-double, but can you double-double-double? Might be a nice consistency test, and conceivably if somebody cares about double-double they might be interested in even higher precision, too. |
Yes! That was a key goal (and win) from doing this in JAX. I'm sure it can work, though I'm not sure if there are any fiddly things left to get right in the current implementation (as of yesterday afternoon it was a WIP). |
|
One main question I have: I believe I'll need to flatten inputs to make this as general as possible – any pointers on that? |
|
@jakevdp the general pattern is in most api.py transformations, and might look something like this (untested): def doubledouble(f):
def wrapped(*args):
args_flat, in_tree = tree_flatten(args)
f_flat, out_tree = flatten_fun_nokwargs(lu.wrap_init(f), in_tree)
arg_pairs = [(x, jnp.zeros_like(x)) for x in args_flat]
out_pairs_flat = doubling_transform(f_flat).call_wrapped(*arg_pairs)
out_flat = [hi + lo for hi, lo in out_pairs_flat]
out = tree_unflatten(out_tree(), out_flat)
return out
return wrapped |
|
The logic there is: we basically want to generate a function def f_flat(*flat_args):
args = unflatten(flat_args)
out = f(*args)
out_flat = flatten(out)
return out_flatWhat should we flatten to and from? Well, the caller needs to flatten the That's what @lu.transformation_with_aux
def flatten_fun_nokwargs(in_tree, *args_flat):
py_args = tree_unflatten(in_tree, args_flat)
ans = yield py_args, {}
out_flat, out_tree = tree_flatten(ans)
yield out_flat, out_treeIt's written in an unfamiliar way, but the logic is just the above: we pass in |
mattjj
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is so cool! Nice work, I can't wait to learn more. I also can't wait for doubledouble(doubledouble(f)), but no need to block on it :)
|
Tests pass internally; I'm going to merge and then keep iterating. |
|
Is it possible to use this transformation on linear algebra functions? I would like to apply it to matrix inversion. |
|
In theory, yes. In practice, no: doubling precision rules have not yet been
defined for any linear algebra routines.
…On Sat, Apr 17, 2021 at 12:48 PM gugar20 ***@***.***> wrote:
Is it possible to use this transformation on linear algebra functions? I
would like to apply it to matrix inversion.
—
You are receiving this because you commented.
Reply to this email directly, view it on GitHub
<#3465 (comment)>, or
unsubscribe
<https://github.com/notifications/unsubscribe-auth/AAJJFVVDP667YZYEHMOV3XDTJHQYZANCNFSM4N775S4Q>
.
|
This PR adds an experimental precision doubling transform, following the basic approach outlined in Dekker 1971 (pdf). When this transform is applied, the number of significant bits is approximately doubled compared to the base operation.
Simple demo:
This initial experiment supports basic arithmetic operators and inequalities.