You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
How do you deal with the bn running mean/variance? Because the BatchNorm would be calculated twice (once during the forward pass and once during recomputation in the backward pass), and the running mean&var would updated twice.
The text was updated successfully, but these errors were encountered:
This is a good point. Ideally, PyTorch's batch norm layers should be smart enough to update the running mean/var appropriately with the checkpointing operation.
If this is not the case, then you should raise an issue with PyTorch, since the checkpointing/batch norm layers are part of their library, not this library.
How do you deal with the bn running mean/variance? Because the BatchNorm would be calculated twice (once during the forward pass and once during recomputation in the backward pass), and the running mean&var would updated twice.
The text was updated successfully, but these errors were encountered: