-
Notifications
You must be signed in to change notification settings - Fork 80
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Two slightly different process for Deq #11
Comments
Hi @SamChen , Thanks for the question. The gradcheck fails because The memory leak, I believe, is a pytorch-related issue. I'm not entirely sure about the source of this problem but pytorch 1.6 and 1.7 should both work well (i.e., no memory leak). If you encounter the z1s_copy = z1s.clone().detach().requires_grad_()
new_z1s_copy = self.func(z1s_copy , *func_args) # Spend one more NFE in training forward
def backward_hook(grad):
new_grad = self.b_solver(lambda y: autograd.grad(new_z1s_copy , z1s_copy , y, retain_graph=True)[0] + grad, \
torch.zeros_like(grad), threshold=b_thres)['result']
return new_grad
new_z1s.register_hook(backward_hook) Of course, this means you have to spend one more NFE in the forward pass of training, which means slightly more memory and computation (which is what the current implementation hoped to avoid). But this should help avoid the memory leak. Let me know if this helps! |
Thanks for the clear explanation. :-) deq/DEQ-Sequence/models/deq_transformer.py Line 374 in c161644
|
Hi Shaojie,
I found that there were two slightly different forward-backward process for Deq. One was in Chapter 4: Deep Equilibrium.
And, the second one was in this repo.
deq/DEQ-Sequence/models/deq_transformer.py
Lines 355 to 380 in c161644
I tried torch.autograd.gradcheck on both method using the exact same process from Chapter 4 on colab.
gradcheck(deq, torch.randn(1,2,3,3).cuda().double().requires_grad_(), check_undefined_grad=False)
Interestingly, only the method 1 works properly. The second method breaks my experiment session.
Here is my experiment code https://colab.research.google.com/drive/19vGpV16nbF5HRRKlFGScO-N1Js3NC4hj#scrollTo=kg2UmSW1x1R3
I also tried it on my workstation. I found that method 2 slowly ate all GPU memory and eventually return this message
SystemError: <built-in method run_backward of torch._C._EngineBase object at 0x7f65d099e770> returned NULL without setting an error
.I think I triggered an infinite loop in backward solver although I already called
torch.cuda.synchronize()
in backward_hook function.In this repo, I do not find similar code related to gradient checking. Moreover, method 2 is used in your Transformer-XL examples. I wander whether this means the memory hunger issue rarely happens in practical cases, like training a transformer.
Thanks :-)
My experiment environment:
workstation:
Google colab:
default environment with GPU.
The text was updated successfully, but these errors were encountered: