From 8c7cdb3c73524e4c772cfb9e46828ceead400b91 Mon Sep 17 00:00:00 2001 From: whcao <41630003+HIT-cwh@users.noreply.github.com> Date: Tue, 25 Oct 2022 21:18:18 +0800 Subject: [PATCH] [Feature] Add deit-base (#332) * WIP: support deit * WIP: add deithead * WIP: fix checkpoint hook * fix data preprocessor * fix cfg * WIP: add readme * reset single_teacher_distill * add metafile * add model to model-index * fix configs and readme --- configs/distill/mmcls/deit/README.md | 45 ++++++++++++ .../deit-base_regnety160_pt-16xb64_in1k.py | 64 +++++++++++++++++ configs/distill/mmcls/deit/metafile.yml | 34 +++++++++ .../models/architectures/heads/__init__.py | 3 +- .../models/architectures/heads/deit_head.py | 69 +++++++++++++++++++ mmrazor/models/losses/__init__.py | 3 +- mmrazor/models/losses/cross_entropy_loss.py | 23 +++++++ model-index.yml | 1 + 8 files changed, 240 insertions(+), 2 deletions(-) create mode 100644 configs/distill/mmcls/deit/README.md create mode 100644 configs/distill/mmcls/deit/deit-base_regnety160_pt-16xb64_in1k.py create mode 100644 configs/distill/mmcls/deit/metafile.yml create mode 100644 mmrazor/models/architectures/heads/deit_head.py create mode 100644 mmrazor/models/losses/cross_entropy_loss.py diff --git a/configs/distill/mmcls/deit/README.md b/configs/distill/mmcls/deit/README.md new file mode 100644 index 000000000..9482cc344 --- /dev/null +++ b/configs/distill/mmcls/deit/README.md @@ -0,0 +1,45 @@ +# DeiT + +> [](https://arxiv.org/abs/2012.12877) +> Training data-efficient image transformers & distillation through attention + + + +## Abstract + +Recently, neural networks purely based on attention were shown to address image understanding tasks such as image classification. However, these visual transformers are pre-trained with hundreds of millions of images using an expensive infrastructure, thereby limiting their adoption. In this work, we produce a competitive convolution-free transformer by training on Imagenet only. We train them on a single computer in less than 3 days. Our reference vision transformer (86M parameters) achieves top-1 accuracy of 83.1% (single-crop evaluation) on ImageNet with no external data. More importantly, we introduce a teacher-student strategy specific to transformers. It relies on a distillation token ensuring that the student learns from the teacher through attention. We show the interest of this token-based distillation, especially when using a convnet as a teacher. This leads us to report results competitive with convnets for both Imagenet (where we obtain up to 85.2% accuracy) and when transferring to other tasks. We share our code and models. + +
+ +
+ +## Results and models + +### Classification + +| Dataset | Model | Teacher | Top-1 (%) | Top-5 (%) | Configs | Download | +| -------- | --------- | ----------- | --------- | --------- | ------------------------------------------------ | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | +| ImageNet | Deit-base | RegNety-160 | 83.24 | 96.33 | [config](deit-base_regnety160_pt-16xb64_in1k.py) | [model](https://openmmlab-share.oss-cn-hangzhou.aliyuncs.com/mmrazor/v1/deit/deit-base/deit-base_regnety160_pt-16xb64_in1k_20221011_113403-a67bf475.pth?versionId=CAEQThiBgMCFteW0oBgiIDdmMWY2NGRiOGY1YzRmZWZiOTExMzQ2NjNlMjk2Nzcz) \| [log](https://openmmlab-share.oss-cn-hangzhou.aliyuncs.com/mmrazor/v1/deit/deit-base/deit-base_regnety160_pt-16xb64_in1k_20221011_113403-a67bf475.json?versionId=CAEQThiBgIDGos20oBgiIGVlNDgyM2M2ZTk5MzQyYjFhNTgwNGIzMjllZjg3YmZm) | + +```{warning} +Before training, please first install `timm`. + +pip install timm +or +git clone https://github.com/rwightman/pytorch-image-models +cd pytorch-image-models && pip install -e . +``` + +## Citation + +``` +@InProceedings{pmlr-v139-touvron21a, + title = {Training data-efficient image transformers & distillation through attention}, + author = {Touvron, Hugo and Cord, Matthieu and Douze, Matthijs and Massa, Francisco and Sablayrolles, Alexandre and Jegou, Herve}, + booktitle = {International Conference on Machine Learning}, + pages = {10347--10357}, + year = {2021}, + volume = {139}, + month = {July} +} +``` diff --git a/configs/distill/mmcls/deit/deit-base_regnety160_pt-16xb64_in1k.py b/configs/distill/mmcls/deit/deit-base_regnety160_pt-16xb64_in1k.py new file mode 100644 index 000000000..c2cfaf56a --- /dev/null +++ b/configs/distill/mmcls/deit/deit-base_regnety160_pt-16xb64_in1k.py @@ -0,0 +1,64 @@ +_base_ = ['mmcls::deit/deit-base_pt-16xb64_in1k.py'] + +# student settings +student = _base_.model +student.backbone.type = 'DistilledVisionTransformer' +student.head = dict( + type='mmrazor.DeiTClsHead', + num_classes=1000, + in_channels=768, + loss=dict( + type='mmcls.LabelSmoothLoss', + label_smooth_val=0.1, + mode='original', + loss_weight=0.5)) + +data_preprocessor = dict( + type='mmcls.ClsDataPreprocessor', batch_augments=student.train_cfg) + +# teacher settings +checkpoint_path = 'https://dl.fbaipublicfiles.com/deit/regnety_160-a5fe301d.pth' # noqa: E501 +teacher = dict( + _scope_='mmcls', + type='ImageClassifier', + backbone=dict( + type='TIMMBackbone', model_name='regnety_160', pretrained=True), + neck=dict(type='GlobalAveragePooling'), + head=dict( + type='LinearClsHead', + num_classes=1000, + in_channels=3024, + loss=dict( + type='LabelSmoothLoss', + label_smooth_val=0.1, + mode='original', + loss_weight=0.5), + topk=(1, 5), + init_cfg=dict( + type='Pretrained', checkpoint=checkpoint_path, prefix='head.'))) + +model = dict( + _scope_='mmrazor', + _delete_=True, + type='SingleTeacherDistill', + architecture=student, + teacher=teacher, + distiller=dict( + type='ConfigurableDistiller', + student_recorders=dict( + fc=dict(type='ModuleOutputs', source='head.layers.head_dist')), + teacher_recorders=dict( + fc=dict(type='ModuleOutputs', source='head.fc')), + distill_losses=dict( + loss_distill=dict( + type='CrossEntropyLoss', + loss_weight=0.5, + )), + loss_forward_mappings=dict( + loss_distill=dict( + preds_S=dict(from_student=True, recorder='fc'), + preds_T=dict(from_student=False, recorder='fc'))))) + +find_unused_parameters = True + +val_cfg = dict(_delete_=True, type='mmrazor.SingleTeacherDistillValLoop') diff --git a/configs/distill/mmcls/deit/metafile.yml b/configs/distill/mmcls/deit/metafile.yml new file mode 100644 index 000000000..1c28545c1 --- /dev/null +++ b/configs/distill/mmcls/deit/metafile.yml @@ -0,0 +1,34 @@ +Collections: + - Name: DEIT + Metadata: + Training Data: + - ImageNet-1k + Paper: + URL: https://arxiv.org/abs/2012.12877 + Title: Training data-efficient image transformers & distillation through attention + README: configs/distill/mmcls/deit/README.md + +Models: + - Name: deit-base_regnety160_pt-16xb64_in1k + In Collection: DEIT + Metadata: + Student: + Config: mmcls::deit/deit-base_pt-16xb64_in1k.py + Weights: https://download.openmmlab.com/mmclassification/v0/deit/deit-base_pt-16xb64_in1k_20220216-db63c16c.pth + Metrics: + Top 1 Accuracy: 81.76 + Top 5 Accuracy: 95.81 + Teacher: + Config: mmrazor::distill/mmcls/deit/deit-base_regnety160_pt-16xb64_in1k.py + Weights: https://dl.fbaipublicfiles.com/deit/regnety_160-a5fe301d.pth + Metrics: + Top 1 Accuracy: 82.83 + Top 5 Accuracy: 96.42 + Results: + - Task: Classification + Dataset: ImageNet-1k + Metrics: + Top 1 Accuracy: 83.24 + Top 5 Accuracy: 96.33 + Weights: https://openmmlab-share.oss-cn-hangzhou.aliyuncs.com/mmrazor/v1/deit/deit-base/deit-base_regnety160_pt-16xb64_in1k_20221011_113403-a67bf475.pth?versionId=CAEQThiBgMCFteW0oBgiIDdmMWY2NGRiOGY1YzRmZWZiOTExMzQ2NjNlMjk2Nzcz + Config: configs/distill/mmcls/deit/deit-base_regnety160_pt-16xb64_in1k.py diff --git a/mmrazor/models/architectures/heads/__init__.py b/mmrazor/models/architectures/heads/__init__.py index de84c30d5..0d7da475d 100644 --- a/mmrazor/models/architectures/heads/__init__.py +++ b/mmrazor/models/architectures/heads/__init__.py @@ -1,4 +1,5 @@ # Copyright (c) OpenMMLab. All rights reserved. from .darts_subnet_head import DartsSubnetClsHead +from .deit_head import DeiTClsHead -__all__ = ['DartsSubnetClsHead'] +__all__ = ['DartsSubnetClsHead', 'DeiTClsHead'] diff --git a/mmrazor/models/architectures/heads/deit_head.py b/mmrazor/models/architectures/heads/deit_head.py new file mode 100644 index 000000000..61d587d93 --- /dev/null +++ b/mmrazor/models/architectures/heads/deit_head.py @@ -0,0 +1,69 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import List, Tuple + +import torch +import torch.nn as nn + +from mmrazor.registry import MODELS + +try: + from mmcls.models import VisionTransformerClsHead +except ImportError: + from mmrazor.utils import get_placeholder + VisionTransformerClsHead = get_placeholder('mmcls') + + +@MODELS.register_module() +class DeiTClsHead(VisionTransformerClsHead): + """Distilled Vision Transformer classifier head. + + Comparing with the :class:`DeiTClsHead` in mmcls, this head support to + train the distilled version DeiT. + + Args: + num_classes (int): Number of categories excluding the background + category. + in_channels (int): Number of channels in the input feature map. + hidden_dim (int, optional): Number of the dimensions for hidden layer. + Defaults to None, which means no extra hidden layer. + act_cfg (dict): The activation config. Only available during + pre-training. Defaults to ``dict(type='Tanh')``. + init_cfg (dict): The extra initialization configs. Defaults to + ``dict(type='Constant', layer='Linear', val=0)``. + """ + + def _init_layers(self): + """"Init extra hidden linear layer to handle dist token if exists.""" + super(DeiTClsHead, self)._init_layers() + if self.hidden_dim is None: + head_dist = nn.Linear(self.in_channels, self.num_classes) + else: + head_dist = nn.Linear(self.hidden_dim, self.num_classes) + self.layers.add_module('head_dist', head_dist) + + def pre_logits( + self, feats: Tuple[List[torch.Tensor]] + ) -> Tuple[torch.Tensor, torch.Tensor]: + """The process before the final classification head. + + The input ``feats`` is a tuple of list of tensor, and each tensor is + the feature of a backbone stage. In ``DeiTClsHead``, we obtain the + feature of the last stage and forward in hidden layer if exists. + """ + _, cls_token, dist_token = feats[-1] + if self.hidden_dim is None: + return cls_token, dist_token + else: + cls_token = self.layers.act(self.layers.pre_logits(cls_token)) + dist_token = self.layers.act(self.layers.pre_logits(dist_token)) + return cls_token, dist_token + + def forward(self, feats: Tuple[List[torch.Tensor]]) -> torch.Tensor: + """The forward process.""" + cls_token, dist_token = self.pre_logits(feats) + # The final classification head. + cls_score = self.layers.head(cls_token) + # Forward so that the corresponding recorder can record the output + # of the distillation token + _ = self.layers.head_dist(dist_token) + return cls_score diff --git a/mmrazor/models/losses/__init__.py b/mmrazor/models/losses/__init__.py index a145ba914..e7df97168 100644 --- a/mmrazor/models/losses/__init__.py +++ b/mmrazor/models/losses/__init__.py @@ -2,6 +2,7 @@ from .ab_loss import ABLoss from .at_loss import ATLoss from .crd_loss import CRDLoss +from .cross_entropy_loss import CrossEntropyLoss from .cwd import ChannelWiseDivergence from .dafl_loss import ActivationLoss, InformationEntropyLoss, OnehotLikeLoss from .decoupled_kd import DKDLoss @@ -19,5 +20,5 @@ 'ChannelWiseDivergence', 'KLDivergence', 'AngleWiseRKD', 'DistanceWiseRKD', 'WSLD', 'L2Loss', 'ABLoss', 'DKDLoss', 'KDSoftCELoss', 'ActivationLoss', 'OnehotLikeLoss', 'InformationEntropyLoss', 'FTLoss', 'ATLoss', 'OFDLoss', - 'L1Loss', 'FBKDLoss', 'CRDLoss' + 'L1Loss', 'FBKDLoss', 'CRDLoss', 'CrossEntropyLoss' ] diff --git a/mmrazor/models/losses/cross_entropy_loss.py b/mmrazor/models/losses/cross_entropy_loss.py new file mode 100644 index 000000000..685748092 --- /dev/null +++ b/mmrazor/models/losses/cross_entropy_loss.py @@ -0,0 +1,23 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch.nn as nn +import torch.nn.functional as F + +from mmrazor.registry import MODELS + + +@MODELS.register_module() +class CrossEntropyLoss(nn.Module): + """Cross entropy loss. + + Args: + loss_weight (float): Weight of the loss. Defaults to 1.0. + """ + + def __init__(self, loss_weight=1.0): + super(CrossEntropyLoss, self).__init__() + self.loss_weight = loss_weight + + def forward(self, preds_S, preds_T): + preds_T = preds_T.detach() + loss = F.cross_entropy(preds_S, preds_T.argmax(dim=1)) + return loss * self.loss_weight diff --git a/model-index.yml b/model-index.yml index b1a321c84..6594c923e 100644 --- a/model-index.yml +++ b/model-index.yml @@ -18,3 +18,4 @@ Import: - configs/nas/mmcls/autoslim/metafile.yml - configs/nas/mmcls/darts/metafile.yml - configs/nas/mmdet/detnas/metafile.yml + - configs/distill/mmcls/deit/metafile.yml