Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Add deit-base #332

Merged
merged 12 commits into from
Oct 25, 2022
45 changes: 45 additions & 0 deletions configs/distill/mmcls/deit/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
# DeiT

> [](https://arxiv.org/abs/2012.12877)
> Training data-efficient image transformers & distillation through attention

<!-- [ALGORITHM] -->

## Abstract

Recently, neural networks purely based on attention were shown to address image understanding tasks such as image classification. However, these visual transformers are pre-trained with hundreds of millions of images using an expensive infrastructure, thereby limiting their adoption. In this work, we produce a competitive convolution-free transformer by training on Imagenet only. We train them on a single computer in less than 3 days. Our reference vision transformer (86M parameters) achieves top-1 accuracy of 83.1% (single-crop evaluation) on ImageNet with no external data. More importantly, we introduce a teacher-student strategy specific to transformers. It relies on a distillation token ensuring that the student learns from the teacher through attention. We show the interest of this token-based distillation, especially when using a convnet as a teacher. This leads us to report results competitive with convnets for both Imagenet (where we obtain up to 85.2% accuracy) and when transferring to other tasks. We share our code and models.

<div align=center>
<img src="https://user-images.githubusercontent.com/26739999/143225703-c287c29e-82c9-4c85-a366-dfae30d198cd.png" width="40%"/>
</div>

## Results and models

### Classification

| Dataset | Model | Teacher | Top-1 (%) | Top-5 (%) | Configs | Download |
| -------- | --------- | ----------- | --------- | --------- | ------------------------------------------------ | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
| ImageNet | Deit-base | RegNety-160 | 83.24 | 96.33 | [config](deit-base_regnety160_pt-16xb64_in1k.py) | [model](https://openmmlab-share.oss-cn-hangzhou.aliyuncs.com/mmrazor/v1/deit/deit-base/deit-base_regnety160_pt-16xb64_in1k_20221011_113403-a67bf475.pth?versionId=CAEQThiBgMCFteW0oBgiIDdmMWY2NGRiOGY1YzRmZWZiOTExMzQ2NjNlMjk2Nzcz) \| [log](https://openmmlab-share.oss-cn-hangzhou.aliyuncs.com/mmrazor/v1/deit/deit-base/deit-base_regnety160_pt-16xb64_in1k_20221011_113403-a67bf475.json?versionId=CAEQThiBgIDGos20oBgiIGVlNDgyM2M2ZTk5MzQyYjFhNTgwNGIzMjllZjg3YmZm) |

```{warning}
Before training, please first install `timm`.

pip install timm
or
git clone https://github.com/rwightman/pytorch-image-models
cd pytorch-image-models && pip install -e .
```

## Citation

```
@InProceedings{pmlr-v139-touvron21a,
title = {Training data-efficient image transformers &amp; distillation through attention},
author = {Touvron, Hugo and Cord, Matthieu and Douze, Matthijs and Massa, Francisco and Sablayrolles, Alexandre and Jegou, Herve},
booktitle = {International Conference on Machine Learning},
pages = {10347--10357},
year = {2021},
volume = {139},
month = {July}
}
```
64 changes: 64 additions & 0 deletions configs/distill/mmcls/deit/deit-base_regnety160_pt-16xb64_in1k.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
_base_ = ['mmcls::deit/deit-base_pt-16xb64_in1k.py']

# student settings
student = _base_.model
student.backbone.type = 'DistilledVisionTransformer'
student.head = dict(
type='mmrazor.DeiTClsHead',
num_classes=1000,
in_channels=768,
loss=dict(
type='mmcls.LabelSmoothLoss',
label_smooth_val=0.1,
mode='original',
loss_weight=0.5))

data_preprocessor = dict(
type='mmcls.ClsDataPreprocessor', batch_augments=student.train_cfg)

# teacher settings
checkpoint_path = 'https://dl.fbaipublicfiles.com/deit/regnety_160-a5fe301d.pth' # noqa: E501
teacher = dict(
_scope_='mmcls',
type='ImageClassifier',
backbone=dict(
type='TIMMBackbone', model_name='regnety_160', pretrained=True),
neck=dict(type='GlobalAveragePooling'),
head=dict(
type='LinearClsHead',
num_classes=1000,
in_channels=3024,
loss=dict(
type='LabelSmoothLoss',
label_smooth_val=0.1,
mode='original',
loss_weight=0.5),
topk=(1, 5),
init_cfg=dict(
type='Pretrained', checkpoint=checkpoint_path, prefix='head.')))

model = dict(
_scope_='mmrazor',
_delete_=True,
type='SingleTeacherDistill',
architecture=student,
teacher=teacher,
distiller=dict(
type='ConfigurableDistiller',
student_recorders=dict(
fc=dict(type='ModuleOutputs', source='head.layers.head_dist')),
teacher_recorders=dict(
fc=dict(type='ModuleOutputs', source='head.fc')),
distill_losses=dict(
loss_distill=dict(
type='CrossEntropyLoss',
loss_weight=0.5,
)),
loss_forward_mappings=dict(
loss_distill=dict(
preds_S=dict(from_student=True, recorder='fc'),
preds_T=dict(from_student=False, recorder='fc')))))

find_unused_parameters = True

val_cfg = dict(_delete_=True, type='mmrazor.SingleTeacherDistillValLoop')
34 changes: 34 additions & 0 deletions configs/distill/mmcls/deit/metafile.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
Collections:
- Name: DEIT
Metadata:
Training Data:
- ImageNet-1k
Paper:
URL: https://arxiv.org/abs/2012.12877
Title: Training data-efficient image transformers & distillation through attention
README: configs/distill/mmcls/deit/README.md

Models:
- Name: deit-base_regnety160_pt-16xb64_in1k
In Collection: DEIT
Metadata:
Student:
Config: mmcls::deit/deit-base_pt-16xb64_in1k.py
Weights: https://download.openmmlab.com/mmclassification/v0/deit/deit-base_pt-16xb64_in1k_20220216-db63c16c.pth
Metrics:
Top 1 Accuracy: 81.76
Top 5 Accuracy: 95.81
Teacher:
Config: mmrazor::distill/mmcls/deit/deit-base_regnety160_pt-16xb64_in1k.py
Weights: https://dl.fbaipublicfiles.com/deit/regnety_160-a5fe301d.pth
Metrics:
Top 1 Accuracy: 82.83
Top 5 Accuracy: 96.42
Results:
- Task: Classification
Dataset: ImageNet-1k
Metrics:
Top 1 Accuracy: 83.24
Top 5 Accuracy: 96.33
Weights: https://openmmlab-share.oss-cn-hangzhou.aliyuncs.com/mmrazor/v1/deit/deit-base/deit-base_regnety160_pt-16xb64_in1k_20221011_113403-a67bf475.pth?versionId=CAEQThiBgMCFteW0oBgiIDdmMWY2NGRiOGY1YzRmZWZiOTExMzQ2NjNlMjk2Nzcz
Config: configs/distill/mmcls/deit/deit-base_regnety160_pt-16xb64_in1k.py
3 changes: 2 additions & 1 deletion mmrazor/models/architectures/heads/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .darts_subnet_head import DartsSubnetClsHead
from .deit_head import DeiTClsHead

__all__ = ['DartsSubnetClsHead']
__all__ = ['DartsSubnetClsHead', 'DeiTClsHead']
69 changes: 69 additions & 0 deletions mmrazor/models/architectures/heads/deit_head.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import List, Tuple

import torch
import torch.nn as nn

from mmrazor.registry import MODELS

try:
from mmcls.models import VisionTransformerClsHead
except ImportError:
from mmrazor.utils import get_placeholder
VisionTransformerClsHead = get_placeholder('mmcls')


@MODELS.register_module()
class DeiTClsHead(VisionTransformerClsHead):
"""Distilled Vision Transformer classifier head.

Comparing with the :class:`DeiTClsHead` in mmcls, this head support to
train the distilled version DeiT.

Args:
num_classes (int): Number of categories excluding the background
category.
in_channels (int): Number of channels in the input feature map.
hidden_dim (int, optional): Number of the dimensions for hidden layer.
Defaults to None, which means no extra hidden layer.
act_cfg (dict): The activation config. Only available during
pre-training. Defaults to ``dict(type='Tanh')``.
init_cfg (dict): The extra initialization configs. Defaults to
``dict(type='Constant', layer='Linear', val=0)``.
"""

def _init_layers(self):
""""Init extra hidden linear layer to handle dist token if exists."""
super(DeiTClsHead, self)._init_layers()
if self.hidden_dim is None:
head_dist = nn.Linear(self.in_channels, self.num_classes)
else:
head_dist = nn.Linear(self.hidden_dim, self.num_classes)
self.layers.add_module('head_dist', head_dist)

def pre_logits(
self, feats: Tuple[List[torch.Tensor]]
) -> Tuple[torch.Tensor, torch.Tensor]:
"""The process before the final classification head.

The input ``feats`` is a tuple of list of tensor, and each tensor is
the feature of a backbone stage. In ``DeiTClsHead``, we obtain the
feature of the last stage and forward in hidden layer if exists.
"""
_, cls_token, dist_token = feats[-1]
if self.hidden_dim is None:
return cls_token, dist_token
else:
cls_token = self.layers.act(self.layers.pre_logits(cls_token))
dist_token = self.layers.act(self.layers.pre_logits(dist_token))
return cls_token, dist_token

def forward(self, feats: Tuple[List[torch.Tensor]]) -> torch.Tensor:
"""The forward process."""
cls_token, dist_token = self.pre_logits(feats)
# The final classification head.
cls_score = self.layers.head(cls_token)
# Forward so that the corresponding recorder can record the output
# of the distillation token
_ = self.layers.head_dist(dist_token)
return cls_score
3 changes: 2 additions & 1 deletion mmrazor/models/losses/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from .ab_loss import ABLoss
from .at_loss import ATLoss
from .crd_loss import CRDLoss
from .cross_entropy_loss import CrossEntropyLoss
from .cwd import ChannelWiseDivergence
from .dafl_loss import ActivationLoss, InformationEntropyLoss, OnehotLikeLoss
from .decoupled_kd import DKDLoss
Expand All @@ -19,5 +20,5 @@
'ChannelWiseDivergence', 'KLDivergence', 'AngleWiseRKD', 'DistanceWiseRKD',
'WSLD', 'L2Loss', 'ABLoss', 'DKDLoss', 'KDSoftCELoss', 'ActivationLoss',
'OnehotLikeLoss', 'InformationEntropyLoss', 'FTLoss', 'ATLoss', 'OFDLoss',
'L1Loss', 'FBKDLoss', 'CRDLoss'
'L1Loss', 'FBKDLoss', 'CRDLoss', 'CrossEntropyLoss'
]
23 changes: 23 additions & 0 deletions mmrazor/models/losses/cross_entropy_loss.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
# Copyright (c) OpenMMLab. All rights reserved.
import torch.nn as nn
import torch.nn.functional as F

from mmrazor.registry import MODELS


@MODELS.register_module()
class CrossEntropyLoss(nn.Module):
"""Cross entropy loss.

Args:
loss_weight (float): Weight of the loss. Defaults to 1.0.
"""

def __init__(self, loss_weight=1.0):
super(CrossEntropyLoss, self).__init__()
self.loss_weight = loss_weight

def forward(self, preds_S, preds_T):
preds_T = preds_T.detach()
loss = F.cross_entropy(preds_S, preds_T.argmax(dim=1))
return loss * self.loss_weight
1 change: 1 addition & 0 deletions model-index.yml
Original file line number Diff line number Diff line change
Expand Up @@ -18,3 +18,4 @@ Import:
- configs/nas/mmcls/autoslim/metafile.yml
- configs/nas/mmcls/darts/metafile.yml
- configs/nas/mmdet/detnas/metafile.yml
- configs/distill/mmcls/deit/metafile.yml