Skip to content

Commit

Permalink
add Chex ops, Chex mutator and algorithm (#404)
Browse files Browse the repository at this point in the history
* finish chex mutator

* add TestChexAlgorithm

* add chex ops

Co-authored-by: jacky <jacky@xx.com>
  • Loading branch information
LKJacky and jacky committed Dec 20, 2022
1 parent 5a932cc commit a91e2c7
Show file tree
Hide file tree
Showing 10 changed files with 236 additions and 29 deletions.
1 change: 1 addition & 0 deletions mmrazor/models/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .algorithms import * # noqa: F401,F403
from .architectures import * # noqa: F401,F403
from .chex import * # noqa: F401,F403
from .distillers import * # noqa: F401,F403
from .losses import * # noqa: F401,F403
from .mutables import * # noqa: F401,F403
Expand Down
9 changes: 9 additions & 0 deletions mmrazor/models/chex/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,10 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .chex_algorithm import ChexAlgorithm
from .chex_mutator import ChexMutator
from .chex_ops import ChexConv2d, ChexLinear, ChexMixin
from .chex_unit import ChexUnit

__all__ = [
'ChexAlgorithm', 'ChexMutator', 'ChexUnit', 'ChexConv2d', 'ChexLinear',
'ChexMixin'
]
34 changes: 25 additions & 9 deletions mmrazor/models/chex/chex_algorithm.py
Original file line number Diff line number Diff line change
@@ -1,35 +1,51 @@
# Copyright (c) OpenMMLab. All rights reserved.
import math
from typing import Dict, Optional, Union

import torch.nn as nn
from mmengine.model import BaseModel

from mmrazor.models.algorithms import BaseAlgorithm
from mmrazor.registry import MODELS
from .chex_mutator import ChexMutator
from .utils import RuntimeInfo


class ChexAlgoritm(BaseAlgorithm):
@MODELS.register_module()
class ChexAlgorithm(BaseAlgorithm):

def __init__(self,
architecture: Union[BaseModel, Dict],
data_preprocessor: Optional[Union[Dict, nn.Module]] = None,
mutator_cfg=dict(
type='ChexMutator',
channel_unit_cfg=dict(type='ChexUnit')),
delta_t=2,
total_steps=10,
init_growth_rate=0.3,
init_cfg: Optional[Dict] = None):
super().__init__(architecture, data_preprocessor, init_cfg)

self.delta_t = delta_t
self.total_steps = total_steps
self.init_growth_rate = init_growth_rate

self.mutator: ChexMutator = MODELS.build(mutator_cfg)
self.mutator.prepare_from_supernet(self.architecture)

def forward(self, inputs, data_samples=None, mode: str = 'tensor'):
if True: #
self.mutator.prune()
self.mutator.grow(self.growth_ratio)
if self.training: #
if RuntimeInfo.iter() % self.delta_t == 0 and \
RuntimeInfo.iter() // self.delta_t < self.total_steps:
self.mutator.prune()
self.mutator.grow(self.growth_ratio)
return super().forward(inputs, data_samples, mode)

@property
def _epoch(self):
pass

@property
def growth_ratio(self):
# return growth ratio in current epoch
pass
def cos():
a = math.pi * RuntimeInfo.epoch() / RuntimeInfo.max_epochs()
return (math.cos(a) + 1) / 2

return self.init_growth_rate * cos()
37 changes: 31 additions & 6 deletions mmrazor/models/chex/chex_mutator.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,14 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Dict, List, Optional

import torch

from mmrazor.models.mutators import ChannelMutator
from mmrazor.registry import MODELS
from .chex_unit import ChexUnit


@MODELS.register_module()
class ChexMutator(ChannelMutator):

def __init__(self,
Expand All @@ -24,20 +29,40 @@ def prune(self):
step1: get pruning structure
step2: prune based on ChexMixin.prune_imp
"""
_ = self._get_prune_choices()
pass
choices = self._get_prune_choices()
for unit in self.mutable_units:
unit.prune(choices[unit.name])

def grow(self, growth_ratio=0.0):
"""Make the model grow.
step1: get growth choices
step2: grow based on ChexMixin.growth_imp
"""
_ = self._get_grow_choices(growth_ratio)
pass
choices = self._get_grow_choices(growth_ratio)
for unit in self.mutable_units:
unit: ChexUnit
unit.grow(choices[unit.name] - unit.current_choice)

def _get_grow_choices(self, growth_choice):
pass
choices = {}
for unit in self.mutable_units:
unit: ChexUnit
choices[unit.name] = min(
unit.num_channels,
(unit.current_choice + int(unit.num_channels * growth_choice)))
return choices

def _get_prune_choices(self):
pass
choices = {}
bn_imps = {}
for unit in self.mutable_units:
unit: ChexUnit
bn_imps[unit.name] = unit.bn_imp
bn_imp: torch.Tensor = torch.cat(list(bn_imps.values()), dim=0)
num_remain = int(self.channel_ratio * len(bn_imp))
threshold = bn_imp.topk(num_remain)[0][-1]
for unit in self.mutable_units:
num = (bn_imps[unit.name] >= threshold).float().sum().long().item()
choices[unit.name] = num
return choices
52 changes: 41 additions & 11 deletions mmrazor/models/chex/chex_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,30 +7,60 @@

class ChexMixin:

@property
def prune_imp(self):
def prune_imp(self, num_remain):
# compute channel importance for pruning
return self._prune_imp(self.weight)
return self._prune_imp(self.get_weight_matrix(), num_remain)

@property
def growth_imp(self):
# compute channel importance for growth
return self._growth_imp(self.weight)
return self._growth_imp(self.get_weight_matrix())

def get_weight_matrix(self):
raise NotImplementedError()

def _prune_imp(self, weight):
def _prune_imp(self, weight, num_remain):
# weight: out * in. return the importance of each channel
out_channel = weight.shape[0]
return torch.rand(out_channel)
# modified from https://github.com/zejiangh/Filter-GaP
assert num_remain <= weight.shape[0]
weight_t = weight.T # in out
if weight_t.shape[0] >= weight_t.shape[1]: # in >= out
_, _, V = torch.svd(weight_t, some=True) # out out
Vk = V[:, :num_remain] # out out'
lvs = torch.norm(Vk, dim=1) # out
return lvs
else:
# l1-norm
return weight.abs().mean(-1)

def _growth_imp(self, weight):
# weight: out * in. return the importance of each channel when growth
out_channel = weight.shape[0]
return torch.rand(out_channel)

def get_proj(weight):
# out' in
wt = weight.T # in out'
scatter = torch.matmul(wt.T, wt) # out' out'
inv = torch.pinverse(scatter) # out' out'
return torch.matmul(torch.matmul(wt, inv), wt.T) # in in

mask = self.get_mutable_attr('out_channels').current_mask
n_mask = ~mask
proj = get_proj(weight[mask]) # in in
weight_c = weight[n_mask] # out'' in

error = (weight_c - weight_c @ proj).norm(dim=-1)
all_errors = torch.zeros([weight.shape[0]], device=weight.device)
all_errors.masked_scatter_(n_mask, error)
return all_errors


class ChexConv2d(DynamicConv2d, ChexMixin):
pass

def get_weight_matrix(self):
return self.weight.flatten(1)


class ChexLinear(DynamicLinear, ChexMixin):
pass

def get_weight_matrix(self):
return self.weight
12 changes: 9 additions & 3 deletions mmrazor/models/chex/chex_unit.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,11 @@
import mmrazor.models.architectures.dynamic_ops as dynamic_ops
from mmrazor.models.mutables.mutable_channel import MutableChannelContainer
from mmrazor.models.mutables.mutable_channel.units import L1MutableChannelUnit
from mmrazor.registry import MODELS
from .chex_ops import ChexConv2d, ChexLinear, ChexMixin


@MODELS.register_module()
class ChexUnit(L1MutableChannelUnit):

def prepare_for_pruning(self, model: nn.Module):
Expand All @@ -33,9 +35,10 @@ def get_prune_imp():
prune_imp: torch.Tensor = torch.zeros([self.num_channels])
for channel in self.chex_channels:
module = channel.module
prune_imp = prune_imp.to(module.prune_imp.device)
prune_imp = prune_imp + module.prune_imp[channel.start:channel.
end]
prune_imp = prune_imp.to(
module.prune_imp(num_remaining).device)
prune_imp = prune_imp + module.prune_imp(
num_remaining)[channel.start:channel.end]
return prune_imp

prune_imp = get_prune_imp()
Expand All @@ -47,6 +50,9 @@ def get_prune_imp():
self.mutable_channel.current_choice.data = mask

def grow(self, num):
assert num >= 0
if num == 0:
return

def get_growth_imp():
growth_imp: torch.Tensor = torch.zeros([self.num_channels])
Expand Down
25 changes: 25 additions & 0 deletions mmrazor/models/chex/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
# Copyright (c) OpenMMLab. All rights reserved.
from mmengine.logging import MessageHub


class RuntimeInfo():

@classmethod
def get_info(cls, key):
hub = MessageHub.get_current_instance()
if key in hub.runtime_info:
return hub.runtime_info[key]
else:
raise KeyError(key)

@classmethod
def epoch(cls):
return cls.get_info('epoch')

@classmethod
def max_epochs(cls):
return cls.get_info('max_epochs')

@classmethod
def iter(cls):
return cls.get_info('iter')
42 changes: 42 additions & 0 deletions tests/test_chex/test_chex_algorithm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
# Copyright (c) OpenMMLab. All rights reserved.
import unittest

import torch
from mmengine.logging import MessageHub

from mmrazor.models.chex import ChexAlgorithm

MODEL_CFG = dict(
_scope_='mmcls',
type='ImageClassifier',
backbone=dict(
type='ResNet',
depth=18,
num_stages=4,
out_indices=(3, ),
style='pytorch'),
neck=dict(type='GlobalAveragePooling'),
head=dict(
type='LinearClsHead',
num_classes=1000,
in_channels=512,
loss=dict(type='CrossEntropyLoss', loss_weight=1.0),
topk=(1, 5),
))


class TestChexAlgorithm(unittest.TestCase):

def test_chex_algorithm(self):
algorithm = ChexAlgorithm(MODEL_CFG)
x = torch.rand([2, 3, 64, 64])
self._set_epoch_ite(0, 2, 100)
_ = algorithm(x)

def _set_epoch_ite(self, epoch, ite, max_epoch):
iter_per_epoch = 10
message_hub = MessageHub.get_current_instance()
message_hub.update_info('epoch', epoch)
message_hub.update_info('max_epochs', max_epoch)
message_hub.update_info('max_iters', max_epoch * 10)
message_hub.update_info('iter', ite + iter_per_epoch * epoch)
27 changes: 27 additions & 0 deletions tests/test_chex/test_chex_mutator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
# Copyright (c) OpenMMLab. All rights reserved.
import unittest

import torch
import torch.nn as nn

from mmrazor.models.chex.chex_mutator import ChexMutator
from mmrazor.models.chex.chex_unit import ChexUnit
from ..data.models import SingleLineModel


class TestChexMutator(unittest.TestCase):

def test_chex_mutator(self):
model = SingleLineModel()
mutator = ChexMutator(channel_unit_cfg=ChexUnit)
mutator.prepare_from_supernet(model)

for module in model.modules():
if isinstance(module, nn.modules.batchnorm._BatchNorm):
module.weight.data = torch.rand_like(module.weight.data)

mutator.prune()
print(mutator.current_choices)

mutator.grow(0.2)
print(mutator.current_choices)
26 changes: 26 additions & 0 deletions tests/test_chex/test_chex_ops.py
Original file line number Diff line number Diff line change
@@ -1 +1,27 @@
# Copyright (c) OpenMMLab. All rights reserved.
import unittest

import torch
import torch.nn as nn

from mmrazor.models.chex import ChexConv2d
from mmrazor.models.mutables import SimpleMutableChannel


class TestChexOps(unittest.TestCase):

def test_ops(self):
for in_c, out_c in [(4, 8), (8, 4), (8, 8)]:
conv = nn.Conv2d(in_c, out_c, 3, 1, 1)
conv: ChexConv2d = ChexConv2d.convert_from(conv)

mutable_in = SimpleMutableChannel(in_c)
mutable_out = SimpleMutableChannel(out_c)

conv.register_mutable_attr('in_channels', mutable_in)
conv.register_mutable_attr('out_channels', mutable_out)

mutable_out.current_choice = torch.normal(0, 1, [out_c]) < 0

self.assertEqual(list(conv.prune_imp(4).shape), [out_c])
self.assertEqual(list(conv.growth_imp.shape), [out_c])

0 comments on commit a91e2c7

Please sign in to comment.