Skip to content

Commit

Permalink
vit_update: initialize method and do pre-scale before softmax; update…
Browse files Browse the repository at this point in the history
… configuration yaml and ckpt files (#765)
  • Loading branch information
sageyou committed Mar 15, 2024
1 parent 20b366f commit 6cfc6e6
Show file tree
Hide file tree
Showing 5 changed files with 89 additions and 91 deletions.
10 changes: 5 additions & 5 deletions configs/vit/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -34,11 +34,11 @@ Our reproduced model performance on ImageNet-1K is reported as follows.

<div align="center">

| Model | Context | Top-1 (%) | Top-5 (%) | Params (M) | Recipe | Download |
|--------------|----------|-----------|-----------|------------|-----------------------------------------------------------------------------------------------|--------------------------------------------------------------------------------|
| vit_b_32_224 | D910x8-G | 75.86 | 92.08 | 87.46 | [yaml](https://github.com/mindspore-lab/mindcv/blob/main/configs/vit/vit_b32_224_ascend.yaml) | [weights](https://download.mindspore.cn/toolkits/mindcv/vit/vit_b_32_224-f50866e8.ckpt) |
| vit_l_16_224 | D910x8-G | 76.34 | 92.79 | 303.31 | [yaml](https://github.com/mindspore-lab/mindcv/blob/main/configs/vit/vit_l16_224_ascend.yaml) | [weights](https://download.mindspore.cn/toolkits/mindcv/vit/vit_l_16_224-97d0fdbc.ckpt) |
| vit_l_32_224 | D910x8-G | 73.71 | 90.92 | 305.52 | [yaml](https://github.com/mindspore-lab/mindcv/blob/main/configs/vit/vit_l32_224_ascend.yaml) | [weights](https://download.mindspore.cn/toolkits/mindcv/vit/vit_l_32_224-b80441df.ckpt) |
| Model | Context | Top-1 (%) | Top-5 (%) | Params (M) | Recipe | Download |
|--------------|----------|--|-----------|------------|-----------------------------------------------------------------------------------------------|-----------------------------------------------------------------------------------------|
| vit_b_32_224 | D910x8-G | 77.45 | 93.27 | 87.46 | [yaml](https://github.com/mindspore-lab/mindcv/blob/main/configs/vit/vit_b32_224_ascend.yaml) | [weights](https://download.mindspore.cn/toolkits/mindcv/vit/vit_b_32_224-4a1c9d8e.ckpt) |
| vit_l_16_224 | D910x8-G | 81.25 | 95.53 | 303.31 | [yaml](https://github.com/mindspore-lab/mindcv/blob/main/configs/vit/vit_l16_224_ascend.yaml) | [weights](https://download.mindspore.cn/toolkits/mindcv/vit/vit_l_16_224-d2635f8b.ckpt) |
| vit_l_32_224 | D910x8-G | 74.57 | 91.01 | 305.52 | [yaml](https://github.com/mindspore-lab/mindcv/blob/main/configs/vit/vit_l32_224_ascend.yaml) | [weights](https://download.mindspore.cn/toolkits/mindcv/vit/vit_l_32_224-8c8ea164.ckpt) |

</div>

Expand Down
45 changes: 23 additions & 22 deletions configs/vit/vit_b32_224_ascend.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -9,52 +9,53 @@ val_interval: 1
dataset: "imagenet"
data_dir: "/path/to/imagenet"
shuffle: True
dataset_download: False
batch_size: 256
drop_remainder: True
batch_size: 512

# augmentation
image_resize: 224
scale: [0.08, 1.0]
ratio: [0.75, 1.333]
hflip: 0.5
vflip: 0.0
interpolation: "bicubic"
re_prob: 0.1
mixup: 0.2
auto_augment: "randaug-m9-mstd0.5"
re_prob: 0.0
cutmix: 1.0
cutmix_prob: 1.0
crop_pct: 0.875
color_jitter: [0.4, 0.4, 0.4]
auto_augment: "randaug-m7-mstd0.5"
mixup: 0.8

# model
model: "vit_b_32_224"
num_classes: 1000
drop_rate: 0.1
drop_path_rate: 0.1
num_classes: 1000
pretrained: False
ckpt_path: ""
keep_checkpoint_max: 10
ckpt_save_policy: "top_k"
ckpt_save_dir: "./ckpt"
epoch_size: 600
ckpt_save_interval: 1
ckpt_save_policy: "top_k"
epoch_size: 300
dataset_sink_mode: True
amp_level: "O2"

# loss
loss: "CE"
loss_scale: 1024.0
label_smoothing: 0.1

# lr scheduler
scheduler: "warmup_cosine_decay"
lr: 0.003
min_lr: 1e-6
warmup_epochs: 32
decay_epochs: 568
lr_epoch_stair: False
scheduler: "cosine_decay"
lr: 1.6e-3
min_lr: 0.0
warmup_epochs: 30
decay_epochs: 270
warmup_factor: 0.01

# optimizer
opt: "adamw"
weight_decay: 0.025
weight_decay_filter: "norm_and_bias"
weight_decay: 0.3
use_nesterov: False

# amp
amp_level: "O2"
val_amp_level: 'O2'
loss_scale_type: 'fixed'
loss_scale: 1024
47 changes: 25 additions & 22 deletions configs/vit/vit_l16_224_ascend.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -9,52 +9,55 @@ val_interval: 1
dataset: "imagenet"
data_dir: "/path/to/imagenet"
shuffle: True
dataset_download: False
batch_size: 48
drop_remainder: True
batch_size: 512

# augmentation
image_resize: 224
scale: [0.08, 1.0]
ratio: [0.75, 1.333]
hflip: 0.5
vflip: 0.0
interpolation: "bicubic"
re_prob: 0.15
mixup: 0.2
auto_augment: "randaug-m9-mstd0.5"
re_prob: 0.0
cutmix: 1.0
cutmix_prob: 1.0
crop_pct: 0.875
color_jitter: [0.4, 0.4, 0.4]
auto_augment: "randaug-m9-mstd0.5"
mixup: 0.8

# model
model: "vit_l_16_224"
drop_rate: 0.12
drop_path_rate: 0.1
num_classes: 1000
pretrained: False
ckpt_path: ""
drop_rate: 0.15
drop_path_rate: 0.2
keep_checkpoint_max: 10
ckpt_save_policy: "top_k"
ckpt_save_dir: "./ckpt"
ckpt_save_interval: 1
ckpt_save_policy: "top_k"
epoch_size: 300
dataset_sink_mode: True
amp_level: "O2"

# loss
loss: "CE"
loss_scale: 1024.0
label_smoothing: 0.1

# lr scheduler
scheduler: "warmup_cosine_decay"
lr: 0.0005
min_lr: 1e-5
warmup_epochs: 32
decay_epochs: 268
lr_epoch_stair: False
scheduler: "cosine_decay"
lr: 1.6e-3
min_lr: 0.0
warmup_epochs: 30
decay_epochs: 270
warmup_factor: 0.01

# optimizer
opt: "adamw"
weight_decay: 0.05
weight_decay_filter: "norm_and_bias"
weight_decay: 0.3
use_nesterov: False

# amp
amp_level: "O2"
val_amp_level: 'O2'
loss_scale_type: 'fixed'
loss_scale: 1024

ema: True
45 changes: 23 additions & 22 deletions configs/vit/vit_l32_224_ascend.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -9,52 +9,53 @@ val_interval: 1
dataset: "imagenet"
data_dir: "/path/to/imagenet"
shuffle: True
dataset_download: False
batch_size: 128
drop_remainder: True
batch_size: 512

# augmentation
image_resize: 224
scale: [0.08, 1.0]
ratio: [0.75, 1.333]
hflip: 0.5
vflip: 0.0
interpolation: "bicubic"
re_prob: 0.1
mixup: 0.2
auto_augment: "randaug-m9-mstd0.5"
re_prob: 0.0
cutmix: 1.0
cutmix_prob: 1.0
crop_pct: 0.875
color_jitter: [0.4, 0.4, 0.4]
auto_augment: "randaug-m7-mstd0.5"
mixup: 0.8

# model
model: "vit_l_32_224"
drop_rate: 0.1
drop_path_rate: 0.1
num_classes: 1000
pretrained: False
ckpt_path: ""
drop_rate: 0.1
drop_path_rate: 0.2
keep_checkpoint_max: 10
ckpt_save_policy: "top_k"
ckpt_save_dir: "./ckpt"
ckpt_save_interval: 1
ckpt_save_policy: "top_k"
epoch_size: 300
dataset_sink_mode: True
amp_level: "O2"

# loss
loss: "CE"
loss_scale: 1024.0
label_smoothing: 0.1

# lr scheduler
scheduler: "warmup_cosine_decay"
lr: 0.0015
min_lr: 1e-6
warmup_epochs: 32
decay_epochs: 268
lr_epoch_stair: False
scheduler: "cosine_decay"
lr: 1.6e-3
min_lr: 0.0
warmup_epochs: 30
decay_epochs: 270
warmup_factor: 0.01

# optimizer
opt: "adamw"
weight_decay: 0.025
weight_decay_filter: "norm_and_bias"
weight_decay: 0.3
use_nesterov: False

# amp
amp_level: "O2"
val_amp_level: 'O2'
loss_scale_type: 'fixed'
loss_scale: 1024
33 changes: 13 additions & 20 deletions mindcv/models/vit.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
"""ViT"""
import functools
from typing import Callable, Optional

import numpy as np

import mindspore as ms
from mindspore import Parameter, Tensor, nn, ops
from mindspore.common.initializer import HeUniform, TruncatedNormal, initializer
from mindspore.common.initializer import TruncatedNormal, XavierUniform, initializer

from .helpers import load_pretrained
from .layers.compatibility import Dropout
Expand Down Expand Up @@ -34,7 +35,7 @@ def _cfg(url="", **kwargs):
"num_classes": 1000,
"input_size": (3, 224, 224),
"first_conv": "patch_embed.proj",
"classifier": "head.classifier",
"classifier": "head",
**kwargs,
}

Expand All @@ -44,15 +45,15 @@ def _cfg(url="", **kwargs):
"vit_b_16_384": _cfg(
url="", input_size=(3, 384, 384)
),
"vit_l_16_224": _cfg(url="https://download.mindspore.cn/toolkits/mindcv/vit/vit_l_16_224-97d0fdbc.ckpt"),
"vit_l_16_224": _cfg(url="https://download.mindspore.cn/toolkits/mindcv/vit/vit_l_16_224-d2635f8b.ckpt"),
"vit_l_16_384": _cfg(
url="", input_size=(3, 384, 384)
),
"vit_b_32_224": _cfg(url="https://download.mindspore.cn/toolkits/mindcv/vit/vit_b_32_224-f50866e8.ckpt"),
"vit_b_32_224": _cfg(url="https://download.mindspore.cn/toolkits/mindcv/vit/vit_b_32_224-4a1c9d8e.ckpt"),
"vit_b_32_384": _cfg(
url="", input_size=(3, 384, 384)
),
"vit_l_32_224": _cfg(url="https://download.mindspore.cn/toolkits/mindcv/vit/vit_l_32_224-b80441df.ckpt"),
"vit_l_32_224": _cfg(url="https://download.mindspore.cn/toolkits/mindcv/vit/vit_l_32_224-8c8ea164.ckpt"),
}


Expand Down Expand Up @@ -114,11 +115,11 @@ def construct(self, x):
q, k, v = self.unstack(qkv)
q, k = self.q_norm(q), self.k_norm(k)

q = self.mul(q, self.scale**0.5)
k = self.mul(k, self.scale**0.5)
attn = self.q_matmul_k(q, k)
attn = self.mul(attn, self.scale)

attn = attn.astype(ms.float32)
attn = ops.softmax(attn, axis=-1)
attn = ops.softmax(attn.astype(ms.float32), axis=-1).astype(attn.dtype)
attn = self.attn_drop(attn)

out = self.attn_matmul_v(attn, v)
Expand Down Expand Up @@ -325,14 +326,14 @@ def __init__(
def get_num_layers(self):
return len(self.blocks)

def no_weight_decay(self):
return {'pos_embed', 'cls_token'}

def _init_weights(self):
w = self.patch_embed.proj.weight
w_shape_flatted = (w.shape[0], functools.reduce(lambda x, y: x*y, w.shape[1:]))
w.set_data(initializer(XavierUniform(), w_shape_flatted, w.dtype).reshape(w.shape))
for _, cell in self.cells_and_names():
if isinstance(cell, nn.Dense):
cell.weight.set_data(
initializer(TruncatedNormal(0.02), cell.weight.shape, cell.weight.dtype)
initializer(XavierUniform(), cell.weight.shape, cell.weight.dtype)
)
if cell.bias is not None:
cell.bias.set_data(
Expand All @@ -345,14 +346,6 @@ def _init_weights(self):
cell.beta.set_data(
initializer('zeros', cell.beta.shape, cell.beta.dtype)
)
elif isinstance(cell, nn.Conv2d):
cell.weight.set_data(
initializer(HeUniform(), cell.weight.shape, cell.weight.dtype)
)
if cell.bias is not None:
cell.bias.set_data(
initializer("zeros", cell.bias.shape, cell.bias.dtype)
)

def _pos_embed(self, x):
if self.dynamic_img_size or self.dynamic_img_pad:
Expand Down

0 comments on commit 6cfc6e6

Please sign in to comment.