-
Notifications
You must be signed in to change notification settings - Fork 223
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* [feature] CONTRASTIVE REPRESENTATION DISTILLATION with dataset wrapper (#281) * init * TD: CRDLoss * complete UT * fix docstrings * fix ci * update * fix CI * DONE * maintain CRD dataset unique funcs as a mixin * maintain CRD dataset unique funcs as a mixin * maintain CRD dataset unique funcs as a mixin * add UT: CRD_ClsDataset * init * TODO: UT test formatting. * init * crd dataset wrapper * update docstring Co-authored-by: huangpengsheng <huangpengsheng@sensetime.com> * [Improvement] Update estimator with api revision (#277) * update estimator usage and fix bugs * refactor api of estimator & add inner check methods * fix docstrings * update search loop and config * fix lint * update unittest * decouple mmdet dependency and fix lint Co-authored-by: humu789 <humu@pjlab.org.cn> * [Fix] Fix tracer (#273) * test image_classifier_loss_calculator * fix backward tracer * update SingleStageDetectorPseudoLoss * merge * [Feature] Add Dsnas Algorithm (#226) * [tmp] Update Dsnas * [tmp] refactor arch_loss & flops_loss * Update Dsnas & MMRAZOR_EVALUATOR: 1. finalized compute_loss & handle_grads in algorithm; 2. add MMRAZOR_EVALUATOR; 3. fix bugs. * Update lr scheduler & fix a bug: 1. update param_scheduler & lr_scheduler for dsnas; 2. fix a bug of switching to finetune stage. * remove old evaluators * remove old evaluators * update param_scheduler config * merge dev-1.x into gy/estimator * add flops_loss in Dsnas using ResourcesEstimator * get resources before mutator.prepare_from_supernet * delete unness broadcast api from gml * broadcast spec_modules_resources when estimating * update early fix mechanism for Dsnas * fix merge * update units in estimator * minor change * fix data_preprocessor api * add flops_loss_coef * remove DsnasOptimWrapper * fix bn eps and data_preprocessor * fix bn weight decay bug * add betas for mutator optimizer * set diff_rank_seed=True for dsnas * fix start_factor of lr when warm up * remove .module in non-ddp mode * add GlobalAveragePoolingWithDropout * add UT for dsnas * remove unness channel adjustment for shufflenetv2 * update supernet configs * delete unness dropout * delete unness part with minor change on dsnas * minor change on the flag of search stage * update README and subnet configs * add UT for OneHotMutableOP * [Feature] Update train (#279) * support auto resume * add enable auto_scale_lr in train.py * support '--amp' option * [Fix] Fix darts metafile (#278) fix darts metafile * fix ci (#284) * fix ci for circle ci * fix bug in test_metafiles * add pr_stage_test for github ci * add multiple version * fix ut * fix lint * Temporarily skip dataset UT * update github ci * add github lint ci * install wheel * remove timm from requirements * install wheel when test on windows * fix error * fix bug * remove github windows ci * fix device error of arch_params when DsnasDDP * fix CRD dataset ut * fix scope error * rm test_cuda in workflows of github * [Doc] fix typos in en/usr_guides Co-authored-by: liukai <liukai@pjlab.org.cn> Co-authored-by: pppppM <gjf_mail@126.com> Co-authored-by: gaoyang07 <1546308416@qq.com> Co-authored-by: huangpengsheng <huangpengsheng@sensetime.com> Co-authored-by: SheffieldCao <1751899@tongji.edu.cn> Co-authored-by: P.Huang <37200926+FreakieHuang@users.noreply.github.com> Co-authored-by: huangpengsheng <huangpengsheng@sensetime.com> Co-authored-by: Yang Gao <Gary1546308416AL@gmail.com> Co-authored-by: humu789 <humu@pjlab.org.cn> Co-authored-by: whcao <41630003+HIT-cwh@users.noreply.github.com> Co-authored-by: liukai <liukai@pjlab.org.cn> Co-authored-by: pppppM <gjf_mail@126.com> Co-authored-by: gaoyang07 <1546308416@qq.com> Co-authored-by: SheffieldCao <1751899@tongji.edu.cn>
- Loading branch information
1 parent
d137b67
commit 8330b62
Showing
68 changed files
with
2,895 additions
and
411 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,28 @@ | ||
norm_cfg = dict(type='BN', eps=0.01) | ||
|
||
_STAGE_MUTABLE = dict( | ||
type='mmrazor.OneHotMutableOP', | ||
fix_threshold=0.3, | ||
candidates=dict( | ||
shuffle_3x3=dict( | ||
type='ShuffleBlock', kernel_size=3, norm_cfg=norm_cfg), | ||
shuffle_5x5=dict( | ||
type='ShuffleBlock', kernel_size=5, norm_cfg=norm_cfg), | ||
shuffle_7x7=dict( | ||
type='ShuffleBlock', kernel_size=7, norm_cfg=norm_cfg), | ||
shuffle_xception=dict(type='ShuffleXception', norm_cfg=norm_cfg))) | ||
|
||
arch_setting = [ | ||
# Parameters to build layers. 3 parameters are needed to construct a | ||
# layer, from left to right: channel, num_blocks, mutable_cfg. | ||
[64, 4, _STAGE_MUTABLE], | ||
[160, 4, _STAGE_MUTABLE], | ||
[320, 8, _STAGE_MUTABLE], | ||
[640, 4, _STAGE_MUTABLE] | ||
] | ||
|
||
nas_backbone = dict( | ||
type='mmrazor.SearchableShuffleNetV2', | ||
widen_factor=1.0, | ||
arch_setting=arch_setting, | ||
norm_cfg=norm_cfg) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,102 @@ | ||
# dataset settings | ||
dataset_type = 'mmcls.ImageNet' | ||
data_preprocessor = dict( | ||
type='mmcls.ClsDataPreprocessor', | ||
# RGB format normalization parameters | ||
mean=[123.675, 116.28, 103.53], | ||
std=[58.395, 57.12, 57.375], | ||
# convert image from BGR to RGB | ||
to_rgb=True, | ||
) | ||
|
||
train_pipeline = [ | ||
dict(type='mmcls.LoadImageFromFile'), | ||
dict(type='mmcls.RandomResizedCrop', scale=224), | ||
dict(type='mmcls.RandomFlip', prob=0.5, direction='horizontal'), | ||
dict(type='mmcls.PackClsInputs'), | ||
] | ||
|
||
test_pipeline = [ | ||
dict(type='mmcls.LoadImageFromFile'), | ||
dict(type='mmcls.ResizeEdge', scale=256, edge='short'), | ||
dict(type='mmcls.CenterCrop', crop_size=224), | ||
dict(type='mmcls.PackClsInputs'), | ||
] | ||
|
||
train_dataloader = dict( | ||
batch_size=128, | ||
num_workers=4, | ||
dataset=dict( | ||
type=dataset_type, | ||
data_root='data/imagenet', | ||
ann_file='meta/train.txt', | ||
data_prefix='train', | ||
pipeline=train_pipeline), | ||
sampler=dict(type='mmcls.DefaultSampler', shuffle=True), | ||
persistent_workers=True, | ||
) | ||
|
||
val_dataloader = dict( | ||
batch_size=128, | ||
num_workers=4, | ||
dataset=dict( | ||
type=dataset_type, | ||
data_root='data/imagenet', | ||
ann_file='meta/val.txt', | ||
data_prefix='val', | ||
pipeline=test_pipeline), | ||
sampler=dict(type='mmcls.DefaultSampler', shuffle=False), | ||
persistent_workers=True, | ||
) | ||
val_evaluator = dict(type='mmcls.Accuracy', topk=(1, 5)) | ||
|
||
# If you want standard test, please manually configure the test dataset | ||
test_dataloader = val_dataloader | ||
test_evaluator = val_evaluator | ||
|
||
# optimizer | ||
paramwise_cfg = dict(bias_decay_mult=0.0, norm_decay_mult=0.0) | ||
|
||
optim_wrapper = dict( | ||
constructor='mmrazor.SeparateOptimWrapperConstructor', | ||
architecture=dict( | ||
optimizer=dict( | ||
type='mmcls.SGD', lr=0.5, momentum=0.9, weight_decay=4e-5), | ||
paramwise_cfg=paramwise_cfg), | ||
mutator=dict( | ||
optimizer=dict( | ||
type='mmcls.Adam', lr=0.001, weight_decay=0.0, betas=(0.5, | ||
0.999)))) | ||
|
||
search_epochs = 85 | ||
# leanring policy | ||
param_scheduler = dict( | ||
architecture=[ | ||
dict( | ||
type='mmcls.LinearLR', | ||
end=5, | ||
start_factor=0.2, | ||
by_epoch=True, | ||
convert_to_iter_based=True), | ||
dict( | ||
type='mmcls.CosineAnnealingLR', | ||
T_max=240, | ||
begin=5, | ||
end=search_epochs, | ||
by_epoch=True, | ||
convert_to_iter_based=True), | ||
dict( | ||
type='mmcls.CosineAnnealingLR', | ||
T_max=160, | ||
begin=search_epochs, | ||
end=240, | ||
eta_min=0.0, | ||
by_epoch=True, | ||
convert_to_iter_based=True) | ||
], | ||
mutator=[]) | ||
|
||
# train, val, test setting | ||
train_cfg = dict(by_epoch=True, max_epochs=240) | ||
val_cfg = dict() | ||
test_cfg = dict() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,30 @@ | ||
# CONTRASTIVE REPRESENTATION DISTILLATION | ||
|
||
> [CONTRASTIVE REPRESENTATION DISTILLATION](https://arxiv.org/abs/1910.10699) | ||
## Abstract | ||
|
||
Often we wish to transfer representational knowledge from one neural network to another. Examples include distilling a large network into a smaller one, transferring knowledge from one sensory modality to a second, or ensembling a collection of models into a single estimator. Knowledge distillation, the standard approach to these problems, minimizes the KL divergence between the probabilistic outputs of a teacher and student network. We demonstrate that this objective ignores important structural knowledge of the teacher network. This motivates an alternative objective by which we train a student to capture significantly more information in the teacher’s representation of the data. We formulate this objective as contrastive learning. Experiments demonstrate that our resulting new objective outperforms knowledge distillation and other cutting-edge distillers on a variety of knowledge transfer tasks, including single model compression, ensemble distillation, and cross-modal transfer. Our method sets a new state-of-the-art in many transfer tasks, and sometimes even outperforms the teacher network when combined with knowledge distillation.[Original code](http://github.com/HobbitLong/RepDistiller) | ||
|
||
![pipeline](../../../../docs/en/imgs/model_zoo/crd/pipeline.jpg) | ||
|
||
## Citation | ||
|
||
```latex | ||
@article{tian2019contrastive, | ||
title={Contrastive representation distillation}, | ||
author={Tian, Yonglong and Krishnan, Dilip and Isola, Phillip}, | ||
journal={arXiv preprint arXiv:1910.10699}, | ||
year={2019} | ||
} | ||
``` | ||
|
||
## Results and models | ||
|
||
| Dataset | Model | Teacher | Top-1 (%) | Top-5 (%) | Configs | Download | | ||
| ------- | --------- | --------- | --------- | --------- | ------------------------------------------- | -------------------------------------------------------------------------------------------------------------------------------------------- | | ||
| CIFAR10 | ResNet-18 | ResNet-50 | 94.79 | 99.86 | [config](crd_neck_r50_r18_8xb16_cifar10.py) | [teacher](https://download.openmmlab.com/mmclassification/v0/resnet/resnet50_b16x8_cifar10_20210528-f54bfad9.pth) \|[model](<>) \| [log](<>) | | ||
|
||
## Acknowledgement | ||
|
||
Shout out to @chengshuang18 for his special contribution. |
108 changes: 108 additions & 0 deletions
108
configs/distill/mmcls/crd/crd_neck_r50_r18_8xb16_cifar10.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,108 @@ | ||
_base_ = [ | ||
'mmcls::_base_/datasets/cifar10_bs16.py', | ||
'mmcls::_base_/schedules/cifar10_bs128.py', | ||
'mmcls::_base_/default_runtime.py' | ||
] | ||
|
||
model = dict( | ||
_scope_='mmrazor', | ||
type='SingleTeacherDistill', | ||
data_preprocessor=dict( | ||
type='ImgDataPreprocessor', | ||
# RGB format normalization parameters | ||
mean=[123.675, 116.28, 103.53], | ||
std=[58.395, 57.12, 57.375], | ||
# convert image from BGR to RGB | ||
bgr_to_rgb=True), | ||
architecture=dict( | ||
cfg_path='mmcls::resnet/resnet18_8xb16_cifar10.py', pretrained=False), | ||
teacher=dict( | ||
cfg_path='mmcls::resnet/resnet50_8xb16_cifar10.py', pretrained=True), | ||
teacher_ckpt='resnet50_b16x8_cifar10_20210528-f54bfad9.pth', | ||
distiller=dict( | ||
type='ConfigurableDistiller', | ||
student_recorders=dict( | ||
neck=dict(type='ModuleOutputs', source='neck.gap'), | ||
data_samples=dict(type='ModuleInputs', source='')), | ||
teacher_recorders=dict( | ||
neck=dict(type='ModuleOutputs', source='neck.gap')), | ||
distill_losses=dict(loss_crd=dict(type='CRDLoss', loss_weight=0.8)), | ||
connectors=dict( | ||
loss_crd_stu=dict(type='CRDConnector', dim_in=512, dim_out=128), | ||
loss_crd_tea=dict(type='CRDConnector', dim_in=2048, dim_out=128)), | ||
loss_forward_mappings=dict( | ||
loss_crd=dict( | ||
s_feats=dict( | ||
from_student=True, | ||
recorder='neck', | ||
connector='loss_crd_stu'), | ||
t_feats=dict( | ||
from_student=False, | ||
recorder='neck', | ||
connector='loss_crd_tea'), | ||
data_samples=dict( | ||
from_student=True, recorder='data_samples', data_idx=1))))) | ||
|
||
find_unused_parameters = True | ||
|
||
val_cfg = dict(_delete_=True, type='mmrazor.SingleTeacherDistillValLoop') | ||
|
||
# change `CIFAR10` dataset to `CRDDataset` dataset. | ||
dataset_type = 'CIFAR10' | ||
train_pipeline = [ | ||
dict(_scope_='mmcls', type='RandomCrop', crop_size=32, padding=4), | ||
dict(_scope_='mmcls', type='RandomFlip', prob=0.5, direction='horizontal'), | ||
dict(_scope_='mmrazor', type='PackCRDClsInputs'), | ||
] | ||
|
||
test_pipeline = [ | ||
dict(_scope_='mmrazor', type='PackCRDClsInputs'), | ||
] | ||
|
||
ori_train_dataset = dict( | ||
_scope_='mmcls', | ||
type=dataset_type, | ||
data_prefix='data/cifar10', | ||
test_mode=False, | ||
pipeline=train_pipeline) | ||
|
||
crd_train_dataset = dict( | ||
_scope_='mmrazor', | ||
type='CRDDataset', | ||
dataset=ori_train_dataset, | ||
neg_num=16384, | ||
sample_mode='exact', | ||
percent=1.0) | ||
|
||
ori_test_dataset = dict( | ||
_scope_='mmcls', | ||
type=dataset_type, | ||
data_prefix='data/cifar10/', | ||
test_mode=True, | ||
pipeline=test_pipeline) | ||
|
||
crd_test_dataset = dict( | ||
_scope_='mmrazor', | ||
type='CRDDataset', | ||
dataset=ori_test_dataset, | ||
neg_num=16384, | ||
sample_mode='exact', | ||
percent=1.0) | ||
|
||
train_dataloader = dict( | ||
_delete_=True, | ||
batch_size=16, | ||
num_workers=2, | ||
dataset=crd_train_dataset, | ||
sampler=dict(type='DefaultSampler', shuffle=True), | ||
persistent_workers=True, | ||
) | ||
|
||
val_dataloader = dict( | ||
_delete_=True, | ||
batch_size=16, | ||
num_workers=2, | ||
dataset=crd_test_dataset, | ||
sampler=dict(type='DefaultSampler', shuffle=False), | ||
persistent_workers=True, | ||
) |
Oops, something went wrong.