From e6ed5905726e448143ff08fcd5be23737a01c66f Mon Sep 17 00:00:00 2001 From: Benjamin Date: Mon, 14 Jun 2021 16:58:11 -0400 Subject: [PATCH 1/2] initialize M-FAC grad buffer based on masks, not param values --- src/sparseml/pytorch/optim/mask_pruning.py | 2 +- .../pytorch/optim/mask_pruning_scorer.py | 24 +++++++++++-------- 2 files changed, 15 insertions(+), 11 deletions(-) diff --git a/src/sparseml/pytorch/optim/mask_pruning.py b/src/sparseml/pytorch/optim/mask_pruning.py index e68c2558418..698a6426383 100644 --- a/src/sparseml/pytorch/optim/mask_pruning.py +++ b/src/sparseml/pytorch/optim/mask_pruning.py @@ -421,7 +421,7 @@ def pre_optim_step_update(self): updates scores and buffers that depend on gradients. Should be called before Optimizer.step() to grab the latest gradients """ - self._scorer.pre_optim_step_update() + self._scorer.pre_optim_step_update(self._param_masks) def pruning_end(self, leave_enabled: bool): """ diff --git a/src/sparseml/pytorch/optim/mask_pruning_scorer.py b/src/sparseml/pytorch/optim/mask_pruning_scorer.py index 482d903f551..c30de20ecb8 100644 --- a/src/sparseml/pytorch/optim/mask_pruning_scorer.py +++ b/src/sparseml/pytorch/optim/mask_pruning_scorer.py @@ -58,10 +58,12 @@ def score_parameters(self) -> List[Tensor]: """ raise NotImplementedError() - def pre_optim_step_update(self): + def pre_optim_step_update(self, masks: List[Tensor]): """ Perform any required logic for tracking Parameter data and gradients before an Optimizer step is applied to the model. + + :param masks: latest masks that are applied to these parameters """ pass @@ -226,9 +228,11 @@ def score_parameters(self) -> List[Tensor]: return self._movement_scores - def pre_optim_step_update(self): + def pre_optim_step_update(self, masks: List[Tensor]): """ Update movement scores based on the current Parameter weights and gradients + + :param masks: latest masks that are applied to these parameters """ self.check_regen_param_vals() for idx, param in enumerate(self._params): @@ -374,9 +378,11 @@ def score_parameters(self) -> List[Tensor]: return param_scores - def pre_optim_step_update(self): + def pre_optim_step_update(self, masks: List[Tensor]): """ Update the gradient buffer based on the current gradients + + :param masks: latest masks that are applied to these parameters """ if any(param.grad is None for param in self._params): @@ -384,7 +390,7 @@ def pre_optim_step_update(self): return if self._grad_buffer is None: - self._setup_grad_buffer() + self._setup_grad_buffer(masks) # get non-pruned grads non_pruned_grads = [ @@ -432,7 +438,7 @@ def mask_update(self, masks: List[Tensor], mask_diffs: List[Tensor]): self._latest_h_inv_diag = None # clear h_inv self._grads = None # clear grads - self._setup_grad_buffer() # reset grad buffer + self._setup_grad_buffer(masks) # reset grad buffer torch.cuda.empty_cache() # release GPU memory @staticmethod @@ -509,12 +515,10 @@ def _calc_params_perterb(self, mask_diffs): h_inv, diag = self._latest_h_inv_diag return h_inv.mul(-1.0 * weights_to_prune / diag) - def _setup_grad_buffer(self): + def _setup_grad_buffer(self, masks: Tensor): total_nonzero = 0 - for idx, param in enumerate(self._params): - self._unpruned_idxs[idx] = ( - param.view(-1).nonzero(as_tuple=False).reshape(-1) - ) + for idx, mask in enumerate(masks): + self._unpruned_idxs[idx] = mask.view(-1).nonzero(as_tuple=False).reshape(-1) total_nonzero += self._unpruned_idxs[idx].numel() # only track nonzero grads num_grads = self._mfac_options.get_num_grads_for_sparsity( From 4651a0caee955984b1d741dbc079d62a3225c8ca Mon Sep 17 00:00:00 2001 From: Benjamin Date: Mon, 14 Jun 2021 17:46:50 -0400 Subject: [PATCH 2/2] update tests --- tests/sparseml/pytorch/optim/test_mask_pruning_scorer.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/sparseml/pytorch/optim/test_mask_pruning_scorer.py b/tests/sparseml/pytorch/optim/test_mask_pruning_scorer.py index 8cc3031f127..0c88d4afbac 100644 --- a/tests/sparseml/pytorch/optim/test_mask_pruning_scorer.py +++ b/tests/sparseml/pytorch/optim/test_mask_pruning_scorer.py @@ -52,7 +52,8 @@ def test_pruning_scorer(score_type, n_updates): for i in range(n_updates): _fake_params_random_update(params) - scorer.pre_optim_step_update() + fake_masks = [(param != 0).type(param.dtype) for param in params] + scorer.pre_optim_step_update(fake_masks) scores = scorer.score_parameters() assert len(scores) == len(params)