Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Fix]Dcff Deploy Revision #383

Merged
merged 12 commits into from
Dec 16, 2022
3 changes: 3 additions & 0 deletions configs/pruning/mmcls/dcff/dcff_compact_resnet_8xb32_in1k.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,7 @@

# model settings
model = _base_.model
# Avoid pruning_ratio check in mutator
model['fix_subnet'] = 'configs/pruning/mmcls/dcff/fix_subnet.yaml'
model['target_pruning_ratio'] = None
model['is_deployed'] = True
4 changes: 4 additions & 0 deletions configs/pruning/mmcls/dcff/dcff_resnet_8xb32_in1k.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,11 @@
parse_cfg=dict(
type='BackwardTracer',
loss_calculator=dict(type='ImageClassifierPseudoLoss'))),
fix_subnet=None,
data_preprocessor=None,
target_pruning_ratio=target_pruning_ratio,
step_freq=1,
linear_schedule=False,
is_deployed=False)

val_cfg = dict(_delete_=True, type='mmrazor.ItePruneValLoop')
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,8 @@

# model settings
model = _base_.model
model = _base_.model
# Avoid pruning_ratio check in mutator
model['fix_subnet'] = 'configs/pruning/mmdet/dcff/fix_subnet.yaml'
model['target_pruning_ratio'] = None
model['is_deployed'] = True
Original file line number Diff line number Diff line change
Expand Up @@ -65,19 +65,14 @@
_delete_=True)
train_cfg = dict(max_epochs=120, val_interval=1)

# !dataset config
# ==========================================================================
# data preprocessor

model = dict(
_scope_='mmrazor',
type='DCFF',
architecture=_base_.architecture,
mutator_cfg=dict(
type='DCFFChannelMutator',
channel_unit_cfg=dict(
type='DCFFChannelUnit',
units='configs/pruning/mmdet/dcff/resnet_det.json'),
type='DCFFChannelUnit', default_args=dict(choice_mode='ratio')),
parse_cfg=dict(
type='BackwardTracer',
loss_calculator=dict(type='TwoStageDetectorPseudoLoss'))),
Expand All @@ -89,4 +84,4 @@
model_wrapper = dict(
type='mmcv.MMDistributedDataParallel', find_unused_parameters=True)

val_cfg = dict(_delete_=True)
val_cfg = dict(_delete_=True, type='mmrazor.ItePruneValLoop')
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,8 @@

# model settings
model = _base_.model
model = _base_.model
# Avoid pruning_ratio check in mutator
model['fix_subnet'] = 'configs/pruning/mmpose/dcff/fix_subnet.yaml'
model['target_pruning_ratio'] = None
model['is_deployed'] = True
Original file line number Diff line number Diff line change
Expand Up @@ -108,13 +108,11 @@
model = dict(
_scope_='mmrazor',
type='DCFF',
architecture=dict(
cfg_path='mmcls::resnet/resnet50_8xb32_in1k.py', pretrained=False),
architecture=architecture,
mutator_cfg=dict(
type='DCFFChannelMutator',
channel_unit_cfg=dict(
type='DCFFChannelUnit',
units='configs/pruning/mmpose/dcff/resnet_pose.json'),
type='DCFFChannelUnit', default_args=dict(choice_mode='ratio')),
parse_cfg=dict(
type='BackwardTracer',
loss_calculator=dict(type='TopdownPoseEstimatorPseudoLoss'))),
Expand All @@ -125,7 +123,7 @@

dataset_type = 'CocoDataset'
data_mode = 'topdown'
data_root = 'data/coco'
data_root = 'data/coco/'

file_client_args = dict(backend='disk')

Expand Down Expand Up @@ -186,3 +184,5 @@
type='mmpose.CocoMetric',
ann_file=data_root + 'annotations/person_keypoints_val2017.json')
test_evaluator = val_evaluator

val_cfg = dict(_delete_=True, type='mmrazor.ItePruneValLoop')
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,7 @@

# model settings
model = _base_.model
# Avoid pruning_ratio check in mutator
model['fix_subnet'] = 'configs/pruning/mmseg/dcff/fix_subnet.yaml'
model['target_pruning_ratio'] = None
model['is_deployed'] = True
Original file line number Diff line number Diff line change
Expand Up @@ -80,13 +80,11 @@
model = dict(
_scope_='mmrazor',
type='DCFF',
architecture=dict(
cfg_path='mmcls::resnet/resnet50_8xb32_in1k.py', pretrained=False),
architecture=_base_.architecture,
mutator_cfg=dict(
type='DCFFChannelMutator',
channel_unit_cfg=dict(
type='DCFFChannelUnit',
units='configs/pruning/mmseg/dcff/resnet_seg.json'),
type='DCFFChannelUnit', default_args=dict(choice_mode='ratio')),
parse_cfg=dict(
type='BackwardTracer',
loss_calculator=dict(type='CascadeEncoderDecoderPseudoLoss'))),
Expand All @@ -97,3 +95,5 @@

model_wrapper = dict(
type='mmcv.MMDistributedDataParallel', find_unused_parameters=True)

val_cfg = dict(_delete_=True, type='mmrazor.ItePruneValLoop')
7 changes: 4 additions & 3 deletions mmrazor/engine/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,14 @@
from .optimizers import SeparateOptimWrapperConstructor
from .runner import (AutoSlimValLoop, DartsEpochBasedTrainLoop,
DartsIterBasedTrainLoop, EvolutionSearchLoop,
GreedySamplerTrainLoop, SelfDistillValLoop,
SingleTeacherDistillValLoop, SlimmableValLoop)
GreedySamplerTrainLoop, ItePruneValLoop,
SelfDistillValLoop, SingleTeacherDistillValLoop,
SlimmableValLoop)

__all__ = [
'SeparateOptimWrapperConstructor', 'DumpSubnetHook',
'SingleTeacherDistillValLoop', 'DartsEpochBasedTrainLoop',
'DartsIterBasedTrainLoop', 'SlimmableValLoop', 'EvolutionSearchLoop',
'GreedySamplerTrainLoop', 'AutoSlimValLoop', 'EstimateResourcesHook',
'SelfDistillValLoop'
'SelfDistillValLoop', 'ItePruneValLoop'
]
4 changes: 3 additions & 1 deletion mmrazor/engine/runner/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,13 @@
from .darts_loop import DartsEpochBasedTrainLoop, DartsIterBasedTrainLoop
from .distill_val_loop import SelfDistillValLoop, SingleTeacherDistillValLoop
from .evolution_search_loop import EvolutionSearchLoop
from .iteprune_val_loop import ItePruneValLoop
from .slimmable_val_loop import SlimmableValLoop
from .subnet_sampler_loop import GreedySamplerTrainLoop

__all__ = [
'SingleTeacherDistillValLoop', 'DartsEpochBasedTrainLoop',
'DartsIterBasedTrainLoop', 'SlimmableValLoop', 'EvolutionSearchLoop',
'GreedySamplerTrainLoop', 'AutoSlimValLoop', 'SelfDistillValLoop'
'GreedySamplerTrainLoop', 'AutoSlimValLoop', 'SelfDistillValLoop',
'ItePruneValLoop'
]
46 changes: 46 additions & 0 deletions mmrazor/engine/runner/iteprune_val_loop.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
# Copyright (c) OpenMMLab. All rights reserved.
import os.path as osp

from mmengine import fileio
from mmengine.runner import ValLoop

from mmrazor.registry import LOOPS
from mmrazor.structures import export_fix_subnet


@LOOPS.register_module()
class ItePruneValLoop(ValLoop):
"""Pruning loop for validation. Export fixed subnet configs.

Args:
runner (Runner): A reference of runner.
dataloader (Dataloader or dict): A dataloader object or a dict to
build a dataloader.
evaluator (Evaluator or dict or list): Used for computing metrics.
fp16 (bool): Whether to enable fp16 validation. Defaults to
False.
"""

def run(self):
"""Launch validation."""
self.runner.call_hook('before_val')
self.runner.call_hook('before_val_epoch')
self.runner.model.eval()
for idx, data_batch in enumerate(self.dataloader):
self.run_iter(idx, data_batch)

# compute metrics
metrics = self.evaluator.evaluate(len(self.dataloader.dataset))
self._save_fix_subnet()
self.runner.call_hook('after_val_epoch', metrics=metrics)
self.runner.call_hook('after_val')
return metrics

def _save_fix_subnet(self):
"""Save model subnet config."""
fix_subnet = export_fix_subnet(self.model)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's weights also should be exported so that we can load it when in finetuning stage.

save_name = 'fix_subnet.yaml'
fileio.dump(fix_subnet, osp.join(self.runner.work_dir, save_name))
self.runner.logger.info(
'export finished and '
f'{save_name} saved in {self.runner.work_dir}.')
32 changes: 7 additions & 25 deletions mmrazor/models/algorithms/pruning/dcff.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,9 @@
from mmengine.model import BaseModel
from mmengine.structures import BaseDataElement

from mmrazor.models.mutables import BaseMutable
from mmrazor.models.mutators import DCFFChannelMutator
from mmrazor.registry import MODELS
from mmrazor.structures.subnet.fix_subnet import _dynamic_to_static
from mmrazor.utils import ValidFixMutable
from .ite_prune_algorithm import ItePruneAlgorithm, ItePruneConfigManager

LossResults = Dict[str, torch.Tensor]
Expand All @@ -30,8 +29,8 @@ class DCFF(ItePruneAlgorithm):
Args:
architecture (Union[BaseModel, Dict]): The model to be pruned.
mutator_cfg (Union[Dict, ChannelMutator], optional): The config
of a mutator. Defaults to dict( type='ChannelMutator',
channel_unit_cfg=dict( type='SequentialMutableChannelUnit')).
of a mutator. Defaults to dict( type='DCFFChannelMutator',
channel_unit_cfg=dict( type='DCFFChannelUnit')).
data_preprocessor (Optional[Union[Dict, nn.Module]], optional):
Defaults to None.
target_pruning_ratio (dict, optional): The prune-target. The template
Expand All @@ -56,6 +55,7 @@ def __init__(self,
mutator_cfg: Union[Dict, DCFFChannelMutator] = dict(
type=' DCFFChannelMutator',
channel_unit_cfg=dict(type='DCFFChannelUnit')),
fix_subnet: Optional[ValidFixMutable] = None,
data_preprocessor: Optional[Union[Dict, nn.Module]] = None,
target_pruning_ratio: Optional[Dict[str, float]] = None,
step_freq=1,
Expand All @@ -64,27 +64,9 @@ def __init__(self,
linear_schedule=False,
is_deployed=False) -> None:
# invalid param prune_times, reset after message_hub get [max_epoch]
super().__init__(architecture, mutator_cfg, data_preprocessor,
target_pruning_ratio, step_freq, prune_times,
init_cfg, linear_schedule)
self.is_deployed = is_deployed
if (self.is_deployed):
# To static ops for loaded pruned network.
self._deploy()

def _fix_archtecture(self):
for module in self.architecture.modules():
if isinstance(module, BaseMutable):
if not module.is_fixed:
module.fix_chosen(None)

def _deploy(self):
config = self.prune_config_manager.prune_at(self._iter)
self.mutator.set_choices(config)
self.mutator.fix_channel_mutables()
self._fix_archtecture()
_dynamic_to_static(self.architecture)
self.is_deployed = True
super().__init__(architecture, mutator_cfg, fix_subnet,
data_preprocessor, target_pruning_ratio, step_freq,
prune_times, init_cfg, linear_schedule, is_deployed)

def _calc_temperature(self, cur_num: int, max_num: int):
"""Calculate temperature param."""
Expand Down
26 changes: 20 additions & 6 deletions mmrazor/models/algorithms/pruning/ite_prune_algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from mmrazor.models.mutables import MutableChannelUnit
from mmrazor.models.mutators import ChannelMutator
from mmrazor.registry import MODELS
from mmrazor.utils import ValidFixMutable
from ..base import BaseAlgorithm

LossResults = Dict[str, torch.Tensor]
Expand Down Expand Up @@ -97,6 +98,8 @@ class ItePruneAlgorithm(BaseAlgorithm):
mutator_cfg (Union[Dict, ChannelMutator], optional): The config
of a mutator. Defaults to dict( type='ChannelMutator',
channel_unit_cfg=dict( type='SequentialMutableChannelUnit')).
fix_subnet (str | dict | :obj:`FixSubnet`): The path of yaml file or
loaded dict or built :obj:`FixSubnet`. Defaults to None.
data_preprocessor (Optional[Union[Dict, nn.Module]], optional):
Defaults to None.
target_pruning_ratio (dict, optional): The prune-target. The template
Expand All @@ -110,6 +113,8 @@ class ItePruneAlgorithm(BaseAlgorithm):
Defaults to None.
linear_schedule (bool, optional): flag to set linear ratio schedule.
Defaults to True.
is_deployed (bool, optional): flag to set deployed algorithm.
Defaults to False.
"""

def __init__(self,
Expand All @@ -118,12 +123,14 @@ def __init__(self,
type='ChannelMutator',
channel_unit_cfg=dict(
type='SequentialMutableChannelUnit')),
fix_subnet: Optional[ValidFixMutable] = None,
data_preprocessor: Optional[Union[Dict, nn.Module]] = None,
target_pruning_ratio: Optional[Dict[str, float]] = None,
step_freq=-1,
prune_times=-1,
step_freq=1,
prune_times=1,
init_cfg: Optional[Dict] = None,
linear_schedule=True) -> None:
linear_schedule=True,
is_deployed=False) -> None:

super().__init__(architecture, data_preprocessor, init_cfg)

Expand All @@ -132,10 +139,17 @@ def __init__(self,
self.step_freq = step_freq
self.prune_times = prune_times
self.linear_schedule = linear_schedule
self.is_deployed = is_deployed
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is_deployed is no needed.


# mutator
self.mutator: ChannelMutator = MODELS.build(mutator_cfg)
self.mutator.prepare_from_supernet(self.architecture)
if self.is_deployed:
assert fix_subnet is not None
# Avoid circular import
from mmrazor.structures import load_fix_subnet
load_fix_subnet(self.architecture, fix_subnet)
else:
# init mutator
self.mutator: ChannelMutator = MODELS.build(mutator_cfg)
self.mutator.prepare_from_supernet(self.architecture)

def group_target_pruning_ratio(
self, target: Dict[str, float],
Expand Down
Loading