From a91e2c7de2e662ee81b569844674ded710877c6b Mon Sep 17 00:00:00 2001 From: LKJacky <108643365+LKJacky@users.noreply.github.com> Date: Tue, 20 Dec 2022 05:11:53 -0300 Subject: [PATCH] add Chex ops, Chex mutator and algorithm (#404) * finish chex mutator * add TestChexAlgorithm * add chex ops Co-authored-by: jacky --- mmrazor/models/__init__.py | 1 + mmrazor/models/chex/__init__.py | 9 +++++ mmrazor/models/chex/chex_algorithm.py | 34 ++++++++++++----- mmrazor/models/chex/chex_mutator.py | 37 +++++++++++++++--- mmrazor/models/chex/chex_ops.py | 52 ++++++++++++++++++++------ mmrazor/models/chex/chex_unit.py | 12 ++++-- mmrazor/models/chex/utils.py | 25 +++++++++++++ tests/test_chex/test_chex_algorithm.py | 42 +++++++++++++++++++++ tests/test_chex/test_chex_mutator.py | 27 +++++++++++++ tests/test_chex/test_chex_ops.py | 26 +++++++++++++ 10 files changed, 236 insertions(+), 29 deletions(-) create mode 100644 mmrazor/models/chex/utils.py create mode 100644 tests/test_chex/test_chex_algorithm.py create mode 100644 tests/test_chex/test_chex_mutator.py diff --git a/mmrazor/models/__init__.py b/mmrazor/models/__init__.py index f5295aa9e..90d69825d 100644 --- a/mmrazor/models/__init__.py +++ b/mmrazor/models/__init__.py @@ -1,6 +1,7 @@ # Copyright (c) OpenMMLab. All rights reserved. from .algorithms import * # noqa: F401,F403 from .architectures import * # noqa: F401,F403 +from .chex import * # noqa: F401,F403 from .distillers import * # noqa: F401,F403 from .losses import * # noqa: F401,F403 from .mutables import * # noqa: F401,F403 diff --git a/mmrazor/models/chex/__init__.py b/mmrazor/models/chex/__init__.py index ef101fec6..311cba977 100644 --- a/mmrazor/models/chex/__init__.py +++ b/mmrazor/models/chex/__init__.py @@ -1 +1,10 @@ # Copyright (c) OpenMMLab. All rights reserved. +from .chex_algorithm import ChexAlgorithm +from .chex_mutator import ChexMutator +from .chex_ops import ChexConv2d, ChexLinear, ChexMixin +from .chex_unit import ChexUnit + +__all__ = [ + 'ChexAlgorithm', 'ChexMutator', 'ChexUnit', 'ChexConv2d', 'ChexLinear', + 'ChexMixin' +] diff --git a/mmrazor/models/chex/chex_algorithm.py b/mmrazor/models/chex/chex_algorithm.py index 368d3fcd3..fcbcc1beb 100644 --- a/mmrazor/models/chex/chex_algorithm.py +++ b/mmrazor/models/chex/chex_algorithm.py @@ -1,35 +1,51 @@ # Copyright (c) OpenMMLab. All rights reserved. +import math from typing import Dict, Optional, Union import torch.nn as nn from mmengine.model import BaseModel from mmrazor.models.algorithms import BaseAlgorithm +from mmrazor.registry import MODELS +from .chex_mutator import ChexMutator +from .utils import RuntimeInfo -class ChexAlgoritm(BaseAlgorithm): +@MODELS.register_module() +class ChexAlgorithm(BaseAlgorithm): def __init__(self, architecture: Union[BaseModel, Dict], data_preprocessor: Optional[Union[Dict, nn.Module]] = None, + mutator_cfg=dict( + type='ChexMutator', + channel_unit_cfg=dict(type='ChexUnit')), delta_t=2, total_steps=10, init_growth_rate=0.3, init_cfg: Optional[Dict] = None): + super().__init__(architecture, data_preprocessor, init_cfg) + self.delta_t = delta_t + self.total_steps = total_steps self.init_growth_rate = init_growth_rate + self.mutator: ChexMutator = MODELS.build(mutator_cfg) + self.mutator.prepare_from_supernet(self.architecture) + def forward(self, inputs, data_samples=None, mode: str = 'tensor'): - if True: # - self.mutator.prune() - self.mutator.grow(self.growth_ratio) + if self.training: # + if RuntimeInfo.iter() % self.delta_t == 0 and \ + RuntimeInfo.iter() // self.delta_t < self.total_steps: + self.mutator.prune() + self.mutator.grow(self.growth_ratio) return super().forward(inputs, data_samples, mode) - @property - def _epoch(self): - pass - @property def growth_ratio(self): # return growth ratio in current epoch - pass + def cos(): + a = math.pi * RuntimeInfo.epoch() / RuntimeInfo.max_epochs() + return (math.cos(a) + 1) / 2 + + return self.init_growth_rate * cos() diff --git a/mmrazor/models/chex/chex_mutator.py b/mmrazor/models/chex/chex_mutator.py index 52d138fb9..46c986c6a 100644 --- a/mmrazor/models/chex/chex_mutator.py +++ b/mmrazor/models/chex/chex_mutator.py @@ -1,9 +1,14 @@ # Copyright (c) OpenMMLab. All rights reserved. from typing import Dict, List, Optional +import torch + from mmrazor.models.mutators import ChannelMutator +from mmrazor.registry import MODELS +from .chex_unit import ChexUnit +@MODELS.register_module() class ChexMutator(ChannelMutator): def __init__(self, @@ -24,8 +29,9 @@ def prune(self): step1: get pruning structure step2: prune based on ChexMixin.prune_imp """ - _ = self._get_prune_choices() - pass + choices = self._get_prune_choices() + for unit in self.mutable_units: + unit.prune(choices[unit.name]) def grow(self, growth_ratio=0.0): """Make the model grow. @@ -33,11 +39,30 @@ def grow(self, growth_ratio=0.0): step1: get growth choices step2: grow based on ChexMixin.growth_imp """ - _ = self._get_grow_choices(growth_ratio) - pass + choices = self._get_grow_choices(growth_ratio) + for unit in self.mutable_units: + unit: ChexUnit + unit.grow(choices[unit.name] - unit.current_choice) def _get_grow_choices(self, growth_choice): - pass + choices = {} + for unit in self.mutable_units: + unit: ChexUnit + choices[unit.name] = min( + unit.num_channels, + (unit.current_choice + int(unit.num_channels * growth_choice))) + return choices def _get_prune_choices(self): - pass + choices = {} + bn_imps = {} + for unit in self.mutable_units: + unit: ChexUnit + bn_imps[unit.name] = unit.bn_imp + bn_imp: torch.Tensor = torch.cat(list(bn_imps.values()), dim=0) + num_remain = int(self.channel_ratio * len(bn_imp)) + threshold = bn_imp.topk(num_remain)[0][-1] + for unit in self.mutable_units: + num = (bn_imps[unit.name] >= threshold).float().sum().long().item() + choices[unit.name] = num + return choices diff --git a/mmrazor/models/chex/chex_ops.py b/mmrazor/models/chex/chex_ops.py index 4317b5185..83f324f6c 100644 --- a/mmrazor/models/chex/chex_ops.py +++ b/mmrazor/models/chex/chex_ops.py @@ -7,30 +7,60 @@ class ChexMixin: - @property - def prune_imp(self): + def prune_imp(self, num_remain): # compute channel importance for pruning - return self._prune_imp(self.weight) + return self._prune_imp(self.get_weight_matrix(), num_remain) @property def growth_imp(self): # compute channel importance for growth - return self._growth_imp(self.weight) + return self._growth_imp(self.get_weight_matrix()) + + def get_weight_matrix(self): + raise NotImplementedError() - def _prune_imp(self, weight): + def _prune_imp(self, weight, num_remain): # weight: out * in. return the importance of each channel - out_channel = weight.shape[0] - return torch.rand(out_channel) + # modified from https://github.com/zejiangh/Filter-GaP + assert num_remain <= weight.shape[0] + weight_t = weight.T # in out + if weight_t.shape[0] >= weight_t.shape[1]: # in >= out + _, _, V = torch.svd(weight_t, some=True) # out out + Vk = V[:, :num_remain] # out out' + lvs = torch.norm(Vk, dim=1) # out + return lvs + else: + # l1-norm + return weight.abs().mean(-1) def _growth_imp(self, weight): # weight: out * in. return the importance of each channel when growth - out_channel = weight.shape[0] - return torch.rand(out_channel) + + def get_proj(weight): + # out' in + wt = weight.T # in out' + scatter = torch.matmul(wt.T, wt) # out' out' + inv = torch.pinverse(scatter) # out' out' + return torch.matmul(torch.matmul(wt, inv), wt.T) # in in + + mask = self.get_mutable_attr('out_channels').current_mask + n_mask = ~mask + proj = get_proj(weight[mask]) # in in + weight_c = weight[n_mask] # out'' in + + error = (weight_c - weight_c @ proj).norm(dim=-1) + all_errors = torch.zeros([weight.shape[0]], device=weight.device) + all_errors.masked_scatter_(n_mask, error) + return all_errors class ChexConv2d(DynamicConv2d, ChexMixin): - pass + + def get_weight_matrix(self): + return self.weight.flatten(1) class ChexLinear(DynamicLinear, ChexMixin): - pass + + def get_weight_matrix(self): + return self.weight diff --git a/mmrazor/models/chex/chex_unit.py b/mmrazor/models/chex/chex_unit.py index b29615bb5..bdc6b384f 100644 --- a/mmrazor/models/chex/chex_unit.py +++ b/mmrazor/models/chex/chex_unit.py @@ -9,9 +9,11 @@ import mmrazor.models.architectures.dynamic_ops as dynamic_ops from mmrazor.models.mutables.mutable_channel import MutableChannelContainer from mmrazor.models.mutables.mutable_channel.units import L1MutableChannelUnit +from mmrazor.registry import MODELS from .chex_ops import ChexConv2d, ChexLinear, ChexMixin +@MODELS.register_module() class ChexUnit(L1MutableChannelUnit): def prepare_for_pruning(self, model: nn.Module): @@ -33,9 +35,10 @@ def get_prune_imp(): prune_imp: torch.Tensor = torch.zeros([self.num_channels]) for channel in self.chex_channels: module = channel.module - prune_imp = prune_imp.to(module.prune_imp.device) - prune_imp = prune_imp + module.prune_imp[channel.start:channel. - end] + prune_imp = prune_imp.to( + module.prune_imp(num_remaining).device) + prune_imp = prune_imp + module.prune_imp( + num_remaining)[channel.start:channel.end] return prune_imp prune_imp = get_prune_imp() @@ -47,6 +50,9 @@ def get_prune_imp(): self.mutable_channel.current_choice.data = mask def grow(self, num): + assert num >= 0 + if num == 0: + return def get_growth_imp(): growth_imp: torch.Tensor = torch.zeros([self.num_channels]) diff --git a/mmrazor/models/chex/utils.py b/mmrazor/models/chex/utils.py new file mode 100644 index 000000000..1688cfcd8 --- /dev/null +++ b/mmrazor/models/chex/utils.py @@ -0,0 +1,25 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from mmengine.logging import MessageHub + + +class RuntimeInfo(): + + @classmethod + def get_info(cls, key): + hub = MessageHub.get_current_instance() + if key in hub.runtime_info: + return hub.runtime_info[key] + else: + raise KeyError(key) + + @classmethod + def epoch(cls): + return cls.get_info('epoch') + + @classmethod + def max_epochs(cls): + return cls.get_info('max_epochs') + + @classmethod + def iter(cls): + return cls.get_info('iter') diff --git a/tests/test_chex/test_chex_algorithm.py b/tests/test_chex/test_chex_algorithm.py new file mode 100644 index 000000000..1d5dd4881 --- /dev/null +++ b/tests/test_chex/test_chex_algorithm.py @@ -0,0 +1,42 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import unittest + +import torch +from mmengine.logging import MessageHub + +from mmrazor.models.chex import ChexAlgorithm + +MODEL_CFG = dict( + _scope_='mmcls', + type='ImageClassifier', + backbone=dict( + type='ResNet', + depth=18, + num_stages=4, + out_indices=(3, ), + style='pytorch'), + neck=dict(type='GlobalAveragePooling'), + head=dict( + type='LinearClsHead', + num_classes=1000, + in_channels=512, + loss=dict(type='CrossEntropyLoss', loss_weight=1.0), + topk=(1, 5), + )) + + +class TestChexAlgorithm(unittest.TestCase): + + def test_chex_algorithm(self): + algorithm = ChexAlgorithm(MODEL_CFG) + x = torch.rand([2, 3, 64, 64]) + self._set_epoch_ite(0, 2, 100) + _ = algorithm(x) + + def _set_epoch_ite(self, epoch, ite, max_epoch): + iter_per_epoch = 10 + message_hub = MessageHub.get_current_instance() + message_hub.update_info('epoch', epoch) + message_hub.update_info('max_epochs', max_epoch) + message_hub.update_info('max_iters', max_epoch * 10) + message_hub.update_info('iter', ite + iter_per_epoch * epoch) diff --git a/tests/test_chex/test_chex_mutator.py b/tests/test_chex/test_chex_mutator.py new file mode 100644 index 000000000..ecdab4488 --- /dev/null +++ b/tests/test_chex/test_chex_mutator.py @@ -0,0 +1,27 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import unittest + +import torch +import torch.nn as nn + +from mmrazor.models.chex.chex_mutator import ChexMutator +from mmrazor.models.chex.chex_unit import ChexUnit +from ..data.models import SingleLineModel + + +class TestChexMutator(unittest.TestCase): + + def test_chex_mutator(self): + model = SingleLineModel() + mutator = ChexMutator(channel_unit_cfg=ChexUnit) + mutator.prepare_from_supernet(model) + + for module in model.modules(): + if isinstance(module, nn.modules.batchnorm._BatchNorm): + module.weight.data = torch.rand_like(module.weight.data) + + mutator.prune() + print(mutator.current_choices) + + mutator.grow(0.2) + print(mutator.current_choices) diff --git a/tests/test_chex/test_chex_ops.py b/tests/test_chex/test_chex_ops.py index ef101fec6..5cc1aebda 100644 --- a/tests/test_chex/test_chex_ops.py +++ b/tests/test_chex/test_chex_ops.py @@ -1 +1,27 @@ # Copyright (c) OpenMMLab. All rights reserved. +import unittest + +import torch +import torch.nn as nn + +from mmrazor.models.chex import ChexConv2d +from mmrazor.models.mutables import SimpleMutableChannel + + +class TestChexOps(unittest.TestCase): + + def test_ops(self): + for in_c, out_c in [(4, 8), (8, 4), (8, 8)]: + conv = nn.Conv2d(in_c, out_c, 3, 1, 1) + conv: ChexConv2d = ChexConv2d.convert_from(conv) + + mutable_in = SimpleMutableChannel(in_c) + mutable_out = SimpleMutableChannel(out_c) + + conv.register_mutable_attr('in_channels', mutable_in) + conv.register_mutable_attr('out_channels', mutable_out) + + mutable_out.current_choice = torch.normal(0, 1, [out_c]) < 0 + + self.assertEqual(list(conv.prune_imp(4).shape), [out_c]) + self.assertEqual(list(conv.growth_imp.shape), [out_c])