Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Two slightly different process for Deq #11

Closed
SamChen opened this issue Aug 19, 2021 · 2 comments
Closed

Two slightly different process for Deq #11

SamChen opened this issue Aug 19, 2021 · 2 comments

Comments

@SamChen
Copy link

SamChen commented Aug 19, 2021

Hi Shaojie,

I found that there were two slightly different forward-backward process for Deq. One was in Chapter 4: Deep Equilibrium.

class DEQFixedPoint(nn.Module):
    def __init__(self, f, solver, **kwargs):
        super().__init__()
        self.f = f
        self.solver = solver
        self.kwargs = kwargs
        
    def forward(self, x):
        # compute forward pass and re-engage autograd tape
        with torch.no_grad():
            z, self.forward_res = self.solver(lambda z : self.f(z, x), torch.zeros_like(x), **self.kwargs)
        z = self.f(z,x)
        
        # set up Jacobian vector product (without additional forward calls)
        z0 = z.clone().detach().requires_grad_()
        f0 = self.f(z0,x)
        def backward_hook(grad):
            g, self.backward_res = self.solver(lambda y : autograd.grad(f0, z0, y, retain_graph=True)[0] + grad,
                                               grad, **self.kwargs)
            return g
                
        z.register_hook(backward_hook)
        return z

And, the second one was in this repo.

with torch.no_grad():
result = self.f_solver(lambda z: self.func(z, *func_args), z1s, threshold=f_thres, stop_mode=self.stop_mode)
z1s = result['result']
new_z1s = z1s
if (not self.training) and spectral_radius_mode:
with torch.enable_grad():
z1s.requires_grad_()
new_z1s = self.func(z1s, *func_args)
_, sradius = power_method(new_z1s, z1s, n_iters=150)
if self.training:
z1s.requires_grad_()
new_z1s = self.func(z1s, *func_args)
if compute_jac_loss:
jac_loss = jac_loss_estimate(new_z1s, z1s, vecs=1)
def backward_hook(grad):
if self.hook is not None:
# To avoid infinite loop
self.hook.remove()
torch.cuda.synchronize()
new_grad = self.b_solver(lambda y: autograd.grad(new_z1s, z1s, y, retain_graph=True)[0] + grad, \
torch.zeros_like(grad), threshold=b_thres)['result']
return new_grad
self.hook = new_z1s.register_hook(backward_hook)

I tried torch.autograd.gradcheck on both method using the exact same process from Chapter 4 on colab.
gradcheck(deq, torch.randn(1,2,3,3).cuda().double().requires_grad_(), check_undefined_grad=False)

Interestingly, only the method 1 works properly. The second method breaks my experiment session.
Here is my experiment code https://colab.research.google.com/drive/19vGpV16nbF5HRRKlFGScO-N1Js3NC4hj#scrollTo=kg2UmSW1x1R3

I also tried it on my workstation. I found that method 2 slowly ate all GPU memory and eventually return this message SystemError: <built-in method run_backward of torch._C._EngineBase object at 0x7f65d099e770> returned NULL without setting an error.
I think I triggered an infinite loop in backward solver although I already called torch.cuda.synchronize() in backward_hook function.

In this repo, I do not find similar code related to gradient checking. Moreover, method 2 is used in your Transformer-XL examples. I wander whether this means the memory hunger issue rarely happens in practical cases, like training a transformer.

Thanks :-)

My experiment environment:
workstation:

  • python: 3.8
  • pytorch: 1.9
  • cuda: 11.1
  • GPU: Nvidia3090

Google colab:
default environment with GPU.

@jerrybai1995
Copy link
Member

Hi @SamChen ,

Thanks for the question. The gradcheck fails because gradcheck works by backpropagating through the same computation graph multiple times (e.g., by adding eps to each entry of the vector output, and then backprop); whereas the self.hook.remove() already removed the hook upon the first backward call. Therefore, the code is correct for DEQ-Transformer training (where each iteration has exactly ONE backward pass through the DEQ), but is incorrect for repetitive backward passes (which is what gradcheck does).

The memory leak, I believe, is a pytorch-related issue. I'm not entirely sure about the source of this problem but pytorch 1.6 and 1.7 should both work well (i.e., no memory leak). If you encounter the SystemError and do not want to downgrade pytorch, then you can also use the tutorial implementation--- basically replacing the current L372-380 with:

z1s_copy = z1s.clone().detach().requires_grad_()
new_z1s_copy = self.func(z1s_copy , *func_args)      # Spend one more NFE in training forward
def backward_hook(grad): 
    new_grad = self.b_solver(lambda y: autograd.grad(new_z1s_copy , z1s_copy , y, retain_graph=True)[0] + grad, \ 
                                  torch.zeros_like(grad), threshold=b_thres)['result'] 
    return new_grad 
new_z1s.register_hook(backward_hook) 

Of course, this means you have to spend one more NFE in the forward pass of training, which means slightly more memory and computation (which is what the current implementation hoped to avoid). But this should help avoid the memory leak.

Let me know if this helps!

@SamChen
Copy link
Author

SamChen commented Aug 19, 2021

Thanks for the clear explanation. :-)
Your words about gradckeck explains why I saw it calls the DNN function over and over. And, of course, it is not related to the infinite loop

# To avoid infinite loop

@SamChen SamChen closed this as completed Aug 19, 2021
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants