Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions neural_compressor/compression/pruner/patterns/ninm.py
Original file line number Diff line number Diff line change
Expand Up @@ -368,8 +368,8 @@ def get_pattern_lock_masks(self, modules):
mask = torch.ones(orig_shape, device=weight.device)
pattern_lock_masks[key] = mask.bool()
continue
mask = self.get_least_ninm_mask_from_data(weight)
mask = self._reshape_2dims_to_orig(mask, orig_shape)
reduced_mask = self.get_reduced_masks_from_data(weight, key)
mask = self.reshape_reduced_to_orig(reduced_mask, key, orig_shape)
pattern_lock_masks[key] = mask
return pattern_lock_masks

Expand Down
17 changes: 9 additions & 8 deletions neural_compressor/compression/pruner/pruners/progressive.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,12 @@
from ..regs import get_reg
from ..schedulers import get_scheduler
from ..utils import logger, torch
from .base import PytorchBasePruner, register_pruner
from .base import register_pruner
from .basic import PytorchBasicPruner


@register_pruner("pt_progressive")
class PytorchProgressivePruner(PytorchBasePruner):
class PytorchProgressivePruner(PytorchBasicPruner):
"""Pruning Pruner.

A Pruner class derived from BasicPruner. In this pruner, mask interpolation will be applied.
Expand Down Expand Up @@ -207,12 +208,12 @@ def update_masks_progressive(self, local_step):
for n in self.masks.keys():
self.pre_masks[n] = self.masks[n].clone()
# update new masks
if not self.use_progressive:
self.masks = self.pattern.get_masks(
self.criterion.scores,
current_target_sparsity_ratio,
self.masks,
)
# if not self.use_progressive:
# self.masks = self.pattern.get_masks(
# self.criterion.scores,
# current_target_sparsity_ratio,
# self.masks,
# )
self.masks = self.pattern.get_masks(
self.criterion.scores,
current_target_sparsity_ratio,
Expand Down