From 00f1cdf247e063cc40c27355ac691853acfcd2c6 Mon Sep 17 00:00:00 2001 From: Eldar Kurtic Date: Fri, 29 Jul 2022 12:27:27 +0200 Subject: [PATCH] Clear cache before and after the OBS pruning step Pytorch caching doesn't seem to work consistently across different versions, so let's make sure we clear it explicitly before and after the pruning step. --- .../pytorch/sparsification/pruning/modifier_pruning_obs.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/sparseml/pytorch/sparsification/pruning/modifier_pruning_obs.py b/src/sparseml/pytorch/sparsification/pruning/modifier_pruning_obs.py index b63faef357d..104ef14ea18 100644 --- a/src/sparseml/pytorch/sparsification/pruning/modifier_pruning_obs.py +++ b/src/sparseml/pytorch/sparsification/pruning/modifier_pruning_obs.py @@ -246,6 +246,7 @@ def check_mask_update( return # not a one-shot run _LOGGER.info("Running OBS Pruning") + torch.cuda.empty_cache() if self._scorer._is_main_proc: # collect grads for empirical inverse Fisher estimation self._scorer._enabled_grad_buffering = True @@ -254,6 +255,7 @@ def check_mask_update( self._scorer._enabled_grad_buffering = False super().check_mask_update(module, epoch, steps_per_epoch, **kwargs) + torch.cuda.empty_cache() def _get_mask_creator( self, param_names: List[str], params: List[Parameter]