Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Performance] Redirect AllocWorkspace to PyTorch's allocator if available #4199

Merged
merged 9 commits into from
Jul 7, 2022

Conversation

yaox12
Copy link
Collaborator

@yaox12 yaox12 commented Jul 1, 2022

Description

Related issues: #3933 #3957.

Redirect AllocWorkspace/FreeWorkspace to PyTorch's allocator via raw_alloc and raw_delete.

I run examples/pytorch/graphsage/node_classification.py and get the GPU memory footprints as:

nvidia-smi max_allocated allocated max_reserved reserved
new allocator + pure_gpu 10629 9944 5542 9971 9971
old allocator + pure_gpu 10645 7930 5243 7958 7958
new allocator + uva 2531 550 266 1480 1480
old allocator + uva 2591 550 265 1241 1241

*The four columns on the right are reported by torch.cuda.max_allocated/allocated/max_reserved/reserved.

The total GPU memory footprints are close. Advantages are:

  1. Users can release the reserved GPU memory via PyTorch's APIs if they'd like to.
  2. When PyTorch reports OOM, users won't see a big discrepancy between the memory used by PyTorch and the GPUs capacity.

Checklist

Please feel free to remove inapplicable items for your PR.

  • The PR title starts with [$CATEGORY] (such as [NN], [Model], [Doc], [Feature]])
  • Changes are complete (i.e. I finished coding on this PR)
  • All changes have test coverage
  • Code is well-documented
  • To the best of my knowledge, examples are either not affected by this change,
    or have been fixed to be compatible with this change
  • Related issue is referred in this PR

@dgl-bot
Copy link
Collaborator

dgl-bot commented Jul 1, 2022

To trigger regression tests:

  • @dgl-bot run [instance-type] [which tests] [compare-with-branch];
    For example: @dgl-bot run g4dn.4xlarge all dmlc/master or @dgl-bot run c5.9xlarge kernel,api dmlc/master

@dgl-bot

This comment was marked as outdated.

@dgl-bot

This comment was marked as outdated.

@dgl-bot

This comment was marked as outdated.

@jermainewang jermainewang added the Release Candidate Candidate PRs for the upcoming release label Jul 4, 2022
@dgl-bot
Copy link
Collaborator

dgl-bot commented Jul 5, 2022

Commit ID: c07819c

Build ID: 4

Status: ✅ CI test succeeded

Report path: link

Full logs path: link

@yaox12 yaox12 requested a review from nv-dlasalle July 5, 2022 09:32
@yaox12
Copy link
Collaborator Author

yaox12 commented Jul 6, 2022

More backgrounds on PyTorch's CUDA allocator:

  1. Call stack of relevant objects: CudaCachingAllocator device_allocatorTHCCachingAllocator caching_allocatorstd::vector<std::unique_ptr<DeviceCachingAllocator>> device_allocator

  2. There’re mainly two ways in PyTorch to allocate CUDA memory:

    1. c10::cuda::CUDACachingAllocator::get()->allocate() [code]. It uses CudaCachingAllocator device_allocator and has a life cycle management (will be freed automatically). Inside, it calls THCCachingAllocator caching_allocator for the actual allocation. [code]
    2. c10::cuda::CUDACachingAllocator::raw_alloc()/raw_delete() [code]. It uses THCCachingAllocator caching_allocator directly and requires being freed via raw_delete manually.

cc @nv-dlasalle

@dgl-bot
Copy link
Collaborator

dgl-bot commented Jul 7, 2022

Commit ID: d210032

Build ID: 5

Status: ✅ CI test succeeded

Report path: link

Full logs path: link

@dgl-bot
Copy link
Collaborator

dgl-bot commented Jul 7, 2022

Commit ID: e2ec146

Build ID: 6

Status: ✅ CI test succeeded

Report path: link

Full logs path: link

@yaox12 yaox12 merged commit 9ee7ced into dmlc:master Jul 7, 2022
BarclayII pushed a commit to BarclayII/dgl that referenced this pull request Aug 10, 2022
@frozenbugs frozenbugs removed the Release Candidate Candidate PRs for the upcoming release label Jan 11, 2023
@shintarok111
Copy link

I encountered the same issue as in #3933. Has the problem been resolved?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

7 participants