Skip to content

Commit

Permalink
Add periodic CUDA cache cleanup (#882)
Browse files Browse the repository at this point in the history
Summary:
This adds a periodic call to `torch.cuda.empty_cache()` in order to
mitigate memory fragmentation in the PyTorch CUDA cached allocator
that can cause OOMs on models approaching GPU memory limit.
By default, this will occur every 64 updates.

Performance considerations:

- I've benchmarked this on a reasonably large model with memory
  footprint 16 GB, and the overhead with the default setting is <0.2%.
  With `update-freq > 1`, the cost is mitigated even further.
- This behavior can be disabled with a value of zero.
Pull Request resolved: fairinternal/fairseq-py#882

Differential Revision: D17742386

Pulled By: jma127

fbshipit-source-id: 68d8f93f798d6818b5efc3d67d43b52dfb8b2865
  • Loading branch information
jma127 authored and facebook-github-bot committed Oct 4, 2019
1 parent de348d1 commit 315c463
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 0 deletions.
2 changes: 2 additions & 0 deletions fairseq/options.py
Expand Up @@ -193,6 +193,8 @@ def get_parser(desc, default_task='translation'):
help='threshold FP16 loss scale from below')
parser.add_argument('--user-dir', default=None,
help='path to a python module containing custom extensions (tasks and/or architectures)')
parser.add_argument('--empty-cache-freq', default=0, type=int,
help='how often to clear the PyTorch CUDA cache (0 to disable)')

from fairseq.registry import REGISTRIES
for registry_name, REGISTRY in REGISTRIES.items():
Expand Down
8 changes: 8 additions & 0 deletions fairseq/trainer.py
Expand Up @@ -426,6 +426,14 @@ def maybe_no_sync():

if 'nll_loss' in logging_output:
self.meters['train_nll_loss'].update(logging_output.get('nll_loss', 0), ntokens)

# clear CUDA cache to reduce memory fragmentation
if (self.args.empty_cache_freq > 0 and
((self.get_num_updates() + self.args.empty_cache_freq - 1) %
self.args.empty_cache_freq) == 0 and
torch.cuda.is_available() and
not self.args.cpu):
torch.cuda.empty_cache()
except OverflowError as e:
print('| WARNING: overflow detected, ' + str(e))
self.zero_grad()
Expand Down

0 comments on commit 315c463

Please sign in to comment.