Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 25 additions & 17 deletions neural_compressor/experimental/pruning_v2.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
"""pruning module."""
#!/usr/bin/env python
# !/usr/bin/env python
# -*- coding: utf-8 -*-
#
# Copyright (c) 2021 Intel Corporation
Expand Down Expand Up @@ -30,13 +30,15 @@
from ..utils.utility import LazyImport
from ..pruner.pruners import get_pruner
from ..conf.pythonic_config import Config

LazyImport('torch.nn')
torch = LazyImport('torch')

from deprecated import deprecated
import importlib
import re


class Pruning(Component):
"""This is base class of pruning object.

Expand Down Expand Up @@ -71,12 +73,12 @@ def __init__(self, conf_fname_or_obj=None):
# yaml file
raise NotImplementedError("Only WeightPruningConfig config is supported currently.")
self.pruners_info = process_config(self.conf)
# self.model = None # here skip
# self.model = None # here skip
# align with old Component based API
# self._init_with_conf()
self.callbacks = dict(tf_pruning=TfPruningCallback)
self.pruners = []
self.generate_hooks() # place generate hooks here, to get rid of prepare() function.
self.generate_hooks() # place generate hooks here, to get rid of prepare() function.

def update_config(self, *args, **kwargs):
"""Add user-defined arguments to the original configurations.
Expand Down Expand Up @@ -134,6 +136,11 @@ def get_sparsity_ratio(self):
elementwise_over_all = float(
element_sparsity_cnt) / param_cnt

logger.info(
f"elementwise_over_matmul_gemm_conv:{elementwise_over_matmul_gemm_conv},"
f" elementwise_over_all:{elementwise_over_all},"
f"blockwise_over_matmul_gemm_conv:{blockwise_over_matmul_gemm_conv}")

return elementwise_over_matmul_gemm_conv, elementwise_over_all, blockwise_over_matmul_gemm_conv

def _on_train_begin(self, dataloader=None):
Expand Down Expand Up @@ -188,6 +195,7 @@ def _on_train_end(self):
"""Functions called after training."""
for pruner in self.pruners:
pruner.on_train_end()
self.get_sparsity_ratio()

def _on_before_eval(self):
"""Implement at the beginning of evaluation phase."""
Expand Down Expand Up @@ -227,16 +235,16 @@ def pre_process(self):
if self._train_dataloader is None and self._train_func is None:
train_dataloader_cfg = self.cfg.pruning.train.dataloader
assert train_dataloader_cfg is not None, \
'dataloader field of train field of pruning section ' \
'in yaml file should be configured as train_dataloader property is NOT set!'
'dataloader field of train field of pruning section ' \
'in yaml file should be configured as train_dataloader property is NOT set!'
train_dataloader_cfg.distributed = self.train_distributed
self._train_dataloader = create_dataloader(self.framework, train_dataloader_cfg)

if self._eval_dataloader is None and self._eval_func is None:
eval_dataloader_cfg = self.cfg.evaluation.accuracy.dataloader
assert eval_dataloader_cfg is not None, \
'dataloader field of evaluation ' \
'in yaml file should be configured as eval_dataloader property is NOT set!'
'dataloader field of evaluation ' \
'in yaml file should be configured as eval_dataloader property is NOT set!'
eval_dataloader_cfg.distributed = self.evaluation_distributed
self._eval_dataloader = create_dataloader(self.framework, eval_dataloader_cfg)

Expand All @@ -246,22 +254,22 @@ def pre_process(self):
assert train_cfg, "train field of pruning section in yaml file must " \
"be configured for pruning if pruning_func is NOT set."
self._train_func = create_train_func(self.framework, \
self.train_dataloader, \
self.adaptor, \
train_cfg, \
hooks=self.hooks, \
callbacks=self.callbacks)
self.train_dataloader, \
self.adaptor, \
train_cfg, \
hooks=self.hooks, \
callbacks=self.callbacks)
if self._eval_func is None:
# eval section in yaml file should be configured.
eval_cfg = self.cfg.evaluation
assert eval_cfg, "eval field of pruning section in yaml file must " \
"be configured for pruning if eval_func is NOT set."
"be configured for pruning if eval_func is NOT set."
self._eval_func = create_eval_func(self.framework, \
self.eval_dataloader, \
self.adaptor, \
eval_cfg.accuracy.metric, \
eval_cfg.accuracy.postprocess, \
fp32_baseline = False)
fp32_baseline=False)
if getattr(self.train_dataloader, 'distributed', False):
self.register_hook('on_train_begin', self.adaptor._pre_hook_for_hvd)

Expand All @@ -272,14 +280,14 @@ def execute(self):
"""
logger.info("Start to get the baseline model's score before pruning.")
self.baseline_score = self._eval_func(self._model if getattr(self._eval_func, 'builtin', None) \
else self._model.model)
else self._model.model)
logger.info("Baseline model's score is {}.".format(str(self.baseline_score)))
logger.info("Model pruning begins.")
self._train_func(self._model if getattr(self._train_func, 'builtin', None) \
else self._model.model)
else self._model.model)
logger.info("Model pruning is done. Start to evaluate the pruned model.")
self.last_score = self._eval_func(self._model if getattr(self._eval_func, 'builtin', None) \
else self._model.model)
else self._model.model)
logger.info("Pruned model score is {}.".format(str(self.last_score)))
return self._model

Expand Down
10 changes: 7 additions & 3 deletions neural_compressor/pruner/criteria.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,10 @@ def on_step_begin(self):
"""Calculate and store the pruning scores of pruning modules at the beginning of a step."""
pass

def on_before_optimizer_step(self):
"""Calculate and store the pruning scores of pruning modules before the optimizer step."""
pass

def on_after_optimizer_step(self):
"""Calculate and store the pruning scores of pruning modules after the optimizer step."""
pass
Expand Down Expand Up @@ -113,7 +117,7 @@ def __init__(self, modules, config):
super(GradientCriterion, self).__init__(modules, config)
assert self.config.end_step > 0, "please set end_step > 0 for gradient based criterion"

def on_after_optimizer_step(self):
def on_before_optimizer_step(self):
"""Calculate and store the pruning scores based on gradient criterion."""
with torch.no_grad():
for key in self.modules.keys():
Expand Down Expand Up @@ -143,7 +147,7 @@ def __init__(self, modules, config):
super(SnipCriterion, self).__init__(modules, config)
assert self.config.end_step > 0, "please set end_step > 0 for gradient based criterion"

def on_after_optimizer_step(self):
def on_before_optimizer_step(self):
"""Calculate and store the pruning scores based on snip criterion."""
##self.mask_weights()
with torch.no_grad():
Expand Down Expand Up @@ -180,7 +184,7 @@ def __init__(self, modules, config):
self.alpha = 0.9
self.beta = 1.0

def on_after_optimizer_step(self):
def on_before_optimizer_step(self):
"""Calculate and store the pruning scores based on snip_momentum criterion."""
with torch.no_grad():
for key in self.modules.keys():
Expand Down
6 changes: 4 additions & 2 deletions neural_compressor/pruner/pruners.py
Original file line number Diff line number Diff line change
Expand Up @@ -323,14 +323,15 @@ def update_masks(self, local_step):
def on_before_optimizer_step(self):
"""Implement before optimizer.step()."""
self.reg.on_before_optimizer_step()
self.criterion.on_before_optimizer_step()

def on_after_optimizer_step(self):
"""Prune the model after optimization."""
##the order of the following three lines can't not be exchanged
if self.global_step >= self.start_step and self.global_step <= self.end_step:
self.reg.on_after_optimizer_step()
self.mask_weights()
self.criterion.on_after_optimizer_step()

self.global_step += 1


Expand Down Expand Up @@ -563,6 +564,7 @@ def on_step_begin(self, local_step):
def on_before_optimizer_step(self):
"""Implement before optimizer.step()."""
self.reg.on_before_optimizer_step()
self.criterion.on_before_optimizer_step()

def on_after_optimizer_step(self):
"""Prune the model after optimization."""
Expand All @@ -573,7 +575,7 @@ def on_after_optimizer_step(self):
self.mask_weights()
else:
self.mask_weights_general(self.progressive_masks)
self.criterion.on_after_optimizer_step()

self.global_step += 1

def print_progressive_sparsity(self):
Expand Down