Skip to content

Commit

Permalink
pytorch prune new api (#1212)
Browse files Browse the repository at this point in the history
  • Loading branch information
YIYANGCAI committed Sep 6, 2022
1 parent 498ac48 commit 6cec70b
Show file tree
Hide file tree
Showing 10 changed files with 1,031 additions and 29 deletions.
38 changes: 34 additions & 4 deletions neural_compressor/conf/config.py
Expand Up @@ -50,16 +50,27 @@ def constructor(loader, node):
@constructor_register
class Pruner():
def __init__(self, start_epoch=None, end_epoch=None, initial_sparsity=None,
target_sparsity=None, update_frequency=1, prune_type='basic_magnitude',
method='per_tensor', names=[], parameters=None):
target_sparsity=None, update_frequency=1,
method='per_tensor',
prune_type='basic_magnitude',##for pytorch pruning, these values should be None
start_step=None, end_step=None, update_frequency_on_step=None, prune_domain=None,
sparsity_decay_type=None, pattern="tile_pattern_1x1", names=None, exclude_names=None, parameters=None):
self.start_epoch = start_epoch
self.end_epoch = end_epoch
self.update_frequency = update_frequency
self.target_sparsity = target_sparsity
self.initial_sparsity = initial_sparsity
self.update_frequency = update_frequency
assert prune_type.replace('_', '') in [i.lower() for i in PRUNERS], \
'now only support {}'.format(PRUNERS.keys())
self.start_step = start_step
self.end_step = end_step
self.update_frequency_on_step = update_frequency_on_step
self.prune_domain = prune_domain
self.sparsity_decay_type = sparsity_decay_type
self.exclude_names = exclude_names
self.pattern = pattern
## move this to experimental/pruning to support dynamic pruning
# assert prune_type.replace('_', '') in [i.lower() for i in PRUNERS], \
# 'now only support {}'.format(PRUNERS.keys())
self.prune_type = prune_type
self.method = method
self.names= names
Expand Down Expand Up @@ -663,15 +674,33 @@ def percent_to_float(data):
weight_compression_schema = Schema({
Optional('initial_sparsity', default=0): And(float, lambda s: s < 1.0 and s >= 0.0),
Optional('target_sparsity', default=0.97): float,
Optional('max_sparsity_ratio_per_layer', default=0.98):float,
Optional('prune_type', default="basic_magnitude"): str,
Optional('start_epoch', default=0): int,
Optional('end_epoch', default=4): int,
Optional('start_step', default=0): int,
Optional('end_step', default=0): int,
Optional('update_frequency', default=1.0): float,
Optional('update_frequency_on_step', default=1):int,
Optional('not_to_prune_names', default=[]):list,
Optional('prune_domain', default="global"): str,
Optional('names', default=[]): list,
Optional('exclude_names', default=None): list,
Optional('prune_layer_type', default=None): list,
Optional('sparsity_decay_type', default="exp"): str,
Optional('pattern', default="tile_pattern_1x1"): str,

Optional('pruners'): And(list, \
lambda s: all(isinstance(i, Pruner) for i in s))
})

# weight_compression_pytorch_schema = Schema({},ignore_extra_keys=True)

approach_schema = Schema({
Hook('weight_compression', handler=_valid_prune_sparsity): object,
Hook('weight_compression_pytorch', handler=_valid_prune_sparsity): object,
Optional('weight_compression'): weight_compression_schema,
Optional('weight_compression_pytorch'): weight_compression_schema,
})

default_workspace = './nc_workspace/{}/'.format(
Expand Down Expand Up @@ -1498,6 +1527,7 @@ class Pruning_Conf(Conf):

def __init__(self, cfg=None):
if isinstance(cfg, str):
self._read_cfg(cfg)
self.usr_cfg = DotDict(self._read_cfg(cfg))
elif isinstance(cfg, DotDict):
self.usr_cfg = DotDict(schema.validate(self._convert_cfg(
Expand Down
69 changes: 44 additions & 25 deletions neural_compressor/experimental/pruning.py
Expand Up @@ -23,6 +23,7 @@
from ..model import BaseModel
from ..adaptor import FRAMEWORKS
from ..conf.config import PruningConf

from warnings import warn

class Pruning(Component):
Expand Down Expand Up @@ -86,16 +87,22 @@ def _on_epoch_end(self):
res = []
for pruner in self.pruners:
res.append(pruner.on_epoch_end())
stats, sparsity = self._model.report_sparsity()
logger.info(stats)
logger.info(sparsity)
if hasattr(self, "_model"):
stats, sparsity = self._model.report_sparsity()
logger.info(stats)
logger.info(sparsity)
return res

def _on_train_end(self):
""" called after training """
for pruner in self.pruners:
pruner.on_train_end()

def _on_after_optimizer_step(self):
""" called after optimzier step """
for pruner in self.pruners:
pruner.on_after_optimizer_step()

def pre_process(self):
assert isinstance(self._model, BaseModel), 'need set neural_compressor Model for pruning....'

Expand Down Expand Up @@ -181,28 +188,40 @@ def generate_hooks(self):

def generate_pruners(self):
for name in self.cfg.pruning.approach:
assert name == 'weight_compression', 'now we only support weight_compression'
for pruner in self.cfg.pruning.approach.weight_compression.pruners:
if pruner.prune_type == 'basic_magnitude':
self.pruners.append(PRUNERS['BasicMagnitude'](\
self._model, \
pruner,
self.cfg.pruning.approach.weight_compression))
if pruner.prune_type == 'pattern_lock':
self.pruners.append(PRUNERS['PatternLock'](\
self._model, \
pruner,
self.cfg.pruning.approach.weight_compression))
elif pruner.prune_type == 'gradient_sensitivity':
self.pruners.append(PRUNERS['GradientSensitivity'](\
self._model, \
pruner,
self.cfg.pruning.approach.weight_compression))
elif pruner.prune_type == 'group_lasso':
self.pruners.append(PRUNERS['GroupLasso'](\
self._model, \
pruner,
self.cfg.pruning.approach.weight_compression))
assert name == 'weight_compression' or name == "weight_compression_pytorch", \
'now we only support weight_compression and weight_compression_pytorch'

if self.cfg.pruning.approach.weight_compression_pytorch != None:
from .pytorch_pruner.pruning import Pruning as PytorchPruning
self.pytorch_pruner = PytorchPruning(self.cfg)
self.pruners.append(self.pytorch_pruner)


if self.cfg.pruning.approach.weight_compression != None:
for pruner in self.cfg.pruning.approach.weight_compression.pruners:
if pruner.prune_type == 'basic_magnitude':
self.pruners.append(PRUNERS['BasicMagnitude'](\
self._model, \
pruner,
self.cfg.pruning.approach.weight_compression))
elif pruner.prune_type == 'pattern_lock':
self.pruners.append(PRUNERS['PatternLock'](\
self._model, \
pruner,
self.cfg.pruning.approach.weight_compression))
elif pruner.prune_type == 'gradient_sensitivity':
self.pruners.append(PRUNERS['GradientSensitivity'](\
self._model, \
pruner,
self.cfg.pruning.approach.weight_compression))
elif pruner.prune_type == 'group_lasso':
self.pruners.append(PRUNERS['GroupLasso'](\
self._model, \
pruner,
self.cfg.pruning.approach.weight_compression))
else:
##print(pruner.prune_type)
assert False, 'now only support {}'.format(PRUNERS.keys())

def __call__(self):
"""The main entry point of pruning.
Expand Down
16 changes: 16 additions & 0 deletions neural_compressor/experimental/pytorch_pruner/__init__.py
@@ -0,0 +1,16 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
#
# Copyright (c) 2022 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
22 changes: 22 additions & 0 deletions neural_compressor/experimental/pytorch_pruner/logger.py
@@ -0,0 +1,22 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
#
# Copyright (c) 2022 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

try:
from ...utils import logger
except:
import logging
logger = logging.getLogger(__name__)

0 comments on commit 6cec70b

Please sign in to comment.