Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] PyTorch version of PKD: General Distillation Framework for Object Detectors via Pearson Correlation Coefficient. #304

Merged
merged 20 commits into from
Oct 26, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 34 additions & 0 deletions configs/distill/mmdet/pkd/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
# PKD

> [PKD: General Distillation Framework for Object Detectors via Pearson Correlation Coefficient](https://arxiv.org/abs/2207.02039)

<!-- [ALGORITHM] -->

## Abstract

Knowledge distillation(KD) is a widely-used technique to train compact models in object detection. However, there is still a lack of study on how to distill between heterogeneous detectors. In this paper, we empirically find that better FPN features from a heterogeneous teacher detector can help the student although their detection heads and label assignments are different. However, directly aligning the feature maps to distill detectors suffers from two problems. First, the difference in feature magnitude between the teacher and the student could enforce overly strict constraints on the student. Second, the FPN stages and channels with large feature magnitude from the teacher model could dominate the gradient of distillation loss, which will overwhelm the effects of other features in KD and introduce much noise. To address the above issues, we propose to imitate features with Pearson Correlation Coefficient to focus on the relational information from the teacher and relax constraints on the magnitude of the features. Our method consistently outperforms the existing detection KD methods and works for both homogeneous and heterogeneous student-teacher pairs. Furthermore, it converges faster. With a powerful MaskRCNN-Swin detector as the teacher, ResNet-50 based RetinaNet and FCOS achieve 41.5% and 43.9% mAP on COCO2017, which are 4.1% and 4.8% higher than the baseline, respectively.

![pipeline](https://user-images.githubusercontent.com/41630003/197719796-76fa5f33-1d54-4927-8a08-86f5c6e33879.png)

## Results and models

### Detection

| Location | Dataset | Teacher | Student | Lr schd | mAP | mAP(T) | mAP(S) | Config | Download |
| :------: | :-----: | :--------------------------------------------------------------------------------------------------------------------------------------------------------: | :--------------------------------------------------------------------------------------------------------------------------------------: | :-----: | :--: | :----: | :----: | :-----------------------------------------------------------: | :------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
| FPN | COCO | [FCOS-X101](https://github.com/open-mmlab/mmdetection/blob/dev-3.x/configs/fcos/fcos_x101-64x4d_fpn_gn-head_ms-640-800-2x_coco.py) | [RetinaNet-R50](https://github.com/open-mmlab/mmdetection/blob/dev-3.x/configs/retinanet/retinanet_r50_fpn_1x_coco.py) | 1x | 40.3 | 42.6 | 36.5 | [config](pkd_fpn_fcos_x101_retina_r50_1x_coco.py) | [teacher](https://download.openmmlab.com/mmdetection/v2.0/fcos/fcos_x101_64x4d_fpn_gn-head_mstrain_640-800_2x_coco/fcos_x101_64x4d_fpn_gn-head_mstrain_640-800_2x_coco-ede514a8.pth) \|[model](https://openmmlab-share.oss-cn-hangzhou.aliyuncs.com/mmrazor/v1/pkd/pkd_fcos_retina/pkd_fpn_fcos_x101_retina_r50_1x_coco_20220925_181547-9cac5059.pth?versionId=CAEQThiBgMCLyNC0oBgiIDBjY2FkY2JlNGFiYzRmM2RiZGUyYzM1NjQxYzQxODA4) \| [log](https://openmmlab-share.oss-cn-hangzhou.aliyuncs.com/mmrazor/v1/pkd/pkd_fcos_retina/pkd_fpn_fcos_x101_retina_r50_1x_coco_20220925_181547-9cac5059.json?versionId=CAEQThiBgMDA0dS0oBgiIDM4ZjZlZmVkMzc4MjQxMGJiN2FlMDFlOTA2NGIzZGQ4) |
| FPN | COCO | [Faster-Rcnn-R101](https://github.com/open-mmlab/mmdetection/blob/dev-3.x/configs/faster_rcnn/faster-rcnn_r101_fpn_2x_coco.py) | [Faster-rcnn-R50](https://github.com/open-mmlab/mmdetection/blob/dev-3.x/configs/faster_rcnn/faster-rcnn_r50_fpn_2x_coco.py) | 2x | 40.3 | 39.8 | 38.4 | [config](pkd_fpn_faster-rcnn_r101_faster-rcnn_r50_2x_coco.py) | [teacher](https://download.openmmlab.com/mmdetection/v2.0/faster_rcnn/faster_rcnn_r101_fpn_2x_coco/faster_rcnn_r101_fpn_2x_coco_bbox_mAP-0.398_20200504_210455-1d2dac9c.pth) \|[model](https://openmmlab-share.oss-cn-hangzhou.aliyuncs.com/mmrazor/v1/pkd/pkd_frcnn/pkd_fpn_faster-rcnn_r101_faster-rcnn_r50_2x_coco_20221014_103040-3efbd439.pth?versionId=CAEQThiBgMDQr9C0oBgiIDMyZWE1Y2ZlMDA2ZDQ2ZGNhZmQ3NzMxODk3YzgzYWFl) \| [log](https://openmmlab-share.oss-cn-hangzhou.aliyuncs.com/mmrazor/v1/pkd/pkd_frcnn/pkd_fpn_faster-rcnn_r101_faster-rcnn_r50_2x_coco_20221014_103040-3efbd439.json?versionId=CAEQThiBgICYsNC0oBgiIDdhNWY5ZjZlYjUyNzRjMGU4NGFhYzk4NzQwZDAxY2Rj) |
| FPN | COCO | [Mask-Rcnn-Swin](https://github.com/open-mmlab/mmdetection/blob/dev-3.x/configs/swin/mask-rcnn_swin-s-p4-w7_fpn_amp-ms-crop-3x_coco.py) | [RetinaNet-R50](https://github.com/open-mmlab/mmdetection/blob/dev-3.x/configs/retinanet/retinanet_r50_fpn_2x_coco.py) | 2x | 41.5 | 48.2 | 37.4 | [config](pkd_fpn_mask-rcnn_swin_retina_r50_2x_coco.py) | [teacher](https://download.openmmlab.com/mmdetection/v2.0/swin/mask_rcnn_swin-t-p4-w7_fpn_1x_coco/mask_rcnn_swin-t-p4-w7_fpn_1x_coco_20210902_120937-9d6b7cfa.pth) \|[model](https://openmmlab-share.oss-cn-hangzhou.aliyuncs.com/mmrazor/v1/pkd/pkd_swin_retina/pkd_fpn_mask_rcnn_swin_retina_r50_2x_coco_20220925_142555-edec7433.pth?versionId=CAEQThiBgIDWqNC0oBgiIDViOGE0ZDU4ODgxNzQ5YmE5OGU3MzRkMjFiZGRjZmRm) \| [log](https://openmmlab-share.oss-cn-hangzhou.aliyuncs.com/mmrazor/v1/pkd/pkd_swin_retina/pkd_fpn_mask_rcnn_swin_retina_r50_2x_coco_20220925_142555-edec7433.json?versionId=CAEQThiBgIDVqdC0oBgiIDU3YzFjOWRmNWY3NTRmYjFhMDdmNzU2ODE3MzdlZThk) |
| FPN | COCO | [Reppoints-X101-dcn](https://github.com/open-mmlab/mmdetection/blob/dev-3.x/configs/reppoints/reppoints-moment_x101-dconv-c3-c5_fpn-gn_head-gn_2x_coco.py) | [Reppoints-R50](https://github.com/open-mmlab/mmdetection/blob/dev-3.x/configs/reppoints/reppoints-moment_r50_fpn-gn_head-gn_2x_coco.py) | 2x | 42.3 | 44.2 | 38.6 | [config](pkd_fpn_reppoints_x101-dcn_reppoints_r50_2x_coco.py) | [teacher](https://download.openmmlab.com/mmdetection/v2.0/reppoints/reppoints_moment_x101_fpn_dconv_c3-c5_gn-neck%2Bhead_2x_coco/reppoints_moment_x101_fpn_dconv_c3-c5_gn-neck%2Bhead_2x_coco_20200329-f87da1ea.pth) \|[model](https://openmmlab-share.oss-cn-hangzhou.aliyuncs.com/mmrazor/v1/pkd/pkd_reppoints/pkd_fpn_reppoints_x101_dcn_reppoints_r50_2x_coco_20220926_145818-f8932e12.pth?versionId=CAEQThiBgIC8rNC0oBgiIGU2N2IxM2NkMjNlMjQyN2E4YmVlNmViNGI2MDY3OTE5) \| [log](https://openmmlab-share.oss-cn-hangzhou.aliyuncs.com/mmrazor/v1/pkd/pkd_reppoints/pkd_fpn_reppoints_x101_dcn_reppoints_r50_2x_coco_20220926_145818-f8932e12.json?versionId=CAEQThiBgICordC0oBgiIDJhMjBjOGZiN2UxNjQxYmI5MzE3NWVhZDgxZDE2NmJm) |
| FPN | COCO | [RetinaNet-X101](https://github.com/open-mmlab/mmdetection/blob/dev-3.x/configs/retinanet/retinanet_x101-64x4d_fpn_1x_coco.py) | [RetinaNet-R50](https://github.com/open-mmlab/mmdetection/blob/dev-3.x/configs/retinanet/retinanet_r50_fpn_2x_coco.py) | 2x | 40.8 | 41.0 | 37.4 | [config](pkd_fpn_retina_x101_retina_r50_2x_coco.py) | [teacher](https://download.openmmlab.com/mmdetection/v2.0/retinanet/retinanet_x101_64x4d_fpn_1x_coco/retinanet_x101_64x4d_fpn_1x_coco_20200130-366f5af1.pth) \|[model](https://openmmlab-share.oss-cn-hangzhou.aliyuncs.com/mmrazor/v1/pkd/pkd_retinax_retina/pkd_fpn_retina_x101_retina_r50_2x_coco_20221014_232526-4c0f8d96.pth?versionId=CAEQThiBgIDQqdC0oBgiIGFmZjNmZmE4NDFiMDQ4MzhiMzdjOGI2NzI4MTQxMjFi) \| [log](https://openmmlab-share.oss-cn-hangzhou.aliyuncs.com/mmrazor/v1/pkd/pkd_retinax_retina/pkd_fpn_retina_x101_retina_r50_2x_coco_20221014_232526-4c0f8d96.json?versionId=CAEQThiBgMC2qdC0oBgiIGRkMTIzODYwMzliMDQ3M2JiYjNlYjA5N2I4Y2QzMGFl) |

## Citation

```latex
@article{cao2022pkd,
title={PKD: General Distillation Framework for Object Detectors via Pearson Correlation Coefficient},
author={Cao, Weihan and Zhang, Yifan and Gao, Jianfei and Cheng, Anda and Cheng, Ke and Cheng, Jian},
journal={arXiv preprint arXiv:2207.02039},
year={2022}
}
```
110 changes: 110 additions & 0 deletions configs/distill/mmdet/pkd/metafile.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
Models:
- Name: pkd_fpn_fcos_x101_retina_r50_1x_coco
In Collection: PKD
Metadata:
Location: FPN
Student:
Metrics:
box AP: 36.5
Config: mmdet::retinanet/retinanet_r50_fpn_1x_coco.py
Weights: https://download.openmmlab.com/mmdetection/v2.0/retinanet/retinanet_r50_fpn_1x_coco/retinanet_r50_fpn_1x_coco_20200130-c2398f9e.pth
Teacher:
Metrics:
box AP: 42.6
Config: mmdet::fcos/fcos_x101-64x4d_fpn_gn-head_ms-640-800-2x_coco.py
Weights: https://download.openmmlab.com/mmdetection/v2.0/fcos/fcos_x101_64x4d_fpn_gn-head_mstrain_640-800_2x_coco/fcos_x101_64x4d_fpn_gn-head_mstrain_640-800_2x_coco-ede514a8.pth
Results:
- Task: Object Detection
Dataset: COCO
Metrics:
box AP: 40.3
Config: configs/distill/mmdet/pkd/pkd_fpn_fcos_x101_retina_r50_1x_coco.py
Weights: https://openmmlab-share.oss-cn-hangzhou.aliyuncs.com/mmrazor/v1/pkd/pkd_fcos_retina/pkd_fpn_fcos_x101_retina_r50_1x_coco_20220925_181547-9cac5059.pth?versionId=CAEQThiBgMCLyNC0oBgiIDBjY2FkY2JlNGFiYzRmM2RiZGUyYzM1NjQxYzQxODA4

- Name: pkd_fpn_faster-rcnn_r101_faster-rcnn_r50_2x_coco
In Collection: PKD
Metadata:
Location: FPN
Student:
Metrics:
box AP: 38.4
Config: mmdet::faster_rcnn/faster-rcnn_r50_fpn_2x_coco.py
Weights: https://download.openmmlab.com/mmdetection/v2.0/faster_rcnn/faster_rcnn_r50_fpn_2x_coco/faster_rcnn_r50_fpn_2x_coco_bbox_mAP-0.384_20200504_210434-a5d8aa15.pth
Teacher:
Metrics:
box AP: 39.8
Config: mmdet::faster_rcnn/faster-rcnn_r101_fpn_2x_coco.py
Weights: https://download.openmmlab.com/mmdetection/v2.0/faster_rcnn/faster_rcnn_r101_fpn_2x_coco/faster_rcnn_r101_fpn_2x_coco_bbox_mAP-0.398_20200504_210455-1d2dac9c.pth
Results:
- Task: Object Detection
Dataset: COCO
Metrics:
box AP: 40.4
Config: configs/distill/mmdet/pkd/pkd_fpn_faster-rcnn_r101_faster-rcnn_r50_2x_coco.py
Weights: https://openmmlab-share.oss-cn-hangzhou.aliyuncs.com/mmrazor/v1/pkd/pkd_frcnn/pkd_fpn_faster-rcnn_r101_faster-rcnn_r50_2x_coco_20221014_103040-3efbd439.pth?versionId=CAEQThiBgMDQr9C0oBgiIDMyZWE1Y2ZlMDA2ZDQ2ZGNhZmQ3NzMxODk3YzgzYWFl

- Name: pkd_fpn_mask-rcnn_swin_retina_r50_2x_coco
In Collection: PKD
Metadata:
Location: FPN
Student:
Metrics:
box AP: 37.4
Config: mmdet::retinanet/retinanet_r50_fpn_2x_coco.py
Weights: https://download.openmmlab.com/mmdetection/v2.0/retinanet/retinanet_r50_fpn_2x_coco/retinanet_r50_fpn_2x_coco_20200131-fdb43119.pth
Teacher:
Metrics:
box AP: 48.2
Config: mmdet::swin/mask-rcnn_swin-s-p4-w7_fpn_amp-ms-crop-3x_coco.py
Weights: https://download.openmmlab.com/mmdetection/v2.0/swin/mask_rcnn_swin-t-p4-w7_fpn_1x_coco/mask_rcnn_swin-t-p4-w7_fpn_1x_coco_20210902_120937-9d6b7cfa.pth
Results:
- Task: Object Detection
Dataset: COCO
Metrics:
box AP: 41.5
Config: configs/distill/mmdet/pkd/pkd_fpn_mask-rcnn_swin_retina_r50_2x_coco.py
Weights: https://openmmlab-share.oss-cn-hangzhou.aliyuncs.com/mmrazor/v1/pkd/pkd_swin_retina/pkd_fpn_mask_rcnn_swin_retina_r50_2x_coco_20220925_142555-edec7433.pth?versionId=CAEQThiBgIDWqNC0oBgiIDViOGE0ZDU4ODgxNzQ5YmE5OGU3MzRkMjFiZGRjZmRm

- Name: pkd_fpn_reppoints_x101-dcn_reppoints_r50_2x_coco
In Collection: PKD
Metadata:
Location: FPN
Student:
Metrics:
box AP: 38.6
Config: mmdet::reppoints/reppoints-moment_r50_fpn-gn_head-gn_2x_coco.py
Weights: https://download.openmmlab.com/mmdetection/v2.0/reppoints/reppoints_moment_r50_fpn_gn-neck%2Bhead_2x_coco/reppoints_moment_r50_fpn_gn-neck%2Bhead_2x_coco_20200329-91babaa2.pth
Teacher:
Metrics:
box AP: 44.2
Config: mmdet::reppoints/reppoints-moment_x101-dconv-c3-c5_fpn-gn_head-gn_2x_coco.py
Weights: https://download.openmmlab.com/mmdetection/v2.0/reppoints/reppoints_moment_x101_fpn_dconv_c3-c5_gn-neck%2Bhead_2x_coco/reppoints_moment_x101_fpn_dconv_c3-c5_gn-neck%2Bhead_2x_coco_20200329-f87da1ea.pth
Results:
- Task: Object Detection
Dataset: COCO
Metrics:
box AP: 42.3
Config: configs/distill/mmdet/pkd/pkd_fpn_reppoints_x101-dcn_reppoints_r50_2x_coco.py
Weights: https://openmmlab-share.oss-cn-hangzhou.aliyuncs.com/mmrazor/v1/pkd/pkd_reppoints/pkd_fpn_reppoints_x101_dcn_reppoints_r50_2x_coco_20220926_145818-f8932e12.pth?versionId=CAEQThiBgIC8rNC0oBgiIGU2N2IxM2NkMjNlMjQyN2E4YmVlNmViNGI2MDY3OTE5

- Name: pkd_fpn_retina_x101_retina_r50_2x_coco
In Collection: PKD
Metadata:
Location: FPN
Student:
Metrics:
box AP: 37.4
Config: mmdet::retinanet/retinanet_r50_fpn_2x_coco.py
Weights: https://download.openmmlab.com/mmdetection/v2.0/retinanet/retinanet_r50_fpn_2x_coco/retinanet_r50_fpn_2x_coco_20200131-fdb43119.pth
Teacher:
Metrics:
box AP: 41.0
Config: mmdet::retinanet/retinanet_x101-64x4d_fpn_1x_coco.py
Weights: https://download.openmmlab.com/mmdetection/v2.0/retinanet/retinanet_x101_64x4d_fpn_1x_coco/retinanet_x101_64x4d_fpn_1x_coco_20200130-366f5af1.pth
Results:
- Task: Object Detection
Dataset: COCO
Metrics:
box AP: 40.8
Config: configs/distill/mmdet/pkd/pkd_fpn_retina_x101_retina_r50_2x_coco.py
Weights: https://openmmlab-share.oss-cn-hangzhou.aliyuncs.com/mmrazor/v1/pkd/pkd_retinax_retina/pkd_fpn_retina_x101_retina_r50_2x_coco_20221014_232526-4c0f8d96.pth?versionId=CAEQThiBgIDQqdC0oBgiIGFmZjNmZmE4NDFiMDQ4MzhiMzdjOGI2NzI4MTQxMjFi
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
_base_ = [
'mmdet::_base_/datasets/coco_detection.py',
'mmdet::_base_/schedules/schedule_2x.py',
'mmdet::_base_/default_runtime.py'
]

teacher_ckpt = 'https://download.openmmlab.com/mmdetection/v2.0/faster_rcnn/faster_rcnn_r101_fpn_2x_coco/faster_rcnn_r101_fpn_2x_coco_bbox_mAP-0.398_20200504_210455-1d2dac9c.pth' # noqa: E501

model = dict(
_scope_='mmrazor',
type='FpnTeacherDistill',
architecture=dict(
cfg_path='mmdet::faster_rcnn/faster-rcnn_r50_fpn_2x_coco.py',
pretrained=False),
teacher=dict(
cfg_path='mmdet::faster_rcnn/faster-rcnn_r101_fpn_2x_coco.py',
pretrained=False),
teacher_ckpt=teacher_ckpt,
distiller=dict(
type='ConfigurableDistiller',
student_recorders=dict(fpn=dict(type='ModuleOutputs', source='neck')),
teacher_recorders=dict(fpn=dict(type='ModuleOutputs', source='neck')),
distill_losses=dict(
loss_pkd_fpn0=dict(type='PKDLoss', loss_weight=6),
loss_pkd_fpn1=dict(type='PKDLoss', loss_weight=6),
loss_pkd_fpn2=dict(type='PKDLoss', loss_weight=6),
loss_pkd_fpn3=dict(type='PKDLoss', loss_weight=6)),
loss_forward_mappings=dict(
loss_pkd_fpn0=dict(
preds_S=dict(from_student=True, recorder='fpn', data_idx=0),
preds_T=dict(from_student=False, recorder='fpn', data_idx=0)),
loss_pkd_fpn1=dict(
preds_S=dict(from_student=True, recorder='fpn', data_idx=1),
preds_T=dict(from_student=False, recorder='fpn', data_idx=1)),
loss_pkd_fpn2=dict(
preds_S=dict(from_student=True, recorder='fpn', data_idx=2),
preds_T=dict(from_student=False, recorder='fpn', data_idx=2)),
loss_pkd_fpn3=dict(
preds_S=dict(from_student=True, recorder='fpn', data_idx=3),
preds_T=dict(from_student=False, recorder='fpn',
data_idx=3)))))

find_unused_parameters = True

val_cfg = dict(_delete_=True, type='mmrazor.SingleTeacherDistillValLoop')
27 changes: 27 additions & 0 deletions configs/distill/mmdet/pkd/pkd_fpn_fcos_x101_retina_r50_1x_coco.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
_base_ = ['./pkd_fpn_retina_x101_retina_r50_2x_coco.py']

teacher_ckpt = 'https://download.openmmlab.com/mmdetection/v2.0/fcos/fcos_x101_64x4d_fpn_gn-head_mstrain_640-800_2x_coco/fcos_x101_64x4d_fpn_gn-head_mstrain_640-800_2x_coco-ede514a8.pth' # noqa: E501

model = dict(
architecture=dict(
cfg_path='mmdet::retinanet/retinanet_r50_fpn_1x_coco.py'),
teacher=dict(
cfg_path= # noqa: E251
'mmdet::fcos/fcos_x101-64x4d_fpn_gn-head_ms-640-800-2x_coco.py'),
teacher_ckpt=teacher_ckpt)

# training schedule for 1x
train_cfg = dict(type='EpochBasedTrainLoop', max_epochs=12, val_interval=1)

# learning rate
param_scheduler = [
dict(
type='LinearLR', start_factor=0.001, by_epoch=False, begin=0, end=500),
dict(
type='MultiStepLR',
begin=0,
end=12,
by_epoch=True,
milestones=[8, 11],
gamma=0.1)
]
Loading