From 8c7cdb3c73524e4c772cfb9e46828ceead400b91 Mon Sep 17 00:00:00 2001
From: whcao <41630003+HIT-cwh@users.noreply.github.com>
Date: Tue, 25 Oct 2022 21:18:18 +0800
Subject: [PATCH] [Feature] Add deit-base (#332)
* WIP: support deit
* WIP: add deithead
* WIP: fix checkpoint hook
* fix data preprocessor
* fix cfg
* WIP: add readme
* reset single_teacher_distill
* add metafile
* add model to model-index
* fix configs and readme
---
configs/distill/mmcls/deit/README.md | 45 ++++++++++++
.../deit-base_regnety160_pt-16xb64_in1k.py | 64 +++++++++++++++++
configs/distill/mmcls/deit/metafile.yml | 34 +++++++++
.../models/architectures/heads/__init__.py | 3 +-
.../models/architectures/heads/deit_head.py | 69 +++++++++++++++++++
mmrazor/models/losses/__init__.py | 3 +-
mmrazor/models/losses/cross_entropy_loss.py | 23 +++++++
model-index.yml | 1 +
8 files changed, 240 insertions(+), 2 deletions(-)
create mode 100644 configs/distill/mmcls/deit/README.md
create mode 100644 configs/distill/mmcls/deit/deit-base_regnety160_pt-16xb64_in1k.py
create mode 100644 configs/distill/mmcls/deit/metafile.yml
create mode 100644 mmrazor/models/architectures/heads/deit_head.py
create mode 100644 mmrazor/models/losses/cross_entropy_loss.py
diff --git a/configs/distill/mmcls/deit/README.md b/configs/distill/mmcls/deit/README.md
new file mode 100644
index 000000000..9482cc344
--- /dev/null
+++ b/configs/distill/mmcls/deit/README.md
@@ -0,0 +1,45 @@
+# DeiT
+
+> [](https://arxiv.org/abs/2012.12877)
+> Training data-efficient image transformers & distillation through attention
+
+
+
+## Abstract
+
+Recently, neural networks purely based on attention were shown to address image understanding tasks such as image classification. However, these visual transformers are pre-trained with hundreds of millions of images using an expensive infrastructure, thereby limiting their adoption. In this work, we produce a competitive convolution-free transformer by training on Imagenet only. We train them on a single computer in less than 3 days. Our reference vision transformer (86M parameters) achieves top-1 accuracy of 83.1% (single-crop evaluation) on ImageNet with no external data. More importantly, we introduce a teacher-student strategy specific to transformers. It relies on a distillation token ensuring that the student learns from the teacher through attention. We show the interest of this token-based distillation, especially when using a convnet as a teacher. This leads us to report results competitive with convnets for both Imagenet (where we obtain up to 85.2% accuracy) and when transferring to other tasks. We share our code and models.
+
+
+
+
+
+## Results and models
+
+### Classification
+
+| Dataset | Model | Teacher | Top-1 (%) | Top-5 (%) | Configs | Download |
+| -------- | --------- | ----------- | --------- | --------- | ------------------------------------------------ | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
+| ImageNet | Deit-base | RegNety-160 | 83.24 | 96.33 | [config](deit-base_regnety160_pt-16xb64_in1k.py) | [model](https://openmmlab-share.oss-cn-hangzhou.aliyuncs.com/mmrazor/v1/deit/deit-base/deit-base_regnety160_pt-16xb64_in1k_20221011_113403-a67bf475.pth?versionId=CAEQThiBgMCFteW0oBgiIDdmMWY2NGRiOGY1YzRmZWZiOTExMzQ2NjNlMjk2Nzcz) \| [log](https://openmmlab-share.oss-cn-hangzhou.aliyuncs.com/mmrazor/v1/deit/deit-base/deit-base_regnety160_pt-16xb64_in1k_20221011_113403-a67bf475.json?versionId=CAEQThiBgIDGos20oBgiIGVlNDgyM2M2ZTk5MzQyYjFhNTgwNGIzMjllZjg3YmZm) |
+
+```{warning}
+Before training, please first install `timm`.
+
+pip install timm
+or
+git clone https://github.com/rwightman/pytorch-image-models
+cd pytorch-image-models && pip install -e .
+```
+
+## Citation
+
+```
+@InProceedings{pmlr-v139-touvron21a,
+ title = {Training data-efficient image transformers & distillation through attention},
+ author = {Touvron, Hugo and Cord, Matthieu and Douze, Matthijs and Massa, Francisco and Sablayrolles, Alexandre and Jegou, Herve},
+ booktitle = {International Conference on Machine Learning},
+ pages = {10347--10357},
+ year = {2021},
+ volume = {139},
+ month = {July}
+}
+```
diff --git a/configs/distill/mmcls/deit/deit-base_regnety160_pt-16xb64_in1k.py b/configs/distill/mmcls/deit/deit-base_regnety160_pt-16xb64_in1k.py
new file mode 100644
index 000000000..c2cfaf56a
--- /dev/null
+++ b/configs/distill/mmcls/deit/deit-base_regnety160_pt-16xb64_in1k.py
@@ -0,0 +1,64 @@
+_base_ = ['mmcls::deit/deit-base_pt-16xb64_in1k.py']
+
+# student settings
+student = _base_.model
+student.backbone.type = 'DistilledVisionTransformer'
+student.head = dict(
+ type='mmrazor.DeiTClsHead',
+ num_classes=1000,
+ in_channels=768,
+ loss=dict(
+ type='mmcls.LabelSmoothLoss',
+ label_smooth_val=0.1,
+ mode='original',
+ loss_weight=0.5))
+
+data_preprocessor = dict(
+ type='mmcls.ClsDataPreprocessor', batch_augments=student.train_cfg)
+
+# teacher settings
+checkpoint_path = 'https://dl.fbaipublicfiles.com/deit/regnety_160-a5fe301d.pth' # noqa: E501
+teacher = dict(
+ _scope_='mmcls',
+ type='ImageClassifier',
+ backbone=dict(
+ type='TIMMBackbone', model_name='regnety_160', pretrained=True),
+ neck=dict(type='GlobalAveragePooling'),
+ head=dict(
+ type='LinearClsHead',
+ num_classes=1000,
+ in_channels=3024,
+ loss=dict(
+ type='LabelSmoothLoss',
+ label_smooth_val=0.1,
+ mode='original',
+ loss_weight=0.5),
+ topk=(1, 5),
+ init_cfg=dict(
+ type='Pretrained', checkpoint=checkpoint_path, prefix='head.')))
+
+model = dict(
+ _scope_='mmrazor',
+ _delete_=True,
+ type='SingleTeacherDistill',
+ architecture=student,
+ teacher=teacher,
+ distiller=dict(
+ type='ConfigurableDistiller',
+ student_recorders=dict(
+ fc=dict(type='ModuleOutputs', source='head.layers.head_dist')),
+ teacher_recorders=dict(
+ fc=dict(type='ModuleOutputs', source='head.fc')),
+ distill_losses=dict(
+ loss_distill=dict(
+ type='CrossEntropyLoss',
+ loss_weight=0.5,
+ )),
+ loss_forward_mappings=dict(
+ loss_distill=dict(
+ preds_S=dict(from_student=True, recorder='fc'),
+ preds_T=dict(from_student=False, recorder='fc')))))
+
+find_unused_parameters = True
+
+val_cfg = dict(_delete_=True, type='mmrazor.SingleTeacherDistillValLoop')
diff --git a/configs/distill/mmcls/deit/metafile.yml b/configs/distill/mmcls/deit/metafile.yml
new file mode 100644
index 000000000..1c28545c1
--- /dev/null
+++ b/configs/distill/mmcls/deit/metafile.yml
@@ -0,0 +1,34 @@
+Collections:
+ - Name: DEIT
+ Metadata:
+ Training Data:
+ - ImageNet-1k
+ Paper:
+ URL: https://arxiv.org/abs/2012.12877
+ Title: Training data-efficient image transformers & distillation through attention
+ README: configs/distill/mmcls/deit/README.md
+
+Models:
+ - Name: deit-base_regnety160_pt-16xb64_in1k
+ In Collection: DEIT
+ Metadata:
+ Student:
+ Config: mmcls::deit/deit-base_pt-16xb64_in1k.py
+ Weights: https://download.openmmlab.com/mmclassification/v0/deit/deit-base_pt-16xb64_in1k_20220216-db63c16c.pth
+ Metrics:
+ Top 1 Accuracy: 81.76
+ Top 5 Accuracy: 95.81
+ Teacher:
+ Config: mmrazor::distill/mmcls/deit/deit-base_regnety160_pt-16xb64_in1k.py
+ Weights: https://dl.fbaipublicfiles.com/deit/regnety_160-a5fe301d.pth
+ Metrics:
+ Top 1 Accuracy: 82.83
+ Top 5 Accuracy: 96.42
+ Results:
+ - Task: Classification
+ Dataset: ImageNet-1k
+ Metrics:
+ Top 1 Accuracy: 83.24
+ Top 5 Accuracy: 96.33
+ Weights: https://openmmlab-share.oss-cn-hangzhou.aliyuncs.com/mmrazor/v1/deit/deit-base/deit-base_regnety160_pt-16xb64_in1k_20221011_113403-a67bf475.pth?versionId=CAEQThiBgMCFteW0oBgiIDdmMWY2NGRiOGY1YzRmZWZiOTExMzQ2NjNlMjk2Nzcz
+ Config: configs/distill/mmcls/deit/deit-base_regnety160_pt-16xb64_in1k.py
diff --git a/mmrazor/models/architectures/heads/__init__.py b/mmrazor/models/architectures/heads/__init__.py
index de84c30d5..0d7da475d 100644
--- a/mmrazor/models/architectures/heads/__init__.py
+++ b/mmrazor/models/architectures/heads/__init__.py
@@ -1,4 +1,5 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .darts_subnet_head import DartsSubnetClsHead
+from .deit_head import DeiTClsHead
-__all__ = ['DartsSubnetClsHead']
+__all__ = ['DartsSubnetClsHead', 'DeiTClsHead']
diff --git a/mmrazor/models/architectures/heads/deit_head.py b/mmrazor/models/architectures/heads/deit_head.py
new file mode 100644
index 000000000..61d587d93
--- /dev/null
+++ b/mmrazor/models/architectures/heads/deit_head.py
@@ -0,0 +1,69 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from typing import List, Tuple
+
+import torch
+import torch.nn as nn
+
+from mmrazor.registry import MODELS
+
+try:
+ from mmcls.models import VisionTransformerClsHead
+except ImportError:
+ from mmrazor.utils import get_placeholder
+ VisionTransformerClsHead = get_placeholder('mmcls')
+
+
+@MODELS.register_module()
+class DeiTClsHead(VisionTransformerClsHead):
+ """Distilled Vision Transformer classifier head.
+
+ Comparing with the :class:`DeiTClsHead` in mmcls, this head support to
+ train the distilled version DeiT.
+
+ Args:
+ num_classes (int): Number of categories excluding the background
+ category.
+ in_channels (int): Number of channels in the input feature map.
+ hidden_dim (int, optional): Number of the dimensions for hidden layer.
+ Defaults to None, which means no extra hidden layer.
+ act_cfg (dict): The activation config. Only available during
+ pre-training. Defaults to ``dict(type='Tanh')``.
+ init_cfg (dict): The extra initialization configs. Defaults to
+ ``dict(type='Constant', layer='Linear', val=0)``.
+ """
+
+ def _init_layers(self):
+ """"Init extra hidden linear layer to handle dist token if exists."""
+ super(DeiTClsHead, self)._init_layers()
+ if self.hidden_dim is None:
+ head_dist = nn.Linear(self.in_channels, self.num_classes)
+ else:
+ head_dist = nn.Linear(self.hidden_dim, self.num_classes)
+ self.layers.add_module('head_dist', head_dist)
+
+ def pre_logits(
+ self, feats: Tuple[List[torch.Tensor]]
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ """The process before the final classification head.
+
+ The input ``feats`` is a tuple of list of tensor, and each tensor is
+ the feature of a backbone stage. In ``DeiTClsHead``, we obtain the
+ feature of the last stage and forward in hidden layer if exists.
+ """
+ _, cls_token, dist_token = feats[-1]
+ if self.hidden_dim is None:
+ return cls_token, dist_token
+ else:
+ cls_token = self.layers.act(self.layers.pre_logits(cls_token))
+ dist_token = self.layers.act(self.layers.pre_logits(dist_token))
+ return cls_token, dist_token
+
+ def forward(self, feats: Tuple[List[torch.Tensor]]) -> torch.Tensor:
+ """The forward process."""
+ cls_token, dist_token = self.pre_logits(feats)
+ # The final classification head.
+ cls_score = self.layers.head(cls_token)
+ # Forward so that the corresponding recorder can record the output
+ # of the distillation token
+ _ = self.layers.head_dist(dist_token)
+ return cls_score
diff --git a/mmrazor/models/losses/__init__.py b/mmrazor/models/losses/__init__.py
index a145ba914..e7df97168 100644
--- a/mmrazor/models/losses/__init__.py
+++ b/mmrazor/models/losses/__init__.py
@@ -2,6 +2,7 @@
from .ab_loss import ABLoss
from .at_loss import ATLoss
from .crd_loss import CRDLoss
+from .cross_entropy_loss import CrossEntropyLoss
from .cwd import ChannelWiseDivergence
from .dafl_loss import ActivationLoss, InformationEntropyLoss, OnehotLikeLoss
from .decoupled_kd import DKDLoss
@@ -19,5 +20,5 @@
'ChannelWiseDivergence', 'KLDivergence', 'AngleWiseRKD', 'DistanceWiseRKD',
'WSLD', 'L2Loss', 'ABLoss', 'DKDLoss', 'KDSoftCELoss', 'ActivationLoss',
'OnehotLikeLoss', 'InformationEntropyLoss', 'FTLoss', 'ATLoss', 'OFDLoss',
- 'L1Loss', 'FBKDLoss', 'CRDLoss'
+ 'L1Loss', 'FBKDLoss', 'CRDLoss', 'CrossEntropyLoss'
]
diff --git a/mmrazor/models/losses/cross_entropy_loss.py b/mmrazor/models/losses/cross_entropy_loss.py
new file mode 100644
index 000000000..685748092
--- /dev/null
+++ b/mmrazor/models/losses/cross_entropy_loss.py
@@ -0,0 +1,23 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import torch.nn as nn
+import torch.nn.functional as F
+
+from mmrazor.registry import MODELS
+
+
+@MODELS.register_module()
+class CrossEntropyLoss(nn.Module):
+ """Cross entropy loss.
+
+ Args:
+ loss_weight (float): Weight of the loss. Defaults to 1.0.
+ """
+
+ def __init__(self, loss_weight=1.0):
+ super(CrossEntropyLoss, self).__init__()
+ self.loss_weight = loss_weight
+
+ def forward(self, preds_S, preds_T):
+ preds_T = preds_T.detach()
+ loss = F.cross_entropy(preds_S, preds_T.argmax(dim=1))
+ return loss * self.loss_weight
diff --git a/model-index.yml b/model-index.yml
index b1a321c84..6594c923e 100644
--- a/model-index.yml
+++ b/model-index.yml
@@ -18,3 +18,4 @@ Import:
- configs/nas/mmcls/autoslim/metafile.yml
- configs/nas/mmcls/darts/metafile.yml
- configs/nas/mmdet/detnas/metafile.yml
+ - configs/distill/mmcls/deit/metafile.yml