Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Multiple Backwards Pass Error with Retain_Graph=True #4078

Closed
smiret-intel opened this issue Jun 2, 2022 · 9 comments
Closed

Multiple Backwards Pass Error with Retain_Graph=True #4078

smiret-intel opened this issue Jun 2, 2022 · 9 comments
Assignees
Labels
bug:confirmed Something isn't working
Projects

Comments

@smiret-intel
Copy link

🐛 Bug

In order to run DGL based GNNs on use-cases like the OpenCatalyst project, multiple backwards passes are required, which are currently not supported in the latest DGL versions. OpenCatalyst is a major dataset for science based GNN applications and currently support only PyTorch Geometric, which has support for multiple backwards passes.
We believe that we have tracked this error to sparse.py and tensor.py in the backend folder for pytorch (linked here) where the context (ctx) gets set to None at various instances (example). We have attached preliminary patches that appear to circumvent this issue, but cause issues with distributed data parallel training, which are necessary for effective training in OpenCatalyst.

Files for the backend-patch and the minimal test script referenced below are attached.
dgl_bug_report_double_autograd.zip

To Reproduce

Steps to reproduce the behavior:

  1. Run the attached minimal test script (autograd_tester.py or simple-test.py inside the backend-patch directory):

Multiple backwards passes not possible giving the following error:

Traceback (most recent call last): File "autograd_tester.py", line 83, in <module> train(g, model) File "autograd_tester.py", line 71, in train loss.backward() File "/opt/conda/lib/python3.8/site-packages/torch/_tensor.py", line 363, in backward torch.autograd.backward(self, gradient, retain_graph, create_graph, inputs=inputs) File "/opt/conda/lib/python3.8/site-packages/torch/autograd/__init__.py", line 173, in backward Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass File "/opt/conda/lib/python3.8/site-packages/torch/autograd/function.py", line 253, in apply return user_fn(self, *args) File "/opt/conda/lib/python3.8/site-packages/torch/cuda/amp/autocast_mode.py", line 135, in decorate_bwd return bwd(*args, **kwargs) File "/opt/conda/lib/python3.8/site-packages/dgl/backend/pytorch/sparse.py", line 149, in backward gidx, op, reduce_op, X_shape, Y_shape, dtype, device, reduce_last = ctx.backward_cache TypeError: cannot unpack non-iterable NoneType object

Expected behavior

The minimal test script would go through without issues as multiple backwards passes can be done using retain_graph=True as with other pytorch functions. This should also work in a distributed data parallel setting and integrate with all other DGL functionalities.

Environment

  • DGL Version (e.g., 1.0): 0.8.2
  • Backend Library & Version: PyTorch
  • OS (e.g., Linux): Linux
  • How you installed DGL (conda, pip, source): Conda install
  • Build command you used (if compiling from source): conda in stall -c dglteam dgl-cuda11.3
  • Python version: 3.8
  • CUDA/cuDNN version (if applicable): 11.3
  • Any other relevant information:

Additional context

The simple patch we provided allows the simple example to go through, but there is still a problem with distributed data parallel training which fails with a segmentation fault. This indicates that the patch we provided does not address this. For our use case, this is very important to do training runs. DDP and multiple backwards passes worked with DGL version 0.5 and 0.6 and were referenced in a prior issue (#1046).

@BarclayII
Copy link
Collaborator

When do we need multiple backward passes? Does any of the example in OCP require multiple backward passes?

@jermainewang jermainewang added the bug:unconfirmed May be a bug. Need further investigation. label Jun 6, 2022
@smiret-intel
Copy link
Author

It's needed whenever forces are computed via the gradient, which is most models (e.g. DimeNet, SchNet). Since forces are the derivative of energy, and the output of the GNN model is energy, the first pass is to get forces and the second pass is to do the gradient updates for the GNN model.

@BarclayII
Copy link
Collaborator

I see. So does that mean we need to support second-order gradients of our message passing functions?

@jermainewang jermainewang added this to Issue in triage in DGL Tracker via automation Jun 13, 2022
@smiret-intel
Copy link
Author

Yes - ideally it would be multiple order gradients (minimum second order, but also potentially higher) with an API that is compatible with PyTorch's retain_graph=True and create_graph=True that works natively with PyTorch tensors (and also PyG).

@BarclayII
Copy link
Collaborator

I checked the code and it seems that the statement ctx.backward_cache = None is breaking the retain_graph functionality, and removing it no longer causes memory leak in PyTorch 1.10. Could you try removing that statement and see if it fixes your problem?

@smiret-intel
Copy link
Author

It fixes things temporarily, but the distributed data parallel (DDP) training does not work in that case. Whenever I have tried running things beyond the simple examples provided, there is a segmentation fault (both on GPU and CPU). I am not sure where a segmentation fault could materialize beyond the simple examples, but it has shown up consistently when training with DDP.

@smiret-intel
Copy link
Author

Is there a minimal script to test if DDP training can run with the retain_graph functionality?

@BarclayII
Copy link
Collaborator

Hmm after removing all the statement ctx.backward_cache = None in both python/dgl/backend/pytorch/{sparse,tensor}.py I tweaked examples/pytorch/graphsage/multi_gpu_node_classification.py (which uses DDP) with

diff --git a/examples/pytorch/graphsage/multi_gpu_node_classification.py b/examples/pytorch/graphsage/multi_gpu_node_classification.py
index 2631541b..b81b027e 100644
--- a/examples/pytorch/graphsage/multi_gpu_node_classification.py
+++ b/examples/pytorch/graphsage/multi_gpu_node_classification.py
@@ -127,7 +127,7 @@ def train(rank, world_size, graph, num_classes, split_idx):
             use_uva=True)

     durations = []
-    for _ in range(10):
+    for _ in range(50):
         model.train()
         t0 = time.time()
         for it, (input_nodes, output_nodes, blocks) in enumerate(train_dataloader):
@@ -136,6 +136,7 @@ def train(rank, world_size, graph, num_classes, split_idx):
             y_hat = model(blocks, x)
             loss = F.cross_entropy(y_hat, y)
             opt.zero_grad()
+            loss.backward(retain_graph=True)
             loss.backward()
             opt.step()
             if it % 20 == 0 and rank == 0:

It works without crashing.

@jermainewang jermainewang added bug:confirmed Something isn't working and removed bug:unconfirmed May be a bug. Need further investigation. labels Jun 23, 2022
@BarclayII
Copy link
Collaborator

Closing due to #4249 being merged. Please reopen if it does not fix your problem.

DGL Tracker automation moved this from Issue in triage to Done Jul 18, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug:confirmed Something isn't working
Projects
No open projects
Development

No branches or pull requests

3 participants