-
Notifications
You must be signed in to change notification settings - Fork 1.1k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
13 changed files
with
939 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,88 @@ | ||
# Fine-tuning MAE on Some Image Classification Datasets | ||
|
||
## Usage | ||
|
||
### Setup Environment | ||
|
||
Please refer to [Get Started](https://mmpretrain.readthedocs.io/en/latest/get_started.html) documentation of MMPretrain to finish installation. | ||
|
||
### Data Preparation | ||
|
||
Please download and unzip datasets in the `data` folder. | ||
|
||
### Fine-tuning Commands | ||
|
||
At first, you need to add the current folder to `PYTHONPATH`, so that Python can find your model files. In `projects/mae_classification/` root directory, please run command below to add it. | ||
|
||
```shell | ||
export PYTHONPATH=`pwd`:$PYTHONPATH | ||
``` | ||
|
||
Then run the following commands to train the model: | ||
|
||
#### On Local Single GPU | ||
|
||
```bash | ||
# train with mim | ||
mim train mmpretrain ${CONFIG} --work-dir ${WORK_DIR} | ||
|
||
# a specific command example | ||
mim train mmpretrain configs/vit-base-p16_8xb8-coslr-100e_caltech101.py --work-dir work_dirs/vit-base-p16_8xb8-coslr-100e_caltech101 | ||
``` | ||
|
||
#### On Multiple GPUs | ||
|
||
```bash | ||
# train with mim | ||
# a specific command examples, 8 GPUs here | ||
mim train mmpretrain configs/vit-base-p16_8xb8-coslr-100e_caltech101.py --work-dir work_dirs/vit-base-p16_8xb8-coslr-100e_caltech101 --launcher pytorch --gpus 8 | ||
``` | ||
|
||
Note: | ||
|
||
- CONFIG: the config files under the directory `configs/` | ||
- WORK_DIR: the working directory to save configs, logs, and checkpoints | ||
|
||
#### On Multiple GPUs with Slurm | ||
|
||
```bash | ||
# train with mim | ||
mim train mmpretrain ${CONFIG} \ | ||
--work-dir ${WORK_DIR} \ | ||
--launcher slurm --gpus 16 --gpus-per-node 8 \ | ||
--partition ${PARTITION} | ||
``` | ||
|
||
Note: | ||
|
||
- CONFIG: the config files under the directory `configs/` | ||
- WORK_DIR: the working directory to save configs, logs, and checkpoints | ||
- PARTITION: the slurm partition you are using | ||
|
||
## Results | ||
|
||
| Datasets | Backbone | Params | Flops | Accuracy (%) | Config | | ||
| :-----------------: | :------: | :------: | :---------: | :----------: | :-------------------------------------------------------------: | | ||
| Food-101 | MAE-base | 85876325 | 17581219584 | 91.57 | [config](configs/vit-base-p16_8xb32-coslr-100e_food101.py) | | ||
| CIFAR-10 | MAE-base | 85806346 | 17581219584 | 98.45 | [config](configs/vit-base-p16_8xb32-coslr-100e_cifar10.py) | | ||
| CIFAR-100 | MAE-base | 85875556 | 17581219584 | 90.06 | [config](configs/vit-base-p16_8xb16-coslr-100e_cifar100.py) | | ||
| SUN397 | MAE-base | 86103949 | 17581219584 | 67.84 | [config](configs/vit-base-p16_8xb32-coslr-100e_sun397.py) | | ||
| Stanford Cars | MAE-base | 85949380 | 17581219584 | 93.11 | [config](configs/vit-base-p16_8xb8-coslr-100e_stanfordcars.py) | | ||
| FGVC Aircraft | MAE-base | 85875556 | 17581219584 | 88.24 | [config](configs/vit-base-p16_8xb8-coslr-100e_fgvcaircraft.py) | | ||
| DTD | MAE-base | 85834799 | 17581219584 | 77.55 | [config](configs/vit-base-p16_8xb16-coslr-100e_dtd.py) | | ||
| Oxford-IIIT Pets | MAE-base | 85827109 | 17581219584 | 91.66 | [config](configs/vit-base-p16_8xb8-coslr-100e_oxfordiiitpet.py) | | ||
| Caltech-101 | MAE-base | 85877094 | 17581219584 | 93.22 | [config](configs/vit-base-p16_8xb8-coslr-100e_caltech101.py) | | ||
| Oxford 102 Flowers | MAE-base | 85877094 | 17581219584 | 95.20 | [config](configs/vit-base-p16_8xb8-coslr-100e_flowers102.py) | | ||
| PASCAL VOC 2007 cls | MAE-base | 85814036 | 17581219584 | 88.69 (mAP) | [config](configs/vit-base-p16_8xb8-coslr-100e_voc.py) | | ||
|
||
## Citation | ||
|
||
```bibtex | ||
@article{He2021MaskedAA, | ||
title={Masked Autoencoders Are Scalable Vision Learners}, | ||
author={Kaiming He and Xinlei Chen and Saining Xie and Yanghao Li and | ||
Piotr Doll'ar and Ross B. Girshick}, | ||
journal={arXiv}, | ||
year={2021} | ||
} | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,75 @@ | ||
_base_ = 'mmpretrain::_base_/default_runtime.py' | ||
|
||
# dataset settings | ||
data_preprocessor = dict( | ||
num_classes=1000, | ||
# RGB format normalization parameters | ||
mean=[123.675, 116.28, 103.53], | ||
std=[58.395, 57.12, 57.375], | ||
# convert image from BGR to RGB | ||
to_rgb=True, | ||
) | ||
|
||
# model settings | ||
pretrained = 'https://download.openmmlab.com/mmselfsup/1.x/mae/mae_vit-base-p16_8xb512-fp16-coslr-1600e_in1k/mae_vit-base-p16_8xb512-fp16-coslr-1600e_in1k_20220825-f7569ca2.pth' # noqa | ||
|
||
model = dict( | ||
type='ImageClassifier', | ||
backbone=dict( | ||
type='VisionTransformer', | ||
arch='base', | ||
img_size=224, | ||
patch_size=16, | ||
drop_path_rate=0.1, | ||
out_type='avg_featmap', | ||
final_norm=False, | ||
init_cfg=dict( | ||
type='Pretrained', checkpoint=pretrained, prefix='backbone')), | ||
neck=None, | ||
head=dict( | ||
type='LinearClsHead', | ||
num_classes=1000, | ||
in_channels=768, | ||
loss=dict( | ||
type='LabelSmoothLoss', label_smooth_val=0.1, loss_weight=1.0), | ||
init_cfg=[dict(type='Constant', layer='Linear', val=0.0)]), | ||
train_cfg=dict(augments=[ | ||
dict(type='Mixup', alpha=0.8), | ||
dict(type='CutMix', alpha=1.0) | ||
])) | ||
|
||
# schedule settings | ||
optim_wrapper = dict( | ||
optimizer=dict( | ||
type='AdamW', | ||
lr=2.5e-05, | ||
betas=(0.9, 0.999), | ||
eps=1e-08, | ||
weight_decay=0.05), | ||
paramwise_cfg=dict( | ||
norm_decay_mult=0.0, | ||
bias_decay_mult=0.0, | ||
flat_decay_mult=0.0, | ||
custom_keys=dict({ | ||
'.absolute_pos_embed': dict(decay_mult=0.0), | ||
'.relative_position_bias_table': dict(decay_mult=0.0), | ||
'.ln': dict(decay_mult=0.0), | ||
'.bias': dict(decay_mult=0.0), | ||
'.cls_token': dict(decay_mult=0.0), | ||
'.pos_embed': dict(decay_mult=0.0) | ||
}))) | ||
|
||
param_scheduler = [ | ||
dict(type='LinearLR', start_factor=0.1, by_epoch=True, begin=0, end=5), | ||
dict(type='CosineAnnealingLR', T_max=95, by_epoch=True, begin=5, end=100) | ||
] | ||
|
||
train_cfg = dict(by_epoch=True, max_epochs=100, val_interval=1) | ||
val_cfg = dict() | ||
test_cfg = dict() | ||
auto_scale_lr = dict(base_batch_size=64) | ||
|
||
# runtime settings | ||
default_hooks = dict( | ||
# save checkpoint per epoch. | ||
checkpoint=dict(type='CheckpointHook', interval=1, max_keep_ckpts=1)) |
73 changes: 73 additions & 0 deletions
73
projects/mae_classification/configs/vit-base-p16_8xb16-coslr-100e_cifar100.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,73 @@ | ||
_base_ = './_base_.py' | ||
|
||
# dataset settings | ||
dataset_type = 'CIFAR100' | ||
num_classes = 100 | ||
data_preprocessor = dict( | ||
num_classes=num_classes, | ||
# RGB format normalization parameters | ||
mean=[129.304, 124.070, 112.434], | ||
std=[68.170, 65.392, 70.418], | ||
# loaded images are already RGB format | ||
to_rgb=False) | ||
|
||
train_pipeline = [ | ||
dict(type='Resize', scale=224), | ||
dict( | ||
type='RandAugment', | ||
policies='timm_increasing', | ||
num_policies=2, | ||
total_level=10, | ||
magnitude_level=7, | ||
magnitude_std=0.5, | ||
hparams=dict(pad_val=[104, 116, 124])), | ||
dict( | ||
type='RandomErasing', | ||
erase_prob=0.25, | ||
mode='rand', | ||
min_area_ratio=0.02, | ||
max_area_ratio=0.333333333, | ||
fill_color=[103.53, 116.28, 123.675], | ||
fill_std=[57.375, 57.12, 58.395]), | ||
dict(type='RandomFlip', prob=0.5, direction='horizontal'), | ||
dict(type='PackInputs') | ||
] | ||
|
||
test_pipeline = [ | ||
dict(type='Resize', scale=224), | ||
dict(type='PackInputs'), | ||
] | ||
|
||
train_dataloader = dict( | ||
batch_size=16, | ||
num_workers=2, | ||
dataset=dict( | ||
type=dataset_type, | ||
data_prefix='data/cifar100/', | ||
split='train', | ||
pipeline=train_pipeline), | ||
sampler=dict(type='DefaultSampler', shuffle=True), | ||
) | ||
|
||
val_dataloader = dict( | ||
batch_size=16, | ||
num_workers=2, | ||
dataset=dict( | ||
type=dataset_type, | ||
data_prefix='data/cifar100/', | ||
split='test', | ||
pipeline=test_pipeline), | ||
sampler=dict(type='DefaultSampler', shuffle=False), | ||
) | ||
val_evaluator = dict(type='Accuracy', topk=(1, )) | ||
|
||
test_dataloader = val_dataloader | ||
test_evaluator = val_evaluator | ||
|
||
# model settings | ||
model = dict(head=dict(num_classes=num_classes)) | ||
|
||
# schedule settings | ||
optim_wrapper = dict(optimizer=dict(lr=5e-05, weight_decay=0.01), ) | ||
|
||
auto_scale_lr = dict(base_batch_size=128) |
62 changes: 62 additions & 0 deletions
62
projects/mae_classification/configs/vit-base-p16_8xb16-coslr-100e_dtd.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,62 @@ | ||
_base_ = './_base_.py' | ||
|
||
# dataset settings | ||
dataset_type = 'DTD' | ||
num_classes = 47 | ||
data_preprocessor = dict(num_classes=num_classes) | ||
|
||
train_pipeline = [ | ||
dict(type='LoadImageFromFile'), | ||
dict(type='Resize', scale=224), | ||
dict( | ||
type='RandAugment', | ||
policies='timm_increasing', | ||
num_policies=2, | ||
total_level=10, | ||
magnitude_level=7, | ||
magnitude_std=0.5, | ||
hparams=dict(pad_val=[104, 116, 124])), | ||
dict(type='RandomFlip', prob=0.5, direction='horizontal'), | ||
dict(type='PackInputs') | ||
] | ||
|
||
test_pipeline = [ | ||
dict(type='LoadImageFromFile'), | ||
dict(type='Resize', scale=224), | ||
dict(type='PackInputs') | ||
] | ||
|
||
train_dataloader = dict( | ||
batch_size=16, | ||
num_workers=2, | ||
dataset=dict( | ||
type=dataset_type, | ||
data_root='data/dtd', | ||
split='trainval', | ||
pipeline=train_pipeline), | ||
sampler=dict(type='DefaultSampler', shuffle=True), | ||
) | ||
|
||
val_dataloader = dict( | ||
batch_size=16, | ||
num_workers=2, | ||
dataset=dict( | ||
type=dataset_type, | ||
data_root='data/dtd', | ||
split='test', | ||
pipeline=test_pipeline), | ||
sampler=dict(type='DefaultSampler', shuffle=False), | ||
) | ||
|
||
val_evaluator = dict(type='Accuracy', topk=(1, )) | ||
|
||
test_dataloader = val_dataloader | ||
test_evaluator = val_evaluator | ||
|
||
# model settings | ||
model = dict(head=dict(num_classes=num_classes)) | ||
|
||
# optimizer wrapper | ||
optim_wrapper = dict(optimizer=dict(lr=0.0001, weight_decay=0.005), ) | ||
|
||
auto_scale_lr = dict(base_batch_size=128) |
73 changes: 73 additions & 0 deletions
73
projects/mae_classification/configs/vit-base-p16_8xb32-coslr-100e_cifar10.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,73 @@ | ||
_base_ = './_base_.py' | ||
|
||
# dataset settings | ||
dataset_type = 'CIFAR10' | ||
num_classes = 10 | ||
data_preprocessor = dict( | ||
num_classes=num_classes, | ||
# RGB format normalization parameters | ||
mean=[125.307, 122.961, 113.8575], | ||
std=[51.5865, 50.847, 51.255], | ||
# loaded images are already RGB format | ||
to_rgb=False) | ||
|
||
train_pipeline = [ | ||
dict(type='Resize', scale=224), | ||
dict( | ||
type='RandAugment', | ||
policies='timm_increasing', | ||
num_policies=2, | ||
total_level=10, | ||
magnitude_level=7, | ||
magnitude_std=0.5, | ||
hparams=dict(pad_val=[104, 116, 124])), | ||
dict(type='RandomFlip', prob=0.5, direction='horizontal'), | ||
dict(type='PackInputs') | ||
] | ||
|
||
test_pipeline = [ | ||
dict(type='Resize', scale=224), | ||
dict(type='PackInputs'), | ||
] | ||
|
||
train_dataloader = dict( | ||
batch_size=32, | ||
num_workers=2, | ||
dataset=dict( | ||
type=dataset_type, | ||
data_prefix='data/cifar10', | ||
split='train', | ||
pipeline=train_pipeline), | ||
sampler=dict(type='DefaultSampler', shuffle=True), | ||
) | ||
|
||
val_dataloader = dict( | ||
batch_size=32, | ||
num_workers=2, | ||
dataset=dict( | ||
type=dataset_type, | ||
data_prefix='data/cifar10/', | ||
split='test', | ||
pipeline=test_pipeline), | ||
sampler=dict(type='DefaultSampler', shuffle=False), | ||
) | ||
val_evaluator = dict(type='Accuracy', topk=(1, )) | ||
|
||
test_dataloader = val_dataloader | ||
test_evaluator = val_evaluator | ||
|
||
# model settings | ||
model = dict(head=dict(num_classes=num_classes)) | ||
|
||
param_scheduler = [ | ||
dict( | ||
type='LinearLR', | ||
start_factor=0.001, | ||
by_epoch=True, | ||
begin=0, | ||
end=5, | ||
convert_to_iter_based=True), | ||
dict(type='CosineAnnealingLR', T_max=95, by_epoch=True, begin=5, end=100) | ||
] | ||
|
||
auto_scale_lr = dict(base_batch_size=256) |
Oops, something went wrong.