-
Notifications
You must be signed in to change notification settings - Fork 220
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* WIP: support deit * WIP: add deithead * WIP: fix checkpoint hook * fix data preprocessor * fix cfg * WIP: add readme * reset single_teacher_distill * add metafile * add model to model-index * fix configs and readme
- Loading branch information
Showing
8 changed files
with
240 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,45 @@ | ||
# DeiT | ||
|
||
> [](https://arxiv.org/abs/2012.12877) | ||
> Training data-efficient image transformers & distillation through attention | ||
<!-- [ALGORITHM] --> | ||
|
||
## Abstract | ||
|
||
Recently, neural networks purely based on attention were shown to address image understanding tasks such as image classification. However, these visual transformers are pre-trained with hundreds of millions of images using an expensive infrastructure, thereby limiting their adoption. In this work, we produce a competitive convolution-free transformer by training on Imagenet only. We train them on a single computer in less than 3 days. Our reference vision transformer (86M parameters) achieves top-1 accuracy of 83.1% (single-crop evaluation) on ImageNet with no external data. More importantly, we introduce a teacher-student strategy specific to transformers. It relies on a distillation token ensuring that the student learns from the teacher through attention. We show the interest of this token-based distillation, especially when using a convnet as a teacher. This leads us to report results competitive with convnets for both Imagenet (where we obtain up to 85.2% accuracy) and when transferring to other tasks. We share our code and models. | ||
|
||
<div align=center> | ||
<img src="https://user-images.githubusercontent.com/26739999/143225703-c287c29e-82c9-4c85-a366-dfae30d198cd.png" width="40%"/> | ||
</div> | ||
|
||
## Results and models | ||
|
||
### Classification | ||
|
||
| Dataset | Model | Teacher | Top-1 (%) | Top-5 (%) | Configs | Download | | ||
| -------- | --------- | ----------- | --------- | --------- | ------------------------------------------------ | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | | ||
| ImageNet | Deit-base | RegNety-160 | 83.24 | 96.33 | [config](deit-base_regnety160_pt-16xb64_in1k.py) | [model](https://openmmlab-share.oss-cn-hangzhou.aliyuncs.com/mmrazor/v1/deit/deit-base/deit-base_regnety160_pt-16xb64_in1k_20221011_113403-a67bf475.pth?versionId=CAEQThiBgMCFteW0oBgiIDdmMWY2NGRiOGY1YzRmZWZiOTExMzQ2NjNlMjk2Nzcz) \| [log](https://openmmlab-share.oss-cn-hangzhou.aliyuncs.com/mmrazor/v1/deit/deit-base/deit-base_regnety160_pt-16xb64_in1k_20221011_113403-a67bf475.json?versionId=CAEQThiBgIDGos20oBgiIGVlNDgyM2M2ZTk5MzQyYjFhNTgwNGIzMjllZjg3YmZm) | | ||
|
||
```{warning} | ||
Before training, please first install `timm`. | ||
pip install timm | ||
or | ||
git clone https://github.com/rwightman/pytorch-image-models | ||
cd pytorch-image-models && pip install -e . | ||
``` | ||
|
||
## Citation | ||
|
||
``` | ||
@InProceedings{pmlr-v139-touvron21a, | ||
title = {Training data-efficient image transformers & distillation through attention}, | ||
author = {Touvron, Hugo and Cord, Matthieu and Douze, Matthijs and Massa, Francisco and Sablayrolles, Alexandre and Jegou, Herve}, | ||
booktitle = {International Conference on Machine Learning}, | ||
pages = {10347--10357}, | ||
year = {2021}, | ||
volume = {139}, | ||
month = {July} | ||
} | ||
``` |
64 changes: 64 additions & 0 deletions
64
configs/distill/mmcls/deit/deit-base_regnety160_pt-16xb64_in1k.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,64 @@ | ||
_base_ = ['mmcls::deit/deit-base_pt-16xb64_in1k.py'] | ||
|
||
# student settings | ||
student = _base_.model | ||
student.backbone.type = 'DistilledVisionTransformer' | ||
student.head = dict( | ||
type='mmrazor.DeiTClsHead', | ||
num_classes=1000, | ||
in_channels=768, | ||
loss=dict( | ||
type='mmcls.LabelSmoothLoss', | ||
label_smooth_val=0.1, | ||
mode='original', | ||
loss_weight=0.5)) | ||
|
||
data_preprocessor = dict( | ||
type='mmcls.ClsDataPreprocessor', batch_augments=student.train_cfg) | ||
|
||
# teacher settings | ||
checkpoint_path = 'https://dl.fbaipublicfiles.com/deit/regnety_160-a5fe301d.pth' # noqa: E501 | ||
teacher = dict( | ||
_scope_='mmcls', | ||
type='ImageClassifier', | ||
backbone=dict( | ||
type='TIMMBackbone', model_name='regnety_160', pretrained=True), | ||
neck=dict(type='GlobalAveragePooling'), | ||
head=dict( | ||
type='LinearClsHead', | ||
num_classes=1000, | ||
in_channels=3024, | ||
loss=dict( | ||
type='LabelSmoothLoss', | ||
label_smooth_val=0.1, | ||
mode='original', | ||
loss_weight=0.5), | ||
topk=(1, 5), | ||
init_cfg=dict( | ||
type='Pretrained', checkpoint=checkpoint_path, prefix='head.'))) | ||
|
||
model = dict( | ||
_scope_='mmrazor', | ||
_delete_=True, | ||
type='SingleTeacherDistill', | ||
architecture=student, | ||
teacher=teacher, | ||
distiller=dict( | ||
type='ConfigurableDistiller', | ||
student_recorders=dict( | ||
fc=dict(type='ModuleOutputs', source='head.layers.head_dist')), | ||
teacher_recorders=dict( | ||
fc=dict(type='ModuleOutputs', source='head.fc')), | ||
distill_losses=dict( | ||
loss_distill=dict( | ||
type='CrossEntropyLoss', | ||
loss_weight=0.5, | ||
)), | ||
loss_forward_mappings=dict( | ||
loss_distill=dict( | ||
preds_S=dict(from_student=True, recorder='fc'), | ||
preds_T=dict(from_student=False, recorder='fc'))))) | ||
|
||
find_unused_parameters = True | ||
|
||
val_cfg = dict(_delete_=True, type='mmrazor.SingleTeacherDistillValLoop') |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,34 @@ | ||
Collections: | ||
- Name: DEIT | ||
Metadata: | ||
Training Data: | ||
- ImageNet-1k | ||
Paper: | ||
URL: https://arxiv.org/abs/2012.12877 | ||
Title: Training data-efficient image transformers & distillation through attention | ||
README: configs/distill/mmcls/deit/README.md | ||
|
||
Models: | ||
- Name: deit-base_regnety160_pt-16xb64_in1k | ||
In Collection: DEIT | ||
Metadata: | ||
Student: | ||
Config: mmcls::deit/deit-base_pt-16xb64_in1k.py | ||
Weights: https://download.openmmlab.com/mmclassification/v0/deit/deit-base_pt-16xb64_in1k_20220216-db63c16c.pth | ||
Metrics: | ||
Top 1 Accuracy: 81.76 | ||
Top 5 Accuracy: 95.81 | ||
Teacher: | ||
Config: mmrazor::distill/mmcls/deit/deit-base_regnety160_pt-16xb64_in1k.py | ||
Weights: https://dl.fbaipublicfiles.com/deit/regnety_160-a5fe301d.pth | ||
Metrics: | ||
Top 1 Accuracy: 82.83 | ||
Top 5 Accuracy: 96.42 | ||
Results: | ||
- Task: Classification | ||
Dataset: ImageNet-1k | ||
Metrics: | ||
Top 1 Accuracy: 83.24 | ||
Top 5 Accuracy: 96.33 | ||
Weights: https://openmmlab-share.oss-cn-hangzhou.aliyuncs.com/mmrazor/v1/deit/deit-base/deit-base_regnety160_pt-16xb64_in1k_20221011_113403-a67bf475.pth?versionId=CAEQThiBgMCFteW0oBgiIDdmMWY2NGRiOGY1YzRmZWZiOTExMzQ2NjNlMjk2Nzcz | ||
Config: configs/distill/mmcls/deit/deit-base_regnety160_pt-16xb64_in1k.py |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,5 @@ | ||
# Copyright (c) OpenMMLab. All rights reserved. | ||
from .darts_subnet_head import DartsSubnetClsHead | ||
from .deit_head import DeiTClsHead | ||
|
||
__all__ = ['DartsSubnetClsHead'] | ||
__all__ = ['DartsSubnetClsHead', 'DeiTClsHead'] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,69 @@ | ||
# Copyright (c) OpenMMLab. All rights reserved. | ||
from typing import List, Tuple | ||
|
||
import torch | ||
import torch.nn as nn | ||
|
||
from mmrazor.registry import MODELS | ||
|
||
try: | ||
from mmcls.models import VisionTransformerClsHead | ||
except ImportError: | ||
from mmrazor.utils import get_placeholder | ||
VisionTransformerClsHead = get_placeholder('mmcls') | ||
|
||
|
||
@MODELS.register_module() | ||
class DeiTClsHead(VisionTransformerClsHead): | ||
"""Distilled Vision Transformer classifier head. | ||
Comparing with the :class:`DeiTClsHead` in mmcls, this head support to | ||
train the distilled version DeiT. | ||
Args: | ||
num_classes (int): Number of categories excluding the background | ||
category. | ||
in_channels (int): Number of channels in the input feature map. | ||
hidden_dim (int, optional): Number of the dimensions for hidden layer. | ||
Defaults to None, which means no extra hidden layer. | ||
act_cfg (dict): The activation config. Only available during | ||
pre-training. Defaults to ``dict(type='Tanh')``. | ||
init_cfg (dict): The extra initialization configs. Defaults to | ||
``dict(type='Constant', layer='Linear', val=0)``. | ||
""" | ||
|
||
def _init_layers(self): | ||
""""Init extra hidden linear layer to handle dist token if exists.""" | ||
super(DeiTClsHead, self)._init_layers() | ||
if self.hidden_dim is None: | ||
head_dist = nn.Linear(self.in_channels, self.num_classes) | ||
else: | ||
head_dist = nn.Linear(self.hidden_dim, self.num_classes) | ||
self.layers.add_module('head_dist', head_dist) | ||
|
||
def pre_logits( | ||
self, feats: Tuple[List[torch.Tensor]] | ||
) -> Tuple[torch.Tensor, torch.Tensor]: | ||
"""The process before the final classification head. | ||
The input ``feats`` is a tuple of list of tensor, and each tensor is | ||
the feature of a backbone stage. In ``DeiTClsHead``, we obtain the | ||
feature of the last stage and forward in hidden layer if exists. | ||
""" | ||
_, cls_token, dist_token = feats[-1] | ||
if self.hidden_dim is None: | ||
return cls_token, dist_token | ||
else: | ||
cls_token = self.layers.act(self.layers.pre_logits(cls_token)) | ||
dist_token = self.layers.act(self.layers.pre_logits(dist_token)) | ||
return cls_token, dist_token | ||
|
||
def forward(self, feats: Tuple[List[torch.Tensor]]) -> torch.Tensor: | ||
"""The forward process.""" | ||
cls_token, dist_token = self.pre_logits(feats) | ||
# The final classification head. | ||
cls_score = self.layers.head(cls_token) | ||
# Forward so that the corresponding recorder can record the output | ||
# of the distillation token | ||
_ = self.layers.head_dist(dist_token) | ||
return cls_score |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,23 @@ | ||
# Copyright (c) OpenMMLab. All rights reserved. | ||
import torch.nn as nn | ||
import torch.nn.functional as F | ||
|
||
from mmrazor.registry import MODELS | ||
|
||
|
||
@MODELS.register_module() | ||
class CrossEntropyLoss(nn.Module): | ||
"""Cross entropy loss. | ||
Args: | ||
loss_weight (float): Weight of the loss. Defaults to 1.0. | ||
""" | ||
|
||
def __init__(self, loss_weight=1.0): | ||
super(CrossEntropyLoss, self).__init__() | ||
self.loss_weight = loss_weight | ||
|
||
def forward(self, preds_S, preds_T): | ||
preds_T = preds_T.detach() | ||
loss = F.cross_entropy(preds_S, preds_T.argmax(dim=1)) | ||
return loss * self.loss_weight |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters