Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Docs] Add files for KD algo DFND #586

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 31 additions & 0 deletions configs/distill/mmcls/dfnd/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
# Learning Student Networks in the Wild (DFND)

> [Learning Student Networks in the Wild](https://openaccess.thecvf.com/content/CVPR2021/papers/Chen_Learning_Student_Networks_in_the_Wild_CVPR_2021_paper.pdf)

<!-- [ALGORITHM] -->

## Abstract

Data-free learning for student networks is a new paradigm for solving users’ anxiety caused by the privacy problem of using original training data. Since the architectures of modern convolutional neural networks (CNNs) are compact and sophisticated, the alternative images or meta-data generated from the teacher network are often broken. Thus, the student network cannot achieve the comparable performance to that of the pre-trained teacher network especially on the large-scale image dataset. Different to previous works, we present to maximally utilize the massive available unlabeled data in the wild. Specifically, we first thoroughly analyze the output differences between teacher and student network on the original data and develop a data collection method. Then, a noisy knowledge distillation algorithm is proposed for achieving the performance of the student network. In practice, an adaptation matrix is learned with the student network for correcting the label noise produced by the teacher network on the collected unlabeled images. The effectiveness of our DFND (DataFree Noisy Distillation) method is then verified on several benchmarks to demonstrate its superiority over state-of-theart data-free distillation methods. Experiments on various datasets demonstrate that the student networks learned by the proposed method can achieve comparable performance with those using the original dataset.

<img width="910" alt="pipeline" src="./dfnd.PNG">

## Results and models

### Classification

| Location | Dataset | Teacher | Student | Acc | Acc(T) | Acc(S) | Config | |
| :---------------: | :-----: | :-------------------------------------------------------------------------------------------------------------: | :-------------------------------------------------------------------------------------------------------------: | :---: | :----: | :----: | :---------------------------------------------------------: | :---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
| backbone & logits | Cifar10 | [resnet34](https://github.com/open-mmlab/mmclassification/blob/master/configs/resnet/resnet34_8xb16_cifar10.py) | [resnet18](https://github.com/open-mmlab/mmclassification/blob/master/configs/resnet/resnet18_8xb16_cifar10.py) | 94.78 | 95.34 | 94.82 | [config](./dfnd_logits_resnet34_resnet18_8xb32_cifar10.py) | [student](https://drive.google.com/file/d/1_MekfTkCsEl68meWPqtdNZIxdJO2R2Eb/view?usp=drive_link) |

## Citation

```latex
@inproceedings{chen2021learning,
title={Learning student networks in the wild},
author={Chen, Hanting and Guo, Tianyu and Xu, Chang and Li, Wenshuo and Xu, Chunjing and Xu, Chao and Wang, Yunhe},
booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition},
pages={6428--6437},
year={2021}
}
```
Binary file added configs/distill/mmcls/dfnd/dfnd.PNG
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
_base_ = ['mmcls::_base_/default_runtime.py']

# optimizer
optim_wrapper = dict(
optimizer=dict(type='SGD', lr=0.1, momentum=0.9, weight_decay=0.0001))
# learning policy
param_scheduler = dict(
type='MultiStepLR', by_epoch=True, milestones=[320, 640], gamma=0.1)

# train, val, test setting
train_cfg = dict(by_epoch=True, max_epochs=800, val_interval=1)
test_cfg = dict()

# NOTE: `auto_scale_lr` is for automatically scaling LR
# based on the actual training batch size.
auto_scale_lr = dict(base_batch_size=128)

train_pipeline = [
dict(type='LoadImageFromFile'),
dict(type='RandomResizedCrop', scale=32),
dict(type='RandomFlip', prob=0.5, direction='horizontal'),
dict(type='PackClsInputs'),
]

train_dataloader = dict(
batch_size=256,
num_workers=5,
dataset=dict(
type='ImageNet',
data_root='/cache/data/imagenet/',
data_prefix='train',
pipeline=train_pipeline),
sampler=dict(type='DefaultSampler', shuffle=True),
)

test_pipeline = [
dict(type='PackClsInputs'),
]

val_dataloader = dict(
batch_size=16,
num_workers=2,
dataset=dict(
type='CIFAR10',
data_prefix='/cache/data/cifar',
test_mode=True,
pipeline=test_pipeline),
sampler=dict(type='DefaultSampler', shuffle=False),
)
val_evaluator = dict(type='Accuracy', topk=(1, ))

test_dataloader = val_dataloader
test_evaluator = val_evaluator

teacher_ckpt = '/cache/models/resnet_model.pth' # noqa: E501

model = dict(
_scope_='mmrazor',
type='DFNDDistill',
calculate_student_loss=False,
data_preprocessor=dict(
type='ImgDataPreprocessor',
# RGB format normalization parameters
mean=[123.675, 116.28, 103.53],
std=[58.395, 57.12, 57.375],
# convert image from BGR to RGB
bgr_to_rgb=True),
val_data_preprocessor=dict(
type='ImgDataPreprocessor',
# RGB format normalization parameters
mean=[125.307, 122.961, 113.8575],
std=[51.5865, 50.847, 51.255],
# convert image from BGR to RGB
bgr_to_rgb=False),
architecture=dict(
cfg_path='mmcls::resnet/resnet18_8xb16_cifar10.py', pretrained=False),
teacher=dict(
cfg_path='mmcls::resnet/resnet34_8xb16_cifar10.py', pretrained=False),
teacher_ckpt=teacher_ckpt,
distiller=dict(
type='ConfigurableDistiller',
student_recorders=dict(
fc=dict(type='ModuleOutputs', source='head.fc')),
teacher_recorders=dict(
fc=dict(type='ModuleOutputs', source='head.fc')),
distill_losses=dict(
loss_kl=dict(
type='DFNDLoss',
tau=4,
loss_weight=1,
num_classes=10,
batch_select=0.5)),
loss_forward_mappings=dict(
loss_kl=dict(
preds_S=dict(from_student=True, recorder='fc'),
preds_T=dict(from_student=False, recorder='fc')))))

find_unused_parameters = True

val_cfg = dict(type='mmrazor.DFNDValLoop')
5 changes: 3 additions & 2 deletions mmrazor/engine/runner/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .autoslim_greedy_search_loop import AutoSlimGreedySearchLoop
from .darts_loop import DartsEpochBasedTrainLoop, DartsIterBasedTrainLoop
from .distill_val_loop import SelfDistillValLoop, SingleTeacherDistillValLoop
from .distill_val_loop import (DFNDValLoop, SelfDistillValLoop,
SingleTeacherDistillValLoop)
from .evolution_search_loop import EvolutionSearchLoop
from .iteprune_val_loop import ItePruneValLoop
from .quantization_loops import (LSQEpochBasedLoop, PTQLoop, QATEpochBasedLoop,
Expand All @@ -15,5 +16,5 @@
'DartsIterBasedTrainLoop', 'SlimmableValLoop', 'EvolutionSearchLoop',
'GreedySamplerTrainLoop', 'SubnetValLoop', 'SelfDistillValLoop',
'ItePruneValLoop', 'AutoSlimGreedySearchLoop', 'QATEpochBasedLoop',
'PTQLoop', 'LSQEpochBasedLoop', 'QATValLoop'
'PTQLoop', 'LSQEpochBasedLoop', 'QATValLoop', 'DFNDValLoop'
]
35 changes: 35 additions & 0 deletions mmrazor/engine/runner/distill_val_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,3 +125,38 @@ def run(self):

self.runner.call_hook('after_val_epoch', metrics=student_metrics)
self.runner.call_hook('after_val')


@LOOPS.register_module()
class DFNDValLoop(SingleTeacherDistillValLoop):
"""Validation loop for DFND. DFND requires different dataset for training
and validation.

Args:
runner (Runner): A reference of runner.
dataloader (Dataloader or dict): A dataloader object or a dict to
build a dataloader.
evaluator (Evaluator or dict or list): Used for computing metrics.
fp16 (bool): Whether to enable fp16 validation. Defaults to
False.
"""

def __init__(self,
runner,
dataloader: Union[DataLoader, Dict],
evaluator: Union[Evaluator, Dict, List],
fp16: bool = False) -> None:
super().__init__(runner, dataloader, evaluator, fp16)
if self.runner.distributed:
assert hasattr(self.runner.model.module, 'teacher')
# TODO: remove hard code after mmcls add data_preprocessor
data_preprocessor = self.runner.model.module.val_data_preprocessor
self.teacher = self.runner.model.module.teacher
self.teacher.data_preprocessor = data_preprocessor

else:
assert hasattr(self.runner.model, 'teacher')
# TODO: remove hard code after mmcls add data_preprocessor
data_preprocessor = self.runner.model.val_data_preprocessor
self.teacher = self.runner.model.teacher
self.teacher.data_preprocessor = data_preprocessor
3 changes: 2 additions & 1 deletion mmrazor/models/algorithms/distill/configurable/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .datafree_distillation import (DAFLDataFreeDistillation,
DataFreeDistillation)
from .dfnd_distill import DFNDDistill
from .fpn_teacher_distill import FpnTeacherDistill
from .overhaul_feature_distillation import OverhaulFeatureDistillation
from .self_distill import SelfDistill
Expand All @@ -9,5 +10,5 @@
__all__ = [
'SelfDistill', 'SingleTeacherDistill', 'FpnTeacherDistill',
'DataFreeDistillation', 'DAFLDataFreeDistillation',
'OverhaulFeatureDistillation'
'OverhaulFeatureDistillation', 'DFNDDistill'
]
Loading