diff --git a/fairseq/options.py b/fairseq/options.py index 06a52b62ba..c33e1ac8e9 100644 --- a/fairseq/options.py +++ b/fairseq/options.py @@ -193,6 +193,8 @@ def get_parser(desc, default_task='translation'): help='threshold FP16 loss scale from below') parser.add_argument('--user-dir', default=None, help='path to a python module containing custom extensions (tasks and/or architectures)') + parser.add_argument('--empty-cache-freq', default=0, type=int, + help='how often to clear the PyTorch CUDA cache (0 to disable)') from fairseq.registry import REGISTRIES for registry_name, REGISTRY in REGISTRIES.items(): diff --git a/fairseq/trainer.py b/fairseq/trainer.py index 8e911a2174..03a1333ff1 100644 --- a/fairseq/trainer.py +++ b/fairseq/trainer.py @@ -426,6 +426,14 @@ def maybe_no_sync(): if 'nll_loss' in logging_output: self.meters['train_nll_loss'].update(logging_output.get('nll_loss', 0), ntokens) + + # clear CUDA cache to reduce memory fragmentation + if (self.args.empty_cache_freq > 0 and + ((self.get_num_updates() + self.args.empty_cache_freq - 1) % + self.args.empty_cache_freq) == 0 and + torch.cuda.is_available() and + not self.args.cpu): + torch.cuda.empty_cache() except OverflowError as e: print('| WARNING: overflow detected, ' + str(e)) self.zero_grad()