Skip to content

Commit

Permalink
mae finetune
Browse files Browse the repository at this point in the history
  • Loading branch information
zzc98 committed Jun 8, 2023
1 parent 057d7c6 commit 7dfd991
Show file tree
Hide file tree
Showing 13 changed files with 939 additions and 0 deletions.
88 changes: 88 additions & 0 deletions projects/mae_classification/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
# Fine-tuning MAE on Some Image Classification Datasets

## Usage

### Setup Environment

Please refer to [Get Started](https://mmpretrain.readthedocs.io/en/latest/get_started.html) documentation of MMPretrain to finish installation.

### Data Preparation

Please download and unzip datasets in the `data` folder.

### Fine-tuning Commands

At first, you need to add the current folder to `PYTHONPATH`, so that Python can find your model files. In `projects/mae_classification/` root directory, please run command below to add it.

```shell
export PYTHONPATH=`pwd`:$PYTHONPATH
```

Then run the following commands to train the model:

#### On Local Single GPU

```bash
# train with mim
mim train mmpretrain ${CONFIG} --work-dir ${WORK_DIR}

# a specific command example
mim train mmpretrain configs/vit-base-p16_8xb8-coslr-100e_caltech101.py --work-dir work_dirs/vit-base-p16_8xb8-coslr-100e_caltech101
```

#### On Multiple GPUs

```bash
# train with mim
# a specific command examples, 8 GPUs here
mim train mmpretrain configs/vit-base-p16_8xb8-coslr-100e_caltech101.py --work-dir work_dirs/vit-base-p16_8xb8-coslr-100e_caltech101 --launcher pytorch --gpus 8
```

Note:

- CONFIG: the config files under the directory `configs/`
- WORK_DIR: the working directory to save configs, logs, and checkpoints

#### On Multiple GPUs with Slurm

```bash
# train with mim
mim train mmpretrain ${CONFIG} \
--work-dir ${WORK_DIR} \
--launcher slurm --gpus 16 --gpus-per-node 8 \
--partition ${PARTITION}
```

Note:

- CONFIG: the config files under the directory `configs/`
- WORK_DIR: the working directory to save configs, logs, and checkpoints
- PARTITION: the slurm partition you are using

## Results

| Datasets | Backbone | Params | Flops | Accuracy (%) | Config |
| :-----------------: | :------: | :------: | :---------: | :----------: | :-------------------------------------------------------------: |
| Food-101 | MAE-base | 85876325 | 17581219584 | 91.57 | [config](configs/vit-base-p16_8xb32-coslr-100e_food101.py) |
| CIFAR-10 | MAE-base | 85806346 | 17581219584 | 98.45 | [config](configs/vit-base-p16_8xb32-coslr-100e_cifar10.py) |
| CIFAR-100 | MAE-base | 85875556 | 17581219584 | 90.06 | [config](configs/vit-base-p16_8xb16-coslr-100e_cifar100.py) |
| SUN397 | MAE-base | 86103949 | 17581219584 | 67.84 | [config](configs/vit-base-p16_8xb32-coslr-100e_sun397.py) |
| Stanford Cars | MAE-base | 85949380 | 17581219584 | 93.11 | [config](configs/vit-base-p16_8xb8-coslr-100e_stanfordcars.py) |
| FGVC Aircraft | MAE-base | 85875556 | 17581219584 | 88.24 | [config](configs/vit-base-p16_8xb8-coslr-100e_fgvcaircraft.py) |
| DTD | MAE-base | 85834799 | 17581219584 | 77.55 | [config](configs/vit-base-p16_8xb16-coslr-100e_dtd.py) |
| Oxford-IIIT Pets | MAE-base | 85827109 | 17581219584 | 91.66 | [config](configs/vit-base-p16_8xb8-coslr-100e_oxfordiiitpet.py) |
| Caltech-101 | MAE-base | 85877094 | 17581219584 | 93.22 | [config](configs/vit-base-p16_8xb8-coslr-100e_caltech101.py) |
| Oxford 102 Flowers | MAE-base | 85877094 | 17581219584 | 95.20 | [config](configs/vit-base-p16_8xb8-coslr-100e_flowers102.py) |
| PASCAL VOC 2007 cls | MAE-base | 85814036 | 17581219584 | 88.69 (mAP) | [config](configs/vit-base-p16_8xb8-coslr-100e_voc.py) |

## Citation

```bibtex
@article{He2021MaskedAA,
title={Masked Autoencoders Are Scalable Vision Learners},
author={Kaiming He and Xinlei Chen and Saining Xie and Yanghao Li and
Piotr Doll'ar and Ross B. Girshick},
journal={arXiv},
year={2021}
}
```
75 changes: 75 additions & 0 deletions projects/mae_classification/configs/_base_.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
_base_ = 'mmpretrain::_base_/default_runtime.py'

# dataset settings
data_preprocessor = dict(
num_classes=1000,
# RGB format normalization parameters
mean=[123.675, 116.28, 103.53],
std=[58.395, 57.12, 57.375],
# convert image from BGR to RGB
to_rgb=True,
)

# model settings
pretrained = 'https://download.openmmlab.com/mmselfsup/1.x/mae/mae_vit-base-p16_8xb512-fp16-coslr-1600e_in1k/mae_vit-base-p16_8xb512-fp16-coslr-1600e_in1k_20220825-f7569ca2.pth' # noqa

model = dict(
type='ImageClassifier',
backbone=dict(
type='VisionTransformer',
arch='base',
img_size=224,
patch_size=16,
drop_path_rate=0.1,
out_type='avg_featmap',
final_norm=False,
init_cfg=dict(
type='Pretrained', checkpoint=pretrained, prefix='backbone')),
neck=None,
head=dict(
type='LinearClsHead',
num_classes=1000,
in_channels=768,
loss=dict(
type='LabelSmoothLoss', label_smooth_val=0.1, loss_weight=1.0),
init_cfg=[dict(type='Constant', layer='Linear', val=0.0)]),
train_cfg=dict(augments=[
dict(type='Mixup', alpha=0.8),
dict(type='CutMix', alpha=1.0)
]))

# schedule settings
optim_wrapper = dict(
optimizer=dict(
type='AdamW',
lr=2.5e-05,
betas=(0.9, 0.999),
eps=1e-08,
weight_decay=0.05),
paramwise_cfg=dict(
norm_decay_mult=0.0,
bias_decay_mult=0.0,
flat_decay_mult=0.0,
custom_keys=dict({
'.absolute_pos_embed': dict(decay_mult=0.0),
'.relative_position_bias_table': dict(decay_mult=0.0),
'.ln': dict(decay_mult=0.0),
'.bias': dict(decay_mult=0.0),
'.cls_token': dict(decay_mult=0.0),
'.pos_embed': dict(decay_mult=0.0)
})))

param_scheduler = [
dict(type='LinearLR', start_factor=0.1, by_epoch=True, begin=0, end=5),
dict(type='CosineAnnealingLR', T_max=95, by_epoch=True, begin=5, end=100)
]

train_cfg = dict(by_epoch=True, max_epochs=100, val_interval=1)
val_cfg = dict()
test_cfg = dict()
auto_scale_lr = dict(base_batch_size=64)

# runtime settings
default_hooks = dict(
# save checkpoint per epoch.
checkpoint=dict(type='CheckpointHook', interval=1, max_keep_ckpts=1))
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
_base_ = './_base_.py'

# dataset settings
dataset_type = 'CIFAR100'
num_classes = 100
data_preprocessor = dict(
num_classes=num_classes,
# RGB format normalization parameters
mean=[129.304, 124.070, 112.434],
std=[68.170, 65.392, 70.418],
# loaded images are already RGB format
to_rgb=False)

train_pipeline = [
dict(type='Resize', scale=224),
dict(
type='RandAugment',
policies='timm_increasing',
num_policies=2,
total_level=10,
magnitude_level=7,
magnitude_std=0.5,
hparams=dict(pad_val=[104, 116, 124])),
dict(
type='RandomErasing',
erase_prob=0.25,
mode='rand',
min_area_ratio=0.02,
max_area_ratio=0.333333333,
fill_color=[103.53, 116.28, 123.675],
fill_std=[57.375, 57.12, 58.395]),
dict(type='RandomFlip', prob=0.5, direction='horizontal'),
dict(type='PackInputs')
]

test_pipeline = [
dict(type='Resize', scale=224),
dict(type='PackInputs'),
]

train_dataloader = dict(
batch_size=16,
num_workers=2,
dataset=dict(
type=dataset_type,
data_prefix='data/cifar100/',
split='train',
pipeline=train_pipeline),
sampler=dict(type='DefaultSampler', shuffle=True),
)

val_dataloader = dict(
batch_size=16,
num_workers=2,
dataset=dict(
type=dataset_type,
data_prefix='data/cifar100/',
split='test',
pipeline=test_pipeline),
sampler=dict(type='DefaultSampler', shuffle=False),
)
val_evaluator = dict(type='Accuracy', topk=(1, ))

test_dataloader = val_dataloader
test_evaluator = val_evaluator

# model settings
model = dict(head=dict(num_classes=num_classes))

# schedule settings
optim_wrapper = dict(optimizer=dict(lr=5e-05, weight_decay=0.01), )

auto_scale_lr = dict(base_batch_size=128)
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
_base_ = './_base_.py'

# dataset settings
dataset_type = 'DTD'
num_classes = 47
data_preprocessor = dict(num_classes=num_classes)

train_pipeline = [
dict(type='LoadImageFromFile'),
dict(type='Resize', scale=224),
dict(
type='RandAugment',
policies='timm_increasing',
num_policies=2,
total_level=10,
magnitude_level=7,
magnitude_std=0.5,
hparams=dict(pad_val=[104, 116, 124])),
dict(type='RandomFlip', prob=0.5, direction='horizontal'),
dict(type='PackInputs')
]

test_pipeline = [
dict(type='LoadImageFromFile'),
dict(type='Resize', scale=224),
dict(type='PackInputs')
]

train_dataloader = dict(
batch_size=16,
num_workers=2,
dataset=dict(
type=dataset_type,
data_root='data/dtd',
split='trainval',
pipeline=train_pipeline),
sampler=dict(type='DefaultSampler', shuffle=True),
)

val_dataloader = dict(
batch_size=16,
num_workers=2,
dataset=dict(
type=dataset_type,
data_root='data/dtd',
split='test',
pipeline=test_pipeline),
sampler=dict(type='DefaultSampler', shuffle=False),
)

val_evaluator = dict(type='Accuracy', topk=(1, ))

test_dataloader = val_dataloader
test_evaluator = val_evaluator

# model settings
model = dict(head=dict(num_classes=num_classes))

# optimizer wrapper
optim_wrapper = dict(optimizer=dict(lr=0.0001, weight_decay=0.005), )

auto_scale_lr = dict(base_batch_size=128)
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
_base_ = './_base_.py'

# dataset settings
dataset_type = 'CIFAR10'
num_classes = 10
data_preprocessor = dict(
num_classes=num_classes,
# RGB format normalization parameters
mean=[125.307, 122.961, 113.8575],
std=[51.5865, 50.847, 51.255],
# loaded images are already RGB format
to_rgb=False)

train_pipeline = [
dict(type='Resize', scale=224),
dict(
type='RandAugment',
policies='timm_increasing',
num_policies=2,
total_level=10,
magnitude_level=7,
magnitude_std=0.5,
hparams=dict(pad_val=[104, 116, 124])),
dict(type='RandomFlip', prob=0.5, direction='horizontal'),
dict(type='PackInputs')
]

test_pipeline = [
dict(type='Resize', scale=224),
dict(type='PackInputs'),
]

train_dataloader = dict(
batch_size=32,
num_workers=2,
dataset=dict(
type=dataset_type,
data_prefix='data/cifar10',
split='train',
pipeline=train_pipeline),
sampler=dict(type='DefaultSampler', shuffle=True),
)

val_dataloader = dict(
batch_size=32,
num_workers=2,
dataset=dict(
type=dataset_type,
data_prefix='data/cifar10/',
split='test',
pipeline=test_pipeline),
sampler=dict(type='DefaultSampler', shuffle=False),
)
val_evaluator = dict(type='Accuracy', topk=(1, ))

test_dataloader = val_dataloader
test_evaluator = val_evaluator

# model settings
model = dict(head=dict(num_classes=num_classes))

param_scheduler = [
dict(
type='LinearLR',
start_factor=0.001,
by_epoch=True,
begin=0,
end=5,
convert_to_iter_based=True),
dict(type='CosineAnnealingLR', T_max=95, by_epoch=True, begin=5, end=100)
]

auto_scale_lr = dict(base_batch_size=256)
Loading

0 comments on commit 7dfd991

Please sign in to comment.