You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I think there may be a bug with state initialization in the optimizer. Specifically, because the gradients are on the GPU and the states are initialized on the CPU, there is an error coming because of tensors on two different devices. I investigated the code and compared to code of other PyTorch optimizers and noticed a couple of things that could be causing this issue.
Typically when the states are initialized, the torch.zeros_like function is passed with memory_format=torch.preserve_format so that it has the same format as the input tensor, which is usually the model parameters. However, in this case, since it's happening in the __init__ function, the model parameters might not be on the GPU yet. So often, the PyTorch optimizer step includes the initialization code, where there is a check for len(state)==0 in order to initialize.
I changed the optimizer code to follow this sort of pattern and the code runs without issue. I will point out that I am using fastai, so it is possible that this is a fastai-specific issue, but to me it seems like this could be a major issue for other users as well.
The text was updated successfully, but these errors were encountered:
I've had another user run into this issue as well. I've created the "inline" branch which has an implementation that initializes the optimizer within the step instead. I'm considering merging that branch, sounds like it would be a good idea.
I think there may be a bug with state initialization in the optimizer. Specifically, because the gradients are on the GPU and the states are initialized on the CPU, there is an error coming because of tensors on two different devices. I investigated the code and compared to code of other PyTorch optimizers and noticed a couple of things that could be causing this issue.
Typically when the states are initialized, the
torch.zeros_like
function is passed withmemory_format=torch.preserve_format
so that it has the same format as the input tensor, which is usually the model parameters. However, in this case, since it's happening in the__init__
function, the model parameters might not be on the GPU yet. So often, the PyTorch optimizer step includes the initialization code, where there is a check forlen(state)==0
in order to initialize.I changed the optimizer code to follow this sort of pattern and the code runs without issue. I will point out that I am using fastai, so it is possible that this is a fastai-specific issue, but to me it seems like this could be a major issue for other users as well.
The text was updated successfully, but these errors were encountered: