-
Notifications
You must be signed in to change notification settings - Fork 220
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
add Chex ops, Chex mutator and algorithm (#404)
* finish chex mutator * add TestChexAlgorithm * add chex ops Co-authored-by: jacky <jacky@xx.com>
- Loading branch information
Showing
10 changed files
with
236 additions
and
29 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1,10 @@ | ||
# Copyright (c) OpenMMLab. All rights reserved. | ||
from .chex_algorithm import ChexAlgorithm | ||
from .chex_mutator import ChexMutator | ||
from .chex_ops import ChexConv2d, ChexLinear, ChexMixin | ||
from .chex_unit import ChexUnit | ||
|
||
__all__ = [ | ||
'ChexAlgorithm', 'ChexMutator', 'ChexUnit', 'ChexConv2d', 'ChexLinear', | ||
'ChexMixin' | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,35 +1,51 @@ | ||
# Copyright (c) OpenMMLab. All rights reserved. | ||
import math | ||
from typing import Dict, Optional, Union | ||
|
||
import torch.nn as nn | ||
from mmengine.model import BaseModel | ||
|
||
from mmrazor.models.algorithms import BaseAlgorithm | ||
from mmrazor.registry import MODELS | ||
from .chex_mutator import ChexMutator | ||
from .utils import RuntimeInfo | ||
|
||
|
||
class ChexAlgoritm(BaseAlgorithm): | ||
@MODELS.register_module() | ||
class ChexAlgorithm(BaseAlgorithm): | ||
|
||
def __init__(self, | ||
architecture: Union[BaseModel, Dict], | ||
data_preprocessor: Optional[Union[Dict, nn.Module]] = None, | ||
mutator_cfg=dict( | ||
type='ChexMutator', | ||
channel_unit_cfg=dict(type='ChexUnit')), | ||
delta_t=2, | ||
total_steps=10, | ||
init_growth_rate=0.3, | ||
init_cfg: Optional[Dict] = None): | ||
super().__init__(architecture, data_preprocessor, init_cfg) | ||
|
||
self.delta_t = delta_t | ||
self.total_steps = total_steps | ||
self.init_growth_rate = init_growth_rate | ||
|
||
self.mutator: ChexMutator = MODELS.build(mutator_cfg) | ||
self.mutator.prepare_from_supernet(self.architecture) | ||
|
||
def forward(self, inputs, data_samples=None, mode: str = 'tensor'): | ||
if True: # | ||
self.mutator.prune() | ||
self.mutator.grow(self.growth_ratio) | ||
if self.training: # | ||
if RuntimeInfo.iter() % self.delta_t == 0 and \ | ||
RuntimeInfo.iter() // self.delta_t < self.total_steps: | ||
self.mutator.prune() | ||
self.mutator.grow(self.growth_ratio) | ||
return super().forward(inputs, data_samples, mode) | ||
|
||
@property | ||
def _epoch(self): | ||
pass | ||
|
||
@property | ||
def growth_ratio(self): | ||
# return growth ratio in current epoch | ||
pass | ||
def cos(): | ||
a = math.pi * RuntimeInfo.epoch() / RuntimeInfo.max_epochs() | ||
return (math.cos(a) + 1) / 2 | ||
|
||
return self.init_growth_rate * cos() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
# Copyright (c) OpenMMLab. All rights reserved. | ||
from mmengine.logging import MessageHub | ||
|
||
|
||
class RuntimeInfo(): | ||
|
||
@classmethod | ||
def get_info(cls, key): | ||
hub = MessageHub.get_current_instance() | ||
if key in hub.runtime_info: | ||
return hub.runtime_info[key] | ||
else: | ||
raise KeyError(key) | ||
|
||
@classmethod | ||
def epoch(cls): | ||
return cls.get_info('epoch') | ||
|
||
@classmethod | ||
def max_epochs(cls): | ||
return cls.get_info('max_epochs') | ||
|
||
@classmethod | ||
def iter(cls): | ||
return cls.get_info('iter') |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,42 @@ | ||
# Copyright (c) OpenMMLab. All rights reserved. | ||
import unittest | ||
|
||
import torch | ||
from mmengine.logging import MessageHub | ||
|
||
from mmrazor.models.chex import ChexAlgorithm | ||
|
||
MODEL_CFG = dict( | ||
_scope_='mmcls', | ||
type='ImageClassifier', | ||
backbone=dict( | ||
type='ResNet', | ||
depth=18, | ||
num_stages=4, | ||
out_indices=(3, ), | ||
style='pytorch'), | ||
neck=dict(type='GlobalAveragePooling'), | ||
head=dict( | ||
type='LinearClsHead', | ||
num_classes=1000, | ||
in_channels=512, | ||
loss=dict(type='CrossEntropyLoss', loss_weight=1.0), | ||
topk=(1, 5), | ||
)) | ||
|
||
|
||
class TestChexAlgorithm(unittest.TestCase): | ||
|
||
def test_chex_algorithm(self): | ||
algorithm = ChexAlgorithm(MODEL_CFG) | ||
x = torch.rand([2, 3, 64, 64]) | ||
self._set_epoch_ite(0, 2, 100) | ||
_ = algorithm(x) | ||
|
||
def _set_epoch_ite(self, epoch, ite, max_epoch): | ||
iter_per_epoch = 10 | ||
message_hub = MessageHub.get_current_instance() | ||
message_hub.update_info('epoch', epoch) | ||
message_hub.update_info('max_epochs', max_epoch) | ||
message_hub.update_info('max_iters', max_epoch * 10) | ||
message_hub.update_info('iter', ite + iter_per_epoch * epoch) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,27 @@ | ||
# Copyright (c) OpenMMLab. All rights reserved. | ||
import unittest | ||
|
||
import torch | ||
import torch.nn as nn | ||
|
||
from mmrazor.models.chex.chex_mutator import ChexMutator | ||
from mmrazor.models.chex.chex_unit import ChexUnit | ||
from ..data.models import SingleLineModel | ||
|
||
|
||
class TestChexMutator(unittest.TestCase): | ||
|
||
def test_chex_mutator(self): | ||
model = SingleLineModel() | ||
mutator = ChexMutator(channel_unit_cfg=ChexUnit) | ||
mutator.prepare_from_supernet(model) | ||
|
||
for module in model.modules(): | ||
if isinstance(module, nn.modules.batchnorm._BatchNorm): | ||
module.weight.data = torch.rand_like(module.weight.data) | ||
|
||
mutator.prune() | ||
print(mutator.current_choices) | ||
|
||
mutator.grow(0.2) | ||
print(mutator.current_choices) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1,27 @@ | ||
# Copyright (c) OpenMMLab. All rights reserved. | ||
import unittest | ||
|
||
import torch | ||
import torch.nn as nn | ||
|
||
from mmrazor.models.chex import ChexConv2d | ||
from mmrazor.models.mutables import SimpleMutableChannel | ||
|
||
|
||
class TestChexOps(unittest.TestCase): | ||
|
||
def test_ops(self): | ||
for in_c, out_c in [(4, 8), (8, 4), (8, 8)]: | ||
conv = nn.Conv2d(in_c, out_c, 3, 1, 1) | ||
conv: ChexConv2d = ChexConv2d.convert_from(conv) | ||
|
||
mutable_in = SimpleMutableChannel(in_c) | ||
mutable_out = SimpleMutableChannel(out_c) | ||
|
||
conv.register_mutable_attr('in_channels', mutable_in) | ||
conv.register_mutable_attr('out_channels', mutable_out) | ||
|
||
mutable_out.current_choice = torch.normal(0, 1, [out_c]) < 0 | ||
|
||
self.assertEqual(list(conv.prune_imp(4).shape), [out_c]) | ||
self.assertEqual(list(conv.growth_imp.shape), [out_c]) |