Skip to content

Commit

Permalink
[Feature] Add deit-base (#332)
Browse files Browse the repository at this point in the history
* WIP: support deit

* WIP: add deithead

* WIP: fix checkpoint hook

* fix data preprocessor

* fix cfg

* WIP: add readme

* reset single_teacher_distill

* add metafile

* add model to model-index

* fix configs and readme
  • Loading branch information
HIT-cwh committed Oct 25, 2022
1 parent 972fd8e commit 8c7cdb3
Show file tree
Hide file tree
Showing 8 changed files with 240 additions and 2 deletions.
45 changes: 45 additions & 0 deletions configs/distill/mmcls/deit/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
# DeiT

> [](https://arxiv.org/abs/2012.12877)
> Training data-efficient image transformers & distillation through attention
<!-- [ALGORITHM] -->

## Abstract

Recently, neural networks purely based on attention were shown to address image understanding tasks such as image classification. However, these visual transformers are pre-trained with hundreds of millions of images using an expensive infrastructure, thereby limiting their adoption. In this work, we produce a competitive convolution-free transformer by training on Imagenet only. We train them on a single computer in less than 3 days. Our reference vision transformer (86M parameters) achieves top-1 accuracy of 83.1% (single-crop evaluation) on ImageNet with no external data. More importantly, we introduce a teacher-student strategy specific to transformers. It relies on a distillation token ensuring that the student learns from the teacher through attention. We show the interest of this token-based distillation, especially when using a convnet as a teacher. This leads us to report results competitive with convnets for both Imagenet (where we obtain up to 85.2% accuracy) and when transferring to other tasks. We share our code and models.

<div align=center>
<img src="https://user-images.githubusercontent.com/26739999/143225703-c287c29e-82c9-4c85-a366-dfae30d198cd.png" width="40%"/>
</div>

## Results and models

### Classification

| Dataset | Model | Teacher | Top-1 (%) | Top-5 (%) | Configs | Download |
| -------- | --------- | ----------- | --------- | --------- | ------------------------------------------------ | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
| ImageNet | Deit-base | RegNety-160 | 83.24 | 96.33 | [config](deit-base_regnety160_pt-16xb64_in1k.py) | [model](https://openmmlab-share.oss-cn-hangzhou.aliyuncs.com/mmrazor/v1/deit/deit-base/deit-base_regnety160_pt-16xb64_in1k_20221011_113403-a67bf475.pth?versionId=CAEQThiBgMCFteW0oBgiIDdmMWY2NGRiOGY1YzRmZWZiOTExMzQ2NjNlMjk2Nzcz) \| [log](https://openmmlab-share.oss-cn-hangzhou.aliyuncs.com/mmrazor/v1/deit/deit-base/deit-base_regnety160_pt-16xb64_in1k_20221011_113403-a67bf475.json?versionId=CAEQThiBgIDGos20oBgiIGVlNDgyM2M2ZTk5MzQyYjFhNTgwNGIzMjllZjg3YmZm) |

```{warning}
Before training, please first install `timm`.
pip install timm
or
git clone https://github.com/rwightman/pytorch-image-models
cd pytorch-image-models && pip install -e .
```

## Citation

```
@InProceedings{pmlr-v139-touvron21a,
title = {Training data-efficient image transformers &amp; distillation through attention},
author = {Touvron, Hugo and Cord, Matthieu and Douze, Matthijs and Massa, Francisco and Sablayrolles, Alexandre and Jegou, Herve},
booktitle = {International Conference on Machine Learning},
pages = {10347--10357},
year = {2021},
volume = {139},
month = {July}
}
```
64 changes: 64 additions & 0 deletions configs/distill/mmcls/deit/deit-base_regnety160_pt-16xb64_in1k.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
_base_ = ['mmcls::deit/deit-base_pt-16xb64_in1k.py']

# student settings
student = _base_.model
student.backbone.type = 'DistilledVisionTransformer'
student.head = dict(
type='mmrazor.DeiTClsHead',
num_classes=1000,
in_channels=768,
loss=dict(
type='mmcls.LabelSmoothLoss',
label_smooth_val=0.1,
mode='original',
loss_weight=0.5))

data_preprocessor = dict(
type='mmcls.ClsDataPreprocessor', batch_augments=student.train_cfg)

# teacher settings
checkpoint_path = 'https://dl.fbaipublicfiles.com/deit/regnety_160-a5fe301d.pth' # noqa: E501
teacher = dict(
_scope_='mmcls',
type='ImageClassifier',
backbone=dict(
type='TIMMBackbone', model_name='regnety_160', pretrained=True),
neck=dict(type='GlobalAveragePooling'),
head=dict(
type='LinearClsHead',
num_classes=1000,
in_channels=3024,
loss=dict(
type='LabelSmoothLoss',
label_smooth_val=0.1,
mode='original',
loss_weight=0.5),
topk=(1, 5),
init_cfg=dict(
type='Pretrained', checkpoint=checkpoint_path, prefix='head.')))

model = dict(
_scope_='mmrazor',
_delete_=True,
type='SingleTeacherDistill',
architecture=student,
teacher=teacher,
distiller=dict(
type='ConfigurableDistiller',
student_recorders=dict(
fc=dict(type='ModuleOutputs', source='head.layers.head_dist')),
teacher_recorders=dict(
fc=dict(type='ModuleOutputs', source='head.fc')),
distill_losses=dict(
loss_distill=dict(
type='CrossEntropyLoss',
loss_weight=0.5,
)),
loss_forward_mappings=dict(
loss_distill=dict(
preds_S=dict(from_student=True, recorder='fc'),
preds_T=dict(from_student=False, recorder='fc')))))

find_unused_parameters = True

val_cfg = dict(_delete_=True, type='mmrazor.SingleTeacherDistillValLoop')
34 changes: 34 additions & 0 deletions configs/distill/mmcls/deit/metafile.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
Collections:
- Name: DEIT
Metadata:
Training Data:
- ImageNet-1k
Paper:
URL: https://arxiv.org/abs/2012.12877
Title: Training data-efficient image transformers & distillation through attention
README: configs/distill/mmcls/deit/README.md

Models:
- Name: deit-base_regnety160_pt-16xb64_in1k
In Collection: DEIT
Metadata:
Student:
Config: mmcls::deit/deit-base_pt-16xb64_in1k.py
Weights: https://download.openmmlab.com/mmclassification/v0/deit/deit-base_pt-16xb64_in1k_20220216-db63c16c.pth
Metrics:
Top 1 Accuracy: 81.76
Top 5 Accuracy: 95.81
Teacher:
Config: mmrazor::distill/mmcls/deit/deit-base_regnety160_pt-16xb64_in1k.py
Weights: https://dl.fbaipublicfiles.com/deit/regnety_160-a5fe301d.pth
Metrics:
Top 1 Accuracy: 82.83
Top 5 Accuracy: 96.42
Results:
- Task: Classification
Dataset: ImageNet-1k
Metrics:
Top 1 Accuracy: 83.24
Top 5 Accuracy: 96.33
Weights: https://openmmlab-share.oss-cn-hangzhou.aliyuncs.com/mmrazor/v1/deit/deit-base/deit-base_regnety160_pt-16xb64_in1k_20221011_113403-a67bf475.pth?versionId=CAEQThiBgMCFteW0oBgiIDdmMWY2NGRiOGY1YzRmZWZiOTExMzQ2NjNlMjk2Nzcz
Config: configs/distill/mmcls/deit/deit-base_regnety160_pt-16xb64_in1k.py
3 changes: 2 additions & 1 deletion mmrazor/models/architectures/heads/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .darts_subnet_head import DartsSubnetClsHead
from .deit_head import DeiTClsHead

__all__ = ['DartsSubnetClsHead']
__all__ = ['DartsSubnetClsHead', 'DeiTClsHead']
69 changes: 69 additions & 0 deletions mmrazor/models/architectures/heads/deit_head.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import List, Tuple

import torch
import torch.nn as nn

from mmrazor.registry import MODELS

try:
from mmcls.models import VisionTransformerClsHead
except ImportError:
from mmrazor.utils import get_placeholder
VisionTransformerClsHead = get_placeholder('mmcls')


@MODELS.register_module()
class DeiTClsHead(VisionTransformerClsHead):
"""Distilled Vision Transformer classifier head.
Comparing with the :class:`DeiTClsHead` in mmcls, this head support to
train the distilled version DeiT.
Args:
num_classes (int): Number of categories excluding the background
category.
in_channels (int): Number of channels in the input feature map.
hidden_dim (int, optional): Number of the dimensions for hidden layer.
Defaults to None, which means no extra hidden layer.
act_cfg (dict): The activation config. Only available during
pre-training. Defaults to ``dict(type='Tanh')``.
init_cfg (dict): The extra initialization configs. Defaults to
``dict(type='Constant', layer='Linear', val=0)``.
"""

def _init_layers(self):
""""Init extra hidden linear layer to handle dist token if exists."""
super(DeiTClsHead, self)._init_layers()
if self.hidden_dim is None:
head_dist = nn.Linear(self.in_channels, self.num_classes)
else:
head_dist = nn.Linear(self.hidden_dim, self.num_classes)
self.layers.add_module('head_dist', head_dist)

def pre_logits(
self, feats: Tuple[List[torch.Tensor]]
) -> Tuple[torch.Tensor, torch.Tensor]:
"""The process before the final classification head.
The input ``feats`` is a tuple of list of tensor, and each tensor is
the feature of a backbone stage. In ``DeiTClsHead``, we obtain the
feature of the last stage and forward in hidden layer if exists.
"""
_, cls_token, dist_token = feats[-1]
if self.hidden_dim is None:
return cls_token, dist_token
else:
cls_token = self.layers.act(self.layers.pre_logits(cls_token))
dist_token = self.layers.act(self.layers.pre_logits(dist_token))
return cls_token, dist_token

def forward(self, feats: Tuple[List[torch.Tensor]]) -> torch.Tensor:
"""The forward process."""
cls_token, dist_token = self.pre_logits(feats)
# The final classification head.
cls_score = self.layers.head(cls_token)
# Forward so that the corresponding recorder can record the output
# of the distillation token
_ = self.layers.head_dist(dist_token)
return cls_score
3 changes: 2 additions & 1 deletion mmrazor/models/losses/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from .ab_loss import ABLoss
from .at_loss import ATLoss
from .crd_loss import CRDLoss
from .cross_entropy_loss import CrossEntropyLoss
from .cwd import ChannelWiseDivergence
from .dafl_loss import ActivationLoss, InformationEntropyLoss, OnehotLikeLoss
from .decoupled_kd import DKDLoss
Expand All @@ -19,5 +20,5 @@
'ChannelWiseDivergence', 'KLDivergence', 'AngleWiseRKD', 'DistanceWiseRKD',
'WSLD', 'L2Loss', 'ABLoss', 'DKDLoss', 'KDSoftCELoss', 'ActivationLoss',
'OnehotLikeLoss', 'InformationEntropyLoss', 'FTLoss', 'ATLoss', 'OFDLoss',
'L1Loss', 'FBKDLoss', 'CRDLoss'
'L1Loss', 'FBKDLoss', 'CRDLoss', 'CrossEntropyLoss'
]
23 changes: 23 additions & 0 deletions mmrazor/models/losses/cross_entropy_loss.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
# Copyright (c) OpenMMLab. All rights reserved.
import torch.nn as nn
import torch.nn.functional as F

from mmrazor.registry import MODELS


@MODELS.register_module()
class CrossEntropyLoss(nn.Module):
"""Cross entropy loss.
Args:
loss_weight (float): Weight of the loss. Defaults to 1.0.
"""

def __init__(self, loss_weight=1.0):
super(CrossEntropyLoss, self).__init__()
self.loss_weight = loss_weight

def forward(self, preds_S, preds_T):
preds_T = preds_T.detach()
loss = F.cross_entropy(preds_S, preds_T.argmax(dim=1))
return loss * self.loss_weight
1 change: 1 addition & 0 deletions model-index.yml
Original file line number Diff line number Diff line change
Expand Up @@ -18,3 +18,4 @@ Import:
- configs/nas/mmcls/autoslim/metafile.yml
- configs/nas/mmcls/darts/metafile.yml
- configs/nas/mmdet/detnas/metafile.yml
- configs/distill/mmcls/deit/metafile.yml

0 comments on commit 8c7cdb3

Please sign in to comment.