Skip to content

Commit

Permalink
[feature] CONTRASTIVE REPRESENTATION DISTILLATION with dataset wrapper (
Browse files Browse the repository at this point in the history
#281)

* init

* TD: CRDLoss

* complete UT

* fix docstrings

* fix ci

* update

* fix CI

* DONE

* maintain CRD dataset unique funcs as a mixin

* maintain CRD dataset unique funcs as a mixin

* maintain CRD dataset unique funcs as a mixin

* add UT: CRD_ClsDataset

* init

* TODO: UT test formatting.

* init

* crd dataset wrapper

* update docstring

Co-authored-by: huangpengsheng <huangpengsheng@sensetime.com>
  • Loading branch information
fpshuang and huangpengsheng committed Sep 13, 2022
1 parent 1f1bcd1 commit eb25bb7
Show file tree
Hide file tree
Showing 24 changed files with 1,109 additions and 6 deletions.
30 changes: 30 additions & 0 deletions configs/distill/mmcls/crd/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
# CONTRASTIVE REPRESENTATION DISTILLATION

> [CONTRASTIVE REPRESENTATION DISTILLATION](https://arxiv.org/abs/1910.10699)
## Abstract

Often we wish to transfer representational knowledge from one neural network to another. Examples include distilling a large network into a smaller one, transferring knowledge from one sensory modality to a second, or ensembling a collection of models into a single estimator. Knowledge distillation, the standard approach to these problems, minimizes the KL divergence between the probabilistic outputs of a teacher and student network. We demonstrate that this objective ignores important structural knowledge of the teacher network. This motivates an alternative objective by which we train a student to capture significantly more information in the teacher’s representation of the data. We formulate this objective as contrastive learning. Experiments demonstrate that our resulting new objective outperforms knowledge distillation and other cutting-edge distillers on a variety of knowledge transfer tasks, including single model compression, ensemble distillation, and cross-modal transfer. Our method sets a new state-of-the-art in many transfer tasks, and sometimes even outperforms the teacher network when combined with knowledge distillation.[Original code](http://github.com/HobbitLong/RepDistiller)

![pipeline](../../../../docs/en/imgs/model_zoo/crd/pipeline.jpg)

## Citation

```latex
@article{tian2019contrastive,
title={Contrastive representation distillation},
author={Tian, Yonglong and Krishnan, Dilip and Isola, Phillip},
journal={arXiv preprint arXiv:1910.10699},
year={2019}
}
```

## Results and models

| Dataset | Model | Teacher | Top-1 (%) | Top-5 (%) | Configs | Download |
| ------- | --------- | --------- | --------- | --------- | ------------------------------------------- | -------------------------------------------------------------------------------------------------------------------------------------------- |
| CIFAR10 | ResNet-18 | ResNet-50 | 94.79 | 99.86 | [config](crd_neck_r50_r18_8xb16_cifar10.py) | [teacher](https://download.openmmlab.com/mmclassification/v0/resnet/resnet50_b16x8_cifar10_20210528-f54bfad9.pth) \|[model](<>) \| [log](<>) |

## Acknowledgement

Shout out to @chengshuang18 for his special contribution.
108 changes: 108 additions & 0 deletions configs/distill/mmcls/crd/crd_neck_r50_r18_8xb16_cifar10.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
_base_ = [
'mmcls::_base_/datasets/cifar10_bs16.py',
'mmcls::_base_/schedules/cifar10_bs128.py',
'mmcls::_base_/default_runtime.py'
]

model = dict(
_scope_='mmrazor',
type='SingleTeacherDistill',
data_preprocessor=dict(
type='ImgDataPreprocessor',
# RGB format normalization parameters
mean=[123.675, 116.28, 103.53],
std=[58.395, 57.12, 57.375],
# convert image from BGR to RGB
bgr_to_rgb=True),
architecture=dict(
cfg_path='mmcls::resnet/resnet18_8xb16_cifar10.py', pretrained=False),
teacher=dict(
cfg_path='mmcls::resnet/resnet50_8xb16_cifar10.py', pretrained=True),
teacher_ckpt='resnet50_b16x8_cifar10_20210528-f54bfad9.pth',
distiller=dict(
type='ConfigurableDistiller',
student_recorders=dict(
neck=dict(type='ModuleOutputs', source='neck.gap'),
data_samples=dict(type='ModuleInputs', source='')),
teacher_recorders=dict(
neck=dict(type='ModuleOutputs', source='neck.gap')),
distill_losses=dict(loss_crd=dict(type='CRDLoss', loss_weight=0.8)),
connectors=dict(
loss_crd_stu=dict(type='CRDConnector', dim_in=512, dim_out=128),
loss_crd_tea=dict(type='CRDConnector', dim_in=2048, dim_out=128)),
loss_forward_mappings=dict(
loss_crd=dict(
s_feats=dict(
from_student=True,
recorder='neck',
connector='loss_crd_stu'),
t_feats=dict(
from_student=False,
recorder='neck',
connector='loss_crd_tea'),
data_samples=dict(
from_student=True, recorder='data_samples', data_idx=1)))))

find_unused_parameters = True

val_cfg = dict(_delete_=True, type='mmrazor.SingleTeacherDistillValLoop')

# change `CIFAR10` dataset to `CRDDataset` dataset.
dataset_type = 'CIFAR10'
train_pipeline = [
dict(_scope_='mmcls', type='RandomCrop', crop_size=32, padding=4),
dict(_scope_='mmcls', type='RandomFlip', prob=0.5, direction='horizontal'),
dict(_scope_='mmrazor', type='PackCRDClsInputs'),
]

test_pipeline = [
dict(_scope_='mmrazor', type='PackCRDClsInputs'),
]

ori_train_dataset = dict(
_scope_='mmcls',
type=dataset_type,
data_prefix='data/cifar10',
test_mode=False,
pipeline=train_pipeline)

crd_train_dataset = dict(
_scope_='mmrazor',
type='CRDDataset',
dataset=ori_train_dataset,
neg_num=16384,
sample_mode='exact',
percent=1.0)

ori_test_dataset = dict(
_scope_='mmcls',
type=dataset_type,
data_prefix='data/cifar10/',
test_mode=True,
pipeline=test_pipeline)

crd_test_dataset = dict(
_scope_='mmrazor',
type='CRDDataset',
dataset=ori_test_dataset,
neg_num=16384,
sample_mode='exact',
percent=1.0)

train_dataloader = dict(
_delete_=True,
batch_size=16,
num_workers=2,
dataset=crd_train_dataset,
sampler=dict(type='DefaultSampler', shuffle=True),
persistent_workers=True,
)

val_dataloader = dict(
_delete_=True,
batch_size=16,
num_workers=2,
dataset=crd_test_dataset,
sampler=dict(type='DefaultSampler', shuffle=False),
persistent_workers=True,
)
49 changes: 49 additions & 0 deletions configs/distill/mmcls/crd/datasets/crd_cifar10_bs16.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
# dataset settings
dataset_type = 'CIFAR10'
preprocess_cfg = dict(
# RGB format normalization parameters
mean=[125.307, 122.961, 113.8575],
std=[51.5865, 50.847, 51.255],
# loaded images are already RGB format
to_rgb=False)

train_pipeline = [
dict(type='RandomCrop', crop_size=32, padding=4),
dict(type='RandomFlip', prob=0.5, direction='horizontal'),
dict(type='PackClsInputs'),
]

test_pipeline = [
dict(type='PackClsInputs'),
]

neg_num = 16384
train_dataloader = dict(
batch_size=16,
num_workers=2,
dataset=dict(
type=dataset_type,
data_prefix='data/cifar10',
test_mode=False,
pipeline=train_pipeline,
neg_num=neg_num),
sampler=dict(type='DefaultSampler', shuffle=True),
persistent_workers=True,
)

val_dataloader = dict(
batch_size=16,
num_workers=2,
dataset=dict(
type=dataset_type,
data_prefix='data/cifar10/',
test_mode=True,
pipeline=test_pipeline,
neg_num=neg_num),
sampler=dict(type='DefaultSampler', shuffle=False),
persistent_workers=True,
)
val_evaluator = dict(type='Accuracy', topk=(1, ))

test_dataloader = val_dataloader
test_evaluator = val_evaluator
5 changes: 5 additions & 0 deletions mmrazor/datasets/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .crd_dataset_wrapper import CRDDataset
from .transforms import PackCRDClsInputs

__all__ = ['PackCRDClsInputs', 'CRDDataset']
Loading

0 comments on commit eb25bb7

Please sign in to comment.