# MutableChannelUnit

Each MutableChannelUnit is a basic unit for pruning. It records all channels which are dependent on each other.
Below, we will introduce you about:
1. The data structure of MutableChannelUnit.
2. How to prune the model with a MutableChannelUnit.
3. How to get MutableChannelUnits.
4. How to develop a new MutableChannelUnit for a new pruning algorithm.
<p align="center"><img src="../../../../../docs/en/imgs/pruning/unit.png" alt="MutableChannelUnit" width="800"></p>

## The Data Structure of MutableChannelUnit

First, let's parse a model and get several MutableChannelUnits.

In [1]:
# define a model
from mmengine.model import BaseModel
from torch import nn
import torch
from collections import OrderedDict

class MyModel(BaseModel):

    def __init__(self):
        super().__init__(None, None)
        self.net = nn.Sequential(
            OrderedDict([('conv0', nn.Conv2d(3, 8, 3, 1, 1)),
                         ('relu', nn.ReLU()),
                         ('conv1', nn.Conv2d(8, 16, 3, 1, 1))]))
        self.pool = nn.AdaptiveAvgPool2d(1)
        self.head = nn.Linear(16, 1000)

    def forward(self, x):
        feature = self.net(x)
        pool = self.pool(feature).flatten(1)
        return self.head(pool)

In [2]:
# There are multiple types of MutableChannelUnits. Here, We take SequentialMutableChannelUnit as the example.
from mmrazor.models.mutables.mutable_channel.units import SequentialMutableChannelUnit
from mmrazor.structures.graph import ModuleGraph
from typing import List

model = MyModel()
graph = ModuleGraph.init_from_backward_tracer(model)
units: List[
    SequentialMutableChannelUnit] = SequentialMutableChannelUnit.init_from_graph(graph)  # type: ignore
print(
    f'This model has {len(units)} MutableChannelUnit(SequentialMutableChannelUnit).'
)


This model has 4 MutableChannelUnit(SequentialMutableChannelUnit).


In [3]:
unit1=units[1]
print(unit1)

SequentialMutableChannelUnit(
  name=net.conv0_(0, 8)_8
  (output_related): ModuleList(
    (0): Channel(net.conv0, index=(0, 8), is_output_channel=true, expand_ratio=1)
  )
  (input_related): ModuleList(
    (0): Channel(net.conv1, index=(0, 8), is_output_channel=false, expand_ratio=1)
  )
  (mutable_channel): SquentialMutableChannel(num_channels=8, activated_channels=8)
)


As shown above, each MutableChannelUnit has four important attributes: 
1. name: str
2. output_related: ModuleList
3. input_related: ModuleList
4. mutable_channel: BaseMutableChannel

"name" is the identifier of the MutableChannelUnit. It's automatically generated usually.

"output_related" and "input_related" are two ModuleLists. They store all Channels with channel dependency.
The difference is that the "output_related" includes output channels and the "input_related" includes input channels.
All these channels

"mutable_channel" is a BaseMutableChannel used to control the channel mask of modules. The mutable_channel is registered to the modules whose channels are stored in "output_related" and "input_related".

## How to prune the model with a MutableChannelUnit.

There are three steps to prune the model using a MutableChannelUnit:
1. replace modules, whose channel are stored in the "output_related" and "input_related", with dynamic ops which are able to deal with mutable number of channels.
2. register the "mutable_channel" to the replaced dynamic ops.
3. change the choice of the "mutable_channel".

For simplicity, we run step 1 and 2 with one method "prepare_for_pruning".

In [4]:
# We run "prepare_for_pruning" once before pruning to run step 1 and 2 above.
unit1.prepare_for_pruning(model)
print(f'The current choice of unit1 is {unit1.current_choice}.')
print(model.net.conv0)
print(model.net.conv1)

The current choice of unit1 is 8.
DynamicConv2d(
  3, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)
  (mutable_attrs): ModuleDict(
    (in_channels): MutableChannelContainer(num_channels=3, activated_channels=3)
    (out_channels): MutableChannelContainer(num_channels=8, activated_channels=8)
  )
)
DynamicConv2d(
  8, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)
  (mutable_attrs): ModuleDict(
    (in_channels): MutableChannelContainer(num_channels=8, activated_channels=8)
    (out_channels): MutableChannelContainer(num_channels=16, activated_channels=16)
  )
)


We prune the model by changing the current_choice of the MutableChannelUnits.

In [5]:
sampled_choice=unit1.sample_choice()
print(f'We get a sampled choice {sampled_choice}.')
unit1.current_choice=sampled_choice
print(model.net.conv0)
print(model.net.conv1)

We get a sampled choice 2.
DynamicConv2d(
  3, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)
  (mutable_attrs): ModuleDict(
    (in_channels): MutableChannelContainer(num_channels=3, activated_channels=3)
    (out_channels): MutableChannelContainer(num_channels=8, activated_channels=2)
  )
)
DynamicConv2d(
  8, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)
  (mutable_attrs): ModuleDict(
    (in_channels): MutableChannelContainer(num_channels=8, activated_channels=2)
    (out_channels): MutableChannelContainer(num_channels=16, activated_channels=16)
  )
)


Besides, different types of MutableChannelUnit may have different types of choices. Please read documents for more details.

## How to get MutableChannelUnits.

There are three ways to get MutableChannelUnits.
1. Using a tracer.
   This way, firstly, converts a model to a graph, then converts the graph to MutableChannelUnits. It automatically returns all available MutableChannelUnits.
2. Using a config.
   This way uses a config to initialize a MutableChannelUnit.
3. Using a predefined model.
   This way parses a predefined model with dynamic ops. It returns all available MutableChannelUnits.

All these three ways have corresponding documents in the README of ChannelMutator.

In [6]:
# 1. using tracer
def get_mutable_channel_units_using_tracer(model):
    graph = ModuleGraph.init_from_backward_tracer(model)
    units = SequentialMutableChannelUnit.init_from_graph(graph)
    return units


model = MyModel()
units = get_mutable_channel_units_using_tracer(model)
print(f'The model has {len(units)} MutableChannelUnits.')

The model has 4 MutableChannelUnits.


In [7]:
# 2. using config
config = {
    'init_args': {
        'num_channels': 8,
    },
    'channels': {
        'input_related': [{
            'name': 'net.conv1',
        }],
        'output_related': [{
            'name': 'net.conv0',
        }]
    },
    'choice': 8
}
unit=SequentialMutableChannelUnit.init_from_cfg(model, config)
print(unit)

SequentialMutableChannelUnit(
  name=net.conv0_(0, 8)_8
  (output_related): ModuleList(
    (0): Channel(net.conv0, index=(0, 8), is_output_channel=true, expand_ratio=1)
  )
  (input_related): ModuleList(
    (0): Channel(net.conv1, index=(0, 8), is_output_channel=false, expand_ratio=1)
  )
  (mutable_channel): SquentialMutableChannel(num_channels=8, activated_channels=8)
)


In [8]:
# 3. using predefined model

from mmrazor.models.architectures.dynamic_ops import DynamicConv2d, DynamicLinear
from mmrazor.models.mutables import MutableChannelUnit, MutableChannelContainer,SquentialMutableChannel
from collections import OrderedDict

class MyDynamicModel(BaseModel):

    def __init__(self):
        super().__init__(None, None)
        self.net = nn.Sequential(
            OrderedDict([('conv0', DynamicConv2d(3, 8, 3, 1, 1)),
                         ('relu', nn.ReLU()),
                         ('conv1', DynamicConv2d(8, 16, 3, 1, 1))]))
        self.pool = nn.AdaptiveAvgPool2d(1)
        self.head = DynamicLinear(16, 1000)

        # register MutableChannelContainer
        MutableChannelUnit._register_channel_container(
            self, MutableChannelContainer)
        self._register_mutables()

    def forward(self, x):
        feature = self.net(x)
        pool = self.pool(feature).flatten(1)
        return self.head(pool)

    def _register_mutables(self):
        mutable1 = SquentialMutableChannel(8)
        mutable2 = SquentialMutableChannel(16)
        MutableChannelContainer.register_mutable_channel_to_module(
            self.net.conv0, mutable1, is_to_output_channel=True)
        MutableChannelContainer.register_mutable_channel_to_module(
            self.net.conv1, mutable1, is_to_output_channel=False)

        MutableChannelContainer.register_mutable_channel_to_module(
            self.net.conv1, mutable2, is_to_output_channel=True)
        MutableChannelContainer.register_mutable_channel_to_module(
            self.head, mutable2, is_to_output_channel=False)
model=MyDynamicModel()
units=SequentialMutableChannelUnit.init_from_predefined_model(model)            
print(f'The model has {len(units)} MutableChannelUnits.')

The model has 2 MutableChannelUnits.
