diff --git a/neural_compressor/compression/pruner/patterns/ninm.py b/neural_compressor/compression/pruner/patterns/ninm.py index b14edc8bc9f..d02508cbbcd 100644 --- a/neural_compressor/compression/pruner/patterns/ninm.py +++ b/neural_compressor/compression/pruner/patterns/ninm.py @@ -368,8 +368,8 @@ def get_pattern_lock_masks(self, modules): mask = torch.ones(orig_shape, device=weight.device) pattern_lock_masks[key] = mask.bool() continue - mask = self.get_least_ninm_mask_from_data(weight) - mask = self._reshape_2dims_to_orig(mask, orig_shape) + reduced_mask = self.get_reduced_masks_from_data(weight, key) + mask = self.reshape_reduced_to_orig(reduced_mask, key, orig_shape) pattern_lock_masks[key] = mask return pattern_lock_masks diff --git a/neural_compressor/compression/pruner/pruners/progressive.py b/neural_compressor/compression/pruner/pruners/progressive.py index 2e2a460889c..d972c712a18 100644 --- a/neural_compressor/compression/pruner/pruners/progressive.py +++ b/neural_compressor/compression/pruner/pruners/progressive.py @@ -23,11 +23,12 @@ from ..regs import get_reg from ..schedulers import get_scheduler from ..utils import logger, torch -from .base import PytorchBasePruner, register_pruner +from .base import register_pruner +from .basic import PytorchBasicPruner @register_pruner("pt_progressive") -class PytorchProgressivePruner(PytorchBasePruner): +class PytorchProgressivePruner(PytorchBasicPruner): """Pruning Pruner. A Pruner class derived from BasicPruner. In this pruner, mask interpolation will be applied. @@ -207,12 +208,12 @@ def update_masks_progressive(self, local_step): for n in self.masks.keys(): self.pre_masks[n] = self.masks[n].clone() # update new masks - if not self.use_progressive: - self.masks = self.pattern.get_masks( - self.criterion.scores, - current_target_sparsity_ratio, - self.masks, - ) + # if not self.use_progressive: + # self.masks = self.pattern.get_masks( + # self.criterion.scores, + # current_target_sparsity_ratio, + # self.masks, + # ) self.masks = self.pattern.get_masks( self.criterion.scores, current_target_sparsity_ratio,