Skip to content

Using MetaAdam results in 'RuntimeError: Trying to backward through the graph a second time' #191

Closed Answered by XuehaiPan
LabChameleon asked this question in Q&A
Discussion options

You must be logged in to vote

Doesn't moving the line torchopt.recover_state_dict(net, net_state) into the inner loop prevent the gradients from flowing from one inner loop iteration to the next one? Instead, now they get detached after every inner loop iteration, don't they?

Yes, you are correct.

The problem is that you should create an inner optimizer at the beginning of each outer loop. In your code snippet, the inner optimizer is shared across multiple outer loop optimization. You should either extract/recover and detach the state of the inner optimizer like you do for the network parameter, or recreate a new inner optimizer.

import torch
import torch.nn as nn
import torch.nn.functional as F

import torchopt


d…

Replies: 3 comments 1 reply

Comment options

You must be logged in to vote
0 replies
Comment options

You must be logged in to vote
1 reply
@XuehaiPan
Comment options

Answer selected by LabChameleon
Comment options

You must be logged in to vote
0 replies
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants