Skip to content

Best practice: scan vs fori_loop/while_loop #3850

Answered by mattjj
AdrienCorenflos asked this question in General
Discussion options

You must be logged in to vote

Thanks for the question!

Our slogan is, "always scan when you can!"

This doesn't apply to your example, but in general it's also a good idea to use what distinguishes scan from fori_loop when you can, i.e. the scanned-over inputs and outputs rather than the loop carry (since fori_loop only has the loop carry). When you use scanned-over inputs and outputs instead of using the loop carry, it lets JAX generate more efficient differentiation code. The reason is pretty straightforward: we need to save data from each loop iteration for the forward pass to be consumed on the backward pass. When that data is in the loop carry we basically have to snapshot the whole loop carry for each iteration, …

Replies: 12 comments 5 replies

Comment options

You must be logged in to vote
1 reply
@jecampagne
Comment options

Answer selected by shoyer
Comment options

You must be logged in to vote
0 replies
Comment options

You must be logged in to vote
0 replies
Comment options

You must be logged in to vote
0 replies
Comment options

You must be logged in to vote
0 replies
Comment options

You must be logged in to vote
0 replies
Comment options

You must be logged in to vote
0 replies
Comment options

You must be logged in to vote
0 replies
Comment options

You must be logged in to vote
0 replies
Comment options

You must be logged in to vote
0 replies
Comment options

You must be logged in to vote
0 replies
Comment options

You must be logged in to vote
4 replies
@mattjj
Comment options

@shoyer
Comment options

@shoyer
Comment options

@bionicles
Comment options

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
6 participants
Converted from issue

This discussion was converted from issue #3850 on July 30, 2020 23:55.