You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
It would be great to have vmap support for lax.cond similar to the recently added vmap support for lax.while_loop. I know that I can use np.where as a workaround for some cases, but for computationally expensive if conditions a batched cond might be better.
The text was updated successfully, but these errors were encountered:
As I usually only need the if condition (no else part) I am considering a single iteration while_loop instead of cond as another workaround. What do you think?
It would be great to have vmap support for
lax.cond
similar to the recently added vmap support forlax.while_loop
. I know that I can usenp.where
as a workaround for some cases, but for computationally expensive if conditions a batched cond might be better.The text was updated successfully, but these errors were encountered: