Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUGFIX] Fix the lazy device copy issue of DGL node/edge features. #6564

Merged
merged 3 commits into from
Nov 23, 2023

Conversation

czkkkkkk
Copy link
Collaborator

Description

Resolve #6542.

In the following code snippet

import dgl
import torch
from dgl.dataloading import GraphDataLoader


dataset = dgl.data.QM9EdgeDataset()
dataloader = GraphDataLoader(dataset, batch_size=64)

graph_list = []

for batch_graph, _ in dataloader:
    batch_graph = batch_graph.to('cuda:0')
    split_graphs = dgl.unbatch(batch_graph)

    graph_list.extend([graph.cpu() for graph in split_graphs])

    print(f'memory allocation: {torch.cuda.memory_allocated() / 1024**2} Mb')

    torch.cuda.empty_cache()

The GPU memory of node and edge features of split_graphs is not released properly. This is because when copying the split graphs to CPU, the Column, which is the DGL internal feature storage, does not copy data to CPU when Column.to() is called. The copy only happen when Column.data is called. In my fix, I simply copy the data to the device in Column.to. I am not sure whether there is any risk to do in this way. @BarclayII, do you have concerns?

Checklist

Please feel free to remove inapplicable items for your PR.

  • The PR title starts with [$CATEGORY] (such as [NN], [Model], [Doc], [Feature]])
  • I've leverage the tools to beautify the python and c++ code.
  • The PR is complete and small, read the Google eng practice (CL equals to PR) to understand more about small PR. In DGL, we consider PRs with less than 200 lines of core code change are small (example, test and documentation could be exempted).
  • All changes have test coverage
  • Code is well-documented
  • To the best of my knowledge, examples are either not affected by this change, or have been fixed to be compatible with this change
  • Related issue is referred in this PR
  • If the PR is for a new model/paper, I've updated the example index here.

Changes

@dgl-bot
Copy link
Collaborator

dgl-bot commented Nov 14, 2023

To trigger regression tests:

  • @dgl-bot run [instance-type] [which tests] [compare-with-branch];
    For example: @dgl-bot run g4dn.4xlarge all dmlc/master or @dgl-bot run c5.9xlarge kernel,api dmlc/master

@czkkkkkk czkkkkkk changed the title [BUGFIX] Fix lazy the device copy issue of DGL node/edge features. [BUGFIX] Fix the lazy device copy issue of DGL node/edge features. Nov 14, 2023
@dgl-bot
Copy link
Collaborator

dgl-bot commented Nov 14, 2023

Commit ID: cd1d0a3

Build ID: 1

Status: ⚪️ CI test cancelled due to overrun.

Report path: link

Full logs path: link

@dgl-bot
Copy link
Collaborator

dgl-bot commented Nov 14, 2023

Commit ID: 00148d6

Build ID: 2

Status: ❌ CI test failed in Stage [Torch GPU Unit test].

Report path: link

Full logs path: link

python/dgl/frame.py Outdated Show resolved Hide resolved
@czkkkkkk
Copy link
Collaborator Author

Failed on some test cases. The lazy copy mechanism appears to be a feature.

@dgl-bot
Copy link
Collaborator

dgl-bot commented Nov 21, 2023

Commit ID: 68ca87009c18c7aa9e0e55b7e10ee42c33e67621

Build ID: 3

Status: ❌ CI test failed in Stage [Lint Check].

Report path: link

Full logs path: link

@dgl-bot
Copy link
Collaborator

dgl-bot commented Nov 22, 2023

Commit ID: fed626d61eb5e8f769431527786112261514f6e8

Build ID: 4

Status: ✅ CI test succeeded.

Report path: link

Full logs path: link

@dgl-bot
Copy link
Collaborator

dgl-bot commented Nov 22, 2023

Commit ID: c6c1dcd0ea70606bf1979c6ca6d96f6c1881f047

Build ID: 5

Status: ✅ CI test succeeded.

Report path: link

Full logs path: link

@czkkkkkk czkkkkkk merged commit c08f77b into dmlc:master Nov 23, 2023
2 checks passed
@czkkkkkk czkkkkkk deleted the fix6542 branch November 23, 2023 01:53
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

OOM issue with dgl.unbatch on GPU
3 participants