Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add get_prune_config and a demo config_pruning #389

Merged
merged 8 commits into from
Dec 13, 2022
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
702 changes: 702 additions & 0 deletions demo/pruning/config_pruning.ipynb

Large diffs are not rendered by default.

2 changes: 2 additions & 0 deletions docs/en/user_guides/pruning_user_guide.md
Original file line number Diff line number Diff line change
Expand Up @@ -146,3 +146,5 @@ Please refer to the following documents for more details.
- [MutableChannel](../../../mmrazor/models/mutables/mutable_channel/MutableChannel.md)
- [ChannelMutator](../../../mmrazor/models/mutables/mutable_channel/units/mutable_channel_unit.ipynb)
- [MutableChannelUnit](../../../mmrazor/models/mutators/channel_mutator/channel_mutator.ipynb)
- Demos
- [Config pruning](../../../demo/config_pruning.ipynb)
2 changes: 2 additions & 0 deletions requirements/tests.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@ codecov
flake8
interrogate
isort==4.3.21
nbconvert
nbformat
pytest
xdoctest >= 0.10.0
yapf
32 changes: 32 additions & 0 deletions tests/test_doc.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
# Copyright (c) OpenMMLab. All rights reserved.
import os
from unittest import TestCase

import nbformat
from nbconvert.preprocessors import ExecutePreprocessor

TEST_DOC = os.getenv('TEST_DOC') == 'true'
notebook_paths = [
'./mmrazor/models/mutators/channel_mutator/channel_mutator.ipynb',
'./mmrazor/models/mutables/mutable_channel/units/mutable_channel_unit.ipynb', # noqa
'./demo/config_pruning.ipynb'
]


class TestDocs(TestCase):

def setUp(self) -> None:
if not TEST_DOC:
self.skipTest('disabled')

def test_notebooks(self):
for path in notebook_paths:
with self.subTest(path=path):
with open(path) as file:
nb_in = nbformat.read(file, nbformat.NO_CONVERT)
ep = ExecutePreprocessor(
timeout=600, kernel_name='python3')
try:
_ = ep.preprocess(nb_in)
except Exception:
self.fail()
1 change: 1 addition & 0 deletions tests/test_tools/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
# Copyright (c) OpenMMLab. All rights reserved.
101 changes: 101 additions & 0 deletions tests/test_tools/test_tools.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
# Copyright (c) OpenMMLab. All rights reserved.
import os
import shutil
import subprocess
from unittest import TestCase

import torch

from mmrazor import digit_version

TEST_TOOLS = os.getenv('TEST_TOOLS') == 'true'


class TestTools(TestCase):
_config_path = None

def setUp(self) -> None:
if not TEST_TOOLS:
self.skipTest('disabled')

@property
def config_path(self):
if self._config_path is None:
self._config_path = self._get_config_path()
return self._config_path

def _setUp(self) -> None:
self.workdir = os.path.dirname(__file__) + '/tmp/'
if not os.path.exists(self.workdir):
os.mkdir(self.workdir)

def save_to_config(self, name, content):
with open(self.workdir + f'/{name}', 'w') as f:
f.write(content)

def test_get_channel_unit(self):
if digit_version(torch.__version__) < digit_version('1.12.0'):
self.skipTest('version of torch < 1.12.0')

for path in self.config_path:
with self.subTest(path=path):
self._setUp()
self.save_to_config('pretrain.py', f"""_base_=['{path}']""")
try:
subprocess.run([
'python', './tools/get_channel_units.py',
f'{self.workdir}/pretrain.py', '-o',
f'{self.workdir}/unit.json'
])
except Exception as e:
self.fail(f'{e}')
self.assertTrue(os.path.exists(f'{self.workdir}/unit.json'))

self._tearDown()

def test_get_prune_config(self):
if digit_version(torch.__version__) < digit_version('1.12.0'):
self.skipTest('version of torch < 1.12.0')
for path in self.config_path:
with self.subTest(path=path):
self._setUp()
self.save_to_config('pretrain.py', f"""_base_=['{path}']""")
try:
subprocess.run([
'python',
'./tools/pruning/get_l1_prune_config.py',
f'{self.workdir}/pretrain.py',
'-o',
f'{self.workdir}/prune.py',
])
pass
except Exception as e:
self.fail(f'{e}')
self.assertTrue(os.path.exists(f'{self.workdir}/prune.py'))

self._tearDown()

def _tearDown(self) -> None:
print('delete')
shutil.rmtree(self.workdir)
pass

def _get_config_path(self):
config_paths = []
paths = [
('mmcls', 'mmcls::resnet/resnet34_8xb32_in1k.py'),
('mmdet', 'mmdet::retinanet/retinanet_r18_fpn_1x_coco.py'),
(
'mmseg',
'mmseg::deeplabv3plus/deeplabv3plus_r50-d8_4xb4-20k_voc12aug-512x512.py' # noqa
),
('mmyolo',
'mmyolo::yolov5/yolov5_m-p6-v62_syncbn_fast_8xb16-300e_coco.py')
]
for repo_name, path in paths:
try:
__import__(repo_name)
config_paths.append(path)
except Exception:
pass
return config_paths
19 changes: 18 additions & 1 deletion tools/get_channel_units.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Copyright (c) OpenMMLab. All rights reserved.
import argparse
import json
import sys

import torch.nn as nn
from mmengine import MODELS
Expand All @@ -9,6 +10,8 @@
from mmrazor.models import BaseAlgorithm
from mmrazor.models.mutators import ChannelMutator

sys.setrecursionlimit(int(pow(2, 20)))


def parse_args():
parser = argparse.ArgumentParser(
Expand Down Expand Up @@ -40,11 +43,25 @@ def parse_args():
def main():
args = parse_args()
config = Config.fromfile(args.config)
default_scope = config['default_scope']

model = MODELS.build(config['model'])
if isinstance(model, BaseAlgorithm):
mutator = model.mutator
elif isinstance(model, nn.Module):
mutator = ChannelMutator()
mutator: ChannelMutator = ChannelMutator(
channel_unit_cfg=dict(
type='L1MutableChannelUnit',
default_args=dict(choice_mode='ratio'),
),
parse_cfg={
'type': 'ChannelAnalyzer',
'demo_input': {
'type': 'DefaultDemoInput',
'scope': default_scope
},
'tracer_type': 'FxTracer'
})
mutator.prepare_from_supernet(model)
if args.choice:
config = mutator.choice_template
Expand Down
127 changes: 127 additions & 0 deletions tools/pruning/get_l1_prune_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
# Copyright (c) OpenMMLab. All rights reserved.
import argparse
import copy
from typing import Dict

from mmengine import Config, fileio

from mmrazor.models.mutators import ChannelMutator
from mmrazor.registry import MODELS


def parse_args():
parser = argparse.ArgumentParser(
description='Get the config to prune a model.')
parser.add_argument('config', help='config of the model')
parser.add_argument(
'--checkpoint',
default=None,
type=str,
help='checkpoint path of the model')
parser.add_argument(
'--subnet',
default=None,
type=str,
help='pruning structure for the model')
parser.add_argument(
'-o',
type=str,
default='./prune.py',
help='output path to store the pruning config.')
args = parser.parse_args()
return args


def wrap_prune_config(config: Config, prune_target: Dict,
checkpoint_path: str):
config = copy.deepcopy(config)
default_scope = config['default_scope']
arch_config: Dict = config['model']

# update checkpoint_path
if checkpoint_path is not None:
arch_config.update({
'init_cfg': {
'type': 'Pretrained',
'checkpoint': checkpoint_path # noqa
},
})

# deal with data_preprocessor
if 'data_preprocessor' in config:
data_preprocessor = config['data_preprocessor']
arch_config.update({'data_preprocessor': data_preprocessor})
config['data_preprocessor'] = None
else:
data_preprocessor = None

# prepare algorithm
algorithm_config = dict(
_scope_='mmrazor',
type='ItePruneAlgorithm',
architecture=arch_config,
target_pruning_ratio=prune_target,
mutator_cfg=dict(
type='ChannelMutator',
channel_unit_cfg=dict(
type='L1MutableChannelUnit',
default_args=dict(choice_mode='ratio')),
parse_cfg=dict(
type='ChannelAnalyzer',
tracer_type='FxTracer',
demo_input=dict(type='DefaultDemoInput',
scope=default_scope))))
config['model'] = algorithm_config

return config


def change_config(config):

scope = config['default_scope']
config['model']['_scope_'] = scope
return config


if __name__ == '__main__':
args = parse_args()
config_path = args.config
checkpoint_path = args.checkpoint
target_path = args.o

origin_config = Config.fromfile(config_path)
origin_config = change_config(origin_config)
default_scope = origin_config['default_scope']

# get subnet config
model = MODELS.build(copy.deepcopy(origin_config['model']))
mutator: ChannelMutator = ChannelMutator(
channel_unit_cfg=dict(
type='L1MutableChannelUnit',
default_args=dict(choice_mode='ratio'),
),
parse_cfg={
'type': 'ChannelAnalyzer',
'demo_input': {
'type': 'DefaultDemoInput',
'scope': default_scope
},
'tracer_type': 'FxTracer'
})
mutator.prepare_from_supernet(model)
if args.subnet is None:
choice_template = mutator.choice_template
else:
input_choices = fileio.load(args.subnet)
try:
mutator.set_choices(input_choices)
choice_template = input_choices
except Exception as e:
print(f'error when apply input subnet: {e}')
choice_template = mutator.choice_template

# prune and finetune

prune_config: Config = wrap_prune_config(origin_config, choice_template,
checkpoint_path)
prune_config.dump(target_path)