Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Enhance] Support dynamic input shape for ViT-based algorithms #706

Merged
merged 13 commits into from
Mar 3, 2022

Conversation

mzr1996
Copy link
Member

@mzr1996 mzr1996 commented Feb 22, 2022

Motivation

Currently, our ViT-based algorithms need to specify the image shape. But in some downstream tasks, they need input images with various shapes.

Modification

Refactor these ViT-based algorithms to support dynamic input shape as far as possible.

  • ✔️ VisionTransformer
  • ✔️ SwinTransformer
  • ✔️ DeiT
  • ❌ MLP-Mixer (The structure doesn't support dynamic input shape)
  • ✔️ T2T-ViT

BC-breaking (Optional)

The input_resolution and auto_pad argument are deprecated in the ShiftWindowMSA. And the auto_pad argument is removed from SwinTransformer.

Use cases (Optional)

>>> from mmcls.models import VisionTransformer
>>> import torch
>>> model = VisionTransformer(arch="base")
>>> inputs = torch.rand(1, 3, 256, 323)  # The input shape is no-limit.
>>> patch_embed, cls_token = model(inputs)[0]
>>> print(patch_embed.shape)
torch.Size([1, 768, 16, 21])
>>> print(cls_token.shape)
torch.Size([1, 768])

Checklist

Before PR:

  • Pre-commit or other linting tools are used to fix the potential lint issues.
  • Bug fixes are fully covered by unit tests, the case that causes the bug should be added in the unit tests.
  • The modification is covered by complete unit tests. If not, please add more unit test to ensure the correctness.
  • The documentation has been modified accordingly, like docstring or example tutorials.

After PR:

  • If the modification has potential influence on downstream or other related projects, this PR should be tested with those projects, like MMDet or MMSeg.
  • CLA has been signed and all committers have signed the CLA in this PR.

@codecov
Copy link

codecov bot commented Feb 25, 2022

Codecov Report

Merging #706 (d0d81e6) into dev (28e3bc8) will decrease coverage by 0.27%.
The diff coverage is 96.63%.

Impacted file tree graph

@@            Coverage Diff             @@
##              dev     #706      +/-   ##
==========================================
- Coverage   85.22%   84.94%   -0.28%     
==========================================
  Files         121      121              
  Lines        7432     7548     +116     
  Branches     1278     1303      +25     
==========================================
+ Hits         6334     6412      +78     
- Misses        911      944      +33     
- Partials      187      192       +5     
Flag Coverage Δ
unittests 84.87% <96.63%> (-0.31%) ⬇️

Flags with carried forward coverage won't be shown. Click here to find out more.

Impacted Files Coverage Δ
mmcls/models/backbones/swin_transformer.py 91.97% <87.80%> (-0.03%) ⬇️
mmcls/models/backbones/t2t_vit.py 95.12% <96.15%> (+0.59%) ⬆️
mmcls/models/backbones/deit.py 97.82% <100.00%> (+0.32%) ⬆️
mmcls/models/backbones/mlp_mixer.py 96.77% <100.00%> (+1.16%) ⬆️
mmcls/models/backbones/vision_transformer.py 96.37% <100.00%> (-0.75%) ⬇️
mmcls/models/utils/__init__.py 100.00% <100.00%> (ø)
mmcls/models/utils/attention.py 100.00% <100.00%> (ø)
mmcls/models/utils/embed.py 80.00% <100.00%> (-0.59%) ⬇️
mmcls/datasets/dataset_wrappers.py 70.58% <0.00%> (-21.17%) ⬇️
... and 8 more

Continue to review full report at Codecov.

Legend - Click here to learn more
Δ = absolute <relative> (impact), ø = not affected, ? = missing data
Powered by Codecov. Last update 28e3bc8...d0d81e6. Read the comment docs.

@Ezra-Yu Ezra-Yu changed the base branch from master to dev February 25, 2022 07:42
@Ezra-Yu Ezra-Yu self-requested a review March 1, 2022 11:08
Copy link
Collaborator

@Ezra-Yu Ezra-Yu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please add DeprecationWarning in PatchEmbed and PatchMerge.

mmcls/models/utils/embed.py Show resolved Hide resolved
@mzr1996 mzr1996 requested a review from Ezra-Yu March 2, 2022 07:24
Copy link
Collaborator

@Ezra-Yu Ezra-Yu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM.

img_size=224,
patch_size=16,
in_channels=3,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This introduces BC-breaking in mmselfsup 0.6.0.

@mzr1996 mzr1996 deleted the dyn-vit branch July 20, 2022 06:58
mzr1996 added a commit to mzr1996/mmpretrain that referenced this pull request Nov 24, 2022
…-mmlab#706)

* Move `resize_pos_embed` to `mmcls.models.utils`

* Refactor Vision Transformer

* Refactor DeiT

* Refactor MLP-Mixer

* Refactor Swin-Transformer

* Remove `indexing` arg

* Support dynamic inputs for t2t_vit

* Add copyright

* Fix bugs in swin transformer

* Add `pad_small_maps` option

* Update swin transformer

* Handle `attn_mask` in checkpoints of swin

* Imporve by comments
@ShangWeiKuo
Copy link

ShangWeiKuo commented Mar 7, 2024

Excuse me @mzr1996 . Could you please explain how you did to make ViT-based algorithms to support dynamic input shape?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants