Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Add Activation Boundaries Loss #214

Merged
merged 3 commits into from
Aug 5, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 73 additions & 0 deletions configs/distill/mmcls/abloss/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
# Activation Boundaries Loss (ABLoss)

> [Knowledge Transfer via Distillation of Activation Boundaries Formed by Hidden Neurons](https://arxiv.org/pdf/1811.03233.pdf)

<!-- [ALGORITHM] -->

## Abstract

An activation boundary for a neuron refers to a separating hyperplane that determines whether the neuron is activated or deactivated. It has been long considered in neural networks that the activations of neurons, rather than their exact output values, play the most important role in forming classification friendly partitions of the hidden feature space. However, as far as we know, this aspect of neural networks has not been considered in the literature of knowledge transfer. In this pa- per, we propose a knowledge transfer method via distillation of activation boundaries formed by hidden neurons. For the distillation, we propose an activation transfer loss that has the minimum value when the boundaries generated by the stu- dent coincide with those by the teacher. Since the activation transfer loss is not differentiable, we design a piecewise differentiable loss approximating the activation transfer loss. By the proposed method, the student learns a separating bound- ary between activation region and deactivation region formed by each neuron in the teacher. Through the experiments in various aspects of knowledge transfer, it is verified that the proposed method outperforms the current state-of-the-art [link](https://github.com/bhheo/AB_distillation)

![pipeline](/docs/en/imgs/model_zoo/abloss/pipeline.png)

## Results and models

### Classification

| Location | Dataset | Teacher | Student | Acc | Acc(T) | Acc(S) | Config | Download |
| :----------------------------------: | :------: | :----------------------------------------------------------------------------------------------------------: | :----------------------------------------------------------------------------------------------------------: | :---: | :----: | :----: | :--------------------------------------------------------------------------------------------------------------------------------------: | :---------------------------------------------------------------------------------------------------------------------------------------- |
| backbone (pretrain) & logits (train) | ImageNet | [resnet50](https://github.com/open-mmlab/mmclassification/blob/master/configs/resnet/resnet50_8xb32_in1k.py) | [resnet18](https://github.com/open-mmlab/mmclassification/blob/master/configs/resnet/resnet18_8xb32_in1k.py) | 70.58 | 76.55 | 69.90 | [pretrain_config](./abloss_backbone_resnet50_resnet18_8xb32_in1k_pretrain) [train_config](./abloss_head_resnet50_resnet18_8xb32_in1k.py) | [teacher](https://download.openmmlab.com/mmclassification/v0/resnet/resnet50_8xb32_in1k_20210831-ea4938fc.pth) \|[model](<>) \| [log](<>) |

## Citation

```latex
@inproceedings{DBLP:conf/aaai/HeoLY019a,
author = {Byeongho Heo, Minsik Lee, Sangdoo Yun and Jin Young Choi},
title = {Knowledge Transfer via Distillation of Activation Boundaries Formed
by Hidden Neurons},
booktitle = {The Thirty-Third {AAAI} Conference on Artificial Intelligence, {AAAI}
2019, The Thirty-First Innovative Applications of Artificial Intelligence
Conference, {IAAI} 2019, The Ninth {AAAI} Symposium on Educational
Advances in Artificial Intelligence, {EAAI} 2019, Honolulu, Hawaii,
USA, January 27 - February 1, 2019},
pages = {3779--3787},
publisher = {{AAAI} Press},
year = {2019},
url = {https://doi.org/10.1609/aaai.v33i01.33013779},
doi = {10.1609/aaai.v33i01.33013779},
timestamp = {Fri, 07 May 2021 11:57:04 +0200},
biburl = {https://dblp.org/rec/conf/aaai/HeoLY019a.bib},
bibsource = {dblp computer science bibliography, https://dblp.org}
}
```

## Getting Started

### ABConnectors and Student pre-training.

```bash
sh tools/slurm_train.sh $PARTITION $JOB_NAME \
configs/distill/mmcls/abloss/abloss_backbone_resnet50_resnet18_8xb32_in1k_pretrain.py\
$PRETRAIN_WORK_DIR

```

### Modify Distillation training config

open file 'configs/distill/mmcls/abloss/abloss_head_resnet50_resnet18_8xb32_in1k.py'

```python
# modify init_cfg in model settings
# pretrain_work_dir is same as the PRETRAIN_WORK_DIR in pre-training.
init_cfg=dict(
type='Pretrained', checkpoint='pretrain_work_dir/last_chechpoint.pth'),
```

### Distillation training.

```bash
sh tools/slurm_train.sh $PARTITION $JOB_NAME \
configs/distill/mmcls/abloss/abloss_head_resnet50_resnet18_8xb32_in1k.py\
$DISTILLATION_WORK_DIR

```
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
_base_ = [
'mmcls::_base_/datasets/imagenet_bs32.py',
'mmcls::_base_/schedules/imagenet_bs256.py',
'mmcls::_base_/default_runtime.py'
]

train_cfg = dict(by_epoch=True, max_epochs=20, val_interval=1)

model = dict(
_scope_='mmrazor',
type='SingleTeacherDistill',
data_preprocessor=dict(
type='ImgDataPreprocessor',
# RGB format normalization parameters
mean=[123.675, 116.28, 103.53],
std=[58.395, 57.12, 57.375],
# convert image from BGR to RGB
bgr_to_rgb=True),
architecture=dict(
cfg_path='mmcls::resnet/resnet18_8xb32_in1k.py', pretrained=False),
teacher=dict(
cfg_path='mmcls::resnet/resnet50_8xb32_in1k.py', pretrained=True),
teacher_ckpt='resnet50_8xb32_in1k_20210831-ea4938fc.pth',
calculate_student_loss=False,
distiller=dict(
type='ConfigurableDistiller',
student_recorders=dict(
bb_s4=dict(type='ModuleOutputs', source='backbone.layer4.1.conv2'),
bb_s3=dict(type='ModuleOutputs', source='backbone.layer3.1.conv2'),
bb_s2=dict(type='ModuleOutputs', source='backbone.layer2.1.conv2'),
bb_s1=dict(type='ModuleOutputs',
source='backbone.layer1.1.conv2')),
teacher_recorders=dict(
bb_s4=dict(type='ModuleOutputs', source='backbone.layer4.2.conv3'),
bb_s3=dict(type='ModuleOutputs', source='backbone.layer3.5.conv3'),
bb_s2=dict(type='ModuleOutputs', source='backbone.layer2.3.conv3'),
bb_s1=dict(type='ModuleOutputs',
source='backbone.layer1.2.conv3')),
distill_losses=dict(
loss_s4=dict(type='ABLoss', loss_weight=1.0),
loss_s3=dict(type='ABLoss', loss_weight=0.5),
loss_s2=dict(type='ABLoss', loss_weight=0.25),
loss_s1=dict(type='ABLoss', loss_weight=0.125)),
connectors=dict(
loss_s4_sfeat=dict(
type='ConvModuleConncetor',
in_channel=512,
out_channel=2048,
norm_cfg=dict(type='BN'),
act_cfg=None),
loss_s3_sfeat=dict(
type='ConvModuleConncetor',
in_channel=256,
out_channel=1024,
norm_cfg=dict(type='BN'),
act_cfg=None),
loss_s2_sfeat=dict(
type='ConvModuleConncetor',
in_channel=128,
out_channel=512,
norm_cfg=dict(type='BN'),
act_cfg=None),
loss_s1_sfeat=dict(
type='ConvModuleConncetor',
in_channel=64,
out_channel=256,
norm_cfg=dict(type='BN'),
act_cfg=None)),
loss_forward_mappings=dict(
loss_s4=dict(
s_feature=dict(
from_student=True,
recorder='bb_s4',
connector='loss_s4_sfeat'),
t_feature=dict(from_student=False, recorder='bb_s4')),
loss_s3=dict(
s_feature=dict(
from_student=True,
recorder='bb_s3',
connector='loss_s3_sfeat'),
t_feature=dict(from_student=False, recorder='bb_s3')),
loss_s2=dict(
s_feature=dict(
from_student=True,
recorder='bb_s2',
connector='loss_s2_sfeat'),
t_feature=dict(from_student=False, recorder='bb_s2')),
loss_s1=dict(
s_feature=dict(
from_student=True,
recorder='bb_s1',
connector='loss_s1_sfeat'),
t_feature=dict(from_student=False, recorder='bb_s1')))))

find_unused_parameters = True

val_cfg = dict(_delete_=True, type='mmrazor.SingleTeacherDistillValLoop')
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
_base_ = [
'mmcls::_base_/datasets/imagenet_bs32.py',
'mmcls::_base_/schedules/imagenet_bs256.py',
'mmcls::_base_/default_runtime.py'
]

model = dict(
_scope_='mmrazor',
type='SingleTeacherDistill',
data_preprocessor=dict(
type='ImgDataPreprocessor',
# RGB format normalization parameters
mean=[123.675, 116.28, 103.53],
std=[58.395, 57.12, 57.375],
# convert image from BGR to RGB
bgr_to_rgb=True),
architecture=dict(
cfg_path='mmcls::resnet/resnet18_8xb32_in1k.py', pretrained=False),
teacher=dict(
cfg_path='mmcls::resnet/resnet50_8xb32_in1k.py', pretrained=False),
init_cfg=dict(
type='Pretrained', checkpoint='pretrain_work_dir/last_chechpoint.pth'),
distiller=dict(
type='ConfigurableDistiller',
student_recorders=dict(
fc=dict(type='ModuleOutputs', source='head.fc')),
teacher_recorders=dict(
fc=dict(type='ModuleOutputs', source='head.fc')),
distill_losses=dict(
loss_kl=dict(
type='KLDivergence', loss_weight=200, reduction='mean')),
loss_forward_mappings=dict(
loss_kl=dict(
preds_S=dict(from_student=True, recorder='fc'),
preds_T=dict(from_student=False, recorder='fc')))))

find_unused_parameters = True

val_cfg = dict(_delete_=True, type='mmrazor.SingleTeacherDistillValLoop')
Original file line number Diff line number Diff line change
Expand Up @@ -36,12 +36,12 @@
type='KLDivergence', tau=6, loss_weight=10, reduction='mean')),
connectors=dict(
loss_s4_sfeat=dict(
type='ConvBNReLUConnector',
type='ConvModuleConncetor',
in_channel=512,
out_channel=2048,
norm_cfg=dict(type='BN')),
loss_s3_sfeat=dict(
type='ConvBNReLUConnector',
type='ConvModuleConncetor',
in_channel=256,
out_channel=1024,
norm_cfg=dict(type='BN'))),
Expand Down
Binary file added docs/en/imgs/model_zoo/abloss/pipeline.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
5 changes: 2 additions & 3 deletions mmrazor/models/architectures/connectors/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .general_connector import (ConvBNConnector, ConvBNReLUConnector,
SingleConvConnector)
from .convmodule_connector import ConvModuleConncetor

__all__ = ['ConvBNConnector', 'ConvBNReLUConnector', 'SingleConvConnector']
__all__ = ['ConvModuleConncetor']
92 changes: 92 additions & 0 deletions mmrazor/models/architectures/connectors/convmodule_connector.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Dict, Optional, Tuple, Union

import torch
from mmcv.cnn import ConvModule

from mmrazor.registry import MODELS
from .base_connector import BaseConnector


@MODELS.register_module()
class ConvModuleConncetor(BaseConnector):
"""Convolution connector that bundles conv/norm/activation layers.

Args:
in_channel (int): The input channel of the connector.
out_channel (int): The output channel of the connector.
kernel_size (int | tuple[int]): Size of the convolving kernel.
Same as that in ``nn._ConvNd``.
stride (int | tuple[int]): Stride of the convolution.
Same as that in ``nn._ConvNd``.
padding (int | tuple[int]): Zero-padding added to both sides of
the input. Same as that in ``nn._ConvNd``.
dilation (int | tuple[int]): Spacing between kernel elements.
Same as that in ``nn._ConvNd``.
groups (int): Number of blocked connections from input channels to
output channels. Same as that in ``nn._ConvNd``.
bias (bool | str): If specified as `auto`, it will be decided by the
norm_cfg. Bias will be set as True if `norm_cfg` is None, otherwise
False. Default: "auto".
conv_cfg (dict): Config dict for convolution layer. Default: None,
which means using conv2d.
norm_cfg (dict): Config dict for normalization layer. Default: None.
act_cfg (dict): Config dict for activation layer.
Default: dict(type='ReLU').
inplace (bool): Whether to use inplace mode for activation.
Default: True.
with_spectral_norm (bool): Whether use spectral norm in conv module.
Default: False.
padding_mode (str): If the `padding_mode` has not been supported by
current `Conv2d` in PyTorch, we will use our own padding layer
instead. Currently, we support ['zeros', 'circular'] with official
implementation and ['reflect'] with our own implementation.
Default: 'zeros'.
order (tuple[str]): The order of conv/norm/activation layers. It is a
sequence of "conv", "norm" and "act". Common examples are
("conv", "norm", "act") and ("act", "conv", "norm").
Default: ('conv', 'norm', 'act').
init_cfg (dict, optional): The config to control the initialization.
"""

def __init__(
self,
in_channel: int,
out_channel: int,
kernel_size: Union[int, Tuple[int]] = 1,
stride: Union[int, Tuple[int]] = 1,
padding: Union[int, Tuple[int]] = 0,
dilation: Union[int, Tuple[int]] = 1,
groups: int = 1,
bias: Union[str, bool] = 'auto',
conv_cfg: Optional[Dict] = None,
norm_cfg: Optional[Dict] = None,
act_cfg: Dict = dict(type='ReLU'),
inplace: bool = True,
with_spectral_norm: bool = False,
padding_mode: str = 'zeros',
order: tuple = ('conv', 'norm', 'act'),
init_cfg: Optional[Dict] = None,
) -> None:
super().__init__(init_cfg)
self.conv_module = ConvModule(in_channel, out_channel, kernel_size,
stride, padding, dilation, groups, bias,
conv_cfg, norm_cfg, act_cfg, inplace,
with_spectral_norm, padding_mode, order)

def forward_train(self, feature: torch.Tensor) -> torch.Tensor:
"""Forward computation.

Args:
feature (torch.Tensor): Input feature.
"""
for layer in self.conv_module.order:
if layer == 'conv':
if self.conv_module.with_explicit_padding:
feature = self.conv_module.padding_layer(feature)
feature = self.conv_module.conv(feature)
elif layer == 'norm' and self.conv_module.with_norm:
feature = self.conv_module.norm(feature)
elif layer == 'act' and self.conv_module.with_activation:
feature = self.conv_module.activate(feature)
return feature
Loading