Best practice: scan vs fori_loop/while_loop #3850
-
Hi, I'm a bit confused so as to what syntax I should use to loop over some data. In practice virtually any loop can be coded using a scan-like syntax but does that mean it should? The only difference I am aware of is the support for backprop in the scan case whilst fori_loop (or while_loop). Is there any other hidden one? Just to fix ideas below is a toy example which computes the same result, but has different behaviours in terms of gradient implementation: Thanks Adrien import jax.numpy as jnp
from jax import lax, jit
from jax.test_util import check_grads
@jit
def scan_fun(x, y):
# A loop using scan, note how I access y within the body
n = x.shape[0]
def body(carry, x):
curr, i = carry
return (curr + x * y[i], i+1), None
(res, _), _ = lax.scan(body, (0., 0), x)
return res / n
@jit
def loop_fun(x, y):
# A loop using fori_loop, note how I access both x and y within the body
n = x.shape[0]
def body(i, curr):
return curr + x[i] * y[i]
return lax.fori_loop(0, n, body, 0.) / n
n = 50
m = 150
x = jnp.arange(n, dtype=jnp.float32)
y = jnp.arange(m, step=2., dtype=jnp.float32)
assert loop_fun(x, y) == scan_fun(x, y)
check_grads(scan_fun, (x, y), 1, eps=1e-2) # All good
check_grads(loop_fun, (x, y), 1, eps=1e-2) # raises ValueError: Reverse-mode differentiation does not work for lax.while_loop or lax.fori_loop. Try using lax.scan instead. |
Beta Was this translation helpful? Give feedback.
Replies: 12 comments 5 replies
-
Thanks for the question! Our slogan is, "always scan when you can!" This doesn't apply to your example, but in general it's also a good idea to use what distinguishes That comment doesn't apply to the Zooming back out to the big picture, the only reason not to use (In fact, we'd even like to implement WDYT? |
Beta Was this translation helpful? Give feedback.
-
One other thought: my previous comment focused on the autodiff advantages of using |
Beta Was this translation helpful? Give feedback.
-
Thanks for the explanation Matt, that's crystal clear.
One point you raise w.r.t. the performance though makes me wonder. I don't see how this can be true when you account for the second output (the accumulated one), which is useless in a number of applications. That being said, I suppose this would be discarded statically if not used in a jitted function?
Adrien
EDIT: I apologize for the originally horrendous format of the reply (seems like the email answers don't bode well with GitHub yet)
…On Fri, 24 Jul 2020, 18:34 Matthew Johnson, ***@***.***> wrote:
One other thought: my previous comment focused on the autodiff advantages
of using scan instead of other loop constructs. But there may be
performance advantages too! Because scan provides more information to the
XLA compiler (namely a loop with a guaranteed static trip count, since XLA
programs are shape-specialized), the compiler can often do more
optimizations. So scan is often faster too, even without autodiff! Seems
like a good deal.
—
You are receiving this because you authored the thread.
Reply to this email directly, view it on GitHub
<https://github.com/google/jax/issues/3850#issuecomment-663598272>, or
unsubscribe
<https://github.com/notifications/unsubscribe-auth/AEYGFZY4WYFRCDINKILVILDR5GSZFANCNFSM4PGX43JA>
.
|
Beta Was this translation helpful? Give feedback.
-
As an FYI, I'm posting my real use case for the question: I am doing a linear interpolation on some data I know is sorted. Of course the function below is numba'd in "real life" import numpy as np
def sorted_interp(x, xp, fp):
""" Interpolates linearly between xp and fp BUT makes the assumption that x is sorted,
this allows the complextity to be O(n+m) instead of O(m * log(n)) and leverages array contiguity
(not like the simpler searchsorted implementation). Equivalent to np.interp(np.sort(x), xp, fp)
Parameters
----------
x: (M) array
locations at which to interpolate
xp: (N) array
locations of the given function values
fp: (N) array
values of the function at the given locations
Returns
-------
out: (M) array
the resulting interpolated values
"""
m = x.shape[0]
n = xp.shape[0]
x = np.atleast_1d(x)
j = 0
xp_0 = xp[0]
fp_0 = fp[0]
res = np.empty(x.shape, fp.dtype)
for i in range(m):
x_i = x[i]
if x_i <= xp_0:
res[i] = fp_0
continue
else:
while True:
xp_j = xp[j]
fp_j = fp[j]
next_xp_j = xp[j + 1]
next_fp_j = fp[j + 1]
if x_i > next_xp_j:
if j + 1 == n - 1:
res[i] = next_fp_j
break
else:
j += 1
continue
else:
if xp_j == next_xp_j:
val = fp_j
else:
val = fp_j + (next_fp_j - fp_j) * (x_i - xp_j) / (next_xp_j - xp_j)
res[i] = val
break
return res Now that I think of it, once I've implemented it in JAX, I may just leverage my use case (long live open source) to provide you guys with a jax implementation of np.interp as I think it's still missing. It may even prove faster than the numpy version... def inv(argsort, arr):
"""Computes the inverse of the permutation, but might be better
for autodiff to simply compute the inverse permutation and then index the array
"""
inverse = np.empty_like(arr)
for i, p in enumerate(argsort):
inverse[p] = arr[i]
return inverse
def interp(x, xp, fp):
"""Equivalent of np.interp that doesn't rely on searchsorted"""
argsort = np.argsort(x)
sorted_x= x[argsort]
interpolated_f = sorted_interp(sorted_x, xp, fp)
return inv(argsort, interpolated_f) |
Beta Was this translation helpful? Give feedback.
-
Side-note: a useful trick is that you can invert an argsort by argsorting the argsort 😁 def inv(argsort, arr):
return arr[np.argsort(argsort)] |
Beta Was this translation helpful? Give feedback.
-
Yes but it's a costly trick when one can do it in linear time :) |
Beta Was this translation helpful? Give feedback.
-
You can also calculate an inverse permutation in single call to array indexing: def inv(argsort, arr):
inverse = np.empty_like(arr)
inverse[argsort] = arr
return inverse In JAX, you would write this as |
Beta Was this translation helpful? Give feedback.
-
Just opened a new issue to discuss specifics and not spam the "scan vs loop" question: #3860 |
Beta Was this translation helpful? Give feedback.
-
Do you mean the scanned-over output? Not only will that be ignored if unused in a jitted function, but also it's optional; you can just scan a function with a None as its second output. I think we covered this one, so I'm going to close this issue! |
Beta Was this translation helpful? Give feedback.
-
Should we try converting this issue into a “Discussion” for posterity?
…On Mon, Jul 27, 2020 at 8:51 PM Matthew Johnson ***@***.***> wrote:
I don't see how this can be true when you account for the second output
(the accumulated one), which is useless in a number of applications. That
being said, I suppose this would be discarded statically if not used in a
jitted function?
Do you mean the scanned-over output? Not only will that be ignored if
unused in a jitted function, but also it's optional; you can just scan a
function with a None as its second output.
I think we covered this one, so I'm going to close this issue!
—
You are receiving this because you commented.
Reply to this email directly, view it on GitHub
<https://github.com/google/jax/issues/3850#issuecomment-664760258>, or
unsubscribe
<https://github.com/notifications/unsubscribe-auth/AAJJFVVLBWVUNC2I2P5VURTR5ZDKFANCNFSM4PGX43JA>
.
|
Beta Was this translation helpful? Give feedback.
-
@shoyer you just want to exercise your discussion-moving GitHub powers! SGTM! |
Beta Was this translation helpful? Give feedback.
-
how could we bail out of scans early to make a differentiable while loop? |
Beta Was this translation helpful? Give feedback.
Thanks for the question!
Our slogan is, "always scan when you can!"
This doesn't apply to your example, but in general it's also a good idea to use what distinguishes
scan
fromfori_loop
when you can, i.e. the scanned-over inputs and outputs rather than the loop carry (sincefori_loop
only has the loop carry). When you use scanned-over inputs and outputs instead of using the loop carry, it lets JAX generate more efficient differentiation code. The reason is pretty straightforward: we need to save data from each loop iteration for the forward pass to be consumed on the backward pass. When that data is in the loop carry we basically have to snapshot the whole loop carry for each iteration, …