Skip to content
This repository was archived by the owner on Jun 3, 2025. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/sparseml/pytorch/optim/mask_pruning.py
Original file line number Diff line number Diff line change
Expand Up @@ -421,7 +421,7 @@ def pre_optim_step_update(self):
updates scores and buffers that depend on gradients. Should be called
before Optimizer.step() to grab the latest gradients
"""
self._scorer.pre_optim_step_update()
self._scorer.pre_optim_step_update(self._param_masks)

def pruning_end(self, leave_enabled: bool):
"""
Expand Down
24 changes: 14 additions & 10 deletions src/sparseml/pytorch/optim/mask_pruning_scorer.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,10 +58,12 @@ def score_parameters(self) -> List[Tensor]:
"""
raise NotImplementedError()

def pre_optim_step_update(self):
def pre_optim_step_update(self, masks: List[Tensor]):
"""
Perform any required logic for tracking Parameter data and gradients before
an Optimizer step is applied to the model.

:param masks: latest masks that are applied to these parameters
"""
pass

Expand Down Expand Up @@ -226,9 +228,11 @@ def score_parameters(self) -> List[Tensor]:

return self._movement_scores

def pre_optim_step_update(self):
def pre_optim_step_update(self, masks: List[Tensor]):
"""
Update movement scores based on the current Parameter weights and gradients

:param masks: latest masks that are applied to these parameters
"""
self.check_regen_param_vals()
for idx, param in enumerate(self._params):
Expand Down Expand Up @@ -374,17 +378,19 @@ def score_parameters(self) -> List[Tensor]:

return param_scores

def pre_optim_step_update(self):
def pre_optim_step_update(self, masks: List[Tensor]):
"""
Update the gradient buffer based on the current gradients

:param masks: latest masks that are applied to these parameters
"""

if any(param.grad is None for param in self._params):
# only update buffer if all gradients are computed
return

if self._grad_buffer is None:
self._setup_grad_buffer()
self._setup_grad_buffer(masks)

# get non-pruned grads
non_pruned_grads = [
Expand Down Expand Up @@ -432,7 +438,7 @@ def mask_update(self, masks: List[Tensor], mask_diffs: List[Tensor]):

self._latest_h_inv_diag = None # clear h_inv
self._grads = None # clear grads
self._setup_grad_buffer() # reset grad buffer
self._setup_grad_buffer(masks) # reset grad buffer
torch.cuda.empty_cache() # release GPU memory

@staticmethod
Expand Down Expand Up @@ -509,12 +515,10 @@ def _calc_params_perterb(self, mask_diffs):
h_inv, diag = self._latest_h_inv_diag
return h_inv.mul(-1.0 * weights_to_prune / diag)

def _setup_grad_buffer(self):
def _setup_grad_buffer(self, masks: Tensor):
total_nonzero = 0
for idx, param in enumerate(self._params):
self._unpruned_idxs[idx] = (
param.view(-1).nonzero(as_tuple=False).reshape(-1)
)
for idx, mask in enumerate(masks):
self._unpruned_idxs[idx] = mask.view(-1).nonzero(as_tuple=False).reshape(-1)
total_nonzero += self._unpruned_idxs[idx].numel()
# only track nonzero grads
num_grads = self._mfac_options.get_num_grads_for_sparsity(
Expand Down
3 changes: 2 additions & 1 deletion tests/sparseml/pytorch/optim/test_mask_pruning_scorer.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,8 @@ def test_pruning_scorer(score_type, n_updates):

for i in range(n_updates):
_fake_params_random_update(params)
scorer.pre_optim_step_update()
fake_masks = [(param != 0).type(param.dtype) for param in params]
scorer.pre_optim_step_update(fake_masks)
scores = scorer.score_parameters()
assert len(scores) == len(params)

Expand Down