-
Notifications
You must be signed in to change notification settings - Fork 3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Bug] Error when backward with retain_graph=True #1046
Comments
Hi, Why you set retain_graph to true? Do you need second-order derivative? |
Hypothetically yes, but did not encounter the problem for that. For me it was just a quick fix of mistakenly constructing Anyway I don't understand why we need to explicitly clear the backward cache, |
We found if we didn't do this, the cache would not be properly cleared by pytorch, which results in memory leak. |
@BarclayII is re-confirming the memory leak issue. The latest pytorch may have already resolved this issue. Either way, this is likely a bug on our side and we never test |
@maximillian91 Could you give a minimal example of import dgl
import numpy as np
import scipy.sparse as ssp
from dgl.nn.pytorch import SAGEConv
import torch
import torch.nn as nn
g = dgl.DGLGraph(ssp.random(20, 20, 0.2))
x = torch.randn(20, 10)
m = nn.Linear(10, 20)
g.ndata['x'] = x
g.ndata['w'] = m(x)
g.update_all(dgl.function.copy_u('w', 'm'), dgl.function.sum('m', 'y'))
g.update_all(dgl.function.copy_u('w', 'm'), dgl.function.max('m', 'z'))
loss = g.ndata['y'].sum()
loss2 = g.ndata['z'].sum()
loss.backward(retain_graph=True)
loss2.backward() As per second-order derivative, I'm not sure if/when we would support it, since it needs another kernel that computes such a derivative and so far I'm not aware of any model that needs this functionality. As per memory leak, please see #1060 , although I need to confirm a failing example of |
My "minimal" failing example became this based on my implementation of the Deep Tensor Neural Network K. Schütt 2017, where the error can be produced (and thereby also resolved) under 2 circumstances:
Here's the last 20 lines of the code failing and the rest is in the .zip.
with the following error message:
I know that my example is basically wrong and sloppy in the sense, that I do not need to backprob gradients through the |
@maximillian91 Your code ran fine with the latest |
@BarclayII I did now and the issue was solved when building from source in the latest |
❓ Questions and Help
Any reason for clearing the ctx.backward_cache explicitly in every backward() method?
This is causing errors when calling
loss.backward(retain_graph=True)
twice:TypeError: 'NoneType' object is not iterable
Here's an example:
typeerror-nonetype-object-is-not-iterable-in-loss-backward-in-pytorch
The text was updated successfully, but these errors were encountered: